diff --git a/src/rewards.py b/src/rewards.py
index 2df5505..3587f3c 100644
--- a/src/rewards.py
+++ b/src/rewards.py
@@ -244,6 +244,7 @@ def reward_retry(prompts: list, completions: list, **reward_kwargs) -> list:
Reward function that encourages optimal retry behavior.
Rewards increase with more search attempts but caps at optimal_search_count.
Penalizes having multiple searches in a single message.
+ Returns 0 if final message doesn't contain answer tags.
Args:
prompts: List of input prompts
@@ -263,13 +264,30 @@ def reward_retry(prompts: list, completions: list, **reward_kwargs) -> list:
increment = 0.15 # Reward increment per search attempt (0.2 + 5*0.15 = 0.95 max)
violation_penalty = 0.5 # Penalty for having multiple searches in one message
- # Regex pattern for search tags
+ # Regex pattern for search and answer tags
search_pattern = r"[\s\S]*?"
+ answer_pattern = r"[\s\S]*?"
for completion in completions:
# Get assistant messages
assistant_messages = [msg["content"] for msg in completion["messages"] if msg["role"] == "assistant"]
+ if not assistant_messages:
+ rewards.append(0.0)
+ search_queries.append(0)
+ violations.append(False)
+ continue
+
+ # Check if final message contains answer tags
+ final_message = assistant_messages[-1]
+ has_answer = bool(re.search(answer_pattern, final_message))
+
+ if not has_answer:
+ rewards.append(0.0)
+ search_queries.append(0)
+ violations.append(False)
+ continue
+
# Count search tags in assistant messages
message_searches = []
for msg in assistant_messages: