test: added tests for new reward functions: search strategy and search diversity

main
thinhlpg 1 month ago
parent 4de31e0f30
commit 3081d6e36b

@ -9,6 +9,8 @@ from src.rewards import (
reward_em_chunk,
reward_format,
reward_retry,
reward_search_diversity,
reward_search_strategy,
)
@ -189,7 +191,7 @@ def test_reward_format_real_example_answer_only():
def test_reward_format_incorrect_tag_sequence():
"""Test reward format with incorrect tag sequence - should now pass since sequence doesn't matter"""
"""Test reward format with incorrect tag sequence - should fail since we require proper sequence and ending"""
formats = [
{
"messages": [
@ -208,7 +210,28 @@ def test_reward_format_incorrect_tag_sequence():
for completion in formats:
rewards = reward_format([], [completion])
assert rewards[0] == 1.0, f"Failed with: {completion['messages'][0]['content']}"
assert rewards[0] == 0.0, f"Should fail with incorrect sequence: {completion['messages'][0]['content']}"
# Test correct sequences
correct_formats = [
{
"messages": [
{"role": "assistant", "content": "<think>\nSome reasoning\n</think>\n<answer>\nThe answer\n</answer>"}
]
},
{
"messages": [
{
"role": "assistant",
"content": "<think>\nSome reasoning\n</think>\n<search>query</search>",
}
]
},
]
for completion in correct_formats:
rewards = reward_format([], [completion])
assert rewards[0] == 1.0, f"Should pass with correct sequence: {completion['messages'][0]['content']}"
def test_reward_format_multiple_answers():
@ -381,3 +404,226 @@ def test_reward_correctness_validation(reward_correctness_fn):
for case in test_cases:
rewards = reward_correctness_fn(prompts, [case["completion"]], answer=["4"])
assert rewards[0] == case["expected"], f"Failed: {case['desc']}"
def test_reward_search_strategy_perfect():
"""Test search strategy reward with a perfect search strategy and info processing."""
content = [
{
"role": "assistant",
"content": """<think>Let me search for a broad overview first.</think>
<search>what is quantum computing overview</search>""",
},
{
"role": "user",
"content": "<information>Quantum computing uses quantum mechanics principles...</information>",
},
{
"role": "assistant",
"content": """<think>Based on the provided information about quantum mechanics, I should look for specific examples.</think>
<search>quantum computing practical examples</search>""",
},
{
"role": "user",
"content": "<information>Some practical examples of quantum computing include...</information>",
},
{
"role": "assistant",
"content": """<think>According to the examples provided, I should investigate the advantages.</think>
<search>quantum computing advantages over classical computing</search>""",
},
{"role": "user", "content": "<information>The key advantages of quantum computing are...</information>"},
{
"role": "assistant",
"content": """<think>Based on all the information gathered about quantum computing principles, examples, and advantages, I can now provide a comprehensive answer.</think>
<answer>Quantum computing is a revolutionary technology that...</answer>""",
},
]
completions = [{"messages": content}]
rewards = reward_search_strategy(prompts=["test"], completions=completions)
assert rewards[0] == 1.0 # Initial (0.2) + info processing (0.4) + final synthesis (0.4)
def test_reward_search_strategy_no_refinement():
"""Test search strategy reward with just initial search."""
content = [
{
"role": "assistant",
"content": """<think>Let me search for information.</think>
<search>quantum computing</search>""",
}
]
completions = [{"messages": content}]
rewards = reward_search_strategy(prompts=["test"], completions=completions)
assert rewards[0] == 0.2 # Only initial search reward
def test_reward_search_strategy_multiple_refinements():
"""Test search strategy reward with multiple info-based refinements."""
content = [
{
"role": "assistant",
"content": """<think>Starting with a broad search.</think>
<search>what is machine learning</search>""",
},
{"role": "user", "content": "<information>Machine learning is a branch of AI...</information>"},
{
"role": "assistant",
"content": """<think>Based on this overview, let's look for specific applications.</think>
<search>machine learning examples in healthcare</search>""",
},
{"role": "user", "content": "<information>In healthcare, ML is used for diagnosis...</information>"},
{
"role": "assistant",
"content": """<think>According to the healthcare examples, we should explore different methods.</think>
<search>supervised vs unsupervised learning in medical diagnosis</search>""",
},
]
completions = [{"messages": content}]
rewards = reward_search_strategy(prompts=["test"], completions=completions)
assert rewards[0] == pytest.approx(0.6) # Initial (0.2) + info processing (0.4)
def test_reward_search_strategy_search_operators():
"""Test search strategy reward with search operators and info processing."""
content = [
{
"role": "assistant",
"content": """<think>Let's use specific search operators.</think>
<search>"exact phrase" site:edu</search>""",
},
{"role": "user", "content": "<information>Educational resources show that...</information>"},
{
"role": "assistant",
"content": """<think>Based on the educational resources, we need more specific results.</think>
<search>machine learning -basic filetype:pdf</search>""",
},
{"role": "user", "content": "<information>Advanced ML concepts include...</information>"},
{
"role": "assistant",
"content": """<think>According to these findings, let's combine key concepts.</think>
<search>AI AND "deep learning" OR "neural networks"</search>""",
},
]
completions = [{"messages": content}]
rewards = reward_search_strategy(prompts=["test"], completions=completions)
assert rewards[0] == pytest.approx(0.6) # Initial (0.2) + info processing (0.4)
def test_reward_search_strategy_non_assistant_messages():
"""Test search strategy reward with mixed message roles."""
content = [
{"role": "user", "content": "<search>user search</search>"},
{
"role": "assistant",
"content": """<think>Let me search for information.</think>
<search>what is quantum computing</search>""",
},
{"role": "system", "content": "<think>system thinking</think>"},
]
completions = [{"messages": content}]
rewards = reward_search_strategy(prompts=["test"], completions=completions)
assert rewards[0] == 0.2 # Only counts the assistant's initial search
def test_reward_search_strategy_final_synthesis():
"""Test search strategy reward with final synthesis but no refinements."""
content = [
{
"role": "assistant",
"content": """<think>Let me search for information.</think>
<search>quantum computing basics</search>""",
},
{"role": "user", "content": "<information>Quantum computing uses qubits...</information>"},
{
"role": "assistant",
"content": """<think>Based on the information about quantum computing basics, I can now provide an answer.</think>
<answer>Quantum computing is a field that...</answer>""",
},
]
completions = [{"messages": content}]
rewards = reward_search_strategy(prompts=["test"], completions=completions)
assert rewards[0] == pytest.approx(0.6) # Initial (0.2) + final synthesis (0.4)
def test_reward_search_diversity_no_search():
"""Test search diversity reward with no search queries."""
content = [
{
"role": "assistant",
"content": "<think>Let me think about this.</think>",
}
]
completions = [{"messages": content}]
rewards = reward_search_diversity(prompts=["test"], completions=completions)
assert rewards[0] == 0.0
def test_reward_search_diversity_single_query():
"""Test search diversity reward with a single search query."""
content = [
{
"role": "assistant",
"content": """<think>Let me search.</think>
<search>what is python programming</search>""",
}
]
completions = [{"messages": content}]
rewards = reward_search_diversity(prompts=["test"], completions=completions)
assert rewards[0] == pytest.approx(0.2) # Base reward for single query
def test_reward_search_diversity_diverse_queries():
"""Test search diversity reward with diverse search queries."""
content = [
{
"role": "assistant",
"content": """<think>Let's start broad.</think>
<search>what is machine learning</search>""",
},
{
"role": "assistant",
"content": """<think>Now specific applications.</think>
<search>healthcare diagnosis using neural networks</search>""",
},
{
"role": "assistant",
"content": """<think>Let's look at a different aspect.</think>
<search>ethical concerns in AI decision making</search>""",
},
]
completions = [{"messages": content}]
rewards = reward_search_diversity(prompts=["test"], completions=completions)
# Should get high reward due to diverse queries
assert rewards[0] > 0.5
def test_reward_search_diversity_exact_duplicates():
"""Test search diversity reward with exact duplicate queries."""
content = [
{
"role": "assistant",
"content": """<think>First search.</think>
<search>python tutorial</search>""",
},
{
"role": "assistant",
"content": """<think>Searching again.</think>
<search>python tutorial</search>""",
},
{
"role": "assistant",
"content": """<think>One more time.</think>
<search>python tutorial</search>""",
},
]
completions = [{"messages": content}]
rewards = reward_search_diversity(prompts=["test"], completions=completions)
# Should get very low reward due to exact duplicates
assert rewards[0] < 0.2

Loading…
Cancel
Save