feat: add new reward functions, add less dumb data generation logic, implement better logging

main
thinhlpg 2 months ago
parent b22b02ea1d
commit 04d56325bb

@ -0,0 +1,279 @@
"""
This script performs two main tasks:
1. It loads a markdown document, splits it into chunks, generates embeddings,
and builds a FAISS index (which is saved locally).
2. It generates QA pairs from the document using llama.
For each chunk (using a sliding window for context), it generates multiple question-answer pairs
with different difficulties. The generation is performed in batch with one retry for failed prompts.
Successfully generated QA pairs are saved to "saved_data/questions.json".
Requirements:
pip install langchain faiss-cpu unsloth vllm
"""
import json
import os
import re
from typing import Dict, List, Optional, Tuple
import pandas as pd
from langchain.text_splitter import RecursiveCharacterTextSplitter
from loguru import logger
# Configure logger
logger.add(
"logs/generate_data_{time}.log",
rotation="500 MB",
retention="10 days",
level="INFO",
)
# ========= Part 1: Document Processing and Embedding Generation =========
# Load and split the markdown document using LangChain
from langchain_community.document_loaders import UnstructuredMarkdownLoader
from langchain_community.vectorstores import FAISS
from embeddings import CustomHuggingFaceEmbeddings
# Load your markdown file (adjust the path as needed)
loader = UnstructuredMarkdownLoader("./data/mission_report.md")
docs = loader.load()
# Split the document into smaller chunks (each 1000 characters, no overlap)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
chunks = text_splitter.split_documents(docs)
# Create output directory
os.makedirs("saved_data", exist_ok=True)
# Save chunks to CSV for easy inspection
chunks_df = pd.DataFrame(
{
"chunk_id": range(1, len(chunks) + 1),
"content": [chunk.page_content for chunk in chunks],
"metadata": [chunk.metadata for chunk in chunks],
}
)
chunks_df.to_csv("saved_data/chunks.csv", index=False)
print(f"Saved {len(chunks)} chunks to saved_data/chunks.csv")
embeddings = CustomHuggingFaceEmbeddings()
# Create a FAISS vector store from the document chunks and save it locally
vectorstore = FAISS.from_documents(chunks, embeddings)
vectorstore.save_local("faiss_index")
print("Saved FAISS index to 'faiss_index'")
# TODO: add the paraphrased chunks to the vector store
# ========= Part 2: QA Generation using Llama Backend =========
# Setup Llama backend via unsloth and vLLM
from unsloth import FastLanguageModel
from vllm import SamplingParams
import rl_helpers # Ensure you have this or remove if not used
# Load the Llama model (adjust parameters as needed)
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="meta-llama/meta-Llama-3.1-8B-Instruct",
max_seq_length=4096,
load_in_4bit=True, # Use 4-bit quantization if desired
fast_inference=True, # Enable fast inference
gpu_memory_utilization=0.6, # Adjust based on your GPU memory
)
# Define sampling parameters for generation
sampling_params = SamplingParams(
temperature=0.3,
top_p=0.95,
max_tokens=4096,
)
def batch_generate(prompts: List[str]) -> List[str]:
"""
Given a list of prompt strings, returns a list of generated outputs.
"""
def format_input(text: str) -> str:
return tokenizer.apply_chat_template(
[{"role": "user", "content": text}],
tokenize=False,
add_generation_prompt=True,
)
formatted = [format_input(p) for p in prompts]
outputs = model.fast_generate(formatted, sampling_params=sampling_params)
return [output.outputs[0].text for output in outputs]
def parse_qa_block(block: str) -> Optional[Tuple[str, str, str]]:
"""
Parses a QA block that should contain exactly three non-empty lines:
- A line starting with "Question:"
- A line starting with "Answer:"
- A line starting with "Difficulty:"
If the markers are not present but the block contains exactly three lines,
those are used in order.
Returns a tuple (question, answer, difficulty) or None if parsing fails.
"""
lines = [line.strip() for line in block.splitlines() if line.strip()]
if not lines:
return None
question, answer, difficulty = None, None, None
for line in lines:
lower = line.lower()
if question is None and lower.startswith("question:"):
question = line[len("question:") :].strip()
elif answer is None and lower.startswith("answer:"):
answer = line[len("answer:") :].strip()
elif difficulty is None and lower.startswith("difficulty:"):
difficulty = line[len("difficulty:") :].strip()
if question and answer and difficulty:
return question, answer, difficulty
if len(lines) == 3:
return lines[0], lines[1], lines[2]
return None
def parse_multiple_qa_output(output: str) -> List[Tuple[str, str, str]]:
"""
Splits the output into blocks (separated by one or more blank lines) and
attempts to parse each as a QA pair.
Returns a list of successfully parsed QA tuples.
"""
blocks = re.split(r"\n\s*\n", output.strip())
qa_pairs = []
for block in blocks:
parsed = parse_qa_block(block)
if parsed:
qa_pairs.append(parsed)
return qa_pairs
def generate_question_batch_for_chunks(
chunks: List, num_questions: int = 2, difficulty: Optional[str] = None
) -> List[Dict]:
"""
Generates QA pairs for multiple chunks in batch.
For each chunk, generates questions based on its content only.
Each prompt instructs the model to output exactly three lines per QA pair with markers.
Failed prompts are retried once in batch; if still unsuccessful, they are skipped.
Returns a list of dicts with keys: "chunk_id", "question", "answer", "difficulty", "chunk_content".
"""
prompts = []
chunk_ids = []
chunk_contents = []
# Prepare prompts for each chunk
for i, chunk in enumerate(chunks):
current = chunk.page_content
prompt = (
f"You are a question generator. Generate {num_questions} questions based on the following text.\n"
"Rules:\n"
"1. Questions must be answerable using ONLY the information in the text\n"
"2. Answers must be directly stated in the text\n"
"3. Each question should test understanding of a different aspect of the text\n"
"4. Questions should be clear and specific\n"
"5. Answers should be concise and factual\n\n"
"For each QA pair, output exactly three lines with no extra commentary:\n"
"Line 1: Question: <your question>\n"
"Line 2: Answer: <the answer>\n"
"Line 3: Difficulty: <easy, medium, or hard>\n"
"Do not include any additional text.\n\n"
"Text:\n"
f"{current}\n"
)
prompts.append(prompt)
chunk_ids.append(i + 1) # 1-based indexing
chunk_contents.append(current)
# First batch generation
outputs = batch_generate(prompts)
results: List[Optional[List[Tuple[str, str, str]]]] = [None] * len(outputs)
failed_indices = []
# Parse each output
for idx, output in enumerate(outputs):
qa_pairs = parse_multiple_qa_output(output)
if qa_pairs is None or len(qa_pairs) < num_questions:
failed_indices.append(idx)
logger.warning(f"Failed to generate enough QA pairs for chunk {idx + 1}")
else:
# Validate that answers exist in chunk content
valid_pairs = []
for q, a, d in qa_pairs:
if a.lower() in chunk_contents[idx].lower():
valid_pairs.append((q, a, d))
else:
logger.warning(f"Answer not found in chunk content: {a}")
if len(valid_pairs) >= num_questions:
results[idx] = valid_pairs[:num_questions]
else:
failed_indices.append(idx)
logger.warning(f"Not enough valid QA pairs for chunk {idx + 1}")
# Retry failed prompts in batch
if failed_indices:
logger.info(f"Retrying {len(failed_indices)} failed prompt(s)...")
retry_prompts = [prompts[i] for i in failed_indices]
retry_outputs = batch_generate(retry_prompts)
for j, idx in enumerate(failed_indices):
qa_pairs = parse_multiple_qa_output(retry_outputs[j])
if qa_pairs is not None and len(qa_pairs) >= num_questions:
# Validate answers again
valid_pairs = []
for q, a, d in qa_pairs:
if a.lower() in chunk_contents[idx].lower():
valid_pairs.append((q, a, d))
if len(valid_pairs) >= num_questions:
results[idx] = valid_pairs[:num_questions]
else:
results[idx] = None
logger.warning(
f"Retry failed for chunk {idx + 1}: not enough valid QA pairs"
)
else:
results[idx] = None
logger.warning(f"Retry failed for chunk {idx + 1}: parsing failed")
# Build final output, skipping prompts that failed even after retry
final_questions = []
for i, qa_list in enumerate(results):
if qa_list is not None:
for qa in qa_list:
final_questions.append(
{
"chunk_id": chunk_ids[i],
"question": qa[0],
"answer": qa[1],
"difficulty": qa[2],
"chunk_content": chunk_contents[i],
}
)
logger.info(f"Generated {len(final_questions)} valid QA pairs")
return final_questions
# Generate QA pairs in batch (using a sliding window over the chunks)
all_questions = generate_question_batch_for_chunks(
chunks, num_questions=2, difficulty="medium"
)
print(f"Generated {len(all_questions)} QA pairs.")
# Save the QA pairs to a JSON file
questions_path = os.path.join("saved_data", "questions.json")
with open(questions_path, "w") as f:
json.dump(all_questions, f, indent=2)
print(f"Saved questions to {questions_path}")

