"""Script to evaluate LoRA model performance with enhanced debugging."""

import argparse
import json
import os
import sys
from datetime import datetime
from pathlib import Path

# Add project root to Python path
project_root = str(Path(__file__).parent.parent)
sys.path.append(project_root)

from unsloth import FastLanguageModel
from vllm import SamplingParams

from config import logger
from src import (
    apply_chat_template,
    build_reward_correctness_fn,
    build_user_prompt,
    get_system_prompt,
    run_eval,
)


def main():
    """Run LoRA model evaluation."""
    parser = argparse.ArgumentParser(description="Evaluate LoRA model")
    parser.add_argument("--model_name", type=str, required=True, help="Name/path of the base model")
    parser.add_argument("--lora_path", type=str, required=True, help="Path to LoRA weights")
    parser.add_argument("--temperature", type=float, default=0, help="Sampling temperature")
    args = parser.parse_args()

    logger.info(f"🚀 Setting up model {args.model_name} with LoRA from {args.lora_path}...")

    # Load LoRA config first to get max rank
    with open(f"{args.lora_path}/adapter_config.json") as f:
        lora_config = json.load(f)

    # Setup model with LoRA support using config values
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=args.model_name,
        max_seq_length=4096 * 6,
        load_in_4bit=True,
        fast_inference=True,
        max_lora_rank=lora_config["r"],  # Use rank from config
        gpu_memory_utilization=0.8,
    )

    # Setup LoRA using config
    model = FastLanguageModel.get_peft_model(
        model,
        r=lora_config["r"],
        target_modules=lora_config["target_modules"],
        lora_alpha=lora_config["lora_alpha"],
        lora_dropout=lora_config["lora_dropout"],
        bias=lora_config["bias"],
        use_gradient_checkpointing=True,
        random_state=3407,
    )

    # Setup sampling params
    sampling_params = SamplingParams(
        temperature=args.temperature,
        top_p=0.95,
        max_tokens=4096 * 6,
    )

    # Setup verifier with lower temperature
    verifier_params = SamplingParams(
        temperature=0.1,  # Lower temperature for more consistent verification
        top_p=0.95,
        max_tokens=4096 * 6,
    )

    def generate_fn(inputs):
        """Generate responses for inputs."""
        messages = [
            {
                "messages": [
                    {"role": "system", "content": get_system_prompt()},
                    {"role": "user", "content": build_user_prompt(input_text)},
                ]
            }
            for input_text in inputs
        ]

        lora_request = model.load_lora(args.lora_path)
        outputs = model.fast_generate(
            [apply_chat_template(msg, tokenizer=tokenizer)["text"] for msg in messages],
            sampling_params=sampling_params,
            lora_request=lora_request,
        )
        return outputs

    def verifier_generate_fn(inputs):
        """Generate verification responses with lower temperature."""
        messages = [
            {
                "messages": [
                    {"role": "system", "content": get_system_prompt()},
                    {"role": "user", "content": build_user_prompt(input_text)},
                ]
            }
            for input_text in inputs
        ]

        lora_request = model.load_lora(args.lora_path)
        return model.fast_generate(
            [apply_chat_template(msg, tokenizer=tokenizer)["text"] for msg in messages],
            sampling_params=verifier_params,
            lora_request=lora_request,
        )

    # Build verifier
    verify_fn = build_reward_correctness_fn(verifier_generate_fn, tokenizer)

    # Setup output directories
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    eval_log_dir = "eval_logs"
    os.makedirs(eval_log_dir, exist_ok=True)

    output_file = os.path.join(eval_log_dir, f"lora_model_results_{timestamp}.txt")
    debug_file = os.path.join(eval_log_dir, f"lora_model_debug_{timestamp}.json")

    logger.info("📝 Starting evaluation...")
    logger.info(f"Results will be saved to: {output_file}")
    logger.info(f"Debug info will be saved to: {debug_file}")

    # Run evaluation using the agentic approach
    full_chat_states = run_eval(
        generate_fn=generate_fn,
        verify_fn=verify_fn,
        tokenizer=tokenizer,
        output_file=output_file,
        debug_file=debug_file,
        max_generations=32,
        max_new_tokens=4096 * 6,
    )

    logger.info("✨ Evaluation completed!")
    logger.info(f"Check {output_file} for detailed results")


if __name__ == "__main__":
    main()