feat: enhance reward_retry function to handle missing answer tags

Added logic to return 0 if the final message from the assistant does not contain answer tags (no matter how hard you try, you won't get anything if no result 💀)
main
thinhlpg 1 month ago
parent bf480574a2
commit c8714e0f6b

@ -244,6 +244,7 @@ def reward_retry(prompts: list, completions: list, **reward_kwargs) -> list:
Reward function that encourages optimal retry behavior.
Rewards increase with more search attempts but caps at optimal_search_count.
Penalizes having multiple searches in a single message.
Returns 0 if final message doesn't contain answer tags.
Args:
prompts: List of input prompts
@ -263,13 +264,30 @@ def reward_retry(prompts: list, completions: list, **reward_kwargs) -> list:
increment = 0.15 # Reward increment per search attempt (0.2 + 5*0.15 = 0.95 max)
violation_penalty = 0.5 # Penalty for having multiple searches in one message
# Regex pattern for search tags
# Regex pattern for search and answer tags
search_pattern = r"<search>[\s\S]*?</search>"
answer_pattern = r"<answer>[\s\S]*?</answer>"
for completion in completions:
# Get assistant messages
assistant_messages = [msg["content"] for msg in completion["messages"] if msg["role"] == "assistant"]
if not assistant_messages:
rewards.append(0.0)
search_queries.append(0)
violations.append(False)
continue
# Check if final message contains answer tags
final_message = assistant_messages[-1]
has_answer = bool(re.search(answer_pattern, final_message))
if not has_answer:
rewards.append(0.0)
search_queries.append(0)
violations.append(False)
continue
# Count search tags in assistant messages
message_searches = []
for msg in assistant_messages:

Loading…
Cancel
Save