feat: update demo from DeepSearch to ReZero, adjusting related logging and UI components

main
thinhlpg 4 weeks ago
parent 9738b80353
commit 0b4bf54833

@ -1,7 +1,7 @@
""" """
Gradio web interface for DeepSearch. Gradio web interface for ReZero.
This module provides a simple web interface for interacting with the DeepSearch model This module provides a simple web interface for interacting with the ReZero model
using Gradio. It implements the core functionality directly for better modularity. using Gradio. It implements the core functionality directly for better modularity.
""" """
@ -25,8 +25,6 @@ from src import (
) )
from src.search_module import get_qa_dataset, load_vectorstore, search from src.search_module import get_qa_dataset, load_vectorstore, search
# TODO: check if can reuse tokenizer adapter
def extract_answer_tag(text: str) -> tuple[bool, str | None]: def extract_answer_tag(text: str) -> tuple[bool, str | None]:
"""Check if text contains an answer tag and extract the answer content if found. """Check if text contains an answer tag and extract the answer content if found.
@ -111,8 +109,8 @@ def get_chat_num_tokens(current_chat_state: dict, tokenizer: PreTrainedTokenizer
def create_deepsearch_tab(model, tokenizer, assistant_marker, system_prompt, temperature): def create_deepsearch_tab(model, tokenizer, assistant_marker, system_prompt, temperature):
"""Creates the UI components and logic for the DeepSearch (Vector DB) tab.""" """Creates the UI components and logic for the ReZero (Vector DB) tab."""
logger.info("Creating DeepSearch Tab") logger.info("Creating ReZero Tab")
# tokenizer_for_template = cast(PreTrainedTokenizer, tokenizer) # Global now # tokenizer_for_template = cast(PreTrainedTokenizer, tokenizer) # Global now
# Load QA dataset for examples and gold answers # Load QA dataset for examples and gold answers
@ -120,16 +118,16 @@ def create_deepsearch_tab(model, tokenizer, assistant_marker, system_prompt, tem
_, test_dataset = get_qa_dataset() _, test_dataset = get_qa_dataset()
qa_map = {q: a for q, a in zip(test_dataset["prompt"], test_dataset["answer"])} qa_map = {q: a for q, a in zip(test_dataset["prompt"], test_dataset["answer"])}
example_questions = list(qa_map.keys()) example_questions = list(qa_map.keys())
logger.info(f"Loaded {len(example_questions)} QA examples for DeepSearch tab.") logger.info(f"Loaded {len(example_questions)} QA examples for ReZero tab.")
except Exception as e: except Exception as e:
logger.error(f"Failed to load QA dataset for DeepSearch tab: {e}") logger.error(f"Failed to load QA dataset for ReZero tab: {e}")
qa_map = {} qa_map = {}
example_questions = [ example_questions = [
"What year was the document approved by the Mission Evaluation Team?", "What year was the document approved by the Mission Evaluation Team?",
"Failed to load dataset examples.", "Failed to load dataset examples.",
] ]
# --- Agent Streaming Logic for DeepSearch --- # --- Agent Streaming Logic for ReZero ---
def stream_agent_response( def stream_agent_response(
message: str, message: str,
history_gr: list[gr.ChatMessage], history_gr: list[gr.ChatMessage],
@ -187,13 +185,13 @@ def create_deepsearch_tab(model, tokenizer, assistant_marker, system_prompt, tem
f"TOKEN LIMIT EXCEEDED (Before Generation): Current {current_length_before_gen}, Start {initial_token_length}" f"TOKEN LIMIT EXCEEDED (Before Generation): Current {current_length_before_gen}, Start {initial_token_length}"
) )
chat_state["finished"] = True chat_state["finished"] = True
messages[think_msg_idx] = gr.ChatMessage( messages[think_msg_idx] = gr.ChatMessage(
role="assistant", role="assistant",
content="Context length limit reached.", content="Context length limit reached.",
metadata={"title": "⚠️ Token Limit", "status": "done"}, metadata={"title": "⚠️ Token Limit", "status": "done"},
) )
yield messages yield messages
break break
try: try:
generation_params = get_sampling_params(temp) generation_params = get_sampling_params(temp)
@ -369,9 +367,9 @@ def create_deepsearch_tab(model, tokenizer, assistant_marker, system_prompt, tem
else: else:
logger.info("No gold answer to display for this query.") logger.info("No gold answer to display for this query.")
# --- UI Layout for DeepSearch Tab --- # --- UI Layout for ReZero Tab ---
with gr.Blocks(analytics_enabled=False) as deepsearch_tab: with gr.Blocks(analytics_enabled=False) as deepsearch_tab:
gr.Markdown("# 🧠 DeepSearch with Visible Thinking (Vector DB)") gr.Markdown("# 🧠 ReZero: Enhancing LLM search ability by trying one-more-time")
gr.Markdown("Ask questions answered using the local vector database.") gr.Markdown("Ask questions answered using the local vector database.")
with gr.Row(): with gr.Row():
@ -423,7 +421,7 @@ def create_deepsearch_tab(model, tokenizer, assistant_marker, system_prompt, tem
info="How many results to retrieve per search query", info="How many results to retrieve per search query",
) )
# --- Event Handlers for DeepSearch Tab --- # --- Event Handlers for ReZero Tab ---
def add_user_message(user_msg_text: str, history: list[gr.ChatMessage]) -> tuple[str, list[gr.ChatMessage]]: def add_user_message(user_msg_text: str, history: list[gr.ChatMessage]) -> tuple[str, list[gr.ChatMessage]]:
"""Appends user message to chat history and clears input.""" """Appends user message to chat history and clears input."""
if user_msg_text and user_msg_text.strip(): if user_msg_text and user_msg_text.strip():
@ -918,7 +916,7 @@ def main():
# Combine tabs # Combine tabs
interface = gr.TabbedInterface( interface = gr.TabbedInterface(
[tab1, tab2], tab_names=["DeepSearch (VectorDB)", "Tavily Search (Web)"], title="DeepSearch Agent UI" [tab1, tab2], tab_names=["ReZero (VectorDB)", "Tavily Search (Web)"], title="ReZero Demo"
) )
logger.info("Launching Gradio Tabbed Interface...") logger.info("Launching Gradio Tabbed Interface...")

Loading…
Cancel
Save