test: enhance reward correctness tests with validation logic

- Updated test cases to include role and tag validation for assistant messages.
- Ensured that only properly formatted messages with answer tags are accepted.
- Added new test for validating various incorrect formats and their expected outcomes.
main
thinhlpg 1 month ago
parent 338655e563
commit 1bd609dfae

@ -50,7 +50,7 @@ def reward_correctness_fn():
def test_reward_correctness_basic(reward_correctness_fn):
"""Test basic reward correctness functionality"""
prompts = ["What is 2+2?"]
completions = [{"messages": [{"content": "4"}]}]
completions = [{"messages": [{"role": "assistant", "content": "<answer>4</answer>"}]}]
reward_kwargs = {"answer": ["4"]}
rewards = reward_correctness_fn(prompts, completions, **reward_kwargs)
@ -61,7 +61,7 @@ def test_reward_correctness_basic(reward_correctness_fn):
def test_reward_correctness_wrong_answer(reward_correctness_fn):
"""Test reward correctness with wrong answer"""
prompts = ["What is 2+2?"]
completions = [{"messages": [{"content": "5"}]}]
completions = [{"messages": [{"role": "assistant", "content": "<answer>5</answer>"}]}]
reward_kwargs = {"answer": ["4"]}
rewards = reward_correctness_fn(prompts, completions, **reward_kwargs)
@ -432,3 +432,46 @@ def test_reward_format_search_or_answer_not_both():
completions = [{"messages": [{"role": "assistant", "content": content_answer_only}]}]
rewards = reward_format([], completions)
assert rewards[0] == 1.0, "Should accept messages with just answer tags"
def test_reward_correctness_validation(reward_correctness_fn):
"""Test reward correctness validation logic for message roles and tags"""
prompts = ["What is 2+2?"]
test_cases = [
# Test assistant role validation
{
"completion": {"messages": [{"role": "user", "content": "<answer>4</answer>"}]},
"expected": False,
"desc": "Non-assistant role should fail",
},
# Test answer tag validation
{
"completion": {"messages": [{"role": "assistant", "content": "4"}]},
"expected": False,
"desc": "Missing answer tags should fail",
},
# Test search tag validation
{
"completion": {"messages": [{"role": "assistant", "content": "<answer>4</answer><search>query</search>"}]},
"expected": False,
"desc": "Having search tags should fail",
},
# Test information tag validation
{
"completion": {
"messages": [{"role": "assistant", "content": "<answer>4</answer><information>info</information>"}]
},
"expected": False,
"desc": "Having information tags should fail",
},
# Test valid case
{
"completion": {"messages": [{"role": "assistant", "content": "<answer>4</answer>"}]},
"expected": True,
"desc": "Valid format should pass",
},
]
for case in test_cases:
rewards = reward_correctness_fn(prompts, [case["completion"]], answer=["4"])
assert rewards[0] == case["expected"], f"Failed: {case['desc']}"

Loading…
Cancel
Save