|
|
|
@ -21,7 +21,7 @@ from src import (
|
|
|
|
|
format_search_results,
|
|
|
|
|
get_system_prompt,
|
|
|
|
|
)
|
|
|
|
|
from src.search_module import load_vectorstore, search
|
|
|
|
|
from src.search_module import get_qa_dataset, load_vectorstore, search
|
|
|
|
|
|
|
|
|
|
# TODO: check if can reuse tokenizer adapter
|
|
|
|
|
|
|
|
|
@ -96,6 +96,20 @@ def create_interface(model_path: str, temperature: float = 0.7, system_prompt: s
|
|
|
|
|
system_prompt = system_prompt or get_system_prompt()
|
|
|
|
|
tokenizer_for_template = cast(PreTrainedTokenizer, tokenizer)
|
|
|
|
|
|
|
|
|
|
# Load QA dataset for examples and gold answers
|
|
|
|
|
try:
|
|
|
|
|
_, test_dataset = get_qa_dataset()
|
|
|
|
|
qa_map = {q: a for q, a in zip(test_dataset["prompt"], test_dataset["answer"])}
|
|
|
|
|
example_questions = list(qa_map.keys())
|
|
|
|
|
logger.info(f"Loaded {len(example_questions)} QA examples.")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Failed to load QA dataset: {e}")
|
|
|
|
|
qa_map = {}
|
|
|
|
|
example_questions = [
|
|
|
|
|
"What year was the document approved by the Mission Evaluation Team?",
|
|
|
|
|
"Failed to load dataset examples.",
|
|
|
|
|
] # Provide fallback examples
|
|
|
|
|
|
|
|
|
|
def get_chat_num_tokens(current_chat_state: dict) -> int:
|
|
|
|
|
"""Helper to get number of tokens in chat state."""
|
|
|
|
|
try:
|
|
|
|
@ -112,6 +126,7 @@ def create_interface(model_path: str, temperature: float = 0.7, system_prompt: s
|
|
|
|
|
temp: float,
|
|
|
|
|
max_iter: int = 20,
|
|
|
|
|
num_search_results: int = 2,
|
|
|
|
|
gold_answer_state: str | None = None,
|
|
|
|
|
) -> Iterator[list[gr.ChatMessage]]:
|
|
|
|
|
"""Stream agent responses following agent.py/inference.py logic."""
|
|
|
|
|
chat_state = {
|
|
|
|
@ -326,6 +341,20 @@ def create_interface(model_path: str, temperature: float = 0.7, system_prompt: s
|
|
|
|
|
|
|
|
|
|
logger.info(f"Processing finished in {total_time:.2f} seconds.")
|
|
|
|
|
|
|
|
|
|
# ---> Display Gold Answer if available (passed via gold_answer_state)
|
|
|
|
|
if gold_answer_state:
|
|
|
|
|
logger.info("Displaying gold answer.")
|
|
|
|
|
messages.append(
|
|
|
|
|
gr.ChatMessage(
|
|
|
|
|
role="assistant",
|
|
|
|
|
content=gold_answer_state,
|
|
|
|
|
metadata={"title": "✅ Correct Answer (For comparison)"},
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
yield messages
|
|
|
|
|
else:
|
|
|
|
|
logger.info("No gold answer to display for this query.")
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="DeepSearch - Visible Thinking") as interface:
|
|
|
|
|
gr.Markdown("# 🧠 DeepSearch with Visible Thinking")
|
|
|
|
|
gr.Markdown("Watch as the AI thinks, searches, and processes information to answer your questions.")
|
|
|
|
@ -345,13 +374,13 @@ def create_interface(model_path: str, temperature: float = 0.7, system_prompt: s
|
|
|
|
|
placeholder="Type your message here...", show_label=False, container=False, elem_id="msg-input"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
example_questions = [
|
|
|
|
|
"What year was the document approved by the Mission Evaluation Team?",
|
|
|
|
|
"Summarize the key findings regarding the oxygen tank failure.",
|
|
|
|
|
"Who was the commander of Apollo 13?",
|
|
|
|
|
"What were the main recommendations from the review board?",
|
|
|
|
|
]
|
|
|
|
|
gr.Examples(examples=example_questions, inputs=msg, label="Example Questions", examples_per_page=4)
|
|
|
|
|
# Use questions from dataset as examples
|
|
|
|
|
gr.Examples(
|
|
|
|
|
examples=example_questions,
|
|
|
|
|
inputs=msg,
|
|
|
|
|
label="Example Questions with correct answer for comparison",
|
|
|
|
|
examples_per_page=4,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
with gr.Row():
|
|
|
|
|
clear = gr.Button("Clear Chat")
|
|
|
|
@ -387,16 +416,24 @@ def create_interface(model_path: str, temperature: float = 0.7, system_prompt: s
|
|
|
|
|
return "", history
|
|
|
|
|
|
|
|
|
|
submitted_msg_state = gr.State("")
|
|
|
|
|
gold_answer_state = gr.State(None) # State to hold the gold answer
|
|
|
|
|
|
|
|
|
|
# Function to check if submitted message is an example and store gold answer
|
|
|
|
|
def check_if_example_and_store_answer(msg_text):
|
|
|
|
|
gold_answer = qa_map.get(msg_text) # Returns None if msg_text is not in examples
|
|
|
|
|
logger.info(f"Checking for gold answer for: '{msg_text[:50]}...'. Found: {bool(gold_answer)}")
|
|
|
|
|
return gold_answer
|
|
|
|
|
|
|
|
|
|
# Chain events:
|
|
|
|
|
# 1. User submits -> store msg text in state
|
|
|
|
|
# 2. .then() -> add_user_message (updates chatbot UI history, clears textbox)
|
|
|
|
|
# 3. .then() -> stream_agent_response (takes stored msg text and updated chatbot history)
|
|
|
|
|
submit.click(
|
|
|
|
|
lambda msg_text: msg_text,
|
|
|
|
|
inputs=[msg],
|
|
|
|
|
outputs=[submitted_msg_state],
|
|
|
|
|
queue=False,
|
|
|
|
|
).then( # Check for gold answer immediately after storing submitted message
|
|
|
|
|
check_if_example_and_store_answer,
|
|
|
|
|
inputs=[submitted_msg_state],
|
|
|
|
|
outputs=[gold_answer_state],
|
|
|
|
|
queue=False,
|
|
|
|
|
).then(
|
|
|
|
|
add_user_message,
|
|
|
|
|
inputs=[submitted_msg_state, chatbot],
|
|
|
|
@ -404,7 +441,14 @@ def create_interface(model_path: str, temperature: float = 0.7, system_prompt: s
|
|
|
|
|
queue=False,
|
|
|
|
|
).then(
|
|
|
|
|
stream_agent_response,
|
|
|
|
|
inputs=[submitted_msg_state, chatbot, temp_slider, max_iter_slider, num_results_slider],
|
|
|
|
|
inputs=[
|
|
|
|
|
submitted_msg_state,
|
|
|
|
|
chatbot,
|
|
|
|
|
temp_slider,
|
|
|
|
|
max_iter_slider,
|
|
|
|
|
num_results_slider,
|
|
|
|
|
gold_answer_state,
|
|
|
|
|
], # Pass gold answer state
|
|
|
|
|
outputs=chatbot,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
@ -413,6 +457,11 @@ def create_interface(model_path: str, temperature: float = 0.7, system_prompt: s
|
|
|
|
|
inputs=[msg],
|
|
|
|
|
outputs=[submitted_msg_state],
|
|
|
|
|
queue=False,
|
|
|
|
|
).then( # Check for gold answer immediately after storing submitted message
|
|
|
|
|
check_if_example_and_store_answer,
|
|
|
|
|
inputs=[submitted_msg_state],
|
|
|
|
|
outputs=[gold_answer_state],
|
|
|
|
|
queue=False,
|
|
|
|
|
).then(
|
|
|
|
|
add_user_message,
|
|
|
|
|
inputs=[submitted_msg_state, chatbot],
|
|
|
|
@ -420,11 +469,18 @@ def create_interface(model_path: str, temperature: float = 0.7, system_prompt: s
|
|
|
|
|
queue=False,
|
|
|
|
|
).then(
|
|
|
|
|
stream_agent_response,
|
|
|
|
|
inputs=[submitted_msg_state, chatbot, temp_slider, max_iter_slider, num_results_slider],
|
|
|
|
|
inputs=[
|
|
|
|
|
submitted_msg_state,
|
|
|
|
|
chatbot,
|
|
|
|
|
temp_slider,
|
|
|
|
|
max_iter_slider,
|
|
|
|
|
num_results_slider,
|
|
|
|
|
gold_answer_state,
|
|
|
|
|
], # Pass gold answer state
|
|
|
|
|
outputs=chatbot,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
clear.click(lambda: ([], ""), None, [chatbot, submitted_msg_state])
|
|
|
|
|
clear.click(lambda: ([], None), None, [chatbot, gold_answer_state]) # Also clear gold answer state
|
|
|
|
|
|
|
|
|
|
system_prompt_state = gr.State(system_prompt)
|
|
|
|
|
# TODO: Currently, changing the system prompt mid-chat won't affect the ongoing stream_agent_response.
|
|
|
|
|