diff --git a/.gitignore b/.gitignore
index 0f7c25a..9348aa4 100644
--- a/.gitignore
+++ b/.gitignore
@@ -17,6 +17,7 @@ downloaded_model/
logs/
*.code-workspace
data/
+.gradio/
# Byte-compiled / optimized / DLL files
__pycache__/
diff --git a/app.py b/app.py
new file mode 100644
index 0000000..3bbec08
--- /dev/null
+++ b/app.py
@@ -0,0 +1,449 @@
+"""
+Gradio web interface for DeepSearch.
+
+This module provides a simple web interface for interacting with the DeepSearch model
+using Gradio. It implements the core functionality directly for better modularity.
+"""
+
+import re
+import sys
+import time
+from typing import Iterator, cast
+
+import gradio as gr
+from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer
+
+# Import from config
+from config import GENERATOR_MODEL_DIR, logger
+from src import (
+ apply_chat_template,
+ build_user_prompt,
+ format_search_results,
+ get_system_prompt,
+)
+from src.search_module import load_vectorstore, search
+
+# TODO: check if can reuse tokenizer adapter
+
+
+def extract_answer_tag(text: str) -> tuple[bool, str | None]:
+ """Check if text contains an answer tag and extract the answer content if found.
+
+ Returns:
+ tuple: (has_answer, answer_content)
+ """
+ pattern = re.compile(r"(.*?)", re.DOTALL | re.IGNORECASE)
+ match = re.search(pattern, text)
+ if match:
+ content = match.group(1).strip()
+ return True, content
+ return False, None
+
+
+def extract_thinking_content(text: str) -> str | None:
+ """Extract thinking content from text between tags."""
+ pattern = re.compile(r"(.*?)", re.DOTALL | re.IGNORECASE)
+ match = re.search(pattern, text)
+ if match:
+ content = match.group(1).strip()
+ return content
+ return None
+
+
+def extract_search_query(text: str) -> str | None:
+ """Extract search query from text between tags (Simplified)."""
+ pattern = re.compile(r"(.*?)", re.DOTALL | re.IGNORECASE)
+ match = re.search(pattern, text)
+ if match:
+ content = match.group(1).strip()
+ return content
+ return None
+
+
+def setup_model_and_tokenizer(model_path: str):
+ """Initialize model and tokenizer."""
+ logger.info(f"Setting up model from {model_path}...")
+
+ model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ torch_dtype="float16",
+ device_map="auto",
+ trust_remote_code=True,
+ )
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
+
+ # Defaulting to the one used in inference.py, adjust if needed for your specific model
+ assistant_marker = "<|start_header_id|>assistant<|end_header_id|>"
+ logger.info(f"Using assistant marker: '{assistant_marker}' for response splitting.")
+
+ logger.info("Model and tokenizer setup complete.")
+ return model, tokenizer, assistant_marker
+
+
+def get_sampling_params(temperature: float = 0.7, max_tokens: int = 4096):
+ """Get sampling parameters for generation."""
+ return {
+ "temperature": temperature,
+ "top_p": 0.95,
+ "max_new_tokens": max_tokens,
+ "do_sample": True,
+ }
+
+
+def create_interface(model_path: str, temperature: float = 0.7, system_prompt: str | None = None):
+ """Create Gradio interface for DeepSearch."""
+ model, tokenizer, assistant_marker = setup_model_and_tokenizer(model_path)
+ system_prompt = system_prompt or get_system_prompt()
+ tokenizer_for_template = cast(PreTrainedTokenizer, tokenizer)
+
+ def get_chat_num_tokens(current_chat_state: dict) -> int:
+ """Helper to get number of tokens in chat state."""
+ try:
+ chat_text = apply_chat_template(current_chat_state, tokenizer=tokenizer_for_template)["text"]
+ input_ids = tokenizer.encode(chat_text, add_special_tokens=False)
+ return len(input_ids)
+ except Exception as e:
+ logger.error(f"Error calculating token count: {e}")
+ return sys.maxsize
+
+ def stream_agent_response(
+ message: str,
+ history_gr: list[gr.ChatMessage],
+ temp: float,
+ max_iter: int = 20,
+ num_search_results: int = 2,
+ ) -> Iterator[list[gr.ChatMessage]]:
+ """Stream agent responses following agent.py/inference.py logic."""
+ chat_state = {
+ "messages": [{"role": "system", "content": system_prompt}],
+ "finished": False,
+ }
+ # Convert Gradio history to internal format, skip last user msg (passed separately)
+ processed_history = history_gr[:-1] if history_gr else []
+ for msg_obj in processed_history:
+ role = getattr(msg_obj, "role", "unknown")
+ content = getattr(msg_obj, "content", "")
+ if role == "user":
+ chat_state["messages"].append({"role": "user", "content": build_user_prompt(content)})
+ elif role == "assistant":
+ chat_state["messages"].append({"role": "assistant", "content": content})
+
+ chat_state["messages"].append({"role": "user", "content": build_user_prompt(message)})
+
+ initial_token_length = get_chat_num_tokens(chat_state)
+ max_new_tokens_allowed = get_sampling_params(temp)["max_new_tokens"]
+
+ messages = history_gr
+
+ start_time = time.time()
+ iterations = 0
+ last_assistant_response = ""
+
+ while not chat_state.get("finished", False) and iterations < max_iter:
+ iterations += 1
+ current_turn_start_time = time.time()
+
+ think_msg_idx = len(messages)
+ messages.append(
+ gr.ChatMessage(
+ role="assistant",
+ content="Thinking...",
+ metadata={"title": "đ§ Thinking", "status": "pending"},
+ )
+ )
+ yield messages
+
+ current_length_before_gen = get_chat_num_tokens(chat_state)
+ if current_length_before_gen - initial_token_length > max_new_tokens_allowed:
+ logger.warning(
+ f"TOKEN LIMIT EXCEEDED (Before Generation): Current {current_length_before_gen}, Start {initial_token_length}"
+ )
+ chat_state["finished"] = True
+ messages[think_msg_idx] = gr.ChatMessage(
+ role="assistant",
+ content="Context length limit reached.",
+ metadata={"title": "â ī¸ Token Limit", "status": "done"},
+ )
+ yield messages
+ break
+
+ try:
+ generation_params = get_sampling_params(temp)
+ formatted_prompt = apply_chat_template(chat_state, tokenizer=tokenizer_for_template)["text"]
+ inputs = tokenizer(formatted_prompt, return_tensors="pt", add_special_tokens=False).to(model.device)
+
+ outputs = model.generate(**inputs, **generation_params)
+ full_response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
+
+ if assistant_marker in full_response_text:
+ assistant_response = full_response_text.split(assistant_marker)[-1].strip()
+ else:
+ inputs_dict = cast(dict, inputs)
+ input_token_length = len(inputs_dict["input_ids"][0])
+ assistant_response = tokenizer.decode(
+ outputs[0][input_token_length:], skip_special_tokens=True
+ ).strip()
+ logger.warning(
+ f"Assistant marker '{assistant_marker}' not found in response. Extracted via token slicing fallback."
+ )
+
+ last_assistant_response = assistant_response
+ thinking_content = extract_thinking_content(assistant_response)
+
+ gen_time = time.time() - current_turn_start_time
+
+ display_thinking = thinking_content if thinking_content else "Processing..."
+ messages[think_msg_idx] = gr.ChatMessage(
+ role="assistant",
+ content=display_thinking,
+ metadata={"title": "đ§ Thinking", "status": "done", "duration": gen_time},
+ )
+ yield messages
+
+ except Exception as e:
+ logger.error(f"Error during generation: {e}")
+ chat_state["finished"] = True
+ messages[think_msg_idx] = gr.ChatMessage(
+ role="assistant",
+ content=f"Error during generation: {e}",
+ metadata={"title": "â Generation Error", "status": "done"},
+ )
+ yield messages
+ break
+
+ chat_state["messages"].append({"role": "assistant", "content": assistant_response})
+
+ search_query = extract_search_query(assistant_response)
+
+ if not search_query:
+ chat_state["finished"] = True
+ else:
+ search_msg_idx = len(messages)
+ messages.append(
+ gr.ChatMessage(
+ role="assistant",
+ content=f"Searching for: {search_query}",
+ metadata={"title": "đ Search", "status": "pending"},
+ )
+ )
+ yield messages
+ search_start = time.time()
+ try:
+ results = search(search_query, return_type=str, results=num_search_results)
+ search_duration = time.time() - search_start
+
+ messages[search_msg_idx] = gr.ChatMessage(
+ role="assistant",
+ content=f"{search_query}",
+ metadata={"title": "đ Search", "duration": search_duration},
+ )
+ yield messages
+ display_results = format_search_results(results)
+ messages.append(
+ gr.ChatMessage(
+ role="assistant",
+ content=display_results,
+ metadata={"title": "âšī¸ Information", "status": "done"},
+ )
+ )
+ yield messages
+
+ formatted_results = f"{results}"
+ chat_state["messages"].append({"role": "user", "content": formatted_results})
+
+ except Exception as e:
+ search_duration = time.time() - search_start
+ logger.error(f"Search failed: {str(e)}")
+ messages[search_msg_idx] = gr.ChatMessage(
+ role="assistant",
+ content=f"Search failed: {str(e)}",
+ metadata={"title": "â Search Error", "status": "done", "duration": search_duration},
+ )
+ yield messages
+ chat_state["messages"].append({"role": "system", "content": f"Error during search: {str(e)}"})
+ chat_state["finished"] = True
+
+ current_length_after_iter = get_chat_num_tokens(chat_state)
+ if current_length_after_iter - initial_token_length > max_new_tokens_allowed:
+ logger.warning(
+ f"TOKEN LIMIT EXCEEDED (After Iteration): Current {current_length_after_iter}, Start {initial_token_length}"
+ )
+ chat_state["finished"] = True
+ if messages[-1].metadata.get("title") != "â ī¸ Token Limit":
+ messages.append(
+ gr.ChatMessage(
+ role="assistant",
+ content="Context length limit reached during processing.",
+ metadata={"title": "â ī¸ Token Limit", "status": "done"},
+ )
+ )
+ yield messages
+
+ total_time = time.time() - start_time
+
+ if not chat_state.get("finished", False) and iterations >= max_iter:
+ logger.warning(f"Reached maximum iterations ({max_iter}) without finishing")
+ messages.append(
+ gr.ChatMessage(
+ role="assistant",
+ content=f"Reached maximum iterations ({max_iter}). Displaying last response:\n\n{last_assistant_response}",
+ metadata={"title": "â ī¸ Max Iterations", "status": "done", "duration": total_time},
+ )
+ )
+ yield messages
+ elif chat_state.get("finished", False) and last_assistant_response:
+ has_answer, answer_content = extract_answer_tag(last_assistant_response)
+
+ if has_answer and answer_content is not None:
+ display_title = "đ Final Answer"
+ display_content = answer_content
+ else:
+ display_title = "đĄ Answer"
+ display_content = last_assistant_response
+
+ if len(messages) > 0 and messages[-1].content != display_content:
+ messages.append(
+ gr.ChatMessage(
+ role="assistant",
+ content=display_content,
+ metadata={"title": display_title, "duration": total_time},
+ )
+ )
+ yield messages
+ elif len(messages) == 0:
+ messages.append(
+ gr.ChatMessage(
+ role="assistant",
+ content=display_content,
+ metadata={"title": display_title, "duration": total_time},
+ )
+ )
+ yield messages
+ else:
+ messages[-1].metadata["title"] = display_title
+ messages[-1].metadata["status"] = "done"
+ messages[-1].metadata["duration"] = total_time
+
+ logger.info(f"Processing finished in {total_time:.2f} seconds.")
+
+ with gr.Blocks(title="DeepSearch - Visible Thinking") as interface:
+ gr.Markdown("# đ§ DeepSearch with Visible Thinking")
+ gr.Markdown("Watch as the AI thinks, searches, and processes information to answer your questions.")
+
+ with gr.Row():
+ with gr.Column(scale=3):
+ chatbot = gr.Chatbot(
+ [],
+ elem_id="chatbot",
+ type="messages",
+ height=600,
+ show_label=False,
+ render_markdown=True,
+ bubble_full_width=False,
+ )
+ msg = gr.Textbox(
+ placeholder="Type your message here...", show_label=False, container=False, elem_id="msg-input"
+ )
+
+ example_questions = [
+ "What year was the document approved by the Mission Evaluation Team?",
+ "Summarize the key findings regarding the oxygen tank failure.",
+ "Who was the commander of Apollo 13?",
+ "What were the main recommendations from the review board?",
+ ]
+ gr.Examples(examples=example_questions, inputs=msg, label="Example Questions", examples_per_page=4)
+
+ with gr.Row():
+ clear = gr.Button("Clear Chat")
+ submit = gr.Button("Submit", variant="primary")
+
+ with gr.Column(scale=1):
+ gr.Markdown("### Settings")
+ temp_slider = gr.Slider(minimum=0.1, maximum=1.0, value=temperature, step=0.1, label="Temperature")
+ system_prompt_input = gr.Textbox(
+ label="System Prompt", value=system_prompt, lines=3, info="Controls how the AI behaves"
+ )
+ max_iter_slider = gr.Slider(
+ minimum=1,
+ maximum=20,
+ value=20,
+ step=1,
+ label="Max Search Iterations",
+ info="Maximum number of search-think cycles",
+ )
+ num_results_slider = gr.Slider(
+ minimum=1,
+ maximum=5,
+ value=2,
+ step=1,
+ label="Number of Search Results",
+ info="How many results to retrieve per search query",
+ )
+
+ def add_user_message(user_msg_text: str, history: list[gr.ChatMessage]) -> tuple[str, list[gr.ChatMessage]]:
+ """Appends user message to chat history and clears input."""
+ if user_msg_text and user_msg_text.strip():
+ history.append(gr.ChatMessage(role="user", content=user_msg_text.strip()))
+ return "", history
+
+ submitted_msg_state = gr.State("")
+
+ # Chain events:
+ # 1. User submits -> store msg text in state
+ # 2. .then() -> add_user_message (updates chatbot UI history, clears textbox)
+ # 3. .then() -> stream_agent_response (takes stored msg text and updated chatbot history)
+ submit.click(
+ lambda msg_text: msg_text,
+ inputs=[msg],
+ outputs=[submitted_msg_state],
+ queue=False,
+ ).then(
+ add_user_message,
+ inputs=[submitted_msg_state, chatbot],
+ outputs=[msg, chatbot],
+ queue=False,
+ ).then(
+ stream_agent_response,
+ inputs=[submitted_msg_state, chatbot, temp_slider, max_iter_slider, num_results_slider],
+ outputs=chatbot,
+ )
+
+ msg.submit(
+ lambda msg_text: msg_text,
+ inputs=[msg],
+ outputs=[submitted_msg_state],
+ queue=False,
+ ).then(
+ add_user_message,
+ inputs=[submitted_msg_state, chatbot],
+ outputs=[msg, chatbot],
+ queue=False,
+ ).then(
+ stream_agent_response,
+ inputs=[submitted_msg_state, chatbot, temp_slider, max_iter_slider, num_results_slider],
+ outputs=chatbot,
+ )
+
+ clear.click(lambda: ([], ""), None, [chatbot, submitted_msg_state])
+
+ system_prompt_state = gr.State(system_prompt)
+ # TODO: Currently, changing the system prompt mid-chat won't affect the ongoing stream_agent_response.
+ system_prompt_input.change(lambda prompt: prompt, inputs=[system_prompt_input], outputs=[system_prompt_state])
+
+ return interface
+
+
+def main():
+ """Run the Gradio app."""
+ model_path = str(GENERATOR_MODEL_DIR)
+ logger.info(f"Using model from config: {model_path}")
+
+ interface = create_interface(model_path)
+ interface.launch(share=True)
+
+
+if __name__ == "__main__":
+ if load_vectorstore() is None:
+ logger.warning("â ī¸ FAISS vectorstore could not be loaded. Search functionality may be unavailable.")
+
+ main()
diff --git a/config.py b/config.py
index e00ec1d..1ff243f 100644
--- a/config.py
+++ b/config.py
@@ -14,8 +14,18 @@ load_dotenv(override=True)
# Project paths
PROJ_ROOT = Path(__file__).resolve().parent
DATA_DIR = PROJ_ROOT / "data"
+MODEL_DIR = PROJ_ROOT / "models"
LOG_FOLDER = PROJ_ROOT / "logs"
+# Evaluations
+RETRIEVER_MODEL_REPO_ID = "intfloat/e5-base-v2"
+RETRIEVER_MODEL_DIR = MODEL_DIR / "retriever"
+RETRIEVER_SERVER_PORT = 8001
+GENERATOR_MODEL_REPO_ID = "janhq/250404-llama-3.2-3b-instruct-grpo-03-s250"
+GENERATOR_MODEL_DIR = MODEL_DIR / "generator"
+GENERATOR_SERVER_PORT = 8002
+
+
# Model configuration
# MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"