|
|
|
@ -299,7 +299,7 @@ def test_reward_retry():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_reward_em_chunk():
|
|
|
|
|
"""Test exact match reward function when all paragraphs are found."""
|
|
|
|
|
"""Test reward function when the LAST required paragraph is found."""
|
|
|
|
|
prompts = ["What is Python?"]
|
|
|
|
|
completions = [
|
|
|
|
|
{
|
|
|
|
@ -309,22 +309,21 @@ def test_reward_em_chunk():
|
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
]
|
|
|
|
|
# Expecting a list of lists for supporting_paragraphs
|
|
|
|
|
required_paragraphs = [["Python is a programming language", "It is widely used."]]
|
|
|
|
|
|
|
|
|
|
rewards = reward_em_chunk(prompts, completions, supporting_paragraphs=required_paragraphs)
|
|
|
|
|
assert len(rewards) == 1
|
|
|
|
|
assert rewards[0] == 1.0, "Should give full reward when all required paragraphs are found"
|
|
|
|
|
assert rewards[0] == 1.0, "Should give full reward when the LAST paragraph is found (even if others are too)"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_reward_em_chunk_partial_match():
|
|
|
|
|
"""Test exact match reward function when only some paragraphs are found."""
|
|
|
|
|
"""Test reward function when the LAST required paragraph is MISSING."""
|
|
|
|
|
prompts = ["What is Python?"]
|
|
|
|
|
completions = [
|
|
|
|
|
{
|
|
|
|
|
"messages": [
|
|
|
|
|
{"role": "user", "content": "<information>Python is a programming language</information>"}
|
|
|
|
|
# Missing the second paragraph
|
|
|
|
|
# Missing the last paragraph "It is widely used."
|
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
]
|
|
|
|
@ -332,7 +331,7 @@ def test_reward_em_chunk_partial_match():
|
|
|
|
|
|
|
|
|
|
rewards = reward_em_chunk(prompts, completions, supporting_paragraphs=required_paragraphs)
|
|
|
|
|
assert len(rewards) == 1
|
|
|
|
|
assert rewards[0] == 0.0, "Should give zero reward if not all required paragraphs are found"
|
|
|
|
|
assert rewards[0] == 0.0, "Should give zero reward if the LAST required paragraph is missing"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_reward_em_chunk_no_supporting_paragraphs_kwarg():
|
|
|
|
@ -346,56 +345,56 @@ def test_reward_em_chunk_no_supporting_paragraphs_kwarg():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_reward_em_chunk_multiple_completions():
|
|
|
|
|
"""Test reward EM chunk with multiple completions and varying paragraph requirements."""
|
|
|
|
|
"""Test reward function with multiple completions based on the LAST paragraph rule."""
|
|
|
|
|
prompts = ["Prompt 1", "Prompt 2", "Prompt 3"]
|
|
|
|
|
completions = [
|
|
|
|
|
# Completion 1: Matches both required paragraphs
|
|
|
|
|
# Completion 1: Finds the last required paragraph ("Second paragraph content")
|
|
|
|
|
{
|
|
|
|
|
"messages": [
|
|
|
|
|
{"role": "ipython", "content": "<information>First paragraph content</information>"},
|
|
|
|
|
{"role": "user", "content": "<information>Second paragraph content</information>"},
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
# Completion 2: Matches only one of the required paragraphs
|
|
|
|
|
# Completion 2: Does NOT find the last required paragraph ("Missing paragraph")
|
|
|
|
|
{"messages": [{"role": "user", "content": "<information>Third paragraph content</information>"}]},
|
|
|
|
|
# Completion 3: Matches all (only one required)
|
|
|
|
|
# Completion 3: Finds the last (and only) required paragraph ("Fourth paragraph")
|
|
|
|
|
{"messages": [{"role": "tool", "content": "<information>Fourth paragraph</information>"}]},
|
|
|
|
|
]
|
|
|
|
|
# List of lists for supporting paragraphs
|
|
|
|
|
reward_kwargs = {
|
|
|
|
|
"supporting_paragraphs": [
|
|
|
|
|
["First paragraph content", "Second paragraph content"], # Requires 2, gets 2 -> 1.0
|
|
|
|
|
["Third paragraph content", "Missing paragraph"], # Requires 2, gets 1 -> 0.0
|
|
|
|
|
["Fourth paragraph"], # Requires 1, gets 1 -> 1.0
|
|
|
|
|
["First paragraph content", "Second paragraph content"], # Last is found -> 1.0
|
|
|
|
|
["Third paragraph content", "Missing paragraph"], # Last is missing -> 0.0
|
|
|
|
|
["Fourth paragraph"], # Last is found -> 1.0
|
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
rewards = reward_em_chunk(prompts, completions, **reward_kwargs)
|
|
|
|
|
assert len(rewards) == 3
|
|
|
|
|
assert rewards == [1.0, 0.0, 1.0], "Should reward based on finding *all* required paragraphs per completion"
|
|
|
|
|
assert rewards == [1.0, 0.0, 1.0], "Should reward based ONLY on finding the LAST required paragraph per completion"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_reward_em_chunk_whitespace_handling():
|
|
|
|
|
"""Test reward EM chunk handles whitespace properly when checking paragraphs."""
|
|
|
|
|
"""Test reward function handles whitespace properly when checking the LAST paragraph."""
|
|
|
|
|
prompts = ["Whitespace test"]
|
|
|
|
|
completions = [
|
|
|
|
|
{
|
|
|
|
|
"messages": [
|
|
|
|
|
{"role": "ipython", "content": " <information> Paragraph with spaces. </information> "},
|
|
|
|
|
{"role": "user", "content": "<information>\\nAnother one.\\t</information>"},
|
|
|
|
|
# This is the last paragraph and should be found despite leading/trailing spaces in the info block
|
|
|
|
|
{"role": "user", "content": "<information> Another one. </information>"},
|
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
]
|
|
|
|
|
# The check is simple `paragraph in combined_information`.
|
|
|
|
|
# combined_information joins stripped content with '\\n'.
|
|
|
|
|
# The required paragraphs must match exactly what's expected in combined_information.
|
|
|
|
|
required_paragraphs = [["Paragraph with spaces.", "Another one."]]
|
|
|
|
|
# The required paragraphs list
|
|
|
|
|
required_paragraphs = [
|
|
|
|
|
["Paragraph with spaces.", "Another one."] # We only care about "Another one."
|
|
|
|
|
]
|
|
|
|
|
reward_kwargs = {"supporting_paragraphs": required_paragraphs}
|
|
|
|
|
|
|
|
|
|
rewards = reward_em_chunk(prompts, completions, **reward_kwargs)
|
|
|
|
|
assert len(rewards) == 1
|
|
|
|
|
# The function joins stripped content with newline, so "Paragraph with spaces." and "Another one." should be found.
|
|
|
|
|
assert rewards[0] == 1.0, "Should handle whitespace in content and still match exact paragraphs"
|
|
|
|
|
# The last paragraph "Another one." should be found in the combined, stripped content.
|
|
|
|
|
assert rewards[0] == 1.0, "Should correctly match the LAST paragraph even with whitespace issues"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_reward_format_search_or_answer_not_both():
|
|
|
|
|