diff --git a/eval.py b/eval.py index 36b7679..b9167ca 100644 --- a/eval.py +++ b/eval.py @@ -312,7 +312,7 @@ def evaluate_model( # Define all output file paths eval_log_file = os.path.join(eval_log_dir, f"{model_prefix}_model_eval_{timestamp}.log") output_file = os.path.join(eval_log_dir, f"{model_prefix}_model_results.txt") - debug_file = os.path.join(eval_log_dir, f"{model_prefix}_model_results_debug.json") + debug_file = os.path.join(eval_log_dir, f"{model_prefix}_model_results_debug.txt") logger.info(f"Writing evaluation log to: {eval_log_file}") logger.info(f"Results will be saved to: {output_file}") diff --git a/inference.py b/inference.py index 85ec4cd..5664fd8 100644 --- a/inference.py +++ b/inference.py @@ -6,15 +6,18 @@ and provides search functionality for data retrieval. """ import argparse -import json import os import time from datetime import datetime -from typing import Any, Dict, List, Optional, Union from transformers import AutoModelForCausalLM, AutoTokenizer from vllm import SamplingParams +from src.rl_helpers import ( + build_user_prompt, + extract_search_query, + format_search_results, +) from src.search_module import load_vectorstore, search @@ -43,72 +46,6 @@ def get_sampling_params(temperature: float = 0.7, max_tokens: int = 4096) -> Sam ) -def extract_function_calls(text: str) -> List[Dict[str, Any]]: - """Extract function calls from a text.""" - import json - import re - - # Pattern to match JSON objects - pattern = r"\{(?:[^{}]|(?:\{(?:[^{}]|(?:\{[^{}]*\}))*\}))*\}" - json_matches = re.findall(pattern, text) - - function_calls = [] - for json_str in json_matches: - try: - obj = json.loads(json_str) - if "function" in obj: - function_calls.append(obj) - except json.JSONDecodeError: - continue - - return function_calls - - -def build_user_prompt(q): - """ - Build a user prompt with the question and search tool definition. - - Args: - q (str): The question to ask - - Returns: - str: Formatted user prompt - """ - user_prompt = f"""You are a research assistant, and you use the search_corpus tool to find answers to questions. -Given a question, answer it using by doing searches using the search_corpus tool. -To use the search_corpus tool, respond with a JSON for a function call with its proper arguments. - -PLEASE CONSIDER CHAT HISTORY WHEN ANSWERING THE QUESTION. -ONLY ANSWER WHEN YOU HAVE 100% CONFIDENCE IN THE SEARCH RESULTS, ELSE CONTINUE SEARCHING. -PLEASE SEARCH MULTIPLE TIMES WITH DIFFERENT QUERIES. - -You may also reason in any message, think step by step about how to answer the question. Wrap your reasoning in and tags. - -{json.dumps(SEARCH_TOOL_DEFINITION, indent=2)} - -Question: {q} -""" - return user_prompt - - -def format_search_results(results: Union[str, List[str]]) -> str: - """ - Format search results for display. - - Args: - results: Search results as string or list of strings - - Returns: - Formatted search results - """ - if isinstance(results, list): - content = "\n".join([f"Result {i + 1}:\n{r}\n------" for i, r in enumerate(results)]) - else: - content = results - - return f"\n===== SEARCH RESULTS =====\n{content}\n===========================\n" - - class DeepSearchCLI: """CLI for interacting with the model and search functionality.""" @@ -116,7 +53,7 @@ class DeepSearchCLI: self, model_path: str, temperature: float = 0.7, - system_prompt: Optional[str] = None, + system_prompt: str | None = None, ): """ Initialize the CLI. @@ -213,16 +150,16 @@ You are a helpful assistant with tool calling capabilities.""" return final_response def _check_finished_chat(self, chat_state: dict) -> dict: - """Check if the chat is finished (no more function calls).""" + """Check if the chat is finished (no more search queries).""" if chat_state.get("finished"): return chat_state assert chat_state["messages"][-1]["role"] == "assistant", "Expected the last role to be assistant" assistant_response = chat_state["messages"][-1]["content"] - function_calls = extract_function_calls(assistant_response) + search_query = extract_search_query(assistant_response) - if len(function_calls) == 0: + if not search_query: chat_state["finished"] = True return chat_state @@ -234,32 +171,26 @@ You are a helpful assistant with tool calling capabilities.""" try: assistant_response = chat_state["messages"][-1]["content"] - function_calls = extract_function_calls(assistant_response) - - if len(function_calls) > 1: - print("Multiple function calls found in assistant response") - raise ValueError("Expected only one function call in assistant response") + search_query = extract_search_query(assistant_response) - elif len(function_calls) == 1: - function_call = function_calls[0] - query = function_call["function"]["parameters"]["query"] - print(f"🔍 Search Query: {query}") + if search_query: + print(f"🔍 Search Query: {search_query}") - results = search(query, return_type=str, results=2) + results = search(search_query, return_type=str, results=2) + # Wrap results in tags + formatted_results = f"{results}" # Print search results to terminal - # logger.info("\n===== SEARCH RESULTS =====") - # logger.info( - # results - # ) # The results are already formatted with Result 1:, Result 2:, etc. - # logger.info("===========================\n") + print("\n===== SEARCH RESULTS =====") + print(results) + print("===========================\n") - chat_state["messages"].append({"role": "ipython", "content": results}) + chat_state["messages"].append({"role": "ipython", "content": formatted_results}) # Record search in history search_entry = { "turn": len(self.history) // 2, - "searches": [{"query": query, "results": results}], + "searches": [{"query": search_query, "results": results}], } self.search_history.append(search_entry) @@ -414,65 +345,6 @@ You are a helpful assistant with tool calling capabilities.""" print(f"Error saving chat history: {e}") return None - def save_chat_history_json(self, filepath=None): - """ - Save chat history to a JSON file. - - Args: - filepath: Path to save file (if None, auto-generate based on timestamp) - - Returns: - Path to the saved file - """ - if not self.history: - print("No chat history to save.") - return None - - # Generate a default filepath if none provided - if filepath is None: - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - filepath = os.path.join(os.getcwd(), f"chat_history_{timestamp}.json") - - # Ensure the directory exists - os.makedirs(os.path.dirname(filepath), exist_ok=True) - - # Prepare chat history data - history_data = { - "model": self.model.name_or_path, - "temperature": self.temperature, - "timestamp": datetime.now().isoformat(), - "turns": [], - } - - # Group history into conversation turns - for i in range(0, len(self.history), 2): - turn_number = i // 2 - turn_data = { - "turn": turn_number + 1, - "user": self.history[i]["content"] if i < len(self.history) else "", - "searches": [], - "assistant": self.history[i + 1]["content"] if i + 1 < len(self.history) else "", - } - - # Add searches for this turn - for search_entry in self.search_history: - if search_entry["turn"] == turn_number: - turn_data["searches"].extend(search_entry["searches"]) - - history_data["turns"].append(turn_data) - - # Write to file - try: - with open(filepath, "w", encoding="utf-8") as f: - json.dump(history_data, f, indent=2, ensure_ascii=False) - - print(f"Chat history saved to JSON: {filepath}") - return filepath - - except Exception as e: - print(f"Error saving chat history to JSON: {e}") - return None - def display_help(self): """Display help information.""" print("\n===== Commands =====") @@ -481,7 +353,6 @@ You are a helpful assistant with tool calling capabilities.""" print("clear - Clear conversation history") print("history - Display full chat history with searches") print("save - Save chat history to a text file") - print("savejson - Save chat history to a JSON file") print("help - Display this help message") print("exit/quit - Exit the program") print("Any other input will be treated as a prompt to the model.") @@ -518,10 +389,6 @@ You are a helpful assistant with tool calling capabilities.""" self.save_chat_history() continue - if user_input.lower() == "savejson": - self.save_chat_history_json() - continue - if user_input.lower().startswith("system "): new_prompt = user_input[7:].strip() self.set_system_prompt(new_prompt) @@ -560,69 +427,6 @@ You are a helpful assistant with tool calling capabilities.""" print(f"Error: {e}") -def extract_json_objects(text): - """ - Extracts JSON objects (dictionaries) from a text that may contain multiple JSON objects. - - Args: - text (str): The input text possibly containing JSON objects. - - Returns: - list: A list of parsed JSON objects (dictionaries) extracted from the text. - """ - results = [] - length = len(text) - i = 0 - - while i < length: - # Look for the start of a JSON object - if text[i] == "{": - start = i - stack = 1 - i += 1 - # Continue until we find the matching closing brace - while i < length and stack > 0: - if text[i] == "{": - stack += 1 - elif text[i] == "}": - stack -= 1 - i += 1 - # Only attempt to decode if the braces are balanced - if stack == 0: - candidate = text[start:i] - try: - obj = json.loads(candidate) - # Optionally, ensure it's a dictionary if that's what you expect - if isinstance(obj, dict): - results.append(obj) - except json.JSONDecodeError: - # If it's not valid JSON, skip it. - pass - else: - i += 1 - return results - - -# Tool definition for search corpus -SEARCH_TOOL_DEFINITION = { - "type": "function", - "function": { - "name": "search_corpus", - "description": "Search over the knowledge corpus with a given query", - "parameters": { - "type": "object", - "properties": { - "query": { - "type": "string", - "description": "The query to search the knowledge corpus with", - }, - }, - "required": ["query"], - }, - }, -} - - def main(): """Main function.""" parser = argparse.ArgumentParser(description="DeepSearch CLI") diff --git a/src/config.py b/src/config.py index ecc767b..8eb4fab 100644 --- a/src/config.py +++ b/src/config.py @@ -100,21 +100,21 @@ def _init_logging(env: str = "development") -> None: file_format = "{time:YYYY-MM-DD at HH:mm:ss} | {level} | {name}:{function}:{line} - {message}" - # Add console logging + # Add console logging with DEBUG level logger.add( sys.stderr, format=console_format, - level="DEBUG" if env == "development" else "INFO", + level="DEBUG", # Always use DEBUG level colorize=True, backtrace=True, - diagnose=env == "development", + diagnose=True, # Always enable diagnostics ) - # Add default file logging to ./logs directory + # Add default file logging to ./logs directory with DEBUG level logger.add( LOG_FOLDER / "app.log", format=file_format, - level="INFO", + level="DEBUG", # Always use DEBUG level rotation="500 MB", retention="7 days", compression="zip", diff --git a/src/rl_helpers.py b/src/rl_helpers.py index 117b86b..7e4a8a7 100644 --- a/src/rl_helpers.py +++ b/src/rl_helpers.py @@ -5,7 +5,6 @@ and calculating rewards based on the quality of responses. """ import inspect -import json import re from dataclasses import dataclass from datetime import datetime @@ -36,29 +35,9 @@ You are a helpful assistant with tool calling capabilities. """ -# Tool definition for search corpus -SEARCH_TOOL_DEFINITION = { - "type": "function", - "function": { - "name": "search_corpus", - "description": "Search over the knowledge corpus with a given query", - "parameters": { - "type": "object", - "properties": { - "query": { - "type": "string", - "description": "The query to search the knowledge corpus with", - }, - }, - "required": ["query"], - }, - }, -} - - def build_user_prompt(q): """ - Build a user prompt with the question and search tool definition. + Build a user prompt with the question using the new template format. Args: q (str): The question to ask @@ -66,17 +45,48 @@ def build_user_prompt(q): Returns: str: Formatted user prompt """ - user_prompt = f"""You are a research assistant, and you use the search_corpus tool to find answers to questions. -Given a question, answer it using by doing searches using the search_corpus tool. -To use the search_corpus tool, respond with a JSON for a function call with its proper arguments. + user_prompt = f"""Answer the given question. \ +You must conduct reasoning inside and first every time you get new information. \ +After reasoning, if you find you lack some knowledge, you can call a search engine by query . \ +You can search as many times as your want. \ +If you find no further external knowledge needed, you can directly provide the answer inside and , without detailed illustrations. For example, Beijing . + +IMPORTANT INSTRUCTIONS: +1. PLEASE CONSIDER CHAT HISTORY WHEN ANSWERING THE QUESTION. +2. ONLY ANSWER WHEN YOU HAVE 100% CONFIDENCE IN THE SEARCH RESULTS, ELSE CONTINUE SEARCHING. +3. PLEASE SEARCH MULTIPLE TIMES WITH DIFFERENT QUERIES. + +Question: {q}\n""" + return user_prompt -You may also reason in any message, think step by step about how to answer the question. Wrap your reasoning in and tags. -{json.dumps(SEARCH_TOOL_DEFINITION, indent=2)} +def format_search_results(results: str | list[str]) -> str: + """ + Format search results for display, matching the format from infer.py. + Each result should be in the format: "Doc X(Title: Y) content" -Question: {q} -""" - return user_prompt + Args: + results: Search results as string or list of strings + + Returns: + Formatted search results with document titles + """ + if isinstance(results, list): + # If results are already in the correct format, just join them + if any("Doc" in r and "Title:" in r for r in results): + content = "\n".join(results) + else: + # If results are raw content, format them with default titles + content = "\n".join([f"Doc {i + 1}(Title: Document {i + 1})\n{r}" for i, r in enumerate(results)]) + else: + # If single result is already formatted, use it as is + if "Doc" in results and "Title:" in results: + content = results + else: + # If single result is raw content, format it with default title + content = f"Doc 1(Title: Document 1)\n{results}" + + return content def get_initial_chat(question): @@ -97,49 +107,6 @@ def get_initial_chat(question): } -def extract_json_objects(text): - """ - Extracts JSON objects (dictionaries) from a text that may contain multiple JSON objects. - - Args: - text (str): The input text possibly containing JSON objects. - - Returns: - list: A list of parsed JSON objects (dictionaries) extracted from the text. - """ - results = [] - length = len(text) - i = 0 - - while i < length: - # Look for the start of a JSON object - if text[i] == "{": - start = i - stack = 1 - i += 1 - # Continue until we find the matching closing brace - while i < length and stack > 0: - if text[i] == "{": - stack += 1 - elif text[i] == "}": - stack -= 1 - i += 1 - # Only attempt to decode if the braces are balanced - if stack == 0: - candidate = text[start:i] - try: - obj = json.loads(candidate) - # Optionally, ensure it's a dictionary if that's what you expect - if isinstance(obj, dict): - results.append(obj) - except json.JSONDecodeError: - # If it's not valid JSON, skip it. - pass - else: - i += 1 - return results - - def remove_reasoning(text: str) -> str: """ Removes all content between and tags, @@ -196,7 +163,7 @@ def run_agent_generations(generate_fn, tokenizer, chat_states): def check_finished_chats(chat_states): """ - Check which chat states are finished (no more function calls). + Check which chat states are finished (no more search queries). Args: chat_states: List of chat states @@ -209,12 +176,27 @@ def check_finished_chats(chat_states): continue assert chat_state["messages"][-1]["role"] == "assistant", "Expected the last role to be assistant" assistant_response = chat_state["messages"][-1]["content"] - function_calls = extract_json_objects(assistant_response) - if len(function_calls) == 0: + # Check if there are any search queries in the response + if not re.search(r".*?", assistant_response, re.DOTALL): chat_state["finished"] = True return chat_states +def extract_search_query(text: str) -> str | None: + """ + Extract search query from text between tags. + + Args: + text (str): Text containing search query + + Returns: + str | None: Search query if found, None otherwise + """ + pattern = re.compile(r"(.*?)", re.DOTALL) + matches = pattern.findall(text) + return matches[-1] if matches else None + + def run_tool_calls(chat_states): """ Execute tool calls found in chat states. @@ -231,21 +213,16 @@ def run_tool_calls(chat_states): ) try: assistant_response = chat_state["messages"][-1]["content"] - function_calls = extract_json_objects(assistant_response) - if len(function_calls) > 1: - logger.warning("Multiple function calls found in assistant response") - raise ValueError("Expected only one function call in assistant response") - elif len(function_calls) == 1: - function_call = function_calls[0] - query = function_call["function"]["parameters"]["query"] - logger.info(f"🔍 Search Query: {query}") - results = search(query, return_type=str, results=2) - chat_state["messages"].append({"role": "ipython", "content": results}) - - # Count retries - retries = len(extract_json_objects(assistant_response)) - total_retries += retries - + search_query = extract_search_query(assistant_response) + if search_query: + logger.info(f"🔍 Search Query: {search_query}") + results = search(search_query, return_type=str, results=2) + # Wrap results in tags + formatted_results = f"{results}" + logger.info(f"â„šī¸ Information: {formatted_results}") + + chat_state["messages"].append({"role": "ipython", "content": formatted_results}) + total_retries += 1 logger.debug("Added search results to chat state") except Exception as e: logger.error(f"Error during tool call: {str(e)}") @@ -598,7 +575,7 @@ def reward_formatting(prompts, completions, **reward_kwargs): def reward_retry_behavior(completions: list[dict], **reward_kwargs) -> list[float]: """ Reward function that encourages optimal retry behavior by only rewarding completions - where every assistant message contains at most 1 JSON object. + where every assistant message contains at most 1 search query. """ rewards: list[float] = [] @@ -614,32 +591,31 @@ def reward_retry_behavior(completions: list[dict], **reward_kwargs) -> list[floa rewards.append(0.0) continue - # Check if every message has at most 1 JSON object - has_multiple_json = False - total_json_objects = 0 + # Check if every message has at most 1 search query + has_multiple_searches = False + total_searches = 0 for msg in assistant_msgs: - json_objects = extract_json_objects(msg) - json_count = len(json_objects) - total_json_objects += json_count + search_count = len(re.findall(r".*?", msg, re.DOTALL)) + total_searches += search_count - if json_count > 1: - has_multiple_json = True - logger.warning(f"Message contains {json_count} JSON objects, which exceeds the limit of 1") + if search_count > 1: + has_multiple_searches = True + logger.warning(f"Message contains {search_count} search queries, which exceeds the limit of 1") break - # Only reward if no message has multiple JSON objects - if has_multiple_json: + # Only reward if no message has multiple search queries + if has_multiple_searches: rewards.append(0.0) else: # Base reward is 1.0 if constraint is met base_reward = 1.0 - # Slight penalty for having too many total JSON objects across all messages - if total_json_objects > 4: - penalty = 0.1 * (total_json_objects - 4) + # Slight penalty for having too many total searches across all messages + if total_searches > 4: + penalty = 0.1 * (total_searches - 4) base_reward = max(0.2, base_reward - penalty) - logger.debug(f"Applied penalty for {total_json_objects} total JSON objects: {penalty}") + logger.debug(f"Applied penalty for {total_searches} total searches: {penalty}") rewards.append(base_reward) @@ -647,10 +623,10 @@ def reward_retry_behavior(completions: list[dict], **reward_kwargs) -> list[floa log_metric("rewards/retry_behavior", np.mean(rewards), reward_kwargs.get("step", 0)) log_metric("rewards/retry_behavior_std", np.std(rewards), reward_kwargs.get("step", 0)) log_metric( - "metrics/avg_json_per_msg", + "metrics/avg_searches_per_msg", np.mean( [ - len(extract_json_objects(msg["content"])) + len(re.findall(r".*?", msg["content"], re.DOTALL)) for completion in completions for msg in completion["messages"] if msg["role"] == "assistant" @@ -659,7 +635,7 @@ def reward_retry_behavior(completions: list[dict], **reward_kwargs) -> list[floa reward_kwargs.get("step", 0), ) log_metric( - "metrics/multiple_json_violation_rate", + "metrics/multiple_search_violation_rate", np.mean([0.0 if rewards[i] > 0.0 else 1.0 for i in range(len(rewards))]), reward_kwargs.get("step", 0), )