From 77f121662f28408f6fe471bf1da7d84fc9a15811 Mon Sep 17 00:00:00 2001 From: thinhlpg Date: Fri, 4 Apr 2025 09:59:07 +0700 Subject: [PATCH] test: add tests for reward_retry function scenarios --- tests/test_rewards.py | 138 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 138 insertions(+) diff --git a/tests/test_rewards.py b/tests/test_rewards.py index 49486e2..ad2b9b2 100644 --- a/tests/test_rewards.py +++ b/tests/test_rewards.py @@ -627,3 +627,141 @@ def test_reward_search_diversity_exact_duplicates(): rewards = reward_search_diversity(prompts=["test"], completions=completions) # Should get very low reward due to exact duplicates assert rewards[0] < 0.2 + + +def test_reward_retry_no_answer(): + """Test reward_retry when final message has no answer tags - should return 0.""" + prompts = ["Test prompt"] + completions = [ + { + "messages": [ + {"role": "assistant", "content": "Let me searchquery 1"}, + {"role": "assistant", "content": "Let me search againquery 2"}, + ] + } + ] + rewards = reward_retry(prompts, completions) + assert rewards[0] == 0.0, "Should return 0 when final message has no answer tags" + + +def test_reward_retry_with_answer(): + """Test reward_retry with answer in final message - should calculate reward normally.""" + prompts = ["Test prompt"] + completions = [ + { + "messages": [ + {"role": "assistant", "content": "Let me searchquery 1"}, + {"role": "assistant", "content": "Let me search againquery 2"}, + {"role": "assistant", "content": "Here's the answerFinal answer"}, + ] + } + ] + rewards = reward_retry(prompts, completions) + expected = round(0.2 + 2 * 0.15, 3) # base_reward + 2 searches * increment + assert rewards[0] == expected, "Should calculate reward normally when final message has answer tags" + + +def test_reward_retry_violation_with_answer(): + """Test reward_retry with multiple searches in one message but answer in final message.""" + prompts = ["Test prompt"] + completions = [ + { + "messages": [ + { + "role": "assistant", + "content": "Multiple searchesquery 1query 2", + }, + {"role": "assistant", "content": "Here's the answerFinal answer"}, + ] + } + ] + rewards = reward_retry(prompts, completions) + expected = round((0.2 + 2 * 0.15) * (1 - 0.5), 3) # (base + 2*increment) * (1 - violation_penalty) + assert rewards[0] == expected, "Should apply violation penalty but still calculate reward due to answer" + + +def test_reward_retry_optimal_searches(): + """Test reward_retry with optimal number of searches and answer.""" + prompts = ["Test prompt"] + completions = [ + { + "messages": [ + {"role": "assistant", "content": "Search 1query 1"}, + {"role": "assistant", "content": "Search 2query 2"}, + {"role": "assistant", "content": "Search 3query 3"}, + {"role": "assistant", "content": "Search 4query 4"}, + {"role": "assistant", "content": "Search 5query 5"}, + {"role": "assistant", "content": "Here's the answerFinal answer"}, + ] + } + ] + rewards = reward_retry(prompts, completions) + expected = round(0.2 + 5 * 0.15, 3) # base_reward + optimal_search_count * increment + assert rewards[0] == expected, "Should cap reward at optimal search count" + + +def test_reward_retry_beyond_optimal(): + """Test reward_retry with more than optimal searches but still with answer.""" + prompts = ["Test prompt"] + completions = [ + { + "messages": [ + {"role": "assistant", "content": "Search 1query 1"}, + {"role": "assistant", "content": "Search 2query 2"}, + {"role": "assistant", "content": "Search 3query 3"}, + {"role": "assistant", "content": "Search 4query 4"}, + {"role": "assistant", "content": "Search 5query 5"}, + {"role": "assistant", "content": "Search 6query 6"}, + {"role": "assistant", "content": "Here's the answerFinal answer"}, + ] + } + ] + rewards = reward_retry(prompts, completions) + expected = round(0.2 + 5 * 0.15, 3) # base_reward + optimal_search_count * increment + assert rewards[0] == expected, "Should not exceed max reward even with more searches" + + +def test_reward_retry_empty_messages(): + """Test reward_retry with empty message list.""" + prompts = ["Test prompt"] + completions = [{"messages": []}] + rewards = reward_retry(prompts, completions) + assert rewards[0] == 0.0, "Should return 0 for empty message list" + + +def test_reward_retry_no_searches_with_answer(): + """Test reward_retry with no searches but has answer.""" + prompts = ["Test prompt"] + completions = [ + { + "messages": [ + {"role": "assistant", "content": "Direct answerFinal answer"}, + ] + } + ] + rewards = reward_retry(prompts, completions) + assert rewards[0] == 0.0, "Should return 0 when no searches even with answer" + + +def test_reward_retry_multiple_completions(): + """Test reward_retry with multiple completions.""" + prompts = ["Test 1", "Test 2"] + completions = [ + { + "messages": [ + {"role": "assistant", "content": "Searchquery"}, + {"role": "assistant", "content": "AnswerFinal 1"}, + ] + }, + { + "messages": [ + {"role": "assistant", "content": "Search 1query 1"}, + {"role": "assistant", "content": "Search 2query 2"}, + {"role": "assistant", "content": "AnswerFinal 2"}, + ] + }, + ] + rewards = reward_retry(prompts, completions) + expected1 = round(0.2 + 0.15, 3) # base_reward + 1 search * increment + expected2 = round(0.2 + 2 * 0.15, 3) # base_reward + 2 searches * increment + assert rewards == [expected1, expected2], "Should handle multiple completions correctly"