{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "wAwm-SoKNjoX" }, "source": [ "# Sample GRPO Training Notebook\n", "- Stolen from unsloth :3 thanks!\n" ] }, { "cell_type": "markdown", "metadata": { "id": "rv8V30lINjob" }, "source": [ "### Installation" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1iVpHRPXNjoc" }, "outputs": [], "source": [ "# @title Colab Extra Install { display-mode: \"form\" }\n", "import os\n", "\n", "if \"COLAB_\" not in \"\".join(os.environ.keys()):\n", " !pip install -q unsloth vllm\n", "else:\n", " !pip install -q --no-deps unsloth vllm\n", " # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]\n", " # Skip restarting message in Colab\n", " import sys, re, requests\n", "\n", " modules = list(sys.modules.keys())\n", " for x in modules:\n", " sys.modules.pop(x) if \"PIL\" in x or \"google\" in x else None\n", " !pip install -q --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo\n", " !pip install -q sentencepiece protobuf datasets huggingface_hub hf_transfer\n", "\n", " # vLLM requirements - vLLM breaks Colab due to reinstalling numpy\n", " f = requests.get(\n", " \"https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/requirements/common.txt\"\n", " ).content\n", " with open(\"vllm_requirements.txt\", \"wb\") as file:\n", " file.write(re.sub(rb\"(transformers|numpy|xformers)[^\\n]{1,}\\n\", b\"\", f))\n", " !pip install -q -r vllm_requirements.txt" ] }, { "cell_type": "markdown", "metadata": { "id": "HnqT9aF5Njoc" }, "source": [ "### Unsloth" ] }, { "cell_type": "markdown", "metadata": { "id": "IpAO-idyNjoc" }, "source": [ "Load up `Llama 3.1 8B Instruct`, and set parameters" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000, "referenced_widgets": [ "09fca7e3b90f4b17a2344b67100e26bf", "6eee6e73473440a39c2a14d11b19e7b9", "f64c0803b3c34610a34f9c5f2b5ce6ba", "4bae66db2fc041278f13ae1def5d1ddd", "ce588ed5880240d7bfd91c150ef299ed", "bc92bf99bcf345b59699eb8f0423f7cd", "f589e115f63d4fdb95816dccb34c3e3e", "fb45312477f64384912ba3e843efb0b3", "4e541334eb26485ca5764b9bf0bbbd34", "fe9dbcf4e85646cb8c684d82252b93e2", "d36f0dda703c47cfae6f544e7e655806", "cc9a242d278d4ea1b0f0f44d611439a2", "9ec23dd10394466eb6db5336912f09de", "e72af3cf2aac44d3a10b692b5fd54afe", "7a0041d1e5c349b4b70e96554ea18113", "f3fe4eb2a2fc4b2fb255632272dc0415", "b095d3eac38847c0b087ffcc5876e0b6", "9ebd8c869bc144cf8142375e7d7f04dd", "48403d21996b4ebfa931cf260af6c283", "b019cf9965424856b3fd781bd9e73118", "ff032a23470845ba95a12227baf0e0ce", "18d2fd154ffc4671b76f00760eb35571", "bc8f8134f144435681cbc50f9f5e5274", "d7a1980bf7cc4895aea14df80bf92302", "6be20c9739de44ffb7ef2465ad17424d", "39411144c4c44eae935262bc6312d5b9", "32f3e3f08eb1441ca70191eaa6191805", "2df204a0982b466eb3c25bf80645b6c1", "9ec592e894b945cea7c2a756a96eb959", "7ff47ae65e944dc9bbd50f01edb7fa02", "2f5bc26300174430aa5b0610b5ef4c72", "79da5ea5195c4c63b0ec1ee354511a87", "a5f3704359df48e4b4c5b418a5f72021", "2cae8f5f92304278bcfa4a07fb5aafdc", "63681e5808d1443c9f04654d58952448", "91200f911ca64162b006235d0ec0c5f1", "40c4dfad109643529a16b3de307e2e1a", "491466071e6742c6ab30419597d78fde", "a67ca6e6c91244978b6b468257ffefd7", "781f0652012b424aa43cd39028e3a99a", "9e4ae8a9e1b247ffbd232d97be270cd8", "c558b6c8e51046c6b3e0809ccf749348", "0ea62e14a2d44b8ca7c6f15bca8e1a2b", "27e78ef30ff64be5af79b6431816f569", "b973da5783dc49d1802831298be3ab7e", "d7e3f2c2441743089ee7ed10eb0319dd", "7f490169dd8a4dbeb367af5259b0949e", "2a8b54de6a974cc3a3f0b47daa10af89", "70780d08fb00445da32b28b40d33412f", "2c83de81fc91497db859d1be6860d6f4", "397a49fc365b4c3a95db709b71fc32e3", "1e75c59df9524dec928b21d65dd43c08", "747b264f119746c08bbd6685c134ec4f", "364d9414fd5f4280be4ad621874e15cb", "7b3ad76804bd47a68fba0bf9f967cb04", "2c34cb165ce447dca52c25eff7cc91ce", "e8f729e29f7c4639a10a02a9f4978385", "f60723e017f64a2183af5f05ca827340", "165ed67ab04240fca77e70e6a905ad48", "e229f6a91dcf4b5182673cc6c7002063", "fe4ecfe3522c432b964608f137cfb936", "e1e8c601a2bb4c0b8f0cdf3bf0dd0b2e", "db89c098dd414fa08440071935414b63", "efb3aff9a97340fb8acf5f98402e7c1d", "d4eeb17c3a124c919058c0ffe72a60b6", "8b6bd37f64454464b44691ae3bef0b9f", "239c67254f4044259fcbb35cb78d72b3", "7bdc3ab924804e6badd7f403d0ce33f1", "2b1ff4602dee4511a450588c501a7174", "bfae7687d2a440189c29dc53e47b71ab", "c72dbf9684ed4cf0a55655ebaf645b96", "7083c21ed3dd4ed6b442afdf4a18a05d", "6e4ce08c86f24beeb63d38d436143dc9", "86498c04c2a347968e65867eb4ed2102", "dc90aea1a6e445ebaf68bb305586bf47", "4c69d53cf33c4865af8feead04b810ad", "9dbc67661c614d8a82d0790d9b694e6d", "97f09dcebaf94be3b115a69b1440392b", "a568aff772814128b0ad329594e67610", "92ac502bff6e48d98cbb07a65d84dca6", "0b9dd1c421f14c56986cf30abe0a9455", "50a3cb7fc0ce45bb8c1baabc2a54fff2", "fac3506146244766b886c715ea57ec7b", "d068e96c067e43ffa1777a666b0df147", "9e9b1a12028443c9abe6d88d0ed39d15", "638c59bf00ba4536ad157975faa26984", "73177ee1d64c4f56aa393c45c42e7499", "98aabe0e4167414d943ff3d7759093ad", "0f040a5765854b699b32bf096bb0563c", "813e8e40df6447c18e47414ab38f69a2", "533f84f90e7247e8a60da8b6ce60fe89", "80c5f4171f2844749b5f3e67a46178b9", "0722b569fa2e45f9af7d4fb8134e632b", "115a8f3c523c4cd3aaa472af87fe0e4a", "26cde13f97db4ae4a0900fb5184a51a2", "5c4e90490762447c9712372a18f48e20", "690c64fbee9f4c038a29030086c73368", "cbe9c4e2f8bc49dfb1d8a4cddc36d724", "f33aef89be7641a694fb11e3101c7d85", "6798efcf0f784fca91a2e84c91553296", "190504fd7fec4f08ba43b972ed28d23b", "142703403e5e4ba1ac1531d7c82c752a", "0a45b8dd65764272869fd1424415e6af", "d536479db8034c99be7e745a440baf7d", "594a583be494419c845fd873904d4d6d", "ba21e0c745714031aae12b1665847e54", "3491a152831245f7b03d06fe94f42116", "b80d4e38b5c54dd585d1da848ee9c4c8", "fa155f662d1c40f496daef290ff7a5ff", "9b2a88f9712d409bbf74bf0c2e14df16" ] }, "id": "DkIvEkIIkEyB", "outputId": "e1aaf3b5-6ad0-45c7-be08-2017e74ebba9" }, "outputs": [], "source": [ "from unsloth import FastLanguageModel\n", "import torch\n", "\n", "max_seq_length = 1024 # Can increase for longer reasoning traces\n", "lora_rank = 32 # Larger rank = smarter, but slower\n", "\n", "model, tokenizer = FastLanguageModel.from_pretrained(\n", " model_name=\"meta-llama/meta-Llama-3.1-8B-Instruct\",\n", " max_seq_length=max_seq_length,\n", " load_in_4bit=True, # False for LoRA 16bit\n", " fast_inference=True, # Enable vLLM fast inference\n", " max_lora_rank=lora_rank,\n", " gpu_memory_utilization=0.6, # Reduce if out of memory\n", ")\n", "\n", "model = FastLanguageModel.get_peft_model(\n", " model,\n", " r=lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128\n", " target_modules=[\n", " \"q_proj\",\n", " \"k_proj\",\n", " \"v_proj\",\n", " \"o_proj\",\n", " \"gate_proj\",\n", " \"up_proj\",\n", " \"down_proj\",\n", " ], # Remove QKVO if out of memory\n", " lora_alpha=lora_rank,\n", " use_gradient_checkpointing=\"unsloth\", # Enable long context finetuning\n", " random_state=3407,\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "7KGgPgk_5S8r" }, "source": [ "### Data Prep\n", "\n", "\n", "We directly leverage [@willccbb](https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb) for data prep and all reward functions. You are free to create your own!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 321, "referenced_widgets": [ "5954f25ab3334f8c8e0576ee75922c0c", "5db92327420a48f4a1ad966ecd7ac8c5", "1a1651d5ca7d4ccd887235343e49ffb3", "22462b8181964257a53c85be6265eef6", "1c09f4dc86124b1296ed96fc48b8d856", "5e9d2ad532a8447a904688629a91f171", "2df38622a5ab46979ee53efeb8c31b91", "a75b75f7dec54759816845c3fda31f22", "b14a63b4cac142c5adaaebbfef8739fe", "1a090b75a50b459a8d2f822efc692dfa", "0da6c9fa4eeb4b4d832f8bbc614bc8d9", "3e1954a710bd43d1907b702c7f70c6de", "483db1483d1d4a97a59c1e6cdad04c6a", "cf88de7717504af691ca8db0fb68a6b6", "d4de07c4656c40329bc2c1450bd94bb5", "160f401dc9a74266bb610e747f8f711e", "269f762b5af041eab750226febe70e07", "f38c687b78224c8aa2227914998b4627", "eca5ebef74f341f79dc96964d0189862", "009ea1226a194cb2923ec9bc3a6e1843", "2aade13ff2ca4250ae02b7a12b1ef16d", "23531535f2fc4b99a5ecd1b63fb93d80", "a137a1bc2b4f437a8477e079f003ad8d", "05836f874c9a49d390721a05ff055ac9", "9032130a4a9646db92805eeb4221c83d", "ac58943ffd184cdca0b69cce35f54ec7", "8dc3c53fac5644a6a2841aca62226300", "fda3ed1250d34b12aae8401287322548", "e85f549d21dd4c3293f8239a053b0ad4", "a8d0ef0ef4f747e294394b04d07cbe34", "15b83f9309a642d2bb1bc79da539b923", "436282df7bcc45c4b951de60f2e4ec53", "911d088df0894243b69fd4f69c46e7f3", "eba35bc404f640bcb087de7a7ce74942", "6b39fce0206f4ca894b943a42eb9dc04", "c3ec8a75007846519d40f6726d2a01ee", "de3486da03594b0091b77cbf702401a5", "2e1b5b6b3ae84190a7f50993d78fb6ea", "81f43f522faa4dcf9e557793a2027e90", "ae46c8f17c0644fdbd1aa6d412c322b7", "5de480828775465bb76b674ecd713bd0", "a159d55b4de84020bc127d6c7758d2a9", "cc6097b25ec74decb51a397ccd9fcf27", "d2af3c19e6ec46a3915f6badf7062b7b", "8d2463ee53394dc097843bc733191c97", "f681a0ff692c4cd3b14b1e064547c52f", "0ecef6aade754029ba939e30c91538be", "87e25abbd016411199a2e976cd7ab550", "93f708b722194750a9f66a2597987033", "76dc9524ee764741a1615de6bc43492e", "78201e7953b647ee9656fc7e8f4abbcb", "0bc45e643a9b4a0e8db9f85180c108b2", "cfb5247cf65c4bdfbcf8207c5d8c72c8", "32fe42f8a11d4131a3a529b554f028fe", "39c33d513fb74712be8fcf73093a0c9e", "908333eedde74ec3a0885e52454b3627", "934617e655d04f85a87d6444b7c1ff0e", "63d0c36c4bb8469f8cf5ac410d2c11d6", "092d58ba2c444878a1cbfd19bdc14027", "c8ecf533f45f4ae29bdcfdaeca7daedf", "d89eae07da0e4344909900e24e3d0d09", "b1e00fc06b664fc59d4921388f187269", "b834a1ed2fc14d209c405969def96965", "dc681483ceeb4e878124858ffb467338", "ef2a701c9c594d4d9ffc8379fa9b5899", "22c31da7c0564f09bbdb04efdefad11a" ] }, "id": "cXk993X6C2ZZ", "outputId": "d6b161df-4d15-4cae-a02c-22fb901b91f1" }, "outputs": [], "source": [ "import re\n", "from datasets import load_dataset, Dataset\n", "\n", "# Load and prep dataset\n", "SYSTEM_PROMPT = \"\"\"\n", "Respond in the following format:\n", "\n", "...\n", "\n", "\n", "...\n", "\n", "\"\"\"\n", "\n", "XML_COT_FORMAT = \"\"\"\\\n", "\n", "{reasoning}\n", "\n", "\n", "{answer}\n", "\n", "\"\"\"\n", "\n", "\n", "def extract_xml_answer(text: str) -> str:\n", " answer = text.split(\"\")[-1]\n", " answer = answer.split(\"\")[0]\n", " return answer.strip()\n", "\n", "\n", "def extract_hash_answer(text: str) -> str | None:\n", " if \"####\" not in text:\n", " return None\n", " return text.split(\"####\")[1].strip()\n", "\n", "\n", "# uncomment middle messages for 1-shot prompting\n", "def get_gsm8k_questions(split=\"train\") -> Dataset:\n", " data = load_dataset(\"openai/gsm8k\", \"main\")[split] # type: ignore\n", " data = data.map(\n", " lambda x: { # type: ignore\n", " \"prompt\": [\n", " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n", " {\"role\": \"user\", \"content\": x[\"question\"]},\n", " ],\n", " \"answer\": extract_hash_answer(x[\"answer\"]),\n", " }\n", " ) # type: ignore\n", " return data # type: ignore\n", "\n", "\n", "dataset = get_gsm8k_questions()\n", "\n", "\n", "# Reward functions\n", "def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:\n", " responses = [completion[0][\"content\"] for completion in completions]\n", " q = prompts[0][-1][\"content\"]\n", " extracted_responses = [extract_xml_answer(r) for r in responses]\n", " print(\n", " \"-\" * 20,\n", " f\"Question:\\n{q}\",\n", " f\"\\nAnswer:\\n{answer[0]}\",\n", " f\"\\nResponse:\\n{responses[0]}\",\n", " f\"\\nExtracted:\\n{extracted_responses[0]}\",\n", " )\n", " return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]\n", "\n", "\n", "def int_reward_func(completions, **kwargs) -> list[float]:\n", " responses = [completion[0][\"content\"] for completion in completions]\n", " extracted_responses = [extract_xml_answer(r) for r in responses]\n", " return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]\n", "\n", "\n", "def strict_format_reward_func(completions, **kwargs) -> list[float]:\n", " \"\"\"Reward function that checks if the completion has a specific format.\"\"\"\n", " pattern = r\"^\\n.*?\\n\\n\\n.*?\\n\\n$\"\n", " responses = [completion[0][\"content\"] for completion in completions]\n", " matches = [re.match(pattern, r) for r in responses]\n", " return [0.5 if match else 0.0 for match in matches]\n", "\n", "\n", "def soft_format_reward_func(completions, **kwargs) -> list[float]:\n", " \"\"\"Reward function that checks if the completion has a specific format.\"\"\"\n", " pattern = r\".*?\\s*.*?\"\n", " responses = [completion[0][\"content\"] for completion in completions]\n", " matches = [re.match(pattern, r) for r in responses]\n", " return [0.5 if match else 0.0 for match in matches]\n", "\n", "\n", "def count_xml(text) -> float:\n", " count = 0.0\n", " if text.count(\"\\n\") == 1:\n", " count += 0.125\n", " if text.count(\"\\n\\n\") == 1:\n", " count += 0.125\n", " if text.count(\"\\n\\n\") == 1:\n", " count += 0.125\n", " count -= len(text.split(\"\\n\\n\")[-1]) * 0.001\n", " if text.count(\"\\n\") == 1:\n", " count += 0.125\n", " count -= (len(text.split(\"\\n\")[-1]) - 1) * 0.001\n", " return count\n", "\n", "\n", "def xmlcount_reward_func(completions, **kwargs) -> list[float]:\n", " contents = [completion[0][\"content\"] for completion in completions]\n", " return [count_xml(c) for c in contents]" ] }, { "cell_type": "markdown", "metadata": { "id": "Ux6iqP7z5YOo" }, "source": [ "\n", "### Train the model\n", "\n", "Now set up GRPO Trainer and all configurations!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ptqkXK2D4d6p", "outputId": "1e7b980b-c662-49c9-fb3a-d929edbafa09" }, "outputs": [], "source": [ "max_prompt_length = 256\n", "\n", "from trl import GRPOConfig, GRPOTrainer\n", "\n", "training_args = GRPOConfig(\n", " learning_rate=5e-6,\n", " adam_beta1=0.9,\n", " adam_beta2=0.99,\n", " weight_decay=0.1,\n", " warmup_ratio=0.1,\n", " lr_scheduler_type=\"cosine\",\n", " optim=\"paged_adamw_8bit\",\n", " logging_steps=1,\n", " per_device_train_batch_size=1,\n", " gradient_accumulation_steps=1, # Increase to 4 for smoother training\n", " num_generations=6, # Decrease if out of memory\n", " max_prompt_length=max_prompt_length,\n", " max_completion_length=max_seq_length - max_prompt_length,\n", " # num_train_epochs = 1, # Set to 1 for a full training run\n", " max_steps=250,\n", " save_steps=250,\n", " max_grad_norm=0.1,\n", " report_to=\"none\", # Can use Weights & Biases\n", " output_dir=\"outputs\",\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "r9Mv8UZO5hz-" }, "source": [ "And let's run the trainer! If you scroll up, you'll see a table of rewards. The goal is to see the `reward` column increase!\n", "\n", "You might have to wait 150 to 200 steps for any action. You'll probably get 0 reward for the first 100 steps. Please be patient!\n", "\n", "| Step | Training Loss | reward | reward_std | completion_length | kl |\n", "|------|---------------|-----------|------------|-------------------|----------|\n", "| 1 | 0.000000 | 0.125000 | 0.000000 | 200.000000 | 0.000000 |\n", "| 2 | 0.000000 | 0.072375 | 0.248112 | 200.000000 | 0.000000 |\n", "| 3 | 0.000000 | -0.079000 | 0.163776 | 182.500000 | 0.000005 |\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "vzOuSVCL_GA9", "outputId": "09405367-ccea-4e1c-aabf-2dbb9318c4fa" }, "outputs": [], "source": [ "trainer = GRPOTrainer(\n", " model=model,\n", " processing_class=tokenizer,\n", " reward_funcs=[\n", " xmlcount_reward_func,\n", " soft_format_reward_func,\n", " strict_format_reward_func,\n", " int_reward_func,\n", " correctness_reward_func,\n", " ],\n", " args=training_args,\n", " train_dataset=dataset,\n", ")\n", "trainer.train()" ] }, { "cell_type": "markdown", "metadata": { "id": "tlaUdxC_VHpz" }, "source": [ "\n", "### Inference\n", "Now let's try the model we just trained! First, let's first try the model without any GRPO trained:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 249 }, "id": "qtcz_lpbVC92", "outputId": "70e4f329-acac-4d31-a8cd-47f7c6088747" }, "outputs": [], "source": [ "text = tokenizer.apply_chat_template(\n", " [\n", " {\"role\": \"user\", \"content\": \"Calculate pi.\"},\n", " ],\n", " tokenize=False,\n", " add_generation_prompt=True,\n", ")\n", "\n", "from vllm import SamplingParams\n", "\n", "sampling_params = SamplingParams(\n", " temperature=0.8,\n", " top_p=0.95,\n", " max_tokens=1024,\n", ")\n", "output = (\n", " model.fast_generate(\n", " [text],\n", " sampling_params=sampling_params,\n", " lora_request=None,\n", " )[0]\n", " .outputs[0]\n", " .text\n", ")\n", "\n", "output" ] }, { "cell_type": "markdown", "metadata": { "id": "Colxz9TAVMsi" }, "source": [ "And now with the LoRA we just trained with GRPO - we first save the LoRA first!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "AL-BcuB1VLIv" }, "outputs": [], "source": [ "model.save_lora(\"grpo_saved_lora\")" ] }, { "cell_type": "markdown", "metadata": { "id": "CwpbwnDBVRLg" }, "source": [ "Now we load the LoRA and test:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 249 }, "id": "zf_OY5WMVOxF", "outputId": "22373f16-a3bc-4c99-8a1d-1994522a5f0f" }, "outputs": [], "source": [ "text = tokenizer.apply_chat_template(\n", " [\n", " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n", " {\"role\": \"user\", \"content\": \"Calculate pi.\"},\n", " ],\n", " tokenize=False,\n", " add_generation_prompt=True,\n", ")\n", "\n", "from vllm import SamplingParams\n", "\n", "sampling_params = SamplingParams(\n", " temperature=0.8,\n", " top_p=0.95,\n", " max_tokens=1024,\n", ")\n", "output = (\n", " model.fast_generate(\n", " text,\n", " sampling_params=sampling_params,\n", " lora_request=model.load_lora(\"grpo_saved_lora\"),\n", " )[0]\n", " .outputs[0]\n", " .text\n", ")\n", "\n", "output" ] }, { "cell_type": "markdown", "metadata": { "id": "6aDgFfhFYIAS" }, "source": [ "Our reasoning model is much better - it's not always correct, since we only trained it for an hour or so - it'll be better if we extend the sequence length and train for longer!" ] }, { "cell_type": "markdown", "metadata": { "id": "-NUEmHFSYNTp" }, "source": [ "\n", "### Saving to float16 for VLLM\n", "\n", "We also support saving to `float16` directly. Select `merged_16bit` for float16 or `merged_4bit` for int4. We also allow `lora` adapters as a fallback. Use `push_to_hub_merged` to upload to your Hugging Face account! You can go to https://huggingface.co/settings/tokens for your personal tokens." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NjXGTkp7YNtB" }, "outputs": [], "source": [ "# Merge to 16bit\n", "if False:\n", " model.save_pretrained_merged(\n", " \"model\",\n", " tokenizer,\n", " save_method=\"merged_16bit\",\n", " )\n", "if False:\n", " model.push_to_hub_merged(\n", " \"hf/model\", tokenizer, save_method=\"merged_16bit\", token=\"\"\n", " )\n", "\n", "# Merge to 4bit\n", "if False:\n", " model.save_pretrained_merged(\n", " \"model\",\n", " tokenizer,\n", " save_method=\"merged_4bit\",\n", " )\n", "if False:\n", " model.push_to_hub_merged(\"hf/model\", tokenizer, save_method=\"merged_4bit\", token=\"\")\n", "\n", "# Just LoRA adapters\n", "if False:\n", " model.save_pretrained_merged(\n", " \"model\",\n", " tokenizer,\n", " save_method=\"lora\",\n", " )\n", "if False:\n", " model.push_to_hub_merged(\"hf/model\", tokenizer, save_method=\"lora\", token=\"\")" ] }, { "cell_type": "markdown", "metadata": { "id": "52WMb3k_YPt8" }, "source": [ "### GGUF / llama.cpp Conversion\n", "To save to `GGUF` / `llama.cpp`, we support it natively now! We clone `llama.cpp` and we default save it to `q8_0`. We allow all methods like `q4_k_m`. Use `save_pretrained_gguf` for local saving and `push_to_hub_gguf` for uploading to HF.\n", "\n", "Some supported quant methods (full list on our [Wiki page](https://github.com/unslothai/unsloth/wiki#gguf-quantization-options)):\n", "* `q8_0` - Fast conversion. High resource use, but generally acceptable.\n", "* `q4_k_m` - Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q4_K.\n", "* `q5_k_m` - Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q5_K.\n", "\n", "[**NEW**] To finetune and auto export to Ollama, try our [Ollama notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QyEjW-WuYQIm" }, "outputs": [], "source": [ "# Save to 8bit Q8_0\n", "if False:\n", " model.save_pretrained_gguf(\n", " \"model\",\n", " tokenizer,\n", " )\n", "# Remember to go to https://huggingface.co/settings/tokens for a token!\n", "# And change hf to your username!\n", "if False:\n", " model.push_to_hub_gguf(\"hf/model\", tokenizer, token=\"\")\n", "\n", "# Save to 16bit GGUF\n", "if False:\n", " model.save_pretrained_gguf(\"model\", tokenizer, quantization_method=\"f16\")\n", "if False:\n", " model.push_to_hub_gguf(\"hf/model\", tokenizer, quantization_method=\"f16\", token=\"\")\n", "\n", "# Save to q4_k_m GGUF\n", "if False:\n", " model.save_pretrained_gguf(\"model\", tokenizer, quantization_method=\"q4_k_m\")\n", "if False:\n", " model.push_to_hub_gguf(\n", " \"hf/model\", tokenizer, quantization_method=\"q4_k_m\", token=\"\"\n", " )\n", "\n", "# Save to multiple GGUF options - much faster if you want multiple!\n", "if False:\n", " model.push_to_hub_gguf(\n", " \"hf/model\", # Change hf to your username!\n", " tokenizer,\n", " quantization_method=[\n", " \"q4_k_m\",\n", " \"q8_0\",\n", " \"q5_k_m\",\n", " ],\n", " token=\"\",\n", " )" ] }, { "cell_type": "markdown", "metadata": { "id": "upXTZ_0aNjof" }, "source": [ "Now, use the `model-unsloth.gguf` file or `model-unsloth-Q4_K_M.gguf` file in llama.cpp or a UI based system like Jan or Open WebUI. You can install Jan [here](https://github.com/janhq/jan) and Open WebUI [here](https://github.com/open-webui/open-webui)\n", "\n", "And we're done! If you have any questions on Unsloth, we have a [Discord](https://discord.gg/unsloth) channel! If you find any bugs or want to keep updated with the latest LLM stuff, or need help, join projects etc, feel free to join our Discord!\n", "\n", "Some other links:\n", "1. Train your own reasoning model - Llama GRPO notebook [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-GRPO.ipynb)\n", "2. Saving finetunes to Ollama. [Free notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)\n", "3. Llama 3.2 Vision finetuning - Radiography use case. [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(11B)-Vision.ipynb)\n", "6. See notebooks for DPO, ORPO, Continued pretraining, conversational finetuning and more on our [documentation](https://docs.unsloth.ai/get-started/unsloth-notebooks)!\n", "\n", "
\n", " \n", " \n", " \n", "\n", " Join Discord if you need help + ⭐️ Star us on Github ⭐️\n", "
\n" ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "T4", "provenance": [] }, "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.11" } }, "nbformat": 4, "nbformat_minor": 0 }