diff --git a/.gitignore b/.gitignore index b5487e9..049499e 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,7 @@ model/ graveyard/ eval_logs/ downloaded_model/ +logs/ # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/src/rewards.py b/src/rewards.py index ea7b206..1f11879 100644 --- a/src/rewards.py +++ b/src/rewards.py @@ -2,11 +2,14 @@ Reward functions for RL training. """ +import json import re +from datetime import datetime +from pathlib import Path import numpy as np -from src.config import logger +from src.config import LOG_FOLDER, logger from src.evaluation import check_student_answers @@ -35,29 +38,46 @@ def build_reward_correctness_fn( Returns: List of correctness scores between 0 and 1 """ + # correctness: must be assistant response and must have only one pair of answer tags + # should not have search tags and information tags teacher_answers = reward_kwargs["answer"] - student_answers = [completion["messages"][-1]["content"] for completion in completions] - - # Log non-exact matches - for i, (student, teacher) in enumerate(zip(student_answers, teacher_answers)): - if student.strip().lower() != teacher.strip().lower(): - logger.debug(f"Non-exact match at index {i}:\nStudent: {student}\nTeacher: {teacher}") + student_final_messages = [completion["messages"][-1]["content"] for completion in completions] + student_final_message_roles = [completion["messages"][-1]["role"] for completion in completions] + is_assistant_response = [role == "assistant" for role in student_final_message_roles] + has_answer_tag = [re.search(r"[\s\S]*?", ans) is not None for ans in student_final_messages] + has_search_tag = [re.search(r"[\s\S]*?", ans) is not None for ans in student_final_messages] + has_information_tag = [ + re.search(r"[\s\S]*?", ans) is not None for ans in student_final_messages + ] - correct = check_student_answers( + might_be_correct = check_student_answers( prompts, teacher_answers, - student_answers, + student_final_messages, vllm_generate_func=vllm_generate_func, tokenizer=tokenizer, ) + # Convert lists to numpy arrays for element-wise operations + might_be_correct = np.array(might_be_correct) + is_assistant_response = np.array(is_assistant_response) + has_answer_tag = np.array(has_answer_tag) + has_search_tag = np.array(has_search_tag) + has_information_tag = np.array(has_information_tag) + + # might be correct and is assistant response and has answer tag and no search or information tags + correct = might_be_correct & is_assistant_response & has_answer_tag & ~has_search_tag & ~has_information_tag + + # Convert numpy array back to list for return + correct = correct.tolist() + # Log correctness metrics with length info logger.info(f"Correctness metrics: {correct}") logger.info(f"Average correctness: {np.mean(correct):.2f}") logger.info(f"Standard deviation: {np.std(correct):.2f}") # Log length metrics - student_lengths = [len(ans.strip()) for ans in student_answers] + student_lengths = [len(ans.strip()) for ans in student_final_messages] teacher_lengths = [len(ans.strip()) for ans in teacher_answers] logger.info(f"Student lengths: {student_lengths}") logger.info(f"Teacher lengths: {teacher_lengths}") @@ -65,6 +85,22 @@ def build_reward_correctness_fn( logger.info(f"Average teacher length: {np.mean(teacher_lengths):.2f}") logger.info(f"Length ratio: {np.mean(student_lengths) / np.mean(teacher_lengths):.2f}") + # Log chat state + log_chat_state( + prompts=prompts, + completions=completions, + rewards=correct, + reward_type="correctness", + teacher_answers=teacher_answers, + validation_results={ + "is_assistant": is_assistant_response, + "has_answer": has_answer_tag, + "has_search": has_search_tag, + "has_info": has_information_tag, + "might_be_correct": might_be_correct, + }, + ) + return correct return reward_correctness @@ -103,6 +139,13 @@ def reward_format(prompts: list, completions: list, **reward_kwargs) -> list: ] rewards = [] + validation_results = { + "has_think": [], + "has_answer": [], + "has_search": [], + "has_invalid_tags": [], + "has_info_tags": [], + } for completion in completions: messages = completion.get("messages", []) @@ -110,59 +153,68 @@ def reward_format(prompts: list, completions: list, **reward_kwargs) -> list: if not assistant_msgs: rewards.append(0.0) + for key in validation_results: + validation_results[key].append(False) continue - content = assistant_msgs[-1] # Get the last assistant message + content = assistant_msgs[-1] - # Check for invalid markdown formatting has_invalid_tags = any(re.search(pattern, content) for pattern in invalid_patterns) + validation_results["has_invalid_tags"].append(has_invalid_tags) if has_invalid_tags: - logger.debug("Found markdown-formatted tags in response") rewards.append(0.0) + for key in ["has_think", "has_answer", "has_search", "has_info_tags"]: + validation_results[key].append(False) continue - # Check for any information tag variants (should not exist in assistant messages) has_info_tags = False for pattern in info_patterns: - info_matches = re.findall(pattern, content, re.IGNORECASE) - if info_matches: - logger.debug(f"Found {len(info_matches)} information tag(s) of type '{pattern}' in assistant message") + if re.findall(pattern, content, re.IGNORECASE): has_info_tags = True break + validation_results["has_info_tags"].append(has_info_tags) if has_info_tags: rewards.append(0.0) + for key in ["has_think", "has_answer", "has_search"]: + validation_results[key].append(False) continue - # Find all tag matches think_matches = re.findall(think_pattern, content) search_matches = re.findall(search_pattern, content) answer_matches = re.findall(answer_pattern, content) - # Verify tag presence and count has_think = len(think_matches) >= 1 - has_answer = len(answer_matches) == 1 # Must have exactly one answer - has_search = len(search_matches) >= 1 # One or more search tags + has_answer = len(answer_matches) == 1 + has_search = len(search_matches) >= 1 + + validation_results["has_think"].append(has_think) + validation_results["has_answer"].append(has_answer) + validation_results["has_search"].append(has_search) - # Check for search and answer in the same message (not allowed) if has_search and has_answer: - logger.debug("Found both search and answer tags in the same message") rewards.append(0.0) continue - # Award reward - must have think tag and either answer or search (but not both) reward = 1.0 if has_think and (has_answer or has_search) else 0.0 rewards.append(reward) - # Log issues for debugging if not reward: logger.debug(f"Format issues - think: {has_think}, answer: {has_answer}, search: {has_search}") if search_matches: logger.debug(f"Number of search tags: {len(search_matches)}") - # Log overall metrics logger.info(f"Format reward metrics - Mean: {np.mean(rewards):.3f}, Valid formats: {sum(rewards)}/{len(rewards)}") + # Log chat state with validation results + log_chat_state( + prompts=prompts, + completions=completions, + rewards=rewards, + reward_type="format", + validation_results=validation_results, + ) + return rewards @@ -247,6 +299,17 @@ def reward_retry(prompts: list, completions: list, **reward_kwargs) -> list: logger.info(f"Violations (>1 search per message): {sum(violations)}/{len(violations)}") logger.info(f"Search counts distribution: {search_queries}") + # Log chat state + log_chat_state( + prompts=prompts, + completions=completions, + rewards=rewards, + reward_type="retry", + search_counts=search_queries, + violations=violations, + optimal_search_count=optimal_search_count, + ) + return rewards @@ -309,4 +372,55 @@ def reward_em_chunk(prompts: list, completions: list, **reward_kwargs) -> list: logger.info(f"Average reward: {np.mean(rewards):.3f}") logger.info(f"Reward std: {np.std(rewards):.3f}") + # Log chat state + log_chat_state( + prompts=prompts, + completions=completions, + rewards=rewards, + reward_type="em_chunk", + correct_contents=correct_contents, + ) + return rewards + + +def log_chat_state(prompts: list, completions: list, rewards: list, reward_type: str, **kwargs) -> None: + """Log chat state and rewards to JSONL file. + + Args: + prompts: List of input prompts + completions: List of model completions + rewards: List of calculated rewards + reward_type: Type of reward function used + **kwargs: Additional data to log + """ + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + chat_states_dir = LOG_FOLDER / "chat_states" + chat_states_dir.mkdir(parents=True, exist_ok=True) + + # Convert numpy arrays to lists in kwargs + for key, value in kwargs.items(): + if isinstance(value, dict): + for k, v in value.items(): + if isinstance(v, np.ndarray): + kwargs[key][k] = v.tolist() + elif isinstance(value, np.ndarray): + kwargs[key] = value.tolist() + + # Create one JSONL file per reward type + log_file = chat_states_dir / f"chat_states_{reward_type}.jsonl" + + # Append each chat state as a new line + with open(log_file, "a", encoding="utf-8") as f: + for prompt, completion, reward in zip(prompts, completions, rewards): + chat_state = { + "timestamp": timestamp, + "reward_type": reward_type, + "prompt": prompt, + "messages": completion["messages"], + "reward": float(reward) if isinstance(reward, (np.number, np.ndarray)) else reward, + "metadata": kwargs, + } + f.write(json.dumps(chat_state, ensure_ascii=False) + "\n") + + logger.info(f"💾 Appended {len(prompts)} chat states to {log_file}") diff --git a/tests/test_rewards.py b/tests/test_rewards.py index bf8591b..4941f6c 100644 --- a/tests/test_rewards.py +++ b/tests/test_rewards.py @@ -71,6 +71,7 @@ def test_reward_correctness_wrong_answer(reward_correctness_fn): def test_reward_format_correct(): """Test reward format with correct format""" + prompts = ["Test prompt"] completions = [ { "messages": [ @@ -78,21 +79,23 @@ def test_reward_format_correct(): ] } ] - rewards = reward_format([], completions) + rewards = reward_format(prompts, completions) assert rewards[0] == 1.0 def test_reward_format_with_search(): """Test reward format with search tags only (no answer tags)""" + prompts = ["Test prompt"] completions = [ {"messages": [{"role": "assistant", "content": "\nSome reasoning\n\nquery"}]} ] - rewards = reward_format([], completions) + rewards = reward_format(prompts, completions) assert rewards[0] == 1.0 def test_reward_format_markdown_tags(): """Test reward format with markdown-styled tags""" + prompts = ["Test prompt"] markdown_formats = [ { "messages": [ @@ -121,12 +124,13 @@ def test_reward_format_markdown_tags(): ] for completion in markdown_formats: - rewards = reward_format([], [completion]) + rewards = reward_format(["Test prompt"], [completion]) assert rewards[0] == 0.0, f"Failed with: {completion['messages'][0]['content']}" def test_reward_format_information_tags(): """Test reward format with information tags""" + prompts = ["Test prompt"] # Test different information tag variants info_variants = [ "Some info", @@ -139,12 +143,13 @@ def test_reward_format_information_tags(): for info_tag in info_variants: content = f"\nSome reasoning\n\n{info_tag}\n\nThe answer\n" completions = [{"messages": [{"role": "assistant", "content": content}]}] - rewards = reward_format([], completions) + rewards = reward_format(prompts, completions) assert rewards[0] == 0.0, f"Failed to detect information tag: {info_tag}" def test_reward_format_real_example(): """Test reward format with a real-world example - should fail now since it has both search and answer tags""" + prompts = ["What cars did Paul Walker drive in Fast and Furious?"] content = """I need to search for Paul Walker's cars in Fast and Furious movies. Paul Walker's cars in Fast and Furious @@ -157,27 +162,29 @@ Based on the updated information, it seems the focus was on his career, financia Charger """ completions = [{"messages": [{"role": "assistant", "content": content}]}] - rewards = reward_format([], completions) + rewards = reward_format(prompts, completions) assert rewards[0] == 0.0, "Should reject responses with both search and answer tags" def test_reward_format_real_example_search_only(): """Test reward format with search-only format in a real-world example""" + prompts = ["What cars did Paul Walker drive in Fast and Furious?"] content = """I need to search for Paul Walker's cars in Fast and Furious movies. Paul Walker's cars in Fast and Furious """ completions = [{"messages": [{"role": "assistant", "content": content}]}] - rewards = reward_format([], completions) + rewards = reward_format(prompts, completions) assert rewards[0] == 1.0, "Should accept responses with only search tags" def test_reward_format_real_example_answer_only(): """Test reward format with answer-only format in a real-world example""" + prompts = ["What cars did Paul Walker drive in Fast and Furious?"] content = """Based on the information provided, it seems Paul Walker drove a Charger in the Fast and Furious series. Charger """ completions = [{"messages": [{"role": "assistant", "content": content}]}] - rewards = reward_format([], completions) + rewards = reward_format(prompts, completions) assert rewards[0] == 1.0, "Should accept responses with only answer tags" @@ -252,134 +259,33 @@ def test_reward_format_incomplete_tags(): def test_reward_retry(): - """Test reward retry functionality with progressive rewards up to 5 searches""" - # Test case with no searches - completions = [{"messages": [{"role": "assistant", "content": "No searches here"}]}] - rewards = reward_retry([], completions) - assert rewards[0] == 0.0, "Should get 0 reward for no searches" - - # Test case with one search - completions = [ - { - "messages": [ - { - "role": "assistant", - "content": "I need more information\nFirst query", - } - ] - } - ] - rewards = reward_retry([], completions) - assert rewards[0] == 0.35, "Should get 0.35 reward for one search" - - # Test case with three searches in different messages + """Test reward retry function""" + prompts = ["What is the capital of France?"] completions = [ { "messages": [ - { - "role": "assistant", - "content": "First search\nQuery 1", - }, - {"role": "assistant", "content": "Second search\nQuery 2"}, - {"role": "assistant", "content": "Third search\nQuery 3"}, - ] - } - ] - rewards = reward_retry([], completions) - assert rewards[0] == 0.65, "Should get 0.65 reward for three searches" - - # Test case with five searches in different messages - completions = [ - { - "messages": [ - {"role": "assistant", "content": "Search 1\nQuery 1"}, - {"role": "assistant", "content": "Search 2\nQuery 2"}, - {"role": "assistant", "content": "Search 3\nQuery 3"}, - {"role": "assistant", "content": "Search 4\nQuery 4"}, - {"role": "assistant", "content": "Search 5\nQuery 5"}, + {"role": "assistant", "content": "Let me search\ncapital of France"}, + {"role": "assistant", "content": "Need more info\nParis history"}, + {"role": "assistant", "content": "Found it\nParis"}, ] } ] - rewards = reward_retry([], completions) - assert rewards[0] == 0.95, "Should get 0.95 reward for five searches" - - # Test case with more than five searches - completions = [ - { - "messages": [ - {"role": "assistant", "content": "Search 1\nQuery 1"}, - {"role": "assistant", "content": "Search 2\nQuery 2"}, - {"role": "assistant", "content": "Search 3\nQuery 3"}, - {"role": "assistant", "content": "Search 4\nQuery 4"}, - {"role": "assistant", "content": "Search 5\nQuery 5"}, - {"role": "assistant", "content": "Search 6\nQuery 6"}, - ] - } - ] - rewards = reward_retry([], completions) - assert rewards[0] == 0.95, "Should cap at 0.95 reward for more than five searches" - - # Test case with violation (multiple searches in one message) - completions = [ - { - "messages": [ - { - "role": "assistant", - "content": "Multiple searches\nFirst query\nSecond query", - } - ] - } - ] - rewards = reward_retry([], completions) - assert rewards[0] == 0.25, "Should get penalized reward (0.5 * 0.5) for violation" + rewards = reward_retry(prompts, completions) + assert len(rewards) == 1 + assert rewards[0] > 0, "Should give positive reward for multiple search attempts" def test_reward_em_chunk(): - """Test reward EM chunk functionality with information tags""" - # Test case with matching content in ipython role + """Test exact match chunk reward function""" + prompts = ["What is Python?"] completions = [ - {"messages": [{"role": "ipython", "content": "This is the correct chunk content"}]} + {"messages": [{"role": "user", "content": "Python is a programming language"}]} ] - reward_kwargs = {"chunk_content": ["This is the correct chunk content"]} + correct_contents = ["Python is a programming language"] - rewards = reward_em_chunk([], completions, **reward_kwargs) + rewards = reward_em_chunk(prompts, completions, chunk_content=correct_contents) assert len(rewards) == 1 - assert rewards[0] == 1.0, "Should get reward 1.0 for exact match in ipython role" - - # Test case with matching content in user role - completions = [ - {"messages": [{"role": "user", "content": "This is the correct chunk content"}]} - ] - rewards = reward_em_chunk([], completions, **reward_kwargs) - assert rewards[0] == 1.0, "Should get reward 1.0 for exact match in user role" - - # Test case with content not starting with tag - completions = [{"messages": [{"role": "ipython", "content": "This is the correct chunk content"}]}] - rewards = reward_em_chunk([], completions, **reward_kwargs) - assert rewards[0] == 0.0, "Should get reward 0.0 for missing information tag" - - # Test case with wrong role - completions = [ - { - "messages": [ - {"role": "assistant", "content": "This is the correct chunk content"} - ] - } - ] - rewards = reward_em_chunk([], completions, **reward_kwargs) - assert rewards[0] == 0.0, "Should get reward 0.0 for wrong role" - - # Test case with multiple messages, only one matching - completions = [ - { - "messages": [ - {"role": "ipython", "content": "Wrong content"}, - {"role": "user", "content": "This is the correct chunk content"}, - ] - } - ] - rewards = reward_em_chunk([], completions, **reward_kwargs) - assert rewards[0] == 1.0, "Should get reward 1.0 if any message matches" + assert rewards[0] == 1.0, "Should give full reward for exact chunk match" def test_reward_em_chunk_no_chunk_content():