diff --git a/tests/test_rewards.py b/tests/test_rewards.py
index 4941f6c..49486e2 100644
--- a/tests/test_rewards.py
+++ b/tests/test_rewards.py
@@ -9,6 +9,8 @@ from src.rewards import (
reward_em_chunk,
reward_format,
reward_retry,
+ reward_search_diversity,
+ reward_search_strategy,
)
@@ -189,7 +191,7 @@ def test_reward_format_real_example_answer_only():
def test_reward_format_incorrect_tag_sequence():
- """Test reward format with incorrect tag sequence - should now pass since sequence doesn't matter"""
+ """Test reward format with incorrect tag sequence - should fail since we require proper sequence and ending"""
formats = [
{
"messages": [
@@ -208,7 +210,28 @@ def test_reward_format_incorrect_tag_sequence():
for completion in formats:
rewards = reward_format([], [completion])
- assert rewards[0] == 1.0, f"Failed with: {completion['messages'][0]['content']}"
+ assert rewards[0] == 0.0, f"Should fail with incorrect sequence: {completion['messages'][0]['content']}"
+
+ # Test correct sequences
+ correct_formats = [
+ {
+ "messages": [
+ {"role": "assistant", "content": "\nSome reasoning\n\n\nThe answer\n"}
+ ]
+ },
+ {
+ "messages": [
+ {
+ "role": "assistant",
+ "content": "\nSome reasoning\n\nquery",
+ }
+ ]
+ },
+ ]
+
+ for completion in correct_formats:
+ rewards = reward_format([], [completion])
+ assert rewards[0] == 1.0, f"Should pass with correct sequence: {completion['messages'][0]['content']}"
def test_reward_format_multiple_answers():
@@ -381,3 +404,226 @@ def test_reward_correctness_validation(reward_correctness_fn):
for case in test_cases:
rewards = reward_correctness_fn(prompts, [case["completion"]], answer=["4"])
assert rewards[0] == case["expected"], f"Failed: {case['desc']}"
+
+
+def test_reward_search_strategy_perfect():
+ """Test search strategy reward with a perfect search strategy and info processing."""
+ content = [
+ {
+ "role": "assistant",
+ "content": """Let me search for a broad overview first.
+what is quantum computing overview""",
+ },
+ {
+ "role": "user",
+ "content": "Quantum computing uses quantum mechanics principles...",
+ },
+ {
+ "role": "assistant",
+ "content": """Based on the provided information about quantum mechanics, I should look for specific examples.
+quantum computing practical examples""",
+ },
+ {
+ "role": "user",
+ "content": "Some practical examples of quantum computing include...",
+ },
+ {
+ "role": "assistant",
+ "content": """According to the examples provided, I should investigate the advantages.
+quantum computing advantages over classical computing""",
+ },
+ {"role": "user", "content": "The key advantages of quantum computing are..."},
+ {
+ "role": "assistant",
+ "content": """Based on all the information gathered about quantum computing principles, examples, and advantages, I can now provide a comprehensive answer.
+Quantum computing is a revolutionary technology that...""",
+ },
+ ]
+
+ completions = [{"messages": content}]
+ rewards = reward_search_strategy(prompts=["test"], completions=completions)
+ assert rewards[0] == 1.0 # Initial (0.2) + info processing (0.4) + final synthesis (0.4)
+
+
+def test_reward_search_strategy_no_refinement():
+ """Test search strategy reward with just initial search."""
+ content = [
+ {
+ "role": "assistant",
+ "content": """Let me search for information.
+quantum computing""",
+ }
+ ]
+
+ completions = [{"messages": content}]
+ rewards = reward_search_strategy(prompts=["test"], completions=completions)
+ assert rewards[0] == 0.2 # Only initial search reward
+
+
+def test_reward_search_strategy_multiple_refinements():
+ """Test search strategy reward with multiple info-based refinements."""
+ content = [
+ {
+ "role": "assistant",
+ "content": """Starting with a broad search.
+what is machine learning""",
+ },
+ {"role": "user", "content": "Machine learning is a branch of AI..."},
+ {
+ "role": "assistant",
+ "content": """Based on this overview, let's look for specific applications.
+machine learning examples in healthcare""",
+ },
+ {"role": "user", "content": "In healthcare, ML is used for diagnosis..."},
+ {
+ "role": "assistant",
+ "content": """According to the healthcare examples, we should explore different methods.
+supervised vs unsupervised learning in medical diagnosis""",
+ },
+ ]
+
+ completions = [{"messages": content}]
+ rewards = reward_search_strategy(prompts=["test"], completions=completions)
+ assert rewards[0] == pytest.approx(0.6) # Initial (0.2) + info processing (0.4)
+
+
+def test_reward_search_strategy_search_operators():
+ """Test search strategy reward with search operators and info processing."""
+ content = [
+ {
+ "role": "assistant",
+ "content": """Let's use specific search operators.
+"exact phrase" site:edu""",
+ },
+ {"role": "user", "content": "Educational resources show that..."},
+ {
+ "role": "assistant",
+ "content": """Based on the educational resources, we need more specific results.
+machine learning -basic filetype:pdf""",
+ },
+ {"role": "user", "content": "Advanced ML concepts include..."},
+ {
+ "role": "assistant",
+ "content": """According to these findings, let's combine key concepts.
+AI AND "deep learning" OR "neural networks"""",
+ },
+ ]
+
+ completions = [{"messages": content}]
+ rewards = reward_search_strategy(prompts=["test"], completions=completions)
+ assert rewards[0] == pytest.approx(0.6) # Initial (0.2) + info processing (0.4)
+
+
+def test_reward_search_strategy_non_assistant_messages():
+ """Test search strategy reward with mixed message roles."""
+ content = [
+ {"role": "user", "content": "user search"},
+ {
+ "role": "assistant",
+ "content": """Let me search for information.
+what is quantum computing""",
+ },
+ {"role": "system", "content": "system thinking"},
+ ]
+
+ completions = [{"messages": content}]
+ rewards = reward_search_strategy(prompts=["test"], completions=completions)
+ assert rewards[0] == 0.2 # Only counts the assistant's initial search
+
+
+def test_reward_search_strategy_final_synthesis():
+ """Test search strategy reward with final synthesis but no refinements."""
+ content = [
+ {
+ "role": "assistant",
+ "content": """Let me search for information.
+quantum computing basics""",
+ },
+ {"role": "user", "content": "Quantum computing uses qubits..."},
+ {
+ "role": "assistant",
+ "content": """Based on the information about quantum computing basics, I can now provide an answer.
+Quantum computing is a field that...""",
+ },
+ ]
+
+ completions = [{"messages": content}]
+ rewards = reward_search_strategy(prompts=["test"], completions=completions)
+ assert rewards[0] == pytest.approx(0.6) # Initial (0.2) + final synthesis (0.4)
+
+
+def test_reward_search_diversity_no_search():
+ """Test search diversity reward with no search queries."""
+ content = [
+ {
+ "role": "assistant",
+ "content": "Let me think about this.",
+ }
+ ]
+ completions = [{"messages": content}]
+ rewards = reward_search_diversity(prompts=["test"], completions=completions)
+ assert rewards[0] == 0.0
+
+
+def test_reward_search_diversity_single_query():
+ """Test search diversity reward with a single search query."""
+ content = [
+ {
+ "role": "assistant",
+ "content": """Let me search.
+what is python programming""",
+ }
+ ]
+ completions = [{"messages": content}]
+ rewards = reward_search_diversity(prompts=["test"], completions=completions)
+ assert rewards[0] == pytest.approx(0.2) # Base reward for single query
+
+
+def test_reward_search_diversity_diverse_queries():
+ """Test search diversity reward with diverse search queries."""
+ content = [
+ {
+ "role": "assistant",
+ "content": """Let's start broad.
+what is machine learning""",
+ },
+ {
+ "role": "assistant",
+ "content": """Now specific applications.
+healthcare diagnosis using neural networks""",
+ },
+ {
+ "role": "assistant",
+ "content": """Let's look at a different aspect.
+ethical concerns in AI decision making""",
+ },
+ ]
+ completions = [{"messages": content}]
+ rewards = reward_search_diversity(prompts=["test"], completions=completions)
+ # Should get high reward due to diverse queries
+ assert rewards[0] > 0.5
+
+
+def test_reward_search_diversity_exact_duplicates():
+ """Test search diversity reward with exact duplicate queries."""
+ content = [
+ {
+ "role": "assistant",
+ "content": """First search.
+python tutorial""",
+ },
+ {
+ "role": "assistant",
+ "content": """Searching again.
+python tutorial""",
+ },
+ {
+ "role": "assistant",
+ "content": """One more time.
+python tutorial""",
+ },
+ ]
+ completions = [{"messages": content}]
+ rewards = reward_search_diversity(prompts=["test"], completions=completions)
+ # Should get very low reward due to exact duplicates
+ assert rewards[0] < 0.2