|
|
@ -1,7 +1,10 @@
|
|
|
|
"""Simple script to evaluate LoRA model performance."""
|
|
|
|
"""Script to evaluate LoRA model performance with enhanced debugging."""
|
|
|
|
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
import argparse
|
|
|
|
|
|
|
|
import json
|
|
|
|
|
|
|
|
import os
|
|
|
|
import sys
|
|
|
|
import sys
|
|
|
|
|
|
|
|
from datetime import datetime
|
|
|
|
from pathlib import Path
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
|
|
# Add project root to Python path
|
|
|
|
# Add project root to Python path
|
|
|
@ -15,9 +18,10 @@ from src import (
|
|
|
|
apply_chat_template,
|
|
|
|
apply_chat_template,
|
|
|
|
build_reward_correctness_fn,
|
|
|
|
build_reward_correctness_fn,
|
|
|
|
build_user_prompt,
|
|
|
|
build_user_prompt,
|
|
|
|
get_qa_dataset,
|
|
|
|
|
|
|
|
get_system_prompt,
|
|
|
|
get_system_prompt,
|
|
|
|
|
|
|
|
run_eval,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
from src.config import logger
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
def main():
|
|
|
@ -25,40 +29,46 @@ def main():
|
|
|
|
parser = argparse.ArgumentParser(description="Evaluate LoRA model")
|
|
|
|
parser = argparse.ArgumentParser(description="Evaluate LoRA model")
|
|
|
|
parser.add_argument("--model_name", type=str, required=True, help="Name/path of the base model")
|
|
|
|
parser.add_argument("--model_name", type=str, required=True, help="Name/path of the base model")
|
|
|
|
parser.add_argument("--lora_path", type=str, required=True, help="Path to LoRA weights")
|
|
|
|
parser.add_argument("--lora_path", type=str, required=True, help="Path to LoRA weights")
|
|
|
|
|
|
|
|
parser.add_argument("--temperature", type=float, default=0, help="Sampling temperature")
|
|
|
|
args = parser.parse_args()
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
print(f"🚀 Setting up model {args.model_name} with LoRA from {args.lora_path}...")
|
|
|
|
logger.info(f"🚀 Setting up model {args.model_name} with LoRA from {args.lora_path}...")
|
|
|
|
|
|
|
|
|
|
|
|
# Setup model with LoRA support
|
|
|
|
# Load LoRA config first to get max rank
|
|
|
|
|
|
|
|
with open(f"{args.lora_path}/adapter_config.json") as f:
|
|
|
|
|
|
|
|
lora_config = json.load(f)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Setup model with LoRA support using config values
|
|
|
|
model, tokenizer = FastLanguageModel.from_pretrained(
|
|
|
|
model, tokenizer = FastLanguageModel.from_pretrained(
|
|
|
|
model_name=args.model_name,
|
|
|
|
model_name=args.model_name,
|
|
|
|
max_seq_length=4096 * 2,
|
|
|
|
max_seq_length=4096 * 2,
|
|
|
|
load_in_4bit=True,
|
|
|
|
load_in_4bit=True,
|
|
|
|
fast_inference=True,
|
|
|
|
fast_inference=True,
|
|
|
|
max_lora_rank=64,
|
|
|
|
max_lora_rank=lora_config["r"], # Use rank from config
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# Setup LoRA
|
|
|
|
# Setup LoRA using config
|
|
|
|
model = FastLanguageModel.get_peft_model(
|
|
|
|
model = FastLanguageModel.get_peft_model(
|
|
|
|
model,
|
|
|
|
model,
|
|
|
|
r=64,
|
|
|
|
r=lora_config["r"],
|
|
|
|
target_modules=[
|
|
|
|
target_modules=lora_config["target_modules"],
|
|
|
|
"q_proj",
|
|
|
|
lora_alpha=lora_config["lora_alpha"],
|
|
|
|
"k_proj",
|
|
|
|
lora_dropout=lora_config["lora_dropout"],
|
|
|
|
"v_proj",
|
|
|
|
bias=lora_config["bias"],
|
|
|
|
"o_proj",
|
|
|
|
|
|
|
|
"gate_proj",
|
|
|
|
|
|
|
|
"up_proj",
|
|
|
|
|
|
|
|
"down_proj",
|
|
|
|
|
|
|
|
],
|
|
|
|
|
|
|
|
lora_alpha=64,
|
|
|
|
|
|
|
|
use_gradient_checkpointing=True,
|
|
|
|
use_gradient_checkpointing=True,
|
|
|
|
random_state=3407,
|
|
|
|
random_state=3407,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# Setup sampling params
|
|
|
|
# Setup sampling params
|
|
|
|
sampling_params = SamplingParams(
|
|
|
|
sampling_params = SamplingParams(
|
|
|
|
temperature=0.5,
|
|
|
|
temperature=args.temperature,
|
|
|
|
|
|
|
|
top_p=0.95,
|
|
|
|
|
|
|
|
max_tokens=4096,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Setup verifier with lower temperature
|
|
|
|
|
|
|
|
verifier_params = SamplingParams(
|
|
|
|
|
|
|
|
temperature=0.1, # Lower temperature for more consistent verification
|
|
|
|
top_p=0.95,
|
|
|
|
top_p=0.95,
|
|
|
|
max_tokens=4096,
|
|
|
|
max_tokens=4096,
|
|
|
|
)
|
|
|
|
)
|
|
|
@ -81,41 +91,53 @@ def main():
|
|
|
|
sampling_params=sampling_params,
|
|
|
|
sampling_params=sampling_params,
|
|
|
|
lora_request=lora_request,
|
|
|
|
lora_request=lora_request,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
return outputs
|
|
|
|
|
|
|
|
|
|
|
|
# Format outputs as chat messages
|
|
|
|
def verifier_generate_fn(inputs):
|
|
|
|
formatted_outputs = []
|
|
|
|
"""Generate verification responses with lower temperature."""
|
|
|
|
for output in outputs:
|
|
|
|
messages = [
|
|
|
|
formatted_outputs.append(
|
|
|
|
{
|
|
|
|
{
|
|
|
|
"messages": [
|
|
|
|
"messages": [
|
|
|
|
{"role": "system", "content": get_system_prompt()},
|
|
|
|
{"role": "system", "content": get_system_prompt()},
|
|
|
|
{"role": "user", "content": build_user_prompt(input_text)},
|
|
|
|
{"role": "assistant", "content": output.outputs[0].text},
|
|
|
|
]
|
|
|
|
]
|
|
|
|
}
|
|
|
|
}
|
|
|
|
for input_text in inputs
|
|
|
|
)
|
|
|
|
]
|
|
|
|
return formatted_outputs
|
|
|
|
|
|
|
|
|
|
|
|
lora_request = model.load_lora(args.lora_path)
|
|
|
|
# Get dataset
|
|
|
|
return model.fast_generate(
|
|
|
|
_, test_dataset = get_qa_dataset()
|
|
|
|
[apply_chat_template(msg, tokenizer=tokenizer)["text"] for msg in messages],
|
|
|
|
questions = test_dataset["prompt"]
|
|
|
|
sampling_params=verifier_params,
|
|
|
|
answers = test_dataset["answer"]
|
|
|
|
lora_request=lora_request,
|
|
|
|
|
|
|
|
)
|
|
|
|
print(f"📝 Evaluating {len(questions)} questions...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Build verifier
|
|
|
|
# Build verifier
|
|
|
|
verify_fn = build_reward_correctness_fn(generate_fn, tokenizer)
|
|
|
|
verify_fn = build_reward_correctness_fn(verifier_generate_fn, tokenizer)
|
|
|
|
|
|
|
|
|
|
|
|
# Run evaluation
|
|
|
|
# Setup output directories
|
|
|
|
completions = generate_fn(questions)
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
|
|
rewards = verify_fn(questions, completions, answer=answers)
|
|
|
|
eval_log_dir = "eval_logs"
|
|
|
|
accuracy = sum(rewards) / len(rewards)
|
|
|
|
os.makedirs(eval_log_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\n{'=' * 50}")
|
|
|
|
output_file = os.path.join(eval_log_dir, f"lora_model_results_{timestamp}.txt")
|
|
|
|
print("🎯 LORA MODEL EVALUATION RESULTS:")
|
|
|
|
debug_file = os.path.join(eval_log_dir, f"lora_model_debug_{timestamp}.json")
|
|
|
|
print(f"{'=' * 50}")
|
|
|
|
|
|
|
|
print(f"✨ Base Model: {args.model_name}")
|
|
|
|
logger.info("📝 Starting evaluation...")
|
|
|
|
print(f"🔧 LoRA Path: {args.lora_path}")
|
|
|
|
logger.info(f"Results will be saved to: {output_file}")
|
|
|
|
print(f"📊 Accuracy: {accuracy:.4f} ({sum(rewards)}/{len(rewards)} correct)")
|
|
|
|
logger.info(f"Debug info will be saved to: {debug_file}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Run evaluation using the agentic approach
|
|
|
|
|
|
|
|
full_chat_states = run_eval(
|
|
|
|
|
|
|
|
generate_fn=generate_fn,
|
|
|
|
|
|
|
|
verify_fn=verify_fn,
|
|
|
|
|
|
|
|
tokenizer=tokenizer,
|
|
|
|
|
|
|
|
output_file=output_file,
|
|
|
|
|
|
|
|
debug_file=debug_file,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info("✨ Evaluation completed!")
|
|
|
|
|
|
|
|
logger.info(f"Check {output_file} for detailed results")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
if __name__ == "__main__":
|
|
|
|