You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
639 lines
22 KiB
639 lines
22 KiB
"""
|
|
RL helpers module for handling tool-based conversations.
|
|
This module provides utility functions for handling chat-based tool interactions
|
|
and calculating rewards based on the quality of responses.
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import re
|
|
from dataclasses import dataclass
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
|
|
import nest_asyncio
|
|
import numpy as np
|
|
import torch
|
|
from loguru import logger
|
|
|
|
from search_module import get_qa_dataset, search
|
|
|
|
# Setup loguru
|
|
log_dir = Path("logs")
|
|
log_dir.mkdir(exist_ok=True)
|
|
logger.add(
|
|
log_dir / "rl_helpers_{time}.log",
|
|
rotation="500 MB",
|
|
retention="10 days",
|
|
level="DEBUG",
|
|
format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}",
|
|
)
|
|
|
|
nest_asyncio.apply()
|
|
from typing import Callable, List
|
|
|
|
from trl.trainer.grpo_trainer import apply_chat_template
|
|
|
|
|
|
# Constants for prompts and tool definitions
|
|
def get_system_prompt():
|
|
"""Get the system prompt with current date."""
|
|
current_date = datetime.now().strftime("%d %b %Y")
|
|
return f"""Cutting Knowledge Date: December 2023
|
|
Today Date: {current_date}
|
|
|
|
When you receive a tool call response, use the output to format an answer to the original user question.
|
|
|
|
You are a helpful assistant with tool calling capabilities.
|
|
"""
|
|
|
|
|
|
# Tool definition for search corpus
|
|
SEARCH_TOOL_DEFINITION = {
|
|
"type": "function",
|
|
"function": {
|
|
"name": "search_corpus",
|
|
"description": "Search over the knowledge corpus with a given query",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"query": {
|
|
"type": "string",
|
|
"description": "The query to search the knowledge corpus with",
|
|
},
|
|
},
|
|
"required": ["query"],
|
|
},
|
|
},
|
|
}
|
|
|
|
|
|
def build_user_prompt(q):
|
|
"""
|
|
Build a user prompt with the question and search tool definition.
|
|
|
|
Args:
|
|
q (str): The question to ask
|
|
|
|
Returns:
|
|
str: Formatted user prompt
|
|
"""
|
|
user_prompt = f"""You are a research assistant, and you use the search_corpus tool to find answers to questions.
|
|
Given a question, answer it using by doing searches using the search_corpus tool.
|
|
To use the search_corpus tool, respond with a JSON for a function call with its proper arguments.
|
|
|
|
You may also reason in any message, think step by step about how to answer the question. Wrap your reasoning in <think> and </think> tags.
|
|
|
|
{json.dumps(SEARCH_TOOL_DEFINITION, indent=2)}
|
|
|
|
Question: {q}
|
|
"""
|
|
return user_prompt
|
|
|
|
|
|
def get_initial_chat(question):
|
|
"""
|
|
Initialize a chat state with the question.
|
|
|
|
Args:
|
|
question (str): The question to ask
|
|
|
|
Returns:
|
|
dict: Initial chat state with system and user messages
|
|
"""
|
|
return {
|
|
"messages": [
|
|
{"role": "system", "content": get_system_prompt()},
|
|
{"role": "user", "content": build_user_prompt(question)},
|
|
]
|
|
}
|
|
|
|
|
|
def extract_json_objects(text):
|
|
"""
|
|
Extracts JSON objects (dictionaries) from a text that may contain multiple JSON objects.
|
|
|
|
Args:
|
|
text (str): The input text possibly containing JSON objects.
|
|
|
|
Returns:
|
|
list: A list of parsed JSON objects (dictionaries) extracted from the text.
|
|
"""
|
|
results = []
|
|
length = len(text)
|
|
i = 0
|
|
|
|
while i < length:
|
|
# Look for the start of a JSON object
|
|
if text[i] == "{":
|
|
start = i
|
|
stack = 1
|
|
i += 1
|
|
# Continue until we find the matching closing brace
|
|
while i < length and stack > 0:
|
|
if text[i] == "{":
|
|
stack += 1
|
|
elif text[i] == "}":
|
|
stack -= 1
|
|
i += 1
|
|
# Only attempt to decode if the braces are balanced
|
|
if stack == 0:
|
|
candidate = text[start:i]
|
|
try:
|
|
obj = json.loads(candidate)
|
|
# Optionally, ensure it's a dictionary if that's what you expect
|
|
if isinstance(obj, dict):
|
|
results.append(obj)
|
|
except json.JSONDecodeError:
|
|
# If it's not valid JSON, skip it.
|
|
pass
|
|
else:
|
|
i += 1
|
|
return results
|
|
|
|
|
|
def remove_reasoning(text: str) -> str:
|
|
"""
|
|
Removes all content between <think> and </think> tags,
|
|
including the tags themselves.
|
|
|
|
Parameters:
|
|
text (str): The input text that may contain <think>...</think> tags.
|
|
|
|
Returns:
|
|
str: The text with the tags and their content removed.
|
|
"""
|
|
# The regex pattern matches from <think> to </think> non-greedily.
|
|
pattern = r"<think>.*?</think>"
|
|
cleaned_text = re.sub(pattern, "", text, flags=re.DOTALL)
|
|
return cleaned_text
|
|
|
|
|
|
def run_agent_generations(generate_fn, tokenizer, chat_states):
|
|
"""
|
|
Run generation for chat states requiring assistant responses.
|
|
"""
|
|
logger.debug(f"Starting generation for {len(chat_states)} chat states")
|
|
prompts = []
|
|
batch_indices = []
|
|
# Prepare prompts for chat states needing an assistant response.
|
|
for idx, chat_state in enumerate(chat_states):
|
|
if chat_state.get("finished"):
|
|
logger.debug(f"Chat state {idx} already finished, skipping")
|
|
continue
|
|
|
|
if chat_state["messages"][-1]["role"] in ["ipython", "user"]:
|
|
prompt = apply_chat_template(chat_state, tokenizer=tokenizer)["text"]
|
|
prompts.append(prompt)
|
|
batch_indices.append(idx)
|
|
logger.debug(f"Added prompt for chat state {idx}")
|
|
|
|
if prompts:
|
|
logger.info(f"Generating responses for {len(prompts)} prompts")
|
|
responses = generate_fn(prompts)
|
|
for i, idx in enumerate(batch_indices):
|
|
chat_state = chat_states[idx]
|
|
response = responses[i]
|
|
if hasattr(response, "outputs"):
|
|
full_response = response.outputs[0].text
|
|
else:
|
|
full_response = response
|
|
assistant_response = full_response.split(
|
|
"<|start_header_id|>assistant<|end_header_id|>"
|
|
)[-1]
|
|
chat_state["messages"].append(
|
|
{"role": "assistant", "content": assistant_response}
|
|
)
|
|
logger.debug(f"Added assistant response to chat state {idx}")
|
|
else:
|
|
logger.debug("No prompts to generate responses for")
|
|
return chat_states
|
|
|
|
|
|
def check_finished_chats(chat_states):
|
|
"""
|
|
Check which chat states are finished (no more function calls).
|
|
|
|
Args:
|
|
chat_states: List of chat states
|
|
|
|
Returns:
|
|
list: Updated chat states with finished flag
|
|
"""
|
|
for chat_state in chat_states:
|
|
if chat_state.get("finished"):
|
|
continue
|
|
assert (
|
|
chat_state["messages"][-1]["role"] == "assistant"
|
|
), "Expected the last role to be assistant"
|
|
assistant_response = chat_state["messages"][-1]["content"]
|
|
function_calls = extract_json_objects(assistant_response)
|
|
if len(function_calls) == 0:
|
|
chat_state["finished"] = True
|
|
return chat_states
|
|
|
|
|
|
def run_tool_calls(chat_states):
|
|
"""
|
|
Execute tool calls found in chat states.
|
|
"""
|
|
logger.debug(f"Running tool calls for {len(chat_states)} chat states")
|
|
for chat_state in chat_states:
|
|
if chat_state.get("finished"):
|
|
logger.debug("Chat state already finished, skipping tool calls")
|
|
continue
|
|
assert (
|
|
chat_state["messages"][-1]["role"] == "assistant"
|
|
), "Expected the last role to be assistant to run tool calls"
|
|
try:
|
|
assistant_response = chat_state["messages"][-1]["content"]
|
|
function_calls = extract_json_objects(assistant_response)
|
|
if len(function_calls) > 1:
|
|
logger.warning("Multiple function calls found in assistant response")
|
|
raise ValueError(
|
|
"Expected only one function call in assistant response"
|
|
)
|
|
elif len(function_calls) == 1:
|
|
function_call = function_calls[0]
|
|
query = function_call["function"]["parameters"]["query"]
|
|
logger.info(f"Executing search with query: {query}")
|
|
results = search(query, return_type=str, results=2)
|
|
chat_state["messages"].append({"role": "ipython", "content": results})
|
|
logger.debug("Added search results to chat state")
|
|
except Exception as e:
|
|
logger.error(f"Error during tool call: {str(e)}")
|
|
chat_state["messages"].append(
|
|
{"role": "system", "content": f"Error during post-processing: {str(e)}"}
|
|
)
|
|
chat_state["finished"] = True
|
|
return chat_states
|
|
|
|
|
|
def get_mask(text, tokenizer):
|
|
encoding = tokenizer(text, add_special_tokens=False)
|
|
start_header_id = tokenizer.convert_tokens_to_ids("<|start_header_id|>")
|
|
assistant_token = tokenizer.convert_tokens_to_ids("assistant")
|
|
end_header_id = tokenizer.convert_tokens_to_ids("<|end_header_id|>")
|
|
eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
|
assistant_ranges = []
|
|
i = 0
|
|
while i < len(encoding.input_ids) - 1:
|
|
if (
|
|
encoding.input_ids[i] == start_header_id
|
|
and encoding.input_ids[i + 1] == assistant_token
|
|
):
|
|
i += 2
|
|
while (
|
|
i < len(encoding.input_ids) and encoding.input_ids[i] != end_header_id
|
|
):
|
|
i += 1
|
|
i += 2
|
|
start_idx = i
|
|
while i < len(encoding.input_ids) and encoding.input_ids[i] != eot_id:
|
|
i += 1
|
|
end_idx = i
|
|
assistant_ranges.append((start_idx, end_idx))
|
|
else:
|
|
i += 1
|
|
mask = [0] * len(encoding.input_ids)
|
|
for start_idx, end_idx in assistant_ranges:
|
|
for idx in range(start_idx, end_idx):
|
|
mask[idx] = 1
|
|
return torch.tensor(mask, dtype=torch.int)
|
|
|
|
|
|
def check_exceeded_max_new_tokens(chat_states, max_new_tokens, tokenizer):
|
|
for chat_state in chat_states:
|
|
if chat_state.get("finished"):
|
|
continue
|
|
initial_length = chat_state["initial_length"]
|
|
new_length = get_chat_num_tokens(chat_state, tokenizer)
|
|
if new_length - initial_length > max_new_tokens:
|
|
chat_state["finished"] = True
|
|
return chat_states
|
|
|
|
|
|
@dataclass
|
|
class AgenticOutputs:
|
|
prompt_tokens: list[torch.Tensor]
|
|
response_tokens: list[torch.Tensor]
|
|
response_masks: list[torch.Tensor]
|
|
final_response_str: list[str]
|
|
full_chat_states: list[dict]
|
|
|
|
|
|
def get_chat_num_tokens(chat_state, tokenizer):
|
|
chat_text = apply_chat_template(chat_state, tokenizer=tokenizer)["text"]
|
|
return (
|
|
tokenizer(chat_text, add_special_tokens=False, return_tensors="pt")["input_ids"]
|
|
.squeeze()
|
|
.shape[0]
|
|
)
|
|
|
|
|
|
def run_agent(
|
|
generate_fn, tokenizer, questions, max_generations=5, max_new_tokens=4096
|
|
):
|
|
"""
|
|
Run the agent to completion for a batch of questions.
|
|
"""
|
|
logger.info(f"Starting agent run with {len(questions)} questions")
|
|
logger.debug(
|
|
f"Max generations: {max_generations}, Max new tokens: {max_new_tokens}"
|
|
)
|
|
|
|
chat_states = [get_initial_chat(q) for q in questions]
|
|
# set the initial_prompt length
|
|
for i, chat_state in enumerate(chat_states):
|
|
chat_state["initial_length"] = get_chat_num_tokens(chat_state, tokenizer)
|
|
logger.debug(f"Initial length for question {i}: {chat_state['initial_length']}")
|
|
|
|
# agent loop
|
|
for i in range(max_generations):
|
|
logger.info(f"Starting generation step {i+1}/{max_generations}")
|
|
chat_states = run_agent_generations(generate_fn, tokenizer, chat_states)
|
|
chat_states = check_finished_chats(chat_states)
|
|
chat_states = run_tool_calls(chat_states)
|
|
chat_states = check_exceeded_max_new_tokens(
|
|
chat_states, max_new_tokens, tokenizer
|
|
)
|
|
finished_count = sum(1 for state in chat_states if state.get("finished"))
|
|
logger.info(
|
|
f"Finished {finished_count}/{len(chat_states)} chat states after step {i+1}"
|
|
)
|
|
|
|
logger.info("Agent run completed")
|
|
|
|
# Process final outputs
|
|
logger.debug("Processing final outputs")
|
|
answers = []
|
|
for chat in chat_states:
|
|
answers.append(chat["messages"][-1]["content"])
|
|
logger.debug(f"Final answer: {chat['messages'][-1]['content'][:100]}...")
|
|
|
|
def split_prompt_assistant(convo_text):
|
|
marker = "<|start_header_id|>assistant<|end_header_id|>"
|
|
idx = convo_text.find(marker)
|
|
if idx == -1:
|
|
logger.error("Could not find assistant marker in conversation text")
|
|
raise ValueError("Could not find assistant marker in conversation text.")
|
|
return convo_text, ""
|
|
prompt = convo_text[: idx + len(marker)]
|
|
assistant_response = convo_text[idx + len(marker) :]
|
|
return prompt, assistant_response
|
|
|
|
str_chats = [
|
|
apply_chat_template(chat, tokenizer=tokenizer)["text"] for chat in chat_states
|
|
]
|
|
prompt_toks, response_toks, response_masks = [], [], []
|
|
|
|
logger.debug("Processing tokenization")
|
|
for i, str_chat in enumerate(str_chats):
|
|
prompt, response = split_prompt_assistant(str_chat)
|
|
prompt_toks.append(
|
|
tokenizer(prompt, add_special_tokens=False, return_tensors="pt")[
|
|
"input_ids"
|
|
].squeeze()
|
|
)
|
|
response_toks.append(
|
|
tokenizer(response, add_special_tokens=False, return_tensors="pt")[
|
|
"input_ids"
|
|
].squeeze()[:max_new_tokens]
|
|
)
|
|
mask = get_mask(str_chat, tokenizer)[len(prompt_toks[-1]) :][:max_new_tokens]
|
|
response_masks.append(mask)
|
|
logger.debug(f"Processed tokens for chat {i}")
|
|
|
|
final_response_str = [chat["messages"][-1]["content"] for chat in chat_states]
|
|
full_chat_states = chat_states
|
|
|
|
logger.info("Agent run completed successfully")
|
|
return AgenticOutputs(
|
|
prompt_tokens=prompt_toks,
|
|
response_tokens=response_toks,
|
|
response_masks=response_masks,
|
|
final_response_str=final_response_str,
|
|
full_chat_states=full_chat_states,
|
|
)
|
|
|
|
|
|
# Verification
|
|
async def verify(student_answer: str, question: str, answer: str) -> bool:
|
|
"""
|
|
Verify if student's answer matches the correct answer.
|
|
|
|
Args:
|
|
student_answer: The model's answer
|
|
question: The original question
|
|
answer: The ground truth answer
|
|
|
|
Returns:
|
|
bool: True if answer is correct, False otherwise
|
|
"""
|
|
logger.debug(f"Verifying answer for question: {question}")
|
|
logger.debug(f"Student answer: {student_answer}")
|
|
logger.debug(f"Correct answer: {answer}")
|
|
|
|
# Simple string matching for now
|
|
# TODO: Implement more sophisticated matching
|
|
return student_answer.strip().lower() == answer.strip().lower()
|
|
|
|
|
|
def check_student_answers(
|
|
questions: List[str],
|
|
answers: List[str],
|
|
student_answers: List[str],
|
|
vllm_generate_func: Callable[[List[str]], List[str]],
|
|
tokenizer,
|
|
log_file: str = "qa_log.txt",
|
|
) -> List[bool]:
|
|
"""
|
|
Evaluates a list of student answers against the true answers using a vLLM generate function.
|
|
"""
|
|
logger.info(f"Checking {len(questions)} student answers")
|
|
|
|
if not (len(questions) == len(answers) == len(student_answers)):
|
|
logger.error(
|
|
"Mismatched lengths between questions, answers, and student answers"
|
|
)
|
|
raise ValueError(
|
|
"The number of questions, answers, and student answers must be equal."
|
|
)
|
|
|
|
prompts = []
|
|
for question, answer, student_ans in zip(questions, answers, student_answers):
|
|
prompt_text = (
|
|
"You are grading a student's answer. For the following question, "
|
|
"compare the student's answer to the correct answer. Reply with 'Yes' if the student's answer is correct, or 'No' if it is completely incorrect.\n\n"
|
|
f"Question: {question}\n"
|
|
f"Correct Answer: {answer}\n"
|
|
f"Student Answer: {student_ans}\n"
|
|
)
|
|
formatted_prompt = tokenizer.apply_chat_template(
|
|
[{"role": "user", "content": prompt_text}],
|
|
tokenize=False,
|
|
add_generation_prompt=True,
|
|
)
|
|
prompts.append(formatted_prompt)
|
|
logger.debug(f"Created verification prompt for question: {question[:50]}...")
|
|
|
|
logger.info("Generating verification responses")
|
|
responses = vllm_generate_func(prompts)
|
|
responses_text = []
|
|
for response in responses:
|
|
if hasattr(response, "outputs"):
|
|
responses_text.append(response.outputs[0].text)
|
|
else:
|
|
responses_text.append(response)
|
|
logger.debug(f"Got {len(responses_text)} verification responses")
|
|
|
|
results = []
|
|
for response in responses_text:
|
|
results.append("yes" in response.lower())
|
|
logger.debug(f"Verification result: {'yes' in response.lower()}")
|
|
|
|
logger.info(f"Verification complete. {sum(results)}/{len(results)} answers correct")
|
|
|
|
# Append the QA details and verifier's response to the specified log file
|
|
with open(log_file, "a") as file:
|
|
for question, answer, student_ans, verifier_response in zip(
|
|
questions, answers, student_answers, responses_text
|
|
):
|
|
file.write("Question: " + question + "\n")
|
|
file.write("Correct Answer: " + answer + "\n")
|
|
file.write("Student Answer: " + student_ans + "\n")
|
|
file.write("Verifier said: " + verifier_response + "\n")
|
|
file.write("-" * 40 + "\n")
|
|
|
|
return results
|
|
|
|
|
|
# Reward Functions
|
|
def build_reward_correctness_fn(generate_fn, tokenizer):
|
|
def reward_correctness(prompts, completions, **reward_kwargs):
|
|
teacher_answers = reward_kwargs["answer"]
|
|
student_answers = [
|
|
completion["messages"][-1]["content"] for completion in completions
|
|
]
|
|
|
|
correct = check_student_answers(
|
|
prompts,
|
|
teacher_answers,
|
|
student_answers,
|
|
vllm_generate_func=generate_fn,
|
|
tokenizer=tokenizer,
|
|
)
|
|
return correct
|
|
|
|
return reward_correctness
|
|
|
|
|
|
def reward_formatting(prompts, completions, **reward_kwargs):
|
|
# make sure full chats doesn't have any error function calls
|
|
has_error = [False] * len(completions)
|
|
for i, chat in enumerate(completions):
|
|
for message in chat["messages"]:
|
|
if "Error during" in message["content"]:
|
|
has_error[i] = True
|
|
break
|
|
return [0.7 if not e else 0 for e in has_error]
|
|
|
|
|
|
def reward_retry_behavior(completions: list[dict], **reward_kwargs) -> list[float]:
|
|
"""
|
|
Reward function that encourages optimal retry behavior by counting total function calls
|
|
across all assistant messages in the conversation.
|
|
"""
|
|
rewards: list[float] = []
|
|
|
|
for completion in completions:
|
|
# Get ALL assistant messages
|
|
assistant_msgs: list[str] = [
|
|
msg["content"]
|
|
for msg in completion["messages"]
|
|
if msg["role"] == "assistant" and msg["content"] is not None
|
|
]
|
|
|
|
if not assistant_msgs:
|
|
rewards.append(0.0)
|
|
continue
|
|
|
|
# Count total function calls across all messages
|
|
total_retries: int = 0
|
|
for msg in assistant_msgs:
|
|
total_retries += len(extract_json_objects(msg))
|
|
|
|
# Calculate reward using modified sigmoid function
|
|
x: float = float(total_retries - 4) # Center peak at 4 retries
|
|
base_reward: float = 1.0 / (1.0 + np.exp(-x + abs(x) / 2))
|
|
|
|
# Additional penalty for excessive retries
|
|
if total_retries > 6:
|
|
penalty: float = 0.2 * (total_retries - 6)
|
|
base_reward = max(0.1, base_reward - penalty)
|
|
|
|
rewards.append(base_reward)
|
|
|
|
return rewards
|
|
|
|
|
|
def reward_exact_match_chunk_query(prompts, completions, **reward_kwargs):
|
|
"""
|
|
Reward function that checks if the model's search queries hit the correct chunk content.
|
|
"""
|
|
logger.debug(f"Calculating rewards for {len(prompts)} prompts")
|
|
|
|
# Get correct chunk contents from reward kwargs
|
|
correct_contents = reward_kwargs.get("chunk_content", [])
|
|
if not correct_contents:
|
|
logger.error("No chunk_content provided in reward_kwargs")
|
|
raise ValueError("chunk_content must be provided in reward_kwargs")
|
|
|
|
rewards = []
|
|
for i, (chat_state, correct_content) in enumerate(
|
|
zip(completions, correct_contents)
|
|
):
|
|
# Get all ipython messages (search results) from the chat
|
|
search_results = [
|
|
msg["content"] for msg in chat_state["messages"] if msg["role"] == "ipython"
|
|
]
|
|
logger.debug(f"Found {len(search_results)} search results for prompt {i}")
|
|
|
|
# Check if any search hit the correct chunk content
|
|
found_correct_chunk = False
|
|
for result in search_results:
|
|
if correct_content in result:
|
|
found_correct_chunk = True
|
|
logger.debug(
|
|
f"Found correct chunk content in search results for prompt {i}"
|
|
)
|
|
break
|
|
|
|
reward = 1.0 if found_correct_chunk else 0.0
|
|
rewards.append(reward)
|
|
logger.debug(f"Reward for prompt {i}: {reward}")
|
|
|
|
logger.info(f"Average reward: {sum(rewards)/len(rewards):.3f}")
|
|
return rewards
|
|
|
|
|
|
def run_eval(generate_fn, verify_fn, tokenizer):
|
|
logger.info("Starting evaluation")
|
|
train_dataset, test_dataset = get_qa_dataset()
|
|
questions = test_dataset["prompt"]
|
|
logger.info(f"Loaded {len(questions)} test questions")
|
|
|
|
agentic_outputs = run_agent(generate_fn, tokenizer, questions)
|
|
full_chat_states = agentic_outputs.full_chat_states
|
|
final_responses = agentic_outputs.final_response_str
|
|
|
|
logger.info("Calculating rewards")
|
|
rewards = verify_fn(questions, full_chat_states, answer=test_dataset["answer"])
|
|
avg_reward = sum(rewards) / len(rewards)
|
|
|
|
logger.info("EVALUATION RESULTS:")
|
|
logger.info(f"Percentage of correct answers: {avg_reward:.3f}")
|
|
logger.info("=" * 30)
|
|
|
|
return full_chat_states
|