feat: refactor whole code base, add logic for training R1 distil base models, change some template and reward logics
- Break down rl_helpers into smaller modules - Removed deprecated rl_helpers module to streamline the codebase. - Enhance initial user prompt template inspired by Search-R1main
parent
c90c03267e
commit
31dcbf5d8a
@ -0,0 +1,43 @@
|
||||
"""
|
||||
Main package exports for RL helpers.
|
||||
"""
|
||||
|
||||
from trl.trainer.grpo_trainer import apply_chat_template
|
||||
|
||||
from src.agent import Agent
|
||||
from src.config import logger
|
||||
from src.evaluation import check_student_answers, run_eval, verify
|
||||
from src.prompts import build_user_prompt, format_search_results, get_system_prompt
|
||||
from src.rewards import (
|
||||
build_reward_correctness_fn,
|
||||
reward_em_chunk,
|
||||
reward_format,
|
||||
reward_retry,
|
||||
)
|
||||
from src.search_module import get_qa_dataset, search
|
||||
from src.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter
|
||||
|
||||
__all__ = [
|
||||
# Prompts
|
||||
"get_system_prompt",
|
||||
"build_user_prompt",
|
||||
"format_search_results",
|
||||
"apply_chat_template",
|
||||
# Agent
|
||||
"Agent",
|
||||
"LlamaTokenizerAdapter",
|
||||
"R1DistilTokenizerAdapter",
|
||||
# Rewards
|
||||
"build_reward_correctness_fn",
|
||||
"reward_format",
|
||||
"reward_retry",
|
||||
"reward_em_chunk",
|
||||
# Evaluation
|
||||
"run_eval",
|
||||
"check_student_answers",
|
||||
"verify",
|
||||
# Search
|
||||
"get_qa_dataset",
|
||||
"search",
|
||||
"logger",
|
||||
]
|
@ -0,0 +1,238 @@
|
||||
"""
|
||||
Evaluation utilities for RL training.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
from datetime import datetime
|
||||
|
||||
from src.agent import Agent
|
||||
from src.config import logger
|
||||
from src.search_module import get_qa_dataset
|
||||
from src.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter
|
||||
|
||||
|
||||
async def verify(student_answer: str, question: str, answer: str) -> bool:
|
||||
"""
|
||||
Verify if student's answer matches the correct answer.
|
||||
|
||||
Args:
|
||||
student_answer: The model's answer
|
||||
question: The original question
|
||||
answer: The ground truth answer
|
||||
|
||||
Returns:
|
||||
bool: True if answer is correct, False otherwise
|
||||
"""
|
||||
logger.debug(f"Verifying answer for question: {question}")
|
||||
logger.debug(f"Student answer: {student_answer}")
|
||||
logger.debug(f"Correct answer: {answer}")
|
||||
|
||||
# Simple string matching for now
|
||||
# TODO: Implement more sophisticated matching
|
||||
return student_answer.strip().lower() == answer.strip().lower()
|
||||
|
||||
|
||||
def check_student_answers(
|
||||
questions: list[str],
|
||||
answers: list[str],
|
||||
student_answers: list, # Can be strings or dicts
|
||||
vllm_generate_func,
|
||||
tokenizer,
|
||||
log_file=None,
|
||||
) -> list[bool]:
|
||||
"""
|
||||
Evaluates a list of student answers against the true answers using a vLLM generate function.
|
||||
|
||||
Args:
|
||||
questions: List of questions
|
||||
answers: List of correct answers
|
||||
student_answers: List of student answers to evaluate
|
||||
vllm_generate_func: Function to generate verification responses
|
||||
tokenizer: Tokenizer for formatting prompts
|
||||
log_file: Optional path to write detailed results
|
||||
|
||||
Returns:
|
||||
List of boolean results (True for correct answers)
|
||||
"""
|
||||
logger.info(f"Checking {len(questions)} student answers")
|
||||
|
||||
if not (len(questions) == len(answers) == len(student_answers)):
|
||||
logger.error("Mismatched lengths between questions, answers, and student answers")
|
||||
raise ValueError("The number of questions, answers, and student answers must be equal.")
|
||||
|
||||
prompts = []
|
||||
for question, answer, student_ans in zip(questions, answers, student_answers):
|
||||
prompt_text = (
|
||||
"You are grading a student's answer to a question. For the following question, "
|
||||
"compare the student's answer to the correct answer. Reply with 'Yes' if the student's answer contains the correct information, "
|
||||
"even if it's not an exact match. If the student's answer doesn't contain the right information or is completely incorrect, reply with 'No'.\n\n"
|
||||
f"Question: {question}\n"
|
||||
f"Correct Answer: {answer}\n"
|
||||
f"Student Answer: {student_ans}\n\n"
|
||||
"Your response should be just 'Yes' or 'No'."
|
||||
)
|
||||
|
||||
formatted_prompt = tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": prompt_text}],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
prompts.append(formatted_prompt)
|
||||
logger.debug(f"Created verification prompt for question: {question[:50]}...")
|
||||
|
||||
logger.info("Generating verification responses")
|
||||
responses = vllm_generate_func(prompts)
|
||||
responses_text = []
|
||||
for response in responses:
|
||||
# Handle different response formats
|
||||
if hasattr(response, "outputs"):
|
||||
try:
|
||||
responses_text.append(response.outputs[0].text)
|
||||
except (AttributeError, IndexError):
|
||||
# Fallback for simple string responses
|
||||
responses_text.append(str(response))
|
||||
else:
|
||||
responses_text.append(str(response))
|
||||
logger.debug(f"Got {len(responses_text)} verification responses")
|
||||
|
||||
results = []
|
||||
for response in responses_text:
|
||||
results.append("yes" in response.lower())
|
||||
logger.debug(f"Verification result: {'yes' in response.lower()}")
|
||||
|
||||
logger.info(f"Verification complete. {sum(results)}/{len(results)} answers correct")
|
||||
|
||||
# Append the QA details and verifier's response to the specified log file
|
||||
if log_file:
|
||||
with open(log_file, "a") as file:
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
file.write(f"\n📝 === QA Evaluation at {timestamp} ===\n")
|
||||
file.write(f"📂 File: {__file__}\n")
|
||||
|
||||
# Get current frame info safely
|
||||
frame = inspect.currentframe()
|
||||
if frame:
|
||||
file.write(f"📍 Line: {frame.f_lineno}\n")
|
||||
# Don't forget to delete the frame to avoid reference cycles
|
||||
del frame
|
||||
|
||||
file.write("=" * 80 + "\n")
|
||||
|
||||
for i, (question, answer, student_ans, verifier_response) in enumerate(
|
||||
zip(questions, answers, student_answers, responses_text)
|
||||
):
|
||||
file.write(f"\n❓ Question {i + 1}:\n")
|
||||
file.write("-" * 40 + "\n")
|
||||
file.write(f"📋 Question: {question}\n")
|
||||
file.write(f"✅ Correct Answer: {answer}\n")
|
||||
file.write(f"👨🎓 Student Answer: {student_ans}\n")
|
||||
file.write(f"🔍 Verifier said: {verifier_response}\n")
|
||||
|
||||
# Add search results if available in the chat state
|
||||
if isinstance(student_ans, dict) and "messages" in student_ans:
|
||||
# Get messages from dict
|
||||
messages = student_ans.get("messages", [])
|
||||
search_results = [msg.get("content", "") for msg in messages if msg.get("role") == "ipython"]
|
||||
if search_results:
|
||||
file.write("\n🔎 Search Results:\n")
|
||||
for j, result in enumerate(search_results, 1):
|
||||
file.write(f"\nSearch {j}:\n{result}\n")
|
||||
|
||||
file.write("-" * 40 + "\n")
|
||||
|
||||
file.write(
|
||||
f"\n📊 Summary: {sum(results)}/{len(results)} answers correct ({sum(results) / len(results) * 100:.2f}%)\n"
|
||||
)
|
||||
file.write("=" * 80 + "\n\n")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def run_eval(generate_fn, verify_fn, tokenizer, output_file=None, debug_file=None):
|
||||
"""
|
||||
Run evaluation on the test dataset and return results.
|
||||
|
||||
Args:
|
||||
generate_fn: Function to generate completions
|
||||
verify_fn: Function to verify results
|
||||
tokenizer: Tokenizer for processing text
|
||||
output_file: Path to save evaluation results summary
|
||||
debug_file: Path to save detailed debug information
|
||||
|
||||
Returns:
|
||||
full_chat_states: The chat states from evaluation
|
||||
"""
|
||||
train_dataset, test_dataset = get_qa_dataset()
|
||||
questions = test_dataset["prompt"]
|
||||
|
||||
# Create agent with appropriate adapter based on tokenizer
|
||||
tokenizer_name = tokenizer.name_or_path.lower()
|
||||
if "deepseek-r1-distill" in tokenizer_name:
|
||||
adapter = R1DistilTokenizerAdapter()
|
||||
elif "llama" in tokenizer_name:
|
||||
adapter = LlamaTokenizerAdapter()
|
||||
else:
|
||||
adapter = R1DistilTokenizerAdapter()
|
||||
|
||||
agent = Agent(adapter)
|
||||
agentic_outputs = agent.run_agent(generate_fn, tokenizer, questions)
|
||||
full_chat_states = agentic_outputs.full_chat_states
|
||||
final_responses = agentic_outputs.final_response_str
|
||||
rewards = verify_fn(questions, full_chat_states, answer=test_dataset["answer"])
|
||||
|
||||
# Calculate results
|
||||
percent_correct = sum(rewards) / len(rewards) * 100
|
||||
|
||||
# Log results to console
|
||||
logger.info("RESULTS:")
|
||||
logger.info(f"percentage of correct answers: {percent_correct:.2f}%")
|
||||
logger.info("=" * 30)
|
||||
|
||||
# Save results to file if specified
|
||||
if output_file:
|
||||
try:
|
||||
with open(output_file, "w") as f:
|
||||
f.write("EVALUATION RESULTS\n")
|
||||
f.write("=================\n\n")
|
||||
f.write(f"Total questions: {len(questions)}\n")
|
||||
f.write(f"Correct answers: {sum(rewards)}\n")
|
||||
f.write(f"Percentage correct: {percent_correct:.2f}%\n\n")
|
||||
|
||||
f.write("Individual results:\n")
|
||||
for i, (q, r, resp) in enumerate(zip(questions, rewards, final_responses)):
|
||||
f.write(f"\nQ{i + 1}: {q[:100]}...\n")
|
||||
f.write(f"Correct: {'✓' if r else '✗'}\n")
|
||||
f.write(f"Response: {resp[:150]}...\n")
|
||||
f.write("-" * 40 + "\n")
|
||||
logger.info(f"Saved evaluation results to {output_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving results file: {e}")
|
||||
|
||||
# Save debug information if specified
|
||||
if debug_file:
|
||||
try:
|
||||
import json
|
||||
|
||||
debug_data = []
|
||||
for i, (q, r, resp, chat) in enumerate(zip(questions, rewards, final_responses, full_chat_states)):
|
||||
debug_data.append(
|
||||
{
|
||||
"question_id": i,
|
||||
"question": q,
|
||||
"is_correct": bool(r),
|
||||
"final_response": resp,
|
||||
"chat_state": {
|
||||
k: str(v) if isinstance(v, (list, dict)) else v
|
||||
for k, v in chat.items()
|
||||
if k != "tokenizer"
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
with open(debug_file, "w") as f:
|
||||
json.dump(debug_data, f, indent=2)
|
||||
logger.info(f"Saved debug information to {debug_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving debug file: {e}")
|
||||
|
||||
return full_chat_states
|
@ -0,0 +1,67 @@
|
||||
"""
|
||||
Prompt-related functions for handling system and user prompts.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def get_system_prompt():
|
||||
"""Get the system prompt with current date."""
|
||||
current_date = datetime.now().strftime("%d %b %Y")
|
||||
return f"""Cutting Knowledge Date: December 2023
|
||||
Today Date: {current_date}
|
||||
|
||||
You are a helpful assistant with search capabilities.
|
||||
"""
|
||||
|
||||
|
||||
def build_user_prompt(q):
|
||||
"""
|
||||
Build a user prompt with the question using the new template format.
|
||||
|
||||
Args:
|
||||
q (str): The question to ask
|
||||
|
||||
Returns:
|
||||
str: Formatted user prompt
|
||||
"""
|
||||
user_prompt = f"""Answer the given question. \
|
||||
You must conduct reasoning inside <think> and </think> first every time you get new information. \
|
||||
After reasoning, if you find you lack some knowledge, you can call a search engine by <search> query </search>. \
|
||||
Based on the user's core intent, formulate the most effective search query using specific, descriptive keywords that differentiate the topic clearly. \
|
||||
Aim for queries that resemble how an expert searcher might phrase it, like using "compare lithium-ion vs solid-state battery efficiency" rather than just "batteries". \
|
||||
The document will be provided inside <information> and </information> tags to you later. \
|
||||
You can search as many turns as you want, but only one search query per turn. \
|
||||
If you find no further external knowledge needed, you can directly provide the answer inside <answer> and </answer>, without detailed illustrations. \
|
||||
Only answer when you have 100% confidence in the search results, else continue searching. \
|
||||
Question: {q}\n"""
|
||||
return user_prompt
|
||||
|
||||
|
||||
def format_search_results(results: str | list[str]) -> str:
|
||||
"""
|
||||
Format search results for display, matching the format from infer.py.
|
||||
Each result should be in the format: "Doc X(Title: Y) content"
|
||||
|
||||
Args:
|
||||
results: Search results as string or list of strings
|
||||
|
||||
Returns:
|
||||
Formatted search results with document titles
|
||||
"""
|
||||
if isinstance(results, list):
|
||||
# If results are already in the correct format, just join them
|
||||
if any("Doc" in r and "Title:" in r for r in results):
|
||||
content = "\n".join(results)
|
||||
else:
|
||||
# If results are raw content, format them with default titles
|
||||
content = "\n".join([f"Doc {i + 1}(Title: Document {i + 1})\n{r}" for i, r in enumerate(results)])
|
||||
else:
|
||||
# If single result is already formatted, use it as is
|
||||
if "Doc" in results and "Title:" in results:
|
||||
content = results
|
||||
else:
|
||||
# If single result is raw content, format it with default title
|
||||
content = f"Doc 1(Title: Document 1)\n{results}"
|
||||
|
||||
return content
|
@ -0,0 +1,312 @@
|
||||
"""
|
||||
Reward functions for RL training.
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
|
||||
from src.config import logger
|
||||
from src.evaluation import check_student_answers
|
||||
|
||||
|
||||
def build_reward_correctness_fn(
|
||||
vllm_generate_func,
|
||||
tokenizer,
|
||||
):
|
||||
"""Build a reward function that checks correctness of student answers.
|
||||
|
||||
Args:
|
||||
vllm_generate_func: Function to generate answers using vLLM
|
||||
tokenizer: Tokenizer for the model
|
||||
|
||||
Returns:
|
||||
A reward function that takes prompts and completions and returns correctness scores
|
||||
"""
|
||||
|
||||
def reward_correctness(prompts: list, completions: list, **reward_kwargs) -> list:
|
||||
"""Calculate reward based on correctness of student answers.
|
||||
|
||||
Args:
|
||||
prompts: List of input prompts
|
||||
completions: List of model completions
|
||||
**reward_kwargs: Additional arguments for reward calculation
|
||||
|
||||
Returns:
|
||||
List of correctness scores between 0 and 1
|
||||
"""
|
||||
teacher_answers = reward_kwargs["answer"]
|
||||
student_answers = [completion["messages"][-1]["content"] for completion in completions]
|
||||
|
||||
# Log non-exact matches
|
||||
for i, (student, teacher) in enumerate(zip(student_answers, teacher_answers)):
|
||||
if student.strip().lower() != teacher.strip().lower():
|
||||
logger.debug(f"Non-exact match at index {i}:\nStudent: {student}\nTeacher: {teacher}")
|
||||
|
||||
correct = check_student_answers(
|
||||
prompts,
|
||||
teacher_answers,
|
||||
student_answers,
|
||||
vllm_generate_func=vllm_generate_func,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
# Log correctness metrics with length info
|
||||
logger.info(f"Correctness metrics: {correct}")
|
||||
logger.info(f"Average correctness: {np.mean(correct):.2f}")
|
||||
logger.info(f"Standard deviation: {np.std(correct):.2f}")
|
||||
|
||||
# Log length metrics
|
||||
student_lengths = [len(ans.strip()) for ans in student_answers]
|
||||
teacher_lengths = [len(ans.strip()) for ans in teacher_answers]
|
||||
logger.info(f"Student lengths: {student_lengths}")
|
||||
logger.info(f"Teacher lengths: {teacher_lengths}")
|
||||
logger.info(f"Average student length: {np.mean(student_lengths):.2f}")
|
||||
logger.info(f"Average teacher length: {np.mean(teacher_lengths):.2f}")
|
||||
logger.info(f"Length ratio: {np.mean(student_lengths) / np.mean(teacher_lengths):.2f}")
|
||||
|
||||
return correct
|
||||
|
||||
return reward_correctness
|
||||
|
||||
|
||||
def reward_format(prompts: list, completions: list, **reward_kwargs) -> list:
|
||||
"""Reward function that checks if the completion follows the required format with proper tags.
|
||||
|
||||
Args:
|
||||
prompts: List of input prompts
|
||||
completions: List of completion dictionaries containing messages
|
||||
**reward_kwargs: Additional reward parameters
|
||||
|
||||
Returns:
|
||||
list: List of rewards (1.0 for valid format, 0.0 for invalid)
|
||||
"""
|
||||
# Regex patterns for each tag type - no markdown allowed
|
||||
think_pattern = r"<think>[\s\S]*?</think>"
|
||||
search_pattern = r"<search>[\s\S]*?</search>"
|
||||
answer_pattern = r"<answer>[\s\S]*?</answer>"
|
||||
|
||||
# Information tag patterns - handle multiple variants
|
||||
info_patterns = [
|
||||
r"<information>[\s\S]*?</information>", # Standard
|
||||
r"<info>[\s\S]*?</info>", # Shortened
|
||||
r"<Info[\w]*>[\s\S]*?</Info[\w]*>", # Capitalized variants
|
||||
r"<INFORMATION>[\s\S]*?</INFORMATION>", # Uppercase
|
||||
r"<INFO>[\s\S]*?</INFO>", # Uppercase shortened
|
||||
]
|
||||
|
||||
# Invalid patterns (bold/italic tags)
|
||||
invalid_patterns = [
|
||||
r"\*\*<\/?(?:think|search|answer|information|info)>\*\*", # Bold tags
|
||||
r"\*<\/?(?:think|search|answer|information|info)>\*", # Italic tags
|
||||
r"_<\/?(?:think|search|answer|information|info)>_", # Underscore italic
|
||||
]
|
||||
|
||||
rewards = []
|
||||
|
||||
for completion in completions:
|
||||
messages = completion.get("messages", [])
|
||||
assistant_msgs = [msg["content"] for msg in messages if msg["role"] == "assistant"]
|
||||
|
||||
if not assistant_msgs:
|
||||
rewards.append(0.0)
|
||||
continue
|
||||
|
||||
content = assistant_msgs[-1] # Get the last assistant message
|
||||
|
||||
# Check for invalid markdown formatting
|
||||
has_invalid_tags = any(re.search(pattern, content) for pattern in invalid_patterns)
|
||||
if has_invalid_tags:
|
||||
logger.debug("Found markdown-formatted tags in response")
|
||||
rewards.append(0.0)
|
||||
continue
|
||||
|
||||
# Check for any information tag variants (should not exist in assistant messages)
|
||||
has_info_tags = False
|
||||
for pattern in info_patterns:
|
||||
info_matches = re.findall(pattern, content, re.IGNORECASE)
|
||||
if info_matches:
|
||||
logger.debug(f"Found {len(info_matches)} information tag(s) of type '{pattern}' in assistant message")
|
||||
has_info_tags = True
|
||||
break
|
||||
|
||||
if has_info_tags:
|
||||
rewards.append(0.0)
|
||||
continue
|
||||
|
||||
# Find all tag matches
|
||||
think_matches = re.findall(think_pattern, content)
|
||||
search_matches = re.findall(search_pattern, content)
|
||||
answer_matches = re.findall(answer_pattern, content)
|
||||
|
||||
# Verify tag presence and count
|
||||
has_think = len(think_matches) >= 1
|
||||
has_answer = len(answer_matches) == 1 # Must have exactly one answer
|
||||
has_search = len(search_matches) >= 1 # One or more search tags
|
||||
|
||||
# Check for search and answer in the same message (not allowed)
|
||||
if has_search and has_answer:
|
||||
logger.debug("Found both search and answer tags in the same message")
|
||||
rewards.append(0.0)
|
||||
continue
|
||||
|
||||
# Award reward - must have think tag and either answer or search (but not both)
|
||||
reward = 1.0 if has_think and (has_answer or has_search) else 0.0
|
||||
rewards.append(reward)
|
||||
|
||||
# Log issues for debugging
|
||||
if not reward:
|
||||
logger.debug(f"Format issues - think: {has_think}, answer: {has_answer}, search: {has_search}")
|
||||
if search_matches:
|
||||
logger.debug(f"Number of search tags: {len(search_matches)}")
|
||||
|
||||
# Log overall metrics
|
||||
logger.info(f"Format reward metrics - Mean: {np.mean(rewards):.3f}, Valid formats: {sum(rewards)}/{len(rewards)}")
|
||||
|
||||
return rewards
|
||||
|
||||
|
||||
# TODO: Implement this reward function if the project survives
|
||||
def reward_long_query(completions, **kwargs):
|
||||
"""Reward function that checks if the query is long."""
|
||||
pass
|
||||
|
||||
|
||||
def reward_retry(prompts: list, completions: list, **reward_kwargs) -> list:
|
||||
"""
|
||||
Reward function that encourages optimal retry behavior.
|
||||
Rewards increase with more search attempts but caps at optimal_search_count.
|
||||
Penalizes having multiple searches in a single message.
|
||||
|
||||
Args:
|
||||
prompts: List of input prompts
|
||||
completions: List of completion dictionaries with messages
|
||||
**reward_kwargs: Additional reward parameters (chunk_id, answer, etc.)
|
||||
|
||||
Returns:
|
||||
List of rewards for each completion, rounded to 3 decimal places
|
||||
"""
|
||||
rewards = []
|
||||
search_queries = []
|
||||
violations = []
|
||||
|
||||
# Config for retry rewards
|
||||
optimal_search_count = 5 # Cap rewards at this many searches
|
||||
base_reward = 0.2 # Base reward for having at least one search
|
||||
increment = 0.15 # Reward increment per search attempt (0.2 + 5*0.15 = 0.95 max)
|
||||
violation_penalty = 0.5 # Penalty for having multiple searches in one message
|
||||
|
||||
# Regex pattern for search tags
|
||||
search_pattern = r"<search>[\s\S]*?</search>"
|
||||
|
||||
for completion in completions:
|
||||
# Get assistant messages
|
||||
assistant_messages = [msg["content"] for msg in completion["messages"] if msg["role"] == "assistant"]
|
||||
|
||||
# Count search tags in assistant messages
|
||||
message_searches = []
|
||||
for msg in assistant_messages:
|
||||
# Find all search tags in each message
|
||||
search_matches = re.findall(search_pattern, msg)
|
||||
message_searches.append(len(search_matches))
|
||||
|
||||
# Record total search queries
|
||||
total_searches = sum(message_searches)
|
||||
search_queries.append(total_searches)
|
||||
|
||||
# Check for violations (more than one search query per message)
|
||||
violation = any(count > 1 for count in message_searches)
|
||||
violations.append(violation)
|
||||
|
||||
# Calculate reward
|
||||
if total_searches == 0:
|
||||
reward = 0.0 # No searches = no reward
|
||||
else:
|
||||
# Base reward for having at least one search
|
||||
reward = base_reward
|
||||
|
||||
# Add incremental reward for each search up to optimal_search_count
|
||||
search_bonus = min(total_searches, optimal_search_count) * increment
|
||||
reward += search_bonus
|
||||
|
||||
# Cap reward at 1.0
|
||||
reward = min(1.0, reward)
|
||||
|
||||
# Apply penalty if there's a violation
|
||||
if violation:
|
||||
reward *= 1 - violation_penalty
|
||||
|
||||
# Round to 3 decimal places to avoid floating point precision issues
|
||||
reward = round(reward, 3)
|
||||
|
||||
rewards.append(reward)
|
||||
|
||||
# Log metrics with search distribution info
|
||||
logger.info(f"Retry behavior rewards: {np.mean(rewards):.3f} ± {np.std(rewards):.3f}")
|
||||
logger.info(f"Search tags per completion: {np.mean(search_queries):.2f} ± {np.std(search_queries):.2f}")
|
||||
logger.info(f"Violations (>1 search per message): {sum(violations)}/{len(violations)}")
|
||||
logger.info(f"Search counts distribution: {search_queries}")
|
||||
|
||||
return rewards
|
||||
|
||||
|
||||
def reward_em_chunk(prompts: list, completions: list, **reward_kwargs) -> list:
|
||||
"""Reward function that checks if model's search queries hit the correct chunk content.
|
||||
|
||||
Args:
|
||||
prompts: List of input prompts
|
||||
completions: List of completion dictionaries with messages
|
||||
**reward_kwargs: Additional reward parameters including:
|
||||
- chunk_content: List of correct chunk contents to match against
|
||||
- step: Optional step number for logging metrics
|
||||
|
||||
Returns:
|
||||
list: List of rewards (1.0 for exact match, 0.0 otherwise)
|
||||
|
||||
Raises:
|
||||
ValueError: If chunk_content is not provided in reward_kwargs
|
||||
"""
|
||||
logger.debug(f"Calculating rewards for {len(prompts)} prompts")
|
||||
|
||||
# Get correct chunk contents from reward kwargs
|
||||
correct_contents = reward_kwargs.get("chunk_content", [])
|
||||
if not correct_contents:
|
||||
logger.error("No chunk_content provided in reward_kwargs")
|
||||
raise ValueError("chunk_content must be provided in reward_kwargs")
|
||||
|
||||
rewards = []
|
||||
for i, (completion, correct_content) in enumerate(zip(completions, correct_contents)):
|
||||
# Get all messages from ipython or user roles that start with <information>
|
||||
search_results = [
|
||||
msg["content"]
|
||||
for msg in completion["messages"]
|
||||
if msg["role"] in ("ipython", "user") and msg["content"].strip().startswith("<information>")
|
||||
]
|
||||
logger.debug(f"Found {len(search_results)} search results for prompt {i}")
|
||||
|
||||
# Log ground truth and searched chunks for debugging
|
||||
logger.info(f"📝 Ground Truth Chunk: {correct_content}")
|
||||
for j, result in enumerate(search_results):
|
||||
logger.info(f"🔍 Searched Chunk {j + 1}: {result}")
|
||||
|
||||
# Check if any search hit the correct chunk content
|
||||
found_correct_chunk = any(correct_content in result for result in search_results)
|
||||
|
||||
if not found_correct_chunk:
|
||||
logger.warning(
|
||||
f"Failed to find correct chunk for prompt {i}:\n"
|
||||
f"Search results: {[r[:100] + '...' for r in search_results]}"
|
||||
)
|
||||
|
||||
reward = 1.0 if found_correct_chunk else 0.0
|
||||
rewards.append(reward)
|
||||
logger.debug(f"Reward for prompt {i}: {reward}")
|
||||
|
||||
# Log summary metrics
|
||||
logger.info("Chunk Query Rewards Summary:")
|
||||
logger.info(f"Total prompts: {len(prompts)}")
|
||||
logger.info(f"Correct matches: {sum(rewards)}")
|
||||
logger.info(f"Average reward: {np.mean(rewards):.3f}")
|
||||
logger.info(f"Reward std: {np.std(rewards):.3f}")
|
||||
|
||||
return rewards
|
Loading…
Reference in new issue