feat: update reward_em_chunk to match only the LAST required paragraph of the reasoning chain and adjust related tests

main
thinhlpg 4 weeks ago
parent 358875a035
commit 504f0c6c8e

@ -347,7 +347,7 @@ def reward_retry(prompts: list, completions: list, **reward_kwargs) -> list:
def reward_em_chunk(prompts: list, completions: list, **reward_kwargs) -> list:
"""Reward function checks if model's search results contain all necessary supporting paragraphs.
"""Reward function checks if model's search results contain the LAST necessary supporting paragraph.
Args:
prompts: List of input prompts
@ -357,7 +357,7 @@ def reward_em_chunk(prompts: list, completions: list, **reward_kwargs) -> list:
Each inner list corresponds to a completion.
Returns:
list: List of rewards (1.0 if all paragraphs found, 0.0 otherwise)
list: List of rewards (1.0 if the LAST paragraph is found, 0.0 otherwise)
Raises:
ValueError: If supporting_paragraphs is not provided or invalid.
@ -376,7 +376,7 @@ def reward_em_chunk(prompts: list, completions: list, **reward_kwargs) -> list:
raise ValueError("Length of supporting_paragraphs must match length of completions")
rewards = []
all_found_statuses = [] # Track found status for each paragraph
all_found_statuses = [] # Track found status for the last paragraph
for i, (completion, required_paragraphs) in enumerate(zip(completions, all_supporting_paragraphs)):
if not isinstance(required_paragraphs, list):
@ -384,9 +384,20 @@ def reward_em_chunk(prompts: list, completions: list, **reward_kwargs) -> list:
f"supporting_paragraphs for completion {i} is not a list, skipping. Got: {type(required_paragraphs)}"
)
rewards.append(0.0)
all_found_statuses.append({p: False for p in required_paragraphs or []}) # Handle if None/empty
# Still log status for all potential paragraphs for consistency, even if we only check the last
all_found_statuses.append({p: False for p in required_paragraphs or []})
continue
# Check if the required_paragraphs list is empty - should ideally not happen but good practice
if not required_paragraphs:
logger.warning(f"Empty required_paragraphs list for completion {i}, assigning 0.0 reward.")
rewards.append(0.0)
all_found_statuses.append({})
continue
# Get the last required paragraph
last_required_paragraph = required_paragraphs[-1]
# Get all content from messages starting with <information> (from any role)
search_results_content = [
msg["content"].strip()
@ -398,32 +409,30 @@ def reward_em_chunk(prompts: list, completions: list, **reward_kwargs) -> list:
# Combine all found information into a single text block for easier searching
combined_information = "\n".join(search_results_content)
# Check if *all* required paragraphs are present in the combined information
found_status = {p: (p in combined_information) for p in required_paragraphs}
all_paragraphs_found = all(found_status.values())
# Check if the LAST required paragraph is present in the combined information
last_paragraph_found = last_required_paragraph in combined_information
if not all_paragraphs_found:
missing_paragraphs = [p for p, found in found_status.items() if not found]
if not last_paragraph_found:
logger.debug(
f"Failed to find all required paragraphs for prompt {i}.\n"
f"Required: {len(required_paragraphs)}, Found: {len(required_paragraphs) - len(missing_paragraphs)}\n"
f"Missing paragraphs: {[p[:100] + '...' for p in missing_paragraphs]}\n"
f"Failed to find the required LAST paragraph for prompt {i}.\n"
f"Required last paragraph: {last_required_paragraph[:100] + '...'}\n"
f"Searched in: {combined_information[:200] + '...'}"
)
reward = 1.0 if all_paragraphs_found else 0.0
reward = 1.0 if last_paragraph_found else 0.0
rewards.append(reward)
all_found_statuses.append(found_status)
logger.debug(f"Reward for prompt {i}: {reward}")
# Log the found status specifically for the last paragraph
all_found_statuses.append({last_required_paragraph: last_paragraph_found})
logger.debug(f"Reward for prompt {i} (based on last paragraph): {reward}")
# Log chat state
log_chat_state(
prompts=prompts,
completions=completions,
rewards=rewards,
reward_type="em_chunk",
reward_type="em_chunk_last", # Changed type to reflect new logic
supporting_paragraphs=all_supporting_paragraphs,
found_paragraph_statuses=all_found_statuses, # Log which paragraphs were found
found_last_paragraph_statuses=all_found_statuses, # Log which last paragraphs were found
)
return rewards

@ -299,7 +299,7 @@ def test_reward_retry():
def test_reward_em_chunk():
"""Test exact match reward function when all paragraphs are found."""
"""Test reward function when the LAST required paragraph is found."""
prompts = ["What is Python?"]
completions = [
{
@ -309,22 +309,21 @@ def test_reward_em_chunk():
]
}
]
# Expecting a list of lists for supporting_paragraphs
required_paragraphs = [["Python is a programming language", "It is widely used."]]
rewards = reward_em_chunk(prompts, completions, supporting_paragraphs=required_paragraphs)
assert len(rewards) == 1
assert rewards[0] == 1.0, "Should give full reward when all required paragraphs are found"
assert rewards[0] == 1.0, "Should give full reward when the LAST paragraph is found (even if others are too)"
def test_reward_em_chunk_partial_match():
"""Test exact match reward function when only some paragraphs are found."""
"""Test reward function when the LAST required paragraph is MISSING."""
prompts = ["What is Python?"]
completions = [
{
"messages": [
{"role": "user", "content": "<information>Python is a programming language</information>"}
# Missing the second paragraph
# Missing the last paragraph "It is widely used."
]
}
]
@ -332,7 +331,7 @@ def test_reward_em_chunk_partial_match():
rewards = reward_em_chunk(prompts, completions, supporting_paragraphs=required_paragraphs)
assert len(rewards) == 1
assert rewards[0] == 0.0, "Should give zero reward if not all required paragraphs are found"
assert rewards[0] == 0.0, "Should give zero reward if the LAST required paragraph is missing"
def test_reward_em_chunk_no_supporting_paragraphs_kwarg():
@ -346,56 +345,56 @@ def test_reward_em_chunk_no_supporting_paragraphs_kwarg():
def test_reward_em_chunk_multiple_completions():
"""Test reward EM chunk with multiple completions and varying paragraph requirements."""
"""Test reward function with multiple completions based on the LAST paragraph rule."""
prompts = ["Prompt 1", "Prompt 2", "Prompt 3"]
completions = [
# Completion 1: Matches both required paragraphs
# Completion 1: Finds the last required paragraph ("Second paragraph content")
{
"messages": [
{"role": "ipython", "content": "<information>First paragraph content</information>"},
{"role": "user", "content": "<information>Second paragraph content</information>"},
]
},
# Completion 2: Matches only one of the required paragraphs
# Completion 2: Does NOT find the last required paragraph ("Missing paragraph")
{"messages": [{"role": "user", "content": "<information>Third paragraph content</information>"}]},
# Completion 3: Matches all (only one required)
# Completion 3: Finds the last (and only) required paragraph ("Fourth paragraph")
{"messages": [{"role": "tool", "content": "<information>Fourth paragraph</information>"}]},
]
# List of lists for supporting paragraphs
reward_kwargs = {
"supporting_paragraphs": [
["First paragraph content", "Second paragraph content"], # Requires 2, gets 2 -> 1.0
["Third paragraph content", "Missing paragraph"], # Requires 2, gets 1 -> 0.0
["Fourth paragraph"], # Requires 1, gets 1 -> 1.0
["First paragraph content", "Second paragraph content"], # Last is found -> 1.0
["Third paragraph content", "Missing paragraph"], # Last is missing -> 0.0
["Fourth paragraph"], # Last is found -> 1.0
]
}
rewards = reward_em_chunk(prompts, completions, **reward_kwargs)
assert len(rewards) == 3
assert rewards == [1.0, 0.0, 1.0], "Should reward based on finding *all* required paragraphs per completion"
assert rewards == [1.0, 0.0, 1.0], "Should reward based ONLY on finding the LAST required paragraph per completion"
def test_reward_em_chunk_whitespace_handling():
"""Test reward EM chunk handles whitespace properly when checking paragraphs."""
"""Test reward function handles whitespace properly when checking the LAST paragraph."""
prompts = ["Whitespace test"]
completions = [
{
"messages": [
{"role": "ipython", "content": " <information> Paragraph with spaces. </information> "},
{"role": "user", "content": "<information>\\nAnother one.\\t</information>"},
# This is the last paragraph and should be found despite leading/trailing spaces in the info block
{"role": "user", "content": "<information> Another one. </information>"},
]
}
]
# The check is simple `paragraph in combined_information`.
# combined_information joins stripped content with '\\n'.
# The required paragraphs must match exactly what's expected in combined_information.
required_paragraphs = [["Paragraph with spaces.", "Another one."]]
# The required paragraphs list
required_paragraphs = [
["Paragraph with spaces.", "Another one."] # We only care about "Another one."
]
reward_kwargs = {"supporting_paragraphs": required_paragraphs}
rewards = reward_em_chunk(prompts, completions, **reward_kwargs)
assert len(rewards) == 1
# The function joins stripped content with newline, so "Paragraph with spaces." and "Another one." should be found.
assert rewards[0] == 1.0, "Should handle whitespace in content and still match exact paragraphs"
# The last paragraph "Another one." should be found in the combined, stripped content.
assert rewards[0] == 1.0, "Should correctly match the LAST paragraph even with whitespace issues"
def test_reward_format_search_or_answer_not_both():

Loading…
Cancel
Save