feat: change user prompt template to search-r1 inspried format

use <search></search> instead of embed whole tool definition, which resulted in lots or parsing errors
main
thinhlpg 1 month ago
parent 58dcf9a99d
commit c90c03267e

@ -312,7 +312,7 @@ def evaluate_model(
# Define all output file paths # Define all output file paths
eval_log_file = os.path.join(eval_log_dir, f"{model_prefix}_model_eval_{timestamp}.log") eval_log_file = os.path.join(eval_log_dir, f"{model_prefix}_model_eval_{timestamp}.log")
output_file = os.path.join(eval_log_dir, f"{model_prefix}_model_results.txt") output_file = os.path.join(eval_log_dir, f"{model_prefix}_model_results.txt")
debug_file = os.path.join(eval_log_dir, f"{model_prefix}_model_results_debug.json") debug_file = os.path.join(eval_log_dir, f"{model_prefix}_model_results_debug.txt")
logger.info(f"Writing evaluation log to: {eval_log_file}") logger.info(f"Writing evaluation log to: {eval_log_file}")
logger.info(f"Results will be saved to: {output_file}") logger.info(f"Results will be saved to: {output_file}")

@ -6,15 +6,18 @@ and provides search functionality for data retrieval.
""" """
import argparse import argparse
import json
import os import os
import time import time
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Optional, Union
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from vllm import SamplingParams from vllm import SamplingParams
from src.rl_helpers import (
build_user_prompt,
extract_search_query,
format_search_results,
)
from src.search_module import load_vectorstore, search from src.search_module import load_vectorstore, search
@ -43,72 +46,6 @@ def get_sampling_params(temperature: float = 0.7, max_tokens: int = 4096) -> Sam
) )
def extract_function_calls(text: str) -> List[Dict[str, Any]]:
"""Extract function calls from a text."""
import json
import re
# Pattern to match JSON objects
pattern = r"\{(?:[^{}]|(?:\{(?:[^{}]|(?:\{[^{}]*\}))*\}))*\}"
json_matches = re.findall(pattern, text)
function_calls = []
for json_str in json_matches:
try:
obj = json.loads(json_str)
if "function" in obj:
function_calls.append(obj)
except json.JSONDecodeError:
continue
return function_calls
def build_user_prompt(q):
"""
Build a user prompt with the question and search tool definition.
Args:
q (str): The question to ask
Returns:
str: Formatted user prompt
"""
user_prompt = f"""You are a research assistant, and you use the search_corpus tool to find answers to questions.
Given a question, answer it using by doing searches using the search_corpus tool.
To use the search_corpus tool, respond with a JSON for a function call with its proper arguments.
PLEASE CONSIDER CHAT HISTORY WHEN ANSWERING THE QUESTION.
ONLY ANSWER WHEN YOU HAVE 100% CONFIDENCE IN THE SEARCH RESULTS, ELSE CONTINUE SEARCHING.
PLEASE SEARCH MULTIPLE TIMES WITH DIFFERENT QUERIES.
You may also reason in any message, think step by step about how to answer the question. Wrap your reasoning in <think> and </think> tags.
{json.dumps(SEARCH_TOOL_DEFINITION, indent=2)}
Question: {q}
"""
return user_prompt
def format_search_results(results: Union[str, List[str]]) -> str:
"""
Format search results for display.
Args:
results: Search results as string or list of strings
Returns:
Formatted search results
"""
if isinstance(results, list):
content = "\n".join([f"Result {i + 1}:\n{r}\n------" for i, r in enumerate(results)])
else:
content = results
return f"\n===== SEARCH RESULTS =====\n{content}\n===========================\n"
class DeepSearchCLI: class DeepSearchCLI:
"""CLI for interacting with the model and search functionality.""" """CLI for interacting with the model and search functionality."""
@ -116,7 +53,7 @@ class DeepSearchCLI:
self, self,
model_path: str, model_path: str,
temperature: float = 0.7, temperature: float = 0.7,
system_prompt: Optional[str] = None, system_prompt: str | None = None,
): ):
""" """
Initialize the CLI. Initialize the CLI.
@ -213,16 +150,16 @@ You are a helpful assistant with tool calling capabilities."""
return final_response return final_response
def _check_finished_chat(self, chat_state: dict) -> dict: def _check_finished_chat(self, chat_state: dict) -> dict:
"""Check if the chat is finished (no more function calls).""" """Check if the chat is finished (no more search queries)."""
if chat_state.get("finished"): if chat_state.get("finished"):
return chat_state return chat_state
assert chat_state["messages"][-1]["role"] == "assistant", "Expected the last role to be assistant" assert chat_state["messages"][-1]["role"] == "assistant", "Expected the last role to be assistant"
assistant_response = chat_state["messages"][-1]["content"] assistant_response = chat_state["messages"][-1]["content"]
function_calls = extract_function_calls(assistant_response) search_query = extract_search_query(assistant_response)
if len(function_calls) == 0: if not search_query:
chat_state["finished"] = True chat_state["finished"] = True
return chat_state return chat_state
@ -234,32 +171,26 @@ You are a helpful assistant with tool calling capabilities."""
try: try:
assistant_response = chat_state["messages"][-1]["content"] assistant_response = chat_state["messages"][-1]["content"]
function_calls = extract_function_calls(assistant_response) search_query = extract_search_query(assistant_response)
if len(function_calls) > 1:
print("Multiple function calls found in assistant response")
raise ValueError("Expected only one function call in assistant response")
elif len(function_calls) == 1: if search_query:
function_call = function_calls[0] print(f"🔍 Search Query: {search_query}")
query = function_call["function"]["parameters"]["query"]
print(f"🔍 Search Query: {query}")
results = search(query, return_type=str, results=2) results = search(search_query, return_type=str, results=2)
# Wrap results in <information> tags
formatted_results = f"<information>{results}</information>"
# Print search results to terminal # Print search results to terminal
# logger.info("\n===== SEARCH RESULTS =====") print("\n===== SEARCH RESULTS =====")
# logger.info( print(results)
# results print("===========================\n")
# ) # The results are already formatted with Result 1:, Result 2:, etc.
# logger.info("===========================\n")
chat_state["messages"].append({"role": "ipython", "content": results}) chat_state["messages"].append({"role": "ipython", "content": formatted_results})
# Record search in history # Record search in history
search_entry = { search_entry = {
"turn": len(self.history) // 2, "turn": len(self.history) // 2,
"searches": [{"query": query, "results": results}], "searches": [{"query": search_query, "results": results}],
} }
self.search_history.append(search_entry) self.search_history.append(search_entry)
@ -414,65 +345,6 @@ You are a helpful assistant with tool calling capabilities."""
print(f"Error saving chat history: {e}") print(f"Error saving chat history: {e}")
return None return None
def save_chat_history_json(self, filepath=None):
"""
Save chat history to a JSON file.
Args:
filepath: Path to save file (if None, auto-generate based on timestamp)
Returns:
Path to the saved file
"""
if not self.history:
print("No chat history to save.")
return None
# Generate a default filepath if none provided
if filepath is None:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filepath = os.path.join(os.getcwd(), f"chat_history_{timestamp}.json")
# Ensure the directory exists
os.makedirs(os.path.dirname(filepath), exist_ok=True)
# Prepare chat history data
history_data = {
"model": self.model.name_or_path,
"temperature": self.temperature,
"timestamp": datetime.now().isoformat(),
"turns": [],
}
# Group history into conversation turns
for i in range(0, len(self.history), 2):
turn_number = i // 2
turn_data = {
"turn": turn_number + 1,
"user": self.history[i]["content"] if i < len(self.history) else "",
"searches": [],
"assistant": self.history[i + 1]["content"] if i + 1 < len(self.history) else "",
}
# Add searches for this turn
for search_entry in self.search_history:
if search_entry["turn"] == turn_number:
turn_data["searches"].extend(search_entry["searches"])
history_data["turns"].append(turn_data)
# Write to file
try:
with open(filepath, "w", encoding="utf-8") as f:
json.dump(history_data, f, indent=2, ensure_ascii=False)
print(f"Chat history saved to JSON: {filepath}")
return filepath
except Exception as e:
print(f"Error saving chat history to JSON: {e}")
return None
def display_help(self): def display_help(self):
"""Display help information.""" """Display help information."""
print("\n===== Commands =====") print("\n===== Commands =====")
@ -481,7 +353,6 @@ You are a helpful assistant with tool calling capabilities."""
print("clear - Clear conversation history") print("clear - Clear conversation history")
print("history - Display full chat history with searches") print("history - Display full chat history with searches")
print("save - Save chat history to a text file") print("save - Save chat history to a text file")
print("savejson - Save chat history to a JSON file")
print("help - Display this help message") print("help - Display this help message")
print("exit/quit - Exit the program") print("exit/quit - Exit the program")
print("Any other input will be treated as a prompt to the model.") print("Any other input will be treated as a prompt to the model.")
@ -518,10 +389,6 @@ You are a helpful assistant with tool calling capabilities."""
self.save_chat_history() self.save_chat_history()
continue continue
if user_input.lower() == "savejson":
self.save_chat_history_json()
continue
if user_input.lower().startswith("system "): if user_input.lower().startswith("system "):
new_prompt = user_input[7:].strip() new_prompt = user_input[7:].strip()
self.set_system_prompt(new_prompt) self.set_system_prompt(new_prompt)
@ -560,69 +427,6 @@ You are a helpful assistant with tool calling capabilities."""
print(f"Error: {e}") print(f"Error: {e}")
def extract_json_objects(text):
"""
Extracts JSON objects (dictionaries) from a text that may contain multiple JSON objects.
Args:
text (str): The input text possibly containing JSON objects.
Returns:
list: A list of parsed JSON objects (dictionaries) extracted from the text.
"""
results = []
length = len(text)
i = 0
while i < length:
# Look for the start of a JSON object
if text[i] == "{":
start = i
stack = 1
i += 1
# Continue until we find the matching closing brace
while i < length and stack > 0:
if text[i] == "{":
stack += 1
elif text[i] == "}":
stack -= 1
i += 1
# Only attempt to decode if the braces are balanced
if stack == 0:
candidate = text[start:i]
try:
obj = json.loads(candidate)
# Optionally, ensure it's a dictionary if that's what you expect
if isinstance(obj, dict):
results.append(obj)
except json.JSONDecodeError:
# If it's not valid JSON, skip it.
pass
else:
i += 1
return results
# Tool definition for search corpus
SEARCH_TOOL_DEFINITION = {
"type": "function",
"function": {
"name": "search_corpus",
"description": "Search over the knowledge corpus with a given query",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The query to search the knowledge corpus with",
},
},
"required": ["query"],
},
},
}
def main(): def main():
"""Main function.""" """Main function."""
parser = argparse.ArgumentParser(description="DeepSearch CLI") parser = argparse.ArgumentParser(description="DeepSearch CLI")

