From d8e949ec7c66ae251430a5a4ee49543535537ea7 Mon Sep 17 00:00:00 2001 From: thinhlpg Date: Wed, 9 Apr 2025 06:11:07 +0000 Subject: [PATCH] feat: add Tavily search tab and integrate TavilyClient for web search functionality --- app.py | 553 +++++++++++++++++++++++++++++++++++++++++++------ pyproject.toml | 3 +- 2 files changed, 492 insertions(+), 64 deletions(-) diff --git a/app.py b/app.py index 467cc05..0ee2d22 100644 --- a/app.py +++ b/app.py @@ -5,12 +5,14 @@ This module provides a simple web interface for interacting with the DeepSearch using Gradio. It implements the core functionality directly for better modularity. """ +import os import re import sys import time from typing import Iterator, cast import gradio as gr +from tavily import TavilyClient from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer # Import from config @@ -90,36 +92,44 @@ def get_sampling_params(temperature: float = 0.7, max_tokens: int = 4096): } -def create_interface(model_path: str, temperature: float = 0.7, system_prompt: str | None = None): - """Create Gradio interface for DeepSearch.""" - model, tokenizer, assistant_marker = setup_model_and_tokenizer(model_path) - system_prompt = system_prompt or get_system_prompt() - tokenizer_for_template = cast(PreTrainedTokenizer, tokenizer) +# Define token counting globally, needs tokenizer_for_template accessible +# Note: This requires tokenizer_for_template to be defined before this is called +# We will define tokenizer_for_template globally after model loading in main() +_tokenizer_for_template_global = None # Placeholder + + +def get_chat_num_tokens(current_chat_state: dict, tokenizer: PreTrainedTokenizer) -> int: + """Helper to get number of tokens in chat state.""" + try: + chat_text = apply_chat_template(current_chat_state, tokenizer=tokenizer)["text"] + # Use the passed tokenizer for encoding + input_ids = tokenizer.encode(chat_text, add_special_tokens=False) + return len(input_ids) + except Exception as e: + logger.error(f"Error calculating token count: {e}") + return sys.maxsize + + +def create_deepsearch_tab(model, tokenizer, assistant_marker, system_prompt, temperature): + """Creates the UI components and logic for the DeepSearch (Vector DB) tab.""" + logger.info("Creating DeepSearch Tab") + # tokenizer_for_template = cast(PreTrainedTokenizer, tokenizer) # Global now # Load QA dataset for examples and gold answers try: _, test_dataset = get_qa_dataset() qa_map = {q: a for q, a in zip(test_dataset["prompt"], test_dataset["answer"])} example_questions = list(qa_map.keys()) - logger.info(f"Loaded {len(example_questions)} QA examples.") + logger.info(f"Loaded {len(example_questions)} QA examples for DeepSearch tab.") except Exception as e: - logger.error(f"Failed to load QA dataset: {e}") + logger.error(f"Failed to load QA dataset for DeepSearch tab: {e}") qa_map = {} example_questions = [ "What year was the document approved by the Mission Evaluation Team?", "Failed to load dataset examples.", - ] # Provide fallback examples - - def get_chat_num_tokens(current_chat_state: dict) -> int: - """Helper to get number of tokens in chat state.""" - try: - chat_text = apply_chat_template(current_chat_state, tokenizer=tokenizer_for_template)["text"] - input_ids = tokenizer.encode(chat_text, add_special_tokens=False) - return len(input_ids) - except Exception as e: - logger.error(f"Error calculating token count: {e}") - return sys.maxsize + ] + # --- Agent Streaming Logic for DeepSearch --- def stream_agent_response( message: str, history_gr: list[gr.ChatMessage], @@ -129,11 +139,14 @@ def create_interface(model_path: str, temperature: float = 0.7, system_prompt: s gold_answer_state: str | None = None, ) -> Iterator[list[gr.ChatMessage]]: """Stream agent responses following agent.py/inference.py logic.""" + # Pass the globally defined (and typed) tokenizer to this scope + local_tokenizer_for_template = _tokenizer_for_template_global + assert local_tokenizer_for_template is not None # Ensure it's loaded + chat_state = { "messages": [{"role": "system", "content": system_prompt}], "finished": False, } - # Convert Gradio history to internal format, skip last user msg (passed separately) processed_history = history_gr[:-1] if history_gr else [] for msg_obj in processed_history: role = getattr(msg_obj, "role", "unknown") @@ -145,7 +158,7 @@ def create_interface(model_path: str, temperature: float = 0.7, system_prompt: s chat_state["messages"].append({"role": "user", "content": build_user_prompt(message)}) - initial_token_length = get_chat_num_tokens(chat_state) + initial_token_length = get_chat_num_tokens(chat_state, local_tokenizer_for_template) # Pass tokenizer max_new_tokens_allowed = get_sampling_params(temp)["max_new_tokens"] messages = history_gr @@ -168,23 +181,25 @@ def create_interface(model_path: str, temperature: float = 0.7, system_prompt: s ) yield messages - current_length_before_gen = get_chat_num_tokens(chat_state) + current_length_before_gen = get_chat_num_tokens(chat_state, local_tokenizer_for_template) # Pass tokenizer if current_length_before_gen - initial_token_length > max_new_tokens_allowed: logger.warning( f"TOKEN LIMIT EXCEEDED (Before Generation): Current {current_length_before_gen}, Start {initial_token_length}" ) chat_state["finished"] = True - messages[think_msg_idx] = gr.ChatMessage( - role="assistant", - content="Context length limit reached.", - metadata={"title": "âš ī¸ Token Limit", "status": "done"}, - ) - yield messages - break + messages[think_msg_idx] = gr.ChatMessage( + role="assistant", + content="Context length limit reached.", + metadata={"title": "âš ī¸ Token Limit", "status": "done"}, + ) + yield messages + break try: generation_params = get_sampling_params(temp) - formatted_prompt = apply_chat_template(chat_state, tokenizer=tokenizer_for_template)["text"] + formatted_prompt = apply_chat_template(chat_state, tokenizer=local_tokenizer_for_template)[ + "text" + ] # Use local typed tokenizer inputs = tokenizer(formatted_prompt, return_tensors="pt", add_special_tokens=False).to(model.device) outputs = model.generate(**inputs, **generation_params) @@ -278,7 +293,7 @@ def create_interface(model_path: str, temperature: float = 0.7, system_prompt: s chat_state["messages"].append({"role": "system", "content": f"Error during search: {str(e)}"}) chat_state["finished"] = True - current_length_after_iter = get_chat_num_tokens(chat_state) + current_length_after_iter = get_chat_num_tokens(chat_state, local_tokenizer_for_template) # Pass tokenizer if current_length_after_iter - initial_token_length > max_new_tokens_allowed: logger.warning( f"TOKEN LIMIT EXCEEDED (After Iteration): Current {current_length_after_iter}, Start {initial_token_length}" @@ -341,7 +356,6 @@ def create_interface(model_path: str, temperature: float = 0.7, system_prompt: s logger.info(f"Processing finished in {total_time:.2f} seconds.") - # ---> Display Gold Answer if available (passed via gold_answer_state) if gold_answer_state: logger.info("Displaying gold answer.") messages.append( @@ -355,9 +369,10 @@ def create_interface(model_path: str, temperature: float = 0.7, system_prompt: s else: logger.info("No gold answer to display for this query.") - with gr.Blocks(title="DeepSearch - Visible Thinking") as interface: - gr.Markdown("# 🧠 DeepSearch with Visible Thinking") - gr.Markdown("Watch as the AI thinks, searches, and processes information to answer your questions.") + # --- UI Layout for DeepSearch Tab --- + with gr.Blocks(analytics_enabled=False) as deepsearch_tab: + gr.Markdown("# 🧠 DeepSearch with Visible Thinking (Vector DB)") + gr.Markdown("Ask questions answered using the local vector database.") with gr.Row(): with gr.Column(scale=3): @@ -374,7 +389,6 @@ def create_interface(model_path: str, temperature: float = 0.7, system_prompt: s placeholder="Type your message here...", show_label=False, container=False, elem_id="msg-input" ) - # Use questions from dataset as examples gr.Examples( examples=example_questions, inputs=msg, @@ -409,6 +423,7 @@ def create_interface(model_path: str, temperature: float = 0.7, system_prompt: s info="How many results to retrieve per search query", ) + # --- Event Handlers for DeepSearch Tab --- def add_user_message(user_msg_text: str, history: list[gr.ChatMessage]) -> tuple[str, list[gr.ChatMessage]]: """Appends user message to chat history and clears input.""" if user_msg_text and user_msg_text.strip(): @@ -416,11 +431,10 @@ def create_interface(model_path: str, temperature: float = 0.7, system_prompt: s return "", history submitted_msg_state = gr.State("") - gold_answer_state = gr.State(None) # State to hold the gold answer + gold_answer_state = gr.State(None) - # Function to check if submitted message is an example and store gold answer def check_if_example_and_store_answer(msg_text): - gold_answer = qa_map.get(msg_text) # Returns None if msg_text is not in examples + gold_answer = qa_map.get(msg_text) logger.info(f"Checking for gold answer for: '{msg_text[:50]}...'. Found: {bool(gold_answer)}") return gold_answer @@ -429,7 +443,7 @@ def create_interface(model_path: str, temperature: float = 0.7, system_prompt: s inputs=[msg], outputs=[submitted_msg_state], queue=False, - ).then( # Check for gold answer immediately after storing submitted message + ).then( check_if_example_and_store_answer, inputs=[submitted_msg_state], outputs=[gold_answer_state], @@ -440,15 +454,8 @@ def create_interface(model_path: str, temperature: float = 0.7, system_prompt: s outputs=[msg, chatbot], queue=False, ).then( - stream_agent_response, - inputs=[ - submitted_msg_state, - chatbot, - temp_slider, - max_iter_slider, - num_results_slider, - gold_answer_state, - ], # Pass gold answer state + stream_agent_response, # References the function defined within this scope + inputs=[submitted_msg_state, chatbot, temp_slider, max_iter_slider, num_results_slider, gold_answer_state], outputs=chatbot, ) @@ -457,7 +464,7 @@ def create_interface(model_path: str, temperature: float = 0.7, system_prompt: s inputs=[msg], outputs=[submitted_msg_state], queue=False, - ).then( # Check for gold answer immediately after storing submitted message + ).then( check_if_example_and_store_answer, inputs=[submitted_msg_state], outputs=[gold_answer_state], @@ -468,33 +475,453 @@ def create_interface(model_path: str, temperature: float = 0.7, system_prompt: s outputs=[msg, chatbot], queue=False, ).then( - stream_agent_response, - inputs=[ - submitted_msg_state, - chatbot, - temp_slider, - max_iter_slider, - num_results_slider, - gold_answer_state, - ], # Pass gold answer state + stream_agent_response, # References the function defined within this scope + inputs=[submitted_msg_state, chatbot, temp_slider, max_iter_slider, num_results_slider, gold_answer_state], outputs=chatbot, ) - clear.click(lambda: ([], None), None, [chatbot, gold_answer_state]) # Also clear gold answer state + clear.click(lambda: ([], None), None, [chatbot, gold_answer_state]) system_prompt_state = gr.State(system_prompt) - # TODO: Currently, changing the system prompt mid-chat won't affect the ongoing stream_agent_response. system_prompt_input.change(lambda prompt: prompt, inputs=[system_prompt_input], outputs=[system_prompt_state]) - return interface + return deepsearch_tab + + +def create_tavily_tab(model, tokenizer, assistant_marker, system_prompt, temperature): + """Creates the UI components and logic for the Tavily Search tab.""" + logger.info("Creating Tavily Search Tab") + # tokenizer_for_template = cast(PreTrainedTokenizer, tokenizer) # Global now + + # --- Tavily Client Setup --- + tavily_api_key = os.getenv("TAVILY_API_KEY") + if not tavily_api_key: + logger.error("TAVILY_API_KEY not found in environment variables.") + with gr.Blocks(analytics_enabled=False) as tavily_tab_error: + gr.Markdown("# âš ī¸ Tavily Search Error") + gr.Markdown("TAVILY_API_KEY environment variable not set. Please set it and restart the application.") + return tavily_tab_error + + try: + tavily_client = TavilyClient(api_key=tavily_api_key) + logger.info("TavilyClient initialized successfully.") + except Exception as e: + logger.error(f"Failed to initialize TavilyClient: {e}") + with gr.Blocks(analytics_enabled=False) as tavily_tab_error: + gr.Markdown("# âš ī¸ Tavily Client Initialization Error") + gr.Markdown(f"Failed to initialize Tavily Client: {e}") + return tavily_tab_error + + # --- Agent Streaming Logic for Tavily --- + def stream_tavily_agent_response( + message: str, + history_gr: list[gr.ChatMessage], + temp: float, + max_iter: int = 20, + num_search_results: int = 2, # Tavily default/recommendation might differ + ) -> Iterator[list[gr.ChatMessage]]: + """Stream agent responses using Tavily for search.""" + local_tokenizer_for_template = _tokenizer_for_template_global # Use global + assert local_tokenizer_for_template is not None + + chat_state = { + "messages": [{"role": "system", "content": system_prompt}], + "finished": False, + } + processed_history = history_gr[:-1] if history_gr else [] + for msg_obj in processed_history: + role = getattr(msg_obj, "role", "unknown") + content = getattr(msg_obj, "content", "") + if role == "user": + chat_state["messages"].append({"role": "user", "content": build_user_prompt(content)}) + elif role == "assistant": + chat_state["messages"].append({"role": "assistant", "content": content}) + + chat_state["messages"].append({"role": "user", "content": build_user_prompt(message)}) + + initial_token_length = get_chat_num_tokens(chat_state, local_tokenizer_for_template) + max_new_tokens_allowed = get_sampling_params(temp)["max_new_tokens"] + + messages = history_gr + + start_time = time.time() + iterations = 0 + last_assistant_response = "" + + while not chat_state.get("finished", False) and iterations < max_iter: + iterations += 1 + current_turn_start_time = time.time() + + think_msg_idx = len(messages) + messages.append( + gr.ChatMessage( + role="assistant", + content="Thinking...", + metadata={"title": "🧠 Thinking", "status": "pending"}, + ) + ) + yield messages + + current_length_before_gen = get_chat_num_tokens(chat_state, local_tokenizer_for_template) + if current_length_before_gen - initial_token_length > max_new_tokens_allowed: + logger.warning( + f"TOKEN LIMIT EXCEEDED (Before Generation): Current {current_length_before_gen}, Start {initial_token_length}" + ) + chat_state["finished"] = True + messages[think_msg_idx] = gr.ChatMessage( + role="assistant", + content="Context length limit reached.", + metadata={"title": "âš ī¸ Token Limit", "status": "done"}, + ) + yield messages + break + + try: + generation_params = get_sampling_params(temp) + formatted_prompt = apply_chat_template(chat_state, tokenizer=local_tokenizer_for_template)["text"] + inputs = tokenizer(formatted_prompt, return_tensors="pt", add_special_tokens=False).to(model.device) + + outputs = model.generate(**inputs, **generation_params) + full_response_text = tokenizer.decode(outputs[0], skip_special_tokens=True) + + if assistant_marker in full_response_text: + assistant_response = full_response_text.split(assistant_marker)[-1].strip() + else: + inputs_dict = cast(dict, inputs) + input_token_length = len(inputs_dict["input_ids"][0]) + assistant_response = tokenizer.decode( + outputs[0][input_token_length:], skip_special_tokens=True + ).strip() + logger.warning( + f"Assistant marker '{assistant_marker}' not found in response. Extracted via token slicing fallback." + ) + + last_assistant_response = assistant_response + thinking_content = extract_thinking_content(assistant_response) + + gen_time = time.time() - current_turn_start_time + + display_thinking = thinking_content if thinking_content else "Processing..." + messages[think_msg_idx] = gr.ChatMessage( + role="assistant", + content=display_thinking, + metadata={"title": "🧠 Thinking", "status": "done", "duration": gen_time}, + ) + yield messages + + except Exception as e: + logger.error(f"Error during generation: {e}") + chat_state["finished"] = True + messages[think_msg_idx] = gr.ChatMessage( + role="assistant", + content=f"Error during generation: {e}", + metadata={"title": "❌ Generation Error", "status": "done"}, + ) + yield messages + break + + chat_state["messages"].append({"role": "assistant", "content": assistant_response}) + + search_query = extract_search_query(assistant_response) + + if not search_query: + chat_state["finished"] = True + else: + search_msg_idx = len(messages) + messages.append( + gr.ChatMessage( + role="assistant", + content=f"Searching (Tavily) for: {search_query}", + metadata={"title": "🔍 Tavily Search", "status": "pending"}, + ) + ) + yield messages + search_start = time.time() + try: + # --- Tavily Search Call --- + logger.info(f"Performing Tavily search for: {search_query}") + tavily_response = tavily_client.search( + query=search_query, + search_depth="advanced", + max_results=num_search_results, + include_answer=False, + include_raw_content=False, + ) + search_duration = time.time() - search_start + logger.info(f"Tavily search completed in {search_duration:.2f}s.") + + # --- Format Tavily Results --- + results_list = tavily_response.get("results", []) + formatted_tavily_results = "" + if results_list: + formatted_tavily_results = "\n".join( + [ + f"Doc {i + 1} (Title: {res.get('title', 'N/A')}) URL: {res.get('url', 'N/A')}\n{res.get('content', '')}" + for i, res in enumerate(results_list) + ] + ) + else: + formatted_tavily_results = "No results found by Tavily." + + messages[search_msg_idx] = gr.ChatMessage( + role="assistant", + content=f"{search_query}", + metadata={"title": "🔍 Tavily Search", "duration": search_duration}, + ) + yield messages + + display_results = formatted_tavily_results + messages.append( + gr.ChatMessage( + role="assistant", + content=display_results, + metadata={"title": "â„šī¸ Tavily Information", "status": "done"}, + ) + ) + yield messages + + formatted_results_for_llm = f"{formatted_tavily_results}" + chat_state["messages"].append({"role": "user", "content": formatted_results_for_llm}) + + except Exception as e: + search_duration = time.time() - search_start + logger.error(f"Tavily Search failed: {str(e)}") + messages[search_msg_idx] = gr.ChatMessage( + role="assistant", + content=f"Tavily Search failed: {str(e)}", + metadata={"title": "❌ Tavily Search Error", "status": "done", "duration": search_duration}, + ) + yield messages + chat_state["messages"].append( + {"role": "system", "content": f"Error during Tavily search: {str(e)}"} + ) + chat_state["finished"] = True + + current_length_after_iter = get_chat_num_tokens(chat_state, local_tokenizer_for_template) + if current_length_after_iter - initial_token_length > max_new_tokens_allowed: + logger.warning( + f"TOKEN LIMIT EXCEEDED (After Iteration): Current {current_length_after_iter}, Start {initial_token_length}" + ) + chat_state["finished"] = True + if messages[-1].metadata.get("title") != "âš ī¸ Token Limit": + messages.append( + gr.ChatMessage( + role="assistant", + content="Context length limit reached during processing.", + metadata={"title": "âš ī¸ Token Limit", "status": "done"}, + ) + ) + yield messages + + total_time = time.time() - start_time + + if not chat_state.get("finished", False) and iterations >= max_iter: + logger.warning(f"Reached maximum iterations ({max_iter}) without finishing") + messages.append( + gr.ChatMessage( + role="assistant", + content=f"Reached maximum iterations ({max_iter}). Displaying last response:\n\n{last_assistant_response}", + metadata={"title": "âš ī¸ Max Iterations", "status": "done", "duration": total_time}, + ) + ) + yield messages + elif chat_state.get("finished", False) and last_assistant_response: + has_answer, answer_content = extract_answer_tag(last_assistant_response) + + if has_answer and answer_content is not None: + display_title = "📝 Final Answer" + display_content = answer_content + else: + display_title = "💡 Answer" + display_content = last_assistant_response + + if len(messages) > 0 and messages[-1].content != display_content: + messages.append( + gr.ChatMessage( + role="assistant", + content=display_content, + metadata={"title": display_title, "duration": total_time}, + ) + ) + yield messages + elif len(messages) == 0: + messages.append( + gr.ChatMessage( + role="assistant", + content=display_content, + metadata={"title": display_title, "duration": total_time}, + ) + ) + yield messages + else: + messages[-1].metadata["title"] = display_title + messages[-1].metadata["status"] = "done" + messages[-1].metadata["duration"] = total_time + + logger.info(f"Processing finished in {total_time:.2f} seconds.") + + # --- UI Layout for Tavily Tab --- + with gr.Blocks(analytics_enabled=False) as tavily_tab: + gr.Markdown("# 🌐 Tavily Search with Visible Thinking") + gr.Markdown("Ask questions answered using the Tavily web search API.") + + with gr.Row(): + with gr.Column(scale=3): + tavily_chatbot = gr.Chatbot( + [], + elem_id="tavily_chatbot", + type="messages", + height=600, + show_label=False, + render_markdown=True, + bubble_full_width=False, + ) + tavily_msg = gr.Textbox( + placeholder="Type your message here...", + show_label=False, + container=False, + elem_id="tavily_msg-input", + ) + tavily_example_questions = [ + "What is the weather like in London today?", + "Summarize the latest news about AI advancements.", + "Who won the last Formula 1 race?", + ] + gr.Examples( + examples=tavily_example_questions, + inputs=tavily_msg, + label="Example Questions (Web Search)", + examples_per_page=3, + ) + with gr.Row(): + tavily_clear = gr.Button("Clear Chat") + tavily_submit = gr.Button("Submit", variant="primary") + + with gr.Column(scale=1): + gr.Markdown("### Settings") + tavily_temp_slider = gr.Slider( + minimum=0.1, maximum=1.0, value=temperature, step=0.1, label="Temperature" + ) + tavily_system_prompt_input = gr.Textbox( + label="System Prompt", value=system_prompt, lines=3, info="Controls how the AI behaves" + ) + tavily_max_iter_slider = gr.Slider( + minimum=1, + maximum=20, + value=10, + step=1, + label="Max Search Iterations", + info="Maximum number of search-think cycles", + ) + tavily_num_results_slider = gr.Slider( + minimum=1, + maximum=5, + value=3, + step=1, + label="Number of Search Results", + info="How many results to retrieve per search query", + ) + + # --- Event Handlers for Tavily Tab --- + def tavily_add_user_message( + user_msg_text: str, history: list[gr.ChatMessage] + ) -> tuple[str, list[gr.ChatMessage]]: + if user_msg_text and user_msg_text.strip(): + history.append(gr.ChatMessage(role="user", content=user_msg_text.strip())) + return "", history + + tavily_submitted_msg_state = gr.State("") + + tavily_submit.click( + lambda msg_text: msg_text, + inputs=[tavily_msg], + outputs=[tavily_submitted_msg_state], + queue=False, + ).then( + tavily_add_user_message, + inputs=[tavily_submitted_msg_state, tavily_chatbot], + outputs=[tavily_msg, tavily_chatbot], + queue=False, + ).then( + stream_tavily_agent_response, # Use Tavily-specific stream function + inputs=[ + tavily_submitted_msg_state, + tavily_chatbot, + tavily_temp_slider, + tavily_max_iter_slider, + tavily_num_results_slider, + ], + outputs=tavily_chatbot, + ) + + tavily_msg.submit( + lambda msg_text: msg_text, + inputs=[tavily_msg], + outputs=[tavily_submitted_msg_state], + queue=False, + ).then( + tavily_add_user_message, + inputs=[tavily_submitted_msg_state, tavily_chatbot], + outputs=[tavily_msg, tavily_chatbot], + queue=False, + ).then( + stream_tavily_agent_response, # Use Tavily-specific stream function + inputs=[ + tavily_submitted_msg_state, + tavily_chatbot, + tavily_temp_slider, + tavily_max_iter_slider, + tavily_num_results_slider, + ], + outputs=tavily_chatbot, + ) + + tavily_clear.click(lambda: ([], ""), None, [tavily_chatbot, tavily_submitted_msg_state]) + + tavily_system_prompt_state = gr.State(system_prompt) + tavily_system_prompt_input.change( + lambda prompt: prompt, inputs=[tavily_system_prompt_input], outputs=[tavily_system_prompt_state] + ) + + return tavily_tab def main(): - """Run the Gradio app.""" + """Run the Gradio app with tabs.""" model_path = str(GENERATOR_MODEL_DIR) logger.info(f"Using model from config: {model_path}") - interface = create_interface(model_path) + # Shared model setup (do once) + try: + model, tokenizer, assistant_marker = setup_model_and_tokenizer(model_path) + except Exception as e: + logger.critical(f"Failed to load model/tokenizer: {e}") + # Display error if model fails to load + with gr.Blocks() as demo: + gr.Markdown("# Critical Error") + gr.Markdown( + f"Failed to load model or tokenizer from '{model_path}'. Check the path and ensure the model exists.\n\nError: {e}" + ) + demo.launch(share=True) + sys.exit(1) # Exit if model loading fails + + system_prompt = get_system_prompt() + default_temp = 0.7 + + # Define tokenizer_for_template globally after successful load + global _tokenizer_for_template_global + _tokenizer_for_template_global = cast(PreTrainedTokenizer, tokenizer) + + # Create content for each tab + tab1 = create_deepsearch_tab(model, tokenizer, assistant_marker, system_prompt, default_temp) + tab2 = create_tavily_tab(model, tokenizer, assistant_marker, system_prompt, default_temp) + + # Combine tabs + interface = gr.TabbedInterface( + [tab1, tab2], tab_names=["DeepSearch (VectorDB)", "Tavily Search (Web)"], title="DeepSearch Agent UI" + ) + + logger.info("Launching Gradio Tabbed Interface...") interface.launch(share=True) diff --git a/pyproject.toml b/pyproject.toml index 5bab91e..d4946c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,5 +36,6 @@ dependencies = [ "pytest", "wandb", "requests>=2.31.0", - "tqdm>=4.66.1" + "tqdm>=4.66.1", + "tavily-python", ] \ No newline at end of file