From c8714e0f6b149ab950f2860584ec4dfaf840eeb6 Mon Sep 17 00:00:00 2001 From: thinhlpg Date: Fri, 4 Apr 2025 09:58:44 +0700 Subject: [PATCH] feat: enhance reward_retry function to handle missing answer tags MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Added logic to return 0 if the final message from the assistant does not contain answer tags (no matter how hard you try, you won't get anything if no result 💀) --- src/rewards.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/src/rewards.py b/src/rewards.py index 2df5505..3587f3c 100644 --- a/src/rewards.py +++ b/src/rewards.py @@ -244,6 +244,7 @@ def reward_retry(prompts: list, completions: list, **reward_kwargs) -> list: Reward function that encourages optimal retry behavior. Rewards increase with more search attempts but caps at optimal_search_count. Penalizes having multiple searches in a single message. + Returns 0 if final message doesn't contain answer tags. Args: prompts: List of input prompts @@ -263,13 +264,30 @@ def reward_retry(prompts: list, completions: list, **reward_kwargs) -> list: increment = 0.15 # Reward increment per search attempt (0.2 + 5*0.15 = 0.95 max) violation_penalty = 0.5 # Penalty for having multiple searches in one message - # Regex pattern for search tags + # Regex pattern for search and answer tags search_pattern = r"[\s\S]*?" + answer_pattern = r"[\s\S]*?" for completion in completions: # Get assistant messages assistant_messages = [msg["content"] for msg in completion["messages"] if msg["role"] == "assistant"] + if not assistant_messages: + rewards.append(0.0) + search_queries.append(0) + violations.append(False) + continue + + # Check if final message contains answer tags + final_message = assistant_messages[-1] + has_answer = bool(re.search(answer_pattern, final_message)) + + if not has_answer: + rewards.append(0.0) + search_queries.append(0) + violations.append(False) + continue + # Count search tags in assistant messages message_searches = [] for msg in assistant_messages: