diff --git a/src/rewards.py b/src/rewards.py index 2df5505..3587f3c 100644 --- a/src/rewards.py +++ b/src/rewards.py @@ -244,6 +244,7 @@ def reward_retry(prompts: list, completions: list, **reward_kwargs) -> list: Reward function that encourages optimal retry behavior. Rewards increase with more search attempts but caps at optimal_search_count. Penalizes having multiple searches in a single message. + Returns 0 if final message doesn't contain answer tags. Args: prompts: List of input prompts @@ -263,13 +264,30 @@ def reward_retry(prompts: list, completions: list, **reward_kwargs) -> list: increment = 0.15 # Reward increment per search attempt (0.2 + 5*0.15 = 0.95 max) violation_penalty = 0.5 # Penalty for having multiple searches in one message - # Regex pattern for search tags + # Regex pattern for search and answer tags search_pattern = r"[\s\S]*?" + answer_pattern = r"[\s\S]*?" for completion in completions: # Get assistant messages assistant_messages = [msg["content"] for msg in completion["messages"] if msg["role"] == "assistant"] + if not assistant_messages: + rewards.append(0.0) + search_queries.append(0) + violations.append(False) + continue + + # Check if final message contains answer tags + final_message = assistant_messages[-1] + has_answer = bool(re.search(answer_pattern, final_message)) + + if not has_answer: + rewards.append(0.0) + search_queries.append(0) + violations.append(False) + continue + # Count search tags in assistant messages message_searches = [] for msg in assistant_messages: