feat: expand reward functions with new strategies and diversity checks

- Added reward functions for search strategy and search diversity
- Updated reward_format to include validation for proper message endings.
main
thinhlpg 1 month ago
parent d0e6068055
commit 4de31e0f30

@ -5,7 +5,7 @@ Reward functions for RL training.
import json
import re
from datetime import datetime
from pathlib import Path
from difflib import SequenceMatcher
import numpy as np
@ -145,6 +145,7 @@ def reward_format(prompts: list, completions: list, **reward_kwargs) -> list:
"has_search": [],
"has_invalid_tags": [],
"has_info_tags": [],
"ends_properly": [], # New validation result
}
for completion in completions:
@ -159,6 +160,11 @@ def reward_format(prompts: list, completions: list, **reward_kwargs) -> list:
content = assistant_msgs[-1]
# Check if content ends with </search> or </answer> (ignoring whitespace)
content_stripped = content.strip()
ends_properly = content_stripped.endswith("</search>") or content_stripped.endswith("</answer>")
validation_results["ends_properly"].append(ends_properly)
has_invalid_tags = any(re.search(pattern, content) for pattern in invalid_patterns)
validation_results["has_invalid_tags"].append(has_invalid_tags)
if has_invalid_tags:
@ -196,15 +202,30 @@ def reward_format(prompts: list, completions: list, **reward_kwargs) -> list:
rewards.append(0.0)
continue
reward = 1.0 if has_think and (has_answer or has_search) else 0.0
# Check for proper tag sequence - think must come before answer/search
if has_answer or has_search:
last_think_pos = content.rfind("</think>")
answer_pos = content.find("<answer>") if has_answer else float("inf")
search_pos = content.find("<search>") if has_search else float("inf")
tag_pos = min(answer_pos, search_pos)
if last_think_pos == -1 or last_think_pos > tag_pos:
rewards.append(0.0)
continue
# Only reward if format is valid AND response ends properly
reward = 1.0 if has_think and (has_answer or has_search) and ends_properly else 0.0
rewards.append(reward)
if not reward:
logger.debug(f"Format issues - think: {has_think}, answer: {has_answer}, search: {has_search}")
logger.debug(
f"Format issues - think: {has_think}, answer: {has_answer}, search: {has_search}, ends_properly: {ends_properly}"
)
if search_matches:
logger.debug(f"Number of search tags: {len(search_matches)}")
logger.info(f"Format reward metrics - Mean: {np.mean(rewards):.3f}, Valid formats: {sum(rewards)}/{len(rewards)}")
logger.info(f"Responses ending properly: {sum(validation_results['ends_properly'])}/{len(rewards)}")
# Log chat state with validation results
log_chat_state(
@ -218,12 +239,6 @@ def reward_format(prompts: list, completions: list, **reward_kwargs) -> list:
return rewards
# TODO: Implement this reward function if the project survives
def reward_long_query(completions, **kwargs):
"""Reward function that checks if the query is long."""
pass
def reward_retry(prompts: list, completions: list, **reward_kwargs) -> list:
"""
Reward function that encourages optimal retry behavior.
@ -384,6 +399,402 @@ def reward_em_chunk(prompts: list, completions: list, **reward_kwargs) -> list:
return rewards
def tag_count_reward(prompts: list, completions: list, **reward_kwargs) -> list:
"""Reward function that checks for proper tag counts in the conversation.
Rewards:
- 0.1 for each proper pair of think tags in each assistant message
- 0.5 for having exactly one pair of answer tags in entire conversation
- 0.1 for each proper pair of search tags
Args:
prompts: List of input prompts
completions: List of completion dictionaries with messages
**reward_kwargs: Additional reward parameters
Returns:
list: List of rewards between 0 and 1
"""
rewards = []
validation_results = {
"think_pairs_per_msg": [], # List of lists, each inner list has think pair counts per assistant msg
"answer_pairs": [], # Total answer pairs in conversation
"search_pairs": [], # Total search pairs in conversation
}
for completion in completions:
# Get all assistant messages
assistant_msgs = [msg["content"] for msg in completion["messages"] if msg["role"] == "assistant"]
if not assistant_msgs:
rewards.append(0.0)
validation_results["think_pairs_per_msg"].append([])
validation_results["answer_pairs"].append(0)
validation_results["search_pairs"].append(0)
continue
# Count think pairs per assistant message
think_pairs_per_msg = []
for msg in assistant_msgs:
# Count complete think tag pairs
think_opens = len(re.findall(r"<think>", msg))
think_closes = len(re.findall(r"</think>", msg))
think_pairs = min(think_opens, think_closes)
think_pairs_per_msg.append(think_pairs)
# Count answer tags in entire conversation (should be exactly one pair)
total_answer_opens = sum(msg.count("<answer>") for msg in assistant_msgs)
total_answer_closes = sum(msg.count("</answer>") for msg in assistant_msgs)
answer_pairs = min(total_answer_opens, total_answer_closes)
# Count search tags
total_search_opens = sum(msg.count("<search>") for msg in assistant_msgs)
total_search_closes = sum(msg.count("</search>") for msg in assistant_msgs)
search_pairs = min(total_search_opens, total_search_closes)
# Calculate reward components
think_reward = sum(min(pairs, 1) * 0.1 for pairs in think_pairs_per_msg) # 0.1 per msg with proper think pair
answer_reward = 0.5 if answer_pairs == 1 else 0.0 # 0.5 for exactly one answer pair
search_reward = min(search_pairs, 1) * 0.1 # 0.1 for having search pairs
total_reward = min(think_reward + answer_reward + search_reward, 1.0)
rewards.append(total_reward)
# Store validation results
validation_results["think_pairs_per_msg"].append(think_pairs_per_msg)
validation_results["answer_pairs"].append(answer_pairs)
validation_results["search_pairs"].append(search_pairs)
# Debug logging
if total_reward < 1.0:
logger.debug(
f"Tag count issues - think_pairs: {think_pairs_per_msg}, "
f"answer_pairs: {answer_pairs}, search_pairs: {search_pairs}"
)
# Log metrics
logger.info(
f"Tag count reward metrics - Mean: {np.mean(rewards):.3f}, Perfect scores: {sum(r == 1.0 for r in rewards)}/{len(rewards)}"
)
logger.info(
f"Average think pairs per message: {np.mean([np.mean(pairs) if pairs else 0 for pairs in validation_results['think_pairs_per_msg']]):.2f}"
)
logger.info(
f"Conversations with exactly one answer pair: {sum(pairs == 1 for pairs in validation_results['answer_pairs'])}/{len(rewards)}"
)
# Log chat state
log_chat_state(
prompts=prompts,
completions=completions,
rewards=rewards,
reward_type="tag_count",
validation_results=validation_results,
)
return rewards
def reward_search_strategy(prompts: list, completions: list, **reward_kwargs) -> list:
"""Reward function that checks for good search strategy and query analysis steps.
The expected conversation flow pattern is:
1. Initial search: question -> assistant(think + search)
2. Process info: information -> assistant(think + refined search)
3. Final answer: information -> assistant(think + answer)
Rewards:
- Initial search (0.2): Starting with broad/overview search
- Information processing (0.4): Analyzing provided info and refining search
- Final synthesis (0.4): Analyzing all info and providing final answer
Args:
prompts: List of input prompts
completions: List of completion dictionaries
**reward_kwargs: Additional reward parameters
Returns:
list: List of rewards between 0 and 1
"""
rewards = []
validation_results = {
"initial_search": [], # First search attempt
"info_processing": [], # Number of info-based refinements
"final_synthesis": [], # Final answer with proper analysis
}
# Patterns for conversation flow
think_pattern = r"<think>[^<>]+</think>"
search_pattern = r"<search>[^<>]+</search>"
answer_pattern = r"<answer>[^<>]+</answer>"
info_pattern = r"<information>[^<>]+</information>"
# Analysis patterns
info_analysis_pattern = (
r"<think>[^<>]*?\b(?:based|according|from|results?|found|shows?|provided|information)\b[^<>]*?</think>"
)
for completion in completions:
messages = completion.get("messages", [])
if not messages:
rewards.append(0.0)
for key in validation_results:
validation_results[key].append(False)
continue
# Track conversation flow
has_initial_search = False
info_based_searches = 0
has_final_synthesis = False
# Track current state
last_was_info = False
search_after_info = 0
analysis_after_info = 0
for i, msg in enumerate(messages):
content = msg["content"]
role = msg["role"]
if role == "assistant":
has_think = bool(re.search(think_pattern, content))
has_search = bool(re.search(search_pattern, content))
has_answer = bool(re.search(answer_pattern, content))
has_info_analysis = bool(re.search(info_analysis_pattern, content, re.IGNORECASE))
# Check initial search (first assistant message with search)
if not has_initial_search and has_think and has_search:
has_initial_search = True
# Check info-based refinement
if last_was_info and has_think:
if has_search:
search_after_info += 1
if has_info_analysis:
analysis_after_info += 1
# Check final synthesis
if has_answer and has_think and has_info_analysis:
has_final_synthesis = True
elif role in ["user", "ipython"] and re.search(info_pattern, content):
last_was_info = True
else:
last_was_info = False
# Calculate rewards
initial_reward = 0.2 if has_initial_search else 0.0
# Info processing reward: proper analysis and search after info
info_processing = min(search_after_info, analysis_after_info) # Must have both analysis and search
info_reward = min(0.4, 0.2 * info_processing) # 0.2 per proper info-based refinement, max 0.4
# Final synthesis reward
synthesis_reward = 0.4 if has_final_synthesis else 0.0
total_reward = initial_reward + info_reward + synthesis_reward
rewards.append(total_reward)
# Store validation results
validation_results["initial_search"].append(has_initial_search)
validation_results["info_processing"].append(info_processing)
validation_results["final_synthesis"].append(has_final_synthesis)
# Debug logging
if total_reward < 0.6: # Log if missing significant components
logger.debug(
f"Search flow issues - initial: {has_initial_search}, "
f"info_processing: {info_processing}, "
f"final_synthesis: {has_final_synthesis}"
)
# Log metrics
logger.info(
f"Search strategy metrics - Mean: {np.mean(rewards):.3f}, Perfect scores: {sum(r == 1.0 for r in rewards)}/{len(rewards)}"
)
logger.info(f"Initial searches: {sum(validation_results['initial_search'])}/{len(rewards)}")
logger.info(f"Average info processing steps: {np.mean([r for r in validation_results['info_processing']]):.2f}")
logger.info(f"Final synthesis rate: {sum(validation_results['final_synthesis'])}/{len(rewards)}")
# Log chat state
log_chat_state(
prompts=prompts,
completions=completions,
rewards=rewards,
reward_type="search_strategy",
validation_results=validation_results,
)
return rewards
def reward_search_diversity(prompts: list, completions: list, **reward_kwargs) -> list:
"""Reward function that evaluates diversity of search queries in a conversation.
Rewards higher diversity in search queries and penalizes repetitive searches.
Uses string similarity to compare queries, with diminishing returns for
similar queries.
Scoring:
- Base reward: 0.2 per unique query concept (max 0.4)
- Diversity bonus: Up to 0.4 based on semantic diversity
- Operator bonus: Up to 0.2 for proper use of search operators
- Penalties:
* Similar queries (>0.8 similarity): -0.1 per pair
* Exact duplicates: -0.2 per duplicate
Args:
prompts: List of input prompts
completions: List of completion dictionaries
**reward_kwargs: Additional reward parameters
Returns:
list: List of rewards between 0 and 1
"""
def normalize_query(query: str) -> tuple[str, list[str]]:
"""Normalize search query for comparison."""
# Extract operators before normalization
operators = re.findall(r'(?:site|filetype):\S+|"[^"]+"|(?:\s+OR\s+|\s+AND\s+|-\w+)', query)
# Remove operators for base comparison
base_query = re.sub(r'(?:site|filetype):\S+|"[^"]+"|(?:\s+OR\s+|\s+AND\s+|-\w+)', "", query.lower())
# Remove special chars and extra spaces from base query
base_query = re.sub(r"[^\w\s]", " ", base_query)
return " ".join(base_query.split()), operators
def query_similarity(q1: str, q2: str) -> float:
"""Calculate similarity between two queries."""
# Compare normalized base queries
base1, ops1 = normalize_query(q1)
base2, ops2 = normalize_query(q2)
# Base similarity from query text
base_sim = SequenceMatcher(None, base1, base2).ratio()
# Significantly reduce similarity if using different operators
if ops1 != ops2:
# More operators = more different
unique_ops = len(set(ops1) ^ set(ops2)) # XOR to get unique operators
base_sim *= max(0.3, 1.0 - (unique_ops * 0.2)) # Each unique operator reduces similarity by 20%
return base_sim
rewards = []
for completion in completions:
# Extract all search queries from assistant messages
search_queries = []
for msg in completion.get("messages", []):
if msg["role"] == "assistant":
# Find all search tags
searches = re.findall(r"<search>([^<>]+)</search>", msg["content"])
search_queries.extend(searches)
if not search_queries:
rewards.append(0.0)
continue
# Calculate diversity score
total_queries = len(search_queries)
if total_queries == 1:
rewards.append(0.2) # Base reward for single query
continue
# Calculate pairwise similarities and track duplicates/high similarities
similarity_sum = 0
pair_count = 0
similar_pairs = 0 # Count pairs with >0.8 similarity
exact_duplicates = 0 # Count exact matches
# Count unique operators and track their usage
all_operators = set()
operator_usage = [] # Track operators per query
for query in search_queries:
_, ops = normalize_query(query)
all_operators.update(ops)
operator_usage.append(len(ops))
# Track normalized queries to find duplicates
seen_queries = set()
unique_queries = []
for i in range(total_queries):
base_i, _ = normalize_query(search_queries[i])
if base_i in seen_queries:
exact_duplicates += 1
else:
unique_queries.append(search_queries[i])
seen_queries.add(base_i)
for j in range(i + 1, total_queries):
similarity = query_similarity(search_queries[i], search_queries[j])
similarity_sum += similarity
pair_count += 1
# Count highly similar pairs (ignoring operator differences)
base_i, _ = normalize_query(search_queries[i])
base_j, _ = normalize_query(search_queries[j])
base_sim = SequenceMatcher(None, base_i, base_j).ratio()
if base_sim > 0.8 and base_sim < 1.0: # Don't count exact duplicates twice
similar_pairs += 1
# Average similarity (0-1), weighted less for operator differences
avg_similarity = similarity_sum / pair_count if pair_count > 0 else 0
# Calculate diversity score (1 - avg_similarity)
diversity_score = 1 - avg_similarity
# Calculate operator bonus (up to 0.2)
# Reward both variety and consistent usage
operator_variety_bonus = min(0.15, len(all_operators) * 0.05) # Up to 0.15 for unique operators
operator_usage_ratio = sum(1 for x in operator_usage if x > 0) / total_queries
operator_usage_bonus = 0.05 * operator_usage_ratio # Up to 0.05 for consistent usage
operator_bonus = operator_variety_bonus + operator_usage_bonus
# Calculate penalties
# Reduce penalties when operators are different
similarity_penalty = similar_pairs * 0.1 # Reduced penalty for similar pairs
if len(all_operators) >= 2: # If using multiple operators, reduce penalties
similarity_penalty *= 0.5
duplicate_penalty = exact_duplicates * 0.2 # Keep strong penalty for exact duplicates
# Final reward calculation:
# - Base reward per unique query (max 0.4)
# - Diversity bonus (up to 0.4)
# - Operator bonus (up to 0.2)
# - Apply penalties
unique_query_count = len(unique_queries)
base_reward = min(0.4, 0.2 * unique_query_count)
diversity_bonus = diversity_score * 0.4
total_reward = base_reward + diversity_bonus + operator_bonus - similarity_penalty - duplicate_penalty
# Cap at 1.0 and floor at 0.0
reward = max(0.0, min(1.0, total_reward))
# Debug logging
logger.debug(
f"Search diversity metrics - "
f"Queries: {total_queries}, "
f"Unique: {len(seen_queries)}, "
f"Similar pairs: {similar_pairs}, "
f"Duplicates: {exact_duplicates}, "
f"Avg similarity: {avg_similarity:.2f}, "
f"Diversity score: {diversity_score:.2f}, "
f"Operator bonus: {operator_bonus:.2f}, "
f"Penalties: -{similarity_penalty + duplicate_penalty:.2f}, "
f"Final reward: {reward:.2f}"
)
rewards.append(reward)
# Log overall metrics
if rewards:
logger.info(f"Search diversity metrics - Mean reward: {np.mean(rewards):.3f}, Max reward: {max(rewards):.3f}")
return rewards
def log_chat_state(prompts: list, completions: list, rewards: list, reward_type: str, **kwargs) -> None:
"""Log chat state and rewards to JSONL file.

@ -21,7 +21,14 @@ from src.config import (
logger,
update_log_path,
)
from src.rewards import build_reward_correctness_fn, reward_em_chunk, reward_retry
from src.rewards import (
build_reward_correctness_fn,
reward_em_chunk,
reward_format,
reward_retry,
reward_search_diversity,
reward_search_strategy,
)
from src.search_module import get_qa_dataset
from src.tokenizer_adapter import LlamaTokenizerAdapter, QwenTokenizerAdapter, R1DistilTokenizerAdapter
@ -121,6 +128,8 @@ trainer = UnslothGRPOTrainerTemp.UnslothGRPOTrainer(
reward_format,
reward_retry,
reward_em_chunk,
reward_search_strategy,
reward_search_diversity,
],
args=training_args,
train_dataset=train_dataset,

Loading…
Cancel
Save