feat: expand reward functions with new strategies and diversity checks

- Added reward functions for search strategy and search diversity
- Updated reward_format to include validation for proper message endings.
main
thinhlpg 1 month ago
parent d0e6068055
commit 4de31e0f30

@ -5,7 +5,7 @@ Reward functions for RL training.
import json import json
import re import re
from datetime import datetime from datetime import datetime
from pathlib import Path from difflib import SequenceMatcher
import numpy as np import numpy as np
@ -145,6 +145,7 @@ def reward_format(prompts: list, completions: list, **reward_kwargs) -> list:
"has_search": [], "has_search": [],
"has_invalid_tags": [], "has_invalid_tags": [],
"has_info_tags": [], "has_info_tags": [],
"ends_properly": [], # New validation result
} }
for completion in completions: for completion in completions:
@ -159,6 +160,11 @@ def reward_format(prompts: list, completions: list, **reward_kwargs) -> list:
content = assistant_msgs[-1] content = assistant_msgs[-1]
# Check if content ends with </search> or </answer> (ignoring whitespace)
content_stripped = content.strip()
ends_properly = content_stripped.endswith("</search>") or content_stripped.endswith("</answer>")
validation_results["ends_properly"].append(ends_properly)
has_invalid_tags = any(re.search(pattern, content) for pattern in invalid_patterns) has_invalid_tags = any(re.search(pattern, content) for pattern in invalid_patterns)
validation_results["has_invalid_tags"].append(has_invalid_tags) validation_results["has_invalid_tags"].append(has_invalid_tags)
if has_invalid_tags: if has_invalid_tags:
@ -196,15 +202,30 @@ def reward_format(prompts: list, completions: list, **reward_kwargs) -> list:
rewards.append(0.0) rewards.append(0.0)
continue continue
reward = 1.0 if has_think and (has_answer or has_search) else 0.0 # Check for proper tag sequence - think must come before answer/search
if has_answer or has_search:
last_think_pos = content.rfind("</think>")
answer_pos = content.find("<answer>") if has_answer else float("inf")
search_pos = content.find("<search>") if has_search else float("inf")
tag_pos = min(answer_pos, search_pos)
if last_think_pos == -1 or last_think_pos > tag_pos:
rewards.append(0.0)
continue
# Only reward if format is valid AND response ends properly
reward = 1.0 if has_think and (has_answer or has_search) and ends_properly else 0.0
rewards.append(reward) rewards.append(reward)
if not reward: if not reward:
logger.debug(f"Format issues - think: {has_think}, answer: {has_answer}, search: {has_search}") logger.debug(
f"Format issues - think: {has_think}, answer: {has_answer}, search: {has_search}, ends_properly: {ends_properly}"
)
if search_matches: if search_matches:
logger.debug(f"Number of search tags: {len(search_matches)}") logger.debug(f"Number of search tags: {len(search_matches)}")
logger.info(f"Format reward metrics - Mean: {np.mean(rewards):.3f}, Valid formats: {sum(rewards)}/{len(rewards)}") logger.info(f"Format reward metrics - Mean: {np.mean(rewards):.3f}, Valid formats: {sum(rewards)}/{len(rewards)}")
logger.info(f"Responses ending properly: {sum(validation_results['ends_properly'])}/{len(rewards)}")
# Log chat state with validation results # Log chat state with validation results
log_chat_state( log_chat_state(
@ -218,12 +239,6 @@ def reward_format(prompts: list, completions: list, **reward_kwargs) -> list:
return rewards return rewards
# TODO: Implement this reward function if the project survives
def reward_long_query(completions, **kwargs):
"""Reward function that checks if the query is long."""
pass
def reward_retry(prompts: list, completions: list, **reward_kwargs) -> list: def reward_retry(prompts: list, completions: list, **reward_kwargs) -> list:
""" """
Reward function that encourages optimal retry behavior. Reward function that encourages optimal retry behavior.
@ -384,6 +399,402 @@ def reward_em_chunk(prompts: list, completions: list, **reward_kwargs) -> list:
return rewards return rewards
def tag_count_reward(prompts: list, completions: list, **reward_kwargs) -> list:
"""Reward function that checks for proper tag counts in the conversation.
Rewards:
- 0.1 for each proper pair of think tags in each assistant message
- 0.5 for having exactly one pair of answer tags in entire conversation
- 0.1 for each proper pair of search tags
Args:
prompts: List of input prompts
completions: List of completion dictionaries with messages
**reward_kwargs: Additional reward parameters
Returns:
list: List of rewards between 0 and 1
"""
rewards = []
validation_results = {
"think_pairs_per_msg": [], # List of lists, each inner list has think pair counts per assistant msg
"answer_pairs": [], # Total answer pairs in conversation
"search_pairs": [], # Total search pairs in conversation
}
for completion in completions:
# Get all assistant messages
assistant_msgs = [msg["content"] for msg in completion["messages"] if msg["role"] == "assistant"]
if not assistant_msgs:
rewards.append(0.0)
validation_results["think_pairs_per_msg"].append([])
validation_results["answer_pairs"].append(0)
validation_results["search_pairs"].append(0)
continue
# Count think pairs per assistant message
think_pairs_per_msg = []
for msg in assistant_msgs:
# Count complete think tag pairs
think_opens = len(re.findall(r"<think>", msg))
think_closes = len(re.findall(r"</think>", msg))
think_pairs = min(think_opens, think_closes)
think_pairs_per_msg.append(think_pairs)
# Count answer tags in entire conversation (should be exactly one pair)
total_answer_opens = sum(msg.count("<answer>") for msg in assistant_msgs)
total_answer_closes = sum(msg.count("</answer>") for msg in assistant_msgs)
answer_pairs = min(total_answer_opens, total_answer_closes)
# Count search tags
total_search_opens = sum(msg.count("<search>") for msg in assistant_msgs)
total_search_closes = sum(msg.count("</search>") for msg in assistant_msgs)
search_pairs = min(total_search_opens, total_search_closes)
# Calculate reward components
think_reward = sum(min(pairs, 1) * 0.1 for pairs in think_pairs_per_msg) # 0.1 per msg with proper think pair
answer_reward = 0.5 if answer_pairs == 1 else 0.0 # 0.5 for exactly one answer pair
search_reward = min(search_pairs, 1) * 0.1 # 0.1 for having search pairs
total_reward = min(think_reward + answer_reward + search_reward, 1.0)
rewards.append(total_reward)
# Store validation results
validation_results["think_pairs_per_msg"].append(think_pairs_per_msg)
validation_results["answer_pairs"].append(answer_pairs)
validation_results["search_pairs"].append(search_pairs)
# Debug logging
if total_reward < 1.0:
logger.debug(
f"Tag count issues - think_pairs: {think_pairs_per_msg}, "
f"answer_pairs: {answer_pairs}, search_pairs: {search_pairs}"
)
# Log metrics
logger.info(
f"Tag count reward metrics - Mean: {np.mean(rewards):.3f}, Perfect scores: {sum(r == 1.0 for r in rewards)}/{len(rewards)}"
)
logger.info(
f"Average think pairs per message: {np.mean([np.mean(pairs) if pairs else 0 for pairs in validation_results['think_pairs_per_msg']]):.2f}"
)
logger.info(
f"Conversations with exactly one answer pair: {sum(pairs == 1 for pairs in validation_results['answer_pairs'])}/{len(rewards)}"
)
# Log chat state
log_chat_state(
prompts=prompts,
completions=completions,
rewards=rewards,
reward_type="tag_count",
validation_results=validation_results,
)
return rewards
def reward_search_strategy(prompts: list, completions: list, **reward_kwargs) -> list:
"""Reward function that checks for good search strategy and query analysis steps.
The expected conversation flow pattern is:
1. Initial search: question -> assistant(think + search)
2. Process info: information -> assistant(think + refined search)
3. Final answer: information -> assistant(think + answer)
Rewards:
- Initial search (0.2): Starting with broad/overview search
- Information processing (0.4): Analyzing provided info and refining search
- Final synthesis (0.4): Analyzing all info and providing final answer
Args:
prompts: List of input prompts
completions: List of completion dictionaries
**reward_kwargs: Additional reward parameters
Returns:
list: List of rewards between 0 and 1
"""
rewards = []
validation_results = {
"initial_search": [], # First search attempt
"info_processing": [], # Number of info-based refinements
"final_synthesis": [], # Final answer with proper analysis
}
# Patterns for conversation flow
think_pattern = r"<think>[^<>]+</think>"
search_pattern = r"<search>[^<>]+</search>"
answer_pattern = r"<answer>[^<>]+</answer>"
info_pattern = r"<information>[^<>]+</information>"
# Analysis patterns
info_analysis_pattern = (
r"<think>[^<>]*?\b(?:based|according|from|results?|found|shows?|provided|information)\b[^<>]*?</think>"
)
for completion in completions:
messages = completion.get("messages", [])
if not messages:
rewards.append(0.0)
for key in validation_results:
validation_results[key].append(False)
continue
# Track conversation flow
has_initial_search = False
info_based_searches = 0
has_final_synthesis = False
# Track current state
last_was_info = False
search_after_info = 0
analysis_after_info = 0
for i, msg in enumerate(messages):
content = msg["content"]
role = msg["role"]
if role == "assistant":
has_think = bool(re.search(think_pattern, content))
has_search = bool(re.search(search_pattern, content))
has_answer = bool(re.search(answer_pattern, content))
has_info_analysis = bool(re.search(info_analysis_pattern, content, re.IGNORECASE))
# Check initial search (first assistant message with search)
if not has_initial_search and has_think and has_search:
has_initial_search = True
# Check info-based refinement
if last_was_info and has_think:
if has_search:
search_after_info += 1
if has_info_analysis:
analysis_after_info += 1
# Check final synthesis
if has_answer and has_think and has_info_analysis:
has_final_synthesis = True
elif role in ["user", "ipython"] and re.search(info_pattern, content):
last_was_info = True
else:
last_was_info = False
# Calculate rewards
initial_reward = 0.2 if has_initial_search else 0.0
# Info processing reward: proper analysis and search after info
info_processing = min(search_after_info, analysis_after_info) # Must have both analysis and search
info_reward = min(0.4, 0.2 * info_processing) # 0.2 per proper info-based refinement, max 0.4
# Final synthesis reward
synthesis_reward = 0.4 if has_final_synthesis else 0.0
total_reward = initial_reward + info_reward + synthesis_reward
rewards.append(total_reward)
# Store validation results
validation_results["initial_search"].append(has_initial_search)
validation_results["info_processing"].append(info_processing)
validation_results["final_synthesis"].append(has_final_synthesis)
# Debug logging
if total_reward < 0.6: # Log if missing significant components
logger.debug(
f"Search flow issues - initial: {has_initial_search}, "
f"info_processing: {info_processing}, "
f"final_synthesis: {has_final_synthesis}"
)
# Log metrics
logger.info(
f"Search strategy metrics - Mean: {np.mean(rewards):.3f}, Perfect scores: {sum(r == 1.0 for r in rewards)}/{len(rewards)}"
)
logger.info(f"Initial searches: {sum(validation_results['initial_search'])}/{len(rewards)}")
logger.info(f"Average info processing steps: {np.mean([r for r in validation_results['info_processing']]):.2f}")
logger.info(f"Final synthesis rate: {sum(validation_results['final_synthesis'])}/{len(rewards)}")
# Log chat state
log_chat_state(
prompts=prompts,
completions=completions,
rewards=rewards,
reward_type="search_strategy",
validation_results=validation_results,
)
return rewards
def reward_search_diversity(prompts: list, completions: list, **reward_kwargs) -> list:
"""Reward function that evaluates diversity of search queries in a conversation.
Rewards higher diversity in search queries and penalizes repetitive searches.
Uses string similarity to compare queries, with diminishing returns for
similar queries.
Scoring:
- Base reward: 0.2 per unique query concept (max 0.4)
- Diversity bonus: Up to 0.4 based on semantic diversity
- Operator bonus: Up to 0.2 for proper use of search operators
- Penalties:
* Similar queries (>0.8 similarity): -0.1 per pair
* Exact duplicates: -0.2 per duplicate
Args:
prompts: List of input prompts
completions: List of completion dictionaries
**reward_kwargs: Additional reward parameters
Returns:
list: List of rewards between 0 and 1
"""
def normalize_query(query: str) -> tuple[str, list[str]]:
"""Normalize search query for comparison."""
# Extract operators before normalization
operators = re.findall(r'(?:site|filetype):\S+|"[^"]+"|(?:\s+OR\s+|\s+AND\s+|-\w+)', query)
# Remove operators for base comparison
base_query = re.sub(r'(?:site|filetype):\S+|"[^"]+"|(?:\s+OR\s+|\s+AND\s+|-\w+)', "", query.lower())
# Remove special chars and extra spaces from base query
base_query = re.sub(r"[^\w\s]", " ", base_query)
return " ".join(base_query.split()), operators
def query_similarity(q1: str, q2: str) -> float:
"""Calculate similarity between two queries."""
# Compare normalized base queries
base1, ops1 = normalize_query(q1)
base2, ops2 = normalize_query(q2)
# Base similarity from query text
base_sim = SequenceMatcher(None, base1, base2).ratio()
# Significantly reduce similarity if using different operators
if ops1 != ops2:
# More operators = more different
unique_ops = len(set(ops1) ^ set(ops2)) # XOR to get unique operators
base_sim *= max(0.3, 1.0 - (unique_ops * 0.2)) # Each unique operator reduces similarity by 20%
return base_sim
rewards = []
for completion in completions:
# Extract all search queries from assistant messages
search_queries = []
for msg in completion.get("messages", []):
if msg["role"] == "assistant":
# Find all search tags
searches = re.findall(r"<search>([^<>]+)</search>", msg["content"])
search_queries.extend(searches)
if not search_queries:
rewards.append(0.0)
continue
# Calculate diversity score
total_queries = len(search_queries)
if total_queries == 1:
rewards.append(0.2) # Base reward for single query
continue
# Calculate pairwise similarities and track duplicates/high similarities
similarity_sum = 0
pair_count = 0
similar_pairs = 0 # Count pairs with >0.8 similarity
exact_duplicates = 0 # Count exact matches
# Count unique operators and track their usage
all_operators = set()
operator_usage = [] # Track operators per query
for query in search_queries:
_, ops = normalize_query(query)
all_operators.update(ops)
operator_usage.append(len(ops))
# Track normalized queries to find duplicates
seen_queries = set()
unique_queries = []
for i in range(total_queries):
base_i, _ = normalize_query(search_queries[i])
if base_i in seen_queries:
exact_duplicates += 1
else:
unique_queries.append(search_queries[i])
seen_queries.add(base_i)
for j in range(i + 1, total_queries):
similarity = query_similarity(search_queries[i], search_queries[j])
similarity_sum += similarity
pair_count += 1
# Count highly similar pairs (ignoring operator differences)
base_i, _ = normalize_query(search_queries[i])
base_j, _ = normalize_query(search_queries[j])
base_sim = SequenceMatcher(None, base_i, base_j).ratio()
if base_sim > 0.8 and base_sim < 1.0: # Don't count exact duplicates twice
similar_pairs += 1
# Average similarity (0-1), weighted less for operator differences
avg_similarity = similarity_sum / pair_count if pair_count > 0 else 0
# Calculate diversity score (1 - avg_similarity)
diversity_score = 1 - avg_similarity
# Calculate operator bonus (up to 0.2)
# Reward both variety and consistent usage
operator_variety_bonus = min(0.15, len(all_operators) * 0.05) # Up to 0.15 for unique operators
operator_usage_ratio = sum(1 for x in operator_usage if x > 0) / total_queries
operator_usage_bonus = 0.05 * operator_usage_ratio # Up to 0.05 for consistent usage
operator_bonus = operator_variety_bonus + operator_usage_bonus
# Calculate penalties
# Reduce penalties when operators are different
similarity_penalty = similar_pairs * 0.1 # Reduced penalty for similar pairs
if len(all_operators) >= 2: # If using multiple operators, reduce penalties
similarity_penalty *= 0.5
duplicate_penalty = exact_duplicates * 0.2 # Keep strong penalty for exact duplicates
# Final reward calculation:
# - Base reward per unique query (max 0.4)
# - Diversity bonus (up to 0.4)
# - Operator bonus (up to 0.2)
# - Apply penalties
unique_query_count = len(unique_queries)
base_reward = min(0.4, 0.2 * unique_query_count)
diversity_bonus = diversity_score * 0.4
total_reward = base_reward + diversity_bonus + operator_bonus - similarity_penalty - duplicate_penalty
# Cap at 1.0 and floor at 0.0
reward = max(0.0, min(1.0, total_reward))
# Debug logging
logger.debug(
f"Search diversity metrics - "
f"Queries: {total_queries}, "
f"Unique: {len(seen_queries)}, "
f"Similar pairs: {similar_pairs}, "
f"Duplicates: {exact_duplicates}, "
f"Avg similarity: {avg_similarity:.2f}, "
f"Diversity score: {diversity_score:.2f}, "
f"Operator bonus: {operator_bonus:.2f}, "
f"Penalties: -{similarity_penalty + duplicate_penalty:.2f}, "
f"Final reward: {reward:.2f}"
)
rewards.append(reward)
# Log overall metrics
if rewards:
logger.info(f"Search diversity metrics - Mean reward: {np.mean(rewards):.3f}, Max reward: {max(rewards):.3f}")
return rewards
def log_chat_state(prompts: list, completions: list, rewards: list, reward_type: str, **kwargs) -> None: def log_chat_state(prompts: list, completions: list, rewards: list, reward_type: str, **kwargs) -> None:
"""Log chat state and rewards to JSONL file. """Log chat state and rewards to JSONL file.

@ -21,7 +21,14 @@ from src.config import (
logger, logger,
update_log_path, update_log_path,
) )
from src.rewards import build_reward_correctness_fn, reward_em_chunk, reward_retry from src.rewards import (
build_reward_correctness_fn,
reward_em_chunk,
reward_format,
reward_retry,
reward_search_diversity,
reward_search_strategy,
)
from src.search_module import get_qa_dataset from src.search_module import get_qa_dataset
from src.tokenizer_adapter import LlamaTokenizerAdapter, QwenTokenizerAdapter, R1DistilTokenizerAdapter from src.tokenizer_adapter import LlamaTokenizerAdapter, QwenTokenizerAdapter, R1DistilTokenizerAdapter
@ -121,6 +128,8 @@ trainer = UnslothGRPOTrainerTemp.UnslothGRPOTrainer(
reward_format, reward_format,
reward_retry, reward_retry,
reward_em_chunk, reward_em_chunk,
reward_search_strategy,
reward_search_diversity,
], ],
args=training_args, args=training_args,
train_dataset=train_dataset, train_dataset=train_dataset,

Loading…
Cancel
Save