You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
768 lines
30 KiB
768 lines
30 KiB
"""
|
|
Test cases for reward functions in rewards.py
|
|
"""
|
|
|
|
import pytest
|
|
|
|
from src.deepsearch.rewards import (
|
|
build_reward_correctness_fn,
|
|
reward_em_chunk,
|
|
reward_format,
|
|
reward_retry,
|
|
reward_search_diversity,
|
|
reward_search_strategy,
|
|
)
|
|
|
|
|
|
class MockResponse:
|
|
"""Mock response class that simulates vLLM response"""
|
|
|
|
def __init__(self, text):
|
|
self.outputs = [type("obj", (object,), {"text": text})()]
|
|
|
|
|
|
# Mock functions for testing
|
|
def mock_vllm_generate_func(*args, **kwargs):
|
|
"""Mock function that returns verification responses based on the input"""
|
|
# Check if the prompt contains "5" (wrong answer) or "4" (correct answer)
|
|
prompt = str(args[0]) if args else ""
|
|
if "5" in prompt:
|
|
return [MockResponse("No, the answer is incorrect")] # Return False for wrong answer
|
|
return [MockResponse("Yes, the answer is correct")] # Return True for correct answer
|
|
|
|
|
|
class MockTokenizer:
|
|
"""Mock tokenizer class that simulates the behavior of a real tokenizer"""
|
|
|
|
def __init__(self):
|
|
self.input_ids = [1, 2, 3]
|
|
|
|
def apply_chat_template(self, messages, tokenize=False, add_generation_prompt=True):
|
|
"""Mock apply_chat_template method"""
|
|
# For testing, we just return a formatted string
|
|
return "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
|
|
|
|
|
@pytest.fixture
|
|
def reward_correctness_fn():
|
|
"""Fixture to create reward correctness function"""
|
|
return build_reward_correctness_fn(mock_vllm_generate_func, MockTokenizer())
|
|
|
|
|
|
def test_reward_correctness_basic(reward_correctness_fn):
|
|
"""Test basic reward correctness functionality"""
|
|
prompts = ["What is 2+2?"]
|
|
completions = [{"messages": [{"role": "assistant", "content": "<answer>4</answer>"}]}]
|
|
reward_kwargs = {"answer": ["4"]}
|
|
|
|
rewards = reward_correctness_fn(prompts, completions, **reward_kwargs)
|
|
assert len(rewards) == 1 # Should return one verification result per answer
|
|
assert rewards[0] is True # Should be True for correct answer
|
|
|
|
|
|
def test_reward_correctness_wrong_answer(reward_correctness_fn):
|
|
"""Test reward correctness with wrong answer"""
|
|
prompts = ["What is 2+2?"]
|
|
completions = [{"messages": [{"role": "assistant", "content": "<answer>5</answer>"}]}]
|
|
reward_kwargs = {"answer": ["4"]}
|
|
|
|
rewards = reward_correctness_fn(prompts, completions, **reward_kwargs)
|
|
assert len(rewards) == 1 # Should return one verification result per answer
|
|
assert rewards[0] is False # Should be False for wrong answer
|
|
|
|
|
|
def test_reward_format_correct():
|
|
"""Test reward format with correct format"""
|
|
prompts = ["Test prompt"]
|
|
completions = [
|
|
{
|
|
"messages": [
|
|
{"role": "assistant", "content": "<think>\nSome reasoning\n</think>\n<answer>\nThe answer\n</answer>"}
|
|
]
|
|
}
|
|
]
|
|
rewards = reward_format(prompts, completions)
|
|
assert rewards[0] == 1.0
|
|
|
|
|
|
def test_reward_format_with_search():
|
|
"""Test reward format with search tags only (no answer tags)"""
|
|
prompts = ["Test prompt"]
|
|
completions = [
|
|
{"messages": [{"role": "assistant", "content": "<think>\nSome reasoning\n</think>\n<search>query</search>"}]}
|
|
]
|
|
rewards = reward_format(prompts, completions)
|
|
assert rewards[0] == 1.0
|
|
|
|
|
|
def test_reward_format_markdown_tags():
|
|
"""Test reward format with markdown-styled tags"""
|
|
prompts = ["Test prompt"]
|
|
markdown_formats = [
|
|
{
|
|
"messages": [
|
|
{
|
|
"role": "assistant",
|
|
"content": "**<think>**\nSome reasoning\n**</think>**\n<answer>\nThe answer\n</answer>",
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"messages": [
|
|
{
|
|
"role": "assistant",
|
|
"content": "*<think>*\nSome reasoning\n*</think>*\n<answer>\nThe answer\n</answer>",
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"messages": [
|
|
{
|
|
"role": "assistant",
|
|
"content": "_<think>_\nSome reasoning\n_</think>_\n<answer>\nThe answer\n</answer>",
|
|
}
|
|
]
|
|
},
|
|
]
|
|
|
|
for completion in markdown_formats:
|
|
rewards = reward_format(["Test prompt"], [completion])
|
|
assert rewards[0] == 0.0, f"Failed with: {completion['messages'][0]['content']}"
|
|
|
|
|
|
def test_reward_format_information_tags():
|
|
"""Test reward format with information tags"""
|
|
prompts = ["Test prompt"]
|
|
# Test different information tag variants
|
|
info_variants = [
|
|
"<information>Some info</information>",
|
|
"<info>Some info</info>",
|
|
"<Info>Some info</Info>",
|
|
"<INFORMATION>Some info</INFORMATION>",
|
|
"<INFO>Some info</INFO>",
|
|
]
|
|
|
|
for info_tag in info_variants:
|
|
content = f"<think>\nSome reasoning\n</think>\n{info_tag}\n<answer>\nThe answer\n</answer>"
|
|
completions = [{"messages": [{"role": "assistant", "content": content}]}]
|
|
rewards = reward_format(prompts, completions)
|
|
assert rewards[0] == 0.0, f"Failed to detect information tag: {info_tag}"
|
|
|
|
|
|
def test_reward_format_real_example():
|
|
"""Test reward format with a real-world example - should fail now since it has both search and answer tags"""
|
|
prompts = ["What cars did Paul Walker drive in Fast and Furious?"]
|
|
content = """<think>I need to search for Paul Walker's cars in Fast and Furious movies.</think>
|
|
<search> Paul Walker's cars in Fast and Furious </search>
|
|
|
|
From the information provided, it's clear that Paul Walker was a part of the "Fast and Furious" series, but the specific list of cars is not mentioned. Since I lack this particular detail, I will call a search engine to get the specific list of cars Paul Walker drove in the "Fast and Furious" movies.
|
|
|
|
<search> list of cars paul walker drove in Fast and Furious </search>
|
|
|
|
Based on the updated information, it seems the focus was on his career, financials, and family. However, I am still missing the specific list of cars he drove in the "Fast and Furious" movies. Since it appears that the information might not be contained within the accessed documents, and I have no further search queries to make, I will provide an answer based on the details I have.
|
|
|
|
<answer> Charger </answer>"""
|
|
|
|
completions = [{"messages": [{"role": "assistant", "content": content}]}]
|
|
rewards = reward_format(prompts, completions)
|
|
assert rewards[0] == 0.0, "Should reject responses with both search and answer tags"
|
|
|
|
|
|
def test_reward_format_real_example_search_only():
|
|
"""Test reward format with search-only format in a real-world example"""
|
|
prompts = ["What cars did Paul Walker drive in Fast and Furious?"]
|
|
content = """<think>I need to search for Paul Walker's cars in Fast and Furious movies.</think>
|
|
<search> Paul Walker's cars in Fast and Furious </search>"""
|
|
|
|
completions = [{"messages": [{"role": "assistant", "content": content}]}]
|
|
rewards = reward_format(prompts, completions)
|
|
assert rewards[0] == 1.0, "Should accept responses with only search tags"
|
|
|
|
|
|
def test_reward_format_real_example_answer_only():
|
|
"""Test reward format with answer-only format in a real-world example"""
|
|
prompts = ["What cars did Paul Walker drive in Fast and Furious?"]
|
|
content = """<think>Based on the information provided, it seems Paul Walker drove a Charger in the Fast and Furious series.</think>
|
|
<answer> Charger </answer>"""
|
|
|
|
completions = [{"messages": [{"role": "assistant", "content": content}]}]
|
|
rewards = reward_format(prompts, completions)
|
|
assert rewards[0] == 1.0, "Should accept responses with only answer tags"
|
|
|
|
|
|
def test_reward_format_incorrect_tag_sequence():
|
|
"""Test reward format with incorrect tag sequence - should fail since we require proper sequence and ending"""
|
|
formats = [
|
|
{
|
|
"messages": [
|
|
{"role": "assistant", "content": "<answer>\nThe answer\n</answer>\n<think>\nSome reasoning\n</think>"}
|
|
]
|
|
},
|
|
{
|
|
"messages": [
|
|
{
|
|
"role": "assistant",
|
|
"content": "<search>query</search>\n<think>\nSome reasoning\n</think>",
|
|
}
|
|
]
|
|
},
|
|
]
|
|
|
|
for completion in formats:
|
|
rewards = reward_format([], [completion])
|
|
assert rewards[0] == 0.0, f"Should fail with incorrect sequence: {completion['messages'][0]['content']}"
|
|
|
|
# Test correct sequences
|
|
correct_formats = [
|
|
{
|
|
"messages": [
|
|
{"role": "assistant", "content": "<think>\nSome reasoning\n</think>\n<answer>\nThe answer\n</answer>"}
|
|
]
|
|
},
|
|
{
|
|
"messages": [
|
|
{
|
|
"role": "assistant",
|
|
"content": "<think>\nSome reasoning\n</think>\n<search>query</search>",
|
|
}
|
|
]
|
|
},
|
|
]
|
|
|
|
for completion in correct_formats:
|
|
rewards = reward_format([], [completion])
|
|
assert rewards[0] == 1.0, f"Should pass with correct sequence: {completion['messages'][0]['content']}"
|
|
|
|
|
|
def test_reward_format_multiple_answers():
|
|
"""Test reward format with multiple answer tags"""
|
|
completions = [
|
|
{
|
|
"messages": [
|
|
{
|
|
"role": "assistant",
|
|
"content": "<think>\nSome reasoning\n</think>\n<answer>\nFirst answer\n</answer>\n<answer>\nSecond answer\n</answer>",
|
|
}
|
|
]
|
|
}
|
|
]
|
|
rewards = reward_format([], completions)
|
|
assert rewards[0] == 0.0
|
|
|
|
|
|
def test_reward_format_incomplete_tags():
|
|
"""Test reward format with incomplete tags"""
|
|
incomplete_formats = [
|
|
{
|
|
"messages": [
|
|
{"role": "assistant", "content": "<think>\nMissing closing think tag\n<answer>\nThe answer\n</answer>"}
|
|
]
|
|
},
|
|
{
|
|
"messages": [
|
|
{
|
|
"role": "assistant",
|
|
"content": "<think>\nSome reasoning\n</think>\n<answer>\nMissing closing answer tag",
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"messages": [
|
|
{
|
|
"role": "assistant",
|
|
"content": "Missing opening think tag\n</think>\n<answer>\nThe answer\n</answer>",
|
|
}
|
|
]
|
|
},
|
|
]
|
|
|
|
for completion in incomplete_formats:
|
|
rewards = reward_format([], [completion])
|
|
assert rewards[0] == 0.0, f"Failed with: {completion['messages'][0]['content']}"
|
|
|
|
|
|
def test_reward_retry():
|
|
"""Test reward retry function"""
|
|
prompts = ["What is the capital of France?"]
|
|
completions = [
|
|
{
|
|
"messages": [
|
|
{"role": "assistant", "content": "<think>Let me search</think>\n<search>capital of France</search>"},
|
|
{"role": "assistant", "content": "<think>Need more info</think>\n<search>Paris history</search>"},
|
|
{"role": "assistant", "content": "<think>Found it</think>\n<answer>Paris</answer>"},
|
|
]
|
|
}
|
|
]
|
|
rewards = reward_retry(prompts, completions)
|
|
assert len(rewards) == 1
|
|
assert rewards[0] > 0, "Should give positive reward for multiple search attempts"
|
|
|
|
|
|
def test_reward_em_chunk():
|
|
"""Test exact match chunk reward function"""
|
|
prompts = ["What is Python?"]
|
|
completions = [
|
|
{"messages": [{"role": "user", "content": "<information>Python is a programming language</information>"}]}
|
|
]
|
|
correct_contents = ["Python is a programming language"]
|
|
|
|
rewards = reward_em_chunk(prompts, completions, chunk_content=correct_contents)
|
|
assert len(rewards) == 1
|
|
assert rewards[0] == 1.0, "Should give full reward for exact chunk match"
|
|
|
|
|
|
def test_reward_em_chunk_no_chunk_content():
|
|
"""Test reward EM chunk with no chunk content provided"""
|
|
completions = [{"messages": [{"role": "ipython", "content": "<information>Some content</information>"}]}]
|
|
|
|
with pytest.raises(ValueError, match="chunk_content must be provided"):
|
|
reward_em_chunk([], completions)
|
|
|
|
|
|
def test_reward_em_chunk_multiple_chunks():
|
|
"""Test reward EM chunk with multiple chunks to match"""
|
|
completions = [
|
|
{"messages": [{"role": "ipython", "content": "<information>First chunk content</information>"}]},
|
|
{"messages": [{"role": "user", "content": "<information>Second chunk content</information>"}]},
|
|
]
|
|
reward_kwargs = {"chunk_content": ["First chunk content", "Second chunk content"]}
|
|
|
|
rewards = reward_em_chunk([], completions, **reward_kwargs)
|
|
assert len(rewards) == 2
|
|
assert rewards == [1.0, 1.0], "Should get reward 1.0 for each matching chunk"
|
|
|
|
|
|
def test_reward_em_chunk_whitespace_handling():
|
|
"""Test reward EM chunk handles whitespace properly"""
|
|
completions = [
|
|
{"messages": [{"role": "ipython", "content": " <information> Content with spaces </information> "}]}
|
|
]
|
|
reward_kwargs = {"chunk_content": ["Content with spaces"]}
|
|
|
|
rewards = reward_em_chunk([], completions, **reward_kwargs)
|
|
assert rewards[0] == 1.0, "Should handle whitespace in content and tags"
|
|
|
|
|
|
def test_reward_format_search_or_answer_not_both():
|
|
"""Test that having both search and answer tags in the same message is not allowed"""
|
|
content = "<think>I need to search</think>\n<search>query</search>\n<answer>Final answer</answer>"
|
|
completions = [{"messages": [{"role": "assistant", "content": content}]}]
|
|
rewards = reward_format([], completions)
|
|
assert rewards[0] == 0.0, "Should reject messages with both search and answer tags"
|
|
|
|
# Verify that having just search tag is valid
|
|
content_search_only = "<think>I need to search</think>\n<search>query</search>"
|
|
completions = [{"messages": [{"role": "assistant", "content": content_search_only}]}]
|
|
rewards = reward_format([], completions)
|
|
assert rewards[0] == 1.0, "Should accept messages with just search tags"
|
|
|
|
# Verify that having just answer tag is valid
|
|
content_answer_only = "<think>I know the answer</think>\n<answer>Final answer</answer>"
|
|
completions = [{"messages": [{"role": "assistant", "content": content_answer_only}]}]
|
|
rewards = reward_format([], completions)
|
|
assert rewards[0] == 1.0, "Should accept messages with just answer tags"
|
|
|
|
|
|
def test_reward_correctness_validation(reward_correctness_fn):
|
|
"""Test reward correctness validation logic for message roles and tags"""
|
|
prompts = ["What is 2+2?"]
|
|
test_cases = [
|
|
# Test assistant role validation
|
|
{
|
|
"completion": {"messages": [{"role": "user", "content": "<answer>4</answer>"}]},
|
|
"expected": False,
|
|
"desc": "Non-assistant role should fail",
|
|
},
|
|
# Test answer tag validation
|
|
{
|
|
"completion": {"messages": [{"role": "assistant", "content": "4"}]},
|
|
"expected": False,
|
|
"desc": "Missing answer tags should fail",
|
|
},
|
|
# Test search tag validation
|
|
{
|
|
"completion": {"messages": [{"role": "assistant", "content": "<answer>4</answer><search>query</search>"}]},
|
|
"expected": False,
|
|
"desc": "Having search tags should fail",
|
|
},
|
|
# Test information tag validation
|
|
{
|
|
"completion": {
|
|
"messages": [{"role": "assistant", "content": "<answer>4</answer><information>info</information>"}]
|
|
},
|
|
"expected": False,
|
|
"desc": "Having information tags should fail",
|
|
},
|
|
# Test valid case
|
|
{
|
|
"completion": {"messages": [{"role": "assistant", "content": "<answer>4</answer>"}]},
|
|
"expected": True,
|
|
"desc": "Valid format should pass",
|
|
},
|
|
]
|
|
|
|
for case in test_cases:
|
|
rewards = reward_correctness_fn(prompts, [case["completion"]], answer=["4"])
|
|
assert rewards[0] == case["expected"], f"Failed: {case['desc']}"
|
|
|
|
|
|
def test_reward_search_strategy_perfect():
|
|
"""Test search strategy reward with a perfect search strategy and info processing."""
|
|
content = [
|
|
{
|
|
"role": "assistant",
|
|
"content": """<think>Let me search for a broad overview first.</think>
|
|
<search>what is quantum computing overview</search>""",
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": "<information>Quantum computing uses quantum mechanics principles...</information>",
|
|
},
|
|
{
|
|
"role": "assistant",
|
|
"content": """<think>Based on the provided information about quantum mechanics, I should look for specific examples.</think>
|
|
<search>quantum computing practical examples</search>""",
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": "<information>Some practical examples of quantum computing include...</information>",
|
|
},
|
|
{
|
|
"role": "assistant",
|
|
"content": """<think>According to the examples provided, I should investigate the advantages.</think>
|
|
<search>quantum computing advantages over classical computing</search>""",
|
|
},
|
|
{"role": "user", "content": "<information>The key advantages of quantum computing are...</information>"},
|
|
{
|
|
"role": "assistant",
|
|
"content": """<think>Based on all the information gathered about quantum computing principles, examples, and advantages, I can now provide a comprehensive answer.</think>
|
|
<answer>Quantum computing is a revolutionary technology that...</answer>""",
|
|
},
|
|
]
|
|
|
|
completions = [{"messages": content}]
|
|
rewards = reward_search_strategy(prompts=["test"], completions=completions)
|
|
assert rewards[0] == 1.0 # Initial (0.2) + info processing (0.4) + final synthesis (0.4)
|
|
|
|
|
|
def test_reward_search_strategy_no_refinement():
|
|
"""Test search strategy reward with just initial search."""
|
|
content = [
|
|
{
|
|
"role": "assistant",
|
|
"content": """<think>Let me search for information.</think>
|
|
<search>quantum computing</search>""",
|
|
}
|
|
]
|
|
|
|
completions = [{"messages": content}]
|
|
rewards = reward_search_strategy(prompts=["test"], completions=completions)
|
|
assert rewards[0] == 0.2 # Only initial search reward
|
|
|
|
|
|
def test_reward_search_strategy_multiple_refinements():
|
|
"""Test search strategy reward with multiple info-based refinements."""
|
|
content = [
|
|
{
|
|
"role": "assistant",
|
|
"content": """<think>Starting with a broad search.</think>
|
|
<search>what is machine learning</search>""",
|
|
},
|
|
{"role": "user", "content": "<information>Machine learning is a branch of AI...</information>"},
|
|
{
|
|
"role": "assistant",
|
|
"content": """<think>Based on this overview, let's look for specific applications.</think>
|
|
<search>machine learning examples in healthcare</search>""",
|
|
},
|
|
{"role": "user", "content": "<information>In healthcare, ML is used for diagnosis...</information>"},
|
|
{
|
|
"role": "assistant",
|
|
"content": """<think>According to the healthcare examples, we should explore different methods.</think>
|
|
<search>supervised vs unsupervised learning in medical diagnosis</search>""",
|
|
},
|
|
]
|
|
|
|
completions = [{"messages": content}]
|
|
rewards = reward_search_strategy(prompts=["test"], completions=completions)
|
|
assert rewards[0] == pytest.approx(0.6) # Initial (0.2) + info processing (0.4)
|
|
|
|
|
|
def test_reward_search_strategy_search_operators():
|
|
"""Test search strategy reward with search operators and info processing."""
|
|
content = [
|
|
{
|
|
"role": "assistant",
|
|
"content": """<think>Let's use specific search operators.</think>
|
|
<search>"exact phrase" site:edu</search>""",
|
|
},
|
|
{"role": "user", "content": "<information>Educational resources show that...</information>"},
|
|
{
|
|
"role": "assistant",
|
|
"content": """<think>Based on the educational resources, we need more specific results.</think>
|
|
<search>machine learning -basic filetype:pdf</search>""",
|
|
},
|
|
{"role": "user", "content": "<information>Advanced ML concepts include...</information>"},
|
|
{
|
|
"role": "assistant",
|
|
"content": """<think>According to these findings, let's combine key concepts.</think>
|
|
<search>AI AND "deep learning" OR "neural networks"</search>""",
|
|
},
|
|
]
|
|
|
|
completions = [{"messages": content}]
|
|
rewards = reward_search_strategy(prompts=["test"], completions=completions)
|
|
assert rewards[0] == pytest.approx(0.6) # Initial (0.2) + info processing (0.4)
|
|
|
|
|
|
def test_reward_search_strategy_non_assistant_messages():
|
|
"""Test search strategy reward with mixed message roles."""
|
|
content = [
|
|
{"role": "user", "content": "<search>user search</search>"},
|
|
{
|
|
"role": "assistant",
|
|
"content": """<think>Let me search for information.</think>
|
|
<search>what is quantum computing</search>""",
|
|
},
|
|
{"role": "system", "content": "<think>system thinking</think>"},
|
|
]
|
|
|
|
completions = [{"messages": content}]
|
|
rewards = reward_search_strategy(prompts=["test"], completions=completions)
|
|
assert rewards[0] == 0.2 # Only counts the assistant's initial search
|
|
|
|
|
|
def test_reward_search_strategy_final_synthesis():
|
|
"""Test search strategy reward with final synthesis but no refinements."""
|
|
content = [
|
|
{
|
|
"role": "assistant",
|
|
"content": """<think>Let me search for information.</think>
|
|
<search>quantum computing basics</search>""",
|
|
},
|
|
{"role": "user", "content": "<information>Quantum computing uses qubits...</information>"},
|
|
{
|
|
"role": "assistant",
|
|
"content": """<think>Based on the information about quantum computing basics, I can now provide an answer.</think>
|
|
<answer>Quantum computing is a field that...</answer>""",
|
|
},
|
|
]
|
|
|
|
completions = [{"messages": content}]
|
|
rewards = reward_search_strategy(prompts=["test"], completions=completions)
|
|
assert rewards[0] == pytest.approx(0.6) # Initial (0.2) + final synthesis (0.4)
|
|
|
|
|
|
def test_reward_search_diversity_no_search():
|
|
"""Test search diversity reward with no search queries."""
|
|
content = [
|
|
{
|
|
"role": "assistant",
|
|
"content": "<think>Let me think about this.</think>",
|
|
}
|
|
]
|
|
completions = [{"messages": content}]
|
|
rewards = reward_search_diversity(prompts=["test"], completions=completions)
|
|
assert rewards[0] == 0.0
|
|
|
|
|
|
def test_reward_search_diversity_single_query():
|
|
"""Test search diversity reward with a single search query."""
|
|
content = [
|
|
{
|
|
"role": "assistant",
|
|
"content": """<think>Let me search.</think>
|
|
<search>what is python programming</search>""",
|
|
}
|
|
]
|
|
completions = [{"messages": content}]
|
|
rewards = reward_search_diversity(prompts=["test"], completions=completions)
|
|
assert rewards[0] == pytest.approx(0.2) # Base reward for single query
|
|
|
|
|
|
def test_reward_search_diversity_diverse_queries():
|
|
"""Test search diversity reward with diverse search queries."""
|
|
content = [
|
|
{
|
|
"role": "assistant",
|
|
"content": """<think>Let's start broad.</think>
|
|
<search>what is machine learning</search>""",
|
|
},
|
|
{
|
|
"role": "assistant",
|
|
"content": """<think>Now specific applications.</think>
|
|
<search>healthcare diagnosis using neural networks</search>""",
|
|
},
|
|
{
|
|
"role": "assistant",
|
|
"content": """<think>Let's look at a different aspect.</think>
|
|
<search>ethical concerns in AI decision making</search>""",
|
|
},
|
|
]
|
|
completions = [{"messages": content}]
|
|
rewards = reward_search_diversity(prompts=["test"], completions=completions)
|
|
# Should get high reward due to diverse queries
|
|
assert rewards[0] > 0.5
|
|
|
|
|
|
def test_reward_search_diversity_exact_duplicates():
|
|
"""Test search diversity reward with exact duplicate queries."""
|
|
content = [
|
|
{
|
|
"role": "assistant",
|
|
"content": """<think>First search.</think>
|
|
<search>python tutorial</search>""",
|
|
},
|
|
{
|
|
"role": "assistant",
|
|
"content": """<think>Searching again.</think>
|
|
<search>python tutorial</search>""",
|
|
},
|
|
{
|
|
"role": "assistant",
|
|
"content": """<think>One more time.</think>
|
|
<search>python tutorial</search>""",
|
|
},
|
|
]
|
|
completions = [{"messages": content}]
|
|
rewards = reward_search_diversity(prompts=["test"], completions=completions)
|
|
# Should get very low reward due to exact duplicates
|
|
assert rewards[0] < 0.2
|
|
|
|
|
|
def test_reward_retry_no_answer():
|
|
"""Test reward_retry when final message has no answer tags - should return 0."""
|
|
prompts = ["Test prompt"]
|
|
completions = [
|
|
{
|
|
"messages": [
|
|
{"role": "assistant", "content": "<think>Let me search</think><search>query 1</search>"},
|
|
{"role": "assistant", "content": "<think>Let me search again</think><search>query 2</search>"},
|
|
]
|
|
}
|
|
]
|
|
rewards = reward_retry(prompts, completions)
|
|
assert rewards[0] == 0.0, "Should return 0 when final message has no answer tags"
|
|
|
|
|
|
def test_reward_retry_with_answer():
|
|
"""Test reward_retry with answer in final message - should calculate reward normally."""
|
|
prompts = ["Test prompt"]
|
|
completions = [
|
|
{
|
|
"messages": [
|
|
{"role": "assistant", "content": "<think>Let me search</think><search>query 1</search>"},
|
|
{"role": "assistant", "content": "<think>Let me search again</think><search>query 2</search>"},
|
|
{"role": "assistant", "content": "<think>Here's the answer</think><answer>Final answer</answer>"},
|
|
]
|
|
}
|
|
]
|
|
rewards = reward_retry(prompts, completions)
|
|
expected = round(0.2 + 2 * 0.15, 3) # base_reward + 2 searches * increment
|
|
assert rewards[0] == expected, "Should calculate reward normally when final message has answer tags"
|
|
|
|
|
|
def test_reward_retry_violation_with_answer():
|
|
"""Test reward_retry with multiple searches in one message but answer in final message."""
|
|
prompts = ["Test prompt"]
|
|
completions = [
|
|
{
|
|
"messages": [
|
|
{
|
|
"role": "assistant",
|
|
"content": "<think>Multiple searches</think><search>query 1</search><search>query 2</search>",
|
|
},
|
|
{"role": "assistant", "content": "<think>Here's the answer</think><answer>Final answer</answer>"},
|
|
]
|
|
}
|
|
]
|
|
rewards = reward_retry(prompts, completions)
|
|
expected = round((0.2 + 2 * 0.15) * (1 - 0.5), 3) # (base + 2*increment) * (1 - violation_penalty)
|
|
assert rewards[0] == expected, "Should apply violation penalty but still calculate reward due to answer"
|
|
|
|
|
|
def test_reward_retry_optimal_searches():
|
|
"""Test reward_retry with optimal number of searches and answer."""
|
|
prompts = ["Test prompt"]
|
|
completions = [
|
|
{
|
|
"messages": [
|
|
{"role": "assistant", "content": "<think>Search 1</think><search>query 1</search>"},
|
|
{"role": "assistant", "content": "<think>Search 2</think><search>query 2</search>"},
|
|
{"role": "assistant", "content": "<think>Search 3</think><search>query 3</search>"},
|
|
{"role": "assistant", "content": "<think>Search 4</think><search>query 4</search>"},
|
|
{"role": "assistant", "content": "<think>Search 5</think><search>query 5</search>"},
|
|
{"role": "assistant", "content": "<think>Here's the answer</think><answer>Final answer</answer>"},
|
|
]
|
|
}
|
|
]
|
|
rewards = reward_retry(prompts, completions)
|
|
expected = round(0.2 + 5 * 0.15, 3) # base_reward + optimal_search_count * increment
|
|
assert rewards[0] == expected, "Should cap reward at optimal search count"
|
|
|
|
|
|
def test_reward_retry_beyond_optimal():
|
|
"""Test reward_retry with more than optimal searches but still with answer."""
|
|
prompts = ["Test prompt"]
|
|
completions = [
|
|
{
|
|
"messages": [
|
|
{"role": "assistant", "content": "<think>Search 1</think><search>query 1</search>"},
|
|
{"role": "assistant", "content": "<think>Search 2</think><search>query 2</search>"},
|
|
{"role": "assistant", "content": "<think>Search 3</think><search>query 3</search>"},
|
|
{"role": "assistant", "content": "<think>Search 4</think><search>query 4</search>"},
|
|
{"role": "assistant", "content": "<think>Search 5</think><search>query 5</search>"},
|
|
{"role": "assistant", "content": "<think>Search 6</think><search>query 6</search>"},
|
|
{"role": "assistant", "content": "<think>Here's the answer</think><answer>Final answer</answer>"},
|
|
]
|
|
}
|
|
]
|
|
rewards = reward_retry(prompts, completions)
|
|
expected = round(0.2 + 5 * 0.15, 3) # base_reward + optimal_search_count * increment
|
|
assert rewards[0] == expected, "Should not exceed max reward even with more searches"
|
|
|
|
|
|
def test_reward_retry_empty_messages():
|
|
"""Test reward_retry with empty message list."""
|
|
prompts = ["Test prompt"]
|
|
completions = [{"messages": []}]
|
|
rewards = reward_retry(prompts, completions)
|
|
assert rewards[0] == 0.0, "Should return 0 for empty message list"
|
|
|
|
|
|
def test_reward_retry_no_searches_with_answer():
|
|
"""Test reward_retry with no searches but has answer."""
|
|
prompts = ["Test prompt"]
|
|
completions = [
|
|
{
|
|
"messages": [
|
|
{"role": "assistant", "content": "<think>Direct answer</think><answer>Final answer</answer>"},
|
|
]
|
|
}
|
|
]
|
|
rewards = reward_retry(prompts, completions)
|
|
assert rewards[0] == 0.0, "Should return 0 when no searches even with answer"
|
|
|
|
|
|
def test_reward_retry_multiple_completions():
|
|
"""Test reward_retry with multiple completions."""
|
|
prompts = ["Test 1", "Test 2"]
|
|
completions = [
|
|
{
|
|
"messages": [
|
|
{"role": "assistant", "content": "<think>Search</think><search>query</search>"},
|
|
{"role": "assistant", "content": "<think>Answer</think><answer>Final 1</answer>"},
|
|
]
|
|
},
|
|
{
|
|
"messages": [
|
|
{"role": "assistant", "content": "<think>Search 1</think><search>query 1</search>"},
|
|
{"role": "assistant", "content": "<think>Search 2</think><search>query 2</search>"},
|
|
{"role": "assistant", "content": "<think>Answer</think><answer>Final 2</answer>"},
|
|
]
|
|
},
|
|
]
|
|
rewards = reward_retry(prompts, completions)
|
|
expected1 = round(0.2 + 0.15, 3) # base_reward + 1 search * increment
|
|
expected2 = round(0.2 + 2 * 0.15, 3) # base_reward + 2 searches * increment
|
|
assert rewards == [expected1, expected2], "Should handle multiple completions correctly"
|