diff --git a/.gitignore b/.gitignore
index b5487e9..049499e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -14,6 +14,7 @@ model/
graveyard/
eval_logs/
downloaded_model/
+logs/
# Byte-compiled / optimized / DLL files
__pycache__/
diff --git a/src/rewards.py b/src/rewards.py
index ea7b206..1f11879 100644
--- a/src/rewards.py
+++ b/src/rewards.py
@@ -2,11 +2,14 @@
Reward functions for RL training.
"""
+import json
import re
+from datetime import datetime
+from pathlib import Path
import numpy as np
-from src.config import logger
+from src.config import LOG_FOLDER, logger
from src.evaluation import check_student_answers
@@ -35,29 +38,46 @@ def build_reward_correctness_fn(
Returns:
List of correctness scores between 0 and 1
"""
+ # correctness: must be assistant response and must have only one pair of answer tags
+ # should not have search tags and information tags
teacher_answers = reward_kwargs["answer"]
- student_answers = [completion["messages"][-1]["content"] for completion in completions]
-
- # Log non-exact matches
- for i, (student, teacher) in enumerate(zip(student_answers, teacher_answers)):
- if student.strip().lower() != teacher.strip().lower():
- logger.debug(f"Non-exact match at index {i}:\nStudent: {student}\nTeacher: {teacher}")
+ student_final_messages = [completion["messages"][-1]["content"] for completion in completions]
+ student_final_message_roles = [completion["messages"][-1]["role"] for completion in completions]
+ is_assistant_response = [role == "assistant" for role in student_final_message_roles]
+ has_answer_tag = [re.search(r"[\s\S]*?", ans) is not None for ans in student_final_messages]
+ has_search_tag = [re.search(r"[\s\S]*?", ans) is not None for ans in student_final_messages]
+ has_information_tag = [
+ re.search(r"[\s\S]*?", ans) is not None for ans in student_final_messages
+ ]
- correct = check_student_answers(
+ might_be_correct = check_student_answers(
prompts,
teacher_answers,
- student_answers,
+ student_final_messages,
vllm_generate_func=vllm_generate_func,
tokenizer=tokenizer,
)
+ # Convert lists to numpy arrays for element-wise operations
+ might_be_correct = np.array(might_be_correct)
+ is_assistant_response = np.array(is_assistant_response)
+ has_answer_tag = np.array(has_answer_tag)
+ has_search_tag = np.array(has_search_tag)
+ has_information_tag = np.array(has_information_tag)
+
+ # might be correct and is assistant response and has answer tag and no search or information tags
+ correct = might_be_correct & is_assistant_response & has_answer_tag & ~has_search_tag & ~has_information_tag
+
+ # Convert numpy array back to list for return
+ correct = correct.tolist()
+
# Log correctness metrics with length info
logger.info(f"Correctness metrics: {correct}")
logger.info(f"Average correctness: {np.mean(correct):.2f}")
logger.info(f"Standard deviation: {np.std(correct):.2f}")
# Log length metrics
- student_lengths = [len(ans.strip()) for ans in student_answers]
+ student_lengths = [len(ans.strip()) for ans in student_final_messages]
teacher_lengths = [len(ans.strip()) for ans in teacher_answers]
logger.info(f"Student lengths: {student_lengths}")
logger.info(f"Teacher lengths: {teacher_lengths}")
@@ -65,6 +85,22 @@ def build_reward_correctness_fn(
logger.info(f"Average teacher length: {np.mean(teacher_lengths):.2f}")
logger.info(f"Length ratio: {np.mean(student_lengths) / np.mean(teacher_lengths):.2f}")
+ # Log chat state
+ log_chat_state(
+ prompts=prompts,
+ completions=completions,
+ rewards=correct,
+ reward_type="correctness",
+ teacher_answers=teacher_answers,
+ validation_results={
+ "is_assistant": is_assistant_response,
+ "has_answer": has_answer_tag,
+ "has_search": has_search_tag,
+ "has_info": has_information_tag,
+ "might_be_correct": might_be_correct,
+ },
+ )
+
return correct
return reward_correctness
@@ -103,6 +139,13 @@ def reward_format(prompts: list, completions: list, **reward_kwargs) -> list:
]
rewards = []
+ validation_results = {
+ "has_think": [],
+ "has_answer": [],
+ "has_search": [],
+ "has_invalid_tags": [],
+ "has_info_tags": [],
+ }
for completion in completions:
messages = completion.get("messages", [])
@@ -110,59 +153,68 @@ def reward_format(prompts: list, completions: list, **reward_kwargs) -> list:
if not assistant_msgs:
rewards.append(0.0)
+ for key in validation_results:
+ validation_results[key].append(False)
continue
- content = assistant_msgs[-1] # Get the last assistant message
+ content = assistant_msgs[-1]
- # Check for invalid markdown formatting
has_invalid_tags = any(re.search(pattern, content) for pattern in invalid_patterns)
+ validation_results["has_invalid_tags"].append(has_invalid_tags)
if has_invalid_tags:
- logger.debug("Found markdown-formatted tags in response")
rewards.append(0.0)
+ for key in ["has_think", "has_answer", "has_search", "has_info_tags"]:
+ validation_results[key].append(False)
continue
- # Check for any information tag variants (should not exist in assistant messages)
has_info_tags = False
for pattern in info_patterns:
- info_matches = re.findall(pattern, content, re.IGNORECASE)
- if info_matches:
- logger.debug(f"Found {len(info_matches)} information tag(s) of type '{pattern}' in assistant message")
+ if re.findall(pattern, content, re.IGNORECASE):
has_info_tags = True
break
+ validation_results["has_info_tags"].append(has_info_tags)
if has_info_tags:
rewards.append(0.0)
+ for key in ["has_think", "has_answer", "has_search"]:
+ validation_results[key].append(False)
continue
- # Find all tag matches
think_matches = re.findall(think_pattern, content)
search_matches = re.findall(search_pattern, content)
answer_matches = re.findall(answer_pattern, content)
- # Verify tag presence and count
has_think = len(think_matches) >= 1
- has_answer = len(answer_matches) == 1 # Must have exactly one answer
- has_search = len(search_matches) >= 1 # One or more search tags
+ has_answer = len(answer_matches) == 1
+ has_search = len(search_matches) >= 1
+
+ validation_results["has_think"].append(has_think)
+ validation_results["has_answer"].append(has_answer)
+ validation_results["has_search"].append(has_search)
- # Check for search and answer in the same message (not allowed)
if has_search and has_answer:
- logger.debug("Found both search and answer tags in the same message")
rewards.append(0.0)
continue
- # Award reward - must have think tag and either answer or search (but not both)
reward = 1.0 if has_think and (has_answer or has_search) else 0.0
rewards.append(reward)
- # Log issues for debugging
if not reward:
logger.debug(f"Format issues - think: {has_think}, answer: {has_answer}, search: {has_search}")
if search_matches:
logger.debug(f"Number of search tags: {len(search_matches)}")
- # Log overall metrics
logger.info(f"Format reward metrics - Mean: {np.mean(rewards):.3f}, Valid formats: {sum(rewards)}/{len(rewards)}")
+ # Log chat state with validation results
+ log_chat_state(
+ prompts=prompts,
+ completions=completions,
+ rewards=rewards,
+ reward_type="format",
+ validation_results=validation_results,
+ )
+
return rewards
@@ -247,6 +299,17 @@ def reward_retry(prompts: list, completions: list, **reward_kwargs) -> list:
logger.info(f"Violations (>1 search per message): {sum(violations)}/{len(violations)}")
logger.info(f"Search counts distribution: {search_queries}")
+ # Log chat state
+ log_chat_state(
+ prompts=prompts,
+ completions=completions,
+ rewards=rewards,
+ reward_type="retry",
+ search_counts=search_queries,
+ violations=violations,
+ optimal_search_count=optimal_search_count,
+ )
+
return rewards
@@ -309,4 +372,55 @@ def reward_em_chunk(prompts: list, completions: list, **reward_kwargs) -> list:
logger.info(f"Average reward: {np.mean(rewards):.3f}")
logger.info(f"Reward std: {np.std(rewards):.3f}")
+ # Log chat state
+ log_chat_state(
+ prompts=prompts,
+ completions=completions,
+ rewards=rewards,
+ reward_type="em_chunk",
+ correct_contents=correct_contents,
+ )
+
return rewards
+
+
+def log_chat_state(prompts: list, completions: list, rewards: list, reward_type: str, **kwargs) -> None:
+ """Log chat state and rewards to JSONL file.
+
+ Args:
+ prompts: List of input prompts
+ completions: List of model completions
+ rewards: List of calculated rewards
+ reward_type: Type of reward function used
+ **kwargs: Additional data to log
+ """
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ chat_states_dir = LOG_FOLDER / "chat_states"
+ chat_states_dir.mkdir(parents=True, exist_ok=True)
+
+ # Convert numpy arrays to lists in kwargs
+ for key, value in kwargs.items():
+ if isinstance(value, dict):
+ for k, v in value.items():
+ if isinstance(v, np.ndarray):
+ kwargs[key][k] = v.tolist()
+ elif isinstance(value, np.ndarray):
+ kwargs[key] = value.tolist()
+
+ # Create one JSONL file per reward type
+ log_file = chat_states_dir / f"chat_states_{reward_type}.jsonl"
+
+ # Append each chat state as a new line
+ with open(log_file, "a", encoding="utf-8") as f:
+ for prompt, completion, reward in zip(prompts, completions, rewards):
+ chat_state = {
+ "timestamp": timestamp,
+ "reward_type": reward_type,
+ "prompt": prompt,
+ "messages": completion["messages"],
+ "reward": float(reward) if isinstance(reward, (np.number, np.ndarray)) else reward,
+ "metadata": kwargs,
+ }
+ f.write(json.dumps(chat_state, ensure_ascii=False) + "\n")
+
+ logger.info(f"💾 Appended {len(prompts)} chat states to {log_file}")
diff --git a/tests/test_rewards.py b/tests/test_rewards.py
index bf8591b..4941f6c 100644
--- a/tests/test_rewards.py
+++ b/tests/test_rewards.py
@@ -71,6 +71,7 @@ def test_reward_correctness_wrong_answer(reward_correctness_fn):
def test_reward_format_correct():
"""Test reward format with correct format"""
+ prompts = ["Test prompt"]
completions = [
{
"messages": [
@@ -78,21 +79,23 @@ def test_reward_format_correct():
]
}
]
- rewards = reward_format([], completions)
+ rewards = reward_format(prompts, completions)
assert rewards[0] == 1.0
def test_reward_format_with_search():
"""Test reward format with search tags only (no answer tags)"""
+ prompts = ["Test prompt"]
completions = [
{"messages": [{"role": "assistant", "content": "\nSome reasoning\n\nquery"}]}
]
- rewards = reward_format([], completions)
+ rewards = reward_format(prompts, completions)
assert rewards[0] == 1.0
def test_reward_format_markdown_tags():
"""Test reward format with markdown-styled tags"""
+ prompts = ["Test prompt"]
markdown_formats = [
{
"messages": [
@@ -121,12 +124,13 @@ def test_reward_format_markdown_tags():
]
for completion in markdown_formats:
- rewards = reward_format([], [completion])
+ rewards = reward_format(["Test prompt"], [completion])
assert rewards[0] == 0.0, f"Failed with: {completion['messages'][0]['content']}"
def test_reward_format_information_tags():
"""Test reward format with information tags"""
+ prompts = ["Test prompt"]
# Test different information tag variants
info_variants = [
"Some info",
@@ -139,12 +143,13 @@ def test_reward_format_information_tags():
for info_tag in info_variants:
content = f"\nSome reasoning\n\n{info_tag}\n\nThe answer\n"
completions = [{"messages": [{"role": "assistant", "content": content}]}]
- rewards = reward_format([], completions)
+ rewards = reward_format(prompts, completions)
assert rewards[0] == 0.0, f"Failed to detect information tag: {info_tag}"
def test_reward_format_real_example():
"""Test reward format with a real-world example - should fail now since it has both search and answer tags"""
+ prompts = ["What cars did Paul Walker drive in Fast and Furious?"]
content = """I need to search for Paul Walker's cars in Fast and Furious movies.
Paul Walker's cars in Fast and Furious
@@ -157,27 +162,29 @@ Based on the updated information, it seems the focus was on his career, financia
Charger """
completions = [{"messages": [{"role": "assistant", "content": content}]}]
- rewards = reward_format([], completions)
+ rewards = reward_format(prompts, completions)
assert rewards[0] == 0.0, "Should reject responses with both search and answer tags"
def test_reward_format_real_example_search_only():
"""Test reward format with search-only format in a real-world example"""
+ prompts = ["What cars did Paul Walker drive in Fast and Furious?"]
content = """I need to search for Paul Walker's cars in Fast and Furious movies.
Paul Walker's cars in Fast and Furious """
completions = [{"messages": [{"role": "assistant", "content": content}]}]
- rewards = reward_format([], completions)
+ rewards = reward_format(prompts, completions)
assert rewards[0] == 1.0, "Should accept responses with only search tags"
def test_reward_format_real_example_answer_only():
"""Test reward format with answer-only format in a real-world example"""
+ prompts = ["What cars did Paul Walker drive in Fast and Furious?"]
content = """Based on the information provided, it seems Paul Walker drove a Charger in the Fast and Furious series.
Charger """
completions = [{"messages": [{"role": "assistant", "content": content}]}]
- rewards = reward_format([], completions)
+ rewards = reward_format(prompts, completions)
assert rewards[0] == 1.0, "Should accept responses with only answer tags"
@@ -252,134 +259,33 @@ def test_reward_format_incomplete_tags():
def test_reward_retry():
- """Test reward retry functionality with progressive rewards up to 5 searches"""
- # Test case with no searches
- completions = [{"messages": [{"role": "assistant", "content": "No searches here"}]}]
- rewards = reward_retry([], completions)
- assert rewards[0] == 0.0, "Should get 0 reward for no searches"
-
- # Test case with one search
- completions = [
- {
- "messages": [
- {
- "role": "assistant",
- "content": "I need more information\nFirst query",
- }
- ]
- }
- ]
- rewards = reward_retry([], completions)
- assert rewards[0] == 0.35, "Should get 0.35 reward for one search"
-
- # Test case with three searches in different messages
+ """Test reward retry function"""
+ prompts = ["What is the capital of France?"]
completions = [
{
"messages": [
- {
- "role": "assistant",
- "content": "First search\nQuery 1",
- },
- {"role": "assistant", "content": "Second search\nQuery 2"},
- {"role": "assistant", "content": "Third search\nQuery 3"},
- ]
- }
- ]
- rewards = reward_retry([], completions)
- assert rewards[0] == 0.65, "Should get 0.65 reward for three searches"
-
- # Test case with five searches in different messages
- completions = [
- {
- "messages": [
- {"role": "assistant", "content": "Search 1\nQuery 1"},
- {"role": "assistant", "content": "Search 2\nQuery 2"},
- {"role": "assistant", "content": "Search 3\nQuery 3"},
- {"role": "assistant", "content": "Search 4\nQuery 4"},
- {"role": "assistant", "content": "Search 5\nQuery 5"},
+ {"role": "assistant", "content": "Let me search\ncapital of France"},
+ {"role": "assistant", "content": "Need more info\nParis history"},
+ {"role": "assistant", "content": "Found it\nParis"},
]
}
]
- rewards = reward_retry([], completions)
- assert rewards[0] == 0.95, "Should get 0.95 reward for five searches"
-
- # Test case with more than five searches
- completions = [
- {
- "messages": [
- {"role": "assistant", "content": "Search 1\nQuery 1"},
- {"role": "assistant", "content": "Search 2\nQuery 2"},
- {"role": "assistant", "content": "Search 3\nQuery 3"},
- {"role": "assistant", "content": "Search 4\nQuery 4"},
- {"role": "assistant", "content": "Search 5\nQuery 5"},
- {"role": "assistant", "content": "Search 6\nQuery 6"},
- ]
- }
- ]
- rewards = reward_retry([], completions)
- assert rewards[0] == 0.95, "Should cap at 0.95 reward for more than five searches"
-
- # Test case with violation (multiple searches in one message)
- completions = [
- {
- "messages": [
- {
- "role": "assistant",
- "content": "Multiple searches\nFirst query\nSecond query",
- }
- ]
- }
- ]
- rewards = reward_retry([], completions)
- assert rewards[0] == 0.25, "Should get penalized reward (0.5 * 0.5) for violation"
+ rewards = reward_retry(prompts, completions)
+ assert len(rewards) == 1
+ assert rewards[0] > 0, "Should give positive reward for multiple search attempts"
def test_reward_em_chunk():
- """Test reward EM chunk functionality with information tags"""
- # Test case with matching content in ipython role
+ """Test exact match chunk reward function"""
+ prompts = ["What is Python?"]
completions = [
- {"messages": [{"role": "ipython", "content": "This is the correct chunk content"}]}
+ {"messages": [{"role": "user", "content": "Python is a programming language"}]}
]
- reward_kwargs = {"chunk_content": ["This is the correct chunk content"]}
+ correct_contents = ["Python is a programming language"]
- rewards = reward_em_chunk([], completions, **reward_kwargs)
+ rewards = reward_em_chunk(prompts, completions, chunk_content=correct_contents)
assert len(rewards) == 1
- assert rewards[0] == 1.0, "Should get reward 1.0 for exact match in ipython role"
-
- # Test case with matching content in user role
- completions = [
- {"messages": [{"role": "user", "content": "This is the correct chunk content"}]}
- ]
- rewards = reward_em_chunk([], completions, **reward_kwargs)
- assert rewards[0] == 1.0, "Should get reward 1.0 for exact match in user role"
-
- # Test case with content not starting with tag
- completions = [{"messages": [{"role": "ipython", "content": "This is the correct chunk content"}]}]
- rewards = reward_em_chunk([], completions, **reward_kwargs)
- assert rewards[0] == 0.0, "Should get reward 0.0 for missing information tag"
-
- # Test case with wrong role
- completions = [
- {
- "messages": [
- {"role": "assistant", "content": "This is the correct chunk content"}
- ]
- }
- ]
- rewards = reward_em_chunk([], completions, **reward_kwargs)
- assert rewards[0] == 0.0, "Should get reward 0.0 for wrong role"
-
- # Test case with multiple messages, only one matching
- completions = [
- {
- "messages": [
- {"role": "ipython", "content": "Wrong content"},
- {"role": "user", "content": "This is the correct chunk content"},
- ]
- }
- ]
- rewards = reward_em_chunk([], completions, **reward_kwargs)
- assert rewards[0] == 1.0, "Should get reward 1.0 if any message matches"
+ assert rewards[0] == 1.0, "Should give full reward for exact chunk match"
def test_reward_em_chunk_no_chunk_content():