feat: enhance evaluation scripts for base and LoRA models

main
thinhlpg 1 month ago
parent da60b52bd1
commit 6d994feeb2

@ -1,7 +1,9 @@
"""Simple script to evaluate base model performance.""" """Simple script to evaluate base model performance."""
import argparse import argparse
import os
import sys import sys
from datetime import datetime
from pathlib import Path from pathlib import Path
# Add project root to Python path # Add project root to Python path
@ -15,18 +17,20 @@ from src import (
apply_chat_template, apply_chat_template,
build_reward_correctness_fn, build_reward_correctness_fn,
build_user_prompt, build_user_prompt,
get_qa_dataset,
get_system_prompt, get_system_prompt,
run_eval,
) )
from src.config import logger
def main(): def main():
"""Run base model evaluation.""" """Run base model evaluation."""
parser = argparse.ArgumentParser(description="Evaluate base model") parser = argparse.ArgumentParser(description="Evaluate base model")
parser.add_argument("--model_name", type=str, required=True, help="Name/path of the model to evaluate") parser.add_argument("--model_name", type=str, required=True, help="Name/path of the model to evaluate")
parser.add_argument("--temperature", type=float, default=0, help="Sampling temperature")
args = parser.parse_args() args = parser.parse_args()
print(f"🚀 Setting up model {args.model_name}...") logger.info(f"🚀 Setting up model {args.model_name}...")
# Setup model # Setup model
model, tokenizer = FastLanguageModel.from_pretrained( model, tokenizer = FastLanguageModel.from_pretrained(
@ -38,7 +42,14 @@ def main():
# Setup sampling params # Setup sampling params
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=0.5, temperature=args.temperature,
top_p=0.95,
max_tokens=4096,
)
# Setup verifier with lower temperature
verifier_params = SamplingParams(
temperature=0.1, # Lower temperature for more consistent verification
top_p=0.95, top_p=0.95,
max_tokens=4096, max_tokens=4096,
) )
@ -59,40 +70,51 @@ def main():
[apply_chat_template(msg, tokenizer=tokenizer)["text"] for msg in messages], [apply_chat_template(msg, tokenizer=tokenizer)["text"] for msg in messages],
sampling_params=sampling_params, sampling_params=sampling_params,
) )
return outputs
# Format outputs as chat messages def verifier_generate_fn(inputs):
formatted_outputs = [] """Generate verification responses with lower temperature."""
for output in outputs: messages = [
formatted_outputs.append(
{ {
"messages": [ "messages": [
{"role": "system", "content": get_system_prompt()}, {"role": "system", "content": get_system_prompt()},
{"role": "assistant", "content": output.outputs[0].text}, {"role": "user", "content": build_user_prompt(input_text)},
] ]
} }
) for input_text in inputs
return formatted_outputs ]
# Get dataset
_, test_dataset = get_qa_dataset()
questions = test_dataset["prompt"]
answers = test_dataset["answer"]
print(f"📝 Evaluating {len(questions)} questions...") return model.fast_generate(
[apply_chat_template(msg, tokenizer=tokenizer)["text"] for msg in messages],
sampling_params=verifier_params,
)
# Build verifier # Build verifier
verify_fn = build_reward_correctness_fn(generate_fn, tokenizer) verify_fn = build_reward_correctness_fn(verifier_generate_fn, tokenizer)
# Run evaluation # Setup output directories
completions = generate_fn(questions) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
rewards = verify_fn(questions, completions, answer=answers) eval_log_dir = "eval_logs"
accuracy = sum(rewards) / len(rewards) os.makedirs(eval_log_dir, exist_ok=True)
print(f"\n{'=' * 50}") output_file = os.path.join(eval_log_dir, f"base_model_results_{timestamp}.txt")
print("🎯 BASE MODEL EVALUATION RESULTS:") debug_file = os.path.join(eval_log_dir, f"base_model_debug_{timestamp}.json")
print(f"{'=' * 50}")
print(f"✨ Model: {args.model_name}") logger.info("📝 Starting evaluation...")
print(f"📊 Accuracy: {accuracy:.4f} ({sum(rewards)}/{len(rewards)} correct)") logger.info(f"Results will be saved to: {output_file}")
logger.info(f"Debug info will be saved to: {debug_file}")
# Run evaluation using the agentic approach
full_chat_states = run_eval(
generate_fn=generate_fn,
verify_fn=verify_fn,
tokenizer=tokenizer,
output_file=output_file,
debug_file=debug_file,
)
logger.info("✨ Evaluation completed!")
logger.info(f"Check {output_file} for detailed results")
if __name__ == "__main__": if __name__ == "__main__":

