fix: strengthen reward correctness logic to handle final message is not asnwer form assistant. Also update logs for reward functions for better debug

- Added 'logs/' directory to .gitignore to exclude log files.
- Introduced log_chat_state function to log chat states and rewards to JSONL files.
- Updated reward functions to log chat states with validation results for better tracking and debugging.
main
thinhlpg 1 month ago
parent 1bd609dfae
commit d0e6068055

1
.gitignore vendored

@ -14,6 +14,7 @@ model/
graveyard/
eval_logs/
downloaded_model/
logs/
# Byte-compiled / optimized / DLL files
__pycache__/

@ -2,11 +2,14 @@
Reward functions for RL training.
"""
import json
import re
from datetime import datetime
from pathlib import Path
import numpy as np
from src.config import logger
from src.config import LOG_FOLDER, logger
from src.evaluation import check_student_answers
@ -35,29 +38,46 @@ def build_reward_correctness_fn(
Returns:
List of correctness scores between 0 and 1
"""
# correctness: must be assistant response and must have only one pair of answer tags
# should not have search tags and information tags
teacher_answers = reward_kwargs["answer"]
student_answers = [completion["messages"][-1]["content"] for completion in completions]
# Log non-exact matches
for i, (student, teacher) in enumerate(zip(student_answers, teacher_answers)):
if student.strip().lower() != teacher.strip().lower():
logger.debug(f"Non-exact match at index {i}:\nStudent: {student}\nTeacher: {teacher}")
student_final_messages = [completion["messages"][-1]["content"] for completion in completions]
student_final_message_roles = [completion["messages"][-1]["role"] for completion in completions]
is_assistant_response = [role == "assistant" for role in student_final_message_roles]
has_answer_tag = [re.search(r"<answer>[\s\S]*?</answer>", ans) is not None for ans in student_final_messages]
has_search_tag = [re.search(r"<search>[\s\S]*?</search>", ans) is not None for ans in student_final_messages]
has_information_tag = [
re.search(r"<information>[\s\S]*?</information>", ans) is not None for ans in student_final_messages
]
correct = check_student_answers(
might_be_correct = check_student_answers(
prompts,
teacher_answers,
student_answers,
student_final_messages,
vllm_generate_func=vllm_generate_func,
tokenizer=tokenizer,
)
# Convert lists to numpy arrays for element-wise operations
might_be_correct = np.array(might_be_correct)
is_assistant_response = np.array(is_assistant_response)
has_answer_tag = np.array(has_answer_tag)
has_search_tag = np.array(has_search_tag)
has_information_tag = np.array(has_information_tag)
# might be correct and is assistant response and has answer tag and no search or information tags
correct = might_be_correct & is_assistant_response & has_answer_tag & ~has_search_tag & ~has_information_tag
# Convert numpy array back to list for return
correct = correct.tolist()
# Log correctness metrics with length info
logger.info(f"Correctness metrics: {correct}")
logger.info(f"Average correctness: {np.mean(correct):.2f}")
logger.info(f"Standard deviation: {np.std(correct):.2f}")
# Log length metrics
student_lengths = [len(ans.strip()) for ans in student_answers]
student_lengths = [len(ans.strip()) for ans in student_final_messages]
teacher_lengths = [len(ans.strip()) for ans in teacher_answers]
logger.info(f"Student lengths: {student_lengths}")
logger.info(f"Teacher lengths: {teacher_lengths}")
@ -65,6 +85,22 @@ def build_reward_correctness_fn(
logger.info(f"Average teacher length: {np.mean(teacher_lengths):.2f}")
logger.info(f"Length ratio: {np.mean(student_lengths) / np.mean(teacher_lengths):.2f}")
# Log chat state
log_chat_state(
prompts=prompts,
completions=completions,
rewards=correct,
reward_type="correctness",
teacher_answers=teacher_answers,
validation_results={
"is_assistant": is_assistant_response,
"has_answer": has_answer_tag,
"has_search": has_search_tag,
"has_info": has_information_tag,
"might_be_correct": might_be_correct,
},
)
return correct
return reward_correctness
@ -103,6 +139,13 @@ def reward_format(prompts: list, completions: list, **reward_kwargs) -> list:
]
rewards = []
validation_results = {
"has_think": [],
"has_answer": [],
"has_search": [],
"has_invalid_tags": [],
"has_info_tags": [],
}
for completion in completions:
messages = completion.get("messages", [])
@ -110,59 +153,68 @@ def reward_format(prompts: list, completions: list, **reward_kwargs) -> list:
if not assistant_msgs:
rewards.append(0.0)
for key in validation_results:
validation_results[key].append(False)
continue
content = assistant_msgs[-1] # Get the last assistant message
content = assistant_msgs[-1]
# Check for invalid markdown formatting
has_invalid_tags = any(re.search(pattern, content) for pattern in invalid_patterns)
validation_results["has_invalid_tags"].append(has_invalid_tags)
if has_invalid_tags:
logger.debug("Found markdown-formatted tags in response")
rewards.append(0.0)
for key in ["has_think", "has_answer", "has_search", "has_info_tags"]:
validation_results[key].append(False)
continue
# Check for any information tag variants (should not exist in assistant messages)
has_info_tags = False
for pattern in info_patterns:
info_matches = re.findall(pattern, content, re.IGNORECASE)
if info_matches:
logger.debug(f"Found {len(info_matches)} information tag(s) of type '{pattern}' in assistant message")
if re.findall(pattern, content, re.IGNORECASE):
has_info_tags = True
break
validation_results["has_info_tags"].append(has_info_tags)
if has_info_tags:
rewards.append(0.0)
for key in ["has_think", "has_answer", "has_search"]:
validation_results[key].append(False)
continue
# Find all tag matches
think_matches = re.findall(think_pattern, content)
search_matches = re.findall(search_pattern, content)
answer_matches = re.findall(answer_pattern, content)
# Verify tag presence and count
has_think = len(think_matches) >= 1
has_answer = len(answer_matches) == 1 # Must have exactly one answer
has_search = len(search_matches) >= 1 # One or more search tags
has_answer = len(answer_matches) == 1
has_search = len(search_matches) >= 1
validation_results["has_think"].append(has_think)
validation_results["has_answer"].append(has_answer)
validation_results["has_search"].append(has_search)
# Check for search and answer in the same message (not allowed)
if has_search and has_answer:
logger.debug("Found both search and answer tags in the same message")
rewards.append(0.0)
continue
# Award reward - must have think tag and either answer or search (but not both)
reward = 1.0 if has_think and (has_answer or has_search) else 0.0
rewards.append(reward)
# Log issues for debugging
if not reward:
logger.debug(f"Format issues - think: {has_think}, answer: {has_answer}, search: {has_search}")
if search_matches:
logger.debug(f"Number of search tags: {len(search_matches)}")
# Log overall metrics
logger.info(f"Format reward metrics - Mean: {np.mean(rewards):.3f}, Valid formats: {sum(rewards)}/{len(rewards)}")
# Log chat state with validation results
log_chat_state(
prompts=prompts,
completions=completions,
rewards=rewards,
reward_type="format",
validation_results=validation_results,
)
return rewards
@ -247,6 +299,17 @@ def reward_retry(prompts: list, completions: list, **reward_kwargs) -> list:
logger.info(f"Violations (>1 search per message): {sum(violations)}/{len(violations)}")
logger.info(f"Search counts distribution: {search_queries}")
# Log chat state
log_chat_state(
prompts=prompts,
completions=completions,
rewards=rewards,
reward_type="retry",
search_counts=search_queries,
violations=violations,
optimal_search_count=optimal_search_count,
)
return rewards
@ -309,4 +372,55 @@ def reward_em_chunk(prompts: list, completions: list, **reward_kwargs) -> list:
logger.info(f"Average reward: {np.mean(rewards):.3f}")
logger.info(f"Reward std: {np.std(rewards):.3f}")
# Log chat state
log_chat_state(
prompts=prompts,
completions=completions,
rewards=rewards,
reward_type="em_chunk",
correct_contents=correct_contents,
)
return rewards
def log_chat_state(prompts: list, completions: list, rewards: list, reward_type: str, **kwargs) -> None:
"""Log chat state and rewards to JSONL file.
Args:
prompts: List of input prompts
completions: List of model completions
rewards: List of calculated rewards
reward_type: Type of reward function used
**kwargs: Additional data to log
"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
chat_states_dir = LOG_FOLDER / "chat_states"
chat_states_dir.mkdir(parents=True, exist_ok=True)
# Convert numpy arrays to lists in kwargs
for key, value in kwargs.items():
if isinstance(value, dict):
for k, v in value.items():
if isinstance(v, np.ndarray):
kwargs[key][k] = v.tolist()
elif isinstance(value, np.ndarray):
kwargs[key] = value.tolist()
# Create one JSONL file per reward type
log_file = chat_states_dir / f"chat_states_{reward_type}.jsonl"
# Append each chat state as a new line
with open(log_file, "a", encoding="utf-8") as f:
for prompt, completion, reward in zip(prompts, completions, rewards):
chat_state = {
"timestamp": timestamp,
"reward_type": reward_type,
"prompt": prompt,
"messages": completion["messages"],
"reward": float(reward) if isinstance(reward, (np.number, np.ndarray)) else reward,
"metadata": kwargs,
}
f.write(json.dumps(chat_state, ensure_ascii=False) + "\n")
logger.info(f"💾 Appended {len(prompts)} chat states to {log_file}")

@ -71,6 +71,7 @@ def test_reward_correctness_wrong_answer(reward_correctness_fn):
def test_reward_format_correct():
"""Test reward format with correct format"""
prompts = ["Test prompt"]
completions = [
{
"messages": [
@ -78,21 +79,23 @@ def test_reward_format_correct():
]
}
]
rewards = reward_format([], completions)
rewards = reward_format(prompts, completions)
assert rewards[0] == 1.0
def test_reward_format_with_search():
"""Test reward format with search tags only (no answer tags)"""
prompts = ["Test prompt"]
completions = [
{"messages": [{"role": "assistant", "content": "<think>\nSome reasoning\n</think>\n<search>query</search>"}]}
]
rewards = reward_format([], completions)
rewards = reward_format(prompts, completions)
assert rewards[0] == 1.0
def test_reward_format_markdown_tags():
"""Test reward format with markdown-styled tags"""
prompts = ["Test prompt"]
markdown_formats = [
{
"messages": [
@ -121,12 +124,13 @@ def test_reward_format_markdown_tags():
]
for completion in markdown_formats:
rewards = reward_format([], [completion])
rewards = reward_format(["Test prompt"], [completion])
assert rewards[0] == 0.0, f"Failed with: {completion['messages'][0]['content']}"
def test_reward_format_information_tags():
"""Test reward format with information tags"""
prompts = ["Test prompt"]
# Test different information tag variants
info_variants = [
"<information>Some info</information>",
@ -139,12 +143,13 @@ def test_reward_format_information_tags():
for info_tag in info_variants:
content = f"<think>\nSome reasoning\n</think>\n{info_tag}\n<answer>\nThe answer\n</answer>"
completions = [{"messages": [{"role": "assistant", "content": content}]}]
rewards = reward_format([], completions)
rewards = reward_format(prompts, completions)
assert rewards[0] == 0.0, f"Failed to detect information tag: {info_tag}"
def test_reward_format_real_example():
"""Test reward format with a real-world example - should fail now since it has both search and answer tags"""
prompts = ["What cars did Paul Walker drive in Fast and Furious?"]
content = """<think>I need to search for Paul Walker's cars in Fast and Furious movies.</think>
<search> Paul Walker's cars in Fast and Furious </search>
@ -157,27 +162,29 @@ Based on the updated information, it seems the focus was on his career, financia
<answer> Charger </answer>"""
completions = [{"messages": [{"role": "assistant", "content": content}]}]
rewards = reward_format([], completions)
rewards = reward_format(prompts, completions)
assert rewards[0] == 0.0, "Should reject responses with both search and answer tags"
def test_reward_format_real_example_search_only():
"""Test reward format with search-only format in a real-world example"""
prompts = ["What cars did Paul Walker drive in Fast and Furious?"]
content = """<think>I need to search for Paul Walker's cars in Fast and Furious movies.</think>
<search> Paul Walker's cars in Fast and Furious </search>"""
completions = [{"messages": [{"role": "assistant", "content": content}]}]
rewards = reward_format([], completions)
rewards = reward_format(prompts, completions)
assert rewards[0] == 1.0, "Should accept responses with only search tags"
def test_reward_format_real_example_answer_only():
"""Test reward format with answer-only format in a real-world example"""
prompts = ["What cars did Paul Walker drive in Fast and Furious?"]
content = """<think>Based on the information provided, it seems Paul Walker drove a Charger in the Fast and Furious series.</think>
<answer> Charger </answer>"""
completions = [{"messages": [{"role": "assistant", "content": content}]}]
rewards = reward_format([], completions)
rewards = reward_format(prompts, completions)
assert rewards[0] == 1.0, "Should accept responses with only answer tags"
@ -252,134 +259,33 @@ def test_reward_format_incomplete_tags():
def test_reward_retry():
"""Test reward retry functionality with progressive rewards up to 5 searches"""
# Test case with no searches
completions = [{"messages": [{"role": "assistant", "content": "No searches here"}]}]
rewards = reward_retry([], completions)
assert rewards[0] == 0.0, "Should get 0 reward for no searches"
# Test case with one search
completions = [
{
"messages": [
{
"role": "assistant",
"content": "<think>I need more information</think>\n<search>First query</search>",
}
]
}
]
rewards = reward_retry([], completions)
assert rewards[0] == 0.35, "Should get 0.35 reward for one search"
# Test case with three searches in different messages
"""Test reward retry function"""
prompts = ["What is the capital of France?"]
completions = [
{
"messages": [
{
"role": "assistant",
"content": "<think>First search</think>\n<search>Query 1</search>",
},
{"role": "assistant", "content": "<think>Second search</think>\n<search>Query 2</search>"},
{"role": "assistant", "content": "<think>Third search</think>\n<search>Query 3</search>"},
]
}
]
rewards = reward_retry([], completions)
assert rewards[0] == 0.65, "Should get 0.65 reward for three searches"
# Test case with five searches in different messages
completions = [
{
"messages": [
{"role": "assistant", "content": "<think>Search 1</think>\n<search>Query 1</search>"},
{"role": "assistant", "content": "<think>Search 2</think>\n<search>Query 2</search>"},
{"role": "assistant", "content": "<think>Search 3</think>\n<search>Query 3</search>"},
{"role": "assistant", "content": "<think>Search 4</think>\n<search>Query 4</search>"},
{"role": "assistant", "content": "<think>Search 5</think>\n<search>Query 5</search>"},
{"role": "assistant", "content": "<think>Let me search</think>\n<search>capital of France</search>"},
{"role": "assistant", "content": "<think>Need more info</think>\n<search>Paris history</search>"},
{"role": "assistant", "content": "<think>Found it</think>\n<answer>Paris</answer>"},
]
}
]
rewards = reward_retry([], completions)
assert rewards[0] == 0.95, "Should get 0.95 reward for five searches"
# Test case with more than five searches
completions = [
{
"messages": [
{"role": "assistant", "content": "<think>Search 1</think>\n<search>Query 1</search>"},
{"role": "assistant", "content": "<think>Search 2</think>\n<search>Query 2</search>"},
{"role": "assistant", "content": "<think>Search 3</think>\n<search>Query 3</search>"},
{"role": "assistant", "content": "<think>Search 4</think>\n<search>Query 4</search>"},
{"role": "assistant", "content": "<think>Search 5</think>\n<search>Query 5</search>"},
{"role": "assistant", "content": "<think>Search 6</think>\n<search>Query 6</search>"},
]
}
]
rewards = reward_retry([], completions)
assert rewards[0] == 0.95, "Should cap at 0.95 reward for more than five searches"
# Test case with violation (multiple searches in one message)
completions = [
{
"messages": [
{
"role": "assistant",
"content": "<think>Multiple searches</think>\n<search>First query</search>\n<search>Second query</search>",
}
]
}
]
rewards = reward_retry([], completions)
assert rewards[0] == 0.25, "Should get penalized reward (0.5 * 0.5) for violation"
rewards = reward_retry(prompts, completions)
assert len(rewards) == 1
assert rewards[0] > 0, "Should give positive reward for multiple search attempts"
def test_reward_em_chunk():
"""Test reward EM chunk functionality with information tags"""
# Test case with matching content in ipython role
"""Test exact match chunk reward function"""
prompts = ["What is Python?"]
completions = [
{"messages": [{"role": "ipython", "content": "<information>This is the correct chunk content</information>"}]}
{"messages": [{"role": "user", "content": "<information>Python is a programming language</information>"}]}
]
reward_kwargs = {"chunk_content": ["This is the correct chunk content"]}
correct_contents = ["Python is a programming language"]
rewards = reward_em_chunk([], completions, **reward_kwargs)
rewards = reward_em_chunk(prompts, completions, chunk_content=correct_contents)
assert len(rewards) == 1
assert rewards[0] == 1.0, "Should get reward 1.0 for exact match in ipython role"
# Test case with matching content in user role
completions = [
{"messages": [{"role": "user", "content": "<information>This is the correct chunk content</information>"}]}
]
rewards = reward_em_chunk([], completions, **reward_kwargs)
assert rewards[0] == 1.0, "Should get reward 1.0 for exact match in user role"
# Test case with content not starting with <information> tag
completions = [{"messages": [{"role": "ipython", "content": "This is the correct chunk content"}]}]
rewards = reward_em_chunk([], completions, **reward_kwargs)
assert rewards[0] == 0.0, "Should get reward 0.0 for missing information tag"
# Test case with wrong role
completions = [
{
"messages": [
{"role": "assistant", "content": "<information>This is the correct chunk content</information>"}
]
}
]
rewards = reward_em_chunk([], completions, **reward_kwargs)
assert rewards[0] == 0.0, "Should get reward 0.0 for wrong role"
# Test case with multiple messages, only one matching
completions = [
{
"messages": [
{"role": "ipython", "content": "<information>Wrong content</information>"},
{"role": "user", "content": "<information>This is the correct chunk content</information>"},
]
}
]
rewards = reward_em_chunk([], completions, **reward_kwargs)
assert rewards[0] == 1.0, "Should get reward 1.0 if any message matches"
assert rewards[0] == 1.0, "Should give full reward for exact chunk match"
def test_reward_em_chunk_no_chunk_content():

Loading…
Cancel
Save