{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "wAwm-SoKNjoX"
   },
   "source": [
    "# Sample GRPO Training Notebook\n",
    "- Stolen from unsloth :3 thanks!\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "rv8V30lINjob"
   },
   "source": [
    "### Installation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "1iVpHRPXNjoc"
   },
   "outputs": [],
   "source": [
    "# @title Colab Extra Install { display-mode: \"form\" }\n",
    "import os\n",
    "\n",
    "if \"COLAB_\" not in \"\".join(os.environ.keys()):\n",
    "    !pip install -q unsloth vllm\n",
    "else:\n",
    "    !pip install -q --no-deps unsloth vllm\n",
    "    # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]\n",
    "    # Skip restarting message in Colab\n",
    "    import sys\n",
    "    import re\n",
    "    import requests\n",
    "\n",
    "    modules = list(sys.modules.keys())\n",
    "    for x in modules:\n",
    "        sys.modules.pop(x) if \"PIL\" in x or \"google\" in x else None\n",
    "    !pip install -q --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo\n",
    "    !pip install -q sentencepiece protobuf datasets huggingface_hub hf_transfer\n",
    "\n",
    "    # vLLM requirements - vLLM breaks Colab due to reinstalling numpy\n",
    "    f = requests.get(\n",
    "        \"https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/requirements/common.txt\"\n",
    "    ).content\n",
    "    with open(\"vllm_requirements.txt\", \"wb\") as file:\n",
    "        file.write(re.sub(rb\"(transformers|numpy|xformers)[^\\n]{1,}\\n\", b\"\", f))\n",
    "    !pip install -q -r vllm_requirements.txt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "HnqT9aF5Njoc"
   },
   "source": [
    "### Unsloth"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "IpAO-idyNjoc"
   },
   "source": [
    "Load up `Llama 3.1 8B Instruct`, and set parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000,
     "referenced_widgets": [
      "09fca7e3b90f4b17a2344b67100e26bf",
      "6eee6e73473440a39c2a14d11b19e7b9",
      "f64c0803b3c34610a34f9c5f2b5ce6ba",
      "4bae66db2fc041278f13ae1def5d1ddd",
      "ce588ed5880240d7bfd91c150ef299ed",
      "bc92bf99bcf345b59699eb8f0423f7cd",
      "f589e115f63d4fdb95816dccb34c3e3e",
      "fb45312477f64384912ba3e843efb0b3",
      "4e541334eb26485ca5764b9bf0bbbd34",
      "fe9dbcf4e85646cb8c684d82252b93e2",
      "d36f0dda703c47cfae6f544e7e655806",
      "cc9a242d278d4ea1b0f0f44d611439a2",
      "9ec23dd10394466eb6db5336912f09de",
      "e72af3cf2aac44d3a10b692b5fd54afe",
      "7a0041d1e5c349b4b70e96554ea18113",
      "f3fe4eb2a2fc4b2fb255632272dc0415",
      "b095d3eac38847c0b087ffcc5876e0b6",
      "9ebd8c869bc144cf8142375e7d7f04dd",
      "48403d21996b4ebfa931cf260af6c283",
      "b019cf9965424856b3fd781bd9e73118",
      "ff032a23470845ba95a12227baf0e0ce",
      "18d2fd154ffc4671b76f00760eb35571",
      "bc8f8134f144435681cbc50f9f5e5274",
      "d7a1980bf7cc4895aea14df80bf92302",
      "6be20c9739de44ffb7ef2465ad17424d",
      "39411144c4c44eae935262bc6312d5b9",
      "32f3e3f08eb1441ca70191eaa6191805",
      "2df204a0982b466eb3c25bf80645b6c1",
      "9ec592e894b945cea7c2a756a96eb959",
      "7ff47ae65e944dc9bbd50f01edb7fa02",
      "2f5bc26300174430aa5b0610b5ef4c72",
      "79da5ea5195c4c63b0ec1ee354511a87",
      "a5f3704359df48e4b4c5b418a5f72021",
      "2cae8f5f92304278bcfa4a07fb5aafdc",
      "63681e5808d1443c9f04654d58952448",
      "91200f911ca64162b006235d0ec0c5f1",
      "40c4dfad109643529a16b3de307e2e1a",
      "491466071e6742c6ab30419597d78fde",
      "a67ca6e6c91244978b6b468257ffefd7",
      "781f0652012b424aa43cd39028e3a99a",
      "9e4ae8a9e1b247ffbd232d97be270cd8",
      "c558b6c8e51046c6b3e0809ccf749348",
      "0ea62e14a2d44b8ca7c6f15bca8e1a2b",
      "27e78ef30ff64be5af79b6431816f569",
      "b973da5783dc49d1802831298be3ab7e",
      "d7e3f2c2441743089ee7ed10eb0319dd",
      "7f490169dd8a4dbeb367af5259b0949e",
      "2a8b54de6a974cc3a3f0b47daa10af89",
      "70780d08fb00445da32b28b40d33412f",
      "2c83de81fc91497db859d1be6860d6f4",
      "397a49fc365b4c3a95db709b71fc32e3",
      "1e75c59df9524dec928b21d65dd43c08",
      "747b264f119746c08bbd6685c134ec4f",
      "364d9414fd5f4280be4ad621874e15cb",
      "7b3ad76804bd47a68fba0bf9f967cb04",
      "2c34cb165ce447dca52c25eff7cc91ce",
      "e8f729e29f7c4639a10a02a9f4978385",
      "f60723e017f64a2183af5f05ca827340",
      "165ed67ab04240fca77e70e6a905ad48",
      "e229f6a91dcf4b5182673cc6c7002063",
      "fe4ecfe3522c432b964608f137cfb936",
      "e1e8c601a2bb4c0b8f0cdf3bf0dd0b2e",
      "db89c098dd414fa08440071935414b63",
      "efb3aff9a97340fb8acf5f98402e7c1d",
      "d4eeb17c3a124c919058c0ffe72a60b6",
      "8b6bd37f64454464b44691ae3bef0b9f",
      "239c67254f4044259fcbb35cb78d72b3",
      "7bdc3ab924804e6badd7f403d0ce33f1",
      "2b1ff4602dee4511a450588c501a7174",
      "bfae7687d2a440189c29dc53e47b71ab",
      "c72dbf9684ed4cf0a55655ebaf645b96",
      "7083c21ed3dd4ed6b442afdf4a18a05d",
      "6e4ce08c86f24beeb63d38d436143dc9",
      "86498c04c2a347968e65867eb4ed2102",
      "dc90aea1a6e445ebaf68bb305586bf47",
      "4c69d53cf33c4865af8feead04b810ad",
      "9dbc67661c614d8a82d0790d9b694e6d",
      "97f09dcebaf94be3b115a69b1440392b",
      "a568aff772814128b0ad329594e67610",
      "92ac502bff6e48d98cbb07a65d84dca6",
      "0b9dd1c421f14c56986cf30abe0a9455",
      "50a3cb7fc0ce45bb8c1baabc2a54fff2",
      "fac3506146244766b886c715ea57ec7b",
      "d068e96c067e43ffa1777a666b0df147",
      "9e9b1a12028443c9abe6d88d0ed39d15",
      "638c59bf00ba4536ad157975faa26984",
      "73177ee1d64c4f56aa393c45c42e7499",
      "98aabe0e4167414d943ff3d7759093ad",
      "0f040a5765854b699b32bf096bb0563c",
      "813e8e40df6447c18e47414ab38f69a2",
      "533f84f90e7247e8a60da8b6ce60fe89",
      "80c5f4171f2844749b5f3e67a46178b9",
      "0722b569fa2e45f9af7d4fb8134e632b",
      "115a8f3c523c4cd3aaa472af87fe0e4a",
      "26cde13f97db4ae4a0900fb5184a51a2",
      "5c4e90490762447c9712372a18f48e20",
      "690c64fbee9f4c038a29030086c73368",
      "cbe9c4e2f8bc49dfb1d8a4cddc36d724",
      "f33aef89be7641a694fb11e3101c7d85",
      "6798efcf0f784fca91a2e84c91553296",
      "190504fd7fec4f08ba43b972ed28d23b",
      "142703403e5e4ba1ac1531d7c82c752a",
      "0a45b8dd65764272869fd1424415e6af",
      "d536479db8034c99be7e745a440baf7d",
      "594a583be494419c845fd873904d4d6d",
      "ba21e0c745714031aae12b1665847e54",
      "3491a152831245f7b03d06fe94f42116",
      "b80d4e38b5c54dd585d1da848ee9c4c8",
      "fa155f662d1c40f496daef290ff7a5ff",
      "9b2a88f9712d409bbf74bf0c2e14df16"
     ]
    },
    "id": "DkIvEkIIkEyB",
    "outputId": "e1aaf3b5-6ad0-45c7-be08-2017e74ebba9"
   },
   "outputs": [],
   "source": [
    "from unsloth import FastLanguageModel\n",
    "\n",
    "max_seq_length = 1024  # Can increase for longer reasoning traces\n",
    "lora_rank = 32  # Larger rank = smarter, but slower\n",
    "\n",
    "model, tokenizer = FastLanguageModel.from_pretrained(\n",
    "    model_name=\"meta-llama/meta-Llama-3.1-8B-Instruct\",\n",
    "    max_seq_length=max_seq_length,\n",
    "    load_in_4bit=True,  # False for LoRA 16bit\n",
    "    fast_inference=True,  # Enable vLLM fast inference\n",
    "    max_lora_rank=lora_rank,\n",
    "    gpu_memory_utilization=0.6,  # Reduce if out of memory\n",
    ")\n",
    "\n",
    "model = FastLanguageModel.get_peft_model(\n",
    "    model,\n",
    "    r=lora_rank,  # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128\n",
    "    target_modules=[\n",
    "        \"q_proj\",\n",
    "        \"k_proj\",\n",
    "        \"v_proj\",\n",
    "        \"o_proj\",\n",
    "        \"gate_proj\",\n",
    "        \"up_proj\",\n",
    "        \"down_proj\",\n",
    "    ],  # Remove QKVO if out of memory\n",
    "    lora_alpha=lora_rank,\n",
    "    use_gradient_checkpointing=\"unsloth\",  # Enable long context finetuning\n",
    "    random_state=3407,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "7KGgPgk_5S8r"
   },
   "source": [
    "### Data Prep\n",
    "<a name=\"Data\"></a>\n",
    "\n",
    "We directly leverage [@willccbb](https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb) for data prep and all reward functions. You are free to create your own!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 321,
     "referenced_widgets": [
      "5954f25ab3334f8c8e0576ee75922c0c",
      "5db92327420a48f4a1ad966ecd7ac8c5",
      "1a1651d5ca7d4ccd887235343e49ffb3",
      "22462b8181964257a53c85be6265eef6",
      "1c09f4dc86124b1296ed96fc48b8d856",
      "5e9d2ad532a8447a904688629a91f171",
      "2df38622a5ab46979ee53efeb8c31b91",
      "a75b75f7dec54759816845c3fda31f22",
      "b14a63b4cac142c5adaaebbfef8739fe",
      "1a090b75a50b459a8d2f822efc692dfa",
      "0da6c9fa4eeb4b4d832f8bbc614bc8d9",
      "3e1954a710bd43d1907b702c7f70c6de",
      "483db1483d1d4a97a59c1e6cdad04c6a",
      "cf88de7717504af691ca8db0fb68a6b6",
      "d4de07c4656c40329bc2c1450bd94bb5",
      "160f401dc9a74266bb610e747f8f711e",
      "269f762b5af041eab750226febe70e07",
      "f38c687b78224c8aa2227914998b4627",
      "eca5ebef74f341f79dc96964d0189862",
      "009ea1226a194cb2923ec9bc3a6e1843",
      "2aade13ff2ca4250ae02b7a12b1ef16d",
      "23531535f2fc4b99a5ecd1b63fb93d80",
      "a137a1bc2b4f437a8477e079f003ad8d",
      "05836f874c9a49d390721a05ff055ac9",
      "9032130a4a9646db92805eeb4221c83d",
      "ac58943ffd184cdca0b69cce35f54ec7",
      "8dc3c53fac5644a6a2841aca62226300",
      "fda3ed1250d34b12aae8401287322548",
      "e85f549d21dd4c3293f8239a053b0ad4",
      "a8d0ef0ef4f747e294394b04d07cbe34",
      "15b83f9309a642d2bb1bc79da539b923",
      "436282df7bcc45c4b951de60f2e4ec53",
      "911d088df0894243b69fd4f69c46e7f3",
      "eba35bc404f640bcb087de7a7ce74942",
      "6b39fce0206f4ca894b943a42eb9dc04",
      "c3ec8a75007846519d40f6726d2a01ee",
      "de3486da03594b0091b77cbf702401a5",
      "2e1b5b6b3ae84190a7f50993d78fb6ea",
      "81f43f522faa4dcf9e557793a2027e90",
      "ae46c8f17c0644fdbd1aa6d412c322b7",
      "5de480828775465bb76b674ecd713bd0",
      "a159d55b4de84020bc127d6c7758d2a9",
      "cc6097b25ec74decb51a397ccd9fcf27",
      "d2af3c19e6ec46a3915f6badf7062b7b",
      "8d2463ee53394dc097843bc733191c97",
      "f681a0ff692c4cd3b14b1e064547c52f",
      "0ecef6aade754029ba939e30c91538be",
      "87e25abbd016411199a2e976cd7ab550",
      "93f708b722194750a9f66a2597987033",
      "76dc9524ee764741a1615de6bc43492e",
      "78201e7953b647ee9656fc7e8f4abbcb",
      "0bc45e643a9b4a0e8db9f85180c108b2",
      "cfb5247cf65c4bdfbcf8207c5d8c72c8",
      "32fe42f8a11d4131a3a529b554f028fe",
      "39c33d513fb74712be8fcf73093a0c9e",
      "908333eedde74ec3a0885e52454b3627",
      "934617e655d04f85a87d6444b7c1ff0e",
      "63d0c36c4bb8469f8cf5ac410d2c11d6",
      "092d58ba2c444878a1cbfd19bdc14027",
      "c8ecf533f45f4ae29bdcfdaeca7daedf",
      "d89eae07da0e4344909900e24e3d0d09",
      "b1e00fc06b664fc59d4921388f187269",
      "b834a1ed2fc14d209c405969def96965",
      "dc681483ceeb4e878124858ffb467338",
      "ef2a701c9c594d4d9ffc8379fa9b5899",
      "22c31da7c0564f09bbdb04efdefad11a"
     ]
    },
    "id": "cXk993X6C2ZZ",
    "outputId": "d6b161df-4d15-4cae-a02c-22fb901b91f1"
   },
   "outputs": [],
   "source": [
    "import re\n",
    "from datasets import load_dataset, Dataset\n",
    "\n",
    "# Load and prep dataset\n",
    "SYSTEM_PROMPT = \"\"\"\n",
    "Respond in the following format:\n",
    "<reasoning>\n",
    "...\n",
    "</reasoning>\n",
    "<answer>\n",
    "...\n",
    "</answer>\n",
    "\"\"\"\n",
    "\n",
    "XML_COT_FORMAT = \"\"\"\\\n",
    "<reasoning>\n",
    "{reasoning}\n",
    "</reasoning>\n",
    "<answer>\n",
    "{answer}\n",
    "</answer>\n",
    "\"\"\"\n",
    "\n",
    "\n",
    "def extract_xml_answer(text: str) -> str:\n",
    "    answer = text.split(\"<answer>\")[-1]\n",
    "    answer = answer.split(\"</answer>\")[0]\n",
    "    return answer.strip()\n",
    "\n",
    "\n",
    "def extract_hash_answer(text: str) -> str | None:\n",
    "    if \"####\" not in text:\n",
    "        return None\n",
    "    return text.split(\"####\")[1].strip()\n",
    "\n",
    "\n",
    "# uncomment middle messages for 1-shot prompting\n",
    "def get_gsm8k_questions(split=\"train\") -> Dataset:\n",
    "    data = load_dataset(\"openai/gsm8k\", \"main\")[split]  # type: ignore\n",
    "    data = data.map(\n",
    "        lambda x: {  # type: ignore\n",
    "            \"prompt\": [\n",
    "                {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
    "                {\"role\": \"user\", \"content\": x[\"question\"]},\n",
    "            ],\n",
    "            \"answer\": extract_hash_answer(x[\"answer\"]),\n",
    "        }\n",
    "    )  # type: ignore\n",
    "    return data  # type: ignore\n",
    "\n",
    "\n",
    "dataset = get_gsm8k_questions()\n",
    "\n",
    "\n",
    "# Reward functions\n",
    "def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:\n",
    "    responses = [completion[0][\"content\"] for completion in completions]\n",
    "    q = prompts[0][-1][\"content\"]\n",
    "    extracted_responses = [extract_xml_answer(r) for r in responses]\n",
    "    print(\n",
    "        \"-\" * 20,\n",
    "        f\"Question:\\n{q}\",\n",
    "        f\"\\nAnswer:\\n{answer[0]}\",\n",
    "        f\"\\nResponse:\\n{responses[0]}\",\n",
    "        f\"\\nExtracted:\\n{extracted_responses[0]}\",\n",
    "    )\n",
    "    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]\n",
    "\n",
    "\n",
    "def int_reward_func(completions, **kwargs) -> list[float]:\n",
    "    responses = [completion[0][\"content\"] for completion in completions]\n",
    "    extracted_responses = [extract_xml_answer(r) for r in responses]\n",
    "    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]\n",
    "\n",
    "\n",
    "def strict_format_reward_func(completions, **kwargs) -> list[float]:\n",
    "    \"\"\"Reward function that checks if the completion has a specific format.\"\"\"\n",
    "    pattern = r\"^<reasoning>\\n.*?\\n</reasoning>\\n<answer>\\n.*?\\n</answer>\\n$\"\n",
    "    responses = [completion[0][\"content\"] for completion in completions]\n",
    "    matches = [re.match(pattern, r) for r in responses]\n",
    "    return [0.5 if match else 0.0 for match in matches]\n",
    "\n",
    "\n",
    "def soft_format_reward_func(completions, **kwargs) -> list[float]:\n",
    "    \"\"\"Reward function that checks if the completion has a specific format.\"\"\"\n",
    "    pattern = r\"<reasoning>.*?</reasoning>\\s*<answer>.*?</answer>\"\n",
    "    responses = [completion[0][\"content\"] for completion in completions]\n",
    "    matches = [re.match(pattern, r) for r in responses]\n",
    "    return [0.5 if match else 0.0 for match in matches]\n",
    "\n",
    "\n",
    "def count_xml(text) -> float:\n",
    "    count = 0.0\n",
    "    if text.count(\"<reasoning>\\n\") == 1:\n",
    "        count += 0.125\n",
    "    if text.count(\"\\n</reasoning>\\n\") == 1:\n",
    "        count += 0.125\n",
    "    if text.count(\"\\n<answer>\\n\") == 1:\n",
    "        count += 0.125\n",
    "        count -= len(text.split(\"\\n</answer>\\n\")[-1]) * 0.001\n",
    "    if text.count(\"\\n</answer>\") == 1:\n",
    "        count += 0.125\n",
    "        count -= (len(text.split(\"\\n</answer>\")[-1]) - 1) * 0.001\n",
    "    return count\n",
    "\n",
    "\n",
    "def xmlcount_reward_func(completions, **kwargs) -> list[float]:\n",
    "    contents = [completion[0][\"content\"] for completion in completions]\n",
    "    return [count_xml(c) for c in contents]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Ux6iqP7z5YOo"
   },
   "source": [
    "<a name=\"Train\"></a>\n",
    "### Train the model\n",
    "\n",
    "Now set up GRPO Trainer and all configurations!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "ptqkXK2D4d6p",
    "outputId": "1e7b980b-c662-49c9-fb3a-d929edbafa09"
   },
   "outputs": [],
   "source": [
    "max_prompt_length = 256\n",
    "\n",
    "from trl import GRPOConfig, GRPOTrainer\n",
    "\n",
    "training_args = GRPOConfig(\n",
    "    learning_rate=5e-6,\n",
    "    adam_beta1=0.9,\n",
    "    adam_beta2=0.99,\n",
    "    weight_decay=0.1,\n",
    "    warmup_ratio=0.1,\n",
    "    lr_scheduler_type=\"cosine\",\n",
    "    optim=\"paged_adamw_8bit\",\n",
    "    logging_steps=1,\n",
    "    per_device_train_batch_size=1,\n",
    "    gradient_accumulation_steps=1,  # Increase to 4 for smoother training\n",
    "    num_generations=6,  # Decrease if out of memory\n",
    "    max_prompt_length=max_prompt_length,\n",
    "    max_completion_length=max_seq_length - max_prompt_length,\n",
    "    # num_train_epochs = 1, # Set to 1 for a full training run\n",
    "    max_steps=250,\n",
    "    save_steps=250,\n",
    "    max_grad_norm=0.1,\n",
    "    report_to=\"none\",  # Can use Weights & Biases\n",
    "    output_dir=\"outputs\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "r9Mv8UZO5hz-"
   },
   "source": [
    "And let's run the trainer! If you scroll up, you'll see a table of rewards. The goal is to see the `reward` column increase!\n",
    "\n",
    "You might have to wait 150 to 200 steps for any action. You'll probably get 0 reward for the first 100 steps. Please be patient!\n",
    "\n",
    "| Step | Training Loss | reward    | reward_std | completion_length | kl       |\n",
    "|------|---------------|-----------|------------|-------------------|----------|\n",
    "| 1    | 0.000000      | 0.125000  | 0.000000   | 200.000000        | 0.000000 |\n",
    "| 2    | 0.000000      | 0.072375  | 0.248112   | 200.000000        | 0.000000 |\n",
    "| 3    | 0.000000      | -0.079000 | 0.163776   | 182.500000        | 0.000005 |\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "id": "vzOuSVCL_GA9",
    "outputId": "09405367-ccea-4e1c-aabf-2dbb9318c4fa"
   },
   "outputs": [],
   "source": [
    "trainer = GRPOTrainer(\n",
    "    model=model,\n",
    "    processing_class=tokenizer,\n",
    "    reward_funcs=[\n",
    "        xmlcount_reward_func,\n",
    "        soft_format_reward_func,\n",
    "        strict_format_reward_func,\n",
    "        int_reward_func,\n",
    "        correctness_reward_func,\n",
    "    ],\n",
    "    args=training_args,\n",
    "    train_dataset=dataset,\n",
    ")\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "tlaUdxC_VHpz"
   },
   "source": [
    "<a name=\"Inference\"></a>\n",
    "### Inference\n",
    "Now let's try the model we just trained! First, let's first try the model without any GRPO trained:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 249
    },
    "id": "qtcz_lpbVC92",
    "outputId": "70e4f329-acac-4d31-a8cd-47f7c6088747"
   },
   "outputs": [],
   "source": [
    "text = tokenizer.apply_chat_template(\n",
    "    [\n",
    "        {\"role\": \"user\", \"content\": \"Calculate pi.\"},\n",
    "    ],\n",
    "    tokenize=False,\n",
    "    add_generation_prompt=True,\n",
    ")\n",
    "\n",
    "from vllm import SamplingParams\n",
    "\n",
    "sampling_params = SamplingParams(\n",
    "    temperature=0.8,\n",
    "    top_p=0.95,\n",
    "    max_tokens=1024,\n",
    ")\n",
    "output = (\n",
    "    model.fast_generate(\n",
    "        [text],\n",
    "        sampling_params=sampling_params,\n",
    "        lora_request=None,\n",
    "    )[0]\n",
    "    .outputs[0]\n",
    "    .text\n",
    ")\n",
    "\n",
    "output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Colxz9TAVMsi"
   },
   "source": [
    "And now with the LoRA we just trained with GRPO - we first save the LoRA first!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "AL-BcuB1VLIv"
   },
   "outputs": [],
   "source": [
    "model.save_lora(\"grpo_saved_lora\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "CwpbwnDBVRLg"
   },
   "source": [
    "Now we load the LoRA and test:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 249
    },
    "id": "zf_OY5WMVOxF",
    "outputId": "22373f16-a3bc-4c99-8a1d-1994522a5f0f"
   },
   "outputs": [],
   "source": [
    "text = tokenizer.apply_chat_template(\n",
    "    [\n",
    "        {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
    "        {\"role\": \"user\", \"content\": \"Calculate pi.\"},\n",
    "    ],\n",
    "    tokenize=False,\n",
    "    add_generation_prompt=True,\n",
    ")\n",
    "\n",
    "from vllm import SamplingParams\n",
    "\n",
    "sampling_params = SamplingParams(\n",
    "    temperature=0.8,\n",
    "    top_p=0.95,\n",
    "    max_tokens=1024,\n",
    ")\n",
    "output = (\n",
    "    model.fast_generate(\n",
    "        text,\n",
    "        sampling_params=sampling_params,\n",
    "        lora_request=model.load_lora(\"grpo_saved_lora\"),\n",
    "    )[0]\n",
    "    .outputs[0]\n",
    "    .text\n",
    ")\n",
    "\n",
    "output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "6aDgFfhFYIAS"
   },
   "source": [
    "Our reasoning model is much better - it's not always correct, since we only trained it for an hour or so - it'll be better if we extend the sequence length and train for longer!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-NUEmHFSYNTp"
   },
   "source": [
    "<a name=\"Save\"></a>\n",
    "### Saving to float16 for VLLM\n",
    "\n",
    "We also support saving to `float16` directly. Select `merged_16bit` for float16 or `merged_4bit` for int4. We also allow `lora` adapters as a fallback. Use `push_to_hub_merged` to upload to your Hugging Face account! You can go to https://huggingface.co/settings/tokens for your personal tokens."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "NjXGTkp7YNtB"
   },
   "outputs": [],
   "source": [
    "# Merge to 16bit\n",
    "if False:\n",
    "    model.save_pretrained_merged(\n",
    "        \"model\",\n",
    "        tokenizer,\n",
    "        save_method=\"merged_16bit\",\n",
    "    )\n",
    "if False:\n",
    "    model.push_to_hub_merged(\n",
    "        \"hf/model\", tokenizer, save_method=\"merged_16bit\", token=\"\"\n",
    "    )\n",
    "\n",
    "# Merge to 4bit\n",
    "if False:\n",
    "    model.save_pretrained_merged(\n",
    "        \"model\",\n",
    "        tokenizer,\n",
    "        save_method=\"merged_4bit\",\n",
    "    )\n",
    "if False:\n",
    "    model.push_to_hub_merged(\"hf/model\", tokenizer, save_method=\"merged_4bit\", token=\"\")\n",
    "\n",
    "# Just LoRA adapters\n",
    "if False:\n",
    "    model.save_pretrained_merged(\n",
    "        \"model\",\n",
    "        tokenizer,\n",
    "        save_method=\"lora\",\n",
    "    )\n",
    "if False:\n",
    "    model.push_to_hub_merged(\"hf/model\", tokenizer, save_method=\"lora\", token=\"\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "52WMb3k_YPt8"
   },
   "source": [
    "### GGUF / llama.cpp Conversion\n",
    "To save to `GGUF` / `llama.cpp`, we support it natively now! We clone `llama.cpp` and we default save it to `q8_0`. We allow all methods like `q4_k_m`. Use `save_pretrained_gguf` for local saving and `push_to_hub_gguf` for uploading to HF.\n",
    "\n",
    "Some supported quant methods (full list on our [Wiki page](https://github.com/unslothai/unsloth/wiki#gguf-quantization-options)):\n",
    "* `q8_0` - Fast conversion. High resource use, but generally acceptable.\n",
    "* `q4_k_m` - Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q4_K.\n",
    "* `q5_k_m` - Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q5_K.\n",
    "\n",
    "[**NEW**] To finetune and auto export to Ollama, try our [Ollama notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "QyEjW-WuYQIm"
   },
   "outputs": [],
   "source": [
    "# Save to 8bit Q8_0\n",
    "if False:\n",
    "    model.save_pretrained_gguf(\n",
    "        \"model\",\n",
    "        tokenizer,\n",
    "    )\n",
    "# Remember to go to https://huggingface.co/settings/tokens for a token!\n",
    "# And change hf to your username!\n",
    "if False:\n",
    "    model.push_to_hub_gguf(\"hf/model\", tokenizer, token=\"\")\n",
    "\n",
    "# Save to 16bit GGUF\n",
    "if False:\n",
    "    model.save_pretrained_gguf(\"model\", tokenizer, quantization_method=\"f16\")\n",
    "if False:\n",
    "    model.push_to_hub_gguf(\"hf/model\", tokenizer, quantization_method=\"f16\", token=\"\")\n",
    "\n",
    "# Save to q4_k_m GGUF\n",
    "if False:\n",
    "    model.save_pretrained_gguf(\"model\", tokenizer, quantization_method=\"q4_k_m\")\n",
    "if False:\n",
    "    model.push_to_hub_gguf(\n",
    "        \"hf/model\", tokenizer, quantization_method=\"q4_k_m\", token=\"\"\n",
    "    )\n",
    "\n",
    "# Save to multiple GGUF options - much faster if you want multiple!\n",
    "if False:\n",
    "    model.push_to_hub_gguf(\n",
    "        \"hf/model\",  # Change hf to your username!\n",
    "        tokenizer,\n",
    "        quantization_method=[\n",
    "            \"q4_k_m\",\n",
    "            \"q8_0\",\n",
    "            \"q5_k_m\",\n",
    "        ],\n",
    "        token=\"\",\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "upXTZ_0aNjof"
   },
   "source": [
    "Now, use the `model-unsloth.gguf` file or `model-unsloth-Q4_K_M.gguf` file in llama.cpp or a UI based system like Jan or Open WebUI. You can install Jan [here](https://github.com/janhq/jan) and Open WebUI [here](https://github.com/open-webui/open-webui)\n",
    "\n",
    "And we're done! If you have any questions on Unsloth, we have a [Discord](https://discord.gg/unsloth) channel! If you find any bugs or want to keep updated with the latest LLM stuff, or need help, join projects etc, feel free to join our Discord!\n",
    "\n",
    "Some other links:\n",
    "1. Train your own reasoning model - Llama GRPO notebook [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-GRPO.ipynb)\n",
    "2. Saving finetunes to Ollama. [Free notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)\n",
    "3. Llama 3.2 Vision finetuning - Radiography use case. [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(11B)-Vision.ipynb)\n",
    "6. See notebooks for DPO, ORPO, Continued pretraining, conversational finetuning and more on our [documentation](https://docs.unsloth.ai/get-started/unsloth-notebooks)!\n",
    "\n",
    "<div class=\"align-center\">\n",
    "  <a href=\"https://unsloth.ai\"><img src=\"https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png\" width=\"115\"></a>\n",
    "  <a href=\"https://discord.gg/unsloth\"><img src=\"https://github.com/unslothai/unsloth/raw/main/images/Discord.png\" width=\"145\"></a>\n",
    "  <a href=\"https://docs.unsloth.ai/\"><img src=\"https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true\" width=\"125\"></a>\n",
    "\n",
    "  Join Discord if you need help + ⭐️ <i>Star us on <a href=\"https://github.com/unslothai/unsloth\">Github</a> </i> ⭐️\n",
    "</div>\n"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "T4",
   "provenance": []
  },
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}