From 41b7889a30cbe1e02b7c995fcc42587168e6c0b0 Mon Sep 17 00:00:00 2001 From: thinhlpg Date: Wed, 9 Apr 2025 03:28:46 +0000 Subject: [PATCH] feat: integrate QA dataset loading and display gold answers in Gradio interface --- app.py | 86 +++++++++++++++++++++++++++++++++++++++++--------- pyproject.toml | 8 +++-- 2 files changed, 76 insertions(+), 18 deletions(-) diff --git a/app.py b/app.py index 3bbec08..467cc05 100644 --- a/app.py +++ b/app.py @@ -21,7 +21,7 @@ from src import ( format_search_results, get_system_prompt, ) -from src.search_module import load_vectorstore, search +from src.search_module import get_qa_dataset, load_vectorstore, search # TODO: check if can reuse tokenizer adapter @@ -96,6 +96,20 @@ def create_interface(model_path: str, temperature: float = 0.7, system_prompt: s system_prompt = system_prompt or get_system_prompt() tokenizer_for_template = cast(PreTrainedTokenizer, tokenizer) + # Load QA dataset for examples and gold answers + try: + _, test_dataset = get_qa_dataset() + qa_map = {q: a for q, a in zip(test_dataset["prompt"], test_dataset["answer"])} + example_questions = list(qa_map.keys()) + logger.info(f"Loaded {len(example_questions)} QA examples.") + except Exception as e: + logger.error(f"Failed to load QA dataset: {e}") + qa_map = {} + example_questions = [ + "What year was the document approved by the Mission Evaluation Team?", + "Failed to load dataset examples.", + ] # Provide fallback examples + def get_chat_num_tokens(current_chat_state: dict) -> int: """Helper to get number of tokens in chat state.""" try: @@ -112,6 +126,7 @@ def create_interface(model_path: str, temperature: float = 0.7, system_prompt: s temp: float, max_iter: int = 20, num_search_results: int = 2, + gold_answer_state: str | None = None, ) -> Iterator[list[gr.ChatMessage]]: """Stream agent responses following agent.py/inference.py logic.""" chat_state = { @@ -326,6 +341,20 @@ def create_interface(model_path: str, temperature: float = 0.7, system_prompt: s logger.info(f"Processing finished in {total_time:.2f} seconds.") + # ---> Display Gold Answer if available (passed via gold_answer_state) + if gold_answer_state: + logger.info("Displaying gold answer.") + messages.append( + gr.ChatMessage( + role="assistant", + content=gold_answer_state, + metadata={"title": "✅ Correct Answer (For comparison)"}, + ) + ) + yield messages + else: + logger.info("No gold answer to display for this query.") + with gr.Blocks(title="DeepSearch - Visible Thinking") as interface: gr.Markdown("# 🧠 DeepSearch with Visible Thinking") gr.Markdown("Watch as the AI thinks, searches, and processes information to answer your questions.") @@ -345,13 +374,13 @@ def create_interface(model_path: str, temperature: float = 0.7, system_prompt: s placeholder="Type your message here...", show_label=False, container=False, elem_id="msg-input" ) - example_questions = [ - "What year was the document approved by the Mission Evaluation Team?", - "Summarize the key findings regarding the oxygen tank failure.", - "Who was the commander of Apollo 13?", - "What were the main recommendations from the review board?", - ] - gr.Examples(examples=example_questions, inputs=msg, label="Example Questions", examples_per_page=4) + # Use questions from dataset as examples + gr.Examples( + examples=example_questions, + inputs=msg, + label="Example Questions with correct answer for comparison", + examples_per_page=4, + ) with gr.Row(): clear = gr.Button("Clear Chat") @@ -387,16 +416,24 @@ def create_interface(model_path: str, temperature: float = 0.7, system_prompt: s return "", history submitted_msg_state = gr.State("") + gold_answer_state = gr.State(None) # State to hold the gold answer + + # Function to check if submitted message is an example and store gold answer + def check_if_example_and_store_answer(msg_text): + gold_answer = qa_map.get(msg_text) # Returns None if msg_text is not in examples + logger.info(f"Checking for gold answer for: '{msg_text[:50]}...'. Found: {bool(gold_answer)}") + return gold_answer - # Chain events: - # 1. User submits -> store msg text in state - # 2. .then() -> add_user_message (updates chatbot UI history, clears textbox) - # 3. .then() -> stream_agent_response (takes stored msg text and updated chatbot history) submit.click( lambda msg_text: msg_text, inputs=[msg], outputs=[submitted_msg_state], queue=False, + ).then( # Check for gold answer immediately after storing submitted message + check_if_example_and_store_answer, + inputs=[submitted_msg_state], + outputs=[gold_answer_state], + queue=False, ).then( add_user_message, inputs=[submitted_msg_state, chatbot], @@ -404,7 +441,14 @@ def create_interface(model_path: str, temperature: float = 0.7, system_prompt: s queue=False, ).then( stream_agent_response, - inputs=[submitted_msg_state, chatbot, temp_slider, max_iter_slider, num_results_slider], + inputs=[ + submitted_msg_state, + chatbot, + temp_slider, + max_iter_slider, + num_results_slider, + gold_answer_state, + ], # Pass gold answer state outputs=chatbot, ) @@ -413,6 +457,11 @@ def create_interface(model_path: str, temperature: float = 0.7, system_prompt: s inputs=[msg], outputs=[submitted_msg_state], queue=False, + ).then( # Check for gold answer immediately after storing submitted message + check_if_example_and_store_answer, + inputs=[submitted_msg_state], + outputs=[gold_answer_state], + queue=False, ).then( add_user_message, inputs=[submitted_msg_state, chatbot], @@ -420,11 +469,18 @@ def create_interface(model_path: str, temperature: float = 0.7, system_prompt: s queue=False, ).then( stream_agent_response, - inputs=[submitted_msg_state, chatbot, temp_slider, max_iter_slider, num_results_slider], + inputs=[ + submitted_msg_state, + chatbot, + temp_slider, + max_iter_slider, + num_results_slider, + gold_answer_state, + ], # Pass gold answer state outputs=chatbot, ) - clear.click(lambda: ([], ""), None, [chatbot, submitted_msg_state]) + clear.click(lambda: ([], None), None, [chatbot, gold_answer_state]) # Also clear gold answer state system_prompt_state = gr.State(system_prompt) # TODO: Currently, changing the system prompt mid-chat won't affect the ongoing stream_agent_response. diff --git a/pyproject.toml b/pyproject.toml index b2854f2..5bab91e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,8 +23,8 @@ dependencies = [ "langchain-community", "Markdown", "tokenizers", - "unsloth==2025.2.14", - "unsloth_zoo==2025.2.7", + "unsloth", + "unsloth_zoo", "unstructured", "vllm==0.7.2", "transformers==4.49.0", @@ -34,5 +34,7 @@ dependencies = [ "gradio", "tensorboard", "pytest", - "wandb" + "wandb", + "requests>=2.31.0", + "tqdm>=4.66.1" ] \ No newline at end of file