feat: enhance reward_em_chunk function to match multiple paragraphs, add test

main
thinhlpg 4 weeks ago
parent 2df9f39fda
commit 358875a035

@ -347,167 +347,83 @@ def reward_retry(prompts: list, completions: list, **reward_kwargs) -> list:
def reward_em_chunk(prompts: list, completions: list, **reward_kwargs) -> list:
"""Reward function that checks if model's search queries hit the correct chunk content.
"""Reward function checks if model's search results contain all necessary supporting paragraphs.
Args:
prompts: List of input prompts
completions: List of completion dictionaries with messages
**reward_kwargs: Additional reward parameters including:
- chunk_content: List of correct chunk contents to match against
- step: Optional step number for logging metrics
- supporting_paragraphs: List of lists of correct paragraph strings to match against.
Each inner list corresponds to a completion.
Returns:
list: List of rewards (1.0 for exact match, 0.0 otherwise)
list: List of rewards (1.0 if all paragraphs found, 0.0 otherwise)
Raises:
ValueError: If chunk_content is not provided in reward_kwargs
ValueError: If supporting_paragraphs is not provided or invalid.
"""
logger.debug(f"Calculating rewards for {len(prompts)} prompts")
# Get correct chunk contents from reward kwargs
correct_contents = reward_kwargs.get("chunk_content", [])
if not correct_contents:
logger.error("No chunk_content provided in reward_kwargs")
raise ValueError("chunk_content must be provided in reward_kwargs")
logger.debug(f"Calculating 'em_chunk' rewards for {len(prompts)} prompts")
# Get correct supporting paragraphs lists from reward kwargs
all_supporting_paragraphs = reward_kwargs.get("supporting_paragraphs", [])
if not all_supporting_paragraphs or not isinstance(all_supporting_paragraphs, list):
logger.error("No 'supporting_paragraphs' list provided or it's invalid in reward_kwargs")
raise ValueError("supporting_paragraphs (list of lists) must be provided in reward_kwargs")
if len(all_supporting_paragraphs) != len(completions):
logger.error(
f"Mismatch between completions ({len(completions)}) and supporting_paragraphs ({len(all_supporting_paragraphs)})"
)
raise ValueError("Length of supporting_paragraphs must match length of completions")
rewards = []
for i, (completion, correct_content) in enumerate(zip(completions, correct_contents)):
# Get all messages from ipython or user roles that start with <information>
search_results = [
msg["content"]
for msg in completion["messages"]
if msg["role"] in ("ipython", "user") and msg["content"].strip().startswith("<information>")
]
logger.debug(f"Found {len(search_results)} search results for prompt {i}")
# Log ground truth and searched chunks for debugging
logger.info(f"📝 Ground Truth Chunk: {correct_content}")
for j, result in enumerate(search_results):
logger.info(f"🔍 Searched Chunk {j + 1}: {result}")
all_found_statuses = [] # Track found status for each paragraph
# Check if any search hit the correct chunk content
found_correct_chunk = any(correct_content in result for result in search_results)
if not found_correct_chunk:
for i, (completion, required_paragraphs) in enumerate(zip(completions, all_supporting_paragraphs)):
if not isinstance(required_paragraphs, list):
logger.warning(
f"Failed to find correct chunk for prompt {i}:\n"
f"Search results: {[r[:100] + '...' for r in search_results]}"
f"supporting_paragraphs for completion {i} is not a list, skipping. Got: {type(required_paragraphs)}"
)
reward = 1.0 if found_correct_chunk else 0.0
rewards.append(reward)
logger.debug(f"Reward for prompt {i}: {reward}")
# Log summary metrics
logger.info("Chunk Query Rewards Summary:")
logger.info(f"Total prompts: {len(prompts)}")
logger.info(f"Correct matches: {sum(rewards)}")
logger.info(f"Average reward: {np.mean(rewards):.3f}")
logger.info(f"Reward std: {np.std(rewards):.3f}")
# Log chat state
log_chat_state(
prompts=prompts,
completions=completions,
rewards=rewards,
reward_type="em_chunk",
correct_contents=correct_contents,
)
return rewards
def tag_count_reward(prompts: list, completions: list, **reward_kwargs) -> list:
"""Reward function that checks for proper tag counts in the conversation.
Rewards:
- 0.1 for each proper pair of think tags in each assistant message
- 0.5 for having exactly one pair of answer tags in entire conversation
- 0.1 for each proper pair of search tags
Args:
prompts: List of input prompts
completions: List of completion dictionaries with messages
**reward_kwargs: Additional reward parameters
Returns:
list: List of rewards between 0 and 1
"""
rewards = []
validation_results = {
"think_pairs_per_msg": [], # List of lists, each inner list has think pair counts per assistant msg
"answer_pairs": [], # Total answer pairs in conversation
"search_pairs": [], # Total search pairs in conversation
}
for completion in completions:
# Get all assistant messages
assistant_msgs = [msg["content"] for msg in completion["messages"] if msg["role"] == "assistant"]
if not assistant_msgs:
rewards.append(0.0)
validation_results["think_pairs_per_msg"].append([])
validation_results["answer_pairs"].append(0)
validation_results["search_pairs"].append(0)
all_found_statuses.append({p: False for p in required_paragraphs or []}) # Handle if None/empty
continue
# Count think pairs per assistant message
think_pairs_per_msg = []
for msg in assistant_msgs:
# Count complete think tag pairs
think_opens = len(re.findall(r"<think>", msg))
think_closes = len(re.findall(r"</think>", msg))
think_pairs = min(think_opens, think_closes)
think_pairs_per_msg.append(think_pairs)
# Count answer tags in entire conversation (should be exactly one pair)
total_answer_opens = sum(msg.count("<answer>") for msg in assistant_msgs)
total_answer_closes = sum(msg.count("</answer>") for msg in assistant_msgs)
answer_pairs = min(total_answer_opens, total_answer_closes)
# Count search tags
total_search_opens = sum(msg.count("<search>") for msg in assistant_msgs)
total_search_closes = sum(msg.count("</search>") for msg in assistant_msgs)
search_pairs = min(total_search_opens, total_search_closes)
# Calculate reward components
think_reward = sum(min(pairs, 1) * 0.1 for pairs in think_pairs_per_msg) # 0.1 per msg with proper think pair
answer_reward = 0.5 if answer_pairs == 1 else 0.0 # 0.5 for exactly one answer pair
search_reward = min(search_pairs, 1) * 0.1 # 0.1 for having search pairs
total_reward = min(think_reward + answer_reward + search_reward, 1.0)
rewards.append(total_reward)
# Get all content from messages starting with <information> (from any role)
search_results_content = [
msg["content"].strip()
for msg in completion.get("messages", [])
# Check role and prefix - allow 'user' role for potential manual info injection?
if msg.get("role") in ("ipython", "user", "tool")
and msg.get("content", "").strip().startswith("<information>")
]
# Combine all found information into a single text block for easier searching
combined_information = "\n".join(search_results_content)
# Store validation results
validation_results["think_pairs_per_msg"].append(think_pairs_per_msg)
validation_results["answer_pairs"].append(answer_pairs)
validation_results["search_pairs"].append(search_pairs)
# Check if *all* required paragraphs are present in the combined information
found_status = {p: (p in combined_information) for p in required_paragraphs}
all_paragraphs_found = all(found_status.values())
# Debug logging
if total_reward < 1.0:
if not all_paragraphs_found:
missing_paragraphs = [p for p, found in found_status.items() if not found]
logger.debug(
f"Tag count issues - think_pairs: {think_pairs_per_msg}, "
f"answer_pairs: {answer_pairs}, search_pairs: {search_pairs}"
f"Failed to find all required paragraphs for prompt {i}.\n"
f"Required: {len(required_paragraphs)}, Found: {len(required_paragraphs) - len(missing_paragraphs)}\n"
f"Missing paragraphs: {[p[:100] + '...' for p in missing_paragraphs]}\n"
f"Searched in: {combined_information[:200] + '...'}"
)
# Log metrics
logger.info(
f"Tag count reward metrics - Mean: {np.mean(rewards):.3f}, Perfect scores: {sum(r == 1.0 for r in rewards)}/{len(rewards)}"
)
logger.info(
f"Average think pairs per message: {np.mean([np.mean(pairs) if pairs else 0 for pairs in validation_results['think_pairs_per_msg']]):.2f}"
)
logger.info(
f"Conversations with exactly one answer pair: {sum(pairs == 1 for pairs in validation_results['answer_pairs'])}/{len(rewards)}"
)
reward = 1.0 if all_paragraphs_found else 0.0
rewards.append(reward)
all_found_statuses.append(found_status)
logger.debug(f"Reward for prompt {i}: {reward}")
# Log chat state
log_chat_state(
prompts=prompts,
completions=completions,
rewards=rewards,
reward_type="tag_count",
validation_results=validation_results,
reward_type="em_chunk",
supporting_paragraphs=all_supporting_paragraphs,
found_paragraph_statuses=all_found_statuses, # Log which paragraphs were found
)
return rewards

@ -299,48 +299,103 @@ def test_reward_retry():
def test_reward_em_chunk():
"""Test exact match chunk reward function"""
"""Test exact match reward function when all paragraphs are found."""
prompts = ["What is Python?"]
completions = [
{"messages": [{"role": "user", "content": "<information>Python is a programming language</information>"}]}
{
"messages": [
{"role": "user", "content": "<information>Python is a programming language</information>"},
{"role": "tool", "content": "<information>It is widely used.</information>"},
]
}
]
# Expecting a list of lists for supporting_paragraphs
required_paragraphs = [["Python is a programming language", "It is widely used."]]
rewards = reward_em_chunk(prompts, completions, supporting_paragraphs=required_paragraphs)
assert len(rewards) == 1
assert rewards[0] == 1.0, "Should give full reward when all required paragraphs are found"
def test_reward_em_chunk_partial_match():
"""Test exact match reward function when only some paragraphs are found."""
prompts = ["What is Python?"]
completions = [
{
"messages": [
{"role": "user", "content": "<information>Python is a programming language</information>"}
# Missing the second paragraph
]
}
]
correct_contents = ["Python is a programming language"]
required_paragraphs = [["Python is a programming language", "It is widely used."]]
rewards = reward_em_chunk(prompts, completions, chunk_content=correct_contents)
rewards = reward_em_chunk(prompts, completions, supporting_paragraphs=required_paragraphs)
assert len(rewards) == 1
assert rewards[0] == 1.0, "Should give full reward for exact chunk match"
assert rewards[0] == 0.0, "Should give zero reward if not all required paragraphs are found"
def test_reward_em_chunk_no_chunk_content():
"""Test reward EM chunk with no chunk content provided"""
def test_reward_em_chunk_no_supporting_paragraphs_kwarg():
"""Test reward EM chunk with no supporting_paragraphs kwarg provided."""
completions = [{"messages": [{"role": "ipython", "content": "<information>Some content</information>"}]}]
with pytest.raises(ValueError, match="chunk_content must be provided"):
# Updated match string to look for 'supporting_paragraphs'
with pytest.raises(ValueError, match="supporting_paragraphs .+ must be provided"):
# Pass empty list for prompts, as it's not used in the error check path
reward_em_chunk([], completions)
def test_reward_em_chunk_multiple_chunks():
"""Test reward EM chunk with multiple chunks to match"""
def test_reward_em_chunk_multiple_completions():
"""Test reward EM chunk with multiple completions and varying paragraph requirements."""
prompts = ["Prompt 1", "Prompt 2", "Prompt 3"]
completions = [
{"messages": [{"role": "ipython", "content": "<information>First chunk content</information>"}]},
{"messages": [{"role": "user", "content": "<information>Second chunk content</information>"}]},
# Completion 1: Matches both required paragraphs
{
"messages": [
{"role": "ipython", "content": "<information>First paragraph content</information>"},
{"role": "user", "content": "<information>Second paragraph content</information>"},
]
},
# Completion 2: Matches only one of the required paragraphs
{"messages": [{"role": "user", "content": "<information>Third paragraph content</information>"}]},
# Completion 3: Matches all (only one required)
{"messages": [{"role": "tool", "content": "<information>Fourth paragraph</information>"}]},
]
reward_kwargs = {"chunk_content": ["First chunk content", "Second chunk content"]}
# List of lists for supporting paragraphs
reward_kwargs = {
"supporting_paragraphs": [
["First paragraph content", "Second paragraph content"], # Requires 2, gets 2 -> 1.0
["Third paragraph content", "Missing paragraph"], # Requires 2, gets 1 -> 0.0
["Fourth paragraph"], # Requires 1, gets 1 -> 1.0
]
}
rewards = reward_em_chunk([], completions, **reward_kwargs)
assert len(rewards) == 2
assert rewards == [1.0, 1.0], "Should get reward 1.0 for each matching chunk"
rewards = reward_em_chunk(prompts, completions, **reward_kwargs)
assert len(rewards) == 3
assert rewards == [1.0, 0.0, 1.0], "Should reward based on finding *all* required paragraphs per completion"
def test_reward_em_chunk_whitespace_handling():
"""Test reward EM chunk handles whitespace properly"""
"""Test reward EM chunk handles whitespace properly when checking paragraphs."""
prompts = ["Whitespace test"]
completions = [
{"messages": [{"role": "ipython", "content": " <information> Content with spaces </information> "}]}
{
"messages": [
{"role": "ipython", "content": " <information> Paragraph with spaces. </information> "},
{"role": "user", "content": "<information>\\nAnother one.\\t</information>"},
]
}
]
reward_kwargs = {"chunk_content": ["Content with spaces"]}
# The check is simple `paragraph in combined_information`.
# combined_information joins stripped content with '\\n'.
# The required paragraphs must match exactly what's expected in combined_information.
required_paragraphs = [["Paragraph with spaces.", "Another one."]]
reward_kwargs = {"supporting_paragraphs": required_paragraphs}
rewards = reward_em_chunk([], completions, **reward_kwargs)
assert rewards[0] == 1.0, "Should handle whitespace in content and tags"
rewards = reward_em_chunk(prompts, completions, **reward_kwargs)
assert len(rewards) == 1
# The function joins stripped content with newline, so "Paragraph with spaces." and "Another one." should be found.
assert rewards[0] == 1.0, "Should handle whitespace in content and still match exact paragraphs"
def test_reward_format_search_or_answer_not_both():

Loading…
Cancel
Save