feat: integrate QA dataset loading and display gold answers in Gradio interface

main
thinhlpg 1 month ago
parent 7376f596a5
commit 41b7889a30

@ -21,7 +21,7 @@ from src import (
format_search_results, format_search_results,
get_system_prompt, get_system_prompt,
) )
from src.search_module import load_vectorstore, search from src.search_module import get_qa_dataset, load_vectorstore, search
# TODO: check if can reuse tokenizer adapter # TODO: check if can reuse tokenizer adapter
@ -96,6 +96,20 @@ def create_interface(model_path: str, temperature: float = 0.7, system_prompt: s
system_prompt = system_prompt or get_system_prompt() system_prompt = system_prompt or get_system_prompt()
tokenizer_for_template = cast(PreTrainedTokenizer, tokenizer) tokenizer_for_template = cast(PreTrainedTokenizer, tokenizer)
# Load QA dataset for examples and gold answers
try:
_, test_dataset = get_qa_dataset()
qa_map = {q: a for q, a in zip(test_dataset["prompt"], test_dataset["answer"])}
example_questions = list(qa_map.keys())
logger.info(f"Loaded {len(example_questions)} QA examples.")
except Exception as e:
logger.error(f"Failed to load QA dataset: {e}")
qa_map = {}
example_questions = [
"What year was the document approved by the Mission Evaluation Team?",
"Failed to load dataset examples.",
] # Provide fallback examples
def get_chat_num_tokens(current_chat_state: dict) -> int: def get_chat_num_tokens(current_chat_state: dict) -> int:
"""Helper to get number of tokens in chat state.""" """Helper to get number of tokens in chat state."""
try: try:
@ -112,6 +126,7 @@ def create_interface(model_path: str, temperature: float = 0.7, system_prompt: s
temp: float, temp: float,
max_iter: int = 20, max_iter: int = 20,
num_search_results: int = 2, num_search_results: int = 2,
gold_answer_state: str | None = None,
) -> Iterator[list[gr.ChatMessage]]: ) -> Iterator[list[gr.ChatMessage]]:
"""Stream agent responses following agent.py/inference.py logic.""" """Stream agent responses following agent.py/inference.py logic."""
chat_state = { chat_state = {
@ -326,6 +341,20 @@ def create_interface(model_path: str, temperature: float = 0.7, system_prompt: s
logger.info(f"Processing finished in {total_time:.2f} seconds.") logger.info(f"Processing finished in {total_time:.2f} seconds.")
# ---> Display Gold Answer if available (passed via gold_answer_state)
if gold_answer_state:
logger.info("Displaying gold answer.")
messages.append(
gr.ChatMessage(
role="assistant",
content=gold_answer_state,
metadata={"title": "✅ Correct Answer (For comparison)"},
)
)
yield messages
else:
logger.info("No gold answer to display for this query.")
with gr.Blocks(title="DeepSearch - Visible Thinking") as interface: with gr.Blocks(title="DeepSearch - Visible Thinking") as interface:
gr.Markdown("# 🧠 DeepSearch with Visible Thinking") gr.Markdown("# 🧠 DeepSearch with Visible Thinking")
gr.Markdown("Watch as the AI thinks, searches, and processes information to answer your questions.") gr.Markdown("Watch as the AI thinks, searches, and processes information to answer your questions.")
@ -345,13 +374,13 @@ def create_interface(model_path: str, temperature: float = 0.7, system_prompt: s
placeholder="Type your message here...", show_label=False, container=False, elem_id="msg-input" placeholder="Type your message here...", show_label=False, container=False, elem_id="msg-input"
) )
example_questions = [ # Use questions from dataset as examples
"What year was the document approved by the Mission Evaluation Team?", gr.Examples(
"Summarize the key findings regarding the oxygen tank failure.", examples=example_questions,
"Who was the commander of Apollo 13?", inputs=msg,
"What were the main recommendations from the review board?", label="Example Questions with correct answer for comparison",
] examples_per_page=4,
gr.Examples(examples=example_questions, inputs=msg, label="Example Questions", examples_per_page=4) )
with gr.Row(): with gr.Row():
clear = gr.Button("Clear Chat") clear = gr.Button("Clear Chat")
@ -387,16 +416,24 @@ def create_interface(model_path: str, temperature: float = 0.7, system_prompt: s
return "", history return "", history
submitted_msg_state = gr.State("") submitted_msg_state = gr.State("")
gold_answer_state = gr.State(None) # State to hold the gold answer
# Function to check if submitted message is an example and store gold answer
def check_if_example_and_store_answer(msg_text):
gold_answer = qa_map.get(msg_text) # Returns None if msg_text is not in examples
logger.info(f"Checking for gold answer for: '{msg_text[:50]}...'. Found: {bool(gold_answer)}")
return gold_answer
# Chain events:
# 1. User submits -> store msg text in state
# 2. .then() -> add_user_message (updates chatbot UI history, clears textbox)
# 3. .then() -> stream_agent_response (takes stored msg text and updated chatbot history)
submit.click( submit.click(
lambda msg_text: msg_text, lambda msg_text: msg_text,
inputs=[msg], inputs=[msg],
outputs=[submitted_msg_state], outputs=[submitted_msg_state],
queue=False, queue=False,
).then( # Check for gold answer immediately after storing submitted message
check_if_example_and_store_answer,
inputs=[submitted_msg_state],
outputs=[gold_answer_state],
queue=False,
).then( ).then(
add_user_message, add_user_message,
inputs=[submitted_msg_state, chatbot], inputs=[submitted_msg_state, chatbot],
@ -404,7 +441,14 @@ def create_interface(model_path: str, temperature: float = 0.7, system_prompt: s
queue=False, queue=False,
).then( ).then(
stream_agent_response, stream_agent_response,
inputs=[submitted_msg_state, chatbot, temp_slider, max_iter_slider, num_results_slider], inputs=[
submitted_msg_state,
chatbot,
temp_slider,
max_iter_slider,
num_results_slider,
gold_answer_state,
], # Pass gold answer state
outputs=chatbot, outputs=chatbot,
) )
@ -413,6 +457,11 @@ def create_interface(model_path: str, temperature: float = 0.7, system_prompt: s
inputs=[msg], inputs=[msg],
outputs=[submitted_msg_state], outputs=[submitted_msg_state],
queue=False, queue=False,
).then( # Check for gold answer immediately after storing submitted message
check_if_example_and_store_answer,
inputs=[submitted_msg_state],
outputs=[gold_answer_state],
queue=False,
).then( ).then(
add_user_message, add_user_message,
inputs=[submitted_msg_state, chatbot], inputs=[submitted_msg_state, chatbot],
@ -420,11 +469,18 @@ def create_interface(model_path: str, temperature: float = 0.7, system_prompt: s
queue=False, queue=False,
).then( ).then(
stream_agent_response, stream_agent_response,
inputs=[submitted_msg_state, chatbot, temp_slider, max_iter_slider, num_results_slider], inputs=[
submitted_msg_state,
chatbot,
temp_slider,
max_iter_slider,
num_results_slider,
gold_answer_state,
], # Pass gold answer state
outputs=chatbot, outputs=chatbot,
) )
clear.click(lambda: ([], ""), None, [chatbot, submitted_msg_state]) clear.click(lambda: ([], None), None, [chatbot, gold_answer_state]) # Also clear gold answer state
system_prompt_state = gr.State(system_prompt) system_prompt_state = gr.State(system_prompt)
# TODO: Currently, changing the system prompt mid-chat won't affect the ongoing stream_agent_response. # TODO: Currently, changing the system prompt mid-chat won't affect the ongoing stream_agent_response.

@ -23,8 +23,8 @@ dependencies = [
"langchain-community", "langchain-community",
"Markdown", "Markdown",
"tokenizers", "tokenizers",
"unsloth==2025.2.14", "unsloth",
"unsloth_zoo==2025.2.7", "unsloth_zoo",
"unstructured", "unstructured",
"vllm==0.7.2", "vllm==0.7.2",
"transformers==4.49.0", "transformers==4.49.0",
@ -34,5 +34,7 @@ dependencies = [
"gradio", "gradio",
"tensorboard", "tensorboard",
"pytest", "pytest",
"wandb" "wandb",
"requests>=2.31.0",
"tqdm>=4.66.1"
] ]
Loading…
Cancel
Save