@ -1,83 +1,27 @@
"""
"""
Compare base model with LoRA model performance .
Evaluate model performance using vLLM and unsloth .
This script evaluates and compares the performance of a base model against
This script evaluates the performance of a model using vLLM for fast inference
the same model with a LoRA adapter applied .
and unsloth for LoRA support .
"""
"""
import argparse
import argparse
import glob
import os
import os
import re
import time
import time
from datetime import datetime
from datetime import datetime
from unsloth import FastLanguageModel
from unsloth import FastLanguageModel
from vllm import SamplingParams
from vllm import SamplingParams
import src . rl_helpers as rl_helpers
from src import (
from src . config import MODEL_NAME , OUTPUT_DIR , logger
apply_chat_template ,
build_reward_correctness_fn ,
build_user_prompt ,
def find_latest_checkpoint ( search_dir = None ) :
get_qa_dataset ,
"""
get_system_prompt ,
Find the latest checkpoint in the specified directory or OUTPUT_DIR .
run_eval ,
)
Args :
from src . config import MODEL_NAME , logger
search_dir : Directory to search for checkpoints ( default : OUTPUT_DIR )
Returns :
Path to the latest checkpoint or None if no checkpoints found
"""
if search_dir is None :
search_dir = OUTPUT_DIR
logger . info ( f " No search directory provided, using default: { search_dir } " )
else :
logger . info ( f " Searching for checkpoints in: { search_dir } " )
# Check if the directory exists first
if not os . path . exists ( search_dir ) :
logger . warning ( f " Search directory { search_dir } does not exist " )
return None
# First try to find checkpoints in the format checkpoint-{step}
checkpoints = glob . glob ( os . path . join ( search_dir , " checkpoint-* " ) )
if checkpoints :
# Extract checkpoint numbers and sort
checkpoint_numbers = [ ]
for checkpoint in checkpoints :
match = re . search ( r " checkpoint-( \ d+)$ " , checkpoint )
if match :
checkpoint_numbers . append ( ( int ( match . group ( 1 ) ) , checkpoint ) )
if checkpoint_numbers :
# Sort by checkpoint number (descending)
checkpoint_numbers . sort ( reverse = True )
latest = checkpoint_numbers [ 0 ] [ 1 ]
logger . info ( f " Found latest checkpoint: { latest } " )
return latest
# If no checkpoints found, look for saved_adapter_{timestamp}.bin files
adapter_files = glob . glob ( os . path . join ( search_dir , " saved_adapter_*.bin " ) )
if adapter_files :
# Sort by modification time (newest first)
adapter_files . sort ( key = os . path . getmtime , reverse = True )
latest = adapter_files [ 0 ]
logger . info ( f " Found latest adapter file: { latest } " )
return latest
# If all else fails, look for any .bin files
bin_files = glob . glob ( os . path . join ( search_dir , " *.bin " ) )
if bin_files :
# Sort by modification time (newest first)
bin_files . sort ( key = os . path . getmtime , reverse = True )
latest = bin_files [ 0 ]
logger . info ( f " Found latest .bin file: { latest } " )
return latest
logger . warning ( f " No checkpoints found in { search_dir } " )
return None
def get_model_config ( ) :
def get_model_config ( ) :
@ -99,7 +43,7 @@ def get_model_config():
}
}
def get_sampling_params ( temperature : float = 0.5 ) - > SamplingParams :
def get_sampling_params ( temperature = 0.5 ) :
""" Get sampling parameters for generation. """
""" Get sampling parameters for generation. """
return SamplingParams (
return SamplingParams (
temperature = temperature ,
temperature = temperature ,
@ -135,94 +79,6 @@ def setup_model_and_tokenizer():
return model , tokenizer
return model , tokenizer
def test_lora_functionality ( model , tokenizer , lora_path ) :
"""
Test if LoRA is working properly by doing a direct comparison on a simple prompt .
Args :
model : The model to test
tokenizer : The tokenizer
lora_path : Path to LoRA weights
Returns :
bool : True if LoRA is working properly
"""
logger . info ( f " \n { ' = ' * 50 } " )
logger . info ( " TESTING LORA FUNCTIONALITY " )
logger . info ( f " { ' = ' * 50 } " )
# First check if LoRA path exists
if not os . path . exists ( lora_path ) :
logger . error ( f " ERROR: LoRA path does not exist: { lora_path } " )
return False
logger . info ( f " LoRA path exists: { lora_path } " )
# Test prompt
test_prompt = " Explain the concept of Low-Rank Adaptation (LoRA) in one paragraph: "
# Format prompt for model
formatted_prompt = tokenizer . apply_chat_template (
[ { " role " : " user " , " content " : test_prompt } ] ,
tokenize = False ,
add_generation_prompt = True ,
)
# Sample with base model
logger . info ( " Generating with base model... " )
sampling_params = get_sampling_params ( temperature = 0.7 ) # Higher temp to make differences more obvious
base_response = model . fast_generate (
[ formatted_prompt ] ,
sampling_params = sampling_params ,
)
if hasattr ( base_response [ 0 ] , " outputs " ) :
base_text = base_response [ 0 ] . outputs [ 0 ] . text
else :
base_text = base_response [ 0 ]
# Sample with LoRA
logger . info ( f " Loading LoRA adapter from { lora_path } ... " )
lora_request = model . load_lora ( lora_path )
if lora_request is None :
logger . error ( " ERROR: Failed to load LoRA adapter " )
return False
logger . info ( f " LoRA adapter loaded successfully: { lora_request } " )
logger . info ( " Generating with LoRA model... " )
lora_response = model . fast_generate (
[ formatted_prompt ] ,
sampling_params = sampling_params ,
lora_request = lora_request ,
)
if hasattr ( lora_response [ 0 ] , " outputs " ) :
lora_text = lora_response [ 0 ] . outputs [ 0 ] . text
else :
lora_text = lora_response [ 0 ]
# Check if responses are different
are_identical = base_text == lora_text
logger . info ( f " \n Responses are { ' identical ' if are_identical else ' different ' } " )
logger . info ( " \n BASE MODEL RESPONSE: " )
logger . info ( " - " * 40 )
logger . info ( base_text [ : 500 ] + " ... " if len ( base_text ) > 500 else base_text )
logger . info ( " - " * 40 )
logger . info ( " \n LoRA MODEL RESPONSE: " )
logger . info ( " - " * 40 )
logger . info ( lora_text [ : 500 ] + " ... " if len ( lora_text ) > 500 else lora_text )
logger . info ( " - " * 40 )
if are_identical :
logger . warning ( " \n WARNING: LoRA adapter does not seem to change the model ' s output " )
logger . warning ( " This could indicate that the LoRA adapter is not being properly applied " )
else :
logger . info ( " \n LoRA adapter is working as expected (outputs are different) " )
return not are_identical
def evaluate_model (
def evaluate_model (
model ,
model ,
tokenizer ,
tokenizer ,
@ -237,73 +93,19 @@ def evaluate_model(
Args :
Args :
model : The model to evaluate
model : The model to evaluate
tokenizer : The tokenizer
tokenizer : The tokenizer
lora_path : Path to LoRA weights ( None or empty for base model , " auto " for auto - detect )
lora_path : Path to LoRA weights ( None for base model )
temperature : Sampling temperature
temperature : Sampling temperature
output_file : File to write results to
output_file : File to write results to
trainer_dir : Directory containing the checkpoints ( parent of checkpoint directory )
trainer_dir : Directory containing the checkpoints
Returns :
dict : Evaluation results
"""
"""
sampling_params = get_sampling_params ( temperature = temperature )
sampling_params = get_sampling_params ( temperature = temperature )
# --- Determine Trainer Output Directory ---
# Set up output directory
# Prioritize the directory passed from the shell script if available
if trainer_dir :
if trainer_dir and os . path . isdir ( trainer_dir ) :
eval_log_dir = os . path . join ( trainer_dir , " eval_logs " )
trainer_output_dir = os . path . abspath ( trainer_dir )
logger . info ( f " Using trainer directory passed from arguments: { trainer_output_dir } " )
else :
else :
logger . warning (
eval_log_dir = " eval_logs "
f " Trainer directory not provided or invalid: { trainer_dir } . Attempting to determine automatically. "
os . makedirs ( eval_log_dir , exist_ok = True )
)
# Fallback logic if trainer_dir is not provided or invalid
temp_lora_path = lora_path
if temp_lora_path == " auto " :
# Find latest checkpoint, searching within OUTPUT_DIR by default
temp_lora_path = find_latest_checkpoint ( ) # Searches OUTPUT_DIR by default
if temp_lora_path and os . path . exists ( temp_lora_path ) :
# If a LoRA path exists (provided or found), get its parent's parent
checkpoint_dir = os . path . dirname ( os . path . abspath ( temp_lora_path ) )
trainer_output_dir = os . path . dirname ( checkpoint_dir )
logger . info ( f " Determined trainer directory from LoRA path ( { temp_lora_path } ): { trainer_output_dir } " )
else :
# If no LoRA path, default to current directory (should ideally not happen if called from eval.sh)
trainer_output_dir = os . path . abspath ( " . " )
logger . warning (
f " Could not determine trainer directory automatically. Defaulting to current directory: { trainer_output_dir } "
)
# --- Auto-detect LoRA path if needed, searching within the determined trainer_output_dir ---
if lora_path == " auto " :
# Pass the determined trainer_output_dir to find_latest_checkpoint
detected_checkpoint = find_latest_checkpoint ( search_dir = trainer_output_dir )
if detected_checkpoint :
lora_path = detected_checkpoint
logger . info ( f " Auto-detected latest checkpoint in { trainer_output_dir } : { lora_path } " )
else :
logger . warning ( f " No checkpoint found in { trainer_output_dir } for auto-detection. Evaluating base model. " )
lora_path = None
model_type = " LoRA " if lora_path else " Base "
logger . info ( f " \n { ' = ' * 50 } " )
logger . info ( f " Starting evaluation of { model_type } model " )
logger . info ( f " Trainer Output Directory: { trainer_output_dir } " ) # Log the final directory
logger . info ( f " { ' = ' * 50 } " )
# --- Create eval_logs directory ---
# Always create it inside the determined trainer_output_dir
eval_log_dir = os . path . join ( trainer_output_dir , " eval_logs " )
try :
os . makedirs ( eval_log_dir , exist_ok = True )
logger . info ( f " Ensured eval_logs directory exists at: { eval_log_dir } " )
except OSError as e :
logger . error ( f " Failed to create directory { eval_log_dir } : { e } " )
# Fallback to current directory if creation fails
eval_log_dir = os . path . abspath ( " ./eval_logs " )
os . makedirs ( eval_log_dir , exist_ok = True )
logger . warning ( f " Fell back to creating eval_logs in current directory: { eval_log_dir } " )
# Create file names based on model type
# Create file names based on model type
model_prefix = " lora " if lora_path else " base "
model_prefix = " lora " if lora_path else " base "
@ -312,32 +114,40 @@ def evaluate_model(
# Define all output file paths
# Define all output file paths
eval_log_file = os . path . join ( eval_log_dir , f " { model_prefix } _model_eval_ { timestamp } .log " )
eval_log_file = os . path . join ( eval_log_dir , f " { model_prefix } _model_eval_ { timestamp } .log " )
output_file = os . path . join ( eval_log_dir , f " { model_prefix } _model_results.txt " )
output_file = os . path . join ( eval_log_dir , f " { model_prefix } _model_results.txt " )
debug_file = os . path . join ( eval_log_dir , f " { model_prefix } _model_results_debug. txt " )
debug_file = os . path . join ( eval_log_dir , f " { model_prefix } _model_results_debug. json " )
logger . info ( f " Writing evaluation log to: { eval_log_file } " )
logger . info ( f " Writing evaluation log to: { eval_log_file } " )
logger . info ( f " Results will be saved to: { output_file } " )
logger . info ( f " Results will be saved to: { output_file } " )
# Function to generate completions
# Function to generate completions using agentic approach
def eval_generate_fn ( inputs ) :
def eval_generate_fn ( inputs ) :
start_time = time . time ( )
start_time = time . time ( )
# Format inputs as chat messages with system prompt
messages = [
{
" messages " : [
{ " role " : " system " , " content " : get_system_prompt ( ) } ,
{ " role " : " user " , " content " : build_user_prompt ( input_text ) } ,
]
}
for input_text in inputs
]
if lora_path :
if lora_path :
lora_request = model . load_lora ( lora_path )
lora_request = model . load_lora ( lora_path )
load_time = time . time ( ) - start_time
load_time = time . time ( ) - start_time
logger . info ( f " LoRA adapter loaded in { load_time : .2f } seconds: { lora_request } " )
logger . info ( f " LoRA adapter loaded in { load_time : .2f } seconds: { lora_request } " )
responses = model . fast_generate ( inputs , sampling_params = sampling_params , lora_request = lora_request )
responses = model . fast_generate (
[ apply_chat_template ( msg , tokenizer = tokenizer ) [ " text " ] for msg in messages ] ,
sampling_params = sampling_params ,
lora_request = lora_request ,
)
else :
else :
# For base model, add additional logging
responses = model . fast_generate (
logger . info ( " Generating with base model (no LoRA) " )
[ apply_chat_template ( msg , tokenizer = tokenizer ) [ " text " ] for msg in messages ] ,
# Also write to the base model log file directly
sampling_params = sampling_params ,
with open ( eval_log_file , " a " ) as f :
)
f . write ( f " \n { ' = ' * 50 } \n " )
f . write ( " BASE MODEL GENERATION \n " )
f . write ( f " Timestamp: { datetime . now ( ) . strftime ( ' % Y- % m- %d % H: % M: % S ' ) } \n " )
f . write ( f " Model: { MODEL_NAME } \n " )
f . write ( f " Temperature: { temperature } \n " )
f . write ( f " { ' = ' * 50 } \n \n " )
responses = model . fast_generate ( inputs , sampling_params = sampling_params )
gen_time = time . time ( ) - start_time
gen_time = time . time ( ) - start_time
logger . debug ( f " Generation completed in { gen_time : .2f } seconds " )
logger . debug ( f " Generation completed in { gen_time : .2f } seconds " )
@ -346,13 +156,28 @@ def evaluate_model(
def verifier_generate_fn ( inputs ) :
def verifier_generate_fn ( inputs ) :
# Use a lower temperature for verification to get more consistent results
# Use a lower temperature for verification to get more consistent results
verifier_params = get_sampling_params ( temperature = 0.1 )
verifier_params = get_sampling_params ( temperature = 0.1 )
return model . fast_generate ( inputs , sampling_params = verifier_params )
# Format inputs as chat messages with system prompt
messages = [
{
" messages " : [
{ " role " : " system " , " content " : get_system_prompt ( ) } ,
{ " role " : " user " , " content " : build_user_prompt ( input_text ) } ,
]
}
for input_text in inputs
]
return model . fast_generate (
[ apply_chat_template ( msg , tokenizer = tokenizer ) [ " text " ] for msg in messages ] ,
sampling_params = verifier_params ,
)
# Prepare the verification function
# Prepare the verification function
verify_fn = rl_helpers . build_reward_correctness_fn ( verifier_generate_fn , tokenizer , log_file = eval_log_file )
verify_fn = build_reward_correctness_fn( verifier_generate_fn , tokenizer )
# Get the dataset and prepare questions and answers
# Get the dataset and prepare questions and answers
train_dataset , test_dataset = rl_helpers . get_qa_dataset ( )
train_dataset , test_dataset = get_qa_dataset( )
questions = test_dataset [ " prompt " ]
questions = test_dataset [ " prompt " ]
inputs = questions
inputs = questions
@ -360,9 +185,11 @@ def evaluate_model(
# Run the evaluation
# Run the evaluation
start_time = time . time ( )
start_time = time . time ( )
model_type = " LoRA " if lora_path else " Base "
logger . info ( f " Starting { model_type } model evaluation... " )
logger . info ( f " Starting { model_type } model evaluation... " )
full_chat_states = rl_helpers . run_eval (
# Run evaluation using the agentic approach
full_chat_states = run_eval (
generate_fn = eval_generate_fn ,
generate_fn = eval_generate_fn ,
verify_fn = verify_fn ,
verify_fn = verify_fn ,
tokenizer = tokenizer ,
tokenizer = tokenizer ,
@ -422,28 +249,16 @@ def compare_models(lora_path, temperature=0.5, output_file=None, trainer_dir=Non
Compare base model with LoRA model .
Compare base model with LoRA model .
Args :
Args :
lora_path : Path to LoRA weights ( use " auto " for auto - detection )
lora_path : Path to LoRA weights
temperature : Sampling temperature
temperature : Sampling temperature
output_file : File to write results to ( optional , will be auto - generated if None )
output_file : File to write results to ( optional )
trainer_dir : Directory containing the trainer output ( parent of checkpoint directory )
trainer_dir : Directory containing the trainer output
"""
"""
# Auto-detect checkpoint if requested
# Set up output directory
if lora_path == " auto " :
if trainer_dir :
search_dir = trainer_dir if trainer_dir else OUTPUT_DIR
eval_log_dir = os . path . join ( trainer_dir , " eval_logs " )
detected_checkpoint = find_latest_checkpoint ( search_dir = search_dir )
else :
if detected_checkpoint :
eval_log_dir = " eval_logs "
lora_path = detected_checkpoint
logger . info ( f " Auto-detected latest checkpoint: { lora_path } " )
else :
logger . warning ( " No checkpoint found for auto-detection. Skipping comparison. " )
return
# Set up output directory in the checkpoint directory
checkpoint_dir = os . path . dirname ( lora_path )
if not trainer_dir :
trainer_dir = os . path . dirname ( checkpoint_dir )
eval_log_dir = os . path . join ( trainer_dir , " eval_logs " )
os . makedirs ( eval_log_dir , exist_ok = True )
os . makedirs ( eval_log_dir , exist_ok = True )
# Define the comparison file path if not provided
# Define the comparison file path if not provided
@ -456,11 +271,6 @@ def compare_models(lora_path, temperature=0.5, output_file=None, trainer_dir=Non
model , tokenizer = setup_model_and_tokenizer ( )
model , tokenizer = setup_model_and_tokenizer ( )
# Test if LoRA is working properly
lora_works = test_lora_functionality ( model , tokenizer , lora_path )
if not lora_works :
logger . warning ( " LoRA adapter test failed. Results may not be reliable. " )
# Evaluate both models
# Evaluate both models
base_results = evaluate_model (
base_results = evaluate_model (
model ,
model ,
@ -527,66 +337,26 @@ if __name__ == "__main__":
parser . add_argument (
parser . add_argument (
" --lora_path " ,
" --lora_path " ,
type = str ,
type = str ,
default = " auto" ,
default = " tr ainer_o utput_example/checkp oint-101 " ,
help = " Path to LoRA weights (use ' auto ' for auto-detection) " ,
help = " Path to LoRA weights " ,
)
)
parser . add_argument ( " --temperature " , type = float , default = 0.5 , help = " Sampling temperature " )
parser . add_argument ( " --temperature " , type = float , default = 0.5 , help = " Sampling temperature " )
parser . add_argument (
parser . add_argument (
" --output_file " ,
" --output_file " ,
type = str ,
type = str ,
default = None ,
default = None ,
help = " File to write results to (optional , will be auto-generated if None )" ,
help = " File to write results to (optional )" ,
)
)
parser . add_argument (
parser . add_argument (
" --trainer_dir " ,
" --trainer_dir " ,
type = str ,
type = str ,
default = None ,
default = None ,
help = " Directory containing the trainer output (parent of checkpoint directory) " ,
help = " Directory containing the trainer output " ,
)
)
args = parser . parse_args ( )
args = parser . parse_args ( )
# Auto-detect checkpoint first to set up logging directory
checkpoint_dir = None
lora_path = args . lora_path
trainer_dir = args . trainer_dir
if trainer_dir :
if os . path . exists ( trainer_dir ) :
logger . info ( f " Using provided trainer directory: { trainer_dir } " )
else :
logger . warning ( f " Provided trainer directory does not exist: { trainer_dir } " )
trainer_dir = None
if lora_path == " auto " :
search_dir = trainer_dir if trainer_dir else OUTPUT_DIR
detected_checkpoint = find_latest_checkpoint ( search_dir = search_dir )
if detected_checkpoint :
lora_path = detected_checkpoint
checkpoint_dir = os . path . dirname ( lora_path )
if not trainer_dir : # Only set if not provided
trainer_dir = os . path . dirname ( checkpoint_dir )
# Set up logging in the trainer directory
eval_log_dir = os . path . join ( trainer_dir , " eval_logs " )
os . makedirs ( eval_log_dir , exist_ok = True )
# If this is imported from config, use it here
try :
from src . config import update_log_path
update_log_path ( eval_log_dir )
logger . info ( f " Logs will be saved to both ./logs and { eval_log_dir } " )
except ImportError :
logger . info ( " Config ' s update_log_path not available, using default logging " )
if trainer_dir :
logger . info ( f " Using trainer directory: { trainer_dir } " )
logger . info ( f " All evaluation files will be stored in: { os . path . join ( trainer_dir , ' eval_logs ' ) } " )
else :
logger . warning ( " No trainer directory found, will attempt to determine during evaluation " )
logger . info ( f " Starting model evaluation with temperature { args . temperature } " )
logger . info ( f " Starting model evaluation with temperature { args . temperature } " )
results = compare_models ( args . lora_path , args . temperature , args . output_file , trainer_dir = trainer_dir)
results = compare_models ( args . lora_path , args . temperature , args . output_file , trainer_dir = args . trainer_dir )
if results :
if results :
logger . info ( " Evaluation completed successfully " )
logger . info ( " Evaluation completed successfully " )
logger . info ( f " Final improvement: { results [ ' improvement ' ] : .4f } " )
logger . info ( f " Final improvement: { results [ ' improvement ' ] : .4f } " )
@ -599,14 +369,12 @@ if __name__ == "__main__":
logger . info ( f " LoRA model results: { results [ ' lora_output ' ] } " )
logger . info ( f " LoRA model results: { results [ ' lora_output ' ] } " )
# Find and print all log files in the eval_logs directory
# Find and print all log files in the eval_logs directory
if trainer_dir :
eval_log_dir = os . path . join ( args . trainer_dir , " eval_logs " ) if args . trainer_dir else " eval_logs "
eval_log_dir = os . path . join ( trainer_dir , " eval_logs " )
if os . path . exists ( eval_log_dir ) :
if os . path . exists ( eval_log_dir ) :
log_files = [ f for f in os . listdir ( eval_log_dir ) if f . endswith ( " .log " ) ]
log_files = [ f for f in os . listdir ( eval_log_dir ) if f . endswith ( " .log " ) ]
if log_files :
logger . info ( " \n EVALUATION LOG FILES: " )
if log_files :
for log_file in log_files :
logger . info ( " \n EVALUATION LOG FILES: " )
logger . info ( f " - { os . path . join ( eval_log_dir , log_file ) } " )
for log_file in log_files :
logger . info ( f " - { os . path . join ( eval_log_dir , log_file ) } " )
else :
else :
logger . warning ( " Evaluation failed or was skipped " )
logger . warning ( " Evaluation failed or was skipped " )