@ -100,21 +100,21 @@ def _init_logging(env: str = "development") -> None:
file_format = "{time:YYYY-MM-DD at HH:mm:ss} | {level} | {name}:{function}:{line} - {message}" file_format = "{time:YYYY-MM-DD at HH:mm:ss} | {level} | {name}:{function}:{line} - {message}"
# Add console logging # Add console logging with DEBUG level
logger.add( logger.add(
sys.stderr, sys.stderr,
format=console_format, format=console_format,
level="DEBUG" if env == "development" else "INFO", level="DEBUG", # Always use DEBUG level
colorize=True, colorize=True,
backtrace=True, backtrace=True,
diagnose=env == "development", diagnose=True, # Always enable diagnostics
) )
# Add default file logging to ./logs directory # Add default file logging to ./logs directory with DEBUG level
logger.add( logger.add(
LOG_FOLDER / "app.log", LOG_FOLDER / "app.log",
format=file_format, format=file_format,
level="INFO", level="DEBUG", # Always use DEBUG level
rotation="500 MB", rotation="500 MB",
retention="7 days", retention="7 days",
compression="zip", compression="zip",

@ -5,7 +5,6 @@ and calculating rewards based on the quality of responses.
""" """
import inspect import inspect
import json
import re import re
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
@ -36,29 +35,9 @@ You are a helpful assistant with tool calling capabilities.
""" """
# Tool definition for search corpus
SEARCH_TOOL_DEFINITION = {
"type": "function",
"function": {
"name": "search_corpus",
"description": "Search over the knowledge corpus with a given query",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The query to search the knowledge corpus with",
},
},
"required": ["query"],
},
},
}
def build_user_prompt(q): def build_user_prompt(q):
""" """
Build a user prompt with the question and search tool definition. Build a user prompt with the question using the new template format.
Args: Args:
q (str): The question to ask q (str): The question to ask
@ -66,17 +45,48 @@ def build_user_prompt(q):
Returns: Returns:
str: Formatted user prompt str: Formatted user prompt
""" """
user_prompt = f"""You are a research assistant, and you use the search_corpus tool to find answers to questions. user_prompt = f"""Answer the given question. \
Given a question, answer it using by doing searches using the search_corpus tool. You must conduct reasoning inside <think> and </think> first every time you get new information. \
To use the search_corpus tool, respond with a JSON for a function call with its proper arguments. After reasoning, if you find you lack some knowledge, you can call a search engine by <search> query </search>. \
You can search as many times as your want. \
If you find no further external knowledge needed, you can directly provide the answer inside <answer> and </answer>, without detailed illustrations. For example, <answer> Beijing </answer>.
IMPORTANT INSTRUCTIONS:
1. PLEASE CONSIDER CHAT HISTORY WHEN ANSWERING THE QUESTION.
2. ONLY ANSWER WHEN YOU HAVE 100% CONFIDENCE IN THE SEARCH RESULTS, ELSE CONTINUE SEARCHING.
3. PLEASE SEARCH MULTIPLE TIMES WITH DIFFERENT QUERIES.
Question: {q}\n"""
return user_prompt
You may also reason in any message, think step by step about how to answer the question. Wrap your reasoning in <think> and </think> tags.
{json.dumps(SEARCH_TOOL_DEFINITION, indent=2)} def format_search_results(results: str | list[str]) -> str:
"""
Format search results for display, matching the format from infer.py.
Each result should be in the format: "Doc X(Title: Y) content"
Question: {q} Args:
""" results: Search results as string or list of strings
return user_prompt
Returns:
Formatted search results with document titles
"""
if isinstance(results, list):
# If results are already in the correct format, just join them
if any("Doc" in r and "Title:" in r for r in results):
content = "\n".join(results)
else:
# If results are raw content, format them with default titles
content = "\n".join([f"Doc {i + 1}(Title: Document {i + 1})\n{r}" for i, r in enumerate(results)])
else:
# If single result is already formatted, use it as is
if "Doc" in results and "Title:" in results:
content = results
else:
# If single result is raw content, format it with default title
content = f"Doc 1(Title: Document 1)\n{results}"
return content
def get_initial_chat(question): def get_initial_chat(question):
@ -97,49 +107,6 @@ def get_initial_chat(question):
} }
def extract_json_objects(text):
"""
Extracts JSON objects (dictionaries) from a text that may contain multiple JSON objects.
Args:
text (str): The input text possibly containing JSON objects.
Returns:
list: A list of parsed JSON objects (dictionaries) extracted from the text.
"""
results = []
length = len(text)
i = 0
while i < length:
# Look for the start of a JSON object
if text[i] == "{":
start = i
stack = 1
i += 1
# Continue until we find the matching closing brace
while i < length and stack > 0:
if text[i] == "{":
stack += 1
elif text[i] == "}":
stack -= 1
i += 1
# Only attempt to decode if the braces are balanced
if stack == 0:
candidate = text[start:i]
try:
obj = json.loads(candidate)
# Optionally, ensure it's a dictionary if that's what you expect
if isinstance(obj, dict):
results.append(obj)
except json.JSONDecodeError:
# If it's not valid JSON, skip it.
pass
else:
i += 1
return results
def remove_reasoning(text: str) -> str: def remove_reasoning(text: str) -> str:
""" """
Removes all content between <think> and </think> tags, Removes all content between <think> and </think> tags,
@ -196,7 +163,7 @@ def run_agent_generations(generate_fn, tokenizer, chat_states):
def check_finished_chats(chat_states): def check_finished_chats(chat_states):
""" """
Check which chat states are finished (no more function calls). Check which chat states are finished (no more search queries).
Args: Args:
chat_states: List of chat states chat_states: List of chat states
@ -209,12 +176,27 @@ def check_finished_chats(chat_states):
continue continue
assert chat_state["messages"][-1]["role"] == "assistant", "Expected the last role to be assistant" assert chat_state["messages"][-1]["role"] == "assistant", "Expected the last role to be assistant"
assistant_response = chat_state["messages"][-1]["content"] assistant_response = chat_state["messages"][-1]["content"]
function_calls = extract_json_objects(assistant_response) # Check if there are any search queries in the response
if len(function_calls) == 0: if not re.search(r"<search>.*?</search>", assistant_response, re.DOTALL):
chat_state["finished"] = True chat_state["finished"] = True
return chat_states return chat_states
def extract_search_query(text: str) -> str | None:
"""
Extract search query from text between <search> tags.
Args:
text (str): Text containing search query
Returns:
str | None: Search query if found, None otherwise
"""
pattern = re.compile(r"<search>(.*?)</search>", re.DOTALL)
matches = pattern.findall(text)
return matches[-1] if matches else None
def run_tool_calls(chat_states): def run_tool_calls(chat_states):
""" """
Execute tool calls found in chat states. Execute tool calls found in chat states.
@ -231,21 +213,16 @@ def run_tool_calls(chat_states):
) )
try: try:
assistant_response = chat_state["messages"][-1]["content"] assistant_response = chat_state["messages"][-1]["content"]
function_calls = extract_json_objects(assistant_response) search_query = extract_search_query(assistant_response)
if len(function_calls) > 1: if search_query:
logger.warning("Multiple function calls found in assistant response") logger.info(f"🔍 Search Query: {search_query}")
raise ValueError("Expected only one function call in assistant response") results = search(search_query, return_type=str, results=2)
elif len(function_calls) == 1: # Wrap results in <information> tags
function_call = function_calls[0] formatted_results = f"<information>{results}</information>"
query = function_call["function"]["parameters"]["query"] logger.info(f" Information: {formatted_results}")
logger.info(f"🔍 Search Query: {query}")
results = search(query, return_type=str, results=2) chat_state["messages"].append({"role": "ipython", "content": formatted_results})
chat_state["messages"].append({"role": "ipython", "content": results}) total_retries += 1
# Count retries
retries = len(extract_json_objects(assistant_response))
total_retries += retries
logger.debug("Added search results to chat state") logger.debug("Added search results to chat state")
except Exception as e: except Exception as e:
logger.error(f"Error during tool call: {str(e)}") logger.error(f"Error during tool call: {str(e)}")
@ -598,7 +575,7 @@ def reward_formatting(prompts, completions, **reward_kwargs):
def reward_retry_behavior(completions: list[dict], **reward_kwargs) -> list[float]: def reward_retry_behavior(completions: list[dict], **reward_kwargs) -> list[float]:
""" """
Reward function that encourages optimal retry behavior by only rewarding completions Reward function that encourages optimal retry behavior by only rewarding completions
where every assistant message contains at most 1 JSON object. where every assistant message contains at most 1 search query.
""" """
rewards: list[float] = [] rewards: list[float] = []
@ -614,32 +591,31 @@ def reward_retry_behavior(completions: list[dict], **reward_kwargs) -> list[floa
rewards.append(0.0) rewards.append(0.0)
continue continue
# Check if every message has at most 1 JSON object # Check if every message has at most 1 search query
has_multiple_json = False has_multiple_searches = False
total_json_objects = 0 total_searches = 0
for msg in assistant_msgs: for msg in assistant_msgs:
json_objects = extract_json_objects(msg) search_count = len(re.findall(r"<search>.*?</search>", msg, re.DOTALL))
json_count = len(json_objects) total_searches += search_count
total_json_objects += json_count
if json_count > 1: if search_count > 1:
has_multiple_json = True has_multiple_searches = True
logger.warning(f"Message contains {json_count} JSON objects, which exceeds the limit of 1") logger.warning(f"Message contains {search_count} search queries, which exceeds the limit of 1")
break break
# Only reward if no message has multiple JSON objects # Only reward if no message has multiple search queries
if has_multiple_json: if has_multiple_searches:
rewards.append(0.0) rewards.append(0.0)
else: else:
# Base reward is 1.0 if constraint is met # Base reward is 1.0 if constraint is met
base_reward = 1.0 base_reward = 1.0
# Slight penalty for having too many total JSON objects across all messages # Slight penalty for having too many total searches across all messages
if total_json_objects > 4: if total_searches > 4:
penalty = 0.1 * (total_json_objects - 4) penalty = 0.1 * (total_searches - 4)
base_reward = max(0.2, base_reward - penalty) base_reward = max(0.2, base_reward - penalty)
logger.debug(f"Applied penalty for {total_json_objects} total JSON objects: {penalty}") logger.debug(f"Applied penalty for {total_searches} total searches: {penalty}")
rewards.append(base_reward) rewards.append(base_reward)
@ -647,10 +623,10 @@ def reward_retry_behavior(completions: list[dict], **reward_kwargs) -> list[floa
log_metric("rewards/retry_behavior", np.mean(rewards), reward_kwargs.get("step", 0)) log_metric("rewards/retry_behavior", np.mean(rewards), reward_kwargs.get("step", 0))
log_metric("rewards/retry_behavior_std", np.std(rewards), reward_kwargs.get("step", 0)) log_metric("rewards/retry_behavior_std", np.std(rewards), reward_kwargs.get("step", 0))
log_metric( log_metric(
"metrics/avg_json_per_msg", "metrics/avg_searches_per_msg",
np.mean( np.mean(
[ [
len(extract_json_objects(msg["content"])) len(re.findall(r"<search>.*?</search>", msg["content"], re.DOTALL))
for completion in completions for completion in completions
for msg in completion["messages"] for msg in completion["messages"]
if msg["role"] == "assistant" if msg["role"] == "assistant"
@ -659,7 +635,7 @@ def reward_retry_behavior(completions: list[dict], **reward_kwargs) -> list[floa
reward_kwargs.get("step", 0), reward_kwargs.get("step", 0),
) )
log_metric( log_metric(
"metrics/multiple_json_violation_rate", "metrics/multiple_search_violation_rate",
np.mean([0.0 if rewards[i] > 0.0 else 1.0 for i in range(len(rewards))]), np.mean([0.0 if rewards[i] > 0.0 else 1.0 for i in range(len(rewards))]),
reward_kwargs.get("step", 0), reward_kwargs.get("step", 0),
) )

Loading…
Cancel
Save