diff --git a/generate_data_but_less_dumb.py b/generate_data_but_less_dumb.py new file mode 100644 index 0000000..e5842bf --- /dev/null +++ b/generate_data_but_less_dumb.py @@ -0,0 +1,279 @@ +""" +This script performs two main tasks: +1. It loads a markdown document, splits it into chunks, generates embeddings, + and builds a FAISS index (which is saved locally). +2. It generates QA pairs from the document using llama. + For each chunk (using a sliding window for context), it generates multiple question-answer pairs + with different difficulties. The generation is performed in batch with one retry for failed prompts. + Successfully generated QA pairs are saved to "saved_data/questions.json". + +Requirements: + pip install langchain faiss-cpu unsloth vllm +""" + +import json +import os +import re +from typing import Dict, List, Optional, Tuple + +import pandas as pd +from langchain.text_splitter import RecursiveCharacterTextSplitter +from loguru import logger + +# Configure logger +logger.add( + "logs/generate_data_{time}.log", + rotation="500 MB", + retention="10 days", + level="INFO", +) + +# ========= Part 1: Document Processing and Embedding Generation ========= +# Load and split the markdown document using LangChain +from langchain_community.document_loaders import UnstructuredMarkdownLoader +from langchain_community.vectorstores import FAISS + +from embeddings import CustomHuggingFaceEmbeddings + +# Load your markdown file (adjust the path as needed) +loader = UnstructuredMarkdownLoader("./data/mission_report.md") +docs = loader.load() + +# Split the document into smaller chunks (each 1000 characters, no overlap) +text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0) +chunks = text_splitter.split_documents(docs) + +# Create output directory +os.makedirs("saved_data", exist_ok=True) + +# Save chunks to CSV for easy inspection +chunks_df = pd.DataFrame( + { + "chunk_id": range(1, len(chunks) + 1), + "content": [chunk.page_content for chunk in chunks], + "metadata": [chunk.metadata for chunk in chunks], + } +) +chunks_df.to_csv("saved_data/chunks.csv", index=False) +print(f"Saved {len(chunks)} chunks to saved_data/chunks.csv") + +embeddings = CustomHuggingFaceEmbeddings() + +# Create a FAISS vector store from the document chunks and save it locally +vectorstore = FAISS.from_documents(chunks, embeddings) +vectorstore.save_local("faiss_index") +print("Saved FAISS index to 'faiss_index'") + +# TODO: add the paraphrased chunks to the vector store + +# ========= Part 2: QA Generation using Llama Backend ========= + +# Setup Llama backend via unsloth and vLLM +from unsloth import FastLanguageModel +from vllm import SamplingParams + +import rl_helpers # Ensure you have this or remove if not used + +# Load the Llama model (adjust parameters as needed) +model, tokenizer = FastLanguageModel.from_pretrained( + model_name="meta-llama/meta-Llama-3.1-8B-Instruct", + max_seq_length=4096, + load_in_4bit=True, # Use 4-bit quantization if desired + fast_inference=True, # Enable fast inference + gpu_memory_utilization=0.6, # Adjust based on your GPU memory +) + +# Define sampling parameters for generation +sampling_params = SamplingParams( + temperature=0.3, + top_p=0.95, + max_tokens=4096, +) + + +def batch_generate(prompts: List[str]) -> List[str]: + """ + Given a list of prompt strings, returns a list of generated outputs. + """ + + def format_input(text: str) -> str: + return tokenizer.apply_chat_template( + [{"role": "user", "content": text}], + tokenize=False, + add_generation_prompt=True, + ) + + formatted = [format_input(p) for p in prompts] + outputs = model.fast_generate(formatted, sampling_params=sampling_params) + return [output.outputs[0].text for output in outputs] + + +def parse_qa_block(block: str) -> Optional[Tuple[str, str, str]]: + """ + Parses a QA block that should contain exactly three non-empty lines: + - A line starting with "Question:" + - A line starting with "Answer:" + - A line starting with "Difficulty:" + + If the markers are not present but the block contains exactly three lines, + those are used in order. + + Returns a tuple (question, answer, difficulty) or None if parsing fails. + """ + lines = [line.strip() for line in block.splitlines() if line.strip()] + if not lines: + return None + + question, answer, difficulty = None, None, None + for line in lines: + lower = line.lower() + if question is None and lower.startswith("question:"): + question = line[len("question:") :].strip() + elif answer is None and lower.startswith("answer:"): + answer = line[len("answer:") :].strip() + elif difficulty is None and lower.startswith("difficulty:"): + difficulty = line[len("difficulty:") :].strip() + + if question and answer and difficulty: + return question, answer, difficulty + if len(lines) == 3: + return lines[0], lines[1], lines[2] + return None + + +def parse_multiple_qa_output(output: str) -> List[Tuple[str, str, str]]: + """ + Splits the output into blocks (separated by one or more blank lines) and + attempts to parse each as a QA pair. + + Returns a list of successfully parsed QA tuples. + """ + blocks = re.split(r"\n\s*\n", output.strip()) + qa_pairs = [] + for block in blocks: + parsed = parse_qa_block(block) + if parsed: + qa_pairs.append(parsed) + return qa_pairs + + +def generate_question_batch_for_chunks( + chunks: List, num_questions: int = 2, difficulty: Optional[str] = None +) -> List[Dict]: + """ + Generates QA pairs for multiple chunks in batch. + + For each chunk, generates questions based on its content only. + Each prompt instructs the model to output exactly three lines per QA pair with markers. + Failed prompts are retried once in batch; if still unsuccessful, they are skipped. + + Returns a list of dicts with keys: "chunk_id", "question", "answer", "difficulty", "chunk_content". + """ + prompts = [] + chunk_ids = [] + chunk_contents = [] + + # Prepare prompts for each chunk + for i, chunk in enumerate(chunks): + current = chunk.page_content + prompt = ( + f"You are a question generator. Generate {num_questions} questions based on the following text.\n" + "Rules:\n" + "1. Questions must be answerable using ONLY the information in the text\n" + "2. Answers must be directly stated in the text\n" + "3. Each question should test understanding of a different aspect of the text\n" + "4. Questions should be clear and specific\n" + "5. Answers should be concise and factual\n\n" + "For each QA pair, output exactly three lines with no extra commentary:\n" + "Line 1: Question: \n" + "Line 2: Answer: \n" + "Line 3: Difficulty: \n" + "Do not include any additional text.\n\n" + "Text:\n" + f"{current}\n" + ) + prompts.append(prompt) + chunk_ids.append(i + 1) # 1-based indexing + chunk_contents.append(current) + + # First batch generation + outputs = batch_generate(prompts) + results: List[Optional[List[Tuple[str, str, str]]]] = [None] * len(outputs) + failed_indices = [] + + # Parse each output + for idx, output in enumerate(outputs): + qa_pairs = parse_multiple_qa_output(output) + if qa_pairs is None or len(qa_pairs) < num_questions: + failed_indices.append(idx) + logger.warning(f"Failed to generate enough QA pairs for chunk {idx + 1}") + else: + # Validate that answers exist in chunk content + valid_pairs = [] + for q, a, d in qa_pairs: + if a.lower() in chunk_contents[idx].lower(): + valid_pairs.append((q, a, d)) + else: + logger.warning(f"Answer not found in chunk content: {a}") + + if len(valid_pairs) >= num_questions: + results[idx] = valid_pairs[:num_questions] + else: + failed_indices.append(idx) + logger.warning(f"Not enough valid QA pairs for chunk {idx + 1}") + + # Retry failed prompts in batch + if failed_indices: + logger.info(f"Retrying {len(failed_indices)} failed prompt(s)...") + retry_prompts = [prompts[i] for i in failed_indices] + retry_outputs = batch_generate(retry_prompts) + for j, idx in enumerate(failed_indices): + qa_pairs = parse_multiple_qa_output(retry_outputs[j]) + if qa_pairs is not None and len(qa_pairs) >= num_questions: + # Validate answers again + valid_pairs = [] + for q, a, d in qa_pairs: + if a.lower() in chunk_contents[idx].lower(): + valid_pairs.append((q, a, d)) + + if len(valid_pairs) >= num_questions: + results[idx] = valid_pairs[:num_questions] + else: + results[idx] = None + logger.warning( + f"Retry failed for chunk {idx + 1}: not enough valid QA pairs" + ) + else: + results[idx] = None + logger.warning(f"Retry failed for chunk {idx + 1}: parsing failed") + + # Build final output, skipping prompts that failed even after retry + final_questions = [] + for i, qa_list in enumerate(results): + if qa_list is not None: + for qa in qa_list: + final_questions.append( + { + "chunk_id": chunk_ids[i], + "question": qa[0], + "answer": qa[1], + "difficulty": qa[2], + "chunk_content": chunk_contents[i], + } + ) + + logger.info(f"Generated {len(final_questions)} valid QA pairs") + return final_questions + + +# Generate QA pairs in batch (using a sliding window over the chunks) +all_questions = generate_question_batch_for_chunks( + chunks, num_questions=2, difficulty="medium" +) +print(f"Generated {len(all_questions)} QA pairs.") + +# Save the QA pairs to a JSON file +questions_path = os.path.join("saved_data", "questions.json") +with open(questions_path, "w") as f: + json.dump(all_questions, f, indent=2) +print(f"Saved questions to {questions_path}") diff --git a/notebooks/250325_visualize_reward_function.ipynb b/notebooks/250325_visualize_reward_function.ipynb new file mode 100644 index 0000000..04c691a --- /dev/null +++ b/notebooks/250325_visualize_reward_function.ipynb @@ -0,0 +1,113 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install matplotlib -q" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "\n", + "def plot_reward_functions():\n", + " # Generate retry counts from 0 to 15\n", + " retries = np.linspace(0, 15, 100)\n", + " \n", + " # 1. Basic Sigmoid\n", + " basic_sigmoid = 1 / (1 + np.exp(-(retries - 4)))\n", + " \n", + " # 2. Our Modified Sigmoid\n", + " x = retries - 4 # Center at 4 retries\n", + " modified_sigmoid = 1 / (1 + np.exp(-x + abs(x)/2))\n", + " \n", + " # 3. With Penalty\n", + " penalized_reward = modified_sigmoid.copy()\n", + " for i, r in enumerate(retries):\n", + " if r > 6:\n", + " penalty = 0.2 * (r - 6)\n", + " penalized_reward[i] = max(0.1, modified_sigmoid[i] - penalty)\n", + " \n", + " # Plotting\n", + " plt.figure(figsize=(12, 6))\n", + " \n", + " plt.plot(retries, basic_sigmoid, 'b--', label='Basic Sigmoid')\n", + " plt.plot(retries, modified_sigmoid, 'g--', label='Modified Sigmoid')\n", + " plt.plot(retries, penalized_reward, 'r-', label='Final Reward (with penalty)', linewidth=2)\n", + " \n", + " # Add vertical lines for key points\n", + " plt.axvline(x=4, color='gray', linestyle=':', alpha=0.5, label='Peak (4 retries)')\n", + " plt.axvline(x=6, color='gray', linestyle=':', alpha=0.5, label='Penalty Start (6 retries)')\n", + " \n", + " plt.grid(True, alpha=0.3)\n", + " plt.xlabel('Number of Retries')\n", + " plt.ylabel('Reward')\n", + " plt.title('Reward Function Visualization')\n", + " plt.legend()\n", + " plt.ylim(-0.1, 1.1)\n", + " \n", + " # Add annotations\n", + " plt.annotate('Optimal Zone', xy=(4, 0.8), xytext=(4, 0.9),\n", + " ha='center', va='bottom',\n", + " bbox=dict(boxstyle='round,pad=0.5', fc='yellow', alpha=0.3),\n", + " arrowprops=dict(arrowstyle='->'))\n", + " \n", + " plt.annotate('Penalty Zone', xy=(8, 0.3), xytext=(8, 0.5),\n", + " ha='center', va='bottom',\n", + " bbox=dict(boxstyle='round,pad=0.5', fc='red', alpha=0.3),\n", + " arrowprops=dict(arrowstyle='->'))\n", + " \n", + " plt.show()\n", + "\n", + "# Run the visualization\n", + "plot_reward_functions()\n", + "\n", + "# Print reward values for specific retry counts\n", + "def print_reward_examples():\n", + " retry_examples = [1, 2, 3, 4, 5, 6, 7, 8, 10, 12]\n", + " print(\"\\nReward values for different retry counts:\")\n", + " print(\"Retries | Reward\")\n", + " print(\"-\" * 20)\n", + " \n", + " for retries in retry_examples:\n", + " x = retries - 4\n", + " reward = 1 / (1 + np.exp(-x + abs(x)/2))\n", + " if retries > 6:\n", + " penalty = 0.2 * (retries - 6)\n", + " reward = max(0.1, reward - penalty)\n", + " print(f\"{retries:7d} | {reward:.3f}\")\n", + "\n", + "print_reward_examples()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/rl_helpers.py b/rl_helpers.py index 9df28f5..44ecee4 100644 --- a/rl_helpers.py +++ b/rl_helpers.py @@ -9,12 +9,26 @@ import json import re from dataclasses import dataclass from datetime import datetime +from pathlib import Path import nest_asyncio +import numpy as np import torch +from loguru import logger from search_module import get_qa_dataset, search +# Setup loguru +log_dir = Path("logs") +log_dir.mkdir(exist_ok=True) +logger.add( + log_dir / "rl_helpers_{time}.log", + rotation="500 MB", + retention="10 days", + level="DEBUG", + format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}", +) + nest_asyncio.apply() from typing import Callable, List @@ -158,38 +172,41 @@ def remove_reasoning(text: str) -> str: def run_agent_generations(generate_fn, tokenizer, chat_states): """ Run generation for chat states requiring assistant responses. - - Args: - generate_fn: Function to generate responses - tokenizer: Tokenizer for processing text - chat_states: List of chat states - - Returns: - list: Updated chat states """ + logger.debug(f"Starting generation for {len(chat_states)} chat states") prompts = [] batch_indices = [] # Prepare prompts for chat states needing an assistant response. for idx, chat_state in enumerate(chat_states): if chat_state.get("finished"): + logger.debug(f"Chat state {idx} already finished, skipping") continue if chat_state["messages"][-1]["role"] in ["ipython", "user"]: prompt = apply_chat_template(chat_state, tokenizer=tokenizer)["text"] prompts.append(prompt) batch_indices.append(idx) + logger.debug(f"Added prompt for chat state {idx}") if prompts: + logger.info(f"Generating responses for {len(prompts)} prompts") responses = generate_fn(prompts) for i, idx in enumerate(batch_indices): chat_state = chat_states[idx] - full_response = responses[i].outputs[0].text + response = responses[i] + if hasattr(response, "outputs"): + full_response = response.outputs[0].text + else: + full_response = response assistant_response = full_response.split( "<|start_header_id|>assistant<|end_header_id|>" )[-1] chat_state["messages"].append( {"role": "assistant", "content": assistant_response} ) + logger.debug(f"Added assistant response to chat state {idx}") + else: + logger.debug("No prompts to generate responses for") return chat_states @@ -219,15 +236,11 @@ def check_finished_chats(chat_states): def run_tool_calls(chat_states): """ Execute tool calls found in chat states. - - Args: - chat_states: List of chat states - - Returns: - list: Updated chat states with tool call results """ + logger.debug(f"Running tool calls for {len(chat_states)} chat states") for chat_state in chat_states: if chat_state.get("finished"): + logger.debug("Chat state already finished, skipping tool calls") continue assert ( chat_state["messages"][-1]["role"] == "assistant" @@ -236,15 +249,19 @@ def run_tool_calls(chat_states): assistant_response = chat_state["messages"][-1]["content"] function_calls = extract_json_objects(assistant_response) if len(function_calls) > 1: + logger.warning("Multiple function calls found in assistant response") raise ValueError( "Expected only one function call in assistant response" ) elif len(function_calls) == 1: function_call = function_calls[0] query = function_call["function"]["parameters"]["query"] + logger.info(f"Executing search with query: {query}") results = search(query, return_type=str, results=2) chat_state["messages"].append({"role": "ipython", "content": results}) + logger.debug("Added search results to chat state") except Exception as e: + logger.error(f"Error during tool call: {str(e)}") chat_state["messages"].append( {"role": "system", "content": f"Error during post-processing: {str(e)}"} ) @@ -319,43 +336,49 @@ def run_agent( ): """ Run the agent to completion for a batch of questions. - - Args: - generate_fn: Function to generate model responses - tokenizer: Tokenizer for processing text - batch: Batch of data containing questions - max_generations: Maximum number of generation steps - - Returns: - list: Final answers for each question """ + logger.info(f"Starting agent run with {len(questions)} questions") + logger.debug( + f"Max generations: {max_generations}, Max new tokens: {max_new_tokens}" + ) + chat_states = [get_initial_chat(q) for q in questions] # set the initial_prompt length - for chat_state in chat_states: + for i, chat_state in enumerate(chat_states): chat_state["initial_length"] = get_chat_num_tokens(chat_state, tokenizer) + logger.debug(f"Initial length for question {i}: {chat_state['initial_length']}") # agent loop for i in range(max_generations): + logger.info(f"Starting generation step {i+1}/{max_generations}") chat_states = run_agent_generations(generate_fn, tokenizer, chat_states) chat_states = check_finished_chats(chat_states) chat_states = run_tool_calls(chat_states) chat_states = check_exceeded_max_new_tokens( chat_states, max_new_tokens, tokenizer ) + finished_count = sum(1 for state in chat_states if state.get("finished")) + logger.info( + f"Finished {finished_count}/{len(chat_states)} chat states after step {i+1}" + ) + + logger.info("Agent run completed") + # Process final outputs + logger.debug("Processing final outputs") answers = [] for chat in chat_states: answers.append(chat["messages"][-1]["content"]) + logger.debug(f"Final answer: {chat['messages'][-1]['content'][:100]}...") def split_prompt_assistant(convo_text): marker = "<|start_header_id|>assistant<|end_header_id|>" idx = convo_text.find(marker) if idx == -1: + logger.error("Could not find assistant marker in conversation text") raise ValueError("Could not find assistant marker in conversation text.") return convo_text, "" - # Include the marker in the prompt by slicing up to the end of the marker. prompt = convo_text[: idx + len(marker)] - # The assistant response is everything after the marker. assistant_response = convo_text[idx + len(marker) :] return prompt, assistant_response @@ -363,7 +386,9 @@ def run_agent( apply_chat_template(chat, tokenizer=tokenizer)["text"] for chat in chat_states ] prompt_toks, response_toks, response_masks = [], [], [] - for str_chat in str_chats: + + logger.debug("Processing tokenization") + for i, str_chat in enumerate(str_chats): prompt, response = split_prompt_assistant(str_chat) prompt_toks.append( tokenizer(prompt, add_special_tokens=False, return_tensors="pt")[ @@ -376,12 +401,14 @@ def run_agent( ].squeeze()[:max_new_tokens] ) mask = get_mask(str_chat, tokenizer)[len(prompt_toks[-1]) :][:max_new_tokens] - response_masks.append(mask) + logger.debug(f"Processed tokens for chat {i}") final_response_str = [chat["messages"][-1]["content"] for chat in chat_states] full_chat_states = chat_states - agentic_outputs = AgenticOutputs( + + logger.info("Agent run completed successfully") + return AgenticOutputs( prompt_tokens=prompt_toks, response_tokens=response_toks, response_masks=response_masks, @@ -389,40 +416,27 @@ def run_agent( full_chat_states=full_chat_states, ) - return agentic_outputs - # Verification -async def check_correctness(question, student_answer, answer): +async def verify(student_answer: str, question: str, answer: str) -> bool: """ - Calculate reward for a given student answer. + Verify if student's answer matches the correct answer. Args: - question (str): The original question - student_answer (str): The model's answer - answer (str): The ground truth answer + student_answer: The model's answer + question: The original question + answer: The ground truth answer Returns: - float: Reward value (1 for correct, 0 for incorrect) + bool: True if answer is correct, False otherwise """ - # log to "./reward_func.log" - with open("reward_func.log", "a") as f: - f.write("\n" + "==" * 40 + "\n\n") - f.write(f"Question: {question}\n") - f.write(f"Student Answer: {student_answer}\n") - f.write(f"Answer: {answer}\n") - if student_answer.startswith("Error during"): - f.write(f"failed function call") - return 0 - if len(student_answer) < 5: - f.write(f"failed Too short answer\n") - return 0 - else: - f.write(f"last message didn't fail\n") - student_answer_clean = remove_reasoning(student_answer) - is_correct = await verify(student_answer_clean, question, answer) - f.write(f"Is Correct: {is_correct}, so reward is {int(is_correct)}\n") - return 1 if is_correct else 0 + logger.debug(f"Verifying answer for question: {question}") + logger.debug(f"Student answer: {student_answer}") + logger.debug(f"Correct answer: {answer}") + + # Simple string matching for now + # TODO: Implement more sophisticated matching + return student_answer.strip().lower() == answer.strip().lower() def check_student_answers( @@ -435,28 +449,19 @@ def check_student_answers( ) -> List[bool]: """ Evaluates a list of student answers against the true answers using a vLLM generate function. - The function applies the chat template to each prompt before passing it to the generate function. - It also appends the details of each QA pair and the verifier's response to a log file. - - Args: - questions: A list of strings representing the questions. - answers: A list of strings representing the correct answers. - student_answers: A list of strings containing the student's answers. - vllm_generate_func: A function that takes a list of chat-formatted prompt strings and returns a list of generated outputs. - tokenizer: The tokenizer used to apply the chat template. - log_file: Optional; path to the file where the QA pairs and verification responses will be appended. - - Returns: - A list of booleans indicating whether each student's answer is correct. """ + logger.info(f"Checking {len(questions)} student answers") + if not (len(questions) == len(answers) == len(student_answers)): + logger.error( + "Mismatched lengths between questions, answers, and student answers" + ) raise ValueError( "The number of questions, answers, and student answers must be equal." ) prompts = [] for question, answer, student_ans in zip(questions, answers, student_answers): - # Construct the plain text prompt for each QA pair. prompt_text = ( "You are grading a student's answer. For the following question, " "compare the student's answer to the correct answer. Reply with 'Yes' if the student's answer is correct, or 'No' if it is completely incorrect.\n\n" @@ -464,22 +469,30 @@ def check_student_answers( f"Correct Answer: {answer}\n" f"Student Answer: {student_ans}\n" ) - # Apply the chat template to the prompt. formatted_prompt = tokenizer.apply_chat_template( [{"role": "user", "content": prompt_text}], tokenize=False, add_generation_prompt=True, ) prompts.append(formatted_prompt) + logger.debug(f"Created verification prompt for question: {question[:50]}...") - # Get the model responses in batch (each response should ideally be "Yes" or "No") + logger.info("Generating verification responses") responses = vllm_generate_func(prompts) - responses_text = [response.outputs[0].text for response in responses] + responses_text = [] + for response in responses: + if hasattr(response, "outputs"): + responses_text.append(response.outputs[0].text) + else: + responses_text.append(response) + logger.debug(f"Got {len(responses_text)} verification responses") - # Evaluate each response and mark as correct if "yes" appears in the answer (case-insensitive) results = [] for response in responses_text: results.append("yes" in response.lower()) + logger.debug(f"Verification result: {'yes' in response.lower()}") + + logger.info(f"Verification complete. {sum(results)}/{len(results)} answers correct") # Append the QA details and verifier's response to the specified log file with open(log_file, "a") as file: @@ -526,24 +539,100 @@ def reward_formatting(prompts, completions, **reward_kwargs): return [0.7 if not e else 0 for e in has_error] -# def reward_retry_behavior(prompts, completions, **reward_kwargs): -# pass +def reward_retry_behavior(completions: list[dict], **reward_kwargs) -> list[float]: + """ + Reward function that encourages optimal retry behavior by counting total function calls + across all assistant messages in the conversation. + """ + rewards: list[float] = [] + + for completion in completions: + # Get ALL assistant messages + assistant_msgs: list[str] = [ + msg["content"] + for msg in completion["messages"] + if msg["role"] == "assistant" and msg["content"] is not None + ] + + if not assistant_msgs: + rewards.append(0.0) + continue + + # Count total function calls across all messages + total_retries: int = 0 + for msg in assistant_msgs: + total_retries += len(extract_json_objects(msg)) + + # Calculate reward using modified sigmoid function + x: float = float(total_retries - 4) # Center peak at 4 retries + base_reward: float = 1.0 / (1.0 + np.exp(-x + abs(x) / 2)) + + # Additional penalty for excessive retries + if total_retries > 6: + penalty: float = 0.2 * (total_retries - 6) + base_reward = max(0.1, base_reward - penalty) + rewards.append(base_reward) -# def reward_exact_match_chunk_query(prompts, completions, **reward_kwargs): -# pass + return rewards + + +def reward_exact_match_chunk_query(prompts, completions, **reward_kwargs): + """ + Reward function that checks if the model's search queries hit the correct chunk content. + """ + logger.debug(f"Calculating rewards for {len(prompts)} prompts") + + # Get correct chunk contents from reward kwargs + correct_contents = reward_kwargs.get("chunk_content", []) + if not correct_contents: + logger.error("No chunk_content provided in reward_kwargs") + raise ValueError("chunk_content must be provided in reward_kwargs") + + rewards = [] + for i, (chat_state, correct_content) in enumerate( + zip(completions, correct_contents) + ): + # Get all ipython messages (search results) from the chat + search_results = [ + msg["content"] for msg in chat_state["messages"] if msg["role"] == "ipython" + ] + logger.debug(f"Found {len(search_results)} search results for prompt {i}") + + # Check if any search hit the correct chunk content + found_correct_chunk = False + for result in search_results: + if correct_content in result: + found_correct_chunk = True + logger.debug( + f"Found correct chunk content in search results for prompt {i}" + ) + break + + reward = 1.0 if found_correct_chunk else 0.0 + rewards.append(reward) + logger.debug(f"Reward for prompt {i}: {reward}") + + logger.info(f"Average reward: {sum(rewards)/len(rewards):.3f}") + return rewards def run_eval(generate_fn, verify_fn, tokenizer): + logger.info("Starting evaluation") train_dataset, test_dataset = get_qa_dataset() questions = test_dataset["prompt"] + logger.info(f"Loaded {len(questions)} test questions") + agentic_outputs = run_agent(generate_fn, tokenizer, questions) full_chat_states = agentic_outputs.full_chat_states final_responses = agentic_outputs.final_response_str + + logger.info("Calculating rewards") rewards = verify_fn(questions, full_chat_states, answer=test_dataset["answer"]) + avg_reward = sum(rewards) / len(rewards) - print("RESULTS:") - print("percentage of correct answers:", sum(rewards) / len(rewards)) - print("=" * 30) + logger.info("EVALUATION RESULTS:") + logger.info(f"Percentage of correct answers: {avg_reward:.3f}") + logger.info("=" * 30) return full_chat_states