You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
245 lines
9.5 KiB
245 lines
9.5 KiB
"""
|
|
Evaluation utilities for RL training.
|
|
"""
|
|
|
|
import inspect
|
|
from datetime import datetime
|
|
|
|
from config import DATA_DIR, logger
|
|
from src.agent import Agent
|
|
from src.search_module import get_qa_dataset
|
|
from src.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter
|
|
|
|
|
|
async def verify(student_answer: str, question: str, answer: str) -> bool:
|
|
"""
|
|
Verify if student's answer matches the correct answer.
|
|
|
|
Args:
|
|
student_answer: The model's answer
|
|
question: The original question
|
|
answer: The ground truth answer
|
|
|
|
Returns:
|
|
bool: True if answer is correct, False otherwise
|
|
"""
|
|
logger.debug(f"Verifying answer for question: {question}")
|
|
logger.debug(f"Student answer: {student_answer}")
|
|
logger.debug(f"Correct answer: {answer}")
|
|
|
|
# Simple string matching for now
|
|
# TODO: Implement more sophisticated matching
|
|
return student_answer.strip().lower() == answer.strip().lower()
|
|
|
|
|
|
def check_student_answers(
|
|
questions: list[str],
|
|
answers: list[str],
|
|
student_answers: list, # Can be strings or dicts
|
|
vllm_generate_func,
|
|
tokenizer,
|
|
log_file=None,
|
|
) -> list[bool]:
|
|
"""
|
|
Evaluates a list of student answers against the true answers using a vLLM generate function.
|
|
|
|
Args:
|
|
questions: List of questions
|
|
answers: List of correct answers
|
|
student_answers: List of student answers to evaluate
|
|
vllm_generate_func: Function to generate verification responses
|
|
tokenizer: Tokenizer for formatting prompts
|
|
log_file: Optional path to write detailed results
|
|
|
|
Returns:
|
|
List of boolean results (True for correct answers)
|
|
"""
|
|
logger.info(f"Checking {len(questions)} student answers")
|
|
|
|
if not (len(questions) == len(answers) == len(student_answers)):
|
|
logger.error("Mismatched lengths between questions, answers, and student answers")
|
|
raise ValueError("The number of questions, answers, and student answers must be equal.")
|
|
|
|
prompts = []
|
|
for question, answer, student_ans in zip(questions, answers, student_answers):
|
|
prompt_text = (
|
|
"You are grading a student's answer to a question. For the following question, "
|
|
"compare the student's answer to the correct answer. Reply with 'Yes' if the student's answer contains the correct information, "
|
|
"even if it's not an exact match. If the student's answer doesn't contain the right information or is completely incorrect, reply with 'No'.\n\n"
|
|
f"Question: {question}\n"
|
|
f"Correct Answer: {answer}\n"
|
|
f"Student Answer: {student_ans}\n\n"
|
|
"Your response should be just 'Yes' or 'No'."
|
|
)
|
|
|
|
formatted_prompt = tokenizer.apply_chat_template(
|
|
[{"role": "user", "content": prompt_text}],
|
|
tokenize=False,
|
|
add_generation_prompt=True,
|
|
)
|
|
prompts.append(formatted_prompt)
|
|
logger.debug(f"Created verification prompt for question: {question[:50]}...")
|
|
|
|
logger.info("Generating verification responses")
|
|
responses = vllm_generate_func(prompts)
|
|
responses_text = []
|
|
for response in responses:
|
|
# Handle different response formats
|
|
if hasattr(response, "outputs"):
|
|
try:
|
|
responses_text.append(response.outputs[0].text)
|
|
except (AttributeError, IndexError):
|
|
# Fallback for simple string responses
|
|
responses_text.append(str(response))
|
|
else:
|
|
responses_text.append(str(response))
|
|
logger.debug(f"Got {len(responses_text)} verification responses")
|
|
|
|
results = []
|
|
for response in responses_text:
|
|
results.append("yes" in response.lower())
|
|
logger.debug(f"Verification result: {'yes' in response.lower()}")
|
|
|
|
logger.info(f"Verification complete. {sum(results)}/{len(results)} answers correct")
|
|
|
|
# Append the QA details and verifier's response to the specified log file
|
|
if log_file:
|
|
with open(log_file, "a") as file:
|
|
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
file.write(f"\n📝 === QA Evaluation at {timestamp} ===\n")
|
|
file.write(f"📂 File: {__file__}\n")
|
|
|
|
# Get current frame info safely
|
|
frame = inspect.currentframe()
|
|
if frame:
|
|
file.write(f"📍 Line: {frame.f_lineno}\n")
|
|
# Don't forget to delete the frame to avoid reference cycles
|
|
del frame
|
|
|
|
file.write("=" * 80 + "\n")
|
|
|
|
for i, (question, answer, student_ans, verifier_response) in enumerate(
|
|
zip(questions, answers, student_answers, responses_text)
|
|
):
|
|
file.write(f"\n❓ Question {i + 1}:\n")
|
|
file.write("-" * 40 + "\n")
|
|
file.write(f"📋 Question: {question}\n")
|
|
file.write(f"✅ Correct Answer: {answer}\n")
|
|
file.write(f"👨🎓 Student Answer: {student_ans}\n")
|
|
file.write(f"🔍 Verifier said: {verifier_response}\n")
|
|
|
|
# Add search results if available in the chat state
|
|
if isinstance(student_ans, dict) and "messages" in student_ans:
|
|
# Get messages from dict
|
|
messages = student_ans.get("messages", [])
|
|
search_results = [msg.get("content", "") for msg in messages if msg.get("role") == "ipython"]
|
|
if search_results:
|
|
file.write("\n🔎 Search Results:\n")
|
|
for j, result in enumerate(search_results, 1):
|
|
file.write(f"\nSearch {j}:\n{result}\n")
|
|
|
|
file.write("-" * 40 + "\n")
|
|
|
|
file.write(
|
|
f"\n📊 Summary: {sum(results)}/{len(results)} answers correct ({sum(results) / len(results) * 100:.2f}%)\n"
|
|
)
|
|
file.write("=" * 80 + "\n\n")
|
|
|
|
return results
|
|
|
|
|
|
def run_eval(
|
|
generate_fn, verify_fn, tokenizer, max_generations=32, max_new_tokens=4096 * 6, output_file=None, debug_file=None
|
|
):
|
|
"""
|
|
Run evaluation on the test dataset and return results.
|
|
|
|
Args:
|
|
generate_fn: Function to generate completions
|
|
verify_fn: Function to verify results
|
|
tokenizer: Tokenizer for processing text
|
|
output_file: Path to save evaluation results summary
|
|
debug_file: Path to save detailed debug information
|
|
|
|
Returns:
|
|
full_chat_states: The chat states from evaluation
|
|
"""
|
|
train_dataset, test_dataset = get_qa_dataset(
|
|
randomize=False,
|
|
test_size=1,
|
|
questions_path=DATA_DIR / "processed" / "questions_dev.jsonl",
|
|
)
|
|
questions = test_dataset["prompt"]
|
|
|
|
# Create agent with appropriate adapter based on tokenizer
|
|
tokenizer_name = tokenizer.name_or_path.lower()
|
|
if "deepseek-ai/deepseek-r1-distill" in tokenizer_name:
|
|
adapter = R1DistilTokenizerAdapter()
|
|
elif "llama" in tokenizer_name:
|
|
adapter = LlamaTokenizerAdapter()
|
|
else:
|
|
adapter = R1DistilTokenizerAdapter()
|
|
|
|
agent = Agent(adapter)
|
|
agentic_outputs = agent.run_agent(generate_fn, tokenizer, questions, max_generations, max_new_tokens)
|
|
full_chat_states = agentic_outputs.full_chat_states
|
|
final_responses = agentic_outputs.final_response_str
|
|
rewards = verify_fn(questions, full_chat_states, answer=test_dataset["answer"])
|
|
|
|
# Calculate results
|
|
percent_correct = sum(rewards) / len(rewards) * 100
|
|
|
|
# Log results to console
|
|
logger.info("RESULTS:")
|
|
logger.info(f"percentage of correct answers: {percent_correct:.2f}%")
|
|
logger.info("=" * 30)
|
|
|
|
# Save results to file if specified
|
|
if output_file:
|
|
try:
|
|
with open(output_file, "w") as f:
|
|
f.write("EVALUATION RESULTS\n")
|
|
f.write("=================\n\n")
|
|
f.write(f"Total questions: {len(questions)}\n")
|
|
f.write(f"Correct answers: {sum(rewards)}\n")
|
|
f.write(f"Percentage correct: {percent_correct:.2f}%\n\n")
|
|
|
|
f.write("Individual results:\n")
|
|
for i, (q, r, resp) in enumerate(zip(questions, rewards, final_responses)):
|
|
f.write(f"\nQ{i + 1}: {q[:100]}...\n")
|
|
f.write(f"Correct: {'✓' if r else '✗'}\n")
|
|
f.write(f"Response: {resp[:150]}...\n")
|
|
f.write("-" * 40 + "\n")
|
|
logger.info(f"Saved evaluation results to {output_file}")
|
|
except Exception as e:
|
|
logger.error(f"Error saving results file: {e}")
|
|
|
|
# Save debug information if specified
|
|
if debug_file:
|
|
try:
|
|
import json
|
|
|
|
debug_data = []
|
|
for i, (q, r, resp, chat) in enumerate(zip(questions, rewards, final_responses, full_chat_states)):
|
|
debug_data.append(
|
|
{
|
|
"question_id": i,
|
|
"question": q,
|
|
"is_correct": bool(r),
|
|
"final_response": resp,
|
|
"chat_state": {
|
|
k: str(v) if isinstance(v, (list, dict)) else v
|
|
for k, v in chat.items()
|
|
if k != "tokenizer"
|
|
},
|
|
}
|
|
)
|
|
|
|
with open(debug_file, "w") as f:
|
|
json.dump(debug_data, f, indent=2)
|
|
logger.info(f"Saved debug information to {debug_file}")
|
|
except Exception as e:
|
|
logger.error(f"Error saving debug file: {e}")
|
|
|
|
return full_chat_states
|