@ -0,0 +1,113 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!pip install matplotlib -q"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"def plot_reward_functions():\n",
" # Generate retry counts from 0 to 15\n",
" retries = np.linspace(0, 15, 100)\n",
" \n",
" # 1. Basic Sigmoid\n",
" basic_sigmoid = 1 / (1 + np.exp(-(retries - 4)))\n",
" \n",
" # 2. Our Modified Sigmoid\n",
" x = retries - 4 # Center at 4 retries\n",
" modified_sigmoid = 1 / (1 + np.exp(-x + abs(x)/2))\n",
" \n",
" # 3. With Penalty\n",
" penalized_reward = modified_sigmoid.copy()\n",
" for i, r in enumerate(retries):\n",
" if r > 6:\n",
" penalty = 0.2 * (r - 6)\n",
" penalized_reward[i] = max(0.1, modified_sigmoid[i] - penalty)\n",
" \n",
" # Plotting\n",
" plt.figure(figsize=(12, 6))\n",
" \n",
" plt.plot(retries, basic_sigmoid, 'b--', label='Basic Sigmoid')\n",
" plt.plot(retries, modified_sigmoid, 'g--', label='Modified Sigmoid')\n",
" plt.plot(retries, penalized_reward, 'r-', label='Final Reward (with penalty)', linewidth=2)\n",
" \n",
" # Add vertical lines for key points\n",
" plt.axvline(x=4, color='gray', linestyle=':', alpha=0.5, label='Peak (4 retries)')\n",
" plt.axvline(x=6, color='gray', linestyle=':', alpha=0.5, label='Penalty Start (6 retries)')\n",
" \n",
" plt.grid(True, alpha=0.3)\n",
" plt.xlabel('Number of Retries')\n",
" plt.ylabel('Reward')\n",
" plt.title('Reward Function Visualization')\n",
" plt.legend()\n",
" plt.ylim(-0.1, 1.1)\n",
" \n",
" # Add annotations\n",
" plt.annotate('Optimal Zone', xy=(4, 0.8), xytext=(4, 0.9),\n",
" ha='center', va='bottom',\n",
" bbox=dict(boxstyle='round,pad=0.5', fc='yellow', alpha=0.3),\n",
" arrowprops=dict(arrowstyle='->'))\n",
" \n",
" plt.annotate('Penalty Zone', xy=(8, 0.3), xytext=(8, 0.5),\n",
" ha='center', va='bottom',\n",
" bbox=dict(boxstyle='round,pad=0.5', fc='red', alpha=0.3),\n",
" arrowprops=dict(arrowstyle='->'))\n",
" \n",
" plt.show()\n",
"\n",
"# Run the visualization\n",
"plot_reward_functions()\n",
"\n",
"# Print reward values for specific retry counts\n",
"def print_reward_examples():\n",
" retry_examples = [1, 2, 3, 4, 5, 6, 7, 8, 10, 12]\n",
" print(\"\\nReward values for different retry counts:\")\n",
" print(\"Retries | Reward\")\n",
" print(\"-\" * 20)\n",
" \n",
" for retries in retry_examples:\n",
" x = retries - 4\n",
" reward = 1 / (1 + np.exp(-x + abs(x)/2))\n",
" if retries > 6:\n",
" penalty = 0.2 * (retries - 6)\n",
" reward = max(0.1, reward - penalty)\n",
" print(f\"{retries:7d} | {reward:.3f}\")\n",
"\n",
"print_reward_examples()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

