|
|
@ -5,7 +5,6 @@ and calculating rewards based on the quality of responses.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
import inspect
|
|
|
|
import inspect
|
|
|
|
import json
|
|
|
|
|
|
|
|
import re
|
|
|
|
import re
|
|
|
|
from dataclasses import dataclass
|
|
|
|
from dataclasses import dataclass
|
|
|
|
from datetime import datetime
|
|
|
|
from datetime import datetime
|
|
|
@ -36,29 +35,9 @@ You are a helpful assistant with tool calling capabilities.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Tool definition for search corpus
|
|
|
|
|
|
|
|
SEARCH_TOOL_DEFINITION = {
|
|
|
|
|
|
|
|
"type": "function",
|
|
|
|
|
|
|
|
"function": {
|
|
|
|
|
|
|
|
"name": "search_corpus",
|
|
|
|
|
|
|
|
"description": "Search over the knowledge corpus with a given query",
|
|
|
|
|
|
|
|
"parameters": {
|
|
|
|
|
|
|
|
"type": "object",
|
|
|
|
|
|
|
|
"properties": {
|
|
|
|
|
|
|
|
"query": {
|
|
|
|
|
|
|
|
"type": "string",
|
|
|
|
|
|
|
|
"description": "The query to search the knowledge corpus with",
|
|
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
"required": ["query"],
|
|
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_user_prompt(q):
|
|
|
|
def build_user_prompt(q):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Build a user prompt with the question and search tool definition.
|
|
|
|
Build a user prompt with the question using the new template format.
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
|
q (str): The question to ask
|
|
|
|
q (str): The question to ask
|
|
|
@ -66,17 +45,48 @@ def build_user_prompt(q):
|
|
|
|
Returns:
|
|
|
|
Returns:
|
|
|
|
str: Formatted user prompt
|
|
|
|
str: Formatted user prompt
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
user_prompt = f"""You are a research assistant, and you use the search_corpus tool to find answers to questions.
|
|
|
|
user_prompt = f"""Answer the given question. \
|
|
|
|
Given a question, answer it using by doing searches using the search_corpus tool.
|
|
|
|
You must conduct reasoning inside <think> and </think> first every time you get new information. \
|
|
|
|
To use the search_corpus tool, respond with a JSON for a function call with its proper arguments.
|
|
|
|
After reasoning, if you find you lack some knowledge, you can call a search engine by <search> query </search>. \
|
|
|
|
|
|
|
|
You can search as many times as your want. \
|
|
|
|
|
|
|
|
If you find no further external knowledge needed, you can directly provide the answer inside <answer> and </answer>, without detailed illustrations. For example, <answer> Beijing </answer>.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
IMPORTANT INSTRUCTIONS:
|
|
|
|
|
|
|
|
1. PLEASE CONSIDER CHAT HISTORY WHEN ANSWERING THE QUESTION.
|
|
|
|
|
|
|
|
2. ONLY ANSWER WHEN YOU HAVE 100% CONFIDENCE IN THE SEARCH RESULTS, ELSE CONTINUE SEARCHING.
|
|
|
|
|
|
|
|
3. PLEASE SEARCH MULTIPLE TIMES WITH DIFFERENT QUERIES.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Question: {q}\n"""
|
|
|
|
|
|
|
|
return user_prompt
|
|
|
|
|
|
|
|
|
|
|
|
You may also reason in any message, think step by step about how to answer the question. Wrap your reasoning in <think> and </think> tags.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
{json.dumps(SEARCH_TOOL_DEFINITION, indent=2)}
|
|
|
|
def format_search_results(results: str | list[str]) -> str:
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
Format search results for display, matching the format from infer.py.
|
|
|
|
|
|
|
|
Each result should be in the format: "Doc X(Title: Y) content"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
|
|
results: Search results as string or list of strings
|
|
|
|
|
|
|
|
|
|
|
|
Question: {q}
|
|
|
|
Returns:
|
|
|
|
|
|
|
|
Formatted search results with document titles
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
return user_prompt
|
|
|
|
if isinstance(results, list):
|
|
|
|
|
|
|
|
# If results are already in the correct format, just join them
|
|
|
|
|
|
|
|
if any("Doc" in r and "Title:" in r for r in results):
|
|
|
|
|
|
|
|
content = "\n".join(results)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
# If results are raw content, format them with default titles
|
|
|
|
|
|
|
|
content = "\n".join([f"Doc {i + 1}(Title: Document {i + 1})\n{r}" for i, r in enumerate(results)])
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
# If single result is already formatted, use it as is
|
|
|
|
|
|
|
|
if "Doc" in results and "Title:" in results:
|
|
|
|
|
|
|
|
content = results
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
# If single result is raw content, format it with default title
|
|
|
|
|
|
|
|
content = f"Doc 1(Title: Document 1)\n{results}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return content
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_initial_chat(question):
|
|
|
|
def get_initial_chat(question):
|
|
|
@ -97,49 +107,6 @@ def get_initial_chat(question):
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def extract_json_objects(text):
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
Extracts JSON objects (dictionaries) from a text that may contain multiple JSON objects.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
|
|
text (str): The input text possibly containing JSON objects.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
|
|
list: A list of parsed JSON objects (dictionaries) extracted from the text.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
results = []
|
|
|
|
|
|
|
|
length = len(text)
|
|
|
|
|
|
|
|
i = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
while i < length:
|
|
|
|
|
|
|
|
# Look for the start of a JSON object
|
|
|
|
|
|
|
|
if text[i] == "{":
|
|
|
|
|
|
|
|
start = i
|
|
|
|
|
|
|
|
stack = 1
|
|
|
|
|
|
|
|
i += 1
|
|
|
|
|
|
|
|
# Continue until we find the matching closing brace
|
|
|
|
|
|
|
|
while i < length and stack > 0:
|
|
|
|
|
|
|
|
if text[i] == "{":
|
|
|
|
|
|
|
|
stack += 1
|
|
|
|
|
|
|
|
elif text[i] == "}":
|
|
|
|
|
|
|
|
stack -= 1
|
|
|
|
|
|
|
|
i += 1
|
|
|
|
|
|
|
|
# Only attempt to decode if the braces are balanced
|
|
|
|
|
|
|
|
if stack == 0:
|
|
|
|
|
|
|
|
candidate = text[start:i]
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
obj = json.loads(candidate)
|
|
|
|
|
|
|
|
# Optionally, ensure it's a dictionary if that's what you expect
|
|
|
|
|
|
|
|
if isinstance(obj, dict):
|
|
|
|
|
|
|
|
results.append(obj)
|
|
|
|
|
|
|
|
except json.JSONDecodeError:
|
|
|
|
|
|
|
|
# If it's not valid JSON, skip it.
|
|
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
i += 1
|
|
|
|
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def remove_reasoning(text: str) -> str:
|
|
|
|
def remove_reasoning(text: str) -> str:
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Removes all content between <think> and </think> tags,
|
|
|
|
Removes all content between <think> and </think> tags,
|
|
|
@ -196,7 +163,7 @@ def run_agent_generations(generate_fn, tokenizer, chat_states):
|
|
|
|
|
|
|
|
|
|
|
|
def check_finished_chats(chat_states):
|
|
|
|
def check_finished_chats(chat_states):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Check which chat states are finished (no more function calls).
|
|
|
|
Check which chat states are finished (no more search queries).
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
|
chat_states: List of chat states
|
|
|
|
chat_states: List of chat states
|
|
|
@ -209,12 +176,27 @@ def check_finished_chats(chat_states):
|
|
|
|
continue
|
|
|
|
continue
|
|
|
|
assert chat_state["messages"][-1]["role"] == "assistant", "Expected the last role to be assistant"
|
|
|
|
assert chat_state["messages"][-1]["role"] == "assistant", "Expected the last role to be assistant"
|
|
|
|
assistant_response = chat_state["messages"][-1]["content"]
|
|
|
|
assistant_response = chat_state["messages"][-1]["content"]
|
|
|
|
function_calls = extract_json_objects(assistant_response)
|
|
|
|
# Check if there are any search queries in the response
|
|
|
|
if len(function_calls) == 0:
|
|
|
|
if not re.search(r"<search>.*?</search>", assistant_response, re.DOTALL):
|
|
|
|
chat_state["finished"] = True
|
|
|
|
chat_state["finished"] = True
|
|
|
|
return chat_states
|
|
|
|
return chat_states
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def extract_search_query(text: str) -> str | None:
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
Extract search query from text between <search> tags.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
|
|
text (str): Text containing search query
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
|
|
str | None: Search query if found, None otherwise
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
pattern = re.compile(r"<search>(.*?)</search>", re.DOTALL)
|
|
|
|
|
|
|
|
matches = pattern.findall(text)
|
|
|
|
|
|
|
|
return matches[-1] if matches else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_tool_calls(chat_states):
|
|
|
|
def run_tool_calls(chat_states):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Execute tool calls found in chat states.
|
|
|
|
Execute tool calls found in chat states.
|
|
|
@ -231,21 +213,16 @@ def run_tool_calls(chat_states):
|
|
|
|
)
|
|
|
|
)
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
assistant_response = chat_state["messages"][-1]["content"]
|
|
|
|
assistant_response = chat_state["messages"][-1]["content"]
|
|
|
|
function_calls = extract_json_objects(assistant_response)
|
|
|
|
search_query = extract_search_query(assistant_response)
|
|
|
|
if len(function_calls) > 1:
|
|
|
|
if search_query:
|
|
|
|
logger.warning("Multiple function calls found in assistant response")
|
|
|
|
logger.info(f"🔍 Search Query: {search_query}")
|
|
|
|
raise ValueError("Expected only one function call in assistant response")
|
|
|
|
results = search(search_query, return_type=str, results=2)
|
|
|
|
elif len(function_calls) == 1:
|
|
|
|
# Wrap results in <information> tags
|
|
|
|
function_call = function_calls[0]
|
|
|
|
formatted_results = f"<information>{results}</information>"
|
|
|
|
query = function_call["function"]["parameters"]["query"]
|
|
|
|
logger.info(f"ℹ️ Information: {formatted_results}")
|
|
|
|
logger.info(f"🔍 Search Query: {query}")
|
|
|
|
|
|
|
|
results = search(query, return_type=str, results=2)
|
|
|
|
chat_state["messages"].append({"role": "ipython", "content": formatted_results})
|
|
|
|
chat_state["messages"].append({"role": "ipython", "content": results})
|
|
|
|
total_retries += 1
|
|
|
|
|
|
|
|
|
|
|
|
# Count retries
|
|
|
|
|
|
|
|
retries = len(extract_json_objects(assistant_response))
|
|
|
|
|
|
|
|
total_retries += retries
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.debug("Added search results to chat state")
|
|
|
|
logger.debug("Added search results to chat state")
|
|
|
|
except Exception as e:
|
|
|
|
except Exception as e:
|
|
|
|
logger.error(f"Error during tool call: {str(e)}")
|
|
|
|
logger.error(f"Error during tool call: {str(e)}")
|
|
|
@ -598,7 +575,7 @@ def reward_formatting(prompts, completions, **reward_kwargs):
|
|
|
|
def reward_retry_behavior(completions: list[dict], **reward_kwargs) -> list[float]:
|
|
|
|
def reward_retry_behavior(completions: list[dict], **reward_kwargs) -> list[float]:
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Reward function that encourages optimal retry behavior by only rewarding completions
|
|
|
|
Reward function that encourages optimal retry behavior by only rewarding completions
|
|
|
|
where every assistant message contains at most 1 JSON object.
|
|
|
|
where every assistant message contains at most 1 search query.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
rewards: list[float] = []
|
|
|
|
rewards: list[float] = []
|
|
|
|
|
|
|
|
|
|
|
@ -614,32 +591,31 @@ def reward_retry_behavior(completions: list[dict], **reward_kwargs) -> list[floa
|
|
|
|
rewards.append(0.0)
|
|
|
|
rewards.append(0.0)
|
|
|
|
continue
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
# Check if every message has at most 1 JSON object
|
|
|
|
# Check if every message has at most 1 search query
|
|
|
|
has_multiple_json = False
|
|
|
|
has_multiple_searches = False
|
|
|
|
total_json_objects = 0
|
|
|
|
total_searches = 0
|
|
|
|
|
|
|
|
|
|
|
|
for msg in assistant_msgs:
|
|
|
|
for msg in assistant_msgs:
|
|
|
|
json_objects = extract_json_objects(msg)
|
|
|
|
search_count = len(re.findall(r"<search>.*?</search>", msg, re.DOTALL))
|
|
|
|
json_count = len(json_objects)
|
|
|
|
total_searches += search_count
|
|
|
|
total_json_objects += json_count
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if json_count > 1:
|
|
|
|
if search_count > 1:
|
|
|
|
has_multiple_json = True
|
|
|
|
has_multiple_searches = True
|
|
|
|
logger.warning(f"Message contains {json_count} JSON objects, which exceeds the limit of 1")
|
|
|
|
logger.warning(f"Message contains {search_count} search queries, which exceeds the limit of 1")
|
|
|
|
break
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
# Only reward if no message has multiple JSON objects
|
|
|
|
# Only reward if no message has multiple search queries
|
|
|
|
if has_multiple_json:
|
|
|
|
if has_multiple_searches:
|
|
|
|
rewards.append(0.0)
|
|
|
|
rewards.append(0.0)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
# Base reward is 1.0 if constraint is met
|
|
|
|
# Base reward is 1.0 if constraint is met
|
|
|
|
base_reward = 1.0
|
|
|
|
base_reward = 1.0
|
|
|
|
|
|
|
|
|
|
|
|
# Slight penalty for having too many total JSON objects across all messages
|
|
|
|
# Slight penalty for having too many total searches across all messages
|
|
|
|
if total_json_objects > 4:
|
|
|
|
if total_searches > 4:
|
|
|
|
penalty = 0.1 * (total_json_objects - 4)
|
|
|
|
penalty = 0.1 * (total_searches - 4)
|
|
|
|
base_reward = max(0.2, base_reward - penalty)
|
|
|
|
base_reward = max(0.2, base_reward - penalty)
|
|
|
|
logger.debug(f"Applied penalty for {total_json_objects} total JSON objects: {penalty}")
|
|
|
|
logger.debug(f"Applied penalty for {total_searches} total searches: {penalty}")
|
|
|
|
|
|
|
|
|
|
|
|
rewards.append(base_reward)
|
|
|
|
rewards.append(base_reward)
|
|
|
|
|
|
|
|
|
|
|
@ -647,10 +623,10 @@ def reward_retry_behavior(completions: list[dict], **reward_kwargs) -> list[floa
|
|
|
|
log_metric("rewards/retry_behavior", np.mean(rewards), reward_kwargs.get("step", 0))
|
|
|
|
log_metric("rewards/retry_behavior", np.mean(rewards), reward_kwargs.get("step", 0))
|
|
|
|
log_metric("rewards/retry_behavior_std", np.std(rewards), reward_kwargs.get("step", 0))
|
|
|
|
log_metric("rewards/retry_behavior_std", np.std(rewards), reward_kwargs.get("step", 0))
|
|
|
|
log_metric(
|
|
|
|
log_metric(
|
|
|
|
"metrics/avg_json_per_msg",
|
|
|
|
"metrics/avg_searches_per_msg",
|
|
|
|
np.mean(
|
|
|
|
np.mean(
|
|
|
|
[
|
|
|
|
[
|
|
|
|
len(extract_json_objects(msg["content"]))
|
|
|
|
len(re.findall(r"<search>.*?</search>", msg["content"], re.DOTALL))
|
|
|
|
for completion in completions
|
|
|
|
for completion in completions
|
|
|
|
for msg in completion["messages"]
|
|
|
|
for msg in completion["messages"]
|
|
|
|
if msg["role"] == "assistant"
|
|
|
|
if msg["role"] == "assistant"
|
|
|
@ -659,7 +635,7 @@ def reward_retry_behavior(completions: list[dict], **reward_kwargs) -> list[floa
|
|
|
|
reward_kwargs.get("step", 0),
|
|
|
|
reward_kwargs.get("step", 0),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
log_metric(
|
|
|
|
log_metric(
|
|
|
|
"metrics/multiple_json_violation_rate",
|
|
|
|
"metrics/multiple_search_violation_rate",
|
|
|
|
np.mean([0.0 if rewards[i] > 0.0 else 1.0 for i in range(len(rewards))]),
|
|
|
|
np.mean([0.0 if rewards[i] > 0.0 else 1.0 for i in range(len(rewards))]),
|
|
|
|
reward_kwargs.get("step", 0),
|
|
|
|
reward_kwargs.get("step", 0),
|
|
|
|
)
|
|
|
|
)
|
|
|
|