feat: refactor whole code base, add logic for training R1 distil base models, change some template and reward logics
- Break down rl_helpers into smaller modules - Removed deprecated rl_helpers module to streamline the codebase. - Enhance initial user prompt template inspired by Search-R1main
parent
c90c03267e
commit
31dcbf5d8a
@ -0,0 +1,43 @@
|
|||||||
|
"""
|
||||||
|
Main package exports for RL helpers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from trl.trainer.grpo_trainer import apply_chat_template
|
||||||
|
|
||||||
|
from src.agent import Agent
|
||||||
|
from src.config import logger
|
||||||
|
from src.evaluation import check_student_answers, run_eval, verify
|
||||||
|
from src.prompts import build_user_prompt, format_search_results, get_system_prompt
|
||||||
|
from src.rewards import (
|
||||||
|
build_reward_correctness_fn,
|
||||||
|
reward_em_chunk,
|
||||||
|
reward_format,
|
||||||
|
reward_retry,
|
||||||
|
)
|
||||||
|
from src.search_module import get_qa_dataset, search
|
||||||
|
from src.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Prompts
|
||||||
|
"get_system_prompt",
|
||||||
|
"build_user_prompt",
|
||||||
|
"format_search_results",
|
||||||
|
"apply_chat_template",
|
||||||
|
# Agent
|
||||||
|
"Agent",
|
||||||
|
"LlamaTokenizerAdapter",
|
||||||
|
"R1DistilTokenizerAdapter",
|
||||||
|
# Rewards
|
||||||
|
"build_reward_correctness_fn",
|
||||||
|
"reward_format",
|
||||||
|
"reward_retry",
|
||||||
|
"reward_em_chunk",
|
||||||
|
# Evaluation
|
||||||
|
"run_eval",
|
||||||
|
"check_student_answers",
|
||||||
|
"verify",
|
||||||
|
# Search
|
||||||
|
"get_qa_dataset",
|
||||||
|
"search",
|
||||||
|
"logger",
|
||||||
|
]
|
@ -0,0 +1,238 @@
|
|||||||
|
"""
|
||||||
|
Evaluation utilities for RL training.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from src.agent import Agent
|
||||||
|
from src.config import logger
|
||||||
|
from src.search_module import get_qa_dataset
|
||||||
|
from src.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter
|
||||||
|
|
||||||
|
|
||||||
|
async def verify(student_answer: str, question: str, answer: str) -> bool:
|
||||||
|
"""
|
||||||
|
Verify if student's answer matches the correct answer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
student_answer: The model's answer
|
||||||
|
question: The original question
|
||||||
|
answer: The ground truth answer
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if answer is correct, False otherwise
|
||||||
|
"""
|
||||||
|
logger.debug(f"Verifying answer for question: {question}")
|
||||||
|
logger.debug(f"Student answer: {student_answer}")
|
||||||
|
logger.debug(f"Correct answer: {answer}")
|
||||||
|
|
||||||
|
# Simple string matching for now
|
||||||
|
# TODO: Implement more sophisticated matching
|
||||||
|
return student_answer.strip().lower() == answer.strip().lower()
|
||||||
|
|
||||||
|
|
||||||
|
def check_student_answers(
|
||||||
|
questions: list[str],
|
||||||
|
answers: list[str],
|
||||||
|
student_answers: list, # Can be strings or dicts
|
||||||
|
vllm_generate_func,
|
||||||
|
tokenizer,
|
||||||
|
log_file=None,
|
||||||
|
) -> list[bool]:
|
||||||
|
"""
|
||||||
|
Evaluates a list of student answers against the true answers using a vLLM generate function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
questions: List of questions
|
||||||
|
answers: List of correct answers
|
||||||
|
student_answers: List of student answers to evaluate
|
||||||
|
vllm_generate_func: Function to generate verification responses
|
||||||
|
tokenizer: Tokenizer for formatting prompts
|
||||||
|
log_file: Optional path to write detailed results
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of boolean results (True for correct answers)
|
||||||
|
"""
|
||||||
|
logger.info(f"Checking {len(questions)} student answers")
|
||||||
|
|
||||||
|
if not (len(questions) == len(answers) == len(student_answers)):
|
||||||
|
logger.error("Mismatched lengths between questions, answers, and student answers")
|
||||||
|
raise ValueError("The number of questions, answers, and student answers must be equal.")
|
||||||
|
|
||||||
|
prompts = []
|
||||||
|
for question, answer, student_ans in zip(questions, answers, student_answers):
|
||||||
|
prompt_text = (
|
||||||
|
"You are grading a student's answer to a question. For the following question, "
|
||||||
|
"compare the student's answer to the correct answer. Reply with 'Yes' if the student's answer contains the correct information, "
|
||||||
|
"even if it's not an exact match. If the student's answer doesn't contain the right information or is completely incorrect, reply with 'No'.\n\n"
|
||||||
|
f"Question: {question}\n"
|
||||||
|
f"Correct Answer: {answer}\n"
|
||||||
|
f"Student Answer: {student_ans}\n\n"
|
||||||
|
"Your response should be just 'Yes' or 'No'."
|
||||||
|
)
|
||||||
|
|
||||||
|
formatted_prompt = tokenizer.apply_chat_template(
|
||||||
|
[{"role": "user", "content": prompt_text}],
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
)
|
||||||
|
prompts.append(formatted_prompt)
|
||||||
|
logger.debug(f"Created verification prompt for question: {question[:50]}...")
|
||||||
|
|
||||||
|
logger.info("Generating verification responses")
|
||||||
|
responses = vllm_generate_func(prompts)
|
||||||
|
responses_text = []
|
||||||
|
for response in responses:
|
||||||
|
# Handle different response formats
|
||||||
|
if hasattr(response, "outputs"):
|
||||||
|
try:
|
||||||
|
responses_text.append(response.outputs[0].text)
|
||||||
|
except (AttributeError, IndexError):
|
||||||
|
# Fallback for simple string responses
|
||||||
|
responses_text.append(str(response))
|
||||||
|
else:
|
||||||
|
responses_text.append(str(response))
|
||||||
|
logger.debug(f"Got {len(responses_text)} verification responses")
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for response in responses_text:
|
||||||
|
results.append("yes" in response.lower())
|
||||||
|
logger.debug(f"Verification result: {'yes' in response.lower()}")
|
||||||
|
|
||||||
|
logger.info(f"Verification complete. {sum(results)}/{len(results)} answers correct")
|
||||||
|
|
||||||
|
# Append the QA details and verifier's response to the specified log file
|
||||||
|
if log_file:
|
||||||
|
with open(log_file, "a") as file:
|
||||||
|
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
file.write(f"\n📝 === QA Evaluation at {timestamp} ===\n")
|
||||||
|
file.write(f"📂 File: {__file__}\n")
|
||||||
|
|
||||||
|
# Get current frame info safely
|
||||||
|
frame = inspect.currentframe()
|
||||||
|
if frame:
|
||||||
|
file.write(f"📍 Line: {frame.f_lineno}\n")
|
||||||
|
# Don't forget to delete the frame to avoid reference cycles
|
||||||
|
del frame
|
||||||
|
|
||||||
|
file.write("=" * 80 + "\n")
|
||||||
|
|
||||||
|
for i, (question, answer, student_ans, verifier_response) in enumerate(
|
||||||
|
zip(questions, answers, student_answers, responses_text)
|
||||||
|
):
|
||||||
|
file.write(f"\n❓ Question {i + 1}:\n")
|
||||||
|
file.write("-" * 40 + "\n")
|
||||||
|
file.write(f"📋 Question: {question}\n")
|
||||||
|
file.write(f"✅ Correct Answer: {answer}\n")
|
||||||
|
file.write(f"👨🎓 Student Answer: {student_ans}\n")
|
||||||
|
file.write(f"🔍 Verifier said: {verifier_response}\n")
|
||||||
|
|
||||||
|
# Add search results if available in the chat state
|
||||||
|
if isinstance(student_ans, dict) and "messages" in student_ans:
|
||||||
|
# Get messages from dict
|
||||||
|
messages = student_ans.get("messages", [])
|
||||||
|
search_results = [msg.get("content", "") for msg in messages if msg.get("role") == "ipython"]
|
||||||
|
if search_results:
|
||||||
|
file.write("\n🔎 Search Results:\n")
|
||||||
|
for j, result in enumerate(search_results, 1):
|
||||||
|
file.write(f"\nSearch {j}:\n{result}\n")
|
||||||
|
|
||||||
|
file.write("-" * 40 + "\n")
|
||||||
|
|
||||||
|
file.write(
|
||||||
|
f"\n📊 Summary: {sum(results)}/{len(results)} answers correct ({sum(results) / len(results) * 100:.2f}%)\n"
|
||||||
|
)
|
||||||
|
file.write("=" * 80 + "\n\n")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def run_eval(generate_fn, verify_fn, tokenizer, output_file=None, debug_file=None):
|
||||||
|
"""
|
||||||
|
Run evaluation on the test dataset and return results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
generate_fn: Function to generate completions
|
||||||
|
verify_fn: Function to verify results
|
||||||
|
tokenizer: Tokenizer for processing text
|
||||||
|
output_file: Path to save evaluation results summary
|
||||||
|
debug_file: Path to save detailed debug information
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
full_chat_states: The chat states from evaluation
|
||||||
|
"""
|
||||||
|
train_dataset, test_dataset = get_qa_dataset()
|
||||||
|
questions = test_dataset["prompt"]
|
||||||
|
|
||||||
|
# Create agent with appropriate adapter based on tokenizer
|
||||||
|
tokenizer_name = tokenizer.name_or_path.lower()
|
||||||
|
if "deepseek-r1-distill" in tokenizer_name:
|
||||||
|
adapter = R1DistilTokenizerAdapter()
|
||||||
|
elif "llama" in tokenizer_name:
|
||||||
|
adapter = LlamaTokenizerAdapter()
|
||||||
|
else:
|
||||||
|
adapter = R1DistilTokenizerAdapter()
|
||||||
|
|
||||||
|
agent = Agent(adapter)
|
||||||
|
agentic_outputs = agent.run_agent(generate_fn, tokenizer, questions)
|
||||||
|
full_chat_states = agentic_outputs.full_chat_states
|
||||||
|
final_responses = agentic_outputs.final_response_str
|
||||||
|
rewards = verify_fn(questions, full_chat_states, answer=test_dataset["answer"])
|
||||||
|
|
||||||
|
# Calculate results
|
||||||
|
percent_correct = sum(rewards) / len(rewards) * 100
|
||||||
|
|
||||||
|
# Log results to console
|
||||||
|
logger.info("RESULTS:")
|
||||||
|
logger.info(f"percentage of correct answers: {percent_correct:.2f}%")
|
||||||
|
logger.info("=" * 30)
|
||||||
|
|
||||||
|
# Save results to file if specified
|
||||||
|
if output_file:
|
||||||
|
try:
|
||||||
|
with open(output_file, "w") as f:
|
||||||
|
f.write("EVALUATION RESULTS\n")
|
||||||
|
f.write("=================\n\n")
|
||||||
|
f.write(f"Total questions: {len(questions)}\n")
|
||||||
|
f.write(f"Correct answers: {sum(rewards)}\n")
|
||||||
|
f.write(f"Percentage correct: {percent_correct:.2f}%\n\n")
|
||||||
|
|
||||||
|
f.write("Individual results:\n")
|
||||||
|
for i, (q, r, resp) in enumerate(zip(questions, rewards, final_responses)):
|
||||||
|
f.write(f"\nQ{i + 1}: {q[:100]}...\n")
|
||||||
|
f.write(f"Correct: {'✓' if r else '✗'}\n")
|
||||||
|
f.write(f"Response: {resp[:150]}...\n")
|
||||||
|
f.write("-" * 40 + "\n")
|
||||||
|
logger.info(f"Saved evaluation results to {output_file}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving results file: {e}")
|
||||||
|
|
||||||
|
# Save debug information if specified
|
||||||
|
if debug_file:
|
||||||
|
try:
|
||||||
|
import json
|
||||||
|
|
||||||
|
debug_data = []
|
||||||
|
for i, (q, r, resp, chat) in enumerate(zip(questions, rewards, final_responses, full_chat_states)):
|
||||||
|
debug_data.append(
|
||||||
|
{
|
||||||
|
"question_id": i,
|
||||||
|
"question": q,
|
||||||
|
"is_correct": bool(r),
|
||||||
|
"final_response": resp,
|
||||||
|
"chat_state": {
|
||||||
|
k: str(v) if isinstance(v, (list, dict)) else v
|
||||||
|
for k, v in chat.items()
|
||||||
|
if k != "tokenizer"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(debug_file, "w") as f:
|
||||||
|
json.dump(debug_data, f, indent=2)
|
||||||
|
logger.info(f"Saved debug information to {debug_file}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving debug file: {e}")
|
||||||
|
|
||||||
|
return full_chat_states
|
@ -0,0 +1,67 @@
|
|||||||
|
"""
|
||||||
|
Prompt-related functions for handling system and user prompts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
def get_system_prompt():
|
||||||
|
"""Get the system prompt with current date."""
|
||||||
|
current_date = datetime.now().strftime("%d %b %Y")
|
||||||
|
return f"""Cutting Knowledge Date: December 2023
|
||||||
|
Today Date: {current_date}
|
||||||
|
|
||||||
|
You are a helpful assistant with search capabilities.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def build_user_prompt(q):
|
||||||
|
"""
|
||||||
|
Build a user prompt with the question using the new template format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q (str): The question to ask
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Formatted user prompt
|
||||||
|
"""
|
||||||
|
user_prompt = f"""Answer the given question. \
|
||||||
|
You must conduct reasoning inside <think> and </think> first every time you get new information. \
|
||||||
|
After reasoning, if you find you lack some knowledge, you can call a search engine by <search> query </search>. \
|
||||||
|
Based on the user's core intent, formulate the most effective search query using specific, descriptive keywords that differentiate the topic clearly. \
|
||||||
|
Aim for queries that resemble how an expert searcher might phrase it, like using "compare lithium-ion vs solid-state battery efficiency" rather than just "batteries". \
|
||||||
|
The document will be provided inside <information> and </information> tags to you later. \
|
||||||
|
You can search as many turns as you want, but only one search query per turn. \
|
||||||
|
If you find no further external knowledge needed, you can directly provide the answer inside <answer> and </answer>, without detailed illustrations. \
|
||||||
|
Only answer when you have 100% confidence in the search results, else continue searching. \
|
||||||
|
Question: {q}\n"""
|
||||||
|
return user_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def format_search_results(results: str | list[str]) -> str:
|
||||||
|
"""
|
||||||
|
Format search results for display, matching the format from infer.py.
|
||||||
|
Each result should be in the format: "Doc X(Title: Y) content"
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: Search results as string or list of strings
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted search results with document titles
|
||||||
|
"""
|
||||||
|
if isinstance(results, list):
|
||||||
|
# If results are already in the correct format, just join them
|
||||||
|
if any("Doc" in r and "Title:" in r for r in results):
|
||||||
|
content = "\n".join(results)
|
||||||
|
else:
|
||||||
|
# If results are raw content, format them with default titles
|
||||||
|
content = "\n".join([f"Doc {i + 1}(Title: Document {i + 1})\n{r}" for i, r in enumerate(results)])
|
||||||
|
else:
|
||||||
|
# If single result is already formatted, use it as is
|
||||||
|
if "Doc" in results and "Title:" in results:
|
||||||
|
content = results
|
||||||
|
else:
|
||||||
|
# If single result is raw content, format it with default title
|
||||||
|
content = f"Doc 1(Title: Document 1)\n{results}"
|
||||||
|
|
||||||
|
return content
|
@ -0,0 +1,312 @@
|
|||||||
|
"""
|
||||||
|
Reward functions for RL training.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from src.config import logger
|
||||||
|
from src.evaluation import check_student_answers
|
||||||
|
|
||||||
|
|
||||||
|
def build_reward_correctness_fn(
|
||||||
|
vllm_generate_func,
|
||||||
|
tokenizer,
|
||||||
|
):
|
||||||
|
"""Build a reward function that checks correctness of student answers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vllm_generate_func: Function to generate answers using vLLM
|
||||||
|
tokenizer: Tokenizer for the model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A reward function that takes prompts and completions and returns correctness scores
|
||||||
|
"""
|
||||||
|
|
||||||
|
def reward_correctness(prompts: list, completions: list, **reward_kwargs) -> list:
|
||||||
|
"""Calculate reward based on correctness of student answers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompts: List of input prompts
|
||||||
|
completions: List of model completions
|
||||||
|
**reward_kwargs: Additional arguments for reward calculation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of correctness scores between 0 and 1
|
||||||
|
"""
|
||||||
|
teacher_answers = reward_kwargs["answer"]
|
||||||
|
student_answers = [completion["messages"][-1]["content"] for completion in completions]
|
||||||
|
|
||||||
|
# Log non-exact matches
|
||||||
|
for i, (student, teacher) in enumerate(zip(student_answers, teacher_answers)):
|
||||||
|
if student.strip().lower() != teacher.strip().lower():
|
||||||
|
logger.debug(f"Non-exact match at index {i}:\nStudent: {student}\nTeacher: {teacher}")
|
||||||
|
|
||||||
|
correct = check_student_answers(
|
||||||
|
prompts,
|
||||||
|
teacher_answers,
|
||||||
|
student_answers,
|
||||||
|
vllm_generate_func=vllm_generate_func,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Log correctness metrics with length info
|
||||||
|
logger.info(f"Correctness metrics: {correct}")
|
||||||
|
logger.info(f"Average correctness: {np.mean(correct):.2f}")
|
||||||
|
logger.info(f"Standard deviation: {np.std(correct):.2f}")
|
||||||
|
|
||||||
|
# Log length metrics
|
||||||
|
student_lengths = [len(ans.strip()) for ans in student_answers]
|
||||||
|
teacher_lengths = [len(ans.strip()) for ans in teacher_answers]
|
||||||
|
logger.info(f"Student lengths: {student_lengths}")
|
||||||
|
logger.info(f"Teacher lengths: {teacher_lengths}")
|
||||||
|
logger.info(f"Average student length: {np.mean(student_lengths):.2f}")
|
||||||
|
logger.info(f"Average teacher length: {np.mean(teacher_lengths):.2f}")
|
||||||
|
logger.info(f"Length ratio: {np.mean(student_lengths) / np.mean(teacher_lengths):.2f}")
|
||||||
|
|
||||||
|
return correct
|
||||||
|
|
||||||
|
return reward_correctness
|
||||||
|
|
||||||
|
|
||||||
|
def reward_format(prompts: list, completions: list, **reward_kwargs) -> list:
|
||||||
|
"""Reward function that checks if the completion follows the required format with proper tags.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompts: List of input prompts
|
||||||
|
completions: List of completion dictionaries containing messages
|
||||||
|
**reward_kwargs: Additional reward parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: List of rewards (1.0 for valid format, 0.0 for invalid)
|
||||||
|
"""
|
||||||
|
# Regex patterns for each tag type - no markdown allowed
|
||||||
|
think_pattern = r"<think>[\s\S]*?</think>"
|
||||||
|
search_pattern = r"<search>[\s\S]*?</search>"
|
||||||
|
answer_pattern = r"<answer>[\s\S]*?</answer>"
|
||||||
|
|
||||||
|
# Information tag patterns - handle multiple variants
|
||||||
|
info_patterns = [
|
||||||
|
r"<information>[\s\S]*?</information>", # Standard
|
||||||
|
r"<info>[\s\S]*?</info>", # Shortened
|
||||||
|
r"<Info[\w]*>[\s\S]*?</Info[\w]*>", # Capitalized variants
|
||||||
|
r"<INFORMATION>[\s\S]*?</INFORMATION>", # Uppercase
|
||||||
|
r"<INFO>[\s\S]*?</INFO>", # Uppercase shortened
|
||||||
|
]
|
||||||
|
|
||||||
|
# Invalid patterns (bold/italic tags)
|
||||||
|
invalid_patterns = [
|
||||||
|
r"\*\*<\/?(?:think|search|answer|information|info)>\*\*", # Bold tags
|
||||||
|
r"\*<\/?(?:think|search|answer|information|info)>\*", # Italic tags
|
||||||
|
r"_<\/?(?:think|search|answer|information|info)>_", # Underscore italic
|
||||||
|
]
|
||||||
|
|
||||||
|
rewards = []
|
||||||
|
|
||||||
|
for completion in completions:
|
||||||
|
messages = completion.get("messages", [])
|
||||||
|
assistant_msgs = [msg["content"] for msg in messages if msg["role"] == "assistant"]
|
||||||
|
|
||||||
|
if not assistant_msgs:
|
||||||
|
rewards.append(0.0)
|
||||||
|
continue
|
||||||
|
|
||||||
|
content = assistant_msgs[-1] # Get the last assistant message
|
||||||
|
|
||||||
|
# Check for invalid markdown formatting
|
||||||
|
has_invalid_tags = any(re.search(pattern, content) for pattern in invalid_patterns)
|
||||||
|
if has_invalid_tags:
|
||||||
|
logger.debug("Found markdown-formatted tags in response")
|
||||||
|
rewards.append(0.0)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check for any information tag variants (should not exist in assistant messages)
|
||||||
|
has_info_tags = False
|
||||||
|
for pattern in info_patterns:
|
||||||
|
info_matches = re.findall(pattern, content, re.IGNORECASE)
|
||||||
|
if info_matches:
|
||||||
|
logger.debug(f"Found {len(info_matches)} information tag(s) of type '{pattern}' in assistant message")
|
||||||
|
has_info_tags = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if has_info_tags:
|
||||||
|
rewards.append(0.0)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Find all tag matches
|
||||||
|
think_matches = re.findall(think_pattern, content)
|
||||||
|
search_matches = re.findall(search_pattern, content)
|
||||||
|
answer_matches = re.findall(answer_pattern, content)
|
||||||
|
|
||||||
|
# Verify tag presence and count
|
||||||
|
has_think = len(think_matches) >= 1
|
||||||
|
has_answer = len(answer_matches) == 1 # Must have exactly one answer
|
||||||
|
has_search = len(search_matches) >= 1 # One or more search tags
|
||||||
|
|
||||||
|
# Check for search and answer in the same message (not allowed)
|
||||||
|
if has_search and has_answer:
|
||||||
|
logger.debug("Found both search and answer tags in the same message")
|
||||||
|
rewards.append(0.0)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Award reward - must have think tag and either answer or search (but not both)
|
||||||
|
reward = 1.0 if has_think and (has_answer or has_search) else 0.0
|
||||||
|
rewards.append(reward)
|
||||||
|
|
||||||
|
# Log issues for debugging
|
||||||
|
if not reward:
|
||||||
|
logger.debug(f"Format issues - think: {has_think}, answer: {has_answer}, search: {has_search}")
|
||||||
|
if search_matches:
|
||||||
|
logger.debug(f"Number of search tags: {len(search_matches)}")
|
||||||
|
|
||||||
|
# Log overall metrics
|
||||||
|
logger.info(f"Format reward metrics - Mean: {np.mean(rewards):.3f}, Valid formats: {sum(rewards)}/{len(rewards)}")
|
||||||
|
|
||||||
|
return rewards
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Implement this reward function if the project survives
|
||||||
|
def reward_long_query(completions, **kwargs):
|
||||||
|
"""Reward function that checks if the query is long."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def reward_retry(prompts: list, completions: list, **reward_kwargs) -> list:
|
||||||
|
"""
|
||||||
|
Reward function that encourages optimal retry behavior.
|
||||||
|
Rewards increase with more search attempts but caps at optimal_search_count.
|
||||||
|
Penalizes having multiple searches in a single message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompts: List of input prompts
|
||||||
|
completions: List of completion dictionaries with messages
|
||||||
|
**reward_kwargs: Additional reward parameters (chunk_id, answer, etc.)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of rewards for each completion, rounded to 3 decimal places
|
||||||
|
"""
|
||||||
|
rewards = []
|
||||||
|
search_queries = []
|
||||||
|
violations = []
|
||||||
|
|
||||||
|
# Config for retry rewards
|
||||||
|
optimal_search_count = 5 # Cap rewards at this many searches
|
||||||
|
base_reward = 0.2 # Base reward for having at least one search
|
||||||
|
increment = 0.15 # Reward increment per search attempt (0.2 + 5*0.15 = 0.95 max)
|
||||||
|
violation_penalty = 0.5 # Penalty for having multiple searches in one message
|
||||||
|
|
||||||
|
# Regex pattern for search tags
|
||||||
|
search_pattern = r"<search>[\s\S]*?</search>"
|
||||||
|
|
||||||
|
for completion in completions:
|
||||||
|
# Get assistant messages
|
||||||
|
assistant_messages = [msg["content"] for msg in completion["messages"] if msg["role"] == "assistant"]
|
||||||
|
|
||||||
|
# Count search tags in assistant messages
|
||||||
|
message_searches = []
|
||||||
|
for msg in assistant_messages:
|
||||||
|
# Find all search tags in each message
|
||||||
|
search_matches = re.findall(search_pattern, msg)
|
||||||
|
message_searches.append(len(search_matches))
|
||||||
|
|
||||||
|
# Record total search queries
|
||||||
|
total_searches = sum(message_searches)
|
||||||
|
search_queries.append(total_searches)
|
||||||
|
|
||||||
|
# Check for violations (more than one search query per message)
|
||||||
|
violation = any(count > 1 for count in message_searches)
|
||||||
|
violations.append(violation)
|
||||||
|
|
||||||
|
# Calculate reward
|
||||||
|
if total_searches == 0:
|
||||||
|
reward = 0.0 # No searches = no reward
|
||||||
|
else:
|
||||||
|
# Base reward for having at least one search
|
||||||
|
reward = base_reward
|
||||||
|
|
||||||
|
# Add incremental reward for each search up to optimal_search_count
|
||||||
|
search_bonus = min(total_searches, optimal_search_count) * increment
|
||||||
|
reward += search_bonus
|
||||||
|
|
||||||
|
# Cap reward at 1.0
|
||||||
|
reward = min(1.0, reward)
|
||||||
|
|
||||||
|
# Apply penalty if there's a violation
|
||||||
|
if violation:
|
||||||
|
reward *= 1 - violation_penalty
|
||||||
|
|
||||||
|
# Round to 3 decimal places to avoid floating point precision issues
|
||||||
|
reward = round(reward, 3)
|
||||||
|
|
||||||
|
rewards.append(reward)
|
||||||
|
|
||||||
|
# Log metrics with search distribution info
|
||||||
|
logger.info(f"Retry behavior rewards: {np.mean(rewards):.3f} ± {np.std(rewards):.3f}")
|
||||||
|
logger.info(f"Search tags per completion: {np.mean(search_queries):.2f} ± {np.std(search_queries):.2f}")
|
||||||
|
logger.info(f"Violations (>1 search per message): {sum(violations)}/{len(violations)}")
|
||||||
|
logger.info(f"Search counts distribution: {search_queries}")
|
||||||
|
|
||||||
|
return rewards
|
||||||
|
|
||||||
|
|
||||||
|
def reward_em_chunk(prompts: list, completions: list, **reward_kwargs) -> list:
|
||||||
|
"""Reward function that checks if model's search queries hit the correct chunk content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompts: List of input prompts
|
||||||
|
completions: List of completion dictionaries with messages
|
||||||
|
**reward_kwargs: Additional reward parameters including:
|
||||||
|
- chunk_content: List of correct chunk contents to match against
|
||||||
|
- step: Optional step number for logging metrics
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: List of rewards (1.0 for exact match, 0.0 otherwise)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If chunk_content is not provided in reward_kwargs
|
||||||
|
"""
|
||||||
|
logger.debug(f"Calculating rewards for {len(prompts)} prompts")
|
||||||
|
|
||||||
|
# Get correct chunk contents from reward kwargs
|
||||||
|
correct_contents = reward_kwargs.get("chunk_content", [])
|
||||||
|
if not correct_contents:
|
||||||
|
logger.error("No chunk_content provided in reward_kwargs")
|
||||||
|
raise ValueError("chunk_content must be provided in reward_kwargs")
|
||||||
|
|
||||||
|
rewards = []
|
||||||
|
for i, (completion, correct_content) in enumerate(zip(completions, correct_contents)):
|
||||||
|
# Get all messages from ipython or user roles that start with <information>
|
||||||
|
search_results = [
|
||||||
|
msg["content"]
|
||||||
|
for msg in completion["messages"]
|
||||||
|
if msg["role"] in ("ipython", "user") and msg["content"].strip().startswith("<information>")
|
||||||
|
]
|
||||||
|
logger.debug(f"Found {len(search_results)} search results for prompt {i}")
|
||||||
|
|
||||||
|
# Log ground truth and searched chunks for debugging
|
||||||
|
logger.info(f"📝 Ground Truth Chunk: {correct_content}")
|
||||||
|
for j, result in enumerate(search_results):
|
||||||
|
logger.info(f"🔍 Searched Chunk {j + 1}: {result}")
|
||||||
|
|
||||||
|
# Check if any search hit the correct chunk content
|
||||||
|
found_correct_chunk = any(correct_content in result for result in search_results)
|
||||||
|
|
||||||
|
if not found_correct_chunk:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to find correct chunk for prompt {i}:\n"
|
||||||
|
f"Search results: {[r[:100] + '...' for r in search_results]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
reward = 1.0 if found_correct_chunk else 0.0
|
||||||
|
rewards.append(reward)
|
||||||
|
logger.debug(f"Reward for prompt {i}: {reward}")
|
||||||
|
|
||||||
|
# Log summary metrics
|
||||||
|
logger.info("Chunk Query Rewards Summary:")
|
||||||
|
logger.info(f"Total prompts: {len(prompts)}")
|
||||||
|
logger.info(f"Correct matches: {sum(rewards)}")
|
||||||
|
logger.info(f"Average reward: {np.mean(rewards):.3f}")
|
||||||
|
logger.info(f"Reward std: {np.std(rewards):.3f}")
|
||||||
|
|
||||||
|
return rewards
|
Loading…
Reference in new issue