test: add tests for reward_retry function scenarios

main
thinhlpg 1 month ago
parent c8714e0f6b
commit 77f121662f

@ -627,3 +627,141 @@ def test_reward_search_diversity_exact_duplicates():
rewards = reward_search_diversity(prompts=["test"], completions=completions)
# Should get very low reward due to exact duplicates
assert rewards[0] < 0.2
def test_reward_retry_no_answer():
"""Test reward_retry when final message has no answer tags - should return 0."""
prompts = ["Test prompt"]
completions = [
{
"messages": [
{"role": "assistant", "content": "<think>Let me search</think><search>query 1</search>"},
{"role": "assistant", "content": "<think>Let me search again</think><search>query 2</search>"},
]
}
]
rewards = reward_retry(prompts, completions)
assert rewards[0] == 0.0, "Should return 0 when final message has no answer tags"
def test_reward_retry_with_answer():
"""Test reward_retry with answer in final message - should calculate reward normally."""
prompts = ["Test prompt"]
completions = [
{
"messages": [
{"role": "assistant", "content": "<think>Let me search</think><search>query 1</search>"},
{"role": "assistant", "content": "<think>Let me search again</think><search>query 2</search>"},
{"role": "assistant", "content": "<think>Here's the answer</think><answer>Final answer</answer>"},
]
}
]
rewards = reward_retry(prompts, completions)
expected = round(0.2 + 2 * 0.15, 3) # base_reward + 2 searches * increment
assert rewards[0] == expected, "Should calculate reward normally when final message has answer tags"
def test_reward_retry_violation_with_answer():
"""Test reward_retry with multiple searches in one message but answer in final message."""
prompts = ["Test prompt"]
completions = [
{
"messages": [
{
"role": "assistant",
"content": "<think>Multiple searches</think><search>query 1</search><search>query 2</search>",
},
{"role": "assistant", "content": "<think>Here's the answer</think><answer>Final answer</answer>"},
]
}
]
rewards = reward_retry(prompts, completions)
expected = round((0.2 + 2 * 0.15) * (1 - 0.5), 3) # (base + 2*increment) * (1 - violation_penalty)
assert rewards[0] == expected, "Should apply violation penalty but still calculate reward due to answer"
def test_reward_retry_optimal_searches():
"""Test reward_retry with optimal number of searches and answer."""
prompts = ["Test prompt"]
completions = [
{
"messages": [
{"role": "assistant", "content": "<think>Search 1</think><search>query 1</search>"},
{"role": "assistant", "content": "<think>Search 2</think><search>query 2</search>"},
{"role": "assistant", "content": "<think>Search 3</think><search>query 3</search>"},
{"role": "assistant", "content": "<think>Search 4</think><search>query 4</search>"},
{"role": "assistant", "content": "<think>Search 5</think><search>query 5</search>"},
{"role": "assistant", "content": "<think>Here's the answer</think><answer>Final answer</answer>"},
]
}
]
rewards = reward_retry(prompts, completions)
expected = round(0.2 + 5 * 0.15, 3) # base_reward + optimal_search_count * increment
assert rewards[0] == expected, "Should cap reward at optimal search count"
def test_reward_retry_beyond_optimal():
"""Test reward_retry with more than optimal searches but still with answer."""
prompts = ["Test prompt"]
completions = [
{
"messages": [
{"role": "assistant", "content": "<think>Search 1</think><search>query 1</search>"},
{"role": "assistant", "content": "<think>Search 2</think><search>query 2</search>"},
{"role": "assistant", "content": "<think>Search 3</think><search>query 3</search>"},
{"role": "assistant", "content": "<think>Search 4</think><search>query 4</search>"},
{"role": "assistant", "content": "<think>Search 5</think><search>query 5</search>"},
{"role": "assistant", "content": "<think>Search 6</think><search>query 6</search>"},
{"role": "assistant", "content": "<think>Here's the answer</think><answer>Final answer</answer>"},
]
}
]
rewards = reward_retry(prompts, completions)
expected = round(0.2 + 5 * 0.15, 3) # base_reward + optimal_search_count * increment
assert rewards[0] == expected, "Should not exceed max reward even with more searches"
def test_reward_retry_empty_messages():
"""Test reward_retry with empty message list."""
prompts = ["Test prompt"]
completions = [{"messages": []}]
rewards = reward_retry(prompts, completions)
assert rewards[0] == 0.0, "Should return 0 for empty message list"
def test_reward_retry_no_searches_with_answer():
"""Test reward_retry with no searches but has answer."""
prompts = ["Test prompt"]
completions = [
{
"messages": [
{"role": "assistant", "content": "<think>Direct answer</think><answer>Final answer</answer>"},
]
}
]
rewards = reward_retry(prompts, completions)
assert rewards[0] == 0.0, "Should return 0 when no searches even with answer"
def test_reward_retry_multiple_completions():
"""Test reward_retry with multiple completions."""
prompts = ["Test 1", "Test 2"]
completions = [
{
"messages": [
{"role": "assistant", "content": "<think>Search</think><search>query</search>"},
{"role": "assistant", "content": "<think>Answer</think><answer>Final 1</answer>"},
]
},
{
"messages": [
{"role": "assistant", "content": "<think>Search 1</think><search>query 1</search>"},
{"role": "assistant", "content": "<think>Search 2</think><search>query 2</search>"},
{"role": "assistant", "content": "<think>Answer</think><answer>Final 2</answer>"},
]
},
]
rewards = reward_retry(prompts, completions)
expected1 = round(0.2 + 0.15, 3) # base_reward + 1 search * increment
expected2 = round(0.2 + 2 * 0.15, 3) # base_reward + 2 searches * increment
assert rewards == [expected1, expected2], "Should handle multiple completions correctly"

Loading…
Cancel
Save