diff --git a/.gitignore b/.gitignore index 4341448..3f1a72f 100644 --- a/.gitignore +++ b/.gitignore @@ -13,6 +13,7 @@ data/ models/ model/ graveyard/ +eval_logs/ # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/eval.py b/eval.py index b9167ca..d623f43 100644 --- a/eval.py +++ b/eval.py @@ -1,83 +1,27 @@ """ -Compare base model with LoRA model performance. +Evaluate model performance using vLLM and unsloth. -This script evaluates and compares the performance of a base model against -the same model with a LoRA adapter applied. +This script evaluates the performance of a model using vLLM for fast inference +and unsloth for LoRA support. """ import argparse -import glob import os -import re import time from datetime import datetime from unsloth import FastLanguageModel from vllm import SamplingParams -import src.rl_helpers as rl_helpers -from src.config import MODEL_NAME, OUTPUT_DIR, logger - - -def find_latest_checkpoint(search_dir=None): - """ - Find the latest checkpoint in the specified directory or OUTPUT_DIR. - - Args: - search_dir: Directory to search for checkpoints (default: OUTPUT_DIR) - - Returns: - Path to the latest checkpoint or None if no checkpoints found - """ - if search_dir is None: - search_dir = OUTPUT_DIR - logger.info(f"No search directory provided, using default: {search_dir}") - else: - logger.info(f"Searching for checkpoints in: {search_dir}") - - # Check if the directory exists first - if not os.path.exists(search_dir): - logger.warning(f"Search directory {search_dir} does not exist") - return None - - # First try to find checkpoints in the format checkpoint-{step} - checkpoints = glob.glob(os.path.join(search_dir, "checkpoint-*")) - - if checkpoints: - # Extract checkpoint numbers and sort - checkpoint_numbers = [] - for checkpoint in checkpoints: - match = re.search(r"checkpoint-(\d+)$", checkpoint) - if match: - checkpoint_numbers.append((int(match.group(1)), checkpoint)) - - if checkpoint_numbers: - # Sort by checkpoint number (descending) - checkpoint_numbers.sort(reverse=True) - latest = checkpoint_numbers[0][1] - logger.info(f"Found latest checkpoint: {latest}") - return latest - - # If no checkpoints found, look for saved_adapter_{timestamp}.bin files - adapter_files = glob.glob(os.path.join(search_dir, "saved_adapter_*.bin")) - if adapter_files: - # Sort by modification time (newest first) - adapter_files.sort(key=os.path.getmtime, reverse=True) - latest = adapter_files[0] - logger.info(f"Found latest adapter file: {latest}") - return latest - - # If all else fails, look for any .bin files - bin_files = glob.glob(os.path.join(search_dir, "*.bin")) - if bin_files: - # Sort by modification time (newest first) - bin_files.sort(key=os.path.getmtime, reverse=True) - latest = bin_files[0] - logger.info(f"Found latest .bin file: {latest}") - return latest - - logger.warning(f"No checkpoints found in {search_dir}") - return None +from src import ( + apply_chat_template, + build_reward_correctness_fn, + build_user_prompt, + get_qa_dataset, + get_system_prompt, + run_eval, +) +from src.config import MODEL_NAME, logger def get_model_config(): @@ -99,7 +43,7 @@ def get_model_config(): } -def get_sampling_params(temperature: float = 0.5) -> SamplingParams: +def get_sampling_params(temperature=0.5): """Get sampling parameters for generation.""" return SamplingParams( temperature=temperature, @@ -135,94 +79,6 @@ def setup_model_and_tokenizer(): return model, tokenizer -def test_lora_functionality(model, tokenizer, lora_path): - """ - Test if LoRA is working properly by doing a direct comparison on a simple prompt. - - Args: - model: The model to test - tokenizer: The tokenizer - lora_path: Path to LoRA weights - - Returns: - bool: True if LoRA is working properly - """ - logger.info(f"\n{'=' * 50}") - logger.info("TESTING LORA FUNCTIONALITY") - logger.info(f"{'=' * 50}") - - # First check if LoRA path exists - if not os.path.exists(lora_path): - logger.error(f"ERROR: LoRA path does not exist: {lora_path}") - return False - - logger.info(f"LoRA path exists: {lora_path}") - - # Test prompt - test_prompt = "Explain the concept of Low-Rank Adaptation (LoRA) in one paragraph:" - - # Format prompt for model - formatted_prompt = tokenizer.apply_chat_template( - [{"role": "user", "content": test_prompt}], - tokenize=False, - add_generation_prompt=True, - ) - - # Sample with base model - logger.info("Generating with base model...") - sampling_params = get_sampling_params(temperature=0.7) # Higher temp to make differences more obvious - base_response = model.fast_generate( - [formatted_prompt], - sampling_params=sampling_params, - ) - if hasattr(base_response[0], "outputs"): - base_text = base_response[0].outputs[0].text - else: - base_text = base_response[0] - - # Sample with LoRA - logger.info(f"Loading LoRA adapter from {lora_path}...") - lora_request = model.load_lora(lora_path) - if lora_request is None: - logger.error("ERROR: Failed to load LoRA adapter") - return False - - logger.info(f"LoRA adapter loaded successfully: {lora_request}") - logger.info("Generating with LoRA model...") - - lora_response = model.fast_generate( - [formatted_prompt], - sampling_params=sampling_params, - lora_request=lora_request, - ) - if hasattr(lora_response[0], "outputs"): - lora_text = lora_response[0].outputs[0].text - else: - lora_text = lora_response[0] - - # Check if responses are different - are_identical = base_text == lora_text - logger.info(f"\nResponses are {'identical' if are_identical else 'different'}") - - logger.info("\nBASE MODEL RESPONSE:") - logger.info("-" * 40) - logger.info(base_text[:500] + "..." if len(base_text) > 500 else base_text) - logger.info("-" * 40) - - logger.info("\nLoRA MODEL RESPONSE:") - logger.info("-" * 40) - logger.info(lora_text[:500] + "..." if len(lora_text) > 500 else lora_text) - logger.info("-" * 40) - - if are_identical: - logger.warning("\nWARNING: LoRA adapter does not seem to change the model's output") - logger.warning("This could indicate that the LoRA adapter is not being properly applied") - else: - logger.info("\nLoRA adapter is working as expected (outputs are different)") - - return not are_identical - - def evaluate_model( model, tokenizer, @@ -237,73 +93,19 @@ def evaluate_model( Args: model: The model to evaluate tokenizer: The tokenizer - lora_path: Path to LoRA weights (None or empty for base model, "auto" for auto-detect) + lora_path: Path to LoRA weights (None for base model) temperature: Sampling temperature output_file: File to write results to - trainer_dir: Directory containing the checkpoints (parent of checkpoint directory) - - Returns: - dict: Evaluation results + trainer_dir: Directory containing the checkpoints """ sampling_params = get_sampling_params(temperature=temperature) - # --- Determine Trainer Output Directory --- - # Prioritize the directory passed from the shell script if available - if trainer_dir and os.path.isdir(trainer_dir): - trainer_output_dir = os.path.abspath(trainer_dir) - logger.info(f"Using trainer directory passed from arguments: {trainer_output_dir}") + # Set up output directory + if trainer_dir: + eval_log_dir = os.path.join(trainer_dir, "eval_logs") else: - logger.warning( - f"Trainer directory not provided or invalid: {trainer_dir}. Attempting to determine automatically." - ) - # Fallback logic if trainer_dir is not provided or invalid - temp_lora_path = lora_path - if temp_lora_path == "auto": - # Find latest checkpoint, searching within OUTPUT_DIR by default - temp_lora_path = find_latest_checkpoint() # Searches OUTPUT_DIR by default - - if temp_lora_path and os.path.exists(temp_lora_path): - # If a LoRA path exists (provided or found), get its parent's parent - checkpoint_dir = os.path.dirname(os.path.abspath(temp_lora_path)) - trainer_output_dir = os.path.dirname(checkpoint_dir) - logger.info(f"Determined trainer directory from LoRA path ({temp_lora_path}): {trainer_output_dir}") - else: - # If no LoRA path, default to current directory (should ideally not happen if called from eval.sh) - trainer_output_dir = os.path.abspath(".") - logger.warning( - f"Could not determine trainer directory automatically. Defaulting to current directory: {trainer_output_dir}" - ) - - # --- Auto-detect LoRA path if needed, searching within the determined trainer_output_dir --- - if lora_path == "auto": - # Pass the determined trainer_output_dir to find_latest_checkpoint - detected_checkpoint = find_latest_checkpoint(search_dir=trainer_output_dir) - if detected_checkpoint: - lora_path = detected_checkpoint - logger.info(f"Auto-detected latest checkpoint in {trainer_output_dir}: {lora_path}") - else: - logger.warning(f"No checkpoint found in {trainer_output_dir} for auto-detection. Evaluating base model.") - lora_path = None - - model_type = "LoRA" if lora_path else "Base" - - logger.info(f"\n{'=' * 50}") - logger.info(f"Starting evaluation of {model_type} model") - logger.info(f"Trainer Output Directory: {trainer_output_dir}") # Log the final directory - logger.info(f"{'=' * 50}") - - # --- Create eval_logs directory --- - # Always create it inside the determined trainer_output_dir - eval_log_dir = os.path.join(trainer_output_dir, "eval_logs") - try: - os.makedirs(eval_log_dir, exist_ok=True) - logger.info(f"Ensured eval_logs directory exists at: {eval_log_dir}") - except OSError as e: - logger.error(f"Failed to create directory {eval_log_dir}: {e}") - # Fallback to current directory if creation fails - eval_log_dir = os.path.abspath("./eval_logs") - os.makedirs(eval_log_dir, exist_ok=True) - logger.warning(f"Fell back to creating eval_logs in current directory: {eval_log_dir}") + eval_log_dir = "eval_logs" + os.makedirs(eval_log_dir, exist_ok=True) # Create file names based on model type model_prefix = "lora" if lora_path else "base" @@ -312,32 +114,40 @@ def evaluate_model( # Define all output file paths eval_log_file = os.path.join(eval_log_dir, f"{model_prefix}_model_eval_{timestamp}.log") output_file = os.path.join(eval_log_dir, f"{model_prefix}_model_results.txt") - debug_file = os.path.join(eval_log_dir, f"{model_prefix}_model_results_debug.txt") + debug_file = os.path.join(eval_log_dir, f"{model_prefix}_model_results_debug.json") logger.info(f"Writing evaluation log to: {eval_log_file}") logger.info(f"Results will be saved to: {output_file}") - # Function to generate completions + # Function to generate completions using agentic approach def eval_generate_fn(inputs): start_time = time.time() + + # Format inputs as chat messages with system prompt + messages = [ + { + "messages": [ + {"role": "system", "content": get_system_prompt()}, + {"role": "user", "content": build_user_prompt(input_text)}, + ] + } + for input_text in inputs + ] + if lora_path: lora_request = model.load_lora(lora_path) load_time = time.time() - start_time logger.info(f"LoRA adapter loaded in {load_time:.2f} seconds: {lora_request}") - responses = model.fast_generate(inputs, sampling_params=sampling_params, lora_request=lora_request) + responses = model.fast_generate( + [apply_chat_template(msg, tokenizer=tokenizer)["text"] for msg in messages], + sampling_params=sampling_params, + lora_request=lora_request, + ) else: - # For base model, add additional logging - logger.info("Generating with base model (no LoRA)") - # Also write to the base model log file directly - with open(eval_log_file, "a") as f: - f.write(f"\n{'=' * 50}\n") - f.write("BASE MODEL GENERATION\n") - f.write(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") - f.write(f"Model: {MODEL_NAME}\n") - f.write(f"Temperature: {temperature}\n") - f.write(f"{'=' * 50}\n\n") - - responses = model.fast_generate(inputs, sampling_params=sampling_params) + responses = model.fast_generate( + [apply_chat_template(msg, tokenizer=tokenizer)["text"] for msg in messages], + sampling_params=sampling_params, + ) gen_time = time.time() - start_time logger.debug(f"Generation completed in {gen_time:.2f} seconds") @@ -346,13 +156,28 @@ def evaluate_model( def verifier_generate_fn(inputs): # Use a lower temperature for verification to get more consistent results verifier_params = get_sampling_params(temperature=0.1) - return model.fast_generate(inputs, sampling_params=verifier_params) + + # Format inputs as chat messages with system prompt + messages = [ + { + "messages": [ + {"role": "system", "content": get_system_prompt()}, + {"role": "user", "content": build_user_prompt(input_text)}, + ] + } + for input_text in inputs + ] + + return model.fast_generate( + [apply_chat_template(msg, tokenizer=tokenizer)["text"] for msg in messages], + sampling_params=verifier_params, + ) # Prepare the verification function - verify_fn = rl_helpers.build_reward_correctness_fn(verifier_generate_fn, tokenizer, log_file=eval_log_file) + verify_fn = build_reward_correctness_fn(verifier_generate_fn, tokenizer) # Get the dataset and prepare questions and answers - train_dataset, test_dataset = rl_helpers.get_qa_dataset() + train_dataset, test_dataset = get_qa_dataset() questions = test_dataset["prompt"] inputs = questions @@ -360,9 +185,11 @@ def evaluate_model( # Run the evaluation start_time = time.time() + model_type = "LoRA" if lora_path else "Base" logger.info(f"Starting {model_type} model evaluation...") - full_chat_states = rl_helpers.run_eval( + # Run evaluation using the agentic approach + full_chat_states = run_eval( generate_fn=eval_generate_fn, verify_fn=verify_fn, tokenizer=tokenizer, @@ -422,28 +249,16 @@ def compare_models(lora_path, temperature=0.5, output_file=None, trainer_dir=Non Compare base model with LoRA model. Args: - lora_path: Path to LoRA weights (use "auto" for auto-detection) + lora_path: Path to LoRA weights temperature: Sampling temperature - output_file: File to write results to (optional, will be auto-generated if None) - trainer_dir: Directory containing the trainer output (parent of checkpoint directory) + output_file: File to write results to (optional) + trainer_dir: Directory containing the trainer output """ - # Auto-detect checkpoint if requested - if lora_path == "auto": - search_dir = trainer_dir if trainer_dir else OUTPUT_DIR - detected_checkpoint = find_latest_checkpoint(search_dir=search_dir) - if detected_checkpoint: - lora_path = detected_checkpoint - logger.info(f"Auto-detected latest checkpoint: {lora_path}") - else: - logger.warning("No checkpoint found for auto-detection. Skipping comparison.") - return - - # Set up output directory in the checkpoint directory - checkpoint_dir = os.path.dirname(lora_path) - if not trainer_dir: - trainer_dir = os.path.dirname(checkpoint_dir) - - eval_log_dir = os.path.join(trainer_dir, "eval_logs") + # Set up output directory + if trainer_dir: + eval_log_dir = os.path.join(trainer_dir, "eval_logs") + else: + eval_log_dir = "eval_logs" os.makedirs(eval_log_dir, exist_ok=True) # Define the comparison file path if not provided @@ -456,11 +271,6 @@ def compare_models(lora_path, temperature=0.5, output_file=None, trainer_dir=Non model, tokenizer = setup_model_and_tokenizer() - # Test if LoRA is working properly - lora_works = test_lora_functionality(model, tokenizer, lora_path) - if not lora_works: - logger.warning("LoRA adapter test failed. Results may not be reliable.") - # Evaluate both models base_results = evaluate_model( model, @@ -527,66 +337,26 @@ if __name__ == "__main__": parser.add_argument( "--lora_path", type=str, - default="auto", - help="Path to LoRA weights (use 'auto' for auto-detection)", + default="trainer_output_example/checkpoint-101", + help="Path to LoRA weights", ) parser.add_argument("--temperature", type=float, default=0.5, help="Sampling temperature") parser.add_argument( "--output_file", type=str, default=None, - help="File to write results to (optional, will be auto-generated if None)", + help="File to write results to (optional)", ) parser.add_argument( "--trainer_dir", type=str, default=None, - help="Directory containing the trainer output (parent of checkpoint directory)", + help="Directory containing the trainer output", ) args = parser.parse_args() - # Auto-detect checkpoint first to set up logging directory - checkpoint_dir = None - lora_path = args.lora_path - trainer_dir = args.trainer_dir - - if trainer_dir: - if os.path.exists(trainer_dir): - logger.info(f"Using provided trainer directory: {trainer_dir}") - else: - logger.warning(f"Provided trainer directory does not exist: {trainer_dir}") - trainer_dir = None - - if lora_path == "auto": - search_dir = trainer_dir if trainer_dir else OUTPUT_DIR - detected_checkpoint = find_latest_checkpoint(search_dir=search_dir) - if detected_checkpoint: - lora_path = detected_checkpoint - checkpoint_dir = os.path.dirname(lora_path) - if not trainer_dir: # Only set if not provided - trainer_dir = os.path.dirname(checkpoint_dir) - - # Set up logging in the trainer directory - eval_log_dir = os.path.join(trainer_dir, "eval_logs") - os.makedirs(eval_log_dir, exist_ok=True) - - # If this is imported from config, use it here - try: - from src.config import update_log_path - - update_log_path(eval_log_dir) - logger.info(f"Logs will be saved to both ./logs and {eval_log_dir}") - except ImportError: - logger.info("Config's update_log_path not available, using default logging") - - if trainer_dir: - logger.info(f"Using trainer directory: {trainer_dir}") - logger.info(f"All evaluation files will be stored in: {os.path.join(trainer_dir, 'eval_logs')}") - else: - logger.warning("No trainer directory found, will attempt to determine during evaluation") - logger.info(f"Starting model evaluation with temperature {args.temperature}") - results = compare_models(args.lora_path, args.temperature, args.output_file, trainer_dir=trainer_dir) + results = compare_models(args.lora_path, args.temperature, args.output_file, trainer_dir=args.trainer_dir) if results: logger.info("Evaluation completed successfully") logger.info(f"Final improvement: {results['improvement']:.4f}") @@ -599,14 +369,12 @@ if __name__ == "__main__": logger.info(f"LoRA model results: {results['lora_output']}") # Find and print all log files in the eval_logs directory - if trainer_dir: - eval_log_dir = os.path.join(trainer_dir, "eval_logs") - if os.path.exists(eval_log_dir): - log_files = [f for f in os.listdir(eval_log_dir) if f.endswith(".log")] - - if log_files: - logger.info("\nEVALUATION LOG FILES:") - for log_file in log_files: - logger.info(f"- {os.path.join(eval_log_dir, log_file)}") + eval_log_dir = os.path.join(args.trainer_dir, "eval_logs") if args.trainer_dir else "eval_logs" + if os.path.exists(eval_log_dir): + log_files = [f for f in os.listdir(eval_log_dir) if f.endswith(".log")] + if log_files: + logger.info("\nEVALUATION LOG FILES:") + for log_file in log_files: + logger.info(f"- {os.path.join(eval_log_dir, log_file)}") else: logger.warning("Evaluation failed or was skipped") diff --git a/eval.sh b/eval.sh deleted file mode 100755 index ab4757e..0000000 --- a/eval.sh +++ /dev/null @@ -1,92 +0,0 @@ -#!/bin/bash -# Script to run model comparison between base model and LoRA model - -# Initialize variables -LORA_PATH="" -TEMPERATURE=0.5 - -# Parse command line arguments -while [[ $# -gt 0 ]]; do - case $1 in - --lora_path) - LORA_PATH="$2" - shift 2 - ;; - --temperature) - TEMPERATURE="$2" - shift 2 - ;; - --output_file) - echo "Warning: Custom output_file is not recommended. Files are automatically saved in checkpoint's eval_logs directory." - # We'll silently ignore this parameter - shift 2 - ;; - *) - echo "Unknown option: $1" - echo "Usage: $0 [--lora_path ] [--temperature ]" - exit 1 - ;; - esac -done - -# If LORA_PATH is not provided, try to find the latest checkpoint -if [ -z "$LORA_PATH" ]; then - echo "No checkpoint path provided, searching for latest checkpoint..." - # Look for trainer_output directories in current directory and convert to absolute path - TRAINER_DIR=$(find . -maxdepth 1 -type d -name "trainer_output_*" | sort -r | head -n 1) - - if [ -z "$TRAINER_DIR" ]; then - echo "Error: No trainer output directory found. Please provide a checkpoint path with --lora_path" - echo "Usage: $0 [--lora_path ] [--temperature ]" - exit 1 - fi - - # Convert to absolute path - TRAINER_DIR=$(realpath "$TRAINER_DIR") - echo "Found trainer directory: ${TRAINER_DIR}" - - # Get the checkpoint path, filtering out log messages but keeping the path - LORA_PATH=$(python -c "from eval import find_latest_checkpoint; print(find_latest_checkpoint('${TRAINER_DIR}') or '')" | grep -v "INFO" | grep -v "DEBUG" | grep -v "WARNING" | grep -v "ERROR" | grep -v "LangChain" | grep -v "FAISS" | grep -v "Successfully" | grep -v "Loading" | grep -v "Project root" | grep -v "Running in" | grep -v "Automatically" | grep -v "Platform" | grep -v "Torch" | grep -v "CUDA" | grep -v "Triton" | grep -v "Bfloat16" | grep -v "Free license" | grep -v "Fast downloading" | grep -v "vLLM loading" | grep -v "==" | grep -v "^$" | tail -n 1) - - if [ -z "$LORA_PATH" ]; then - echo "Error: No checkpoint found in ${TRAINER_DIR}. Please provide a checkpoint path with --lora_path" - echo "Usage: $0 [--lora_path ] [--temperature ]" - exit 1 - fi - echo "Found latest checkpoint: ${LORA_PATH}" -else - # If LORA_PATH is provided, convert it to absolute path - LORA_PATH=$(realpath "$LORA_PATH") - # Get the trainer directory (parent of checkpoint directory) - TRAINER_DIR=$(dirname "$(dirname "$LORA_PATH")") -fi - -# Verify checkpoint and trainer directory exist -if [ ! -d "$(dirname "$LORA_PATH")" ]; then - echo "Error: Checkpoint directory does not exist: $(dirname "$LORA_PATH")" - exit 1 -fi - -if [ ! -d "$TRAINER_DIR" ]; then - echo "Error: Trainer directory does not exist: $TRAINER_DIR" - exit 1 -fi - -# Create eval_logs directory in the trainer output directory -EVAL_LOGS_DIR="$TRAINER_DIR/eval_logs" -mkdir -p "$EVAL_LOGS_DIR" - -echo "Starting model comparison..." -echo "LoRA path: ${LORA_PATH}" -echo "Trainer directory: ${TRAINER_DIR}" -echo "Temperature: ${TEMPERATURE}" -echo "Evaluation logs will be saved in: ${EVAL_LOGS_DIR}" - -# Run the comparison script, explicitly passing the trainer directory -python eval.py \ - --lora_path "${LORA_PATH}" \ - --temperature "${TEMPERATURE}" \ - --trainer_dir "${TRAINER_DIR}" - -echo "Model comparison completed." -echo "Evaluation logs are saved in: ${EVAL_LOGS_DIR}" \ No newline at end of file diff --git a/inference.py b/inference.py index 5664fd8..d9b7ec1 100644 --- a/inference.py +++ b/inference.py @@ -13,10 +13,12 @@ from datetime import datetime from transformers import AutoModelForCausalLM, AutoTokenizer from vllm import SamplingParams -from src.rl_helpers import ( +from src import ( + apply_chat_template, build_user_prompt, extract_search_query, format_search_results, + get_system_prompt, ) from src.search_module import load_vectorstore, search @@ -68,26 +70,15 @@ class DeepSearchCLI: self.sampling_params = get_sampling_params(temperature) self.history = [] self.search_history = [] - self.system_prompt = ( - system_prompt - or f"""Cutting Knowledge Date: December 2023 -Today Date: {datetime.now().strftime("%d %b %Y")} - -When you receive a tool call response, use the output to format an answer to the original user question. - -You are a helpful assistant with tool calling capabilities.""" - ) + self.system_prompt = system_prompt or get_system_prompt() def _run_agent_generation(self, chat_state: dict) -> dict: """Run a single generation step for the agent.""" - formatted_prompt = self.tokenizer.apply_chat_template( - chat_state["messages"], - tokenize=False, - add_generation_prompt=True, - ) + # Format the chat state using the same template as training + formatted_prompt = apply_chat_template(chat_state, tokenizer=self.tokenizer)["text"] start_time = time.time() - inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.model.device) + inputs = self.tokenizer(formatted_prompt, return_tensors="pt", add_special_tokens=False).to(self.model.device) outputs = self.model.generate( **inputs, max_new_tokens=self.sampling_params.max_tokens, @@ -118,7 +109,7 @@ You are a helpful assistant with tool calling capabilities.""" Returns: The generated response after completing the conversation """ - # Initialize chat state + # Initialize chat state with the same structure as training chat_state = { "messages": [ {"role": "system", "content": self.system_prompt}, diff --git a/src/__init__.py b/src/__init__.py index 169eccd..5e9f975 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -4,7 +4,7 @@ Main package exports for RL helpers. from trl.trainer.grpo_trainer import apply_chat_template -from src.agent import Agent +from src.agent import Agent, extract_search_query from src.config import logger from src.evaluation import check_student_answers, run_eval, verify from src.prompts import build_user_prompt, format_search_results, get_system_prompt @@ -27,6 +27,7 @@ __all__ = [ "Agent", "LlamaTokenizerAdapter", "R1DistilTokenizerAdapter", + "extract_search_query", # Rewards "build_reward_correctness_fn", "reward_format", diff --git a/src/agent.py b/src/agent.py index 11b4748..0f87e4d 100644 --- a/src/agent.py +++ b/src/agent.py @@ -15,6 +15,13 @@ from src.search_module import search from src.tokenizer_adapter import TokenizerAdapter +def extract_search_query(text: str) -> str | None: + """Extract search query from text between tags.""" + pattern = re.compile(r"(.*?)", re.DOTALL) + matches = pattern.findall(text) + return matches[-1] if matches else None + + @dataclass class AgenticOutputs: """Outputs from running the agent on a batch of questions.""" @@ -42,12 +49,6 @@ class Agent: ] } - def extract_search_query(self, text: str) -> str | None: - """Extract search query from text between tags.""" - pattern = re.compile(r"(.*?)", re.DOTALL) - matches = pattern.findall(text) - return matches[-1] if matches else None - def run_agent_generations(self, generate_fn, tokenizer, chat_states: list[dict]) -> list[dict]: """Run generation for chat states requiring assistant responses.""" logger.debug(f"Starting generation for {len(chat_states)} chat states") @@ -109,7 +110,7 @@ class Agent: ) try: assistant_response = chat_state["messages"][-1]["content"] - search_query = self.extract_search_query(assistant_response) + search_query = extract_search_query(assistant_response) if search_query: logger.info(f"🔍 Search Query: {search_query}") results = search(search_query, return_type=str, results=2)