From 3081d6e36b22baea937801fd66031d5e8e85c10f Mon Sep 17 00:00:00 2001 From: thinhlpg Date: Fri, 4 Apr 2025 00:28:04 +0700 Subject: [PATCH] test: added tests for new reward functions: search strategy and search diversity --- tests/test_rewards.py | 250 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 248 insertions(+), 2 deletions(-) diff --git a/tests/test_rewards.py b/tests/test_rewards.py index 4941f6c..49486e2 100644 --- a/tests/test_rewards.py +++ b/tests/test_rewards.py @@ -9,6 +9,8 @@ from src.rewards import ( reward_em_chunk, reward_format, reward_retry, + reward_search_diversity, + reward_search_strategy, ) @@ -189,7 +191,7 @@ def test_reward_format_real_example_answer_only(): def test_reward_format_incorrect_tag_sequence(): - """Test reward format with incorrect tag sequence - should now pass since sequence doesn't matter""" + """Test reward format with incorrect tag sequence - should fail since we require proper sequence and ending""" formats = [ { "messages": [ @@ -208,7 +210,28 @@ def test_reward_format_incorrect_tag_sequence(): for completion in formats: rewards = reward_format([], [completion]) - assert rewards[0] == 1.0, f"Failed with: {completion['messages'][0]['content']}" + assert rewards[0] == 0.0, f"Should fail with incorrect sequence: {completion['messages'][0]['content']}" + + # Test correct sequences + correct_formats = [ + { + "messages": [ + {"role": "assistant", "content": "\nSome reasoning\n\n\nThe answer\n"} + ] + }, + { + "messages": [ + { + "role": "assistant", + "content": "\nSome reasoning\n\nquery", + } + ] + }, + ] + + for completion in correct_formats: + rewards = reward_format([], [completion]) + assert rewards[0] == 1.0, f"Should pass with correct sequence: {completion['messages'][0]['content']}" def test_reward_format_multiple_answers(): @@ -381,3 +404,226 @@ def test_reward_correctness_validation(reward_correctness_fn): for case in test_cases: rewards = reward_correctness_fn(prompts, [case["completion"]], answer=["4"]) assert rewards[0] == case["expected"], f"Failed: {case['desc']}" + + +def test_reward_search_strategy_perfect(): + """Test search strategy reward with a perfect search strategy and info processing.""" + content = [ + { + "role": "assistant", + "content": """Let me search for a broad overview first. +what is quantum computing overview""", + }, + { + "role": "user", + "content": "Quantum computing uses quantum mechanics principles...", + }, + { + "role": "assistant", + "content": """Based on the provided information about quantum mechanics, I should look for specific examples. +quantum computing practical examples""", + }, + { + "role": "user", + "content": "Some practical examples of quantum computing include...", + }, + { + "role": "assistant", + "content": """According to the examples provided, I should investigate the advantages. +quantum computing advantages over classical computing""", + }, + {"role": "user", "content": "The key advantages of quantum computing are..."}, + { + "role": "assistant", + "content": """Based on all the information gathered about quantum computing principles, examples, and advantages, I can now provide a comprehensive answer. +Quantum computing is a revolutionary technology that...""", + }, + ] + + completions = [{"messages": content}] + rewards = reward_search_strategy(prompts=["test"], completions=completions) + assert rewards[0] == 1.0 # Initial (0.2) + info processing (0.4) + final synthesis (0.4) + + +def test_reward_search_strategy_no_refinement(): + """Test search strategy reward with just initial search.""" + content = [ + { + "role": "assistant", + "content": """Let me search for information. +quantum computing""", + } + ] + + completions = [{"messages": content}] + rewards = reward_search_strategy(prompts=["test"], completions=completions) + assert rewards[0] == 0.2 # Only initial search reward + + +def test_reward_search_strategy_multiple_refinements(): + """Test search strategy reward with multiple info-based refinements.""" + content = [ + { + "role": "assistant", + "content": """Starting with a broad search. +what is machine learning""", + }, + {"role": "user", "content": "Machine learning is a branch of AI..."}, + { + "role": "assistant", + "content": """Based on this overview, let's look for specific applications. +machine learning examples in healthcare""", + }, + {"role": "user", "content": "In healthcare, ML is used for diagnosis..."}, + { + "role": "assistant", + "content": """According to the healthcare examples, we should explore different methods. +supervised vs unsupervised learning in medical diagnosis""", + }, + ] + + completions = [{"messages": content}] + rewards = reward_search_strategy(prompts=["test"], completions=completions) + assert rewards[0] == pytest.approx(0.6) # Initial (0.2) + info processing (0.4) + + +def test_reward_search_strategy_search_operators(): + """Test search strategy reward with search operators and info processing.""" + content = [ + { + "role": "assistant", + "content": """Let's use specific search operators. +"exact phrase" site:edu""", + }, + {"role": "user", "content": "Educational resources show that..."}, + { + "role": "assistant", + "content": """Based on the educational resources, we need more specific results. +machine learning -basic filetype:pdf""", + }, + {"role": "user", "content": "Advanced ML concepts include..."}, + { + "role": "assistant", + "content": """According to these findings, let's combine key concepts. +AI AND "deep learning" OR "neural networks"""", + }, + ] + + completions = [{"messages": content}] + rewards = reward_search_strategy(prompts=["test"], completions=completions) + assert rewards[0] == pytest.approx(0.6) # Initial (0.2) + info processing (0.4) + + +def test_reward_search_strategy_non_assistant_messages(): + """Test search strategy reward with mixed message roles.""" + content = [ + {"role": "user", "content": "user search"}, + { + "role": "assistant", + "content": """Let me search for information. +what is quantum computing""", + }, + {"role": "system", "content": "system thinking"}, + ] + + completions = [{"messages": content}] + rewards = reward_search_strategy(prompts=["test"], completions=completions) + assert rewards[0] == 0.2 # Only counts the assistant's initial search + + +def test_reward_search_strategy_final_synthesis(): + """Test search strategy reward with final synthesis but no refinements.""" + content = [ + { + "role": "assistant", + "content": """Let me search for information. +quantum computing basics""", + }, + {"role": "user", "content": "Quantum computing uses qubits..."}, + { + "role": "assistant", + "content": """Based on the information about quantum computing basics, I can now provide an answer. +Quantum computing is a field that...""", + }, + ] + + completions = [{"messages": content}] + rewards = reward_search_strategy(prompts=["test"], completions=completions) + assert rewards[0] == pytest.approx(0.6) # Initial (0.2) + final synthesis (0.4) + + +def test_reward_search_diversity_no_search(): + """Test search diversity reward with no search queries.""" + content = [ + { + "role": "assistant", + "content": "Let me think about this.", + } + ] + completions = [{"messages": content}] + rewards = reward_search_diversity(prompts=["test"], completions=completions) + assert rewards[0] == 0.0 + + +def test_reward_search_diversity_single_query(): + """Test search diversity reward with a single search query.""" + content = [ + { + "role": "assistant", + "content": """Let me search. +what is python programming""", + } + ] + completions = [{"messages": content}] + rewards = reward_search_diversity(prompts=["test"], completions=completions) + assert rewards[0] == pytest.approx(0.2) # Base reward for single query + + +def test_reward_search_diversity_diverse_queries(): + """Test search diversity reward with diverse search queries.""" + content = [ + { + "role": "assistant", + "content": """Let's start broad. +what is machine learning""", + }, + { + "role": "assistant", + "content": """Now specific applications. +healthcare diagnosis using neural networks""", + }, + { + "role": "assistant", + "content": """Let's look at a different aspect. +ethical concerns in AI decision making""", + }, + ] + completions = [{"messages": content}] + rewards = reward_search_diversity(prompts=["test"], completions=completions) + # Should get high reward due to diverse queries + assert rewards[0] > 0.5 + + +def test_reward_search_diversity_exact_duplicates(): + """Test search diversity reward with exact duplicate queries.""" + content = [ + { + "role": "assistant", + "content": """First search. +python tutorial""", + }, + { + "role": "assistant", + "content": """Searching again. +python tutorial""", + }, + { + "role": "assistant", + "content": """One more time. +python tutorial""", + }, + ] + completions = [{"messages": content}] + rewards = reward_search_diversity(prompts=["test"], completions=completions) + # Should get very low reward due to exact duplicates + assert rewards[0] < 0.2