|
|
|
@ -1,7 +1,7 @@
|
|
|
|
|
"""
|
|
|
|
|
Gradio web interface for DeepSearch.
|
|
|
|
|
Gradio web interface for ReZero.
|
|
|
|
|
|
|
|
|
|
This module provides a simple web interface for interacting with the DeepSearch model
|
|
|
|
|
This module provides a simple web interface for interacting with the ReZero model
|
|
|
|
|
using Gradio. It implements the core functionality directly for better modularity.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
@ -25,8 +25,6 @@ from src import (
|
|
|
|
|
)
|
|
|
|
|
from src.search_module import get_qa_dataset, load_vectorstore, search
|
|
|
|
|
|
|
|
|
|
# TODO: check if can reuse tokenizer adapter
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def extract_answer_tag(text: str) -> tuple[bool, str | None]:
|
|
|
|
|
"""Check if text contains an answer tag and extract the answer content if found.
|
|
|
|
@ -111,8 +109,8 @@ def get_chat_num_tokens(current_chat_state: dict, tokenizer: PreTrainedTokenizer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_deepsearch_tab(model, tokenizer, assistant_marker, system_prompt, temperature):
|
|
|
|
|
"""Creates the UI components and logic for the DeepSearch (Vector DB) tab."""
|
|
|
|
|
logger.info("Creating DeepSearch Tab")
|
|
|
|
|
"""Creates the UI components and logic for the ReZero (Vector DB) tab."""
|
|
|
|
|
logger.info("Creating ReZero Tab")
|
|
|
|
|
# tokenizer_for_template = cast(PreTrainedTokenizer, tokenizer) # Global now
|
|
|
|
|
|
|
|
|
|
# Load QA dataset for examples and gold answers
|
|
|
|
@ -120,16 +118,16 @@ def create_deepsearch_tab(model, tokenizer, assistant_marker, system_prompt, tem
|
|
|
|
|
_, test_dataset = get_qa_dataset()
|
|
|
|
|
qa_map = {q: a for q, a in zip(test_dataset["prompt"], test_dataset["answer"])}
|
|
|
|
|
example_questions = list(qa_map.keys())
|
|
|
|
|
logger.info(f"Loaded {len(example_questions)} QA examples for DeepSearch tab.")
|
|
|
|
|
logger.info(f"Loaded {len(example_questions)} QA examples for ReZero tab.")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Failed to load QA dataset for DeepSearch tab: {e}")
|
|
|
|
|
logger.error(f"Failed to load QA dataset for ReZero tab: {e}")
|
|
|
|
|
qa_map = {}
|
|
|
|
|
example_questions = [
|
|
|
|
|
"What year was the document approved by the Mission Evaluation Team?",
|
|
|
|
|
"Failed to load dataset examples.",
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
# --- Agent Streaming Logic for DeepSearch ---
|
|
|
|
|
# --- Agent Streaming Logic for ReZero ---
|
|
|
|
|
def stream_agent_response(
|
|
|
|
|
message: str,
|
|
|
|
|
history_gr: list[gr.ChatMessage],
|
|
|
|
@ -187,13 +185,13 @@ def create_deepsearch_tab(model, tokenizer, assistant_marker, system_prompt, tem
|
|
|
|
|
f"TOKEN LIMIT EXCEEDED (Before Generation): Current {current_length_before_gen}, Start {initial_token_length}"
|
|
|
|
|
)
|
|
|
|
|
chat_state["finished"] = True
|
|
|
|
|
messages[think_msg_idx] = gr.ChatMessage(
|
|
|
|
|
role="assistant",
|
|
|
|
|
content="Context length limit reached.",
|
|
|
|
|
metadata={"title": "⚠️ Token Limit", "status": "done"},
|
|
|
|
|
)
|
|
|
|
|
yield messages
|
|
|
|
|
break
|
|
|
|
|
messages[think_msg_idx] = gr.ChatMessage(
|
|
|
|
|
role="assistant",
|
|
|
|
|
content="Context length limit reached.",
|
|
|
|
|
metadata={"title": "⚠️ Token Limit", "status": "done"},
|
|
|
|
|
)
|
|
|
|
|
yield messages
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
generation_params = get_sampling_params(temp)
|
|
|
|
@ -369,9 +367,9 @@ def create_deepsearch_tab(model, tokenizer, assistant_marker, system_prompt, tem
|
|
|
|
|
else:
|
|
|
|
|
logger.info("No gold answer to display for this query.")
|
|
|
|
|
|
|
|
|
|
# --- UI Layout for DeepSearch Tab ---
|
|
|
|
|
# --- UI Layout for ReZero Tab ---
|
|
|
|
|
with gr.Blocks(analytics_enabled=False) as deepsearch_tab:
|
|
|
|
|
gr.Markdown("# 🧠 DeepSearch with Visible Thinking (Vector DB)")
|
|
|
|
|
gr.Markdown("# 🧠 ReZero: Enhancing LLM search ability by trying one-more-time")
|
|
|
|
|
gr.Markdown("Ask questions answered using the local vector database.")
|
|
|
|
|
|
|
|
|
|
with gr.Row():
|
|
|
|
@ -423,7 +421,7 @@ def create_deepsearch_tab(model, tokenizer, assistant_marker, system_prompt, tem
|
|
|
|
|
info="How many results to retrieve per search query",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# --- Event Handlers for DeepSearch Tab ---
|
|
|
|
|
# --- Event Handlers for ReZero Tab ---
|
|
|
|
|
def add_user_message(user_msg_text: str, history: list[gr.ChatMessage]) -> tuple[str, list[gr.ChatMessage]]:
|
|
|
|
|
"""Appends user message to chat history and clears input."""
|
|
|
|
|
if user_msg_text and user_msg_text.strip():
|
|
|
|
@ -918,7 +916,7 @@ def main():
|
|
|
|
|
|
|
|
|
|
# Combine tabs
|
|
|
|
|
interface = gr.TabbedInterface(
|
|
|
|
|
[tab1, tab2], tab_names=["DeepSearch (VectorDB)", "Tavily Search (Web)"], title="DeepSearch Agent UI"
|
|
|
|
|
[tab1, tab2], tab_names=["ReZero (VectorDB)", "Tavily Search (Web)"], title="ReZero Demo"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
logger.info("Launching Gradio Tabbed Interface...")
|
|
|
|
|