From 358875a0353237531615ac23ef11c16862ccddd3 Mon Sep 17 00:00:00 2001 From: thinhlpg Date: Fri, 11 Apr 2025 17:21:51 +0000 Subject: [PATCH] feat: enhance reward_em_chunk function to match multiple paragraphs, add test --- src/rewards.py | 180 +++++++++++------------------------------- tests/test_rewards.py | 97 ++++++++++++++++++----- 2 files changed, 124 insertions(+), 153 deletions(-) diff --git a/src/rewards.py b/src/rewards.py index 9fce68f..c6c1dfa 100644 --- a/src/rewards.py +++ b/src/rewards.py @@ -347,167 +347,83 @@ def reward_retry(prompts: list, completions: list, **reward_kwargs) -> list: def reward_em_chunk(prompts: list, completions: list, **reward_kwargs) -> list: - """Reward function that checks if model's search queries hit the correct chunk content. + """Reward function checks if model's search results contain all necessary supporting paragraphs. Args: prompts: List of input prompts completions: List of completion dictionaries with messages **reward_kwargs: Additional reward parameters including: - - chunk_content: List of correct chunk contents to match against - - step: Optional step number for logging metrics + - supporting_paragraphs: List of lists of correct paragraph strings to match against. + Each inner list corresponds to a completion. Returns: - list: List of rewards (1.0 for exact match, 0.0 otherwise) + list: List of rewards (1.0 if all paragraphs found, 0.0 otherwise) Raises: - ValueError: If chunk_content is not provided in reward_kwargs + ValueError: If supporting_paragraphs is not provided or invalid. """ - logger.debug(f"Calculating rewards for {len(prompts)} prompts") - - # Get correct chunk contents from reward kwargs - correct_contents = reward_kwargs.get("chunk_content", []) - if not correct_contents: - logger.error("No chunk_content provided in reward_kwargs") - raise ValueError("chunk_content must be provided in reward_kwargs") + logger.debug(f"Calculating 'em_chunk' rewards for {len(prompts)} prompts") + + # Get correct supporting paragraphs lists from reward kwargs + all_supporting_paragraphs = reward_kwargs.get("supporting_paragraphs", []) + if not all_supporting_paragraphs or not isinstance(all_supporting_paragraphs, list): + logger.error("No 'supporting_paragraphs' list provided or it's invalid in reward_kwargs") + raise ValueError("supporting_paragraphs (list of lists) must be provided in reward_kwargs") + if len(all_supporting_paragraphs) != len(completions): + logger.error( + f"Mismatch between completions ({len(completions)}) and supporting_paragraphs ({len(all_supporting_paragraphs)})" + ) + raise ValueError("Length of supporting_paragraphs must match length of completions") rewards = [] - for i, (completion, correct_content) in enumerate(zip(completions, correct_contents)): - # Get all messages from ipython or user roles that start with - search_results = [ - msg["content"] - for msg in completion["messages"] - if msg["role"] in ("ipython", "user") and msg["content"].strip().startswith("") - ] - logger.debug(f"Found {len(search_results)} search results for prompt {i}") - - # Log ground truth and searched chunks for debugging - logger.info(f"📝 Ground Truth Chunk: {correct_content}") - for j, result in enumerate(search_results): - logger.info(f"🔍 Searched Chunk {j + 1}: {result}") + all_found_statuses = [] # Track found status for each paragraph - # Check if any search hit the correct chunk content - found_correct_chunk = any(correct_content in result for result in search_results) - - if not found_correct_chunk: + for i, (completion, required_paragraphs) in enumerate(zip(completions, all_supporting_paragraphs)): + if not isinstance(required_paragraphs, list): logger.warning( - f"Failed to find correct chunk for prompt {i}:\n" - f"Search results: {[r[:100] + '...' for r in search_results]}" + f"supporting_paragraphs for completion {i} is not a list, skipping. Got: {type(required_paragraphs)}" ) - - reward = 1.0 if found_correct_chunk else 0.0 - rewards.append(reward) - logger.debug(f"Reward for prompt {i}: {reward}") - - # Log summary metrics - logger.info("Chunk Query Rewards Summary:") - logger.info(f"Total prompts: {len(prompts)}") - logger.info(f"Correct matches: {sum(rewards)}") - logger.info(f"Average reward: {np.mean(rewards):.3f}") - logger.info(f"Reward std: {np.std(rewards):.3f}") - - # Log chat state - log_chat_state( - prompts=prompts, - completions=completions, - rewards=rewards, - reward_type="em_chunk", - correct_contents=correct_contents, - ) - - return rewards - - -def tag_count_reward(prompts: list, completions: list, **reward_kwargs) -> list: - """Reward function that checks for proper tag counts in the conversation. - - Rewards: - - 0.1 for each proper pair of think tags in each assistant message - - 0.5 for having exactly one pair of answer tags in entire conversation - - 0.1 for each proper pair of search tags - - Args: - prompts: List of input prompts - completions: List of completion dictionaries with messages - **reward_kwargs: Additional reward parameters - - Returns: - list: List of rewards between 0 and 1 - """ - rewards = [] - validation_results = { - "think_pairs_per_msg": [], # List of lists, each inner list has think pair counts per assistant msg - "answer_pairs": [], # Total answer pairs in conversation - "search_pairs": [], # Total search pairs in conversation - } - - for completion in completions: - # Get all assistant messages - assistant_msgs = [msg["content"] for msg in completion["messages"] if msg["role"] == "assistant"] - - if not assistant_msgs: rewards.append(0.0) - validation_results["think_pairs_per_msg"].append([]) - validation_results["answer_pairs"].append(0) - validation_results["search_pairs"].append(0) + all_found_statuses.append({p: False for p in required_paragraphs or []}) # Handle if None/empty continue - # Count think pairs per assistant message - think_pairs_per_msg = [] - for msg in assistant_msgs: - # Count complete think tag pairs - think_opens = len(re.findall(r"", msg)) - think_closes = len(re.findall(r"", msg)) - think_pairs = min(think_opens, think_closes) - think_pairs_per_msg.append(think_pairs) - - # Count answer tags in entire conversation (should be exactly one pair) - total_answer_opens = sum(msg.count("") for msg in assistant_msgs) - total_answer_closes = sum(msg.count("") for msg in assistant_msgs) - answer_pairs = min(total_answer_opens, total_answer_closes) - - # Count search tags - total_search_opens = sum(msg.count("") for msg in assistant_msgs) - total_search_closes = sum(msg.count("") for msg in assistant_msgs) - search_pairs = min(total_search_opens, total_search_closes) - - # Calculate reward components - think_reward = sum(min(pairs, 1) * 0.1 for pairs in think_pairs_per_msg) # 0.1 per msg with proper think pair - answer_reward = 0.5 if answer_pairs == 1 else 0.0 # 0.5 for exactly one answer pair - search_reward = min(search_pairs, 1) * 0.1 # 0.1 for having search pairs - - total_reward = min(think_reward + answer_reward + search_reward, 1.0) - rewards.append(total_reward) + # Get all content from messages starting with (from any role) + search_results_content = [ + msg["content"].strip() + for msg in completion.get("messages", []) + # Check role and prefix - allow 'user' role for potential manual info injection? + if msg.get("role") in ("ipython", "user", "tool") + and msg.get("content", "").strip().startswith("") + ] + # Combine all found information into a single text block for easier searching + combined_information = "\n".join(search_results_content) - # Store validation results - validation_results["think_pairs_per_msg"].append(think_pairs_per_msg) - validation_results["answer_pairs"].append(answer_pairs) - validation_results["search_pairs"].append(search_pairs) + # Check if *all* required paragraphs are present in the combined information + found_status = {p: (p in combined_information) for p in required_paragraphs} + all_paragraphs_found = all(found_status.values()) - # Debug logging - if total_reward < 1.0: + if not all_paragraphs_found: + missing_paragraphs = [p for p, found in found_status.items() if not found] logger.debug( - f"Tag count issues - think_pairs: {think_pairs_per_msg}, " - f"answer_pairs: {answer_pairs}, search_pairs: {search_pairs}" + f"Failed to find all required paragraphs for prompt {i}.\n" + f"Required: {len(required_paragraphs)}, Found: {len(required_paragraphs) - len(missing_paragraphs)}\n" + f"Missing paragraphs: {[p[:100] + '...' for p in missing_paragraphs]}\n" + f"Searched in: {combined_information[:200] + '...'}" ) - # Log metrics - logger.info( - f"Tag count reward metrics - Mean: {np.mean(rewards):.3f}, Perfect scores: {sum(r == 1.0 for r in rewards)}/{len(rewards)}" - ) - logger.info( - f"Average think pairs per message: {np.mean([np.mean(pairs) if pairs else 0 for pairs in validation_results['think_pairs_per_msg']]):.2f}" - ) - logger.info( - f"Conversations with exactly one answer pair: {sum(pairs == 1 for pairs in validation_results['answer_pairs'])}/{len(rewards)}" - ) + reward = 1.0 if all_paragraphs_found else 0.0 + rewards.append(reward) + all_found_statuses.append(found_status) + logger.debug(f"Reward for prompt {i}: {reward}") # Log chat state log_chat_state( prompts=prompts, completions=completions, rewards=rewards, - reward_type="tag_count", - validation_results=validation_results, + reward_type="em_chunk", + supporting_paragraphs=all_supporting_paragraphs, + found_paragraph_statuses=all_found_statuses, # Log which paragraphs were found ) return rewards diff --git a/tests/test_rewards.py b/tests/test_rewards.py index ad2b9b2..6a77b27 100644 --- a/tests/test_rewards.py +++ b/tests/test_rewards.py @@ -299,48 +299,103 @@ def test_reward_retry(): def test_reward_em_chunk(): - """Test exact match chunk reward function""" + """Test exact match reward function when all paragraphs are found.""" prompts = ["What is Python?"] completions = [ - {"messages": [{"role": "user", "content": "Python is a programming language"}]} + { + "messages": [ + {"role": "user", "content": "Python is a programming language"}, + {"role": "tool", "content": "It is widely used."}, + ] + } + ] + # Expecting a list of lists for supporting_paragraphs + required_paragraphs = [["Python is a programming language", "It is widely used."]] + + rewards = reward_em_chunk(prompts, completions, supporting_paragraphs=required_paragraphs) + assert len(rewards) == 1 + assert rewards[0] == 1.0, "Should give full reward when all required paragraphs are found" + + +def test_reward_em_chunk_partial_match(): + """Test exact match reward function when only some paragraphs are found.""" + prompts = ["What is Python?"] + completions = [ + { + "messages": [ + {"role": "user", "content": "Python is a programming language"} + # Missing the second paragraph + ] + } ] - correct_contents = ["Python is a programming language"] + required_paragraphs = [["Python is a programming language", "It is widely used."]] - rewards = reward_em_chunk(prompts, completions, chunk_content=correct_contents) + rewards = reward_em_chunk(prompts, completions, supporting_paragraphs=required_paragraphs) assert len(rewards) == 1 - assert rewards[0] == 1.0, "Should give full reward for exact chunk match" + assert rewards[0] == 0.0, "Should give zero reward if not all required paragraphs are found" -def test_reward_em_chunk_no_chunk_content(): - """Test reward EM chunk with no chunk content provided""" +def test_reward_em_chunk_no_supporting_paragraphs_kwarg(): + """Test reward EM chunk with no supporting_paragraphs kwarg provided.""" completions = [{"messages": [{"role": "ipython", "content": "Some content"}]}] - with pytest.raises(ValueError, match="chunk_content must be provided"): + # Updated match string to look for 'supporting_paragraphs' + with pytest.raises(ValueError, match="supporting_paragraphs .+ must be provided"): + # Pass empty list for prompts, as it's not used in the error check path reward_em_chunk([], completions) -def test_reward_em_chunk_multiple_chunks(): - """Test reward EM chunk with multiple chunks to match""" +def test_reward_em_chunk_multiple_completions(): + """Test reward EM chunk with multiple completions and varying paragraph requirements.""" + prompts = ["Prompt 1", "Prompt 2", "Prompt 3"] completions = [ - {"messages": [{"role": "ipython", "content": "First chunk content"}]}, - {"messages": [{"role": "user", "content": "Second chunk content"}]}, + # Completion 1: Matches both required paragraphs + { + "messages": [ + {"role": "ipython", "content": "First paragraph content"}, + {"role": "user", "content": "Second paragraph content"}, + ] + }, + # Completion 2: Matches only one of the required paragraphs + {"messages": [{"role": "user", "content": "Third paragraph content"}]}, + # Completion 3: Matches all (only one required) + {"messages": [{"role": "tool", "content": "Fourth paragraph"}]}, ] - reward_kwargs = {"chunk_content": ["First chunk content", "Second chunk content"]} + # List of lists for supporting paragraphs + reward_kwargs = { + "supporting_paragraphs": [ + ["First paragraph content", "Second paragraph content"], # Requires 2, gets 2 -> 1.0 + ["Third paragraph content", "Missing paragraph"], # Requires 2, gets 1 -> 0.0 + ["Fourth paragraph"], # Requires 1, gets 1 -> 1.0 + ] + } - rewards = reward_em_chunk([], completions, **reward_kwargs) - assert len(rewards) == 2 - assert rewards == [1.0, 1.0], "Should get reward 1.0 for each matching chunk" + rewards = reward_em_chunk(prompts, completions, **reward_kwargs) + assert len(rewards) == 3 + assert rewards == [1.0, 0.0, 1.0], "Should reward based on finding *all* required paragraphs per completion" def test_reward_em_chunk_whitespace_handling(): - """Test reward EM chunk handles whitespace properly""" + """Test reward EM chunk handles whitespace properly when checking paragraphs.""" + prompts = ["Whitespace test"] completions = [ - {"messages": [{"role": "ipython", "content": " Content with spaces "}]} + { + "messages": [ + {"role": "ipython", "content": " Paragraph with spaces. "}, + {"role": "user", "content": "\\nAnother one.\\t"}, + ] + } ] - reward_kwargs = {"chunk_content": ["Content with spaces"]} + # The check is simple `paragraph in combined_information`. + # combined_information joins stripped content with '\\n'. + # The required paragraphs must match exactly what's expected in combined_information. + required_paragraphs = [["Paragraph with spaces.", "Another one."]] + reward_kwargs = {"supporting_paragraphs": required_paragraphs} - rewards = reward_em_chunk([], completions, **reward_kwargs) - assert rewards[0] == 1.0, "Should handle whitespace in content and tags" + rewards = reward_em_chunk(prompts, completions, **reward_kwargs) + assert len(rewards) == 1 + # The function joins stripped content with newline, so "Paragraph with spaces." and "Another one." should be found. + assert rewards[0] == 1.0, "Should handle whitespace in content and still match exact paragraphs" def test_reward_format_search_or_answer_not_both():