""" Evaluation utilities for RL training. """ import inspect from datetime import datetime from config import DATA_DIR, logger from src.agent import Agent from src.search_module import get_qa_dataset from src.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter async def verify(student_answer: str, question: str, answer: str) -> bool: """ Verify if student's answer matches the correct answer. Args: student_answer: The model's answer question: The original question answer: The ground truth answer Returns: bool: True if answer is correct, False otherwise """ logger.debug(f"Verifying answer for question: {question}") logger.debug(f"Student answer: {student_answer}") logger.debug(f"Correct answer: {answer}") # Simple string matching for now # TODO: Implement more sophisticated matching return student_answer.strip().lower() == answer.strip().lower() def check_student_answers( questions: list[str], answers: list[str], student_answers: list, # Can be strings or dicts vllm_generate_func, tokenizer, log_file=None, ) -> list[bool]: """ Evaluates a list of student answers against the true answers using a vLLM generate function. Args: questions: List of questions answers: List of correct answers student_answers: List of student answers to evaluate vllm_generate_func: Function to generate verification responses tokenizer: Tokenizer for formatting prompts log_file: Optional path to write detailed results Returns: List of boolean results (True for correct answers) """ logger.info(f"Checking {len(questions)} student answers") if not (len(questions) == len(answers) == len(student_answers)): logger.error("Mismatched lengths between questions, answers, and student answers") raise ValueError("The number of questions, answers, and student answers must be equal.") prompts = [] for question, answer, student_ans in zip(questions, answers, student_answers): prompt_text = ( "You are grading a student's answer to a question. For the following question, " "compare the student's answer to the correct answer. Reply with 'Yes' if the student's answer contains the correct information, " "even if it's not an exact match. If the student's answer doesn't contain the right information or is completely incorrect, reply with 'No'.\n\n" f"Question: {question}\n" f"Correct Answer: {answer}\n" f"Student Answer: {student_ans}\n\n" "Your response should be just 'Yes' or 'No'." ) formatted_prompt = tokenizer.apply_chat_template( [{"role": "user", "content": prompt_text}], tokenize=False, add_generation_prompt=True, ) prompts.append(formatted_prompt) logger.debug(f"Created verification prompt for question: {question[:50]}...") logger.info("Generating verification responses") responses = vllm_generate_func(prompts) responses_text = [] for response in responses: # Handle different response formats if hasattr(response, "outputs"): try: responses_text.append(response.outputs[0].text) except (AttributeError, IndexError): # Fallback for simple string responses responses_text.append(str(response)) else: responses_text.append(str(response)) logger.debug(f"Got {len(responses_text)} verification responses") results = [] for response in responses_text: results.append("yes" in response.lower()) logger.debug(f"Verification result: {'yes' in response.lower()}") logger.info(f"Verification complete. {sum(results)}/{len(results)} answers correct") # Append the QA details and verifier's response to the specified log file if log_file: with open(log_file, "a") as file: timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") file.write(f"\nšŸ“ === QA Evaluation at {timestamp} ===\n") file.write(f"šŸ“‚ File: {__file__}\n") # Get current frame info safely frame = inspect.currentframe() if frame: file.write(f"šŸ“ Line: {frame.f_lineno}\n") # Don't forget to delete the frame to avoid reference cycles del frame file.write("=" * 80 + "\n") for i, (question, answer, student_ans, verifier_response) in enumerate( zip(questions, answers, student_answers, responses_text) ): file.write(f"\nā“ Question {i + 1}:\n") file.write("-" * 40 + "\n") file.write(f"šŸ“‹ Question: {question}\n") file.write(f"āœ… Correct Answer: {answer}\n") file.write(f"šŸ‘Øā€šŸŽ“ Student Answer: {student_ans}\n") file.write(f"šŸ” Verifier said: {verifier_response}\n") # Add search results if available in the chat state if isinstance(student_ans, dict) and "messages" in student_ans: # Get messages from dict messages = student_ans.get("messages", []) search_results = [msg.get("content", "") for msg in messages if msg.get("role") == "ipython"] if search_results: file.write("\nšŸ”Ž Search Results:\n") for j, result in enumerate(search_results, 1): file.write(f"\nSearch {j}:\n{result}\n") file.write("-" * 40 + "\n") file.write( f"\nšŸ“Š Summary: {sum(results)}/{len(results)} answers correct ({sum(results) / len(results) * 100:.2f}%)\n" ) file.write("=" * 80 + "\n\n") return results def run_eval( generate_fn, verify_fn, tokenizer, max_generations=32, max_new_tokens=4096 * 6, output_file=None, debug_file=None ): """ Run evaluation on the test dataset and return results. Args: generate_fn: Function to generate completions verify_fn: Function to verify results tokenizer: Tokenizer for processing text output_file: Path to save evaluation results summary debug_file: Path to save detailed debug information Returns: full_chat_states: The chat states from evaluation """ train_dataset, test_dataset = get_qa_dataset( randomize=False, test_size=1, questions_path=DATA_DIR / "processed" / "questions_dev.jsonl", ) questions = test_dataset["prompt"] # Create agent with appropriate adapter based on tokenizer tokenizer_name = tokenizer.name_or_path.lower() if "deepseek-ai/deepseek-r1-distill" in tokenizer_name: adapter = R1DistilTokenizerAdapter() elif "llama" in tokenizer_name: adapter = LlamaTokenizerAdapter() else: adapter = R1DistilTokenizerAdapter() agent = Agent(adapter) agentic_outputs = agent.run_agent(generate_fn, tokenizer, questions, max_generations, max_new_tokens) full_chat_states = agentic_outputs.full_chat_states final_responses = agentic_outputs.final_response_str rewards = verify_fn(questions, full_chat_states, answer=test_dataset["answer"]) # Calculate results percent_correct = sum(rewards) / len(rewards) * 100 # Log results to console logger.info("RESULTS:") logger.info(f"percentage of correct answers: {percent_correct:.2f}%") logger.info("=" * 30) # Save results to file if specified if output_file: try: with open(output_file, "w") as f: f.write("EVALUATION RESULTS\n") f.write("=================\n\n") f.write(f"Total questions: {len(questions)}\n") f.write(f"Correct answers: {sum(rewards)}\n") f.write(f"Percentage correct: {percent_correct:.2f}%\n\n") f.write("Individual results:\n") for i, (q, r, resp) in enumerate(zip(questions, rewards, final_responses)): f.write(f"\nQ{i + 1}: {q[:100]}...\n") f.write(f"Correct: {'āœ“' if r else 'āœ—'}\n") f.write(f"Response: {resp[:150]}...\n") f.write("-" * 40 + "\n") logger.info(f"Saved evaluation results to {output_file}") except Exception as e: logger.error(f"Error saving results file: {e}") # Save debug information if specified if debug_file: try: import json debug_data = [] for i, (q, r, resp, chat) in enumerate(zip(questions, rewards, final_responses, full_chat_states)): debug_data.append( { "question_id": i, "question": q, "is_correct": bool(r), "final_response": resp, "chat_state": { k: str(v) if isinstance(v, (list, dict)) else v for k, v in chat.items() if k != "tokenizer" }, } ) with open(debug_file, "w") as f: json.dump(debug_data, f, indent=2) logger.info(f"Saved debug information to {debug_file}") except Exception as e: logger.error(f"Error saving debug file: {e}") return full_chat_states