feat: enhance evaluation script and remove deprecated shell script

- Updated eval.py to streamline model evaluation using vLLM and unsloth.
- Deleted eval.sh as its functionality is now integrated into eval.py.
- Updated .gitignore to exclude eval_logs directory.
main
thinhlpg 1 month ago
parent 908768458c
commit d2f03b96ab

1
.gitignore vendored

@ -13,6 +13,7 @@ data/
models/
model/
graveyard/
eval_logs/
# Byte-compiled / optimized / DLL files
__pycache__/

@ -1,83 +1,27 @@
"""
Compare base model with LoRA model performance.
Evaluate model performance using vLLM and unsloth.
This script evaluates and compares the performance of a base model against
the same model with a LoRA adapter applied.
This script evaluates the performance of a model using vLLM for fast inference
and unsloth for LoRA support.
"""
import argparse
import glob
import os
import re
import time
from datetime import datetime
from unsloth import FastLanguageModel
from vllm import SamplingParams
import src.rl_helpers as rl_helpers
from src.config import MODEL_NAME, OUTPUT_DIR, logger
def find_latest_checkpoint(search_dir=None):
"""
Find the latest checkpoint in the specified directory or OUTPUT_DIR.
Args:
search_dir: Directory to search for checkpoints (default: OUTPUT_DIR)
Returns:
Path to the latest checkpoint or None if no checkpoints found
"""
if search_dir is None:
search_dir = OUTPUT_DIR
logger.info(f"No search directory provided, using default: {search_dir}")
else:
logger.info(f"Searching for checkpoints in: {search_dir}")
# Check if the directory exists first
if not os.path.exists(search_dir):
logger.warning(f"Search directory {search_dir} does not exist")
return None
# First try to find checkpoints in the format checkpoint-{step}
checkpoints = glob.glob(os.path.join(search_dir, "checkpoint-*"))
if checkpoints:
# Extract checkpoint numbers and sort
checkpoint_numbers = []
for checkpoint in checkpoints:
match = re.search(r"checkpoint-(\d+)$", checkpoint)
if match:
checkpoint_numbers.append((int(match.group(1)), checkpoint))
if checkpoint_numbers:
# Sort by checkpoint number (descending)
checkpoint_numbers.sort(reverse=True)
latest = checkpoint_numbers[0][1]
logger.info(f"Found latest checkpoint: {latest}")
return latest
# If no checkpoints found, look for saved_adapter_{timestamp}.bin files
adapter_files = glob.glob(os.path.join(search_dir, "saved_adapter_*.bin"))
if adapter_files:
# Sort by modification time (newest first)
adapter_files.sort(key=os.path.getmtime, reverse=True)
latest = adapter_files[0]
logger.info(f"Found latest adapter file: {latest}")
return latest
# If all else fails, look for any .bin files
bin_files = glob.glob(os.path.join(search_dir, "*.bin"))
if bin_files:
# Sort by modification time (newest first)
bin_files.sort(key=os.path.getmtime, reverse=True)
latest = bin_files[0]
logger.info(f"Found latest .bin file: {latest}")
return latest
logger.warning(f"No checkpoints found in {search_dir}")
return None
from src import (
apply_chat_template,
build_reward_correctness_fn,
build_user_prompt,
get_qa_dataset,
get_system_prompt,
run_eval,
)
from src.config import MODEL_NAME, logger
def get_model_config():
@ -99,7 +43,7 @@ def get_model_config():
}
def get_sampling_params(temperature: float = 0.5) -> SamplingParams:
def get_sampling_params(temperature=0.5):
"""Get sampling parameters for generation."""
return SamplingParams(
temperature=temperature,
@ -135,94 +79,6 @@ def setup_model_and_tokenizer():
return model, tokenizer
def test_lora_functionality(model, tokenizer, lora_path):
"""
Test if LoRA is working properly by doing a direct comparison on a simple prompt.
Args:
model: The model to test
tokenizer: The tokenizer
lora_path: Path to LoRA weights
Returns:
bool: True if LoRA is working properly
"""
logger.info(f"\n{'=' * 50}")
logger.info("TESTING LORA FUNCTIONALITY")
logger.info(f"{'=' * 50}")
# First check if LoRA path exists
if not os.path.exists(lora_path):
logger.error(f"ERROR: LoRA path does not exist: {lora_path}")
return False
logger.info(f"LoRA path exists: {lora_path}")
# Test prompt
test_prompt = "Explain the concept of Low-Rank Adaptation (LoRA) in one paragraph:"
# Format prompt for model
formatted_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": test_prompt}],
tokenize=False,
add_generation_prompt=True,
)
# Sample with base model
logger.info("Generating with base model...")
sampling_params = get_sampling_params(temperature=0.7) # Higher temp to make differences more obvious
base_response = model.fast_generate(
[formatted_prompt],
sampling_params=sampling_params,
)
if hasattr(base_response[0], "outputs"):
base_text = base_response[0].outputs[0].text
else:
base_text = base_response[0]
# Sample with LoRA
logger.info(f"Loading LoRA adapter from {lora_path}...")
lora_request = model.load_lora(lora_path)
if lora_request is None:
logger.error("ERROR: Failed to load LoRA adapter")
return False
logger.info(f"LoRA adapter loaded successfully: {lora_request}")
logger.info("Generating with LoRA model...")
lora_response = model.fast_generate(
[formatted_prompt],
sampling_params=sampling_params,
lora_request=lora_request,
)
if hasattr(lora_response[0], "outputs"):
lora_text = lora_response[0].outputs[0].text
else:
lora_text = lora_response[0]
# Check if responses are different
are_identical = base_text == lora_text
logger.info(f"\nResponses are {'identical' if are_identical else 'different'}")
logger.info("\nBASE MODEL RESPONSE:")
logger.info("-" * 40)
logger.info(base_text[:500] + "..." if len(base_text) > 500 else base_text)
logger.info("-" * 40)
logger.info("\nLoRA MODEL RESPONSE:")
logger.info("-" * 40)
logger.info(lora_text[:500] + "..." if len(lora_text) > 500 else lora_text)
logger.info("-" * 40)
if are_identical:
logger.warning("\nWARNING: LoRA adapter does not seem to change the model's output")
logger.warning("This could indicate that the LoRA adapter is not being properly applied")
else:
logger.info("\nLoRA adapter is working as expected (outputs are different)")
return not are_identical
def evaluate_model(
model,
tokenizer,
@ -237,73 +93,19 @@ def evaluate_model(
Args:
model: The model to evaluate
tokenizer: The tokenizer
lora_path: Path to LoRA weights (None or empty for base model, "auto" for auto-detect)
lora_path: Path to LoRA weights (None for base model)
temperature: Sampling temperature
output_file: File to write results to
trainer_dir: Directory containing the checkpoints (parent of checkpoint directory)
Returns:
dict: Evaluation results
trainer_dir: Directory containing the checkpoints
"""
sampling_params = get_sampling_params(temperature=temperature)
# --- Determine Trainer Output Directory ---
# Prioritize the directory passed from the shell script if available
if trainer_dir and os.path.isdir(trainer_dir):
trainer_output_dir = os.path.abspath(trainer_dir)
logger.info(f"Using trainer directory passed from arguments: {trainer_output_dir}")
# Set up output directory
if trainer_dir:
eval_log_dir = os.path.join(trainer_dir, "eval_logs")
else:
logger.warning(
f"Trainer directory not provided or invalid: {trainer_dir}. Attempting to determine automatically."
)
# Fallback logic if trainer_dir is not provided or invalid
temp_lora_path = lora_path
if temp_lora_path == "auto":
# Find latest checkpoint, searching within OUTPUT_DIR by default
temp_lora_path = find_latest_checkpoint() # Searches OUTPUT_DIR by default
if temp_lora_path and os.path.exists(temp_lora_path):
# If a LoRA path exists (provided or found), get its parent's parent
checkpoint_dir = os.path.dirname(os.path.abspath(temp_lora_path))
trainer_output_dir = os.path.dirname(checkpoint_dir)
logger.info(f"Determined trainer directory from LoRA path ({temp_lora_path}): {trainer_output_dir}")
else:
# If no LoRA path, default to current directory (should ideally not happen if called from eval.sh)
trainer_output_dir = os.path.abspath(".")
logger.warning(
f"Could not determine trainer directory automatically. Defaulting to current directory: {trainer_output_dir}"
)
# --- Auto-detect LoRA path if needed, searching within the determined trainer_output_dir ---
if lora_path == "auto":
# Pass the determined trainer_output_dir to find_latest_checkpoint
detected_checkpoint = find_latest_checkpoint(search_dir=trainer_output_dir)
if detected_checkpoint:
lora_path = detected_checkpoint
logger.info(f"Auto-detected latest checkpoint in {trainer_output_dir}: {lora_path}")
else:
logger.warning(f"No checkpoint found in {trainer_output_dir} for auto-detection. Evaluating base model.")
lora_path = None
model_type = "LoRA" if lora_path else "Base"
logger.info(f"\n{'=' * 50}")
logger.info(f"Starting evaluation of {model_type} model")
logger.info(f"Trainer Output Directory: {trainer_output_dir}") # Log the final directory
logger.info(f"{'=' * 50}")
# --- Create eval_logs directory ---
# Always create it inside the determined trainer_output_dir
eval_log_dir = os.path.join(trainer_output_dir, "eval_logs")
try:
os.makedirs(eval_log_dir, exist_ok=True)
logger.info(f"Ensured eval_logs directory exists at: {eval_log_dir}")
except OSError as e:
logger.error(f"Failed to create directory {eval_log_dir}: {e}")
# Fallback to current directory if creation fails
eval_log_dir = os.path.abspath("./eval_logs")
os.makedirs(eval_log_dir, exist_ok=True)
logger.warning(f"Fell back to creating eval_logs in current directory: {eval_log_dir}")
eval_log_dir = "eval_logs"
os.makedirs(eval_log_dir, exist_ok=True)
# Create file names based on model type
model_prefix = "lora" if lora_path else "base"
@ -312,32 +114,40 @@ def evaluate_model(
# Define all output file paths
eval_log_file = os.path.join(eval_log_dir, f"{model_prefix}_model_eval_{timestamp}.log")
output_file = os.path.join(eval_log_dir, f"{model_prefix}_model_results.txt")
debug_file = os.path.join(eval_log_dir, f"{model_prefix}_model_results_debug.txt")
debug_file = os.path.join(eval_log_dir, f"{model_prefix}_model_results_debug.json")
logger.info(f"Writing evaluation log to: {eval_log_file}")
logger.info(f"Results will be saved to: {output_file}")
# Function to generate completions
# Function to generate completions using agentic approach
def eval_generate_fn(inputs):
start_time = time.time()
# Format inputs as chat messages with system prompt
messages = [
{
"messages": [
{"role": "system", "content": get_system_prompt()},
{"role": "user", "content": build_user_prompt(input_text)},
]
}
for input_text in inputs
]
if lora_path:
lora_request = model.load_lora(lora_path)
load_time = time.time() - start_time
logger.info(f"LoRA adapter loaded in {load_time:.2f} seconds: {lora_request}")
responses = model.fast_generate(inputs, sampling_params=sampling_params, lora_request=lora_request)
responses = model.fast_generate(
[apply_chat_template(msg, tokenizer=tokenizer)["text"] for msg in messages],
sampling_params=sampling_params,
lora_request=lora_request,
)
else:
# For base model, add additional logging
logger.info("Generating with base model (no LoRA)")
# Also write to the base model log file directly
with open(eval_log_file, "a") as f:
f.write(f"\n{'=' * 50}\n")
f.write("BASE MODEL GENERATION\n")
f.write(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
f.write(f"Model: {MODEL_NAME}\n")
f.write(f"Temperature: {temperature}\n")
f.write(f"{'=' * 50}\n\n")
responses = model.fast_generate(inputs, sampling_params=sampling_params)
responses = model.fast_generate(
[apply_chat_template(msg, tokenizer=tokenizer)["text"] for msg in messages],
sampling_params=sampling_params,
)
gen_time = time.time() - start_time
logger.debug(f"Generation completed in {gen_time:.2f} seconds")
@ -346,13 +156,28 @@ def evaluate_model(
def verifier_generate_fn(inputs):
# Use a lower temperature for verification to get more consistent results
verifier_params = get_sampling_params(temperature=0.1)
return model.fast_generate(inputs, sampling_params=verifier_params)
# Format inputs as chat messages with system prompt
messages = [
{
"messages": [
{"role": "system", "content": get_system_prompt()},
{"role": "user", "content": build_user_prompt(input_text)},
]
}
for input_text in inputs
]
return model.fast_generate(
[apply_chat_template(msg, tokenizer=tokenizer)["text"] for msg in messages],
sampling_params=verifier_params,
)
# Prepare the verification function
verify_fn = rl_helpers.build_reward_correctness_fn(verifier_generate_fn, tokenizer, log_file=eval_log_file)
verify_fn = build_reward_correctness_fn(verifier_generate_fn, tokenizer)
# Get the dataset and prepare questions and answers
train_dataset, test_dataset = rl_helpers.get_qa_dataset()
train_dataset, test_dataset = get_qa_dataset()
questions = test_dataset["prompt"]
inputs = questions
@ -360,9 +185,11 @@ def evaluate_model(
# Run the evaluation
start_time = time.time()
model_type = "LoRA" if lora_path else "Base"
logger.info(f"Starting {model_type} model evaluation...")
full_chat_states = rl_helpers.run_eval(
# Run evaluation using the agentic approach
full_chat_states = run_eval(
generate_fn=eval_generate_fn,
verify_fn=verify_fn,
tokenizer=tokenizer,
@ -422,28 +249,16 @@ def compare_models(lora_path, temperature=0.5, output_file=None, trainer_dir=Non
Compare base model with LoRA model.
Args:
lora_path: Path to LoRA weights (use "auto" for auto-detection)
lora_path: Path to LoRA weights
temperature: Sampling temperature
output_file: File to write results to (optional, will be auto-generated if None)
trainer_dir: Directory containing the trainer output (parent of checkpoint directory)
output_file: File to write results to (optional)
trainer_dir: Directory containing the trainer output
"""
# Auto-detect checkpoint if requested
if lora_path == "auto":
search_dir = trainer_dir if trainer_dir else OUTPUT_DIR
detected_checkpoint = find_latest_checkpoint(search_dir=search_dir)
if detected_checkpoint:
lora_path = detected_checkpoint
logger.info(f"Auto-detected latest checkpoint: {lora_path}")
else:
logger.warning("No checkpoint found for auto-detection. Skipping comparison.")
return
# Set up output directory in the checkpoint directory
checkpoint_dir = os.path.dirname(lora_path)
if not trainer_dir:
trainer_dir = os.path.dirname(checkpoint_dir)
eval_log_dir = os.path.join(trainer_dir, "eval_logs")
# Set up output directory
if trainer_dir:
eval_log_dir = os.path.join(trainer_dir, "eval_logs")
else:
eval_log_dir = "eval_logs"
os.makedirs(eval_log_dir, exist_ok=True)
# Define the comparison file path if not provided
@ -456,11 +271,6 @@ def compare_models(lora_path, temperature=0.5, output_file=None, trainer_dir=Non
model, tokenizer = setup_model_and_tokenizer()
# Test if LoRA is working properly
lora_works = test_lora_functionality(model, tokenizer, lora_path)
if not lora_works:
logger.warning("LoRA adapter test failed. Results may not be reliable.")
# Evaluate both models
base_results = evaluate_model(
model,
@ -527,66 +337,26 @@ if __name__ == "__main__":
parser.add_argument(
"--lora_path",
type=str,
default="auto",
help="Path to LoRA weights (use 'auto' for auto-detection)",
default="trainer_output_example/checkpoint-101",
help="Path to LoRA weights",
)
parser.add_argument("--temperature", type=float, default=0.5, help="Sampling temperature")
parser.add_argument(
"--output_file",
type=str,
default=None,
help="File to write results to (optional, will be auto-generated if None)",
help="File to write results to (optional)",
)
parser.add_argument(
"--trainer_dir",
type=str,
default=None,
help="Directory containing the trainer output (parent of checkpoint directory)",
help="Directory containing the trainer output",
)
args = parser.parse_args()
# Auto-detect checkpoint first to set up logging directory
checkpoint_dir = None
lora_path = args.lora_path
trainer_dir = args.trainer_dir
if trainer_dir:
if os.path.exists(trainer_dir):
logger.info(f"Using provided trainer directory: {trainer_dir}")
else:
logger.warning(f"Provided trainer directory does not exist: {trainer_dir}")
trainer_dir = None
if lora_path == "auto":
search_dir = trainer_dir if trainer_dir else OUTPUT_DIR
detected_checkpoint = find_latest_checkpoint(search_dir=search_dir)
if detected_checkpoint:
lora_path = detected_checkpoint
checkpoint_dir = os.path.dirname(lora_path)
if not trainer_dir: # Only set if not provided
trainer_dir = os.path.dirname(checkpoint_dir)
# Set up logging in the trainer directory
eval_log_dir = os.path.join(trainer_dir, "eval_logs")
os.makedirs(eval_log_dir, exist_ok=True)
# If this is imported from config, use it here
try:
from src.config import update_log_path
update_log_path(eval_log_dir)
logger.info(f"Logs will be saved to both ./logs and {eval_log_dir}")
except ImportError:
logger.info("Config's update_log_path not available, using default logging")
if trainer_dir:
logger.info(f"Using trainer directory: {trainer_dir}")
logger.info(f"All evaluation files will be stored in: {os.path.join(trainer_dir, 'eval_logs')}")
else:
logger.warning("No trainer directory found, will attempt to determine during evaluation")
logger.info(f"Starting model evaluation with temperature {args.temperature}")
results = compare_models(args.lora_path, args.temperature, args.output_file, trainer_dir=trainer_dir)
results = compare_models(args.lora_path, args.temperature, args.output_file, trainer_dir=args.trainer_dir)
if results:
logger.info("Evaluation completed successfully")
logger.info(f"Final improvement: {results['improvement']:.4f}")
@ -599,14 +369,12 @@ if __name__ == "__main__":
logger.info(f"LoRA model results: {results['lora_output']}")
# Find and print all log files in the eval_logs directory
if trainer_dir:
eval_log_dir = os.path.join(trainer_dir, "eval_logs")
if os.path.exists(eval_log_dir):
log_files = [f for f in os.listdir(eval_log_dir) if f.endswith(".log")]
if log_files:
logger.info("\nEVALUATION LOG FILES:")
for log_file in log_files:
logger.info(f"- {os.path.join(eval_log_dir, log_file)}")
eval_log_dir = os.path.join(args.trainer_dir, "eval_logs") if args.trainer_dir else "eval_logs"
if os.path.exists(eval_log_dir):
log_files = [f for f in os.listdir(eval_log_dir) if f.endswith(".log")]
if log_files:
logger.info("\nEVALUATION LOG FILES:")
for log_file in log_files:
logger.info(f"- {os.path.join(eval_log_dir, log_file)}")
else:
logger.warning("Evaluation failed or was skipped")

@ -1,92 +0,0 @@
#!/bin/bash
# Script to run model comparison between base model and LoRA model
# Initialize variables
LORA_PATH=""
TEMPERATURE=0.5
# Parse command line arguments
while [[ $# -gt 0 ]]; do
case $1 in
--lora_path)
LORA_PATH="$2"
shift 2
;;
--temperature)
TEMPERATURE="$2"
shift 2
;;
--output_file)
echo "Warning: Custom output_file is not recommended. Files are automatically saved in checkpoint's eval_logs directory."
# We'll silently ignore this parameter
shift 2
;;
*)
echo "Unknown option: $1"
echo "Usage: $0 [--lora_path <path_to_checkpoint>] [--temperature <value>]"
exit 1
;;
esac
done
# If LORA_PATH is not provided, try to find the latest checkpoint
if [ -z "$LORA_PATH" ]; then
echo "No checkpoint path provided, searching for latest checkpoint..."
# Look for trainer_output directories in current directory and convert to absolute path
TRAINER_DIR=$(find . -maxdepth 1 -type d -name "trainer_output_*" | sort -r | head -n 1)
if [ -z "$TRAINER_DIR" ]; then
echo "Error: No trainer output directory found. Please provide a checkpoint path with --lora_path"
echo "Usage: $0 [--lora_path <path_to_checkpoint>] [--temperature <value>]"
exit 1
fi
# Convert to absolute path
TRAINER_DIR=$(realpath "$TRAINER_DIR")
echo "Found trainer directory: ${TRAINER_DIR}"
# Get the checkpoint path, filtering out log messages but keeping the path
LORA_PATH=$(python -c "from eval import find_latest_checkpoint; print(find_latest_checkpoint('${TRAINER_DIR}') or '')" | grep -v "INFO" | grep -v "DEBUG" | grep -v "WARNING" | grep -v "ERROR" | grep -v "LangChain" | grep -v "FAISS" | grep -v "Successfully" | grep -v "Loading" | grep -v "Project root" | grep -v "Running in" | grep -v "Automatically" | grep -v "Platform" | grep -v "Torch" | grep -v "CUDA" | grep -v "Triton" | grep -v "Bfloat16" | grep -v "Free license" | grep -v "Fast downloading" | grep -v "vLLM loading" | grep -v "==" | grep -v "^$" | tail -n 1)
if [ -z "$LORA_PATH" ]; then
echo "Error: No checkpoint found in ${TRAINER_DIR}. Please provide a checkpoint path with --lora_path"
echo "Usage: $0 [--lora_path <path_to_checkpoint>] [--temperature <value>]"
exit 1
fi
echo "Found latest checkpoint: ${LORA_PATH}"
else
# If LORA_PATH is provided, convert it to absolute path
LORA_PATH=$(realpath "$LORA_PATH")
# Get the trainer directory (parent of checkpoint directory)
TRAINER_DIR=$(dirname "$(dirname "$LORA_PATH")")
fi
# Verify checkpoint and trainer directory exist
if [ ! -d "$(dirname "$LORA_PATH")" ]; then
echo "Error: Checkpoint directory does not exist: $(dirname "$LORA_PATH")"
exit 1
fi
if [ ! -d "$TRAINER_DIR" ]; then
echo "Error: Trainer directory does not exist: $TRAINER_DIR"
exit 1
fi
# Create eval_logs directory in the trainer output directory
EVAL_LOGS_DIR="$TRAINER_DIR/eval_logs"
mkdir -p "$EVAL_LOGS_DIR"
echo "Starting model comparison..."
echo "LoRA path: ${LORA_PATH}"
echo "Trainer directory: ${TRAINER_DIR}"
echo "Temperature: ${TEMPERATURE}"
echo "Evaluation logs will be saved in: ${EVAL_LOGS_DIR}"
# Run the comparison script, explicitly passing the trainer directory
python eval.py \
--lora_path "${LORA_PATH}" \
--temperature "${TEMPERATURE}" \
--trainer_dir "${TRAINER_DIR}"
echo "Model comparison completed."
echo "Evaluation logs are saved in: ${EVAL_LOGS_DIR}"

@ -13,10 +13,12 @@ from datetime import datetime
from transformers import AutoModelForCausalLM, AutoTokenizer
from vllm import SamplingParams
from src.rl_helpers import (
from src import (
apply_chat_template,
build_user_prompt,
extract_search_query,
format_search_results,
get_system_prompt,
)
from src.search_module import load_vectorstore, search
@ -68,26 +70,15 @@ class DeepSearchCLI:
self.sampling_params = get_sampling_params(temperature)
self.history = []
self.search_history = []
self.system_prompt = (
system_prompt
or f"""Cutting Knowledge Date: December 2023
Today Date: {datetime.now().strftime("%d %b %Y")}
When you receive a tool call response, use the output to format an answer to the original user question.
You are a helpful assistant with tool calling capabilities."""
)
self.system_prompt = system_prompt or get_system_prompt()
def _run_agent_generation(self, chat_state: dict) -> dict:
"""Run a single generation step for the agent."""
formatted_prompt = self.tokenizer.apply_chat_template(
chat_state["messages"],
tokenize=False,
add_generation_prompt=True,
)
# Format the chat state using the same template as training
formatted_prompt = apply_chat_template(chat_state, tokenizer=self.tokenizer)["text"]
start_time = time.time()
inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.model.device)
inputs = self.tokenizer(formatted_prompt, return_tensors="pt", add_special_tokens=False).to(self.model.device)
outputs = self.model.generate(
**inputs,
max_new_tokens=self.sampling_params.max_tokens,
@ -118,7 +109,7 @@ You are a helpful assistant with tool calling capabilities."""
Returns:
The generated response after completing the conversation
"""
# Initialize chat state
# Initialize chat state with the same structure as training
chat_state = {
"messages": [
{"role": "system", "content": self.system_prompt},

@ -4,7 +4,7 @@ Main package exports for RL helpers.
from trl.trainer.grpo_trainer import apply_chat_template
from src.agent import Agent
from src.agent import Agent, extract_search_query
from src.config import logger
from src.evaluation import check_student_answers, run_eval, verify
from src.prompts import build_user_prompt, format_search_results, get_system_prompt
@ -27,6 +27,7 @@ __all__ = [
"Agent",
"LlamaTokenizerAdapter",
"R1DistilTokenizerAdapter",
"extract_search_query",
# Rewards
"build_reward_correctness_fn",
"reward_format",

@ -15,6 +15,13 @@ from src.search_module import search
from src.tokenizer_adapter import TokenizerAdapter
def extract_search_query(text: str) -> str | None:
"""Extract search query from text between <search> tags."""
pattern = re.compile(r"<search>(.*?)</search>", re.DOTALL)
matches = pattern.findall(text)
return matches[-1] if matches else None
@dataclass
class AgenticOutputs:
"""Outputs from running the agent on a batch of questions."""
@ -42,12 +49,6 @@ class Agent:
]
}
def extract_search_query(self, text: str) -> str | None:
"""Extract search query from text between <search> tags."""
pattern = re.compile(r"<search>(.*?)</search>", re.DOTALL)
matches = pattern.findall(text)
return matches[-1] if matches else None
def run_agent_generations(self, generate_fn, tokenizer, chat_states: list[dict]) -> list[dict]:
"""Run generation for chat states requiring assistant responses."""
logger.debug(f"Starting generation for {len(chat_states)} chat states")
@ -109,7 +110,7 @@ class Agent:
)
try:
assistant_response = chat_state["messages"][-1]["content"]
search_query = self.extract_search_query(assistant_response)
search_query = extract_search_query(assistant_response)
if search_query:
logger.info(f"🔍 Search Query: {search_query}")
results = search(search_query, return_type=str, results=2)

Loading…
Cancel
Save