diff --git a/src/rewards.py b/src/rewards.py index 1f11879..2df5505 100644 --- a/src/rewards.py +++ b/src/rewards.py @@ -5,7 +5,7 @@ Reward functions for RL training. import json import re from datetime import datetime -from pathlib import Path +from difflib import SequenceMatcher import numpy as np @@ -145,6 +145,7 @@ def reward_format(prompts: list, completions: list, **reward_kwargs) -> list: "has_search": [], "has_invalid_tags": [], "has_info_tags": [], + "ends_properly": [], # New validation result } for completion in completions: @@ -159,6 +160,11 @@ def reward_format(prompts: list, completions: list, **reward_kwargs) -> list: content = assistant_msgs[-1] + # Check if content ends with or (ignoring whitespace) + content_stripped = content.strip() + ends_properly = content_stripped.endswith("") or content_stripped.endswith("") + validation_results["ends_properly"].append(ends_properly) + has_invalid_tags = any(re.search(pattern, content) for pattern in invalid_patterns) validation_results["has_invalid_tags"].append(has_invalid_tags) if has_invalid_tags: @@ -196,15 +202,30 @@ def reward_format(prompts: list, completions: list, **reward_kwargs) -> list: rewards.append(0.0) continue - reward = 1.0 if has_think and (has_answer or has_search) else 0.0 + # Check for proper tag sequence - think must come before answer/search + if has_answer or has_search: + last_think_pos = content.rfind("") + answer_pos = content.find("") if has_answer else float("inf") + search_pos = content.find("") if has_search else float("inf") + tag_pos = min(answer_pos, search_pos) + + if last_think_pos == -1 or last_think_pos > tag_pos: + rewards.append(0.0) + continue + + # Only reward if format is valid AND response ends properly + reward = 1.0 if has_think and (has_answer or has_search) and ends_properly else 0.0 rewards.append(reward) if not reward: - logger.debug(f"Format issues - think: {has_think}, answer: {has_answer}, search: {has_search}") + logger.debug( + f"Format issues - think: {has_think}, answer: {has_answer}, search: {has_search}, ends_properly: {ends_properly}" + ) if search_matches: logger.debug(f"Number of search tags: {len(search_matches)}") logger.info(f"Format reward metrics - Mean: {np.mean(rewards):.3f}, Valid formats: {sum(rewards)}/{len(rewards)}") + logger.info(f"Responses ending properly: {sum(validation_results['ends_properly'])}/{len(rewards)}") # Log chat state with validation results log_chat_state( @@ -218,12 +239,6 @@ def reward_format(prompts: list, completions: list, **reward_kwargs) -> list: return rewards -# TODO: Implement this reward function if the project survives -def reward_long_query(completions, **kwargs): - """Reward function that checks if the query is long.""" - pass - - def reward_retry(prompts: list, completions: list, **reward_kwargs) -> list: """ Reward function that encourages optimal retry behavior. @@ -384,6 +399,402 @@ def reward_em_chunk(prompts: list, completions: list, **reward_kwargs) -> list: return rewards +def tag_count_reward(prompts: list, completions: list, **reward_kwargs) -> list: + """Reward function that checks for proper tag counts in the conversation. + + Rewards: + - 0.1 for each proper pair of think tags in each assistant message + - 0.5 for having exactly one pair of answer tags in entire conversation + - 0.1 for each proper pair of search tags + + Args: + prompts: List of input prompts + completions: List of completion dictionaries with messages + **reward_kwargs: Additional reward parameters + + Returns: + list: List of rewards between 0 and 1 + """ + rewards = [] + validation_results = { + "think_pairs_per_msg": [], # List of lists, each inner list has think pair counts per assistant msg + "answer_pairs": [], # Total answer pairs in conversation + "search_pairs": [], # Total search pairs in conversation + } + + for completion in completions: + # Get all assistant messages + assistant_msgs = [msg["content"] for msg in completion["messages"] if msg["role"] == "assistant"] + + if not assistant_msgs: + rewards.append(0.0) + validation_results["think_pairs_per_msg"].append([]) + validation_results["answer_pairs"].append(0) + validation_results["search_pairs"].append(0) + continue + + # Count think pairs per assistant message + think_pairs_per_msg = [] + for msg in assistant_msgs: + # Count complete think tag pairs + think_opens = len(re.findall(r"", msg)) + think_closes = len(re.findall(r"", msg)) + think_pairs = min(think_opens, think_closes) + think_pairs_per_msg.append(think_pairs) + + # Count answer tags in entire conversation (should be exactly one pair) + total_answer_opens = sum(msg.count("") for msg in assistant_msgs) + total_answer_closes = sum(msg.count("") for msg in assistant_msgs) + answer_pairs = min(total_answer_opens, total_answer_closes) + + # Count search tags + total_search_opens = sum(msg.count("") for msg in assistant_msgs) + total_search_closes = sum(msg.count("") for msg in assistant_msgs) + search_pairs = min(total_search_opens, total_search_closes) + + # Calculate reward components + think_reward = sum(min(pairs, 1) * 0.1 for pairs in think_pairs_per_msg) # 0.1 per msg with proper think pair + answer_reward = 0.5 if answer_pairs == 1 else 0.0 # 0.5 for exactly one answer pair + search_reward = min(search_pairs, 1) * 0.1 # 0.1 for having search pairs + + total_reward = min(think_reward + answer_reward + search_reward, 1.0) + rewards.append(total_reward) + + # Store validation results + validation_results["think_pairs_per_msg"].append(think_pairs_per_msg) + validation_results["answer_pairs"].append(answer_pairs) + validation_results["search_pairs"].append(search_pairs) + + # Debug logging + if total_reward < 1.0: + logger.debug( + f"Tag count issues - think_pairs: {think_pairs_per_msg}, " + f"answer_pairs: {answer_pairs}, search_pairs: {search_pairs}" + ) + + # Log metrics + logger.info( + f"Tag count reward metrics - Mean: {np.mean(rewards):.3f}, Perfect scores: {sum(r == 1.0 for r in rewards)}/{len(rewards)}" + ) + logger.info( + f"Average think pairs per message: {np.mean([np.mean(pairs) if pairs else 0 for pairs in validation_results['think_pairs_per_msg']]):.2f}" + ) + logger.info( + f"Conversations with exactly one answer pair: {sum(pairs == 1 for pairs in validation_results['answer_pairs'])}/{len(rewards)}" + ) + + # Log chat state + log_chat_state( + prompts=prompts, + completions=completions, + rewards=rewards, + reward_type="tag_count", + validation_results=validation_results, + ) + + return rewards + + +def reward_search_strategy(prompts: list, completions: list, **reward_kwargs) -> list: + """Reward function that checks for good search strategy and query analysis steps. + + The expected conversation flow pattern is: + 1. Initial search: question -> assistant(think + search) + 2. Process info: information -> assistant(think + refined search) + 3. Final answer: information -> assistant(think + answer) + + Rewards: + - Initial search (0.2): Starting with broad/overview search + - Information processing (0.4): Analyzing provided info and refining search + - Final synthesis (0.4): Analyzing all info and providing final answer + + Args: + prompts: List of input prompts + completions: List of completion dictionaries + **reward_kwargs: Additional reward parameters + + Returns: + list: List of rewards between 0 and 1 + """ + rewards = [] + validation_results = { + "initial_search": [], # First search attempt + "info_processing": [], # Number of info-based refinements + "final_synthesis": [], # Final answer with proper analysis + } + + # Patterns for conversation flow + think_pattern = r"[^<>]+" + search_pattern = r"[^<>]+" + answer_pattern = r"[^<>]+" + info_pattern = r"[^<>]+" + + # Analysis patterns + info_analysis_pattern = ( + r"[^<>]*?\b(?:based|according|from|results?|found|shows?|provided|information)\b[^<>]*?" + ) + + for completion in completions: + messages = completion.get("messages", []) + if not messages: + rewards.append(0.0) + for key in validation_results: + validation_results[key].append(False) + continue + + # Track conversation flow + has_initial_search = False + info_based_searches = 0 + has_final_synthesis = False + + # Track current state + last_was_info = False + search_after_info = 0 + analysis_after_info = 0 + + for i, msg in enumerate(messages): + content = msg["content"] + role = msg["role"] + + if role == "assistant": + has_think = bool(re.search(think_pattern, content)) + has_search = bool(re.search(search_pattern, content)) + has_answer = bool(re.search(answer_pattern, content)) + has_info_analysis = bool(re.search(info_analysis_pattern, content, re.IGNORECASE)) + + # Check initial search (first assistant message with search) + if not has_initial_search and has_think and has_search: + has_initial_search = True + + # Check info-based refinement + if last_was_info and has_think: + if has_search: + search_after_info += 1 + if has_info_analysis: + analysis_after_info += 1 + + # Check final synthesis + if has_answer and has_think and has_info_analysis: + has_final_synthesis = True + + elif role in ["user", "ipython"] and re.search(info_pattern, content): + last_was_info = True + else: + last_was_info = False + + # Calculate rewards + initial_reward = 0.2 if has_initial_search else 0.0 + + # Info processing reward: proper analysis and search after info + info_processing = min(search_after_info, analysis_after_info) # Must have both analysis and search + info_reward = min(0.4, 0.2 * info_processing) # 0.2 per proper info-based refinement, max 0.4 + + # Final synthesis reward + synthesis_reward = 0.4 if has_final_synthesis else 0.0 + + total_reward = initial_reward + info_reward + synthesis_reward + rewards.append(total_reward) + + # Store validation results + validation_results["initial_search"].append(has_initial_search) + validation_results["info_processing"].append(info_processing) + validation_results["final_synthesis"].append(has_final_synthesis) + + # Debug logging + if total_reward < 0.6: # Log if missing significant components + logger.debug( + f"Search flow issues - initial: {has_initial_search}, " + f"info_processing: {info_processing}, " + f"final_synthesis: {has_final_synthesis}" + ) + + # Log metrics + logger.info( + f"Search strategy metrics - Mean: {np.mean(rewards):.3f}, Perfect scores: {sum(r == 1.0 for r in rewards)}/{len(rewards)}" + ) + logger.info(f"Initial searches: {sum(validation_results['initial_search'])}/{len(rewards)}") + logger.info(f"Average info processing steps: {np.mean([r for r in validation_results['info_processing']]):.2f}") + logger.info(f"Final synthesis rate: {sum(validation_results['final_synthesis'])}/{len(rewards)}") + + # Log chat state + log_chat_state( + prompts=prompts, + completions=completions, + rewards=rewards, + reward_type="search_strategy", + validation_results=validation_results, + ) + + return rewards + + +def reward_search_diversity(prompts: list, completions: list, **reward_kwargs) -> list: + """Reward function that evaluates diversity of search queries in a conversation. + + Rewards higher diversity in search queries and penalizes repetitive searches. + Uses string similarity to compare queries, with diminishing returns for + similar queries. + + Scoring: + - Base reward: 0.2 per unique query concept (max 0.4) + - Diversity bonus: Up to 0.4 based on semantic diversity + - Operator bonus: Up to 0.2 for proper use of search operators + - Penalties: + * Similar queries (>0.8 similarity): -0.1 per pair + * Exact duplicates: -0.2 per duplicate + + Args: + prompts: List of input prompts + completions: List of completion dictionaries + **reward_kwargs: Additional reward parameters + + Returns: + list: List of rewards between 0 and 1 + """ + + def normalize_query(query: str) -> tuple[str, list[str]]: + """Normalize search query for comparison.""" + # Extract operators before normalization + operators = re.findall(r'(?:site|filetype):\S+|"[^"]+"|(?:\s+OR\s+|\s+AND\s+|-\w+)', query) + # Remove operators for base comparison + base_query = re.sub(r'(?:site|filetype):\S+|"[^"]+"|(?:\s+OR\s+|\s+AND\s+|-\w+)', "", query.lower()) + # Remove special chars and extra spaces from base query + base_query = re.sub(r"[^\w\s]", " ", base_query) + return " ".join(base_query.split()), operators + + def query_similarity(q1: str, q2: str) -> float: + """Calculate similarity between two queries.""" + # Compare normalized base queries + base1, ops1 = normalize_query(q1) + base2, ops2 = normalize_query(q2) + + # Base similarity from query text + base_sim = SequenceMatcher(None, base1, base2).ratio() + + # Significantly reduce similarity if using different operators + if ops1 != ops2: + # More operators = more different + unique_ops = len(set(ops1) ^ set(ops2)) # XOR to get unique operators + base_sim *= max(0.3, 1.0 - (unique_ops * 0.2)) # Each unique operator reduces similarity by 20% + + return base_sim + + rewards = [] + + for completion in completions: + # Extract all search queries from assistant messages + search_queries = [] + for msg in completion.get("messages", []): + if msg["role"] == "assistant": + # Find all search tags + searches = re.findall(r"([^<>]+)", msg["content"]) + search_queries.extend(searches) + + if not search_queries: + rewards.append(0.0) + continue + + # Calculate diversity score + total_queries = len(search_queries) + if total_queries == 1: + rewards.append(0.2) # Base reward for single query + continue + + # Calculate pairwise similarities and track duplicates/high similarities + similarity_sum = 0 + pair_count = 0 + similar_pairs = 0 # Count pairs with >0.8 similarity + exact_duplicates = 0 # Count exact matches + + # Count unique operators and track their usage + all_operators = set() + operator_usage = [] # Track operators per query + for query in search_queries: + _, ops = normalize_query(query) + all_operators.update(ops) + operator_usage.append(len(ops)) + + # Track normalized queries to find duplicates + seen_queries = set() + unique_queries = [] + + for i in range(total_queries): + base_i, _ = normalize_query(search_queries[i]) + if base_i in seen_queries: + exact_duplicates += 1 + else: + unique_queries.append(search_queries[i]) + seen_queries.add(base_i) + + for j in range(i + 1, total_queries): + similarity = query_similarity(search_queries[i], search_queries[j]) + similarity_sum += similarity + pair_count += 1 + + # Count highly similar pairs (ignoring operator differences) + base_i, _ = normalize_query(search_queries[i]) + base_j, _ = normalize_query(search_queries[j]) + base_sim = SequenceMatcher(None, base_i, base_j).ratio() + if base_sim > 0.8 and base_sim < 1.0: # Don't count exact duplicates twice + similar_pairs += 1 + + # Average similarity (0-1), weighted less for operator differences + avg_similarity = similarity_sum / pair_count if pair_count > 0 else 0 + + # Calculate diversity score (1 - avg_similarity) + diversity_score = 1 - avg_similarity + + # Calculate operator bonus (up to 0.2) + # Reward both variety and consistent usage + operator_variety_bonus = min(0.15, len(all_operators) * 0.05) # Up to 0.15 for unique operators + operator_usage_ratio = sum(1 for x in operator_usage if x > 0) / total_queries + operator_usage_bonus = 0.05 * operator_usage_ratio # Up to 0.05 for consistent usage + operator_bonus = operator_variety_bonus + operator_usage_bonus + + # Calculate penalties + # Reduce penalties when operators are different + similarity_penalty = similar_pairs * 0.1 # Reduced penalty for similar pairs + if len(all_operators) >= 2: # If using multiple operators, reduce penalties + similarity_penalty *= 0.5 + + duplicate_penalty = exact_duplicates * 0.2 # Keep strong penalty for exact duplicates + + # Final reward calculation: + # - Base reward per unique query (max 0.4) + # - Diversity bonus (up to 0.4) + # - Operator bonus (up to 0.2) + # - Apply penalties + unique_query_count = len(unique_queries) + base_reward = min(0.4, 0.2 * unique_query_count) + diversity_bonus = diversity_score * 0.4 + total_reward = base_reward + diversity_bonus + operator_bonus - similarity_penalty - duplicate_penalty + + # Cap at 1.0 and floor at 0.0 + reward = max(0.0, min(1.0, total_reward)) + + # Debug logging + logger.debug( + f"Search diversity metrics - " + f"Queries: {total_queries}, " + f"Unique: {len(seen_queries)}, " + f"Similar pairs: {similar_pairs}, " + f"Duplicates: {exact_duplicates}, " + f"Avg similarity: {avg_similarity:.2f}, " + f"Diversity score: {diversity_score:.2f}, " + f"Operator bonus: {operator_bonus:.2f}, " + f"Penalties: -{similarity_penalty + duplicate_penalty:.2f}, " + f"Final reward: {reward:.2f}" + ) + + rewards.append(reward) + + # Log overall metrics + if rewards: + logger.info(f"Search diversity metrics - Mean reward: {np.mean(rewards):.3f}, Max reward: {max(rewards):.3f}") + + return rewards + + def log_chat_state(prompts: list, completions: list, rewards: list, reward_type: str, **kwargs) -> None: """Log chat state and rewards to JSONL file. diff --git a/train_grpo.py b/train_grpo.py index 4b79d2e..44b750b 100644 --- a/train_grpo.py +++ b/train_grpo.py @@ -21,7 +21,14 @@ from src.config import ( logger, update_log_path, ) -from src.rewards import build_reward_correctness_fn, reward_em_chunk, reward_retry +from src.rewards import ( + build_reward_correctness_fn, + reward_em_chunk, + reward_format, + reward_retry, + reward_search_diversity, + reward_search_strategy, +) from src.search_module import get_qa_dataset from src.tokenizer_adapter import LlamaTokenizerAdapter, QwenTokenizerAdapter, R1DistilTokenizerAdapter @@ -121,6 +128,8 @@ trainer = UnslothGRPOTrainerTemp.UnslothGRPOTrainer( reward_format, reward_retry, reward_em_chunk, + reward_search_strategy, + reward_search_diversity, ], args=training_args, train_dataset=train_dataset,