@ -9,12 +9,26 @@ import json
import re import re
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from pathlib import Path
import nest_asyncio import nest_asyncio
import numpy as np
import torch import torch
from loguru import logger
from search_module import get_qa_dataset, search from search_module import get_qa_dataset, search
# Setup loguru
log_dir = Path("logs")
log_dir.mkdir(exist_ok=True)
logger.add(
log_dir / "rl_helpers_{time}.log",
rotation="500 MB",
retention="10 days",
level="DEBUG",
format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}",
)
nest_asyncio.apply() nest_asyncio.apply()
from typing import Callable, List from typing import Callable, List
@ -158,38 +172,41 @@ def remove_reasoning(text: str) -> str:
def run_agent_generations(generate_fn, tokenizer, chat_states): def run_agent_generations(generate_fn, tokenizer, chat_states):
""" """
Run generation for chat states requiring assistant responses. Run generation for chat states requiring assistant responses.
Args:
generate_fn: Function to generate responses
tokenizer: Tokenizer for processing text
chat_states: List of chat states
Returns:
list: Updated chat states
""" """
logger.debug(f"Starting generation for {len(chat_states)} chat states")
prompts = [] prompts = []
batch_indices = [] batch_indices = []
# Prepare prompts for chat states needing an assistant response. # Prepare prompts for chat states needing an assistant response.
for idx, chat_state in enumerate(chat_states): for idx, chat_state in enumerate(chat_states):
if chat_state.get("finished"): if chat_state.get("finished"):
logger.debug(f"Chat state {idx} already finished, skipping")
continue continue
if chat_state["messages"][-1]["role"] in ["ipython", "user"]: if chat_state["messages"][-1]["role"] in ["ipython", "user"]:
prompt = apply_chat_template(chat_state, tokenizer=tokenizer)["text"] prompt = apply_chat_template(chat_state, tokenizer=tokenizer)["text"]
prompts.append(prompt) prompts.append(prompt)
batch_indices.append(idx) batch_indices.append(idx)
logger.debug(f"Added prompt for chat state {idx}")
if prompts: if prompts:
logger.info(f"Generating responses for {len(prompts)} prompts")
responses = generate_fn(prompts) responses = generate_fn(prompts)
for i, idx in enumerate(batch_indices): for i, idx in enumerate(batch_indices):
chat_state = chat_states[idx] chat_state = chat_states[idx]
full_response = responses[i].outputs[0].text response = responses[i]
if hasattr(response, "outputs"):
full_response = response.outputs[0].text
else:
full_response = response
assistant_response = full_response.split( assistant_response = full_response.split(
"<|start_header_id|>assistant<|end_header_id|>" "<|start_header_id|>assistant<|end_header_id|>"
)[-1] )[-1]
chat_state["messages"].append( chat_state["messages"].append(
{"role": "assistant", "content": assistant_response} {"role": "assistant", "content": assistant_response}
) )
logger.debug(f"Added assistant response to chat state {idx}")
else:
logger.debug("No prompts to generate responses for")
return chat_states return chat_states
@ -219,15 +236,11 @@ def check_finished_chats(chat_states):
def run_tool_calls(chat_states): def run_tool_calls(chat_states):
""" """
Execute tool calls found in chat states. Execute tool calls found in chat states.
Args:
chat_states: List of chat states
Returns:
list: Updated chat states with tool call results
""" """
logger.debug(f"Running tool calls for {len(chat_states)} chat states")
for chat_state in chat_states: for chat_state in chat_states:
if chat_state.get("finished"): if chat_state.get("finished"):
logger.debug("Chat state already finished, skipping tool calls")
continue continue
assert ( assert (
chat_state["messages"][-1]["role"] == "assistant" chat_state["messages"][-1]["role"] == "assistant"
@ -236,15 +249,19 @@ def run_tool_calls(chat_states):
assistant_response = chat_state["messages"][-1]["content"] assistant_response = chat_state["messages"][-1]["content"]
function_calls = extract_json_objects(assistant_response) function_calls = extract_json_objects(assistant_response)
if len(function_calls) > 1: if len(function_calls) > 1:
logger.warning("Multiple function calls found in assistant response")
raise ValueError( raise ValueError(
"Expected only one function call in assistant response" "Expected only one function call in assistant response"
) )
elif len(function_calls) == 1: elif len(function_calls) == 1:
function_call = function_calls[0] function_call = function_calls[0]
query = function_call["function"]["parameters"]["query"] query = function_call["function"]["parameters"]["query"]
logger.info(f"Executing search with query: {query}")
results = search(query, return_type=str, results=2) results = search(query, return_type=str, results=2)
chat_state["messages"].append({"role": "ipython", "content": results}) chat_state["messages"].append({"role": "ipython", "content": results})
logger.debug("Added search results to chat state")
except Exception as e: except Exception as e:
logger.error(f"Error during tool call: {str(e)}")
chat_state["messages"].append( chat_state["messages"].append(
{"role": "system", "content": f"Error during post-processing: {str(e)}"} {"role": "system", "content": f"Error during post-processing: {str(e)}"}
) )
@ -319,43 +336,49 @@ def run_agent(
): ):
""" """
Run the agent to completion for a batch of questions. Run the agent to completion for a batch of questions.
Args:
generate_fn: Function to generate model responses
tokenizer: Tokenizer for processing text
batch: Batch of data containing questions
max_generations: Maximum number of generation steps
Returns:
list: Final answers for each question
""" """
logger.info(f"Starting agent run with {len(questions)} questions")
logger.debug(
f"Max generations: {max_generations}, Max new tokens: {max_new_tokens}"
)
chat_states = [get_initial_chat(q) for q in questions] chat_states = [get_initial_chat(q) for q in questions]
# set the initial_prompt length # set the initial_prompt length
for chat_state in chat_states: for i, chat_state in enumerate(chat_states):
chat_state["initial_length"] = get_chat_num_tokens(chat_state, tokenizer) chat_state["initial_length"] = get_chat_num_tokens(chat_state, tokenizer)
logger.debug(f"Initial length for question {i}: {chat_state['initial_length']}")
# agent loop # agent loop
for i in range(max_generations): for i in range(max_generations):
logger.info(f"Starting generation step {i+1}/{max_generations}")
chat_states = run_agent_generations(generate_fn, tokenizer, chat_states) chat_states = run_agent_generations(generate_fn, tokenizer, chat_states)
chat_states = check_finished_chats(chat_states) chat_states = check_finished_chats(chat_states)
chat_states = run_tool_calls(chat_states) chat_states = run_tool_calls(chat_states)
chat_states = check_exceeded_max_new_tokens( chat_states = check_exceeded_max_new_tokens(
chat_states, max_new_tokens, tokenizer chat_states, max_new_tokens, tokenizer
) )
finished_count = sum(1 for state in chat_states if state.get("finished"))
logger.info(
f"Finished {finished_count}/{len(chat_states)} chat states after step {i+1}"
)
logger.info("Agent run completed")
# Process final outputs
logger.debug("Processing final outputs")
answers = [] answers = []
for chat in chat_states: for chat in chat_states:
answers.append(chat["messages"][-1]["content"]) answers.append(chat["messages"][-1]["content"])
logger.debug(f"Final answer: {chat['messages'][-1]['content'][:100]}...")
def split_prompt_assistant(convo_text): def split_prompt_assistant(convo_text):
marker = "<|start_header_id|>assistant<|end_header_id|>" marker = "<|start_header_id|>assistant<|end_header_id|>"
idx = convo_text.find(marker) idx = convo_text.find(marker)
if idx == -1: if idx == -1:
logger.error("Could not find assistant marker in conversation text")
raise ValueError("Could not find assistant marker in conversation text.") raise ValueError("Could not find assistant marker in conversation text.")
return convo_text, "" return convo_text, ""
# Include the marker in the prompt by slicing up to the end of the marker.
prompt = convo_text[: idx + len(marker)] prompt = convo_text[: idx + len(marker)]
# The assistant response is everything after the marker.
assistant_response = convo_text[idx + len(marker) :] assistant_response = convo_text[idx + len(marker) :]
return prompt, assistant_response return prompt, assistant_response
@ -363,7 +386,9 @@ def run_agent(
apply_chat_template(chat, tokenizer=tokenizer)["text"] for chat in chat_states apply_chat_template(chat, tokenizer=tokenizer)["text"] for chat in chat_states
] ]
prompt_toks, response_toks, response_masks = [], [], [] prompt_toks, response_toks, response_masks = [], [], []
for str_chat in str_chats:
logger.debug("Processing tokenization")
for i, str_chat in enumerate(str_chats):
prompt, response = split_prompt_assistant(str_chat) prompt, response = split_prompt_assistant(str_chat)
prompt_toks.append( prompt_toks.append(
tokenizer(prompt, add_special_tokens=False, return_tensors="pt")[ tokenizer(prompt, add_special_tokens=False, return_tensors="pt")[
@ -376,12 +401,14 @@ def run_agent(
].squeeze()[:max_new_tokens] ].squeeze()[:max_new_tokens]
) )
mask = get_mask(str_chat, tokenizer)[len(prompt_toks[-1]) :][:max_new_tokens] mask = get_mask(str_chat, tokenizer)[len(prompt_toks[-1]) :][:max_new_tokens]
response_masks.append(mask) response_masks.append(mask)
logger.debug(f"Processed tokens for chat {i}")
final_response_str = [chat["messages"][-1]["content"] for chat in chat_states] final_response_str = [chat["messages"][-1]["content"] for chat in chat_states]
full_chat_states = chat_states full_chat_states = chat_states
agentic_outputs = AgenticOutputs(
logger.info("Agent run completed successfully")
return AgenticOutputs(
prompt_tokens=prompt_toks, prompt_tokens=prompt_toks,
response_tokens=response_toks, response_tokens=response_toks,
response_masks=response_masks, response_masks=response_masks,
@ -389,40 +416,27 @@ def run_agent(
full_chat_states=full_chat_states, full_chat_states=full_chat_states,
) )
return agentic_outputs
# Verification # Verification
async def check_correctness(question, student_answer, answer): async def verify(student_answer: str, question: str, answer: str) -> bool:
""" """
Calculate reward for a given student answer. Verify if student's answer matches the correct answer.
Args: Args:
question (str): The original question student_answer: The model's answer
student_answer (str): The model's answer question: The original question
answer (str): The ground truth answer answer: The ground truth answer
Returns: Returns:
float: Reward value (1 for correct, 0 for incorrect) bool: True if answer is correct, False otherwise
""" """
# log to "./reward_func.log" logger.debug(f"Verifying answer for question: {question}")
with open("reward_func.log", "a") as f: logger.debug(f"Student answer: {student_answer}")
f.write("\n" + "==" * 40 + "\n\n") logger.debug(f"Correct answer: {answer}")
f.write(f"Question: {question}\n")
f.write(f"Student Answer: {student_answer}\n") # Simple string matching for now
f.write(f"Answer: {answer}\n") # TODO: Implement more sophisticated matching
if student_answer.startswith("Error during"): return student_answer.strip().lower() == answer.strip().lower()
f.write(f"failed function call")
return 0
if len(student_answer) < 5:
f.write(f"failed Too short answer\n")
return 0
else:
f.write(f"last message didn't fail\n")
student_answer_clean = remove_reasoning(student_answer)
is_correct = await verify(student_answer_clean, question, answer)
f.write(f"Is Correct: {is_correct}, so reward is {int(is_correct)}\n")
return 1 if is_correct else 0
def check_student_answers( def check_student_answers(
@ -435,28 +449,19 @@ def check_student_answers(
) -> List[bool]: ) -> List[bool]:
""" """
Evaluates a list of student answers against the true answers using a vLLM generate function. Evaluates a list of student answers against the true answers using a vLLM generate function.
The function applies the chat template to each prompt before passing it to the generate function.
It also appends the details of each QA pair and the verifier's response to a log file.
Args:
questions: A list of strings representing the questions.
answers: A list of strings representing the correct answers.
student_answers: A list of strings containing the student's answers.
vllm_generate_func: A function that takes a list of chat-formatted prompt strings and returns a list of generated outputs.
tokenizer: The tokenizer used to apply the chat template.
log_file: Optional; path to the file where the QA pairs and verification responses will be appended.
Returns:
A list of booleans indicating whether each student's answer is correct.
""" """
logger.info(f"Checking {len(questions)} student answers")
if not (len(questions) == len(answers) == len(student_answers)): if not (len(questions) == len(answers) == len(student_answers)):
logger.error(
"Mismatched lengths between questions, answers, and student answers"
)
raise ValueError( raise ValueError(
"The number of questions, answers, and student answers must be equal." "The number of questions, answers, and student answers must be equal."
) )
prompts = [] prompts = []
for question, answer, student_ans in zip(questions, answers, student_answers): for question, answer, student_ans in zip(questions, answers, student_answers):
# Construct the plain text prompt for each QA pair.
prompt_text = ( prompt_text = (
"You are grading a student's answer. For the following question, " "You are grading a student's answer. For the following question, "
"compare the student's answer to the correct answer. Reply with 'Yes' if the student's answer is correct, or 'No' if it is completely incorrect.\n\n" "compare the student's answer to the correct answer. Reply with 'Yes' if the student's answer is correct, or 'No' if it is completely incorrect.\n\n"
@ -464,22 +469,30 @@ def check_student_answers(
f"Correct Answer: {answer}\n" f"Correct Answer: {answer}\n"
f"Student Answer: {student_ans}\n" f"Student Answer: {student_ans}\n"
) )
# Apply the chat template to the prompt.
formatted_prompt = tokenizer.apply_chat_template( formatted_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt_text}], [{"role": "user", "content": prompt_text}],
tokenize=False, tokenize=False,
add_generation_prompt=True, add_generation_prompt=True,
) )
prompts.append(formatted_prompt) prompts.append(formatted_prompt)
logger.debug(f"Created verification prompt for question: {question[:50]}...")
# Get the model responses in batch (each response should ideally be "Yes" or "No") logger.info("Generating verification responses")
responses = vllm_generate_func(prompts) responses = vllm_generate_func(prompts)
responses_text = [response.outputs[0].text for response in responses] responses_text = []
for response in responses:
if hasattr(response, "outputs"):
responses_text.append(response.outputs[0].text)
else:
responses_text.append(response)
logger.debug(f"Got {len(responses_text)} verification responses")
# Evaluate each response and mark as correct if "yes" appears in the answer (case-insensitive)
results = [] results = []
for response in responses_text: for response in responses_text:
results.append("yes" in response.lower()) results.append("yes" in response.lower())
logger.debug(f"Verification result: {'yes' in response.lower()}")
logger.info(f"Verification complete. {sum(results)}/{len(results)} answers correct")
# Append the QA details and verifier's response to the specified log file # Append the QA details and verifier's response to the specified log file
with open(log_file, "a") as file: with open(log_file, "a") as file:
@ -526,24 +539,100 @@ def reward_formatting(prompts, completions, **reward_kwargs):
return [0.7 if not e else 0 for e in has_error] return [0.7 if not e else 0 for e in has_error]
# def reward_retry_behavior(prompts, completions, **reward_kwargs): def reward_retry_behavior(completions: list[dict], **reward_kwargs) -> list[float]:
# pass """
Reward function that encourages optimal retry behavior by counting total function calls
across all assistant messages in the conversation.
"""
rewards: list[float] = []
for completion in completions:
# Get ALL assistant messages
assistant_msgs: list[str] = [
msg["content"]
for msg in completion["messages"]
if msg["role"] == "assistant" and msg["content"] is not None
]
if not assistant_msgs:
rewards.append(0.0)
continue
# Count total function calls across all messages
total_retries: int = 0
for msg in assistant_msgs:
total_retries += len(extract_json_objects(msg))
# Calculate reward using modified sigmoid function
x: float = float(total_retries - 4) # Center peak at 4 retries
base_reward: float = 1.0 / (1.0 + np.exp(-x + abs(x) / 2))
# Additional penalty for excessive retries
if total_retries > 6:
penalty: float = 0.2 * (total_retries - 6)
base_reward = max(0.1, base_reward - penalty)
rewards.append(base_reward)
# def reward_exact_match_chunk_query(prompts, completions, **reward_kwargs): return rewards
# pass
def reward_exact_match_chunk_query(prompts, completions, **reward_kwargs):
"""
Reward function that checks if the model's search queries hit the correct chunk content.
"""
logger.debug(f"Calculating rewards for {len(prompts)} prompts")
# Get correct chunk contents from reward kwargs
correct_contents = reward_kwargs.get("chunk_content", [])
if not correct_contents:
logger.error("No chunk_content provided in reward_kwargs")
raise ValueError("chunk_content must be provided in reward_kwargs")
rewards = []
for i, (chat_state, correct_content) in enumerate(
zip(completions, correct_contents)
):
# Get all ipython messages (search results) from the chat
search_results = [
msg["content"] for msg in chat_state["messages"] if msg["role"] == "ipython"
]
logger.debug(f"Found {len(search_results)} search results for prompt {i}")
# Check if any search hit the correct chunk content
found_correct_chunk = False
for result in search_results:
if correct_content in result:
found_correct_chunk = True
logger.debug(
f"Found correct chunk content in search results for prompt {i}"
)
break
reward = 1.0 if found_correct_chunk else 0.0
rewards.append(reward)
logger.debug(f"Reward for prompt {i}: {reward}")
logger.info(f"Average reward: {sum(rewards)/len(rewards):.3f}")
return rewards
def run_eval(generate_fn, verify_fn, tokenizer): def run_eval(generate_fn, verify_fn, tokenizer):
logger.info("Starting evaluation")
train_dataset, test_dataset = get_qa_dataset() train_dataset, test_dataset = get_qa_dataset()
questions = test_dataset["prompt"] questions = test_dataset["prompt"]
logger.info(f"Loaded {len(questions)} test questions")
agentic_outputs = run_agent(generate_fn, tokenizer, questions) agentic_outputs = run_agent(generate_fn, tokenizer, questions)
full_chat_states = agentic_outputs.full_chat_states full_chat_states = agentic_outputs.full_chat_states
final_responses = agentic_outputs.final_response_str final_responses = agentic_outputs.final_response_str
logger.info("Calculating rewards")
rewards = verify_fn(questions, full_chat_states, answer=test_dataset["answer"]) rewards = verify_fn(questions, full_chat_states, answer=test_dataset["answer"])
avg_reward = sum(rewards) / len(rewards)
print("RESULTS:") logger.info("EVALUATION RESULTS:")
print("percentage of correct answers:", sum(rewards) / len(rewards)) logger.info(f"Percentage of correct answers: {avg_reward:.3f}")
print("=" * 30) logger.info("=" * 30)
return full_chat_states return full_chat_states

Loading…
Cancel
Save