fix: strengthen reward correctness logic to handle final message is not asnwer form assistant. Also update logs for reward functions for better debug

- Added 'logs/' directory to .gitignore to exclude log files.
- Introduced log_chat_state function to log chat states and rewards to JSONL files.
- Updated reward functions to log chat states with validation results for better tracking and debugging.
main
thinhlpg 1 month ago
parent 1bd609dfae
commit d0e6068055

1
.gitignore vendored

@ -14,6 +14,7 @@ model/
graveyard/ graveyard/
eval_logs/ eval_logs/
downloaded_model/ downloaded_model/
logs/
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
__pycache__/ __pycache__/

@ -2,11 +2,14 @@
Reward functions for RL training. Reward functions for RL training.
""" """
import json
import re import re
from datetime import datetime
from pathlib import Path
import numpy as np import numpy as np
from src.config import logger from src.config import LOG_FOLDER, logger
from src.evaluation import check_student_answers from src.evaluation import check_student_answers
@ -35,29 +38,46 @@ def build_reward_correctness_fn(
Returns: Returns:
List of correctness scores between 0 and 1 List of correctness scores between 0 and 1
""" """
# correctness: must be assistant response and must have only one pair of answer tags
# should not have search tags and information tags
teacher_answers = reward_kwargs["answer"] teacher_answers = reward_kwargs["answer"]
student_answers = [completion["messages"][-1]["content"] for completion in completions] student_final_messages = [completion["messages"][-1]["content"] for completion in completions]
student_final_message_roles = [completion["messages"][-1]["role"] for completion in completions]
# Log non-exact matches is_assistant_response = [role == "assistant" for role in student_final_message_roles]
for i, (student, teacher) in enumerate(zip(student_answers, teacher_answers)): has_answer_tag = [re.search(r"<answer>[\s\S]*?</answer>", ans) is not None for ans in student_final_messages]
if student.strip().lower() != teacher.strip().lower(): has_search_tag = [re.search(r"<search>[\s\S]*?</search>", ans) is not None for ans in student_final_messages]
logger.debug(f"Non-exact match at index {i}:\nStudent: {student}\nTeacher: {teacher}") has_information_tag = [
re.search(r"<information>[\s\S]*?</information>", ans) is not None for ans in student_final_messages
]
correct = check_student_answers( might_be_correct = check_student_answers(
prompts, prompts,
teacher_answers, teacher_answers,
student_answers, student_final_messages,
vllm_generate_func=vllm_generate_func, vllm_generate_func=vllm_generate_func,
tokenizer=tokenizer, tokenizer=tokenizer,
) )
# Convert lists to numpy arrays for element-wise operations
might_be_correct = np.array(might_be_correct)
is_assistant_response = np.array(is_assistant_response)
has_answer_tag = np.array(has_answer_tag)
has_search_tag = np.array(has_search_tag)
has_information_tag = np.array(has_information_tag)
# might be correct and is assistant response and has answer tag and no search or information tags
correct = might_be_correct & is_assistant_response & has_answer_tag & ~has_search_tag & ~has_information_tag
# Convert numpy array back to list for return
correct = correct.tolist()
# Log correctness metrics with length info # Log correctness metrics with length info
logger.info(f"Correctness metrics: {correct}") logger.info(f"Correctness metrics: {correct}")
logger.info(f"Average correctness: {np.mean(correct):.2f}") logger.info(f"Average correctness: {np.mean(correct):.2f}")
logger.info(f"Standard deviation: {np.std(correct):.2f}") logger.info(f"Standard deviation: {np.std(correct):.2f}")
# Log length metrics # Log length metrics
student_lengths = [len(ans.strip()) for ans in student_answers] student_lengths = [len(ans.strip()) for ans in student_final_messages]
teacher_lengths = [len(ans.strip()) for ans in teacher_answers] teacher_lengths = [len(ans.strip()) for ans in teacher_answers]
logger.info(f"Student lengths: {student_lengths}") logger.info(f"Student lengths: {student_lengths}")
logger.info(f"Teacher lengths: {teacher_lengths}") logger.info(f"Teacher lengths: {teacher_lengths}")
@ -65,6 +85,22 @@ def build_reward_correctness_fn(
logger.info(f"Average teacher length: {np.mean(teacher_lengths):.2f}") logger.info(f"Average teacher length: {np.mean(teacher_lengths):.2f}")
logger.info(f"Length ratio: {np.mean(student_lengths) / np.mean(teacher_lengths):.2f}") logger.info(f"Length ratio: {np.mean(student_lengths) / np.mean(teacher_lengths):.2f}")
# Log chat state
log_chat_state(
prompts=prompts,
completions=completions,
rewards=correct,
reward_type="correctness",
teacher_answers=teacher_answers,
validation_results={
"is_assistant": is_assistant_response,
"has_answer": has_answer_tag,
"has_search": has_search_tag,
"has_info": has_information_tag,
"might_be_correct": might_be_correct,
},
)
return correct return correct
return reward_correctness return reward_correctness
@ -103,6 +139,13 @@ def reward_format(prompts: list, completions: list, **reward_kwargs) -> list:
] ]
rewards = [] rewards = []
validation_results = {
"has_think": [],
"has_answer": [],
"has_search": [],
"has_invalid_tags": [],
"has_info_tags": [],
}
for completion in completions: for completion in completions:
messages = completion.get("messages", []) messages = completion.get("messages", [])
@ -110,59 +153,68 @@ def reward_format(prompts: list, completions: list, **reward_kwargs) -> list:
if not assistant_msgs: if not assistant_msgs:
rewards.append(0.0) rewards.append(0.0)
for key in validation_results:
validation_results[key].append(False)
continue continue
content = assistant_msgs[-1] # Get the last assistant message content = assistant_msgs[-1]
# Check for invalid markdown formatting
has_invalid_tags = any(re.search(pattern, content) for pattern in invalid_patterns) has_invalid_tags = any(re.search(pattern, content) for pattern in invalid_patterns)
validation_results["has_invalid_tags"].append(has_invalid_tags)
if has_invalid_tags: if has_invalid_tags:
logger.debug("Found markdown-formatted tags in response")
rewards.append(0.0) rewards.append(0.0)
for key in ["has_think", "has_answer", "has_search", "has_info_tags"]:
validation_results[key].append(False)
continue continue
# Check for any information tag variants (should not exist in assistant messages)
has_info_tags = False has_info_tags = False
for pattern in info_patterns: for pattern in info_patterns:
info_matches = re.findall(pattern, content, re.IGNORECASE) if re.findall(pattern, content, re.IGNORECASE):
if info_matches:
logger.debug(f"Found {len(info_matches)} information tag(s) of type '{pattern}' in assistant message")
has_info_tags = True has_info_tags = True
break break
validation_results["has_info_tags"].append(has_info_tags)
if has_info_tags: if has_info_tags:
rewards.append(0.0) rewards.append(0.0)
for key in ["has_think", "has_answer", "has_search"]:
validation_results[key].append(False)
continue continue
# Find all tag matches
think_matches = re.findall(think_pattern, content) think_matches = re.findall(think_pattern, content)
search_matches = re.findall(search_pattern, content) search_matches = re.findall(search_pattern, content)
answer_matches = re.findall(answer_pattern, content) answer_matches = re.findall(answer_pattern, content)
# Verify tag presence and count
has_think = len(think_matches) >= 1 has_think = len(think_matches) >= 1
has_answer = len(answer_matches) == 1 # Must have exactly one answer has_answer = len(answer_matches) == 1
has_search = len(search_matches) >= 1 # One or more search tags has_search = len(search_matches) >= 1
validation_results["has_think"].append(has_think)
validation_results["has_answer"].append(has_answer)
validation_results["has_search"].append(has_search)
# Check for search and answer in the same message (not allowed)
if has_search and has_answer: if has_search and has_answer:
logger.debug("Found both search and answer tags in the same message")
rewards.append(0.0) rewards.append(0.0)
continue continue
# Award reward - must have think tag and either answer or search (but not both)
reward = 1.0 if has_think and (has_answer or has_search) else 0.0 reward = 1.0 if has_think and (has_answer or has_search) else 0.0
rewards.append(reward) rewards.append(reward)
# Log issues for debugging
if not reward: if not reward:
logger.debug(f"Format issues - think: {has_think}, answer: {has_answer}, search: {has_search}") logger.debug(f"Format issues - think: {has_think}, answer: {has_answer}, search: {has_search}")
if search_matches: if search_matches:
logger.debug(f"Number of search tags: {len(search_matches)}") logger.debug(f"Number of search tags: {len(search_matches)}")
# Log overall metrics
logger.info(f"Format reward metrics - Mean: {np.mean(rewards):.3f}, Valid formats: {sum(rewards)}/{len(rewards)}") logger.info(f"Format reward metrics - Mean: {np.mean(rewards):.3f}, Valid formats: {sum(rewards)}/{len(rewards)}")
# Log chat state with validation results
log_chat_state(
prompts=prompts,
completions=completions,
rewards=rewards,
reward_type="format",
validation_results=validation_results,
)
return rewards return rewards
@ -247,6 +299,17 @@ def reward_retry(prompts: list, completions: list, **reward_kwargs) -> list:
logger.info(f"Violations (>1 search per message): {sum(violations)}/{len(violations)}") logger.info(f"Violations (>1 search per message): {sum(violations)}/{len(violations)}")
logger.info(f"Search counts distribution: {search_queries}") logger.info(f"Search counts distribution: {search_queries}")
# Log chat state
log_chat_state(
prompts=prompts,
completions=completions,
rewards=rewards,
reward_type="retry",
search_counts=search_queries,
violations=violations,
optimal_search_count=optimal_search_count,
)
return rewards return rewards
@ -309,4 +372,55 @@ def reward_em_chunk(prompts: list, completions: list, **reward_kwargs) -> list:
logger.info(f"Average reward: {np.mean(rewards):.3f}") logger.info(f"Average reward: {np.mean(rewards):.3f}")
logger.info(f"Reward std: {np.std(rewards):.3f}") logger.info(f"Reward std: {np.std(rewards):.3f}")
# Log chat state
log_chat_state(
prompts=prompts,
completions=completions,
rewards=rewards,
reward_type="em_chunk",
correct_contents=correct_contents,
)
return rewards return rewards
def log_chat_state(prompts: list, completions: list, rewards: list, reward_type: str, **kwargs) -> None:
"""Log chat state and rewards to JSONL file.
Args:
prompts: List of input prompts
completions: List of model completions
rewards: List of calculated rewards
reward_type: Type of reward function used
**kwargs: Additional data to log
"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
chat_states_dir = LOG_FOLDER / "chat_states"
chat_states_dir.mkdir(parents=True, exist_ok=True)
# Convert numpy arrays to lists in kwargs
for key, value in kwargs.items():
if isinstance(value, dict):
for k, v in value.items():
if isinstance(v, np.ndarray):
kwargs[key][k] = v.tolist()
elif isinstance(value, np.ndarray):
kwargs[key] = value.tolist()
# Create one JSONL file per reward type
log_file = chat_states_dir / f"chat_states_{reward_type}.jsonl"
# Append each chat state as a new line
with open(log_file, "a", encoding="utf-8") as f:
for prompt, completion, reward in zip(prompts, completions, rewards):
chat_state = {
"timestamp": timestamp,
"reward_type": reward_type,
"prompt": prompt,
"messages": completion["messages"],
"reward": float(reward) if isinstance(reward, (np.number, np.ndarray)) else reward,
"metadata": kwargs,
}
f.write(json.dumps(chat_state, ensure_ascii=False) + "\n")
logger.info(f"💾 Appended {len(prompts)} chat states to {log_file}")

@ -71,6 +71,7 @@ def test_reward_correctness_wrong_answer(reward_correctness_fn):
def test_reward_format_correct(): def test_reward_format_correct():
"""Test reward format with correct format""" """Test reward format with correct format"""
prompts = ["Test prompt"]
completions = [ completions = [
{ {
"messages": [ "messages": [
@ -78,21 +79,23 @@ def test_reward_format_correct():
] ]
} }
] ]
rewards = reward_format([], completions) rewards = reward_format(prompts, completions)
assert rewards[0] == 1.0 assert rewards[0] == 1.0
def test_reward_format_with_search(): def test_reward_format_with_search():
"""Test reward format with search tags only (no answer tags)""" """Test reward format with search tags only (no answer tags)"""
prompts = ["Test prompt"]
completions = [ completions = [
{"messages": [{"role": "assistant", "content": "<think>\nSome reasoning\n</think>\n<search>query</search>"}]} {"messages": [{"role": "assistant", "content": "<think>\nSome reasoning\n</think>\n<search>query</search>"}]}
] ]
rewards = reward_format([], completions) rewards = reward_format(prompts, completions)
assert rewards[0] == 1.0 assert rewards[0] == 1.0
def test_reward_format_markdown_tags(): def test_reward_format_markdown_tags():
"""Test reward format with markdown-styled tags""" """Test reward format with markdown-styled tags"""
prompts = ["Test prompt"]
markdown_formats = [ markdown_formats = [
{ {
"messages": [ "messages": [
@ -121,12 +124,13 @@ def test_reward_format_markdown_tags():
] ]
for completion in markdown_formats: for completion in markdown_formats:
rewards = reward_format([], [completion]) rewards = reward_format(["Test prompt"], [completion])
assert rewards[0] == 0.0, f"Failed with: {completion['messages'][0]['content']}" assert rewards[0] == 0.0, f"Failed with: {completion['messages'][0]['content']}"
def test_reward_format_information_tags(): def test_reward_format_information_tags():
"""Test reward format with information tags""" """Test reward format with information tags"""
prompts = ["Test prompt"]
# Test different information tag variants # Test different information tag variants
info_variants = [ info_variants = [
"<information>Some info</information>", "<information>Some info</information>",
@ -139,12 +143,13 @@ def test_reward_format_information_tags():
for info_tag in info_variants: for info_tag in info_variants:
content = f"<think>\nSome reasoning\n</think>\n{info_tag}\n<answer>\nThe answer\n</answer>" content = f"<think>\nSome reasoning\n</think>\n{info_tag}\n<answer>\nThe answer\n</answer>"
completions = [{"messages": [{"role": "assistant", "content": content}]}] completions = [{"messages": [{"role": "assistant", "content": content}]}]
rewards = reward_format([], completions) rewards = reward_format(prompts, completions)
assert rewards[0] == 0.0, f"Failed to detect information tag: {info_tag}" assert rewards[0] == 0.0, f"Failed to detect information tag: {info_tag}"
def test_reward_format_real_example(): def test_reward_format_real_example():
"""Test reward format with a real-world example - should fail now since it has both search and answer tags""" """Test reward format with a real-world example - should fail now since it has both search and answer tags"""
prompts = ["What cars did Paul Walker drive in Fast and Furious?"]
content = """<think>I need to search for Paul Walker's cars in Fast and Furious movies.</think> content = """<think>I need to search for Paul Walker's cars in Fast and Furious movies.</think>
<search> Paul Walker's cars in Fast and Furious </search> <search> Paul Walker's cars in Fast and Furious </search>
@ -157,27 +162,29 @@ Based on the updated information, it seems the focus was on his career, financia
<answer> Charger </answer>""" <answer> Charger </answer>"""
completions = [{"messages": [{"role": "assistant", "content": content}]}] completions = [{"messages": [{"role": "assistant", "content": content}]}]
rewards = reward_format([], completions) rewards = reward_format(prompts, completions)
assert rewards[0] == 0.0, "Should reject responses with both search and answer tags" assert rewards[0] == 0.0, "Should reject responses with both search and answer tags"
def test_reward_format_real_example_search_only(): def test_reward_format_real_example_search_only():
"""Test reward format with search-only format in a real-world example""" """Test reward format with search-only format in a real-world example"""
prompts = ["What cars did Paul Walker drive in Fast and Furious?"]
content = """<think>I need to search for Paul Walker's cars in Fast and Furious movies.</think> content = """<think>I need to search for Paul Walker's cars in Fast and Furious movies.</think>
<search> Paul Walker's cars in Fast and Furious </search>""" <search> Paul Walker's cars in Fast and Furious </search>"""
completions = [{"messages": [{"role": "assistant", "content": content}]}] completions = [{"messages": [{"role": "assistant", "content": content}]}]
rewards = reward_format([], completions) rewards = reward_format(prompts, completions)
assert rewards[0] == 1.0, "Should accept responses with only search tags" assert rewards[0] == 1.0, "Should accept responses with only search tags"
def test_reward_format_real_example_answer_only(): def test_reward_format_real_example_answer_only():
"""Test reward format with answer-only format in a real-world example""" """Test reward format with answer-only format in a real-world example"""
prompts = ["What cars did Paul Walker drive in Fast and Furious?"]
content = """<think>Based on the information provided, it seems Paul Walker drove a Charger in the Fast and Furious series.</think> content = """<think>Based on the information provided, it seems Paul Walker drove a Charger in the Fast and Furious series.</think>
<answer> Charger </answer>""" <answer> Charger </answer>"""
completions = [{"messages": [{"role": "assistant", "content": content}]}] completions = [{"messages": [{"role": "assistant", "content": content}]}]
rewards = reward_format([], completions) rewards = reward_format(prompts, completions)
assert rewards[0] == 1.0, "Should accept responses with only answer tags" assert rewards[0] == 1.0, "Should accept responses with only answer tags"
@ -252,134 +259,33 @@ def test_reward_format_incomplete_tags():
def test_reward_retry(): def test_reward_retry():
"""Test reward retry functionality with progressive rewards up to 5 searches""" """Test reward retry function"""
# Test case with no searches prompts = ["What is the capital of France?"]
completions = [{"messages": [{"role": "assistant", "content": "No searches here"}]}]
rewards = reward_retry([], completions)
assert rewards[0] == 0.0, "Should get 0 reward for no searches"
# Test case with one search
completions = [
{
"messages": [
{
"role": "assistant",
"content": "<think>I need more information</think>\n<search>First query</search>",
}
]
}
]
rewards = reward_retry([], completions)
assert rewards[0] == 0.35, "Should get 0.35 reward for one search"
# Test case with three searches in different messages
completions = [ completions = [
{ {
"messages": [ "messages": [
{ {"role": "assistant", "content": "<think>Let me search</think>\n<search>capital of France</search>"},
"role": "assistant", {"role": "assistant", "content": "<think>Need more info</think>\n<search>Paris history</search>"},
"content": "<think>First search</think>\n<search>Query 1</search>", {"role": "assistant", "content": "<think>Found it</think>\n<answer>Paris</answer>"},
},
{"role": "assistant", "content": "<think>Second search</think>\n<search>Query 2</search>"},
{"role": "assistant", "content": "<think>Third search</think>\n<search>Query 3</search>"},
]
}
]
rewards = reward_retry([], completions)
assert rewards[0] == 0.65, "Should get 0.65 reward for three searches"
# Test case with five searches in different messages
completions = [
{
"messages": [
{"role": "assistant", "content": "<think>Search 1</think>\n<search>Query 1</search>"},
{"role": "assistant", "content": "<think>Search 2</think>\n<search>Query 2</search>"},
{"role": "assistant", "content": "<think>Search 3</think>\n<search>Query 3</search>"},
{"role": "assistant", "content": "<think>Search 4</think>\n<search>Query 4</search>"},
{"role": "assistant", "content": "<think>Search 5</think>\n<search>Query 5</search>"},
] ]
} }
] ]
rewards = reward_retry([], completions) rewards = reward_retry(prompts, completions)
assert rewards[0] == 0.95, "Should get 0.95 reward for five searches" assert len(rewards) == 1
assert rewards[0] > 0, "Should give positive reward for multiple search attempts"
# Test case with more than five searches
completions = [
{
"messages": [
{"role": "assistant", "content": "<think>Search 1</think>\n<search>Query 1</search>"},
{"role": "assistant", "content": "<think>Search 2</think>\n<search>Query 2</search>"},
{"role": "assistant", "content": "<think>Search 3</think>\n<search>Query 3</search>"},
{"role": "assistant", "content": "<think>Search 4</think>\n<search>Query 4</search>"},
{"role": "assistant", "content": "<think>Search 5</think>\n<search>Query 5</search>"},
{"role": "assistant", "content": "<think>Search 6</think>\n<search>Query 6</search>"},
]
}
]
rewards = reward_retry([], completions)
assert rewards[0] == 0.95, "Should cap at 0.95 reward for more than five searches"
# Test case with violation (multiple searches in one message)
completions = [
{
"messages": [
{
"role": "assistant",
"content": "<think>Multiple searches</think>\n<search>First query</search>\n<search>Second query</search>",
}
]
}
]
rewards = reward_retry([], completions)
assert rewards[0] == 0.25, "Should get penalized reward (0.5 * 0.5) for violation"
def test_reward_em_chunk(): def test_reward_em_chunk():
"""Test reward EM chunk functionality with information tags""" """Test exact match chunk reward function"""
# Test case with matching content in ipython role prompts = ["What is Python?"]
completions = [ completions = [
{"messages": [{"role": "ipython", "content": "<information>This is the correct chunk content</information>"}]} {"messages": [{"role": "user", "content": "<information>Python is a programming language</information>"}]}
] ]
reward_kwargs = {"chunk_content": ["This is the correct chunk content"]} correct_contents = ["Python is a programming language"]
rewards = reward_em_chunk([], completions, **reward_kwargs) rewards = reward_em_chunk(prompts, completions, chunk_content=correct_contents)
assert len(rewards) == 1 assert len(rewards) == 1
assert rewards[0] == 1.0, "Should get reward 1.0 for exact match in ipython role" assert rewards[0] == 1.0, "Should give full reward for exact chunk match"
# Test case with matching content in user role
completions = [
{"messages": [{"role": "user", "content": "<information>This is the correct chunk content</information>"}]}
]
rewards = reward_em_chunk([], completions, **reward_kwargs)
assert rewards[0] == 1.0, "Should get reward 1.0 for exact match in user role"
# Test case with content not starting with <information> tag
completions = [{"messages": [{"role": "ipython", "content": "This is the correct chunk content"}]}]
rewards = reward_em_chunk([], completions, **reward_kwargs)
assert rewards[0] == 0.0, "Should get reward 0.0 for missing information tag"
# Test case with wrong role
completions = [
{
"messages": [
{"role": "assistant", "content": "<information>This is the correct chunk content</information>"}
]
}
]
rewards = reward_em_chunk([], completions, **reward_kwargs)
assert rewards[0] == 0.0, "Should get reward 0.0 for wrong role"
# Test case with multiple messages, only one matching
completions = [
{
"messages": [
{"role": "ipython", "content": "<information>Wrong content</information>"},
{"role": "user", "content": "<information>This is the correct chunk content</information>"},
]
}
]
rewards = reward_em_chunk([], completions, **reward_kwargs)
assert rewards[0] == 1.0, "Should get reward 1.0 if any message matches"
def test_reward_em_chunk_no_chunk_content(): def test_reward_em_chunk_no_chunk_content():

Loading…
Cancel
Save