|
|
|
@ -347,167 +347,83 @@ def reward_retry(prompts: list, completions: list, **reward_kwargs) -> list:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reward_em_chunk(prompts: list, completions: list, **reward_kwargs) -> list:
|
|
|
|
|
"""Reward function that checks if model's search queries hit the correct chunk content.
|
|
|
|
|
"""Reward function checks if model's search results contain all necessary supporting paragraphs.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
prompts: List of input prompts
|
|
|
|
|
completions: List of completion dictionaries with messages
|
|
|
|
|
**reward_kwargs: Additional reward parameters including:
|
|
|
|
|
- chunk_content: List of correct chunk contents to match against
|
|
|
|
|
- step: Optional step number for logging metrics
|
|
|
|
|
- supporting_paragraphs: List of lists of correct paragraph strings to match against.
|
|
|
|
|
Each inner list corresponds to a completion.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
list: List of rewards (1.0 for exact match, 0.0 otherwise)
|
|
|
|
|
list: List of rewards (1.0 if all paragraphs found, 0.0 otherwise)
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
ValueError: If chunk_content is not provided in reward_kwargs
|
|
|
|
|
ValueError: If supporting_paragraphs is not provided or invalid.
|
|
|
|
|
"""
|
|
|
|
|
logger.debug(f"Calculating rewards for {len(prompts)} prompts")
|
|
|
|
|
|
|
|
|
|
# Get correct chunk contents from reward kwargs
|
|
|
|
|
correct_contents = reward_kwargs.get("chunk_content", [])
|
|
|
|
|
if not correct_contents:
|
|
|
|
|
logger.error("No chunk_content provided in reward_kwargs")
|
|
|
|
|
raise ValueError("chunk_content must be provided in reward_kwargs")
|
|
|
|
|
logger.debug(f"Calculating 'em_chunk' rewards for {len(prompts)} prompts")
|
|
|
|
|
|
|
|
|
|
# Get correct supporting paragraphs lists from reward kwargs
|
|
|
|
|
all_supporting_paragraphs = reward_kwargs.get("supporting_paragraphs", [])
|
|
|
|
|
if not all_supporting_paragraphs or not isinstance(all_supporting_paragraphs, list):
|
|
|
|
|
logger.error("No 'supporting_paragraphs' list provided or it's invalid in reward_kwargs")
|
|
|
|
|
raise ValueError("supporting_paragraphs (list of lists) must be provided in reward_kwargs")
|
|
|
|
|
if len(all_supporting_paragraphs) != len(completions):
|
|
|
|
|
logger.error(
|
|
|
|
|
f"Mismatch between completions ({len(completions)}) and supporting_paragraphs ({len(all_supporting_paragraphs)})"
|
|
|
|
|
)
|
|
|
|
|
raise ValueError("Length of supporting_paragraphs must match length of completions")
|
|
|
|
|
|
|
|
|
|
rewards = []
|
|
|
|
|
for i, (completion, correct_content) in enumerate(zip(completions, correct_contents)):
|
|
|
|
|
# Get all messages from ipython or user roles that start with <information>
|
|
|
|
|
search_results = [
|
|
|
|
|
msg["content"]
|
|
|
|
|
for msg in completion["messages"]
|
|
|
|
|
if msg["role"] in ("ipython", "user") and msg["content"].strip().startswith("<information>")
|
|
|
|
|
]
|
|
|
|
|
logger.debug(f"Found {len(search_results)} search results for prompt {i}")
|
|
|
|
|
|
|
|
|
|
# Log ground truth and searched chunks for debugging
|
|
|
|
|
logger.info(f"📝 Ground Truth Chunk: {correct_content}")
|
|
|
|
|
for j, result in enumerate(search_results):
|
|
|
|
|
logger.info(f"🔍 Searched Chunk {j + 1}: {result}")
|
|
|
|
|
all_found_statuses = [] # Track found status for each paragraph
|
|
|
|
|
|
|
|
|
|
# Check if any search hit the correct chunk content
|
|
|
|
|
found_correct_chunk = any(correct_content in result for result in search_results)
|
|
|
|
|
|
|
|
|
|
if not found_correct_chunk:
|
|
|
|
|
for i, (completion, required_paragraphs) in enumerate(zip(completions, all_supporting_paragraphs)):
|
|
|
|
|
if not isinstance(required_paragraphs, list):
|
|
|
|
|
logger.warning(
|
|
|
|
|
f"Failed to find correct chunk for prompt {i}:\n"
|
|
|
|
|
f"Search results: {[r[:100] + '...' for r in search_results]}"
|
|
|
|
|
f"supporting_paragraphs for completion {i} is not a list, skipping. Got: {type(required_paragraphs)}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
reward = 1.0 if found_correct_chunk else 0.0
|
|
|
|
|
rewards.append(reward)
|
|
|
|
|
logger.debug(f"Reward for prompt {i}: {reward}")
|
|
|
|
|
|
|
|
|
|
# Log summary metrics
|
|
|
|
|
logger.info("Chunk Query Rewards Summary:")
|
|
|
|
|
logger.info(f"Total prompts: {len(prompts)}")
|
|
|
|
|
logger.info(f"Correct matches: {sum(rewards)}")
|
|
|
|
|
logger.info(f"Average reward: {np.mean(rewards):.3f}")
|
|
|
|
|
logger.info(f"Reward std: {np.std(rewards):.3f}")
|
|
|
|
|
|
|
|
|
|
# Log chat state
|
|
|
|
|
log_chat_state(
|
|
|
|
|
prompts=prompts,
|
|
|
|
|
completions=completions,
|
|
|
|
|
rewards=rewards,
|
|
|
|
|
reward_type="em_chunk",
|
|
|
|
|
correct_contents=correct_contents,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return rewards
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def tag_count_reward(prompts: list, completions: list, **reward_kwargs) -> list:
|
|
|
|
|
"""Reward function that checks for proper tag counts in the conversation.
|
|
|
|
|
|
|
|
|
|
Rewards:
|
|
|
|
|
- 0.1 for each proper pair of think tags in each assistant message
|
|
|
|
|
- 0.5 for having exactly one pair of answer tags in entire conversation
|
|
|
|
|
- 0.1 for each proper pair of search tags
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
prompts: List of input prompts
|
|
|
|
|
completions: List of completion dictionaries with messages
|
|
|
|
|
**reward_kwargs: Additional reward parameters
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
list: List of rewards between 0 and 1
|
|
|
|
|
"""
|
|
|
|
|
rewards = []
|
|
|
|
|
validation_results = {
|
|
|
|
|
"think_pairs_per_msg": [], # List of lists, each inner list has think pair counts per assistant msg
|
|
|
|
|
"answer_pairs": [], # Total answer pairs in conversation
|
|
|
|
|
"search_pairs": [], # Total search pairs in conversation
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for completion in completions:
|
|
|
|
|
# Get all assistant messages
|
|
|
|
|
assistant_msgs = [msg["content"] for msg in completion["messages"] if msg["role"] == "assistant"]
|
|
|
|
|
|
|
|
|
|
if not assistant_msgs:
|
|
|
|
|
rewards.append(0.0)
|
|
|
|
|
validation_results["think_pairs_per_msg"].append([])
|
|
|
|
|
validation_results["answer_pairs"].append(0)
|
|
|
|
|
validation_results["search_pairs"].append(0)
|
|
|
|
|
all_found_statuses.append({p: False for p in required_paragraphs or []}) # Handle if None/empty
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# Count think pairs per assistant message
|
|
|
|
|
think_pairs_per_msg = []
|
|
|
|
|
for msg in assistant_msgs:
|
|
|
|
|
# Count complete think tag pairs
|
|
|
|
|
think_opens = len(re.findall(r"<think>", msg))
|
|
|
|
|
think_closes = len(re.findall(r"</think>", msg))
|
|
|
|
|
think_pairs = min(think_opens, think_closes)
|
|
|
|
|
think_pairs_per_msg.append(think_pairs)
|
|
|
|
|
|
|
|
|
|
# Count answer tags in entire conversation (should be exactly one pair)
|
|
|
|
|
total_answer_opens = sum(msg.count("<answer>") for msg in assistant_msgs)
|
|
|
|
|
total_answer_closes = sum(msg.count("</answer>") for msg in assistant_msgs)
|
|
|
|
|
answer_pairs = min(total_answer_opens, total_answer_closes)
|
|
|
|
|
|
|
|
|
|
# Count search tags
|
|
|
|
|
total_search_opens = sum(msg.count("<search>") for msg in assistant_msgs)
|
|
|
|
|
total_search_closes = sum(msg.count("</search>") for msg in assistant_msgs)
|
|
|
|
|
search_pairs = min(total_search_opens, total_search_closes)
|
|
|
|
|
|
|
|
|
|
# Calculate reward components
|
|
|
|
|
think_reward = sum(min(pairs, 1) * 0.1 for pairs in think_pairs_per_msg) # 0.1 per msg with proper think pair
|
|
|
|
|
answer_reward = 0.5 if answer_pairs == 1 else 0.0 # 0.5 for exactly one answer pair
|
|
|
|
|
search_reward = min(search_pairs, 1) * 0.1 # 0.1 for having search pairs
|
|
|
|
|
|
|
|
|
|
total_reward = min(think_reward + answer_reward + search_reward, 1.0)
|
|
|
|
|
rewards.append(total_reward)
|
|
|
|
|
# Get all content from messages starting with <information> (from any role)
|
|
|
|
|
search_results_content = [
|
|
|
|
|
msg["content"].strip()
|
|
|
|
|
for msg in completion.get("messages", [])
|
|
|
|
|
# Check role and prefix - allow 'user' role for potential manual info injection?
|
|
|
|
|
if msg.get("role") in ("ipython", "user", "tool")
|
|
|
|
|
and msg.get("content", "").strip().startswith("<information>")
|
|
|
|
|
]
|
|
|
|
|
# Combine all found information into a single text block for easier searching
|
|
|
|
|
combined_information = "\n".join(search_results_content)
|
|
|
|
|
|
|
|
|
|
# Store validation results
|
|
|
|
|
validation_results["think_pairs_per_msg"].append(think_pairs_per_msg)
|
|
|
|
|
validation_results["answer_pairs"].append(answer_pairs)
|
|
|
|
|
validation_results["search_pairs"].append(search_pairs)
|
|
|
|
|
# Check if *all* required paragraphs are present in the combined information
|
|
|
|
|
found_status = {p: (p in combined_information) for p in required_paragraphs}
|
|
|
|
|
all_paragraphs_found = all(found_status.values())
|
|
|
|
|
|
|
|
|
|
# Debug logging
|
|
|
|
|
if total_reward < 1.0:
|
|
|
|
|
if not all_paragraphs_found:
|
|
|
|
|
missing_paragraphs = [p for p, found in found_status.items() if not found]
|
|
|
|
|
logger.debug(
|
|
|
|
|
f"Tag count issues - think_pairs: {think_pairs_per_msg}, "
|
|
|
|
|
f"answer_pairs: {answer_pairs}, search_pairs: {search_pairs}"
|
|
|
|
|
f"Failed to find all required paragraphs for prompt {i}.\n"
|
|
|
|
|
f"Required: {len(required_paragraphs)}, Found: {len(required_paragraphs) - len(missing_paragraphs)}\n"
|
|
|
|
|
f"Missing paragraphs: {[p[:100] + '...' for p in missing_paragraphs]}\n"
|
|
|
|
|
f"Searched in: {combined_information[:200] + '...'}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Log metrics
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Tag count reward metrics - Mean: {np.mean(rewards):.3f}, Perfect scores: {sum(r == 1.0 for r in rewards)}/{len(rewards)}"
|
|
|
|
|
)
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Average think pairs per message: {np.mean([np.mean(pairs) if pairs else 0 for pairs in validation_results['think_pairs_per_msg']]):.2f}"
|
|
|
|
|
)
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Conversations with exactly one answer pair: {sum(pairs == 1 for pairs in validation_results['answer_pairs'])}/{len(rewards)}"
|
|
|
|
|
)
|
|
|
|
|
reward = 1.0 if all_paragraphs_found else 0.0
|
|
|
|
|
rewards.append(reward)
|
|
|
|
|
all_found_statuses.append(found_status)
|
|
|
|
|
logger.debug(f"Reward for prompt {i}: {reward}")
|
|
|
|
|
|
|
|
|
|
# Log chat state
|
|
|
|
|
log_chat_state(
|
|
|
|
|
prompts=prompts,
|
|
|
|
|
completions=completions,
|
|
|
|
|
rewards=rewards,
|
|
|
|
|
reward_type="tag_count",
|
|
|
|
|
validation_results=validation_results,
|
|
|
|
|
reward_type="em_chunk",
|
|
|
|
|
supporting_paragraphs=all_supporting_paragraphs,
|
|
|
|
|
found_paragraph_statuses=all_found_statuses, # Log which paragraphs were found
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return rewards
|
|
|
|
|