diff --git a/eval.py b/eval.py
index 36b7679..b9167ca 100644
--- a/eval.py
+++ b/eval.py
@@ -312,7 +312,7 @@ def evaluate_model(
# Define all output file paths
eval_log_file = os.path.join(eval_log_dir, f"{model_prefix}_model_eval_{timestamp}.log")
output_file = os.path.join(eval_log_dir, f"{model_prefix}_model_results.txt")
- debug_file = os.path.join(eval_log_dir, f"{model_prefix}_model_results_debug.json")
+ debug_file = os.path.join(eval_log_dir, f"{model_prefix}_model_results_debug.txt")
logger.info(f"Writing evaluation log to: {eval_log_file}")
logger.info(f"Results will be saved to: {output_file}")
diff --git a/inference.py b/inference.py
index 85ec4cd..5664fd8 100644
--- a/inference.py
+++ b/inference.py
@@ -6,15 +6,18 @@ and provides search functionality for data retrieval.
"""
import argparse
-import json
import os
import time
from datetime import datetime
-from typing import Any, Dict, List, Optional, Union
from transformers import AutoModelForCausalLM, AutoTokenizer
from vllm import SamplingParams
+from src.rl_helpers import (
+ build_user_prompt,
+ extract_search_query,
+ format_search_results,
+)
from src.search_module import load_vectorstore, search
@@ -43,72 +46,6 @@ def get_sampling_params(temperature: float = 0.7, max_tokens: int = 4096) -> Sam
)
-def extract_function_calls(text: str) -> List[Dict[str, Any]]:
- """Extract function calls from a text."""
- import json
- import re
-
- # Pattern to match JSON objects
- pattern = r"\{(?:[^{}]|(?:\{(?:[^{}]|(?:\{[^{}]*\}))*\}))*\}"
- json_matches = re.findall(pattern, text)
-
- function_calls = []
- for json_str in json_matches:
- try:
- obj = json.loads(json_str)
- if "function" in obj:
- function_calls.append(obj)
- except json.JSONDecodeError:
- continue
-
- return function_calls
-
-
-def build_user_prompt(q):
- """
- Build a user prompt with the question and search tool definition.
-
- Args:
- q (str): The question to ask
-
- Returns:
- str: Formatted user prompt
- """
- user_prompt = f"""You are a research assistant, and you use the search_corpus tool to find answers to questions.
-Given a question, answer it using by doing searches using the search_corpus tool.
-To use the search_corpus tool, respond with a JSON for a function call with its proper arguments.
-
-PLEASE CONSIDER CHAT HISTORY WHEN ANSWERING THE QUESTION.
-ONLY ANSWER WHEN YOU HAVE 100% CONFIDENCE IN THE SEARCH RESULTS, ELSE CONTINUE SEARCHING.
-PLEASE SEARCH MULTIPLE TIMES WITH DIFFERENT QUERIES.
-
-You may also reason in any message, think step by step about how to answer the question. Wrap your reasoning in and tags.
-
-{json.dumps(SEARCH_TOOL_DEFINITION, indent=2)}
-
-Question: {q}
-"""
- return user_prompt
-
-
-def format_search_results(results: Union[str, List[str]]) -> str:
- """
- Format search results for display.
-
- Args:
- results: Search results as string or list of strings
-
- Returns:
- Formatted search results
- """
- if isinstance(results, list):
- content = "\n".join([f"Result {i + 1}:\n{r}\n------" for i, r in enumerate(results)])
- else:
- content = results
-
- return f"\n===== SEARCH RESULTS =====\n{content}\n===========================\n"
-
-
class DeepSearchCLI:
"""CLI for interacting with the model and search functionality."""
@@ -116,7 +53,7 @@ class DeepSearchCLI:
self,
model_path: str,
temperature: float = 0.7,
- system_prompt: Optional[str] = None,
+ system_prompt: str | None = None,
):
"""
Initialize the CLI.
@@ -213,16 +150,16 @@ You are a helpful assistant with tool calling capabilities."""
return final_response
def _check_finished_chat(self, chat_state: dict) -> dict:
- """Check if the chat is finished (no more function calls)."""
+ """Check if the chat is finished (no more search queries)."""
if chat_state.get("finished"):
return chat_state
assert chat_state["messages"][-1]["role"] == "assistant", "Expected the last role to be assistant"
assistant_response = chat_state["messages"][-1]["content"]
- function_calls = extract_function_calls(assistant_response)
+ search_query = extract_search_query(assistant_response)
- if len(function_calls) == 0:
+ if not search_query:
chat_state["finished"] = True
return chat_state
@@ -234,32 +171,26 @@ You are a helpful assistant with tool calling capabilities."""
try:
assistant_response = chat_state["messages"][-1]["content"]
- function_calls = extract_function_calls(assistant_response)
-
- if len(function_calls) > 1:
- print("Multiple function calls found in assistant response")
- raise ValueError("Expected only one function call in assistant response")
+ search_query = extract_search_query(assistant_response)
- elif len(function_calls) == 1:
- function_call = function_calls[0]
- query = function_call["function"]["parameters"]["query"]
- print(f"đ Search Query: {query}")
+ if search_query:
+ print(f"đ Search Query: {search_query}")
- results = search(query, return_type=str, results=2)
+ results = search(search_query, return_type=str, results=2)
+ # Wrap results in tags
+ formatted_results = f"{results}"
# Print search results to terminal
- # logger.info("\n===== SEARCH RESULTS =====")
- # logger.info(
- # results
- # ) # The results are already formatted with Result 1:, Result 2:, etc.
- # logger.info("===========================\n")
+ print("\n===== SEARCH RESULTS =====")
+ print(results)
+ print("===========================\n")
- chat_state["messages"].append({"role": "ipython", "content": results})
+ chat_state["messages"].append({"role": "ipython", "content": formatted_results})
# Record search in history
search_entry = {
"turn": len(self.history) // 2,
- "searches": [{"query": query, "results": results}],
+ "searches": [{"query": search_query, "results": results}],
}
self.search_history.append(search_entry)
@@ -414,65 +345,6 @@ You are a helpful assistant with tool calling capabilities."""
print(f"Error saving chat history: {e}")
return None
- def save_chat_history_json(self, filepath=None):
- """
- Save chat history to a JSON file.
-
- Args:
- filepath: Path to save file (if None, auto-generate based on timestamp)
-
- Returns:
- Path to the saved file
- """
- if not self.history:
- print("No chat history to save.")
- return None
-
- # Generate a default filepath if none provided
- if filepath is None:
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
- filepath = os.path.join(os.getcwd(), f"chat_history_{timestamp}.json")
-
- # Ensure the directory exists
- os.makedirs(os.path.dirname(filepath), exist_ok=True)
-
- # Prepare chat history data
- history_data = {
- "model": self.model.name_or_path,
- "temperature": self.temperature,
- "timestamp": datetime.now().isoformat(),
- "turns": [],
- }
-
- # Group history into conversation turns
- for i in range(0, len(self.history), 2):
- turn_number = i // 2
- turn_data = {
- "turn": turn_number + 1,
- "user": self.history[i]["content"] if i < len(self.history) else "",
- "searches": [],
- "assistant": self.history[i + 1]["content"] if i + 1 < len(self.history) else "",
- }
-
- # Add searches for this turn
- for search_entry in self.search_history:
- if search_entry["turn"] == turn_number:
- turn_data["searches"].extend(search_entry["searches"])
-
- history_data["turns"].append(turn_data)
-
- # Write to file
- try:
- with open(filepath, "w", encoding="utf-8") as f:
- json.dump(history_data, f, indent=2, ensure_ascii=False)
-
- print(f"Chat history saved to JSON: {filepath}")
- return filepath
-
- except Exception as e:
- print(f"Error saving chat history to JSON: {e}")
- return None
-
def display_help(self):
"""Display help information."""
print("\n===== Commands =====")
@@ -481,7 +353,6 @@ You are a helpful assistant with tool calling capabilities."""
print("clear - Clear conversation history")
print("history - Display full chat history with searches")
print("save - Save chat history to a text file")
- print("savejson - Save chat history to a JSON file")
print("help - Display this help message")
print("exit/quit - Exit the program")
print("Any other input will be treated as a prompt to the model.")
@@ -518,10 +389,6 @@ You are a helpful assistant with tool calling capabilities."""
self.save_chat_history()
continue
- if user_input.lower() == "savejson":
- self.save_chat_history_json()
- continue
-
if user_input.lower().startswith("system "):
new_prompt = user_input[7:].strip()
self.set_system_prompt(new_prompt)
@@ -560,69 +427,6 @@ You are a helpful assistant with tool calling capabilities."""
print(f"Error: {e}")
-def extract_json_objects(text):
- """
- Extracts JSON objects (dictionaries) from a text that may contain multiple JSON objects.
-
- Args:
- text (str): The input text possibly containing JSON objects.
-
- Returns:
- list: A list of parsed JSON objects (dictionaries) extracted from the text.
- """
- results = []
- length = len(text)
- i = 0
-
- while i < length:
- # Look for the start of a JSON object
- if text[i] == "{":
- start = i
- stack = 1
- i += 1
- # Continue until we find the matching closing brace
- while i < length and stack > 0:
- if text[i] == "{":
- stack += 1
- elif text[i] == "}":
- stack -= 1
- i += 1
- # Only attempt to decode if the braces are balanced
- if stack == 0:
- candidate = text[start:i]
- try:
- obj = json.loads(candidate)
- # Optionally, ensure it's a dictionary if that's what you expect
- if isinstance(obj, dict):
- results.append(obj)
- except json.JSONDecodeError:
- # If it's not valid JSON, skip it.
- pass
- else:
- i += 1
- return results
-
-
-# Tool definition for search corpus
-SEARCH_TOOL_DEFINITION = {
- "type": "function",
- "function": {
- "name": "search_corpus",
- "description": "Search over the knowledge corpus with a given query",
- "parameters": {
- "type": "object",
- "properties": {
- "query": {
- "type": "string",
- "description": "The query to search the knowledge corpus with",
- },
- },
- "required": ["query"],
- },
- },
-}
-
-
def main():
"""Main function."""
parser = argparse.ArgumentParser(description="DeepSearch CLI")
diff --git a/src/config.py b/src/config.py
index ecc767b..8eb4fab 100644
--- a/src/config.py
+++ b/src/config.py
@@ -100,21 +100,21 @@ def _init_logging(env: str = "development") -> None:
file_format = "{time:YYYY-MM-DD at HH:mm:ss} | {level} | {name}:{function}:{line} - {message}"
- # Add console logging
+ # Add console logging with DEBUG level
logger.add(
sys.stderr,
format=console_format,
- level="DEBUG" if env == "development" else "INFO",
+ level="DEBUG", # Always use DEBUG level
colorize=True,
backtrace=True,
- diagnose=env == "development",
+ diagnose=True, # Always enable diagnostics
)
- # Add default file logging to ./logs directory
+ # Add default file logging to ./logs directory with DEBUG level
logger.add(
LOG_FOLDER / "app.log",
format=file_format,
- level="INFO",
+ level="DEBUG", # Always use DEBUG level
rotation="500 MB",
retention="7 days",
compression="zip",
diff --git a/src/rl_helpers.py b/src/rl_helpers.py
index 117b86b..7e4a8a7 100644
--- a/src/rl_helpers.py
+++ b/src/rl_helpers.py
@@ -5,7 +5,6 @@ and calculating rewards based on the quality of responses.
"""
import inspect
-import json
import re
from dataclasses import dataclass
from datetime import datetime
@@ -36,29 +35,9 @@ You are a helpful assistant with tool calling capabilities.
"""
-# Tool definition for search corpus
-SEARCH_TOOL_DEFINITION = {
- "type": "function",
- "function": {
- "name": "search_corpus",
- "description": "Search over the knowledge corpus with a given query",
- "parameters": {
- "type": "object",
- "properties": {
- "query": {
- "type": "string",
- "description": "The query to search the knowledge corpus with",
- },
- },
- "required": ["query"],
- },
- },
-}
-
-
def build_user_prompt(q):
"""
- Build a user prompt with the question and search tool definition.
+ Build a user prompt with the question using the new template format.
Args:
q (str): The question to ask
@@ -66,17 +45,48 @@ def build_user_prompt(q):
Returns:
str: Formatted user prompt
"""
- user_prompt = f"""You are a research assistant, and you use the search_corpus tool to find answers to questions.
-Given a question, answer it using by doing searches using the search_corpus tool.
-To use the search_corpus tool, respond with a JSON for a function call with its proper arguments.
+ user_prompt = f"""Answer the given question. \
+You must conduct reasoning inside and first every time you get new information. \
+After reasoning, if you find you lack some knowledge, you can call a search engine by query . \
+You can search as many times as your want. \
+If you find no further external knowledge needed, you can directly provide the answer inside and , without detailed illustrations. For example, Beijing .
+
+IMPORTANT INSTRUCTIONS:
+1. PLEASE CONSIDER CHAT HISTORY WHEN ANSWERING THE QUESTION.
+2. ONLY ANSWER WHEN YOU HAVE 100% CONFIDENCE IN THE SEARCH RESULTS, ELSE CONTINUE SEARCHING.
+3. PLEASE SEARCH MULTIPLE TIMES WITH DIFFERENT QUERIES.
+
+Question: {q}\n"""
+ return user_prompt
-You may also reason in any message, think step by step about how to answer the question. Wrap your reasoning in and tags.
-{json.dumps(SEARCH_TOOL_DEFINITION, indent=2)}
+def format_search_results(results: str | list[str]) -> str:
+ """
+ Format search results for display, matching the format from infer.py.
+ Each result should be in the format: "Doc X(Title: Y) content"
-Question: {q}
-"""
- return user_prompt
+ Args:
+ results: Search results as string or list of strings
+
+ Returns:
+ Formatted search results with document titles
+ """
+ if isinstance(results, list):
+ # If results are already in the correct format, just join them
+ if any("Doc" in r and "Title:" in r for r in results):
+ content = "\n".join(results)
+ else:
+ # If results are raw content, format them with default titles
+ content = "\n".join([f"Doc {i + 1}(Title: Document {i + 1})\n{r}" for i, r in enumerate(results)])
+ else:
+ # If single result is already formatted, use it as is
+ if "Doc" in results and "Title:" in results:
+ content = results
+ else:
+ # If single result is raw content, format it with default title
+ content = f"Doc 1(Title: Document 1)\n{results}"
+
+ return content
def get_initial_chat(question):
@@ -97,49 +107,6 @@ def get_initial_chat(question):
}
-def extract_json_objects(text):
- """
- Extracts JSON objects (dictionaries) from a text that may contain multiple JSON objects.
-
- Args:
- text (str): The input text possibly containing JSON objects.
-
- Returns:
- list: A list of parsed JSON objects (dictionaries) extracted from the text.
- """
- results = []
- length = len(text)
- i = 0
-
- while i < length:
- # Look for the start of a JSON object
- if text[i] == "{":
- start = i
- stack = 1
- i += 1
- # Continue until we find the matching closing brace
- while i < length and stack > 0:
- if text[i] == "{":
- stack += 1
- elif text[i] == "}":
- stack -= 1
- i += 1
- # Only attempt to decode if the braces are balanced
- if stack == 0:
- candidate = text[start:i]
- try:
- obj = json.loads(candidate)
- # Optionally, ensure it's a dictionary if that's what you expect
- if isinstance(obj, dict):
- results.append(obj)
- except json.JSONDecodeError:
- # If it's not valid JSON, skip it.
- pass
- else:
- i += 1
- return results
-
-
def remove_reasoning(text: str) -> str:
"""
Removes all content between and tags,
@@ -196,7 +163,7 @@ def run_agent_generations(generate_fn, tokenizer, chat_states):
def check_finished_chats(chat_states):
"""
- Check which chat states are finished (no more function calls).
+ Check which chat states are finished (no more search queries).
Args:
chat_states: List of chat states
@@ -209,12 +176,27 @@ def check_finished_chats(chat_states):
continue
assert chat_state["messages"][-1]["role"] == "assistant", "Expected the last role to be assistant"
assistant_response = chat_state["messages"][-1]["content"]
- function_calls = extract_json_objects(assistant_response)
- if len(function_calls) == 0:
+ # Check if there are any search queries in the response
+ if not re.search(r".*?", assistant_response, re.DOTALL):
chat_state["finished"] = True
return chat_states
+def extract_search_query(text: str) -> str | None:
+ """
+ Extract search query from text between tags.
+
+ Args:
+ text (str): Text containing search query
+
+ Returns:
+ str | None: Search query if found, None otherwise
+ """
+ pattern = re.compile(r"(.*?)", re.DOTALL)
+ matches = pattern.findall(text)
+ return matches[-1] if matches else None
+
+
def run_tool_calls(chat_states):
"""
Execute tool calls found in chat states.
@@ -231,21 +213,16 @@ def run_tool_calls(chat_states):
)
try:
assistant_response = chat_state["messages"][-1]["content"]
- function_calls = extract_json_objects(assistant_response)
- if len(function_calls) > 1:
- logger.warning("Multiple function calls found in assistant response")
- raise ValueError("Expected only one function call in assistant response")
- elif len(function_calls) == 1:
- function_call = function_calls[0]
- query = function_call["function"]["parameters"]["query"]
- logger.info(f"đ Search Query: {query}")
- results = search(query, return_type=str, results=2)
- chat_state["messages"].append({"role": "ipython", "content": results})
-
- # Count retries
- retries = len(extract_json_objects(assistant_response))
- total_retries += retries
-
+ search_query = extract_search_query(assistant_response)
+ if search_query:
+ logger.info(f"đ Search Query: {search_query}")
+ results = search(search_query, return_type=str, results=2)
+ # Wrap results in tags
+ formatted_results = f"{results}"
+ logger.info(f"âšī¸ Information: {formatted_results}")
+
+ chat_state["messages"].append({"role": "ipython", "content": formatted_results})
+ total_retries += 1
logger.debug("Added search results to chat state")
except Exception as e:
logger.error(f"Error during tool call: {str(e)}")
@@ -598,7 +575,7 @@ def reward_formatting(prompts, completions, **reward_kwargs):
def reward_retry_behavior(completions: list[dict], **reward_kwargs) -> list[float]:
"""
Reward function that encourages optimal retry behavior by only rewarding completions
- where every assistant message contains at most 1 JSON object.
+ where every assistant message contains at most 1 search query.
"""
rewards: list[float] = []
@@ -614,32 +591,31 @@ def reward_retry_behavior(completions: list[dict], **reward_kwargs) -> list[floa
rewards.append(0.0)
continue
- # Check if every message has at most 1 JSON object
- has_multiple_json = False
- total_json_objects = 0
+ # Check if every message has at most 1 search query
+ has_multiple_searches = False
+ total_searches = 0
for msg in assistant_msgs:
- json_objects = extract_json_objects(msg)
- json_count = len(json_objects)
- total_json_objects += json_count
+ search_count = len(re.findall(r".*?", msg, re.DOTALL))
+ total_searches += search_count
- if json_count > 1:
- has_multiple_json = True
- logger.warning(f"Message contains {json_count} JSON objects, which exceeds the limit of 1")
+ if search_count > 1:
+ has_multiple_searches = True
+ logger.warning(f"Message contains {search_count} search queries, which exceeds the limit of 1")
break
- # Only reward if no message has multiple JSON objects
- if has_multiple_json:
+ # Only reward if no message has multiple search queries
+ if has_multiple_searches:
rewards.append(0.0)
else:
# Base reward is 1.0 if constraint is met
base_reward = 1.0
- # Slight penalty for having too many total JSON objects across all messages
- if total_json_objects > 4:
- penalty = 0.1 * (total_json_objects - 4)
+ # Slight penalty for having too many total searches across all messages
+ if total_searches > 4:
+ penalty = 0.1 * (total_searches - 4)
base_reward = max(0.2, base_reward - penalty)
- logger.debug(f"Applied penalty for {total_json_objects} total JSON objects: {penalty}")
+ logger.debug(f"Applied penalty for {total_searches} total searches: {penalty}")
rewards.append(base_reward)
@@ -647,10 +623,10 @@ def reward_retry_behavior(completions: list[dict], **reward_kwargs) -> list[floa
log_metric("rewards/retry_behavior", np.mean(rewards), reward_kwargs.get("step", 0))
log_metric("rewards/retry_behavior_std", np.std(rewards), reward_kwargs.get("step", 0))
log_metric(
- "metrics/avg_json_per_msg",
+ "metrics/avg_searches_per_msg",
np.mean(
[
- len(extract_json_objects(msg["content"]))
+ len(re.findall(r".*?", msg["content"], re.DOTALL))
for completion in completions
for msg in completion["messages"]
if msg["role"] == "assistant"
@@ -659,7 +635,7 @@ def reward_retry_behavior(completions: list[dict], **reward_kwargs) -> list[floa
reward_kwargs.get("step", 0),
)
log_metric(
- "metrics/multiple_json_violation_rate",
+ "metrics/multiple_search_violation_rate",
np.mean([0.0 if rewards[i] > 0.0 else 1.0 for i in range(len(rewards))]),
reward_kwargs.get("step", 0),
)