{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Oh3eu2ULxWAn"
   },
   "source": [
    "To run this, press \"*Runtime*\" and press \"*Run all*\" on a **free** Tesla T4 Google Colab instance!\n",
    "<div class=\"align-center\">\n",
    "<a href=\"https://unsloth.ai/\"><img src=\"https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png\" width=\"115\"></a>\n",
    "<a href=\"https://discord.gg/unsloth\"><img src=\"https://github.com/unslothai/unsloth/raw/main/images/Discord button.png\" width=\"145\"></a>\n",
    "<a href=\"https://docs.unsloth.ai/\"><img src=\"https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true\" width=\"125\"></a></a> Join Discord if you need help + ⭐ <i>Star us on <a href=\"https://github.com/unslothai/unsloth\">Github</a> </i> ⭐\n",
    "</div>\n",
    "\n",
    "To install Unsloth on your own computer, follow the installation instructions on our Github page [here](https://docs.unsloth.ai/get-started/installing-+-updating).\n",
    "\n",
    "You will learn how to do [data prep](#Data), how to [train](#Train), how to [run the model](#Inference), & [how to save it](#Save)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "rNxpti3ExWAp"
   },
   "source": [
    "### News"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "yo-qTlvNxWAp"
   },
   "source": [
    "**Read our [Gemma 3 blog](https://unsloth.ai/blog/gemma3) for what's new in Unsloth and our [Reasoning blog](https://unsloth.ai/blog/r1-reasoning) on how to train reasoning models.**\n",
    "\n",
    "Visit our docs for all our [model uploads](https://docs.unsloth.ai/get-started/all-our-models) and [notebooks](https://docs.unsloth.ai/get-started/unsloth-notebooks).\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "MrkTn5S9xWAq"
   },
   "source": [
    "### Installation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "NV0BlxygxWAq"
   },
   "outputs": [],
   "source": [
    "%%capture\n",
    "import os\n",
    "\n",
    "if \"COLAB_\" not in \"\".join(os.environ.keys()):\n",
    "    !pip install unsloth vllm\n",
    "else:\n",
    "    # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]\n",
    "    !pip install --no-deps unsloth vllm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "8lIsllEIxWAr"
   },
   "outputs": [],
   "source": [
    "# @title Colab Extra Install { display-mode: \"form\" }\n",
    "%%capture\n",
    "import os\n",
    "\n",
    "if \"COLAB_\" not in \"\".join(os.environ.keys()):\n",
    "    !pip install unsloth vllm\n",
    "else:\n",
    "    !pip install --no-deps unsloth vllm\n",
    "    # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]\n",
    "    # Skip restarting message in Colab\n",
    "    import sys\n",
    "    import re\n",
    "    import requests\n",
    "\n",
    "    modules = list(sys.modules.keys())\n",
    "    for x in modules:\n",
    "        sys.modules.pop(x) if \"PIL\" in x or \"google\" in x else None\n",
    "    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo\n",
    "    !pip install sentencepiece protobuf datasets huggingface_hub hf_transfer\n",
    "\n",
    "    # vLLM requirements - vLLM breaks Colab due to reinstalling numpy\n",
    "    f = requests.get(\n",
    "        \"https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/requirements/common.txt\"\n",
    "    ).content\n",
    "    with open(\"vllm_requirements.txt\", \"wb\") as file:\n",
    "        file.write(re.sub(rb\"(transformers|numpy|xformers)[^\\n]{1,}\\n\", b\"\", f))\n",
    "    !pip install -r vllm_requirements.txt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "a6uTCUXFxWAr"
   },
   "source": [
    "### Unsloth"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Joje4qPsyxM9"
   },
   "source": [
    "Load up `Qwen 2.5 3B Instruct`, and set parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000,
     "referenced_widgets": [
      "b46d7118f71043109c2dd00fb0f40557",
      "32defd049d8a473c8ad8fc954c0e3d13",
      "424078e3ad6c405c925bedfb76d18b3f",
      "8d937714a44d422aa9851c02c2345e5c",
      "25b26c398e314cb79a2f8bbb7e32e158",
      "694494752ab2458ead83b0331af20122",
      "1cc89fa4d7c744c69c3be1a3d0b3ba2d",
      "15116ef21fae463ba9013c8b9dbd40b5",
      "43fdf1fe592147538ec43410a1bf8e09",
      "52ba857bdfc742a2aae9e9a949054e5c",
      "589a3857411f49d4ae019465f536da45",
      "5dafeefcf436462e81a9affe7f6c66d9",
      "98d1697fb2054fd88432f027e592ecc0",
      "b1adc20d1d8740f99a6ebd589f240688",
      "f45a696d73054cd6af181a7ce35a62e4",
      "372b687ac774400088d7252467f9cef2",
      "a422e26291bb47d08e36b7d94fffcc9b",
      "799dc8fe6f7d4e4db68ebd6b1d8fdbd5",
      "e61a718f06a045b6b028fbd33382d37a",
      "7c712c0011954872834e90cc908b8cf6",
      "9929aaffe7614506ae934a14d2b5c3ca",
      "1060550e9dec4fa59e2b9b488af61b48",
      "ebd36fe7f8dd47b99e69176ace9b985d",
      "42a3c5c410cd49a38400b788980ad428",
      "e5373b31cbfc4e87af314bf4898d2b72",
      "4379f29ff4a74df784bb2fb3989c64cc",
      "87b42e656263458cbb33ee3ca29128a8",
      "b245d8e4c115438583ef8b433d4d16be",
      "df7f9d244f70440aaf401c6a856a6ed0",
      "5d39d9e8cb724af59d7af29ec9c41cf4",
      "6292f7a4de1948d09bebaf5064fa0d42",
      "62cbedb2d67d4a2c8888b46e56e39f65",
      "57f2c072b0574f74a7dbe2a9d303cbd4",
      "78345adf6d8143e3a649a05c5d476115",
      "f89bfff6526c43bcb75c93f987bfba72",
      "a7244280fc934b8484b3dfd7eecca161",
      "ec7ae2abc92f49f98ab0c52953bf82af",
      "092355509a67474e9cc13a6b4718a192",
      "c69730dfeec14cedaf8917ee9ad27dcb",
      "00177cac24374c7d9144e3856594ddaf",
      "1b1e49195c5b48a1b6e4364db01c5c87",
      "97af130f6bd54d618e0af88ed9e85819",
      "1f88a409f3c64ea291861837e830ce39",
      "e9881edb2c7846fea2070c8e724b25cb",
      "dfc8cc5ee6704f1fb3b4df0507ea471b",
      "16bb453bd4934bd9840a9d365be604cd",
      "1ee1a56d04fa4e85a20e4cde46c164e5",
      "9872da92026e4debb96b731c892408e2",
      "451d181b9ef24b0e9fc3914a4466b9f1",
      "6a49aa5a2b3646d4b772e2398eacf72d",
      "edeb0f283b0b495eab6e06a15fa38af0",
      "0ed98c973a9d4d30b31ab751a7ecf3f6",
      "a0b7dab4f65e448f816e00d12dc3b469",
      "3b6d4b41992b43acac5daa31f119cd30",
      "c5b4960fc3ba4459a6eaff43772050a3",
      "71f9e574addd491cbbe65ddbe0eb23ad",
      "442e58ceec4e44a3a4a2ebadd145ee73",
      "4382289897b54e63b22541d0b45c9bf3",
      "ba6998ac52544199b4e828d1a7f99d0d",
      "29399d40d44d46c8b3648218e8378e89",
      "c6b2d407256549a5bc7543cb6d7475b4",
      "545251749b10408d82c9a60b914dce34",
      "1bbadef2fca5427fa1a4975153968f56",
      "d267ec5dd1954b9a88e566dff4a9af9d",
      "4ddc8fa0c9834cd3b36274a9e0e05504",
      "ec13219478284bbdbe69cbcd3dff5c34",
      "562637f216114f42aa352cabd59cab57",
      "e37bcfc1604b4a788f53825c5a7e04b7",
      "16e05d45084c44b68697316bc8791b88",
      "9b80621f94a0484791b33a13b3c6bade",
      "fc87fbccba304554be90701b39429918",
      "aaf79572906c4443b5542bef5dee8055",
      "bc81dfa869f94f059545a7d59b7eb70a",
      "b3c84c45c0804e4cb703f9c512c6f90d",
      "6520db76749d4a8aaa6d26005f07c43c",
      "29a3f867c8bd4769b8e649b8912776a2",
      "2317c9a94dfe4a62a8605038ebc2210a",
      "bd27949a1e3c43e5912a01b04b6621a4",
      "94ed75e3a9fe4cbdb373d4038af0c32b",
      "adfd8a554ada48f9a3dc2fab69467b79",
      "0f9dcd9b698b4af699cfafdf2874612c",
      "f8700d095ee04a40a15c6a88d7daf1ae",
      "1ccb12282661417383a7f96bd5ea4eac",
      "32a16b0cbbe548f4869cf665d737b9ba",
      "3f546bf0aa124e45b39daa0328a1d23d",
      "61f6966f8b6d4070a9ea3604373cb96a",
      "f8198301b2db406685189589e59efd9a",
      "e813e26451e3476a82a9b75668b638a1",
      "8aa00f7a2de64af8b0a5b8cd5fefb14e",
      "3404fb0c4c4b4ca0ae6d9eeec6722e02",
      "f268eda021e040e6b25889e72ee8dd1e",
      "b61f7faf19594e3c81e1429045f341de",
      "a4ecf424e6ed44c597f25acee7b43df2",
      "c314f7a0c6084790a2620ec231297525",
      "950c1b85a8bd456ab928df4cc3710b21",
      "806b17b8ce074563941dbba9ca2c9422",
      "0b42f72fac844c19adea7b6689e79432",
      "37e2f38471d4481eb93031ca52bdb785",
      "067aaf24d87a400b9f636762051598f3",
      "a03f0a986af84eacbb5c333f49726105",
      "abf0e1d60e9f4fadbcda876dd28f6d14",
      "81a7650cc4124b58b4e191093a5433e1",
      "74f24411d35e4c918c608a1afd799156",
      "a8501b8551d4431fbb708d6556e7dd85",
      "24757877452d48f0afe92bb66e2c48ac",
      "b64be6a0d6df44899ba6fe0a47f676e1",
      "c1588f6fe9ca48a5a7512734f26ae9a3",
      "374765208bf345089f58e925d2a70984",
      "78774135f2d1468d9172fa7588afef20",
      "9b3233c52ecc4b259deed2d5d6f332eb",
      "2567e718b28f417e834a10c5310f92b2",
      "587ea317bd0a4831a9417a3980207a73",
      "9e857c0cb85d4b50ad73c670436b48a5",
      "f8705a1680534caa801dc6e14c3601e3",
      "7e982757835048ae85abf6cba5d27ecb",
      "e14a085ca7e6410ca78fe0f27533be6d",
      "6d204a683e2d49deb5f504f91ee72044",
      "6659b3fef8644c6e96ed9d43adb744d5",
      "4c130bb196db422c8b9c309cbe59bac1",
      "32f42609c389450bbaad10f0a884c527",
      "a9e874a076234974a102b02551ec8666",
      "e505a51d928542b883743a6d48c5cbd3",
      "32ae8645c975473b81788c17c376367a",
      "7cb758be0be64a48930379b2cfa9fb44",
      "fb4bf57c071d4ef68046988a4965b312",
      "f1af6e23e42c4051a1c7fd4864450eaa",
      "5cf88ae3c2fc4419a71eb45f28f67384",
      "327fa24e70ee4cae903fe5673bcf3bfc",
      "d1d5afc36de043afb8638dec0001288e",
      "fffe74f9a1184de5b1297c659c7891b6",
      "878702a019e34789b5b2f68eb321f2c8",
      "91998be216844c079144c18278f660d7",
      "5a1c360e82204ad99238901d366828cf",
      "fc98ab1d58d84f69b6c36dbc19f5edd7",
      "e2e86147b7d44f66a2cdfe02d63ee141",
      "2f5178d399ec48648b435bac95faf188",
      "c13ff304d7bc4c67a1f51da5d2303f96",
      "0e8c4086734c4296884db9c5b67ddd33",
      "c3a97f3130bc42c1be2f8c7dc81ca6cc",
      "0043184800574bfdbfb3c82f4842c3ed",
      "32719fe47c8543f69bcddc9b9bc0f719",
      "784a5c0d201d46459a160e0dc8857fd9",
      "d753c1ecc6034237ba4401a007451bfb",
      "55fd60765d5e46c0baf1ac075820fc87",
      "771af8b28988459ba4f2dbaadd0d5ce1",
      "0cc22299405f4361b1b5f6234f35bae3",
      "0f22a4136dce46f89e9efbf0581057a2",
      "5cda6bd66e3248c4be477f1c1b40198d",
      "269b131121554544a21d47596a24aa58",
      "73758a3ca9c14d7eaa6504d8b9b6e333",
      "5a16728cee0f4b75be08696261d8d8cc",
      "da77a7b205b1446a9dacab0550aafebe",
      "d6bd8a9af6e84500956bb565f2fe2c4e",
      "6643bd32eb984a929159c8d441530aaf",
      "8358f98f790d47fdb783eb5ca210c1fc",
      "0c94630c2693424d9edd8b9921333aac",
      "edf185cd11544b2c8bce478fb2f0b565",
      "a7b4b7d3f1a34b089d72eb2a4ab169b0",
      "09a607655a054ed5ad7f36716f8ee4c6",
      "9258ed2d94924de3b944105e280fbf5f",
      "1c68fbbc833d4c3fbbb686f15b9ba939",
      "fa996ce2f8f24a5d80914c4ef5b6e788",
      "f77e9e785cc04c00822fc8b7ef2959da",
      "5d190801b10b49a698c3a69c7cd1f253",
      "0112cfe0a27e43fd9509ef8da050e188",
      "4c6b2159bbc04b8d811fadff3a22f4a6",
      "c705d699568543829441df2258f2e615",
      "eb8f921c8d67442b9bc2a6a67930d94c",
      "8a89212807f243d4b135a3188f4fb6d0",
      "cc018dfe29144d0182058816134e1fd9",
      "c404db8298364742b8ce076cddcb7641",
      "03aca6d92597461481419eac9bdb6635",
      "59b90e05822f480793685902b9c6d055",
      "d13cc9053e7a4982b6fc7f00474f25c7",
      "c9b781db7ee2422b8dc4eb6aa2575d87",
      "5100eb13df1c453980942ec78d0e810b"
     ]
    },
    "id": "DkIvEkIIkEyB",
    "outputId": "8bf13df5-67ab-4ea2-db01-5719298b2dc9"
   },
   "outputs": [],
   "source": [
    "from unsloth import FastLanguageModel, is_bfloat16_supported\n",
    "\n",
    "max_seq_length = 1024  # Can increase for longer reasoning traces\n",
    "lora_rank = 64  # Larger rank = smarter, but slower\n",
    "\n",
    "model, tokenizer = FastLanguageModel.from_pretrained(\n",
    "    model_name=\"Qwen/Qwen2.5-3B-Instruct\",\n",
    "    max_seq_length=max_seq_length,\n",
    "    load_in_4bit=True,  # False for LoRA 16bit\n",
    "    fast_inference=True,  # Enable vLLM fast inference\n",
    "    max_lora_rank=lora_rank,\n",
    "    gpu_memory_utilization=0.5,  # Reduce if out of memory\n",
    ")\n",
    "\n",
    "model = FastLanguageModel.get_peft_model(\n",
    "    model,\n",
    "    r=lora_rank,  # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128\n",
    "    target_modules=[\n",
    "        \"q_proj\",\n",
    "        \"k_proj\",\n",
    "        \"v_proj\",\n",
    "        \"o_proj\",\n",
    "        \"gate_proj\",\n",
    "        \"up_proj\",\n",
    "        \"down_proj\",\n",
    "    ],  # Remove QKVO if out of memory\n",
    "    lora_alpha=lora_rank,\n",
    "    use_gradient_checkpointing=\"unsloth\",  # Enable long context finetuning\n",
    "    random_state=3407,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0Y56ln_izS9E"
   },
   "source": [
    "### Data Prep\n",
    "<a name=\"Data\"></a>\n",
    "\n",
    "We directly leverage [@willccbb](https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb) for data prep and all reward functions. You are free to create your own!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 209,
     "referenced_widgets": [
      "6b2cf7a462884bc4a3c2f8670f8982ed",
      "8a358b880c194e6baab24f711937f695",
      "0ec5dfa2fa6b4be2bfb45692ac20c053",
      "1a619a0a53c149c3b9b0cbd53f1d2034",
      "fdcbb976f35e4aa383d9ebcd7304849f",
      "b131891a909743f5826429a84a74b4f9",
      "3745da58146d445abe863a248e86d9be",
      "22d29d39ef8149dba932ed9af68def67",
      "413b4818b6de4bd995944df2f73bc3d5",
      "c626fa7f5c624ec5a414e6f58d71ea96",
      "2a8fb4b980ab45be81ae39d5c6b7bad4",
      "ec1837ba407c4ecb86335d9c0c239ad9",
      "1793da35014840b08d776b4abc9a1212",
      "a06bde82466c43a49b58c0d942071be0",
      "a1a3ee663a154319ba186954816e6932",
      "97a9db845a9c45caa88ea1cd0ab6c40d",
      "231be59acb2a4921ba53ba97bdf5888a",
      "5aa0e699560c4fe89250b4ccac6401cb",
      "7529f4d84b9f46709ec5f0e22b75baf4",
      "8b8fd84911e94b5d8d627911fa9a64f7",
      "420154f0bbe344a7979769a8ef066df7",
      "1ee34b9171604805887aee3baa4ff099",
      "3c95b86a26bc49b4a65589df0e7865e2",
      "ef6d7a6a7fb344fab5befe656b98c7b7",
      "c3d523c23ebf4b769c59eda950a3fc1f",
      "bffc24c5bcdb47988f2e309c7fbc21b8",
      "11052ab3d1f441248458cd6ab8b810db",
      "0c2e39b20d5941969101b793b94fdbfa",
      "f0c74de390814c2b9f38f7cc02458426",
      "6c9502ae4cbc4248b14590fc6efe832f",
      "8972cb9e823342bcb5c493e487dc64b6",
      "71898e427b1245d682178a90c6752e8e",
      "90295f1308054c66b557da231d25d3d5",
      "54f444e6a25349fcb8fb35b9efb33e68",
      "2f9a64ad318248bf99fccdedafaddfc0",
      "d3ce3f3bee5646d893242758444acd6a",
      "5b4f99c20a8a4ecf873b0a8ef8b150cd",
      "70e141512e91492bae4713400879d019",
      "7c67cf4d83e34cd6a475405be45f345e",
      "c184e417ad134e94995a55d4bc7c94b5",
      "f1c8d7356b584fb8965d2a62ccec52c8",
      "9bcb0000da6e434d9b241a09d9b46244",
      "11bd7b89fa794a899db9c83cb2a99df5",
      "0e9d3d2244a54eb781311c01965c33e4",
      "00d1bd9a327241e8b434ef7b3bde30df",
      "cd17f032bbb44575977797caaa92bd54",
      "04b74b7851334892b0db68bea1c8fd7e",
      "44b6ecc9b9fc436e92701ec2409e0dd1",
      "46998bc2b1ee4858a8ee1108ce36b9b7",
      "95fa6a54e7e543c7b2dc5d5cf04daf58",
      "8ee15a3d02954d31899d0b6a8dbf7988",
      "1d5a58d63c51427bbbfb729216ea8468",
      "e2790d1e5e714b3e9081b4edcef24535",
      "96519710bf5047239f6800100471ee6e",
      "e9b020ce2dda442b871b41e417210b52",
      "4d18ef3f7c8046ddb0b03314861c3414",
      "176ccdb40b994ee3b381f7ea0f379c3c",
      "3519fdf3abac4d7690c45710f617e5d0",
      "dc320ae9e2f7426588d49848b0871ca4",
      "6b5a1287c457454fb367e5dc03d8b157",
      "989519f3e61b4132b3f88fbdf29703a9",
      "e0f87f6fd7c84261818624defe575a1c",
      "e6d2b6aaa9794971b01c778b1a0596db",
      "334c4833051e4ccea727dc2cdf3871c3",
      "d36027362b41413fa3ee4670dd820a88",
      "d94e3315b99747c39190fee9b79f00ac"
     ]
    },
    "id": "cXk993X6C2ZZ",
    "outputId": "2d545127-0313-405e-d8f9-a43d7dba8d11"
   },
   "outputs": [],
   "source": [
    "import re\n",
    "from datasets import load_dataset, Dataset\n",
    "\n",
    "# Load and prep dataset\n",
    "SYSTEM_PROMPT = \"\"\"\n",
    "Respond in the following format:\n",
    "<reasoning>\n",
    "...\n",
    "</reasoning>\n",
    "<answer>\n",
    "...\n",
    "</answer>\n",
    "\"\"\"\n",
    "\n",
    "XML_COT_FORMAT = \"\"\"\\\n",
    "<reasoning>\n",
    "{reasoning}\n",
    "</reasoning>\n",
    "<answer>\n",
    "{answer}\n",
    "</answer>\n",
    "\"\"\"\n",
    "\n",
    "\n",
    "def extract_xml_answer(text: str) -> str:\n",
    "    answer = text.split(\"<answer>\")[-1]\n",
    "    answer = answer.split(\"</answer>\")[0]\n",
    "    return answer.strip()\n",
    "\n",
    "\n",
    "def extract_hash_answer(text: str) -> str | None:\n",
    "    if \"####\" not in text:\n",
    "        return None\n",
    "    return text.split(\"####\")[1].strip()\n",
    "\n",
    "\n",
    "# uncomment middle messages for 1-shot prompting\n",
    "def get_gsm8k_questions(split=\"train\") -> Dataset:\n",
    "    data = load_dataset(\"openai/gsm8k\", \"main\")[split]  # type: ignore\n",
    "    data = data.map(\n",
    "        lambda x: {  # type: ignore\n",
    "            \"prompt\": [\n",
    "                {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
    "                {\"role\": \"user\", \"content\": x[\"question\"]},\n",
    "            ],\n",
    "            \"answer\": extract_hash_answer(x[\"answer\"]),\n",
    "        }\n",
    "    )  # type: ignore\n",
    "    return data  # type: ignore\n",
    "\n",
    "\n",
    "dataset = get_gsm8k_questions()\n",
    "\n",
    "\n",
    "# Reward functions\n",
    "def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:\n",
    "    responses = [completion[0][\"content\"] for completion in completions]\n",
    "    q = prompts[0][-1][\"content\"]\n",
    "    extracted_responses = [extract_xml_answer(r) for r in responses]\n",
    "    print(\n",
    "        \"-\" * 20,\n",
    "        f\"Question:\\n{q}\",\n",
    "        f\"\\nAnswer:\\n{answer[0]}\",\n",
    "        f\"\\nResponse:\\n{responses[0]}\",\n",
    "        f\"\\nExtracted:\\n{extracted_responses[0]}\",\n",
    "    )\n",
    "    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]\n",
    "\n",
    "\n",
    "def int_reward_func(completions, **kwargs) -> list[float]:\n",
    "    responses = [completion[0][\"content\"] for completion in completions]\n",
    "    extracted_responses = [extract_xml_answer(r) for r in responses]\n",
    "    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]\n",
    "\n",
    "\n",
    "def strict_format_reward_func(completions, **kwargs) -> list[float]:\n",
    "    \"\"\"Reward function that checks if the completion has a specific format.\"\"\"\n",
    "    pattern = r\"^<reasoning>\\n.*?\\n</reasoning>\\n<answer>\\n.*?\\n</answer>\\n$\"\n",
    "    responses = [completion[0][\"content\"] for completion in completions]\n",
    "    matches = [re.match(pattern, r) for r in responses]\n",
    "    return [0.5 if match else 0.0 for match in matches]\n",
    "\n",
    "\n",
    "def soft_format_reward_func(completions, **kwargs) -> list[float]:\n",
    "    \"\"\"Reward function that checks if the completion has a specific format.\"\"\"\n",
    "    pattern = r\"<reasoning>.*?</reasoning>\\s*<answer>.*?</answer>\"\n",
    "    responses = [completion[0][\"content\"] for completion in completions]\n",
    "    matches = [re.match(pattern, r) for r in responses]\n",
    "    return [0.5 if match else 0.0 for match in matches]\n",
    "\n",
    "\n",
    "def count_xml(text) -> float:\n",
    "    count = 0.0\n",
    "    if text.count(\"<reasoning>\\n\") == 1:\n",
    "        count += 0.125\n",
    "    if text.count(\"\\n</reasoning>\\n\") == 1:\n",
    "        count += 0.125\n",
    "    if text.count(\"\\n<answer>\\n\") == 1:\n",
    "        count += 0.125\n",
    "        count -= len(text.split(\"\\n</answer>\\n\")[-1]) * 0.001\n",
    "    if text.count(\"\\n</answer>\") == 1:\n",
    "        count += 0.125\n",
    "        count -= (len(text.split(\"\\n</answer>\")[-1]) - 1) * 0.001\n",
    "    return count\n",
    "\n",
    "\n",
    "def xmlcount_reward_func(completions, **kwargs) -> list[float]:\n",
    "    contents = [completion[0][\"content\"] for completion in completions]\n",
    "    return [count_xml(c) for c in contents]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "bTnL_tJnzh2L"
   },
   "source": [
    "<a name=\"Train\"></a>\n",
    "### Train the model\n",
    "\n",
    "Now set up GRPO Trainer and all configurations!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "ptqkXK2D4d6p",
    "outputId": "7d36a642-4cc8-48ee-c8cc-1470978d4503"
   },
   "outputs": [],
   "source": [
    "from trl import GRPOConfig, GRPOTrainer\n",
    "\n",
    "training_args = GRPOConfig(\n",
    "    use_vllm=True,  # use vLLM for fast inference!\n",
    "    learning_rate=5e-6,\n",
    "    adam_beta1=0.9,\n",
    "    adam_beta2=0.99,\n",
    "    weight_decay=0.1,\n",
    "    warmup_ratio=0.1,\n",
    "    lr_scheduler_type=\"cosine\",\n",
    "    optim=\"adamw_8bit\",\n",
    "    logging_steps=1,\n",
    "    bf16=is_bfloat16_supported(),\n",
    "    fp16=not is_bfloat16_supported(),\n",
    "    per_device_train_batch_size=1,\n",
    "    gradient_accumulation_steps=1,  # Increase to 4 for smoother training\n",
    "    num_generations=8,  # Decrease if out of memory\n",
    "    max_prompt_length=256,\n",
    "    max_completion_length=200,\n",
    "    # num_train_epochs = 1, # Set to 1 for a full training run\n",
    "    max_steps=250,\n",
    "    save_steps=250,\n",
    "    max_grad_norm=0.1,\n",
    "    report_to=\"none\",  # Can use Weights & Biases\n",
    "    output_dir=\"outputs\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "X_71Y0eKz5yE"
   },
   "source": [
    "And let's run the trainer! If you scroll up, you'll see a table of rewards. The goal is to see the `reward` column increase!\n",
    "\n",
    "You might have to wait 150 to 200 steps for any action. You'll probably get 0 reward for the first 100 steps. Please be patient!\n",
    "\n",
    "| Step | Training Loss | reward    | reward_std | completion_length | kl       |\n",
    "|------|---------------|-----------|------------|-------------------|----------|\n",
    "| 1    | 0.000000      | 0.125000  | 0.000000   | 200.000000        | 0.000000 |\n",
    "| 2    | 0.000000      | 0.072375  | 0.248112   | 200.000000        | 0.000000 |\n",
    "| 3    | 0.000000      | -0.079000 | 0.163776   | 182.500000        | 0.000005 |\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "id": "vzOuSVCL_GA9",
    "outputId": "c6fdbbe8-7062-45a2-a48c-649f2b878ac4"
   },
   "outputs": [],
   "source": [
    "trainer = GRPOTrainer(\n",
    "    model=model,\n",
    "    processing_class=tokenizer,\n",
    "    reward_funcs=[\n",
    "        xmlcount_reward_func,\n",
    "        soft_format_reward_func,\n",
    "        strict_format_reward_func,\n",
    "        int_reward_func,\n",
    "        correctness_reward_func,\n",
    "    ],\n",
    "    args=training_args,\n",
    "    train_dataset=dataset,\n",
    ")\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "yUbluAAhD0Lg"
   },
   "source": [
    "<a name=\"Inference\"></a>\n",
    "### Inference\n",
    "Now let's try the model we just trained! First, let's first try the model without any GRPO trained:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 72
    },
    "id": "IqzsdZzeDM_m",
    "outputId": "a7637565-80fe-4ee9-cdc9-33b9affb65dc"
   },
   "outputs": [],
   "source": [
    "text = tokenizer.apply_chat_template(\n",
    "    [\n",
    "        {\"role\": \"user\", \"content\": \"How many r's are in strawberry?\"},\n",
    "    ],\n",
    "    tokenize=False,\n",
    "    add_generation_prompt=True,\n",
    ")\n",
    "\n",
    "from vllm import SamplingParams\n",
    "\n",
    "sampling_params = SamplingParams(\n",
    "    temperature=0.8,\n",
    "    top_p=0.95,\n",
    "    max_tokens=1024,\n",
    ")\n",
    "output = (\n",
    "    model.fast_generate(\n",
    "        [text],\n",
    "        sampling_params=sampling_params,\n",
    "        lora_request=None,\n",
    "    )[0]\n",
    "    .outputs[0]\n",
    "    .text\n",
    ")\n",
    "\n",
    "output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "G4lzJD7REFjs"
   },
   "source": [
    "And now with the LoRA we just trained with GRPO - we first save the LoRA first!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "YC9BBT0RESln"
   },
   "outputs": [],
   "source": [
    "model.save_lora(\"grpo_saved_lora\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "LherO2vzEbMt"
   },
   "source": [
    "Now we load the LoRA and test:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 125
    },
    "id": "SDKIhhvN6lAF",
    "outputId": "8bc37894-678f-40d7-874b-365300685836"
   },
   "outputs": [],
   "source": [
    "text = tokenizer.apply_chat_template(\n",
    "    [\n",
    "        {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
    "        {\"role\": \"user\", \"content\": \"How many r's are in strawberry?\"},\n",
    "    ],\n",
    "    tokenize=False,\n",
    "    add_generation_prompt=True,\n",
    ")\n",
    "\n",
    "from vllm import SamplingParams\n",
    "\n",
    "sampling_params = SamplingParams(\n",
    "    temperature=0.8,\n",
    "    top_p=0.95,\n",
    "    max_tokens=1024,\n",
    ")\n",
    "output = (\n",
    "    model.fast_generate(\n",
    "        text,\n",
    "        sampling_params=sampling_params,\n",
    "        lora_request=model.load_lora(\"grpo_saved_lora\"),\n",
    "    )[0]\n",
    "    .outputs[0]\n",
    "    .text\n",
    ")\n",
    "\n",
    "output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "_ZBnvg2f9Nlg"
   },
   "source": [
    "Our reasoning model is much better - it's not always correct, since we only trained it for an hour or so - it'll be better if we extend the sequence length and train for longer!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "RphEZRSfFhru"
   },
   "source": [
    "<a name=\"Save\"></a>\n",
    "### Saving to float16 for VLLM\n",
    "\n",
    "We also support saving to `float16` directly. Select `merged_16bit` for float16 or `merged_4bit` for int4. We also allow `lora` adapters as a fallback. Use `push_to_hub_merged` to upload to your Hugging Face account! You can go to https://huggingface.co/settings/tokens for your personal tokens."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "GwNY9_PrFiXZ"
   },
   "outputs": [],
   "source": [
    "# Merge to 16bit\n",
    "if False:\n",
    "    model.save_pretrained_merged(\n",
    "        \"model\",\n",
    "        tokenizer,\n",
    "        save_method=\"merged_16bit\",\n",
    "    )\n",
    "if False:\n",
    "    model.push_to_hub_merged(\n",
    "        \"hf/model\", tokenizer, save_method=\"merged_16bit\", token=\"\"\n",
    "    )\n",
    "\n",
    "# Merge to 4bit\n",
    "if False:\n",
    "    model.save_pretrained_merged(\n",
    "        \"model\",\n",
    "        tokenizer,\n",
    "        save_method=\"merged_4bit\",\n",
    "    )\n",
    "if False:\n",
    "    model.push_to_hub_merged(\"hf/model\", tokenizer, save_method=\"merged_4bit\", token=\"\")\n",
    "\n",
    "# Just LoRA adapters\n",
    "if False:\n",
    "    model.save_pretrained_merged(\n",
    "        \"model\",\n",
    "        tokenizer,\n",
    "        save_method=\"lora\",\n",
    "    )\n",
    "if False:\n",
    "    model.push_to_hub_merged(\"hf/model\", tokenizer, save_method=\"lora\", token=\"\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "BDUGPiL3Fkkq"
   },
   "source": [
    "### GGUF / llama.cpp Conversion\n",
    "To save to `GGUF` / `llama.cpp`, we support it natively now! We clone `llama.cpp` and we default save it to `q8_0`. We allow all methods like `q4_k_m`. Use `save_pretrained_gguf` for local saving and `push_to_hub_gguf` for uploading to HF.\n",
    "\n",
    "Some supported quant methods (full list on our [Wiki page](https://github.com/unslothai/unsloth/wiki#gguf-quantization-options)):\n",
    "* `q8_0` - Fast conversion. High resource use, but generally acceptable.\n",
    "* `q4_k_m` - Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q4_K.\n",
    "* `q5_k_m` - Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q5_K.\n",
    "\n",
    "[**NEW**] To finetune and auto export to Ollama, try our [Ollama notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "AGo4dbWvFk4M"
   },
   "outputs": [],
   "source": [
    "# Save to 8bit Q8_0\n",
    "if False:\n",
    "    model.save_pretrained_gguf(\n",
    "        \"model\",\n",
    "        tokenizer,\n",
    "    )\n",
    "# Remember to go to https://huggingface.co/settings/tokens for a token!\n",
    "# And change hf to your username!\n",
    "if False:\n",
    "    model.push_to_hub_gguf(\"hf/model\", tokenizer, token=\"\")\n",
    "\n",
    "# Save to 16bit GGUF\n",
    "if False:\n",
    "    model.save_pretrained_gguf(\"model\", tokenizer, quantization_method=\"f16\")\n",
    "if False:\n",
    "    model.push_to_hub_gguf(\"hf/model\", tokenizer, quantization_method=\"f16\", token=\"\")\n",
    "\n",
    "# Save to q4_k_m GGUF\n",
    "if False:\n",
    "    model.save_pretrained_gguf(\"model\", tokenizer, quantization_method=\"q4_k_m\")\n",
    "if False:\n",
    "    model.push_to_hub_gguf(\n",
    "        \"hf/model\", tokenizer, quantization_method=\"q4_k_m\", token=\"\"\n",
    "    )\n",
    "\n",
    "# Save to multiple GGUF options - much faster if you want multiple!\n",
    "if False:\n",
    "    model.push_to_hub_gguf(\n",
    "        \"hf/model\",  # Change hf to your username!\n",
    "        tokenizer,\n",
    "        quantization_method=[\n",
    "            \"q4_k_m\",\n",
    "            \"q8_0\",\n",
    "            \"q5_k_m\",\n",
    "        ],\n",
    "        token=\"\",\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "bhYfRM8PxWAu"
   },
   "source": [
    "Now, use the `model-unsloth.gguf` file or `model-unsloth-Q4_K_M.gguf` file in llama.cpp or a UI based system like Jan or Open WebUI. You can install Jan [here](https://github.com/janhq/jan) and Open WebUI [here](https://github.com/open-webui/open-webui)\n",
    "\n",
    "And we're done! If you have any questions on Unsloth, we have a [Discord](https://discord.gg/unsloth) channel! If you find any bugs or want to keep updated with the latest LLM stuff, or need help, join projects etc, feel free to join our Discord!\n",
    "\n",
    "Some other links:\n",
    "1. Train your own reasoning model - Llama GRPO notebook [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-GRPO.ipynb)\n",
    "2. Saving finetunes to Ollama. [Free notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)\n",
    "3. Llama 3.2 Vision finetuning - Radiography use case. [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(11B)-Vision.ipynb)\n",
    "6. See notebooks for DPO, ORPO, Continued pretraining, conversational finetuning and more on our [documentation](https://docs.unsloth.ai/get-started/unsloth-notebooks)!\n",
    "\n",
    "<div class=\"align-center\">\n",
    "  <a href=\"https://unsloth.ai\"><img src=\"https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png\" width=\"115\"></a>\n",
    "  <a href=\"https://discord.gg/unsloth\"><img src=\"https://github.com/unslothai/unsloth/raw/main/images/Discord.png\" width=\"145\"></a>\n",
    "  <a href=\"https://docs.unsloth.ai/\"><img src=\"https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true\" width=\"125\"></a>\n",
    "\n",
    "  Join Discord if you need help + ⭐️ <i>Star us on <a href=\"https://github.com/unslothai/unsloth\">Github</a> </i> ⭐️\n",
    "</div>\n"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "T4",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}