|
|
|
@ -5,7 +5,7 @@ Reward functions for RL training.
|
|
|
|
|
import json
|
|
|
|
|
import re
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
from difflib import SequenceMatcher
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
@ -145,6 +145,7 @@ def reward_format(prompts: list, completions: list, **reward_kwargs) -> list:
|
|
|
|
|
"has_search": [],
|
|
|
|
|
"has_invalid_tags": [],
|
|
|
|
|
"has_info_tags": [],
|
|
|
|
|
"ends_properly": [], # New validation result
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for completion in completions:
|
|
|
|
@ -159,6 +160,11 @@ def reward_format(prompts: list, completions: list, **reward_kwargs) -> list:
|
|
|
|
|
|
|
|
|
|
content = assistant_msgs[-1]
|
|
|
|
|
|
|
|
|
|
# Check if content ends with </search> or </answer> (ignoring whitespace)
|
|
|
|
|
content_stripped = content.strip()
|
|
|
|
|
ends_properly = content_stripped.endswith("</search>") or content_stripped.endswith("</answer>")
|
|
|
|
|
validation_results["ends_properly"].append(ends_properly)
|
|
|
|
|
|
|
|
|
|
has_invalid_tags = any(re.search(pattern, content) for pattern in invalid_patterns)
|
|
|
|
|
validation_results["has_invalid_tags"].append(has_invalid_tags)
|
|
|
|
|
if has_invalid_tags:
|
|
|
|
@ -196,15 +202,30 @@ def reward_format(prompts: list, completions: list, **reward_kwargs) -> list:
|
|
|
|
|
rewards.append(0.0)
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
reward = 1.0 if has_think and (has_answer or has_search) else 0.0
|
|
|
|
|
# Check for proper tag sequence - think must come before answer/search
|
|
|
|
|
if has_answer or has_search:
|
|
|
|
|
last_think_pos = content.rfind("</think>")
|
|
|
|
|
answer_pos = content.find("<answer>") if has_answer else float("inf")
|
|
|
|
|
search_pos = content.find("<search>") if has_search else float("inf")
|
|
|
|
|
tag_pos = min(answer_pos, search_pos)
|
|
|
|
|
|
|
|
|
|
if last_think_pos == -1 or last_think_pos > tag_pos:
|
|
|
|
|
rewards.append(0.0)
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# Only reward if format is valid AND response ends properly
|
|
|
|
|
reward = 1.0 if has_think and (has_answer or has_search) and ends_properly else 0.0
|
|
|
|
|
rewards.append(reward)
|
|
|
|
|
|
|
|
|
|
if not reward:
|
|
|
|
|
logger.debug(f"Format issues - think: {has_think}, answer: {has_answer}, search: {has_search}")
|
|
|
|
|
logger.debug(
|
|
|
|
|
f"Format issues - think: {has_think}, answer: {has_answer}, search: {has_search}, ends_properly: {ends_properly}"
|
|
|
|
|
)
|
|
|
|
|
if search_matches:
|
|
|
|
|
logger.debug(f"Number of search tags: {len(search_matches)}")
|
|
|
|
|
|
|
|
|
|
logger.info(f"Format reward metrics - Mean: {np.mean(rewards):.3f}, Valid formats: {sum(rewards)}/{len(rewards)}")
|
|
|
|
|
logger.info(f"Responses ending properly: {sum(validation_results['ends_properly'])}/{len(rewards)}")
|
|
|
|
|
|
|
|
|
|
# Log chat state with validation results
|
|
|
|
|
log_chat_state(
|
|
|
|
@ -218,12 +239,6 @@ def reward_format(prompts: list, completions: list, **reward_kwargs) -> list:
|
|
|
|
|
return rewards
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# TODO: Implement this reward function if the project survives
|
|
|
|
|
def reward_long_query(completions, **kwargs):
|
|
|
|
|
"""Reward function that checks if the query is long."""
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reward_retry(prompts: list, completions: list, **reward_kwargs) -> list:
|
|
|
|
|
"""
|
|
|
|
|
Reward function that encourages optimal retry behavior.
|
|
|
|
@ -384,6 +399,402 @@ def reward_em_chunk(prompts: list, completions: list, **reward_kwargs) -> list:
|
|
|
|
|
return rewards
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def tag_count_reward(prompts: list, completions: list, **reward_kwargs) -> list:
|
|
|
|
|
"""Reward function that checks for proper tag counts in the conversation.
|
|
|
|
|
|
|
|
|
|
Rewards:
|
|
|
|
|
- 0.1 for each proper pair of think tags in each assistant message
|
|
|
|
|
- 0.5 for having exactly one pair of answer tags in entire conversation
|
|
|
|
|
- 0.1 for each proper pair of search tags
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
prompts: List of input prompts
|
|
|
|
|
completions: List of completion dictionaries with messages
|
|
|
|
|
**reward_kwargs: Additional reward parameters
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
list: List of rewards between 0 and 1
|
|
|
|
|
"""
|
|
|
|
|
rewards = []
|
|
|
|
|
validation_results = {
|
|
|
|
|
"think_pairs_per_msg": [], # List of lists, each inner list has think pair counts per assistant msg
|
|
|
|
|
"answer_pairs": [], # Total answer pairs in conversation
|
|
|
|
|
"search_pairs": [], # Total search pairs in conversation
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for completion in completions:
|
|
|
|
|
# Get all assistant messages
|
|
|
|
|
assistant_msgs = [msg["content"] for msg in completion["messages"] if msg["role"] == "assistant"]
|
|
|
|
|
|
|
|
|
|
if not assistant_msgs:
|
|
|
|
|
rewards.append(0.0)
|
|
|
|
|
validation_results["think_pairs_per_msg"].append([])
|
|
|
|
|
validation_results["answer_pairs"].append(0)
|
|
|
|
|
validation_results["search_pairs"].append(0)
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# Count think pairs per assistant message
|
|
|
|
|
think_pairs_per_msg = []
|
|
|
|
|
for msg in assistant_msgs:
|
|
|
|
|
# Count complete think tag pairs
|
|
|
|
|
think_opens = len(re.findall(r"<think>", msg))
|
|
|
|
|
think_closes = len(re.findall(r"</think>", msg))
|
|
|
|
|
think_pairs = min(think_opens, think_closes)
|
|
|
|
|
think_pairs_per_msg.append(think_pairs)
|
|
|
|
|
|
|
|
|
|
# Count answer tags in entire conversation (should be exactly one pair)
|
|
|
|
|
total_answer_opens = sum(msg.count("<answer>") for msg in assistant_msgs)
|
|
|
|
|
total_answer_closes = sum(msg.count("</answer>") for msg in assistant_msgs)
|
|
|
|
|
answer_pairs = min(total_answer_opens, total_answer_closes)
|
|
|
|
|
|
|
|
|
|
# Count search tags
|
|
|
|
|
total_search_opens = sum(msg.count("<search>") for msg in assistant_msgs)
|
|
|
|
|
total_search_closes = sum(msg.count("</search>") for msg in assistant_msgs)
|
|
|
|
|
search_pairs = min(total_search_opens, total_search_closes)
|
|
|
|
|
|
|
|
|
|
# Calculate reward components
|
|
|
|
|
think_reward = sum(min(pairs, 1) * 0.1 for pairs in think_pairs_per_msg) # 0.1 per msg with proper think pair
|
|
|
|
|
answer_reward = 0.5 if answer_pairs == 1 else 0.0 # 0.5 for exactly one answer pair
|
|
|
|
|
search_reward = min(search_pairs, 1) * 0.1 # 0.1 for having search pairs
|
|
|
|
|
|
|
|
|
|
total_reward = min(think_reward + answer_reward + search_reward, 1.0)
|
|
|
|
|
rewards.append(total_reward)
|
|
|
|
|
|
|
|
|
|
# Store validation results
|
|
|
|
|
validation_results["think_pairs_per_msg"].append(think_pairs_per_msg)
|
|
|
|
|
validation_results["answer_pairs"].append(answer_pairs)
|
|
|
|
|
validation_results["search_pairs"].append(search_pairs)
|
|
|
|
|
|
|
|
|
|
# Debug logging
|
|
|
|
|
if total_reward < 1.0:
|
|
|
|
|
logger.debug(
|
|
|
|
|
f"Tag count issues - think_pairs: {think_pairs_per_msg}, "
|
|
|
|
|
f"answer_pairs: {answer_pairs}, search_pairs: {search_pairs}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Log metrics
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Tag count reward metrics - Mean: {np.mean(rewards):.3f}, Perfect scores: {sum(r == 1.0 for r in rewards)}/{len(rewards)}"
|
|
|
|
|
)
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Average think pairs per message: {np.mean([np.mean(pairs) if pairs else 0 for pairs in validation_results['think_pairs_per_msg']]):.2f}"
|
|
|
|
|
)
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Conversations with exactly one answer pair: {sum(pairs == 1 for pairs in validation_results['answer_pairs'])}/{len(rewards)}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Log chat state
|
|
|
|
|
log_chat_state(
|
|
|
|
|
prompts=prompts,
|
|
|
|
|
completions=completions,
|
|
|
|
|
rewards=rewards,
|
|
|
|
|
reward_type="tag_count",
|
|
|
|
|
validation_results=validation_results,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return rewards
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reward_search_strategy(prompts: list, completions: list, **reward_kwargs) -> list:
|
|
|
|
|
"""Reward function that checks for good search strategy and query analysis steps.
|
|
|
|
|
|
|
|
|
|
The expected conversation flow pattern is:
|
|
|
|
|
1. Initial search: question -> assistant(think + search)
|
|
|
|
|
2. Process info: information -> assistant(think + refined search)
|
|
|
|
|
3. Final answer: information -> assistant(think + answer)
|
|
|
|
|
|
|
|
|
|
Rewards:
|
|
|
|
|
- Initial search (0.2): Starting with broad/overview search
|
|
|
|
|
- Information processing (0.4): Analyzing provided info and refining search
|
|
|
|
|
- Final synthesis (0.4): Analyzing all info and providing final answer
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
prompts: List of input prompts
|
|
|
|
|
completions: List of completion dictionaries
|
|
|
|
|
**reward_kwargs: Additional reward parameters
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
list: List of rewards between 0 and 1
|
|
|
|
|
"""
|
|
|
|
|
rewards = []
|
|
|
|
|
validation_results = {
|
|
|
|
|
"initial_search": [], # First search attempt
|
|
|
|
|
"info_processing": [], # Number of info-based refinements
|
|
|
|
|
"final_synthesis": [], # Final answer with proper analysis
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# Patterns for conversation flow
|
|
|
|
|
think_pattern = r"<think>[^<>]+</think>"
|
|
|
|
|
search_pattern = r"<search>[^<>]+</search>"
|
|
|
|
|
answer_pattern = r"<answer>[^<>]+</answer>"
|
|
|
|
|
info_pattern = r"<information>[^<>]+</information>"
|
|
|
|
|
|
|
|
|
|
# Analysis patterns
|
|
|
|
|
info_analysis_pattern = (
|
|
|
|
|
r"<think>[^<>]*?\b(?:based|according|from|results?|found|shows?|provided|information)\b[^<>]*?</think>"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
for completion in completions:
|
|
|
|
|
messages = completion.get("messages", [])
|
|
|
|
|
if not messages:
|
|
|
|
|
rewards.append(0.0)
|
|
|
|
|
for key in validation_results:
|
|
|
|
|
validation_results[key].append(False)
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# Track conversation flow
|
|
|
|
|
has_initial_search = False
|
|
|
|
|
info_based_searches = 0
|
|
|
|
|
has_final_synthesis = False
|
|
|
|
|
|
|
|
|
|
# Track current state
|
|
|
|
|
last_was_info = False
|
|
|
|
|
search_after_info = 0
|
|
|
|
|
analysis_after_info = 0
|
|
|
|
|
|
|
|
|
|
for i, msg in enumerate(messages):
|
|
|
|
|
content = msg["content"]
|
|
|
|
|
role = msg["role"]
|
|
|
|
|
|
|
|
|
|
if role == "assistant":
|
|
|
|
|
has_think = bool(re.search(think_pattern, content))
|
|
|
|
|
has_search = bool(re.search(search_pattern, content))
|
|
|
|
|
has_answer = bool(re.search(answer_pattern, content))
|
|
|
|
|
has_info_analysis = bool(re.search(info_analysis_pattern, content, re.IGNORECASE))
|
|
|
|
|
|
|
|
|
|
# Check initial search (first assistant message with search)
|
|
|
|
|
if not has_initial_search and has_think and has_search:
|
|
|
|
|
has_initial_search = True
|
|
|
|
|
|
|
|
|
|
# Check info-based refinement
|
|
|
|
|
if last_was_info and has_think:
|
|
|
|
|
if has_search:
|
|
|
|
|
search_after_info += 1
|
|
|
|
|
if has_info_analysis:
|
|
|
|
|
analysis_after_info += 1
|
|
|
|
|
|
|
|
|
|
# Check final synthesis
|
|
|
|
|
if has_answer and has_think and has_info_analysis:
|
|
|
|
|
has_final_synthesis = True
|
|
|
|
|
|
|
|
|
|
elif role in ["user", "ipython"] and re.search(info_pattern, content):
|
|
|
|
|
last_was_info = True
|
|
|
|
|
else:
|
|
|
|
|
last_was_info = False
|
|
|
|
|
|
|
|
|
|
# Calculate rewards
|
|
|
|
|
initial_reward = 0.2 if has_initial_search else 0.0
|
|
|
|
|
|
|
|
|
|
# Info processing reward: proper analysis and search after info
|
|
|
|
|
info_processing = min(search_after_info, analysis_after_info) # Must have both analysis and search
|
|
|
|
|
info_reward = min(0.4, 0.2 * info_processing) # 0.2 per proper info-based refinement, max 0.4
|
|
|
|
|
|
|
|
|
|
# Final synthesis reward
|
|
|
|
|
synthesis_reward = 0.4 if has_final_synthesis else 0.0
|
|
|
|
|
|
|
|
|
|
total_reward = initial_reward + info_reward + synthesis_reward
|
|
|
|
|
rewards.append(total_reward)
|
|
|
|
|
|
|
|
|
|
# Store validation results
|
|
|
|
|
validation_results["initial_search"].append(has_initial_search)
|
|
|
|
|
validation_results["info_processing"].append(info_processing)
|
|
|
|
|
validation_results["final_synthesis"].append(has_final_synthesis)
|
|
|
|
|
|
|
|
|
|
# Debug logging
|
|
|
|
|
if total_reward < 0.6: # Log if missing significant components
|
|
|
|
|
logger.debug(
|
|
|
|
|
f"Search flow issues - initial: {has_initial_search}, "
|
|
|
|
|
f"info_processing: {info_processing}, "
|
|
|
|
|
f"final_synthesis: {has_final_synthesis}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Log metrics
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Search strategy metrics - Mean: {np.mean(rewards):.3f}, Perfect scores: {sum(r == 1.0 for r in rewards)}/{len(rewards)}"
|
|
|
|
|
)
|
|
|
|
|
logger.info(f"Initial searches: {sum(validation_results['initial_search'])}/{len(rewards)}")
|
|
|
|
|
logger.info(f"Average info processing steps: {np.mean([r for r in validation_results['info_processing']]):.2f}")
|
|
|
|
|
logger.info(f"Final synthesis rate: {sum(validation_results['final_synthesis'])}/{len(rewards)}")
|
|
|
|
|
|
|
|
|
|
# Log chat state
|
|
|
|
|
log_chat_state(
|
|
|
|
|
prompts=prompts,
|
|
|
|
|
completions=completions,
|
|
|
|
|
rewards=rewards,
|
|
|
|
|
reward_type="search_strategy",
|
|
|
|
|
validation_results=validation_results,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return rewards
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reward_search_diversity(prompts: list, completions: list, **reward_kwargs) -> list:
|
|
|
|
|
"""Reward function that evaluates diversity of search queries in a conversation.
|
|
|
|
|
|
|
|
|
|
Rewards higher diversity in search queries and penalizes repetitive searches.
|
|
|
|
|
Uses string similarity to compare queries, with diminishing returns for
|
|
|
|
|
similar queries.
|
|
|
|
|
|
|
|
|
|
Scoring:
|
|
|
|
|
- Base reward: 0.2 per unique query concept (max 0.4)
|
|
|
|
|
- Diversity bonus: Up to 0.4 based on semantic diversity
|
|
|
|
|
- Operator bonus: Up to 0.2 for proper use of search operators
|
|
|
|
|
- Penalties:
|
|
|
|
|
* Similar queries (>0.8 similarity): -0.1 per pair
|
|
|
|
|
* Exact duplicates: -0.2 per duplicate
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
prompts: List of input prompts
|
|
|
|
|
completions: List of completion dictionaries
|
|
|
|
|
**reward_kwargs: Additional reward parameters
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
list: List of rewards between 0 and 1
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def normalize_query(query: str) -> tuple[str, list[str]]:
|
|
|
|
|
"""Normalize search query for comparison."""
|
|
|
|
|
# Extract operators before normalization
|
|
|
|
|
operators = re.findall(r'(?:site|filetype):\S+|"[^"]+"|(?:\s+OR\s+|\s+AND\s+|-\w+)', query)
|
|
|
|
|
# Remove operators for base comparison
|
|
|
|
|
base_query = re.sub(r'(?:site|filetype):\S+|"[^"]+"|(?:\s+OR\s+|\s+AND\s+|-\w+)', "", query.lower())
|
|
|
|
|
# Remove special chars and extra spaces from base query
|
|
|
|
|
base_query = re.sub(r"[^\w\s]", " ", base_query)
|
|
|
|
|
return " ".join(base_query.split()), operators
|
|
|
|
|
|
|
|
|
|
def query_similarity(q1: str, q2: str) -> float:
|
|
|
|
|
"""Calculate similarity between two queries."""
|
|
|
|
|
# Compare normalized base queries
|
|
|
|
|
base1, ops1 = normalize_query(q1)
|
|
|
|
|
base2, ops2 = normalize_query(q2)
|
|
|
|
|
|
|
|
|
|
# Base similarity from query text
|
|
|
|
|
base_sim = SequenceMatcher(None, base1, base2).ratio()
|
|
|
|
|
|
|
|
|
|
# Significantly reduce similarity if using different operators
|
|
|
|
|
if ops1 != ops2:
|
|
|
|
|
# More operators = more different
|
|
|
|
|
unique_ops = len(set(ops1) ^ set(ops2)) # XOR to get unique operators
|
|
|
|
|
base_sim *= max(0.3, 1.0 - (unique_ops * 0.2)) # Each unique operator reduces similarity by 20%
|
|
|
|
|
|
|
|
|
|
return base_sim
|
|
|
|
|
|
|
|
|
|
rewards = []
|
|
|
|
|
|
|
|
|
|
for completion in completions:
|
|
|
|
|
# Extract all search queries from assistant messages
|
|
|
|
|
search_queries = []
|
|
|
|
|
for msg in completion.get("messages", []):
|
|
|
|
|
if msg["role"] == "assistant":
|
|
|
|
|
# Find all search tags
|
|
|
|
|
searches = re.findall(r"<search>([^<>]+)</search>", msg["content"])
|
|
|
|
|
search_queries.extend(searches)
|
|
|
|
|
|
|
|
|
|
if not search_queries:
|
|
|
|
|
rewards.append(0.0)
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# Calculate diversity score
|
|
|
|
|
total_queries = len(search_queries)
|
|
|
|
|
if total_queries == 1:
|
|
|
|
|
rewards.append(0.2) # Base reward for single query
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# Calculate pairwise similarities and track duplicates/high similarities
|
|
|
|
|
similarity_sum = 0
|
|
|
|
|
pair_count = 0
|
|
|
|
|
similar_pairs = 0 # Count pairs with >0.8 similarity
|
|
|
|
|
exact_duplicates = 0 # Count exact matches
|
|
|
|
|
|
|
|
|
|
# Count unique operators and track their usage
|
|
|
|
|
all_operators = set()
|
|
|
|
|
operator_usage = [] # Track operators per query
|
|
|
|
|
for query in search_queries:
|
|
|
|
|
_, ops = normalize_query(query)
|
|
|
|
|
all_operators.update(ops)
|
|
|
|
|
operator_usage.append(len(ops))
|
|
|
|
|
|
|
|
|
|
# Track normalized queries to find duplicates
|
|
|
|
|
seen_queries = set()
|
|
|
|
|
unique_queries = []
|
|
|
|
|
|
|
|
|
|
for i in range(total_queries):
|
|
|
|
|
base_i, _ = normalize_query(search_queries[i])
|
|
|
|
|
if base_i in seen_queries:
|
|
|
|
|
exact_duplicates += 1
|
|
|
|
|
else:
|
|
|
|
|
unique_queries.append(search_queries[i])
|
|
|
|
|
seen_queries.add(base_i)
|
|
|
|
|
|
|
|
|
|
for j in range(i + 1, total_queries):
|
|
|
|
|
similarity = query_similarity(search_queries[i], search_queries[j])
|
|
|
|
|
similarity_sum += similarity
|
|
|
|
|
pair_count += 1
|
|
|
|
|
|
|
|
|
|
# Count highly similar pairs (ignoring operator differences)
|
|
|
|
|
base_i, _ = normalize_query(search_queries[i])
|
|
|
|
|
base_j, _ = normalize_query(search_queries[j])
|
|
|
|
|
base_sim = SequenceMatcher(None, base_i, base_j).ratio()
|
|
|
|
|
if base_sim > 0.8 and base_sim < 1.0: # Don't count exact duplicates twice
|
|
|
|
|
similar_pairs += 1
|
|
|
|
|
|
|
|
|
|
# Average similarity (0-1), weighted less for operator differences
|
|
|
|
|
avg_similarity = similarity_sum / pair_count if pair_count > 0 else 0
|
|
|
|
|
|
|
|
|
|
# Calculate diversity score (1 - avg_similarity)
|
|
|
|
|
diversity_score = 1 - avg_similarity
|
|
|
|
|
|
|
|
|
|
# Calculate operator bonus (up to 0.2)
|
|
|
|
|
# Reward both variety and consistent usage
|
|
|
|
|
operator_variety_bonus = min(0.15, len(all_operators) * 0.05) # Up to 0.15 for unique operators
|
|
|
|
|
operator_usage_ratio = sum(1 for x in operator_usage if x > 0) / total_queries
|
|
|
|
|
operator_usage_bonus = 0.05 * operator_usage_ratio # Up to 0.05 for consistent usage
|
|
|
|
|
operator_bonus = operator_variety_bonus + operator_usage_bonus
|
|
|
|
|
|
|
|
|
|
# Calculate penalties
|
|
|
|
|
# Reduce penalties when operators are different
|
|
|
|
|
similarity_penalty = similar_pairs * 0.1 # Reduced penalty for similar pairs
|
|
|
|
|
if len(all_operators) >= 2: # If using multiple operators, reduce penalties
|
|
|
|
|
similarity_penalty *= 0.5
|
|
|
|
|
|
|
|
|
|
duplicate_penalty = exact_duplicates * 0.2 # Keep strong penalty for exact duplicates
|
|
|
|
|
|
|
|
|
|
# Final reward calculation:
|
|
|
|
|
# - Base reward per unique query (max 0.4)
|
|
|
|
|
# - Diversity bonus (up to 0.4)
|
|
|
|
|
# - Operator bonus (up to 0.2)
|
|
|
|
|
# - Apply penalties
|
|
|
|
|
unique_query_count = len(unique_queries)
|
|
|
|
|
base_reward = min(0.4, 0.2 * unique_query_count)
|
|
|
|
|
diversity_bonus = diversity_score * 0.4
|
|
|
|
|
total_reward = base_reward + diversity_bonus + operator_bonus - similarity_penalty - duplicate_penalty
|
|
|
|
|
|
|
|
|
|
# Cap at 1.0 and floor at 0.0
|
|
|
|
|
reward = max(0.0, min(1.0, total_reward))
|
|
|
|
|
|
|
|
|
|
# Debug logging
|
|
|
|
|
logger.debug(
|
|
|
|
|
f"Search diversity metrics - "
|
|
|
|
|
f"Queries: {total_queries}, "
|
|
|
|
|
f"Unique: {len(seen_queries)}, "
|
|
|
|
|
f"Similar pairs: {similar_pairs}, "
|
|
|
|
|
f"Duplicates: {exact_duplicates}, "
|
|
|
|
|
f"Avg similarity: {avg_similarity:.2f}, "
|
|
|
|
|
f"Diversity score: {diversity_score:.2f}, "
|
|
|
|
|
f"Operator bonus: {operator_bonus:.2f}, "
|
|
|
|
|
f"Penalties: -{similarity_penalty + duplicate_penalty:.2f}, "
|
|
|
|
|
f"Final reward: {reward:.2f}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
rewards.append(reward)
|
|
|
|
|
|
|
|
|
|
# Log overall metrics
|
|
|
|
|
if rewards:
|
|
|
|
|
logger.info(f"Search diversity metrics - Mean reward: {np.mean(rewards):.3f}, Max reward: {max(rewards):.3f}")
|
|
|
|
|
|
|
|
|
|
return rewards
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def log_chat_state(prompts: list, completions: list, rewards: list, reward_type: str, **kwargs) -> None:
|
|
|
|
|
"""Log chat state and rewards to JSONL file.
|
|
|
|
|
|
|
|
|
|