@ -244,6 +244,7 @@ def reward_retry(prompts: list, completions: list, **reward_kwargs) -> list:
Reward function that encourages optimal retry behavior .
Reward function that encourages optimal retry behavior .
Rewards increase with more search attempts but caps at optimal_search_count .
Rewards increase with more search attempts but caps at optimal_search_count .
Penalizes having multiple searches in a single message .
Penalizes having multiple searches in a single message .
Returns 0 if final message doesn ' t contain answer tags.
Args :
Args :
prompts : List of input prompts
prompts : List of input prompts
@ -263,13 +264,30 @@ def reward_retry(prompts: list, completions: list, **reward_kwargs) -> list:
increment = 0.15 # Reward increment per search attempt (0.2 + 5*0.15 = 0.95 max)
increment = 0.15 # Reward increment per search attempt (0.2 + 5*0.15 = 0.95 max)
violation_penalty = 0.5 # Penalty for having multiple searches in one message
violation_penalty = 0.5 # Penalty for having multiple searches in one message
# Regex pattern for search tags
# Regex pattern for search and answer tags
search_pattern = r " <search>[ \ s \ S]*?</search> "
search_pattern = r " <search>[ \ s \ S]*?</search> "
answer_pattern = r " <answer>[ \ s \ S]*?</answer> "
for completion in completions :
for completion in completions :
# Get assistant messages
# Get assistant messages
assistant_messages = [ msg [ " content " ] for msg in completion [ " messages " ] if msg [ " role " ] == " assistant " ]
assistant_messages = [ msg [ " content " ] for msg in completion [ " messages " ] if msg [ " role " ] == " assistant " ]
if not assistant_messages :
rewards . append ( 0.0 )
search_queries . append ( 0 )
violations . append ( False )
continue
# Check if final message contains answer tags
final_message = assistant_messages [ - 1 ]
has_answer = bool ( re . search ( answer_pattern , final_message ) )
if not has_answer :
rewards . append ( 0.0 )
search_queries . append ( 0 )
violations . append ( False )
continue
# Count search tags in assistant messages
# Count search tags in assistant messages
message_searches = [ ]
message_searches = [ ]
for msg in assistant_messages :
for msg in assistant_messages :