|
|
"""
|
|
|
Gradio web interface for DeepSearch.
|
|
|
|
|
|
This module provides a simple web interface for interacting with the DeepSearch model
|
|
|
using Gradio. It implements the core functionality directly for better modularity.
|
|
|
"""
|
|
|
|
|
|
import re
|
|
|
import sys
|
|
|
import time
|
|
|
from typing import Iterator, cast
|
|
|
|
|
|
import gradio as gr
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer
|
|
|
|
|
|
# Import from config
|
|
|
from config import GENERATOR_MODEL_DIR, logger
|
|
|
from src import (
|
|
|
apply_chat_template,
|
|
|
build_user_prompt,
|
|
|
format_search_results,
|
|
|
get_system_prompt,
|
|
|
)
|
|
|
from src.search_module import get_qa_dataset, load_vectorstore, search
|
|
|
|
|
|
# TODO: check if can reuse tokenizer adapter
|
|
|
|
|
|
|
|
|
def extract_answer_tag(text: str) -> tuple[bool, str | None]:
|
|
|
"""Check if text contains an answer tag and extract the answer content if found.
|
|
|
|
|
|
Returns:
|
|
|
tuple: (has_answer, answer_content)
|
|
|
"""
|
|
|
pattern = re.compile(r"<answer>(.*?)</answer>", re.DOTALL | re.IGNORECASE)
|
|
|
match = re.search(pattern, text)
|
|
|
if match:
|
|
|
content = match.group(1).strip()
|
|
|
return True, content
|
|
|
return False, None
|
|
|
|
|
|
|
|
|
def extract_thinking_content(text: str) -> str | None:
|
|
|
"""Extract thinking content from text between <think> tags."""
|
|
|
pattern = re.compile(r"<think>(.*?)</think>", re.DOTALL | re.IGNORECASE)
|
|
|
match = re.search(pattern, text)
|
|
|
if match:
|
|
|
content = match.group(1).strip()
|
|
|
return content
|
|
|
return None
|
|
|
|
|
|
|
|
|
def extract_search_query(text: str) -> str | None:
|
|
|
"""Extract search query from text between <search> tags (Simplified)."""
|
|
|
pattern = re.compile(r"<search>(.*?)</search>", re.DOTALL | re.IGNORECASE)
|
|
|
match = re.search(pattern, text)
|
|
|
if match:
|
|
|
content = match.group(1).strip()
|
|
|
return content
|
|
|
return None
|
|
|
|
|
|
|
|
|
def setup_model_and_tokenizer(model_path: str):
|
|
|
"""Initialize model and tokenizer."""
|
|
|
logger.info(f"Setting up model from {model_path}...")
|
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
|
model_path,
|
|
|
torch_dtype="float16",
|
|
|
device_map="auto",
|
|
|
trust_remote_code=True,
|
|
|
)
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
|
|
|
|
|
# Defaulting to the one used in inference.py, adjust if needed for your specific model
|
|
|
assistant_marker = "<|start_header_id|>assistant<|end_header_id|>"
|
|
|
logger.info(f"Using assistant marker: '{assistant_marker}' for response splitting.")
|
|
|
|
|
|
logger.info("Model and tokenizer setup complete.")
|
|
|
return model, tokenizer, assistant_marker
|
|
|
|
|
|
|
|
|
def get_sampling_params(temperature: float = 0.7, max_tokens: int = 4096):
|
|
|
"""Get sampling parameters for generation."""
|
|
|
return {
|
|
|
"temperature": temperature,
|
|
|
"top_p": 0.95,
|
|
|
"max_new_tokens": max_tokens,
|
|
|
"do_sample": True,
|
|
|
}
|
|
|
|
|
|
|
|
|
def create_interface(model_path: str, temperature: float = 0.7, system_prompt: str | None = None):
|
|
|
"""Create Gradio interface for DeepSearch."""
|
|
|
model, tokenizer, assistant_marker = setup_model_and_tokenizer(model_path)
|
|
|
system_prompt = system_prompt or get_system_prompt()
|
|
|
tokenizer_for_template = cast(PreTrainedTokenizer, tokenizer)
|
|
|
|
|
|
# Load QA dataset for examples and gold answers
|
|
|
try:
|
|
|
_, test_dataset = get_qa_dataset()
|
|
|
qa_map = {q: a for q, a in zip(test_dataset["prompt"], test_dataset["answer"])}
|
|
|
example_questions = list(qa_map.keys())
|
|
|
logger.info(f"Loaded {len(example_questions)} QA examples.")
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to load QA dataset: {e}")
|
|
|
qa_map = {}
|
|
|
example_questions = [
|
|
|
"What year was the document approved by the Mission Evaluation Team?",
|
|
|
"Failed to load dataset examples.",
|
|
|
] # Provide fallback examples
|
|
|
|
|
|
def get_chat_num_tokens(current_chat_state: dict) -> int:
|
|
|
"""Helper to get number of tokens in chat state."""
|
|
|
try:
|
|
|
chat_text = apply_chat_template(current_chat_state, tokenizer=tokenizer_for_template)["text"]
|
|
|
input_ids = tokenizer.encode(chat_text, add_special_tokens=False)
|
|
|
return len(input_ids)
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error calculating token count: {e}")
|
|
|
return sys.maxsize
|
|
|
|
|
|
def stream_agent_response(
|
|
|
message: str,
|
|
|
history_gr: list[gr.ChatMessage],
|
|
|
temp: float,
|
|
|
max_iter: int = 20,
|
|
|
num_search_results: int = 2,
|
|
|
gold_answer_state: str | None = None,
|
|
|
) -> Iterator[list[gr.ChatMessage]]:
|
|
|
"""Stream agent responses following agent.py/inference.py logic."""
|
|
|
chat_state = {
|
|
|
"messages": [{"role": "system", "content": system_prompt}],
|
|
|
"finished": False,
|
|
|
}
|
|
|
# Convert Gradio history to internal format, skip last user msg (passed separately)
|
|
|
processed_history = history_gr[:-1] if history_gr else []
|
|
|
for msg_obj in processed_history:
|
|
|
role = getattr(msg_obj, "role", "unknown")
|
|
|
content = getattr(msg_obj, "content", "")
|
|
|
if role == "user":
|
|
|
chat_state["messages"].append({"role": "user", "content": build_user_prompt(content)})
|
|
|
elif role == "assistant":
|
|
|
chat_state["messages"].append({"role": "assistant", "content": content})
|
|
|
|
|
|
chat_state["messages"].append({"role": "user", "content": build_user_prompt(message)})
|
|
|
|
|
|
initial_token_length = get_chat_num_tokens(chat_state)
|
|
|
max_new_tokens_allowed = get_sampling_params(temp)["max_new_tokens"]
|
|
|
|
|
|
messages = history_gr
|
|
|
|
|
|
start_time = time.time()
|
|
|
iterations = 0
|
|
|
last_assistant_response = ""
|
|
|
|
|
|
while not chat_state.get("finished", False) and iterations < max_iter:
|
|
|
iterations += 1
|
|
|
current_turn_start_time = time.time()
|
|
|
|
|
|
think_msg_idx = len(messages)
|
|
|
messages.append(
|
|
|
gr.ChatMessage(
|
|
|
role="assistant",
|
|
|
content="Thinking...",
|
|
|
metadata={"title": "🧠 Thinking", "status": "pending"},
|
|
|
)
|
|
|
)
|
|
|
yield messages
|
|
|
|
|
|
current_length_before_gen = get_chat_num_tokens(chat_state)
|
|
|
if current_length_before_gen - initial_token_length > max_new_tokens_allowed:
|
|
|
logger.warning(
|
|
|
f"TOKEN LIMIT EXCEEDED (Before Generation): Current {current_length_before_gen}, Start {initial_token_length}"
|
|
|
)
|
|
|
chat_state["finished"] = True
|
|
|
messages[think_msg_idx] = gr.ChatMessage(
|
|
|
role="assistant",
|
|
|
content="Context length limit reached.",
|
|
|
metadata={"title": "⚠️ Token Limit", "status": "done"},
|
|
|
)
|
|
|
yield messages
|
|
|
break
|
|
|
|
|
|
try:
|
|
|
generation_params = get_sampling_params(temp)
|
|
|
formatted_prompt = apply_chat_template(chat_state, tokenizer=tokenizer_for_template)["text"]
|
|
|
inputs = tokenizer(formatted_prompt, return_tensors="pt", add_special_tokens=False).to(model.device)
|
|
|
|
|
|
outputs = model.generate(**inputs, **generation_params)
|
|
|
full_response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
|
|
if assistant_marker in full_response_text:
|
|
|
assistant_response = full_response_text.split(assistant_marker)[-1].strip()
|
|
|
else:
|
|
|
inputs_dict = cast(dict, inputs)
|
|
|
input_token_length = len(inputs_dict["input_ids"][0])
|
|
|
assistant_response = tokenizer.decode(
|
|
|
outputs[0][input_token_length:], skip_special_tokens=True
|
|
|
).strip()
|
|
|
logger.warning(
|
|
|
f"Assistant marker '{assistant_marker}' not found in response. Extracted via token slicing fallback."
|
|
|
)
|
|
|
|
|
|
last_assistant_response = assistant_response
|
|
|
thinking_content = extract_thinking_content(assistant_response)
|
|
|
|
|
|
gen_time = time.time() - current_turn_start_time
|
|
|
|
|
|
display_thinking = thinking_content if thinking_content else "Processing..."
|
|
|
messages[think_msg_idx] = gr.ChatMessage(
|
|
|
role="assistant",
|
|
|
content=display_thinking,
|
|
|
metadata={"title": "🧠 Thinking", "status": "done", "duration": gen_time},
|
|
|
)
|
|
|
yield messages
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error during generation: {e}")
|
|
|
chat_state["finished"] = True
|
|
|
messages[think_msg_idx] = gr.ChatMessage(
|
|
|
role="assistant",
|
|
|
content=f"Error during generation: {e}",
|
|
|
metadata={"title": "❌ Generation Error", "status": "done"},
|
|
|
)
|
|
|
yield messages
|
|
|
break
|
|
|
|
|
|
chat_state["messages"].append({"role": "assistant", "content": assistant_response})
|
|
|
|
|
|
search_query = extract_search_query(assistant_response)
|
|
|
|
|
|
if not search_query:
|
|
|
chat_state["finished"] = True
|
|
|
else:
|
|
|
search_msg_idx = len(messages)
|
|
|
messages.append(
|
|
|
gr.ChatMessage(
|
|
|
role="assistant",
|
|
|
content=f"Searching for: {search_query}",
|
|
|
metadata={"title": "🔍 Search", "status": "pending"},
|
|
|
)
|
|
|
)
|
|
|
yield messages
|
|
|
search_start = time.time()
|
|
|
try:
|
|
|
results = search(search_query, return_type=str, results=num_search_results)
|
|
|
search_duration = time.time() - search_start
|
|
|
|
|
|
messages[search_msg_idx] = gr.ChatMessage(
|
|
|
role="assistant",
|
|
|
content=f"{search_query}",
|
|
|
metadata={"title": "🔍 Search", "duration": search_duration},
|
|
|
)
|
|
|
yield messages
|
|
|
display_results = format_search_results(results)
|
|
|
messages.append(
|
|
|
gr.ChatMessage(
|
|
|
role="assistant",
|
|
|
content=display_results,
|
|
|
metadata={"title": "ℹ️ Information", "status": "done"},
|
|
|
)
|
|
|
)
|
|
|
yield messages
|
|
|
|
|
|
formatted_results = f"<information>{results}</information>"
|
|
|
chat_state["messages"].append({"role": "user", "content": formatted_results})
|
|
|
|
|
|
except Exception as e:
|
|
|
search_duration = time.time() - search_start
|
|
|
logger.error(f"Search failed: {str(e)}")
|
|
|
messages[search_msg_idx] = gr.ChatMessage(
|
|
|
role="assistant",
|
|
|
content=f"Search failed: {str(e)}",
|
|
|
metadata={"title": "❌ Search Error", "status": "done", "duration": search_duration},
|
|
|
)
|
|
|
yield messages
|
|
|
chat_state["messages"].append({"role": "system", "content": f"Error during search: {str(e)}"})
|
|
|
chat_state["finished"] = True
|
|
|
|
|
|
current_length_after_iter = get_chat_num_tokens(chat_state)
|
|
|
if current_length_after_iter - initial_token_length > max_new_tokens_allowed:
|
|
|
logger.warning(
|
|
|
f"TOKEN LIMIT EXCEEDED (After Iteration): Current {current_length_after_iter}, Start {initial_token_length}"
|
|
|
)
|
|
|
chat_state["finished"] = True
|
|
|
if messages[-1].metadata.get("title") != "⚠️ Token Limit":
|
|
|
messages.append(
|
|
|
gr.ChatMessage(
|
|
|
role="assistant",
|
|
|
content="Context length limit reached during processing.",
|
|
|
metadata={"title": "⚠️ Token Limit", "status": "done"},
|
|
|
)
|
|
|
)
|
|
|
yield messages
|
|
|
|
|
|
total_time = time.time() - start_time
|
|
|
|
|
|
if not chat_state.get("finished", False) and iterations >= max_iter:
|
|
|
logger.warning(f"Reached maximum iterations ({max_iter}) without finishing")
|
|
|
messages.append(
|
|
|
gr.ChatMessage(
|
|
|
role="assistant",
|
|
|
content=f"Reached maximum iterations ({max_iter}). Displaying last response:\n\n{last_assistant_response}",
|
|
|
metadata={"title": "⚠️ Max Iterations", "status": "done", "duration": total_time},
|
|
|
)
|
|
|
)
|
|
|
yield messages
|
|
|
elif chat_state.get("finished", False) and last_assistant_response:
|
|
|
has_answer, answer_content = extract_answer_tag(last_assistant_response)
|
|
|
|
|
|
if has_answer and answer_content is not None:
|
|
|
display_title = "📝 Final Answer"
|
|
|
display_content = answer_content
|
|
|
else:
|
|
|
display_title = "💡 Answer"
|
|
|
display_content = last_assistant_response
|
|
|
|
|
|
if len(messages) > 0 and messages[-1].content != display_content:
|
|
|
messages.append(
|
|
|
gr.ChatMessage(
|
|
|
role="assistant",
|
|
|
content=display_content,
|
|
|
metadata={"title": display_title, "duration": total_time},
|
|
|
)
|
|
|
)
|
|
|
yield messages
|
|
|
elif len(messages) == 0:
|
|
|
messages.append(
|
|
|
gr.ChatMessage(
|
|
|
role="assistant",
|
|
|
content=display_content,
|
|
|
metadata={"title": display_title, "duration": total_time},
|
|
|
)
|
|
|
)
|
|
|
yield messages
|
|
|
else:
|
|
|
messages[-1].metadata["title"] = display_title
|
|
|
messages[-1].metadata["status"] = "done"
|
|
|
messages[-1].metadata["duration"] = total_time
|
|
|
|
|
|
logger.info(f"Processing finished in {total_time:.2f} seconds.")
|
|
|
|
|
|
# ---> Display Gold Answer if available (passed via gold_answer_state)
|
|
|
if gold_answer_state:
|
|
|
logger.info("Displaying gold answer.")
|
|
|
messages.append(
|
|
|
gr.ChatMessage(
|
|
|
role="assistant",
|
|
|
content=gold_answer_state,
|
|
|
metadata={"title": "✅ Correct Answer (For comparison)"},
|
|
|
)
|
|
|
)
|
|
|
yield messages
|
|
|
else:
|
|
|
logger.info("No gold answer to display for this query.")
|
|
|
|
|
|
with gr.Blocks(title="DeepSearch - Visible Thinking") as interface:
|
|
|
gr.Markdown("# 🧠 DeepSearch with Visible Thinking")
|
|
|
gr.Markdown("Watch as the AI thinks, searches, and processes information to answer your questions.")
|
|
|
|
|
|
with gr.Row():
|
|
|
with gr.Column(scale=3):
|
|
|
chatbot = gr.Chatbot(
|
|
|
[],
|
|
|
elem_id="chatbot",
|
|
|
type="messages",
|
|
|
height=600,
|
|
|
show_label=False,
|
|
|
render_markdown=True,
|
|
|
bubble_full_width=False,
|
|
|
)
|
|
|
msg = gr.Textbox(
|
|
|
placeholder="Type your message here...", show_label=False, container=False, elem_id="msg-input"
|
|
|
)
|
|
|
|
|
|
# Use questions from dataset as examples
|
|
|
gr.Examples(
|
|
|
examples=example_questions,
|
|
|
inputs=msg,
|
|
|
label="Example Questions with correct answer for comparison",
|
|
|
examples_per_page=4,
|
|
|
)
|
|
|
|
|
|
with gr.Row():
|
|
|
clear = gr.Button("Clear Chat")
|
|
|
submit = gr.Button("Submit", variant="primary")
|
|
|
|
|
|
with gr.Column(scale=1):
|
|
|
gr.Markdown("### Settings")
|
|
|
temp_slider = gr.Slider(minimum=0.1, maximum=1.0, value=temperature, step=0.1, label="Temperature")
|
|
|
system_prompt_input = gr.Textbox(
|
|
|
label="System Prompt", value=system_prompt, lines=3, info="Controls how the AI behaves"
|
|
|
)
|
|
|
max_iter_slider = gr.Slider(
|
|
|
minimum=1,
|
|
|
maximum=20,
|
|
|
value=20,
|
|
|
step=1,
|
|
|
label="Max Search Iterations",
|
|
|
info="Maximum number of search-think cycles",
|
|
|
)
|
|
|
num_results_slider = gr.Slider(
|
|
|
minimum=1,
|
|
|
maximum=5,
|
|
|
value=2,
|
|
|
step=1,
|
|
|
label="Number of Search Results",
|
|
|
info="How many results to retrieve per search query",
|
|
|
)
|
|
|
|
|
|
def add_user_message(user_msg_text: str, history: list[gr.ChatMessage]) -> tuple[str, list[gr.ChatMessage]]:
|
|
|
"""Appends user message to chat history and clears input."""
|
|
|
if user_msg_text and user_msg_text.strip():
|
|
|
history.append(gr.ChatMessage(role="user", content=user_msg_text.strip()))
|
|
|
return "", history
|
|
|
|
|
|
submitted_msg_state = gr.State("")
|
|
|
gold_answer_state = gr.State(None) # State to hold the gold answer
|
|
|
|
|
|
# Function to check if submitted message is an example and store gold answer
|
|
|
def check_if_example_and_store_answer(msg_text):
|
|
|
gold_answer = qa_map.get(msg_text) # Returns None if msg_text is not in examples
|
|
|
logger.info(f"Checking for gold answer for: '{msg_text[:50]}...'. Found: {bool(gold_answer)}")
|
|
|
return gold_answer
|
|
|
|
|
|
submit.click(
|
|
|
lambda msg_text: msg_text,
|
|
|
inputs=[msg],
|
|
|
outputs=[submitted_msg_state],
|
|
|
queue=False,
|
|
|
).then( # Check for gold answer immediately after storing submitted message
|
|
|
check_if_example_and_store_answer,
|
|
|
inputs=[submitted_msg_state],
|
|
|
outputs=[gold_answer_state],
|
|
|
queue=False,
|
|
|
).then(
|
|
|
add_user_message,
|
|
|
inputs=[submitted_msg_state, chatbot],
|
|
|
outputs=[msg, chatbot],
|
|
|
queue=False,
|
|
|
).then(
|
|
|
stream_agent_response,
|
|
|
inputs=[
|
|
|
submitted_msg_state,
|
|
|
chatbot,
|
|
|
temp_slider,
|
|
|
max_iter_slider,
|
|
|
num_results_slider,
|
|
|
gold_answer_state,
|
|
|
], # Pass gold answer state
|
|
|
outputs=chatbot,
|
|
|
)
|
|
|
|
|
|
msg.submit(
|
|
|
lambda msg_text: msg_text,
|
|
|
inputs=[msg],
|
|
|
outputs=[submitted_msg_state],
|
|
|
queue=False,
|
|
|
).then( # Check for gold answer immediately after storing submitted message
|
|
|
check_if_example_and_store_answer,
|
|
|
inputs=[submitted_msg_state],
|
|
|
outputs=[gold_answer_state],
|
|
|
queue=False,
|
|
|
).then(
|
|
|
add_user_message,
|
|
|
inputs=[submitted_msg_state, chatbot],
|
|
|
outputs=[msg, chatbot],
|
|
|
queue=False,
|
|
|
).then(
|
|
|
stream_agent_response,
|
|
|
inputs=[
|
|
|
submitted_msg_state,
|
|
|
chatbot,
|
|
|
temp_slider,
|
|
|
max_iter_slider,
|
|
|
num_results_slider,
|
|
|
gold_answer_state,
|
|
|
], # Pass gold answer state
|
|
|
outputs=chatbot,
|
|
|
)
|
|
|
|
|
|
clear.click(lambda: ([], None), None, [chatbot, gold_answer_state]) # Also clear gold answer state
|
|
|
|
|
|
system_prompt_state = gr.State(system_prompt)
|
|
|
# TODO: Currently, changing the system prompt mid-chat won't affect the ongoing stream_agent_response.
|
|
|
system_prompt_input.change(lambda prompt: prompt, inputs=[system_prompt_input], outputs=[system_prompt_state])
|
|
|
|
|
|
return interface
|
|
|
|
|
|
|
|
|
def main():
|
|
|
"""Run the Gradio app."""
|
|
|
model_path = str(GENERATOR_MODEL_DIR)
|
|
|
logger.info(f"Using model from config: {model_path}")
|
|
|
|
|
|
interface = create_interface(model_path)
|
|
|
interface.launch(share=True)
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
if load_vectorstore() is None:
|
|
|
logger.warning("⚠️ FAISS vectorstore could not be loaded. Search functionality may be unavailable.")
|
|
|
|
|
|
main()
|