diff --git a/tests/test_rewards.py b/tests/test_rewards.py
index d03819d..bf8591b 100644
--- a/tests/test_rewards.py
+++ b/tests/test_rewards.py
@@ -50,7 +50,7 @@ def reward_correctness_fn():
def test_reward_correctness_basic(reward_correctness_fn):
"""Test basic reward correctness functionality"""
prompts = ["What is 2+2?"]
- completions = [{"messages": [{"content": "4"}]}]
+ completions = [{"messages": [{"role": "assistant", "content": "4"}]}]
reward_kwargs = {"answer": ["4"]}
rewards = reward_correctness_fn(prompts, completions, **reward_kwargs)
@@ -61,7 +61,7 @@ def test_reward_correctness_basic(reward_correctness_fn):
def test_reward_correctness_wrong_answer(reward_correctness_fn):
"""Test reward correctness with wrong answer"""
prompts = ["What is 2+2?"]
- completions = [{"messages": [{"content": "5"}]}]
+ completions = [{"messages": [{"role": "assistant", "content": "5"}]}]
reward_kwargs = {"answer": ["4"]}
rewards = reward_correctness_fn(prompts, completions, **reward_kwargs)
@@ -432,3 +432,46 @@ def test_reward_format_search_or_answer_not_both():
completions = [{"messages": [{"role": "assistant", "content": content_answer_only}]}]
rewards = reward_format([], completions)
assert rewards[0] == 1.0, "Should accept messages with just answer tags"
+
+
+def test_reward_correctness_validation(reward_correctness_fn):
+ """Test reward correctness validation logic for message roles and tags"""
+ prompts = ["What is 2+2?"]
+ test_cases = [
+ # Test assistant role validation
+ {
+ "completion": {"messages": [{"role": "user", "content": "4"}]},
+ "expected": False,
+ "desc": "Non-assistant role should fail",
+ },
+ # Test answer tag validation
+ {
+ "completion": {"messages": [{"role": "assistant", "content": "4"}]},
+ "expected": False,
+ "desc": "Missing answer tags should fail",
+ },
+ # Test search tag validation
+ {
+ "completion": {"messages": [{"role": "assistant", "content": "4query"}]},
+ "expected": False,
+ "desc": "Having search tags should fail",
+ },
+ # Test information tag validation
+ {
+ "completion": {
+ "messages": [{"role": "assistant", "content": "4info"}]
+ },
+ "expected": False,
+ "desc": "Having information tags should fail",
+ },
+ # Test valid case
+ {
+ "completion": {"messages": [{"role": "assistant", "content": "4"}]},
+ "expected": True,
+ "desc": "Valid format should pass",
+ },
+ ]
+
+ for case in test_cases:
+ rewards = reward_correctness_fn(prompts, [case["completion"]], answer=["4"])
+ assert rewards[0] == case["expected"], f"Failed: {case['desc']}"