feat: integrate QA dataset loading and display gold answers in Gradio interface

main
thinhlpg 1 month ago
parent 7376f596a5
commit 41b7889a30

@ -21,7 +21,7 @@ from src import (
format_search_results,
get_system_prompt,
)
from src.search_module import load_vectorstore, search
from src.search_module import get_qa_dataset, load_vectorstore, search
# TODO: check if can reuse tokenizer adapter
@ -96,6 +96,20 @@ def create_interface(model_path: str, temperature: float = 0.7, system_prompt: s
system_prompt = system_prompt or get_system_prompt()
tokenizer_for_template = cast(PreTrainedTokenizer, tokenizer)
# Load QA dataset for examples and gold answers
try:
_, test_dataset = get_qa_dataset()
qa_map = {q: a for q, a in zip(test_dataset["prompt"], test_dataset["answer"])}
example_questions = list(qa_map.keys())
logger.info(f"Loaded {len(example_questions)} QA examples.")
except Exception as e:
logger.error(f"Failed to load QA dataset: {e}")
qa_map = {}
example_questions = [
"What year was the document approved by the Mission Evaluation Team?",
"Failed to load dataset examples.",
] # Provide fallback examples
def get_chat_num_tokens(current_chat_state: dict) -> int:
"""Helper to get number of tokens in chat state."""
try:
@ -112,6 +126,7 @@ def create_interface(model_path: str, temperature: float = 0.7, system_prompt: s
temp: float,
max_iter: int = 20,
num_search_results: int = 2,
gold_answer_state: str | None = None,
) -> Iterator[list[gr.ChatMessage]]:
"""Stream agent responses following agent.py/inference.py logic."""
chat_state = {
@ -326,6 +341,20 @@ def create_interface(model_path: str, temperature: float = 0.7, system_prompt: s
logger.info(f"Processing finished in {total_time:.2f} seconds.")
# ---> Display Gold Answer if available (passed via gold_answer_state)
if gold_answer_state:
logger.info("Displaying gold answer.")
messages.append(
gr.ChatMessage(
role="assistant",
content=gold_answer_state,
metadata={"title": "✅ Correct Answer (For comparison)"},
)
)
yield messages
else:
logger.info("No gold answer to display for this query.")
with gr.Blocks(title="DeepSearch - Visible Thinking") as interface:
gr.Markdown("# 🧠 DeepSearch with Visible Thinking")
gr.Markdown("Watch as the AI thinks, searches, and processes information to answer your questions.")
@ -345,13 +374,13 @@ def create_interface(model_path: str, temperature: float = 0.7, system_prompt: s
placeholder="Type your message here...", show_label=False, container=False, elem_id="msg-input"
)
example_questions = [
"What year was the document approved by the Mission Evaluation Team?",
"Summarize the key findings regarding the oxygen tank failure.",
"Who was the commander of Apollo 13?",
"What were the main recommendations from the review board?",
]
gr.Examples(examples=example_questions, inputs=msg, label="Example Questions", examples_per_page=4)
# Use questions from dataset as examples
gr.Examples(
examples=example_questions,
inputs=msg,
label="Example Questions with correct answer for comparison",
examples_per_page=4,
)
with gr.Row():
clear = gr.Button("Clear Chat")
@ -387,16 +416,24 @@ def create_interface(model_path: str, temperature: float = 0.7, system_prompt: s
return "", history
submitted_msg_state = gr.State("")
gold_answer_state = gr.State(None) # State to hold the gold answer
# Function to check if submitted message is an example and store gold answer
def check_if_example_and_store_answer(msg_text):
gold_answer = qa_map.get(msg_text) # Returns None if msg_text is not in examples
logger.info(f"Checking for gold answer for: '{msg_text[:50]}...'. Found: {bool(gold_answer)}")
return gold_answer
# Chain events:
# 1. User submits -> store msg text in state
# 2. .then() -> add_user_message (updates chatbot UI history, clears textbox)
# 3. .then() -> stream_agent_response (takes stored msg text and updated chatbot history)
submit.click(
lambda msg_text: msg_text,
inputs=[msg],
outputs=[submitted_msg_state],
queue=False,
).then( # Check for gold answer immediately after storing submitted message
check_if_example_and_store_answer,
inputs=[submitted_msg_state],
outputs=[gold_answer_state],
queue=False,
).then(
add_user_message,
inputs=[submitted_msg_state, chatbot],
@ -404,7 +441,14 @@ def create_interface(model_path: str, temperature: float = 0.7, system_prompt: s
queue=False,
).then(
stream_agent_response,
inputs=[submitted_msg_state, chatbot, temp_slider, max_iter_slider, num_results_slider],
inputs=[
submitted_msg_state,
chatbot,
temp_slider,
max_iter_slider,
num_results_slider,
gold_answer_state,
], # Pass gold answer state
outputs=chatbot,
)
@ -413,6 +457,11 @@ def create_interface(model_path: str, temperature: float = 0.7, system_prompt: s
inputs=[msg],
outputs=[submitted_msg_state],
queue=False,
).then( # Check for gold answer immediately after storing submitted message
check_if_example_and_store_answer,
inputs=[submitted_msg_state],
outputs=[gold_answer_state],
queue=False,
).then(
add_user_message,
inputs=[submitted_msg_state, chatbot],
@ -420,11 +469,18 @@ def create_interface(model_path: str, temperature: float = 0.7, system_prompt: s
queue=False,
).then(
stream_agent_response,
inputs=[submitted_msg_state, chatbot, temp_slider, max_iter_slider, num_results_slider],
inputs=[
submitted_msg_state,
chatbot,
temp_slider,
max_iter_slider,
num_results_slider,
gold_answer_state,
], # Pass gold answer state
outputs=chatbot,
)
clear.click(lambda: ([], ""), None, [chatbot, submitted_msg_state])
clear.click(lambda: ([], None), None, [chatbot, gold_answer_state]) # Also clear gold answer state
system_prompt_state = gr.State(system_prompt)
# TODO: Currently, changing the system prompt mid-chat won't affect the ongoing stream_agent_response.

@ -23,8 +23,8 @@ dependencies = [
"langchain-community",
"Markdown",
"tokenizers",
"unsloth==2025.2.14",
"unsloth_zoo==2025.2.7",
"unsloth",
"unsloth_zoo",
"unstructured",
"vllm==0.7.2",
"transformers==4.49.0",
@ -34,5 +34,7 @@ dependencies = [
"gradio",
"tensorboard",
"pytest",
"wandb"
"wandb",
"requests>=2.31.0",
"tqdm>=4.66.1"
]
Loading…
Cancel
Save