{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load finetuned model ans say hello"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ✅ Utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from unsloth import FastLanguageModel\n",
    "from vllm import SamplingParams\n",
    "\n",
    "\n",
    "def load_model(\n",
    "    # model_name=\"meta-llama/Llama-3.2-1B-Instruct\",\n",
    "    model_name=\"meta-llama/meta-Llama-3.1-8B-Instruct\",\n",
    "    lora_path=\"../trainer_output_meta-llama_Llama-3.1-8B-Instruct_gpu1_20250326_134236/checkpoint-101\",\n",
    "    max_seq_length=8192,\n",
    "):\n",
    "    \"\"\"Load model and tokenizer with optional LoRA weights.\"\"\"\n",
    "    # Load base model\n",
    "    model, tokenizer = FastLanguageModel.from_pretrained(\n",
    "        model_name=model_name,\n",
    "        max_seq_length=max_seq_length,\n",
    "        load_in_4bit=True,\n",
    "        fast_inference=True,\n",
    "        max_lora_rank=64,\n",
    "        gpu_memory_utilization=0.6,\n",
    "    )\n",
    "\n",
    "    # Setup LoRA if path provided\n",
    "    if lora_path:\n",
    "        model = FastLanguageModel.get_peft_model(\n",
    "            model,\n",
    "            r=64,\n",
    "            target_modules=[\n",
    "                \"q_proj\",\n",
    "                \"k_proj\",\n",
    "                \"v_proj\",\n",
    "                \"o_proj\",\n",
    "                \"gate_proj\",\n",
    "                \"up_proj\",\n",
    "                \"down_proj\",\n",
    "            ],\n",
    "            lora_alpha=64,\n",
    "            use_gradient_checkpointing=True,\n",
    "            random_state=3407,\n",
    "        )\n",
    "        model.load_lora(lora_path)\n",
    "\n",
    "    return model, tokenizer\n",
    "\n",
    "\n",
    "def get_sampling_params(\n",
    "    temperature=0.7,\n",
    "    top_p=0.95,\n",
    "    max_tokens=4096,\n",
    "):\n",
    "    \"\"\"Get sampling parameters for text generation.\"\"\"\n",
    "    return SamplingParams(\n",
    "        temperature=temperature,\n",
    "        top_p=top_p,\n",
    "        max_tokens=max_tokens,\n",
    "    )\n",
    "\n",
    "\n",
    "def generate_response(prompt, model, tokenizer, sampling_params):\n",
    "    \"\"\"Generate a response from the model.\"\"\"\n",
    "    inputs = tokenizer.apply_chat_template(\n",
    "        [{\"role\": \"user\", \"content\": prompt}],\n",
    "        tokenize=False,\n",
    "        add_generation_prompt=True,\n",
    "    )\n",
    "\n",
    "    outputs = model.fast_generate([inputs], sampling_params=sampling_params)\n",
    "\n",
    "    if hasattr(outputs[0], \"outputs\"):\n",
    "        response_text = outputs[0].outputs[0].text\n",
    "    else:\n",
    "        response_text = outputs[0]\n",
    "\n",
    "    return response_text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model, tokenizer = load_model()  # Using default hardcoded path\n",
    "sampling_params = get_sampling_params()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "response = generate_response(\"Hi! How are you?\", model, tokenizer, sampling_params)\n",
    "print(\"\\nModel response:\")\n",
    "print(response)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Merge lora and save model to 16 bit, then test load inference\n",
    "if False:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save_pretrained_merged(\n",
    "    \"model\",\n",
    "    tokenizer,\n",
    "    save_method=\"merged_16bit\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ✅ Test load merged model 16bit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Test load merged model\n",
    "model, tokenizer = load_model(model_name=\"model\", lora_path=None)\n",
    "sampling_params = get_sampling_params()\n",
    "response = generate_response(\"Hi! How are you?\", model, tokenizer, sampling_params)\n",
    "print(\"\\nModel response:\")\n",
    "print(response)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ❌ Save model to Ollama format\n",
    "- bug no lllama-quantize on wsl, fix later\n",
    "https://docs.unsloth.ai/basics/runninand-saving-models/saving-to-gguf\n",
    "```bash\n",
    "git clone https://github.com/ggml-org/llama.cpp\n",
    "cd llama.cpp\n",
    "cmake -B build\n",
    "cmake --build build --config Release\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save to 8bit Q8_0\n",
    "if True:\n",
    "    model.save_pretrained_gguf(\n",
    "        \"model-gguf\",\n",
    "        tokenizer,\n",
    "    )\n",
    "# Remember to go to https://huggingface.co/settings/tokens for a token!\n",
    "# And change hf to your username!\n",
    "if False:\n",
    "    model.push_to_hub_gguf(\"hf/model\", tokenizer, token=\"\")\n",
    "\n",
    "# Save to 16bit GGUF\n",
    "if False:\n",
    "    model.save_pretrained_gguf(\"model\", tokenizer, quantization_method=\"f16\")\n",
    "if False:\n",
    "    model.push_to_hub_gguf(\"hf/model\", tokenizer, quantization_method=\"f16\", token=\"\")\n",
    "\n",
    "# Save to q4_k_m GGUF\n",
    "if False:\n",
    "    model.save_pretrained_gguf(\"model\", tokenizer, quantization_method=\"q4_k_m\")\n",
    "if False:\n",
    "    model.push_to_hub_gguf(\n",
    "        \"hf/model\", tokenizer, quantization_method=\"q4_k_m\", token=\"\"\n",
    "    )\n",
    "\n",
    "# Save to multiple GGUF options - much faster if you want multiple!\n",
    "if False:\n",
    "    model.push_to_hub_gguf(\n",
    "        \"hf/model\",  # Change hf to your username!\n",
    "        tokenizer,\n",
    "        quantization_method=[\n",
    "            \"q4_k_m\",\n",
    "            \"q8_0\",\n",
    "            \"q5_k_m\",\n",
    "        ],\n",
    "        token=\"\",\n",
    "    )"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "deepsearch-py311",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}