@ -1,7 +1,10 @@
"""Simple script to evaluate LoRA model performance.""" """Script to evaluate LoRA model performance with enhanced debugging."""
import argparse import argparse
import json
import os
import sys import sys
from datetime import datetime
from pathlib import Path from pathlib import Path
# Add project root to Python path # Add project root to Python path
@ -15,9 +18,10 @@ from src import (
apply_chat_template, apply_chat_template,
build_reward_correctness_fn, build_reward_correctness_fn,
build_user_prompt, build_user_prompt,
get_qa_dataset,
get_system_prompt, get_system_prompt,
run_eval,
) )
from src.config import logger
def main(): def main():
@ -25,40 +29,46 @@ def main():
parser = argparse.ArgumentParser(description="Evaluate LoRA model") parser = argparse.ArgumentParser(description="Evaluate LoRA model")
parser.add_argument("--model_name", type=str, required=True, help="Name/path of the base model") parser.add_argument("--model_name", type=str, required=True, help="Name/path of the base model")
parser.add_argument("--lora_path", type=str, required=True, help="Path to LoRA weights") parser.add_argument("--lora_path", type=str, required=True, help="Path to LoRA weights")
parser.add_argument("--temperature", type=float, default=0, help="Sampling temperature")
args = parser.parse_args() args = parser.parse_args()
print(f"🚀 Setting up model {args.model_name} with LoRA from {args.lora_path}...") logger.info(f"🚀 Setting up model {args.model_name} with LoRA from {args.lora_path}...")
# Setup model with LoRA support # Load LoRA config first to get max rank
with open(f"{args.lora_path}/adapter_config.json") as f:
lora_config = json.load(f)
# Setup model with LoRA support using config values
model, tokenizer = FastLanguageModel.from_pretrained( model, tokenizer = FastLanguageModel.from_pretrained(
model_name=args.model_name, model_name=args.model_name,
max_seq_length=4096 * 2, max_seq_length=4096 * 2,
load_in_4bit=True, load_in_4bit=True,
fast_inference=True, fast_inference=True,
max_lora_rank=64, max_lora_rank=lora_config["r"], # Use rank from config
) )
# Setup LoRA # Setup LoRA using config
model = FastLanguageModel.get_peft_model( model = FastLanguageModel.get_peft_model(
model, model,
r=64, r=lora_config["r"],
target_modules=[ target_modules=lora_config["target_modules"],
"q_proj", lora_alpha=lora_config["lora_alpha"],
"k_proj", lora_dropout=lora_config["lora_dropout"],
"v_proj", bias=lora_config["bias"],
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
lora_alpha=64,
use_gradient_checkpointing=True, use_gradient_checkpointing=True,
random_state=3407, random_state=3407,
) )
# Setup sampling params # Setup sampling params
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=0.5, temperature=args.temperature,
top_p=0.95,
max_tokens=4096,
)
# Setup verifier with lower temperature
verifier_params = SamplingParams(
temperature=0.1, # Lower temperature for more consistent verification
top_p=0.95, top_p=0.95,
max_tokens=4096, max_tokens=4096,
) )
@ -81,41 +91,53 @@ def main():
sampling_params=sampling_params, sampling_params=sampling_params,
lora_request=lora_request, lora_request=lora_request,
) )
return outputs
# Format outputs as chat messages def verifier_generate_fn(inputs):
formatted_outputs = [] """Generate verification responses with lower temperature."""
for output in outputs: messages = [
formatted_outputs.append(
{ {
"messages": [ "messages": [
{"role": "system", "content": get_system_prompt()}, {"role": "system", "content": get_system_prompt()},
{"role": "assistant", "content": output.outputs[0].text}, {"role": "user", "content": build_user_prompt(input_text)},
] ]
} }
) for input_text in inputs
return formatted_outputs ]
# Get dataset
_, test_dataset = get_qa_dataset()
questions = test_dataset["prompt"]
answers = test_dataset["answer"]
print(f"📝 Evaluating {len(questions)} questions...") lora_request = model.load_lora(args.lora_path)
return model.fast_generate(
[apply_chat_template(msg, tokenizer=tokenizer)["text"] for msg in messages],
sampling_params=verifier_params,
lora_request=lora_request,
)
# Build verifier # Build verifier
verify_fn = build_reward_correctness_fn(generate_fn, tokenizer) verify_fn = build_reward_correctness_fn(verifier_generate_fn, tokenizer)
# Run evaluation # Setup output directories
completions = generate_fn(questions) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
rewards = verify_fn(questions, completions, answer=answers) eval_log_dir = "eval_logs"
accuracy = sum(rewards) / len(rewards) os.makedirs(eval_log_dir, exist_ok=True)
print(f"\n{'=' * 50}") output_file = os.path.join(eval_log_dir, f"lora_model_results_{timestamp}.txt")
print("🎯 LORA MODEL EVALUATION RESULTS:") debug_file = os.path.join(eval_log_dir, f"lora_model_debug_{timestamp}.json")
print(f"{'=' * 50}")
print(f"✨ Base Model: {args.model_name}") logger.info("📝 Starting evaluation...")
print(f"🔧 LoRA Path: {args.lora_path}") logger.info(f"Results will be saved to: {output_file}")
print(f"📊 Accuracy: {accuracy:.4f} ({sum(rewards)}/{len(rewards)} correct)") logger.info(f"Debug info will be saved to: {debug_file}")
# Run evaluation using the agentic approach
full_chat_states = run_eval(
generate_fn=generate_fn,
verify_fn=verify_fn,
tokenizer=tokenizer,
output_file=output_file,
debug_file=debug_file,
)
logger.info("✨ Evaluation completed!")
logger.info(f"Check {output_file} for detailed results")
if __name__ == "__main__": if __name__ == "__main__":

@ -167,7 +167,7 @@ def run_eval(generate_fn, verify_fn, tokenizer, output_file=None, debug_file=Non
# Create agent with appropriate adapter based on tokenizer # Create agent with appropriate adapter based on tokenizer
tokenizer_name = tokenizer.name_or_path.lower() tokenizer_name = tokenizer.name_or_path.lower()
if "deepseek-r1-distill" in tokenizer_name: if "deepseek-ai/deepseek-r1-distill" in tokenizer_name:
adapter = R1DistilTokenizerAdapter() adapter = R1DistilTokenizerAdapter()
elif "llama" in tokenizer_name: elif "llama" in tokenizer_name:
adapter = LlamaTokenizerAdapter() adapter = LlamaTokenizerAdapter()

Loading…
Cancel
Save