"""
Reward functions for RL training.
"""
import json
import re
from datetime import datetime
from pathlib import Path
import numpy as np
from src.config import LOG_FOLDER, logger
from src.evaluation import check_student_answers
def build_reward_correctness_fn(
vllm_generate_func,
tokenizer,
):
"""Build a reward function that checks correctness of student answers.
Args:
vllm_generate_func: Function to generate answers using vLLM
tokenizer: Tokenizer for the model
Returns:
A reward function that takes prompts and completions and returns correctness scores
"""
def reward_correctness(prompts: list, completions: list, **reward_kwargs) -> list:
"""Calculate reward based on correctness of student answers.
Args:
prompts: List of input prompts
completions: List of model completions
**reward_kwargs: Additional arguments for reward calculation
Returns:
List of correctness scores between 0 and 1
"""
# correctness: must be assistant response and must have only one pair of answer tags
# should not have search tags and information tags
teacher_answers = reward_kwargs["answer"]
student_final_messages = [completion["messages"][-1]["content"] for completion in completions]
student_final_message_roles = [completion["messages"][-1]["role"] for completion in completions]
is_assistant_response = [role == "assistant" for role in student_final_message_roles]
has_answer_tag = [re.search(r"[\s\S]*?", ans) is not None for ans in student_final_messages]
has_search_tag = [re.search(r"[\s\S]*?", ans) is not None for ans in student_final_messages]
has_information_tag = [
re.search(r"[\s\S]*?", ans) is not None for ans in student_final_messages
]
might_be_correct = check_student_answers(
prompts,
teacher_answers,
student_final_messages,
vllm_generate_func=vllm_generate_func,
tokenizer=tokenizer,
)
# Convert lists to numpy arrays for element-wise operations
might_be_correct = np.array(might_be_correct)
is_assistant_response = np.array(is_assistant_response)
has_answer_tag = np.array(has_answer_tag)
has_search_tag = np.array(has_search_tag)
has_information_tag = np.array(has_information_tag)
# might be correct and is assistant response and has answer tag and no search or information tags
correct = might_be_correct & is_assistant_response & has_answer_tag & ~has_search_tag & ~has_information_tag
# Convert numpy array back to list for return
correct = correct.tolist()
# Log correctness metrics with length info
logger.info(f"Correctness metrics: {correct}")
logger.info(f"Average correctness: {np.mean(correct):.2f}")
logger.info(f"Standard deviation: {np.std(correct):.2f}")
# Log length metrics
student_lengths = [len(ans.strip()) for ans in student_final_messages]
teacher_lengths = [len(ans.strip()) for ans in teacher_answers]
logger.info(f"Student lengths: {student_lengths}")
logger.info(f"Teacher lengths: {teacher_lengths}")
logger.info(f"Average student length: {np.mean(student_lengths):.2f}")
logger.info(f"Average teacher length: {np.mean(teacher_lengths):.2f}")
logger.info(f"Length ratio: {np.mean(student_lengths) / np.mean(teacher_lengths):.2f}")
# Log chat state
log_chat_state(
prompts=prompts,
completions=completions,
rewards=correct,
reward_type="correctness",
teacher_answers=teacher_answers,
validation_results={
"is_assistant": is_assistant_response,
"has_answer": has_answer_tag,
"has_search": has_search_tag,
"has_info": has_information_tag,
"might_be_correct": might_be_correct,
},
)
return correct
return reward_correctness
def reward_format(prompts: list, completions: list, **reward_kwargs) -> list:
"""Reward function that checks if the completion follows the required format with proper tags.
Args:
prompts: List of input prompts
completions: List of completion dictionaries containing messages
**reward_kwargs: Additional reward parameters
Returns:
list: List of rewards (1.0 for valid format, 0.0 for invalid)
"""
# Regex patterns for each tag type - no markdown allowed
think_pattern = r"[\s\S]*?"
search_pattern = r"[\s\S]*?"
answer_pattern = r"[\s\S]*?"
# Information tag patterns - handle multiple variants
info_patterns = [
r"[\s\S]*?", # Standard
r"[\s\S]*?", # Shortened
r"[\s\S]*?", # Capitalized variants
r"[\s\S]*?", # Uppercase
r"[\s\S]*?", # Uppercase shortened
]
# Invalid patterns (bold/italic tags)
invalid_patterns = [
r"\*\*<\/?(?:think|search|answer|information|info)>\*\*", # Bold tags
r"\*<\/?(?:think|search|answer|information|info)>\*", # Italic tags
r"_<\/?(?:think|search|answer|information|info)>_", # Underscore italic
]
rewards = []
validation_results = {
"has_think": [],
"has_answer": [],
"has_search": [],
"has_invalid_tags": [],
"has_info_tags": [],
}
for completion in completions:
messages = completion.get("messages", [])
assistant_msgs = [msg["content"] for msg in messages if msg["role"] == "assistant"]
if not assistant_msgs:
rewards.append(0.0)
for key in validation_results:
validation_results[key].append(False)
continue
content = assistant_msgs[-1]
has_invalid_tags = any(re.search(pattern, content) for pattern in invalid_patterns)
validation_results["has_invalid_tags"].append(has_invalid_tags)
if has_invalid_tags:
rewards.append(0.0)
for key in ["has_think", "has_answer", "has_search", "has_info_tags"]:
validation_results[key].append(False)
continue
has_info_tags = False
for pattern in info_patterns:
if re.findall(pattern, content, re.IGNORECASE):
has_info_tags = True
break
validation_results["has_info_tags"].append(has_info_tags)
if has_info_tags:
rewards.append(0.0)
for key in ["has_think", "has_answer", "has_search"]:
validation_results[key].append(False)
continue
think_matches = re.findall(think_pattern, content)
search_matches = re.findall(search_pattern, content)
answer_matches = re.findall(answer_pattern, content)
has_think = len(think_matches) >= 1
has_answer = len(answer_matches) == 1
has_search = len(search_matches) >= 1
validation_results["has_think"].append(has_think)
validation_results["has_answer"].append(has_answer)
validation_results["has_search"].append(has_search)
if has_search and has_answer:
rewards.append(0.0)
continue
reward = 1.0 if has_think and (has_answer or has_search) else 0.0
rewards.append(reward)
if not reward:
logger.debug(f"Format issues - think: {has_think}, answer: {has_answer}, search: {has_search}")
if search_matches:
logger.debug(f"Number of search tags: {len(search_matches)}")
logger.info(f"Format reward metrics - Mean: {np.mean(rewards):.3f}, Valid formats: {sum(rewards)}/{len(rewards)}")
# Log chat state with validation results
log_chat_state(
prompts=prompts,
completions=completions,
rewards=rewards,
reward_type="format",
validation_results=validation_results,
)
return rewards
# TODO: Implement this reward function if the project survives
def reward_long_query(completions, **kwargs):
"""Reward function that checks if the query is long."""
pass
def reward_retry(prompts: list, completions: list, **reward_kwargs) -> list:
"""
Reward function that encourages optimal retry behavior.
Rewards increase with more search attempts but caps at optimal_search_count.
Penalizes having multiple searches in a single message.
Args:
prompts: List of input prompts
completions: List of completion dictionaries with messages
**reward_kwargs: Additional reward parameters (chunk_id, answer, etc.)
Returns:
List of rewards for each completion, rounded to 3 decimal places
"""
rewards = []
search_queries = []
violations = []
# Config for retry rewards
optimal_search_count = 5 # Cap rewards at this many searches
base_reward = 0.2 # Base reward for having at least one search
increment = 0.15 # Reward increment per search attempt (0.2 + 5*0.15 = 0.95 max)
violation_penalty = 0.5 # Penalty for having multiple searches in one message
# Regex pattern for search tags
search_pattern = r"[\s\S]*?"
for completion in completions:
# Get assistant messages
assistant_messages = [msg["content"] for msg in completion["messages"] if msg["role"] == "assistant"]
# Count search tags in assistant messages
message_searches = []
for msg in assistant_messages:
# Find all search tags in each message
search_matches = re.findall(search_pattern, msg)
message_searches.append(len(search_matches))
# Record total search queries
total_searches = sum(message_searches)
search_queries.append(total_searches)
# Check for violations (more than one search query per message)
violation = any(count > 1 for count in message_searches)
violations.append(violation)
# Calculate reward
if total_searches == 0:
reward = 0.0 # No searches = no reward
else:
# Base reward for having at least one search
reward = base_reward
# Add incremental reward for each search up to optimal_search_count
search_bonus = min(total_searches, optimal_search_count) * increment
reward += search_bonus
# Cap reward at 1.0
reward = min(1.0, reward)
# Apply penalty if there's a violation
if violation:
reward *= 1 - violation_penalty
# Round to 3 decimal places to avoid floating point precision issues
reward = round(reward, 3)
rewards.append(reward)
# Log metrics with search distribution info
logger.info(f"Retry behavior rewards: {np.mean(rewards):.3f} ± {np.std(rewards):.3f}")
logger.info(f"Search tags per completion: {np.mean(search_queries):.2f} ± {np.std(search_queries):.2f}")
logger.info(f"Violations (>1 search per message): {sum(violations)}/{len(violations)}")
logger.info(f"Search counts distribution: {search_queries}")
# Log chat state
log_chat_state(
prompts=prompts,
completions=completions,
rewards=rewards,
reward_type="retry",
search_counts=search_queries,
violations=violations,
optimal_search_count=optimal_search_count,
)
return rewards
def reward_em_chunk(prompts: list, completions: list, **reward_kwargs) -> list:
"""Reward function that checks if model's search queries hit the correct chunk content.
Args:
prompts: List of input prompts
completions: List of completion dictionaries with messages
**reward_kwargs: Additional reward parameters including:
- chunk_content: List of correct chunk contents to match against
- step: Optional step number for logging metrics
Returns:
list: List of rewards (1.0 for exact match, 0.0 otherwise)
Raises:
ValueError: If chunk_content is not provided in reward_kwargs
"""
logger.debug(f"Calculating rewards for {len(prompts)} prompts")
# Get correct chunk contents from reward kwargs
correct_contents = reward_kwargs.get("chunk_content", [])
if not correct_contents:
logger.error("No chunk_content provided in reward_kwargs")
raise ValueError("chunk_content must be provided in reward_kwargs")
rewards = []
for i, (completion, correct_content) in enumerate(zip(completions, correct_contents)):
# Get all messages from ipython or user roles that start with
search_results = [
msg["content"]
for msg in completion["messages"]
if msg["role"] in ("ipython", "user") and msg["content"].strip().startswith("")
]
logger.debug(f"Found {len(search_results)} search results for prompt {i}")
# Log ground truth and searched chunks for debugging
logger.info(f"📝 Ground Truth Chunk: {correct_content}")
for j, result in enumerate(search_results):
logger.info(f"🔍 Searched Chunk {j + 1}: {result}")
# Check if any search hit the correct chunk content
found_correct_chunk = any(correct_content in result for result in search_results)
if not found_correct_chunk:
logger.warning(
f"Failed to find correct chunk for prompt {i}:\n"
f"Search results: {[r[:100] + '...' for r in search_results]}"
)
reward = 1.0 if found_correct_chunk else 0.0
rewards.append(reward)
logger.debug(f"Reward for prompt {i}: {reward}")
# Log summary metrics
logger.info("Chunk Query Rewards Summary:")
logger.info(f"Total prompts: {len(prompts)}")
logger.info(f"Correct matches: {sum(rewards)}")
logger.info(f"Average reward: {np.mean(rewards):.3f}")
logger.info(f"Reward std: {np.std(rewards):.3f}")
# Log chat state
log_chat_state(
prompts=prompts,
completions=completions,
rewards=rewards,
reward_type="em_chunk",
correct_contents=correct_contents,
)
return rewards
def log_chat_state(prompts: list, completions: list, rewards: list, reward_type: str, **kwargs) -> None:
"""Log chat state and rewards to JSONL file.
Args:
prompts: List of input prompts
completions: List of model completions
rewards: List of calculated rewards
reward_type: Type of reward function used
**kwargs: Additional data to log
"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
chat_states_dir = LOG_FOLDER / "chat_states"
chat_states_dir.mkdir(parents=True, exist_ok=True)
# Convert numpy arrays to lists in kwargs
for key, value in kwargs.items():
if isinstance(value, dict):
for k, v in value.items():
if isinstance(v, np.ndarray):
kwargs[key][k] = v.tolist()
elif isinstance(value, np.ndarray):
kwargs[key] = value.tolist()
# Create one JSONL file per reward type
log_file = chat_states_dir / f"chat_states_{reward_type}.jsonl"
# Append each chat state as a new line
with open(log_file, "a", encoding="utf-8") as f:
for prompt, completion, reward in zip(prompts, completions, rewards):
chat_state = {
"timestamp": timestamp,
"reward_type": reward_type,
"prompt": prompt,
"messages": completion["messages"],
"reward": float(reward) if isinstance(reward, (np.number, np.ndarray)) else reward,
"metadata": kwargs,
}
f.write(json.dumps(chat_state, ensure_ascii=False) + "\n")
logger.info(f"💾 Appended {len(prompts)} chat states to {log_file}")