From 504f0c6c8e137cc56806b3e84760df05bbf8e8f5 Mon Sep 17 00:00:00 2001 From: thinhlpg Date: Fri, 11 Apr 2025 18:39:18 +0000 Subject: [PATCH] feat: update reward_em_chunk to match only the LAST required paragraph of the reasoning chain and adjust related tests --- src/rewards.py | 43 +++++++++++++++++++++++++---------------- tests/test_rewards.py | 45 +++++++++++++++++++++---------------------- 2 files changed, 48 insertions(+), 40 deletions(-) diff --git a/src/rewards.py b/src/rewards.py index c6c1dfa..2bc06be 100644 --- a/src/rewards.py +++ b/src/rewards.py @@ -347,7 +347,7 @@ def reward_retry(prompts: list, completions: list, **reward_kwargs) -> list: def reward_em_chunk(prompts: list, completions: list, **reward_kwargs) -> list: - """Reward function checks if model's search results contain all necessary supporting paragraphs. + """Reward function checks if model's search results contain the LAST necessary supporting paragraph. Args: prompts: List of input prompts @@ -357,7 +357,7 @@ def reward_em_chunk(prompts: list, completions: list, **reward_kwargs) -> list: Each inner list corresponds to a completion. Returns: - list: List of rewards (1.0 if all paragraphs found, 0.0 otherwise) + list: List of rewards (1.0 if the LAST paragraph is found, 0.0 otherwise) Raises: ValueError: If supporting_paragraphs is not provided or invalid. @@ -376,7 +376,7 @@ def reward_em_chunk(prompts: list, completions: list, **reward_kwargs) -> list: raise ValueError("Length of supporting_paragraphs must match length of completions") rewards = [] - all_found_statuses = [] # Track found status for each paragraph + all_found_statuses = [] # Track found status for the last paragraph for i, (completion, required_paragraphs) in enumerate(zip(completions, all_supporting_paragraphs)): if not isinstance(required_paragraphs, list): @@ -384,9 +384,20 @@ def reward_em_chunk(prompts: list, completions: list, **reward_kwargs) -> list: f"supporting_paragraphs for completion {i} is not a list, skipping. Got: {type(required_paragraphs)}" ) rewards.append(0.0) - all_found_statuses.append({p: False for p in required_paragraphs or []}) # Handle if None/empty + # Still log status for all potential paragraphs for consistency, even if we only check the last + all_found_statuses.append({p: False for p in required_paragraphs or []}) continue + # Check if the required_paragraphs list is empty - should ideally not happen but good practice + if not required_paragraphs: + logger.warning(f"Empty required_paragraphs list for completion {i}, assigning 0.0 reward.") + rewards.append(0.0) + all_found_statuses.append({}) + continue + + # Get the last required paragraph + last_required_paragraph = required_paragraphs[-1] + # Get all content from messages starting with (from any role) search_results_content = [ msg["content"].strip() @@ -398,32 +409,30 @@ def reward_em_chunk(prompts: list, completions: list, **reward_kwargs) -> list: # Combine all found information into a single text block for easier searching combined_information = "\n".join(search_results_content) - # Check if *all* required paragraphs are present in the combined information - found_status = {p: (p in combined_information) for p in required_paragraphs} - all_paragraphs_found = all(found_status.values()) + # Check if the LAST required paragraph is present in the combined information + last_paragraph_found = last_required_paragraph in combined_information - if not all_paragraphs_found: - missing_paragraphs = [p for p, found in found_status.items() if not found] + if not last_paragraph_found: logger.debug( - f"Failed to find all required paragraphs for prompt {i}.\n" - f"Required: {len(required_paragraphs)}, Found: {len(required_paragraphs) - len(missing_paragraphs)}\n" - f"Missing paragraphs: {[p[:100] + '...' for p in missing_paragraphs]}\n" + f"Failed to find the required LAST paragraph for prompt {i}.\n" + f"Required last paragraph: {last_required_paragraph[:100] + '...'}\n" f"Searched in: {combined_information[:200] + '...'}" ) - reward = 1.0 if all_paragraphs_found else 0.0 + reward = 1.0 if last_paragraph_found else 0.0 rewards.append(reward) - all_found_statuses.append(found_status) - logger.debug(f"Reward for prompt {i}: {reward}") + # Log the found status specifically for the last paragraph + all_found_statuses.append({last_required_paragraph: last_paragraph_found}) + logger.debug(f"Reward for prompt {i} (based on last paragraph): {reward}") # Log chat state log_chat_state( prompts=prompts, completions=completions, rewards=rewards, - reward_type="em_chunk", + reward_type="em_chunk_last", # Changed type to reflect new logic supporting_paragraphs=all_supporting_paragraphs, - found_paragraph_statuses=all_found_statuses, # Log which paragraphs were found + found_last_paragraph_statuses=all_found_statuses, # Log which last paragraphs were found ) return rewards diff --git a/tests/test_rewards.py b/tests/test_rewards.py index 6a77b27..5f20a61 100644 --- a/tests/test_rewards.py +++ b/tests/test_rewards.py @@ -299,7 +299,7 @@ def test_reward_retry(): def test_reward_em_chunk(): - """Test exact match reward function when all paragraphs are found.""" + """Test reward function when the LAST required paragraph is found.""" prompts = ["What is Python?"] completions = [ { @@ -309,22 +309,21 @@ def test_reward_em_chunk(): ] } ] - # Expecting a list of lists for supporting_paragraphs required_paragraphs = [["Python is a programming language", "It is widely used."]] rewards = reward_em_chunk(prompts, completions, supporting_paragraphs=required_paragraphs) assert len(rewards) == 1 - assert rewards[0] == 1.0, "Should give full reward when all required paragraphs are found" + assert rewards[0] == 1.0, "Should give full reward when the LAST paragraph is found (even if others are too)" def test_reward_em_chunk_partial_match(): - """Test exact match reward function when only some paragraphs are found.""" + """Test reward function when the LAST required paragraph is MISSING.""" prompts = ["What is Python?"] completions = [ { "messages": [ {"role": "user", "content": "Python is a programming language"} - # Missing the second paragraph + # Missing the last paragraph "It is widely used." ] } ] @@ -332,7 +331,7 @@ def test_reward_em_chunk_partial_match(): rewards = reward_em_chunk(prompts, completions, supporting_paragraphs=required_paragraphs) assert len(rewards) == 1 - assert rewards[0] == 0.0, "Should give zero reward if not all required paragraphs are found" + assert rewards[0] == 0.0, "Should give zero reward if the LAST required paragraph is missing" def test_reward_em_chunk_no_supporting_paragraphs_kwarg(): @@ -346,56 +345,56 @@ def test_reward_em_chunk_no_supporting_paragraphs_kwarg(): def test_reward_em_chunk_multiple_completions(): - """Test reward EM chunk with multiple completions and varying paragraph requirements.""" + """Test reward function with multiple completions based on the LAST paragraph rule.""" prompts = ["Prompt 1", "Prompt 2", "Prompt 3"] completions = [ - # Completion 1: Matches both required paragraphs + # Completion 1: Finds the last required paragraph ("Second paragraph content") { "messages": [ {"role": "ipython", "content": "First paragraph content"}, {"role": "user", "content": "Second paragraph content"}, ] }, - # Completion 2: Matches only one of the required paragraphs + # Completion 2: Does NOT find the last required paragraph ("Missing paragraph") {"messages": [{"role": "user", "content": "Third paragraph content"}]}, - # Completion 3: Matches all (only one required) + # Completion 3: Finds the last (and only) required paragraph ("Fourth paragraph") {"messages": [{"role": "tool", "content": "Fourth paragraph"}]}, ] - # List of lists for supporting paragraphs reward_kwargs = { "supporting_paragraphs": [ - ["First paragraph content", "Second paragraph content"], # Requires 2, gets 2 -> 1.0 - ["Third paragraph content", "Missing paragraph"], # Requires 2, gets 1 -> 0.0 - ["Fourth paragraph"], # Requires 1, gets 1 -> 1.0 + ["First paragraph content", "Second paragraph content"], # Last is found -> 1.0 + ["Third paragraph content", "Missing paragraph"], # Last is missing -> 0.0 + ["Fourth paragraph"], # Last is found -> 1.0 ] } rewards = reward_em_chunk(prompts, completions, **reward_kwargs) assert len(rewards) == 3 - assert rewards == [1.0, 0.0, 1.0], "Should reward based on finding *all* required paragraphs per completion" + assert rewards == [1.0, 0.0, 1.0], "Should reward based ONLY on finding the LAST required paragraph per completion" def test_reward_em_chunk_whitespace_handling(): - """Test reward EM chunk handles whitespace properly when checking paragraphs.""" + """Test reward function handles whitespace properly when checking the LAST paragraph.""" prompts = ["Whitespace test"] completions = [ { "messages": [ {"role": "ipython", "content": " Paragraph with spaces. "}, - {"role": "user", "content": "\\nAnother one.\\t"}, + # This is the last paragraph and should be found despite leading/trailing spaces in the info block + {"role": "user", "content": " Another one. "}, ] } ] - # The check is simple `paragraph in combined_information`. - # combined_information joins stripped content with '\\n'. - # The required paragraphs must match exactly what's expected in combined_information. - required_paragraphs = [["Paragraph with spaces.", "Another one."]] + # The required paragraphs list + required_paragraphs = [ + ["Paragraph with spaces.", "Another one."] # We only care about "Another one." + ] reward_kwargs = {"supporting_paragraphs": required_paragraphs} rewards = reward_em_chunk(prompts, completions, **reward_kwargs) assert len(rewards) == 1 - # The function joins stripped content with newline, so "Paragraph with spaces." and "Another one." should be found. - assert rewards[0] == 1.0, "Should handle whitespace in content and still match exact paragraphs" + # The last paragraph "Another one." should be found in the combined, stripped content. + assert rewards[0] == 1.0, "Should correctly match the LAST paragraph even with whitespace issues" def test_reward_format_search_or_answer_not_both():