@ -50,7 +50,7 @@ def reward_correctness_fn():
def test_reward_correctness_basic ( reward_correctness_fn ) :
def test_reward_correctness_basic ( reward_correctness_fn ) :
""" Test basic reward correctness functionality """
""" Test basic reward correctness functionality """
prompts = [ " What is 2+2? " ]
prompts = [ " What is 2+2? " ]
completions = [ { " messages " : [ { " content" : " 4" } ] } ]
completions = [ { " messages " : [ { " role" : " assistant " , " content" : " <answer> 4</answer> " } ] } ]
reward_kwargs = { " answer " : [ " 4 " ] }
reward_kwargs = { " answer " : [ " 4 " ] }
rewards = reward_correctness_fn ( prompts , completions , * * reward_kwargs )
rewards = reward_correctness_fn ( prompts , completions , * * reward_kwargs )
@ -61,7 +61,7 @@ def test_reward_correctness_basic(reward_correctness_fn):
def test_reward_correctness_wrong_answer ( reward_correctness_fn ) :
def test_reward_correctness_wrong_answer ( reward_correctness_fn ) :
""" Test reward correctness with wrong answer """
""" Test reward correctness with wrong answer """
prompts = [ " What is 2+2? " ]
prompts = [ " What is 2+2? " ]
completions = [ { " messages " : [ { " content" : " 5" } ] } ]
completions = [ { " messages " : [ { " role" : " assistant " , " content" : " <answer> 5</answer> " } ] } ]
reward_kwargs = { " answer " : [ " 4 " ] }
reward_kwargs = { " answer " : [ " 4 " ] }
rewards = reward_correctness_fn ( prompts , completions , * * reward_kwargs )
rewards = reward_correctness_fn ( prompts , completions , * * reward_kwargs )
@ -432,3 +432,46 @@ def test_reward_format_search_or_answer_not_both():
completions = [ { " messages " : [ { " role " : " assistant " , " content " : content_answer_only } ] } ]
completions = [ { " messages " : [ { " role " : " assistant " , " content " : content_answer_only } ] } ]
rewards = reward_format ( [ ] , completions )
rewards = reward_format ( [ ] , completions )
assert rewards [ 0 ] == 1.0 , " Should accept messages with just answer tags "
assert rewards [ 0 ] == 1.0 , " Should accept messages with just answer tags "
def test_reward_correctness_validation ( reward_correctness_fn ) :
""" Test reward correctness validation logic for message roles and tags """
prompts = [ " What is 2+2? " ]
test_cases = [
# Test assistant role validation
{
" completion " : { " messages " : [ { " role " : " user " , " content " : " <answer>4</answer> " } ] } ,
" expected " : False ,
" desc " : " Non-assistant role should fail " ,
} ,
# Test answer tag validation
{
" completion " : { " messages " : [ { " role " : " assistant " , " content " : " 4 " } ] } ,
" expected " : False ,
" desc " : " Missing answer tags should fail " ,
} ,
# Test search tag validation
{
" completion " : { " messages " : [ { " role " : " assistant " , " content " : " <answer>4</answer><search>query</search> " } ] } ,
" expected " : False ,
" desc " : " Having search tags should fail " ,
} ,
# Test information tag validation
{
" completion " : {
" messages " : [ { " role " : " assistant " , " content " : " <answer>4</answer><information>info</information> " } ]
} ,
" expected " : False ,
" desc " : " Having information tags should fail " ,
} ,
# Test valid case
{
" completion " : { " messages " : [ { " role " : " assistant " , " content " : " <answer>4</answer> " } ] } ,
" expected " : True ,
" desc " : " Valid format should pass " ,
} ,
]
for case in test_cases :
rewards = reward_correctness_fn ( prompts , [ case [ " completion " ] ] , answer = [ " 4 " ] )
assert rewards [ 0 ] == case [ " expected " ] , f " Failed: { case [ ' desc ' ] } "