diff --git a/.gitignore b/.gitignore index 0f7c25a..9348aa4 100644 --- a/.gitignore +++ b/.gitignore @@ -17,6 +17,7 @@ downloaded_model/ logs/ *.code-workspace data/ +.gradio/ # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/app.py b/app.py new file mode 100644 index 0000000..3bbec08 --- /dev/null +++ b/app.py @@ -0,0 +1,449 @@ +""" +Gradio web interface for DeepSearch. + +This module provides a simple web interface for interacting with the DeepSearch model +using Gradio. It implements the core functionality directly for better modularity. +""" + +import re +import sys +import time +from typing import Iterator, cast + +import gradio as gr +from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer + +# Import from config +from config import GENERATOR_MODEL_DIR, logger +from src import ( + apply_chat_template, + build_user_prompt, + format_search_results, + get_system_prompt, +) +from src.search_module import load_vectorstore, search + +# TODO: check if can reuse tokenizer adapter + + +def extract_answer_tag(text: str) -> tuple[bool, str | None]: + """Check if text contains an answer tag and extract the answer content if found. + + Returns: + tuple: (has_answer, answer_content) + """ + pattern = re.compile(r"(.*?)", re.DOTALL | re.IGNORECASE) + match = re.search(pattern, text) + if match: + content = match.group(1).strip() + return True, content + return False, None + + +def extract_thinking_content(text: str) -> str | None: + """Extract thinking content from text between tags.""" + pattern = re.compile(r"(.*?)", re.DOTALL | re.IGNORECASE) + match = re.search(pattern, text) + if match: + content = match.group(1).strip() + return content + return None + + +def extract_search_query(text: str) -> str | None: + """Extract search query from text between tags (Simplified).""" + pattern = re.compile(r"(.*?)", re.DOTALL | re.IGNORECASE) + match = re.search(pattern, text) + if match: + content = match.group(1).strip() + return content + return None + + +def setup_model_and_tokenizer(model_path: str): + """Initialize model and tokenizer.""" + logger.info(f"Setting up model from {model_path}...") + + model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype="float16", + device_map="auto", + trust_remote_code=True, + ) + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + + # Defaulting to the one used in inference.py, adjust if needed for your specific model + assistant_marker = "<|start_header_id|>assistant<|end_header_id|>" + logger.info(f"Using assistant marker: '{assistant_marker}' for response splitting.") + + logger.info("Model and tokenizer setup complete.") + return model, tokenizer, assistant_marker + + +def get_sampling_params(temperature: float = 0.7, max_tokens: int = 4096): + """Get sampling parameters for generation.""" + return { + "temperature": temperature, + "top_p": 0.95, + "max_new_tokens": max_tokens, + "do_sample": True, + } + + +def create_interface(model_path: str, temperature: float = 0.7, system_prompt: str | None = None): + """Create Gradio interface for DeepSearch.""" + model, tokenizer, assistant_marker = setup_model_and_tokenizer(model_path) + system_prompt = system_prompt or get_system_prompt() + tokenizer_for_template = cast(PreTrainedTokenizer, tokenizer) + + def get_chat_num_tokens(current_chat_state: dict) -> int: + """Helper to get number of tokens in chat state.""" + try: + chat_text = apply_chat_template(current_chat_state, tokenizer=tokenizer_for_template)["text"] + input_ids = tokenizer.encode(chat_text, add_special_tokens=False) + return len(input_ids) + except Exception as e: + logger.error(f"Error calculating token count: {e}") + return sys.maxsize + + def stream_agent_response( + message: str, + history_gr: list[gr.ChatMessage], + temp: float, + max_iter: int = 20, + num_search_results: int = 2, + ) -> Iterator[list[gr.ChatMessage]]: + """Stream agent responses following agent.py/inference.py logic.""" + chat_state = { + "messages": [{"role": "system", "content": system_prompt}], + "finished": False, + } + # Convert Gradio history to internal format, skip last user msg (passed separately) + processed_history = history_gr[:-1] if history_gr else [] + for msg_obj in processed_history: + role = getattr(msg_obj, "role", "unknown") + content = getattr(msg_obj, "content", "") + if role == "user": + chat_state["messages"].append({"role": "user", "content": build_user_prompt(content)}) + elif role == "assistant": + chat_state["messages"].append({"role": "assistant", "content": content}) + + chat_state["messages"].append({"role": "user", "content": build_user_prompt(message)}) + + initial_token_length = get_chat_num_tokens(chat_state) + max_new_tokens_allowed = get_sampling_params(temp)["max_new_tokens"] + + messages = history_gr + + start_time = time.time() + iterations = 0 + last_assistant_response = "" + + while not chat_state.get("finished", False) and iterations < max_iter: + iterations += 1 + current_turn_start_time = time.time() + + think_msg_idx = len(messages) + messages.append( + gr.ChatMessage( + role="assistant", + content="Thinking...", + metadata={"title": "🧠 Thinking", "status": "pending"}, + ) + ) + yield messages + + current_length_before_gen = get_chat_num_tokens(chat_state) + if current_length_before_gen - initial_token_length > max_new_tokens_allowed: + logger.warning( + f"TOKEN LIMIT EXCEEDED (Before Generation): Current {current_length_before_gen}, Start {initial_token_length}" + ) + chat_state["finished"] = True + messages[think_msg_idx] = gr.ChatMessage( + role="assistant", + content="Context length limit reached.", + metadata={"title": "âš ī¸ Token Limit", "status": "done"}, + ) + yield messages + break + + try: + generation_params = get_sampling_params(temp) + formatted_prompt = apply_chat_template(chat_state, tokenizer=tokenizer_for_template)["text"] + inputs = tokenizer(formatted_prompt, return_tensors="pt", add_special_tokens=False).to(model.device) + + outputs = model.generate(**inputs, **generation_params) + full_response_text = tokenizer.decode(outputs[0], skip_special_tokens=True) + + if assistant_marker in full_response_text: + assistant_response = full_response_text.split(assistant_marker)[-1].strip() + else: + inputs_dict = cast(dict, inputs) + input_token_length = len(inputs_dict["input_ids"][0]) + assistant_response = tokenizer.decode( + outputs[0][input_token_length:], skip_special_tokens=True + ).strip() + logger.warning( + f"Assistant marker '{assistant_marker}' not found in response. Extracted via token slicing fallback." + ) + + last_assistant_response = assistant_response + thinking_content = extract_thinking_content(assistant_response) + + gen_time = time.time() - current_turn_start_time + + display_thinking = thinking_content if thinking_content else "Processing..." + messages[think_msg_idx] = gr.ChatMessage( + role="assistant", + content=display_thinking, + metadata={"title": "🧠 Thinking", "status": "done", "duration": gen_time}, + ) + yield messages + + except Exception as e: + logger.error(f"Error during generation: {e}") + chat_state["finished"] = True + messages[think_msg_idx] = gr.ChatMessage( + role="assistant", + content=f"Error during generation: {e}", + metadata={"title": "❌ Generation Error", "status": "done"}, + ) + yield messages + break + + chat_state["messages"].append({"role": "assistant", "content": assistant_response}) + + search_query = extract_search_query(assistant_response) + + if not search_query: + chat_state["finished"] = True + else: + search_msg_idx = len(messages) + messages.append( + gr.ChatMessage( + role="assistant", + content=f"Searching for: {search_query}", + metadata={"title": "🔍 Search", "status": "pending"}, + ) + ) + yield messages + search_start = time.time() + try: + results = search(search_query, return_type=str, results=num_search_results) + search_duration = time.time() - search_start + + messages[search_msg_idx] = gr.ChatMessage( + role="assistant", + content=f"{search_query}", + metadata={"title": "🔍 Search", "duration": search_duration}, + ) + yield messages + display_results = format_search_results(results) + messages.append( + gr.ChatMessage( + role="assistant", + content=display_results, + metadata={"title": "â„šī¸ Information", "status": "done"}, + ) + ) + yield messages + + formatted_results = f"{results}" + chat_state["messages"].append({"role": "user", "content": formatted_results}) + + except Exception as e: + search_duration = time.time() - search_start + logger.error(f"Search failed: {str(e)}") + messages[search_msg_idx] = gr.ChatMessage( + role="assistant", + content=f"Search failed: {str(e)}", + metadata={"title": "❌ Search Error", "status": "done", "duration": search_duration}, + ) + yield messages + chat_state["messages"].append({"role": "system", "content": f"Error during search: {str(e)}"}) + chat_state["finished"] = True + + current_length_after_iter = get_chat_num_tokens(chat_state) + if current_length_after_iter - initial_token_length > max_new_tokens_allowed: + logger.warning( + f"TOKEN LIMIT EXCEEDED (After Iteration): Current {current_length_after_iter}, Start {initial_token_length}" + ) + chat_state["finished"] = True + if messages[-1].metadata.get("title") != "âš ī¸ Token Limit": + messages.append( + gr.ChatMessage( + role="assistant", + content="Context length limit reached during processing.", + metadata={"title": "âš ī¸ Token Limit", "status": "done"}, + ) + ) + yield messages + + total_time = time.time() - start_time + + if not chat_state.get("finished", False) and iterations >= max_iter: + logger.warning(f"Reached maximum iterations ({max_iter}) without finishing") + messages.append( + gr.ChatMessage( + role="assistant", + content=f"Reached maximum iterations ({max_iter}). Displaying last response:\n\n{last_assistant_response}", + metadata={"title": "âš ī¸ Max Iterations", "status": "done", "duration": total_time}, + ) + ) + yield messages + elif chat_state.get("finished", False) and last_assistant_response: + has_answer, answer_content = extract_answer_tag(last_assistant_response) + + if has_answer and answer_content is not None: + display_title = "📝 Final Answer" + display_content = answer_content + else: + display_title = "💡 Answer" + display_content = last_assistant_response + + if len(messages) > 0 and messages[-1].content != display_content: + messages.append( + gr.ChatMessage( + role="assistant", + content=display_content, + metadata={"title": display_title, "duration": total_time}, + ) + ) + yield messages + elif len(messages) == 0: + messages.append( + gr.ChatMessage( + role="assistant", + content=display_content, + metadata={"title": display_title, "duration": total_time}, + ) + ) + yield messages + else: + messages[-1].metadata["title"] = display_title + messages[-1].metadata["status"] = "done" + messages[-1].metadata["duration"] = total_time + + logger.info(f"Processing finished in {total_time:.2f} seconds.") + + with gr.Blocks(title="DeepSearch - Visible Thinking") as interface: + gr.Markdown("# 🧠 DeepSearch with Visible Thinking") + gr.Markdown("Watch as the AI thinks, searches, and processes information to answer your questions.") + + with gr.Row(): + with gr.Column(scale=3): + chatbot = gr.Chatbot( + [], + elem_id="chatbot", + type="messages", + height=600, + show_label=False, + render_markdown=True, + bubble_full_width=False, + ) + msg = gr.Textbox( + placeholder="Type your message here...", show_label=False, container=False, elem_id="msg-input" + ) + + example_questions = [ + "What year was the document approved by the Mission Evaluation Team?", + "Summarize the key findings regarding the oxygen tank failure.", + "Who was the commander of Apollo 13?", + "What were the main recommendations from the review board?", + ] + gr.Examples(examples=example_questions, inputs=msg, label="Example Questions", examples_per_page=4) + + with gr.Row(): + clear = gr.Button("Clear Chat") + submit = gr.Button("Submit", variant="primary") + + with gr.Column(scale=1): + gr.Markdown("### Settings") + temp_slider = gr.Slider(minimum=0.1, maximum=1.0, value=temperature, step=0.1, label="Temperature") + system_prompt_input = gr.Textbox( + label="System Prompt", value=system_prompt, lines=3, info="Controls how the AI behaves" + ) + max_iter_slider = gr.Slider( + minimum=1, + maximum=20, + value=20, + step=1, + label="Max Search Iterations", + info="Maximum number of search-think cycles", + ) + num_results_slider = gr.Slider( + minimum=1, + maximum=5, + value=2, + step=1, + label="Number of Search Results", + info="How many results to retrieve per search query", + ) + + def add_user_message(user_msg_text: str, history: list[gr.ChatMessage]) -> tuple[str, list[gr.ChatMessage]]: + """Appends user message to chat history and clears input.""" + if user_msg_text and user_msg_text.strip(): + history.append(gr.ChatMessage(role="user", content=user_msg_text.strip())) + return "", history + + submitted_msg_state = gr.State("") + + # Chain events: + # 1. User submits -> store msg text in state + # 2. .then() -> add_user_message (updates chatbot UI history, clears textbox) + # 3. .then() -> stream_agent_response (takes stored msg text and updated chatbot history) + submit.click( + lambda msg_text: msg_text, + inputs=[msg], + outputs=[submitted_msg_state], + queue=False, + ).then( + add_user_message, + inputs=[submitted_msg_state, chatbot], + outputs=[msg, chatbot], + queue=False, + ).then( + stream_agent_response, + inputs=[submitted_msg_state, chatbot, temp_slider, max_iter_slider, num_results_slider], + outputs=chatbot, + ) + + msg.submit( + lambda msg_text: msg_text, + inputs=[msg], + outputs=[submitted_msg_state], + queue=False, + ).then( + add_user_message, + inputs=[submitted_msg_state, chatbot], + outputs=[msg, chatbot], + queue=False, + ).then( + stream_agent_response, + inputs=[submitted_msg_state, chatbot, temp_slider, max_iter_slider, num_results_slider], + outputs=chatbot, + ) + + clear.click(lambda: ([], ""), None, [chatbot, submitted_msg_state]) + + system_prompt_state = gr.State(system_prompt) + # TODO: Currently, changing the system prompt mid-chat won't affect the ongoing stream_agent_response. + system_prompt_input.change(lambda prompt: prompt, inputs=[system_prompt_input], outputs=[system_prompt_state]) + + return interface + + +def main(): + """Run the Gradio app.""" + model_path = str(GENERATOR_MODEL_DIR) + logger.info(f"Using model from config: {model_path}") + + interface = create_interface(model_path) + interface.launch(share=True) + + +if __name__ == "__main__": + if load_vectorstore() is None: + logger.warning("âš ī¸ FAISS vectorstore could not be loaded. Search functionality may be unavailable.") + + main() diff --git a/config.py b/config.py index e00ec1d..1ff243f 100644 --- a/config.py +++ b/config.py @@ -14,8 +14,18 @@ load_dotenv(override=True) # Project paths PROJ_ROOT = Path(__file__).resolve().parent DATA_DIR = PROJ_ROOT / "data" +MODEL_DIR = PROJ_ROOT / "models" LOG_FOLDER = PROJ_ROOT / "logs" +# Evaluations +RETRIEVER_MODEL_REPO_ID = "intfloat/e5-base-v2" +RETRIEVER_MODEL_DIR = MODEL_DIR / "retriever" +RETRIEVER_SERVER_PORT = 8001 +GENERATOR_MODEL_REPO_ID = "janhq/250404-llama-3.2-3b-instruct-grpo-03-s250" +GENERATOR_MODEL_DIR = MODEL_DIR / "generator" +GENERATOR_SERVER_PORT = 8002 + + # Model configuration # MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"