From 1bd609dfae2f3dcd438d8a99b98ca27755c25839 Mon Sep 17 00:00:00 2001 From: thinhlpg Date: Thu, 3 Apr 2025 23:03:09 +0700 Subject: [PATCH] test: enhance reward correctness tests with validation logic - Updated test cases to include role and tag validation for assistant messages. - Ensured that only properly formatted messages with answer tags are accepted. - Added new test for validating various incorrect formats and their expected outcomes. --- tests/test_rewards.py | 47 +++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 45 insertions(+), 2 deletions(-) diff --git a/tests/test_rewards.py b/tests/test_rewards.py index d03819d..bf8591b 100644 --- a/tests/test_rewards.py +++ b/tests/test_rewards.py @@ -50,7 +50,7 @@ def reward_correctness_fn(): def test_reward_correctness_basic(reward_correctness_fn): """Test basic reward correctness functionality""" prompts = ["What is 2+2?"] - completions = [{"messages": [{"content": "4"}]}] + completions = [{"messages": [{"role": "assistant", "content": "4"}]}] reward_kwargs = {"answer": ["4"]} rewards = reward_correctness_fn(prompts, completions, **reward_kwargs) @@ -61,7 +61,7 @@ def test_reward_correctness_basic(reward_correctness_fn): def test_reward_correctness_wrong_answer(reward_correctness_fn): """Test reward correctness with wrong answer""" prompts = ["What is 2+2?"] - completions = [{"messages": [{"content": "5"}]}] + completions = [{"messages": [{"role": "assistant", "content": "5"}]}] reward_kwargs = {"answer": ["4"]} rewards = reward_correctness_fn(prompts, completions, **reward_kwargs) @@ -432,3 +432,46 @@ def test_reward_format_search_or_answer_not_both(): completions = [{"messages": [{"role": "assistant", "content": content_answer_only}]}] rewards = reward_format([], completions) assert rewards[0] == 1.0, "Should accept messages with just answer tags" + + +def test_reward_correctness_validation(reward_correctness_fn): + """Test reward correctness validation logic for message roles and tags""" + prompts = ["What is 2+2?"] + test_cases = [ + # Test assistant role validation + { + "completion": {"messages": [{"role": "user", "content": "4"}]}, + "expected": False, + "desc": "Non-assistant role should fail", + }, + # Test answer tag validation + { + "completion": {"messages": [{"role": "assistant", "content": "4"}]}, + "expected": False, + "desc": "Missing answer tags should fail", + }, + # Test search tag validation + { + "completion": {"messages": [{"role": "assistant", "content": "4query"}]}, + "expected": False, + "desc": "Having search tags should fail", + }, + # Test information tag validation + { + "completion": { + "messages": [{"role": "assistant", "content": "4info"}] + }, + "expected": False, + "desc": "Having information tags should fail", + }, + # Test valid case + { + "completion": {"messages": [{"role": "assistant", "content": "4"}]}, + "expected": True, + "desc": "Valid format should pass", + }, + ] + + for case in test_cases: + rewards = reward_correctness_fn(prompts, [case["completion"]], answer=["4"]) + assert rewards[0] == case["expected"], f"Failed: {case['desc']}"