{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "Oh3eu2ULxWAn" }, "source": [ "To run this, press \"*Runtime*\" and press \"*Run all*\" on a **free** Tesla T4 Google Colab instance!\n", "
\n", "\n", "\n", " Join Discord if you need help + ⭐ Star us on Github ⭐\n", "
\n", "\n", "To install Unsloth on your own computer, follow the installation instructions on our Github page [here](https://docs.unsloth.ai/get-started/installing-+-updating).\n", "\n", "You will learn how to do [data prep](#Data), how to [train](#Train), how to [run the model](#Inference), & [how to save it](#Save)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "rNxpti3ExWAp" }, "source": [ "### News" ] }, { "cell_type": "markdown", "metadata": { "id": "yo-qTlvNxWAp" }, "source": [ "**Read our [Gemma 3 blog](https://unsloth.ai/blog/gemma3) for what's new in Unsloth and our [Reasoning blog](https://unsloth.ai/blog/r1-reasoning) on how to train reasoning models.**\n", "\n", "Visit our docs for all our [model uploads](https://docs.unsloth.ai/get-started/all-our-models) and [notebooks](https://docs.unsloth.ai/get-started/unsloth-notebooks).\n" ] }, { "cell_type": "markdown", "metadata": { "id": "MrkTn5S9xWAq" }, "source": [ "### Installation" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NV0BlxygxWAq" }, "outputs": [], "source": [ "%%capture\n", "import os\n", "\n", "if \"COLAB_\" not in \"\".join(os.environ.keys()):\n", " !pip install unsloth vllm\n", "else:\n", " # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]\n", " !pip install --no-deps unsloth vllm" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8lIsllEIxWAr" }, "outputs": [], "source": [ "# @title Colab Extra Install { display-mode: \"form\" }\n", "%%capture\n", "import os\n", "\n", "if \"COLAB_\" not in \"\".join(os.environ.keys()):\n", " !pip install unsloth vllm\n", "else:\n", " !pip install --no-deps unsloth vllm\n", " # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]\n", " # Skip restarting message in Colab\n", " import sys\n", " import re\n", " import requests\n", "\n", " modules = list(sys.modules.keys())\n", " for x in modules:\n", " sys.modules.pop(x) if \"PIL\" in x or \"google\" in x else None\n", " !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo\n", " !pip install sentencepiece protobuf datasets huggingface_hub hf_transfer\n", "\n", " # vLLM requirements - vLLM breaks Colab due to reinstalling numpy\n", " f = requests.get(\n", " \"https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/requirements/common.txt\"\n", " ).content\n", " with open(\"vllm_requirements.txt\", \"wb\") as file:\n", " file.write(re.sub(rb\"(transformers|numpy|xformers)[^\\n]{1,}\\n\", b\"\", f))\n", " !pip install -r vllm_requirements.txt" ] }, { "cell_type": "markdown", "metadata": { "id": "a6uTCUXFxWAr" }, "source": [ "### Unsloth" ] }, { "cell_type": "markdown", "metadata": { "id": "Joje4qPsyxM9" }, "source": [ "Load up `Qwen 2.5 3B Instruct`, and set parameters" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000, "referenced_widgets": [ "b46d7118f71043109c2dd00fb0f40557", "32defd049d8a473c8ad8fc954c0e3d13", "424078e3ad6c405c925bedfb76d18b3f", "8d937714a44d422aa9851c02c2345e5c", "25b26c398e314cb79a2f8bbb7e32e158", "694494752ab2458ead83b0331af20122", "1cc89fa4d7c744c69c3be1a3d0b3ba2d", "15116ef21fae463ba9013c8b9dbd40b5", "43fdf1fe592147538ec43410a1bf8e09", "52ba857bdfc742a2aae9e9a949054e5c", "589a3857411f49d4ae019465f536da45", "5dafeefcf436462e81a9affe7f6c66d9", "98d1697fb2054fd88432f027e592ecc0", "b1adc20d1d8740f99a6ebd589f240688", "f45a696d73054cd6af181a7ce35a62e4", "372b687ac774400088d7252467f9cef2", "a422e26291bb47d08e36b7d94fffcc9b", "799dc8fe6f7d4e4db68ebd6b1d8fdbd5", "e61a718f06a045b6b028fbd33382d37a", "7c712c0011954872834e90cc908b8cf6", "9929aaffe7614506ae934a14d2b5c3ca", "1060550e9dec4fa59e2b9b488af61b48", "ebd36fe7f8dd47b99e69176ace9b985d", "42a3c5c410cd49a38400b788980ad428", "e5373b31cbfc4e87af314bf4898d2b72", "4379f29ff4a74df784bb2fb3989c64cc", "87b42e656263458cbb33ee3ca29128a8", "b245d8e4c115438583ef8b433d4d16be", "df7f9d244f70440aaf401c6a856a6ed0", "5d39d9e8cb724af59d7af29ec9c41cf4", "6292f7a4de1948d09bebaf5064fa0d42", "62cbedb2d67d4a2c8888b46e56e39f65", "57f2c072b0574f74a7dbe2a9d303cbd4", "78345adf6d8143e3a649a05c5d476115", "f89bfff6526c43bcb75c93f987bfba72", "a7244280fc934b8484b3dfd7eecca161", "ec7ae2abc92f49f98ab0c52953bf82af", "092355509a67474e9cc13a6b4718a192", "c69730dfeec14cedaf8917ee9ad27dcb", "00177cac24374c7d9144e3856594ddaf", "1b1e49195c5b48a1b6e4364db01c5c87", "97af130f6bd54d618e0af88ed9e85819", "1f88a409f3c64ea291861837e830ce39", "e9881edb2c7846fea2070c8e724b25cb", "dfc8cc5ee6704f1fb3b4df0507ea471b", "16bb453bd4934bd9840a9d365be604cd", "1ee1a56d04fa4e85a20e4cde46c164e5", "9872da92026e4debb96b731c892408e2", "451d181b9ef24b0e9fc3914a4466b9f1", "6a49aa5a2b3646d4b772e2398eacf72d", "edeb0f283b0b495eab6e06a15fa38af0", "0ed98c973a9d4d30b31ab751a7ecf3f6", "a0b7dab4f65e448f816e00d12dc3b469", "3b6d4b41992b43acac5daa31f119cd30", "c5b4960fc3ba4459a6eaff43772050a3", "71f9e574addd491cbbe65ddbe0eb23ad", "442e58ceec4e44a3a4a2ebadd145ee73", "4382289897b54e63b22541d0b45c9bf3", "ba6998ac52544199b4e828d1a7f99d0d", "29399d40d44d46c8b3648218e8378e89", "c6b2d407256549a5bc7543cb6d7475b4", "545251749b10408d82c9a60b914dce34", "1bbadef2fca5427fa1a4975153968f56", "d267ec5dd1954b9a88e566dff4a9af9d", "4ddc8fa0c9834cd3b36274a9e0e05504", "ec13219478284bbdbe69cbcd3dff5c34", "562637f216114f42aa352cabd59cab57", "e37bcfc1604b4a788f53825c5a7e04b7", "16e05d45084c44b68697316bc8791b88", "9b80621f94a0484791b33a13b3c6bade", "fc87fbccba304554be90701b39429918", "aaf79572906c4443b5542bef5dee8055", "bc81dfa869f94f059545a7d59b7eb70a", "b3c84c45c0804e4cb703f9c512c6f90d", "6520db76749d4a8aaa6d26005f07c43c", "29a3f867c8bd4769b8e649b8912776a2", "2317c9a94dfe4a62a8605038ebc2210a", "bd27949a1e3c43e5912a01b04b6621a4", "94ed75e3a9fe4cbdb373d4038af0c32b", "adfd8a554ada48f9a3dc2fab69467b79", "0f9dcd9b698b4af699cfafdf2874612c", "f8700d095ee04a40a15c6a88d7daf1ae", "1ccb12282661417383a7f96bd5ea4eac", "32a16b0cbbe548f4869cf665d737b9ba", "3f546bf0aa124e45b39daa0328a1d23d", "61f6966f8b6d4070a9ea3604373cb96a", "f8198301b2db406685189589e59efd9a", "e813e26451e3476a82a9b75668b638a1", "8aa00f7a2de64af8b0a5b8cd5fefb14e", "3404fb0c4c4b4ca0ae6d9eeec6722e02", "f268eda021e040e6b25889e72ee8dd1e", "b61f7faf19594e3c81e1429045f341de", "a4ecf424e6ed44c597f25acee7b43df2", "c314f7a0c6084790a2620ec231297525", "950c1b85a8bd456ab928df4cc3710b21", "806b17b8ce074563941dbba9ca2c9422", "0b42f72fac844c19adea7b6689e79432", "37e2f38471d4481eb93031ca52bdb785", "067aaf24d87a400b9f636762051598f3", "a03f0a986af84eacbb5c333f49726105", "abf0e1d60e9f4fadbcda876dd28f6d14", "81a7650cc4124b58b4e191093a5433e1", "74f24411d35e4c918c608a1afd799156", "a8501b8551d4431fbb708d6556e7dd85", "24757877452d48f0afe92bb66e2c48ac", "b64be6a0d6df44899ba6fe0a47f676e1", "c1588f6fe9ca48a5a7512734f26ae9a3", "374765208bf345089f58e925d2a70984", "78774135f2d1468d9172fa7588afef20", "9b3233c52ecc4b259deed2d5d6f332eb", "2567e718b28f417e834a10c5310f92b2", "587ea317bd0a4831a9417a3980207a73", "9e857c0cb85d4b50ad73c670436b48a5", "f8705a1680534caa801dc6e14c3601e3", "7e982757835048ae85abf6cba5d27ecb", "e14a085ca7e6410ca78fe0f27533be6d", "6d204a683e2d49deb5f504f91ee72044", "6659b3fef8644c6e96ed9d43adb744d5", "4c130bb196db422c8b9c309cbe59bac1", "32f42609c389450bbaad10f0a884c527", "a9e874a076234974a102b02551ec8666", "e505a51d928542b883743a6d48c5cbd3", "32ae8645c975473b81788c17c376367a", "7cb758be0be64a48930379b2cfa9fb44", "fb4bf57c071d4ef68046988a4965b312", "f1af6e23e42c4051a1c7fd4864450eaa", "5cf88ae3c2fc4419a71eb45f28f67384", "327fa24e70ee4cae903fe5673bcf3bfc", "d1d5afc36de043afb8638dec0001288e", "fffe74f9a1184de5b1297c659c7891b6", "878702a019e34789b5b2f68eb321f2c8", "91998be216844c079144c18278f660d7", "5a1c360e82204ad99238901d366828cf", "fc98ab1d58d84f69b6c36dbc19f5edd7", "e2e86147b7d44f66a2cdfe02d63ee141", "2f5178d399ec48648b435bac95faf188", "c13ff304d7bc4c67a1f51da5d2303f96", "0e8c4086734c4296884db9c5b67ddd33", "c3a97f3130bc42c1be2f8c7dc81ca6cc", "0043184800574bfdbfb3c82f4842c3ed", "32719fe47c8543f69bcddc9b9bc0f719", "784a5c0d201d46459a160e0dc8857fd9", "d753c1ecc6034237ba4401a007451bfb", "55fd60765d5e46c0baf1ac075820fc87", "771af8b28988459ba4f2dbaadd0d5ce1", "0cc22299405f4361b1b5f6234f35bae3", "0f22a4136dce46f89e9efbf0581057a2", "5cda6bd66e3248c4be477f1c1b40198d", "269b131121554544a21d47596a24aa58", "73758a3ca9c14d7eaa6504d8b9b6e333", "5a16728cee0f4b75be08696261d8d8cc", "da77a7b205b1446a9dacab0550aafebe", "d6bd8a9af6e84500956bb565f2fe2c4e", "6643bd32eb984a929159c8d441530aaf", "8358f98f790d47fdb783eb5ca210c1fc", "0c94630c2693424d9edd8b9921333aac", "edf185cd11544b2c8bce478fb2f0b565", "a7b4b7d3f1a34b089d72eb2a4ab169b0", "09a607655a054ed5ad7f36716f8ee4c6", "9258ed2d94924de3b944105e280fbf5f", "1c68fbbc833d4c3fbbb686f15b9ba939", "fa996ce2f8f24a5d80914c4ef5b6e788", "f77e9e785cc04c00822fc8b7ef2959da", "5d190801b10b49a698c3a69c7cd1f253", "0112cfe0a27e43fd9509ef8da050e188", "4c6b2159bbc04b8d811fadff3a22f4a6", "c705d699568543829441df2258f2e615", "eb8f921c8d67442b9bc2a6a67930d94c", "8a89212807f243d4b135a3188f4fb6d0", "cc018dfe29144d0182058816134e1fd9", "c404db8298364742b8ce076cddcb7641", "03aca6d92597461481419eac9bdb6635", "59b90e05822f480793685902b9c6d055", "d13cc9053e7a4982b6fc7f00474f25c7", "c9b781db7ee2422b8dc4eb6aa2575d87", "5100eb13df1c453980942ec78d0e810b" ] }, "id": "DkIvEkIIkEyB", "outputId": "8bf13df5-67ab-4ea2-db01-5719298b2dc9" }, "outputs": [], "source": [ "from unsloth import FastLanguageModel, is_bfloat16_supported\n", "\n", "max_seq_length = 1024 # Can increase for longer reasoning traces\n", "lora_rank = 64 # Larger rank = smarter, but slower\n", "\n", "model, tokenizer = FastLanguageModel.from_pretrained(\n", " model_name=\"Qwen/Qwen2.5-3B-Instruct\",\n", " max_seq_length=max_seq_length,\n", " load_in_4bit=True, # False for LoRA 16bit\n", " fast_inference=True, # Enable vLLM fast inference\n", " max_lora_rank=lora_rank,\n", " gpu_memory_utilization=0.5, # Reduce if out of memory\n", ")\n", "\n", "model = FastLanguageModel.get_peft_model(\n", " model,\n", " r=lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128\n", " target_modules=[\n", " \"q_proj\",\n", " \"k_proj\",\n", " \"v_proj\",\n", " \"o_proj\",\n", " \"gate_proj\",\n", " \"up_proj\",\n", " \"down_proj\",\n", " ], # Remove QKVO if out of memory\n", " lora_alpha=lora_rank,\n", " use_gradient_checkpointing=\"unsloth\", # Enable long context finetuning\n", " random_state=3407,\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "0Y56ln_izS9E" }, "source": [ "### Data Prep\n", "\n", "\n", "We directly leverage [@willccbb](https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb) for data prep and all reward functions. You are free to create your own!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 209, "referenced_widgets": [ "6b2cf7a462884bc4a3c2f8670f8982ed", "8a358b880c194e6baab24f711937f695", "0ec5dfa2fa6b4be2bfb45692ac20c053", "1a619a0a53c149c3b9b0cbd53f1d2034", "fdcbb976f35e4aa383d9ebcd7304849f", "b131891a909743f5826429a84a74b4f9", "3745da58146d445abe863a248e86d9be", "22d29d39ef8149dba932ed9af68def67", "413b4818b6de4bd995944df2f73bc3d5", "c626fa7f5c624ec5a414e6f58d71ea96", "2a8fb4b980ab45be81ae39d5c6b7bad4", "ec1837ba407c4ecb86335d9c0c239ad9", "1793da35014840b08d776b4abc9a1212", "a06bde82466c43a49b58c0d942071be0", "a1a3ee663a154319ba186954816e6932", "97a9db845a9c45caa88ea1cd0ab6c40d", "231be59acb2a4921ba53ba97bdf5888a", "5aa0e699560c4fe89250b4ccac6401cb", "7529f4d84b9f46709ec5f0e22b75baf4", "8b8fd84911e94b5d8d627911fa9a64f7", "420154f0bbe344a7979769a8ef066df7", "1ee34b9171604805887aee3baa4ff099", "3c95b86a26bc49b4a65589df0e7865e2", "ef6d7a6a7fb344fab5befe656b98c7b7", "c3d523c23ebf4b769c59eda950a3fc1f", "bffc24c5bcdb47988f2e309c7fbc21b8", "11052ab3d1f441248458cd6ab8b810db", "0c2e39b20d5941969101b793b94fdbfa", "f0c74de390814c2b9f38f7cc02458426", "6c9502ae4cbc4248b14590fc6efe832f", "8972cb9e823342bcb5c493e487dc64b6", "71898e427b1245d682178a90c6752e8e", "90295f1308054c66b557da231d25d3d5", "54f444e6a25349fcb8fb35b9efb33e68", "2f9a64ad318248bf99fccdedafaddfc0", "d3ce3f3bee5646d893242758444acd6a", "5b4f99c20a8a4ecf873b0a8ef8b150cd", "70e141512e91492bae4713400879d019", "7c67cf4d83e34cd6a475405be45f345e", "c184e417ad134e94995a55d4bc7c94b5", "f1c8d7356b584fb8965d2a62ccec52c8", "9bcb0000da6e434d9b241a09d9b46244", "11bd7b89fa794a899db9c83cb2a99df5", "0e9d3d2244a54eb781311c01965c33e4", "00d1bd9a327241e8b434ef7b3bde30df", "cd17f032bbb44575977797caaa92bd54", "04b74b7851334892b0db68bea1c8fd7e", "44b6ecc9b9fc436e92701ec2409e0dd1", "46998bc2b1ee4858a8ee1108ce36b9b7", "95fa6a54e7e543c7b2dc5d5cf04daf58", "8ee15a3d02954d31899d0b6a8dbf7988", "1d5a58d63c51427bbbfb729216ea8468", "e2790d1e5e714b3e9081b4edcef24535", "96519710bf5047239f6800100471ee6e", "e9b020ce2dda442b871b41e417210b52", "4d18ef3f7c8046ddb0b03314861c3414", "176ccdb40b994ee3b381f7ea0f379c3c", "3519fdf3abac4d7690c45710f617e5d0", "dc320ae9e2f7426588d49848b0871ca4", "6b5a1287c457454fb367e5dc03d8b157", "989519f3e61b4132b3f88fbdf29703a9", "e0f87f6fd7c84261818624defe575a1c", "e6d2b6aaa9794971b01c778b1a0596db", "334c4833051e4ccea727dc2cdf3871c3", "d36027362b41413fa3ee4670dd820a88", "d94e3315b99747c39190fee9b79f00ac" ] }, "id": "cXk993X6C2ZZ", "outputId": "2d545127-0313-405e-d8f9-a43d7dba8d11" }, "outputs": [], "source": [ "import re\n", "from datasets import load_dataset, Dataset\n", "\n", "# Load and prep dataset\n", "SYSTEM_PROMPT = \"\"\"\n", "Respond in the following format:\n", "\n", "...\n", "\n", "\n", "...\n", "\n", "\"\"\"\n", "\n", "XML_COT_FORMAT = \"\"\"\\\n", "\n", "{reasoning}\n", "\n", "\n", "{answer}\n", "\n", "\"\"\"\n", "\n", "\n", "def extract_xml_answer(text: str) -> str:\n", " answer = text.split(\"\")[-1]\n", " answer = answer.split(\"\")[0]\n", " return answer.strip()\n", "\n", "\n", "def extract_hash_answer(text: str) -> str | None:\n", " if \"####\" not in text:\n", " return None\n", " return text.split(\"####\")[1].strip()\n", "\n", "\n", "# uncomment middle messages for 1-shot prompting\n", "def get_gsm8k_questions(split=\"train\") -> Dataset:\n", " data = load_dataset(\"openai/gsm8k\", \"main\")[split] # type: ignore\n", " data = data.map(\n", " lambda x: { # type: ignore\n", " \"prompt\": [\n", " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n", " {\"role\": \"user\", \"content\": x[\"question\"]},\n", " ],\n", " \"answer\": extract_hash_answer(x[\"answer\"]),\n", " }\n", " ) # type: ignore\n", " return data # type: ignore\n", "\n", "\n", "dataset = get_gsm8k_questions()\n", "\n", "\n", "# Reward functions\n", "def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:\n", " responses = [completion[0][\"content\"] for completion in completions]\n", " q = prompts[0][-1][\"content\"]\n", " extracted_responses = [extract_xml_answer(r) for r in responses]\n", " print(\n", " \"-\" * 20,\n", " f\"Question:\\n{q}\",\n", " f\"\\nAnswer:\\n{answer[0]}\",\n", " f\"\\nResponse:\\n{responses[0]}\",\n", " f\"\\nExtracted:\\n{extracted_responses[0]}\",\n", " )\n", " return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]\n", "\n", "\n", "def int_reward_func(completions, **kwargs) -> list[float]:\n", " responses = [completion[0][\"content\"] for completion in completions]\n", " extracted_responses = [extract_xml_answer(r) for r in responses]\n", " return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]\n", "\n", "\n", "def strict_format_reward_func(completions, **kwargs) -> list[float]:\n", " \"\"\"Reward function that checks if the completion has a specific format.\"\"\"\n", " pattern = r\"^\\n.*?\\n\\n\\n.*?\\n\\n$\"\n", " responses = [completion[0][\"content\"] for completion in completions]\n", " matches = [re.match(pattern, r) for r in responses]\n", " return [0.5 if match else 0.0 for match in matches]\n", "\n", "\n", "def soft_format_reward_func(completions, **kwargs) -> list[float]:\n", " \"\"\"Reward function that checks if the completion has a specific format.\"\"\"\n", " pattern = r\".*?\\s*.*?\"\n", " responses = [completion[0][\"content\"] for completion in completions]\n", " matches = [re.match(pattern, r) for r in responses]\n", " return [0.5 if match else 0.0 for match in matches]\n", "\n", "\n", "def count_xml(text) -> float:\n", " count = 0.0\n", " if text.count(\"\\n\") == 1:\n", " count += 0.125\n", " if text.count(\"\\n\\n\") == 1:\n", " count += 0.125\n", " if text.count(\"\\n\\n\") == 1:\n", " count += 0.125\n", " count -= len(text.split(\"\\n\\n\")[-1]) * 0.001\n", " if text.count(\"\\n\") == 1:\n", " count += 0.125\n", " count -= (len(text.split(\"\\n\")[-1]) - 1) * 0.001\n", " return count\n", "\n", "\n", "def xmlcount_reward_func(completions, **kwargs) -> list[float]:\n", " contents = [completion[0][\"content\"] for completion in completions]\n", " return [count_xml(c) for c in contents]" ] }, { "cell_type": "markdown", "metadata": { "id": "bTnL_tJnzh2L" }, "source": [ "\n", "### Train the model\n", "\n", "Now set up GRPO Trainer and all configurations!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ptqkXK2D4d6p", "outputId": "7d36a642-4cc8-48ee-c8cc-1470978d4503" }, "outputs": [], "source": [ "from trl import GRPOConfig, GRPOTrainer\n", "\n", "training_args = GRPOConfig(\n", " use_vllm=True, # use vLLM for fast inference!\n", " learning_rate=5e-6,\n", " adam_beta1=0.9,\n", " adam_beta2=0.99,\n", " weight_decay=0.1,\n", " warmup_ratio=0.1,\n", " lr_scheduler_type=\"cosine\",\n", " optim=\"adamw_8bit\",\n", " logging_steps=1,\n", " bf16=is_bfloat16_supported(),\n", " fp16=not is_bfloat16_supported(),\n", " per_device_train_batch_size=1,\n", " gradient_accumulation_steps=1, # Increase to 4 for smoother training\n", " num_generations=8, # Decrease if out of memory\n", " max_prompt_length=256,\n", " max_completion_length=200,\n", " # num_train_epochs = 1, # Set to 1 for a full training run\n", " max_steps=250,\n", " save_steps=250,\n", " max_grad_norm=0.1,\n", " report_to=\"none\", # Can use Weights & Biases\n", " output_dir=\"outputs\",\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "X_71Y0eKz5yE" }, "source": [ "And let's run the trainer! If you scroll up, you'll see a table of rewards. The goal is to see the `reward` column increase!\n", "\n", "You might have to wait 150 to 200 steps for any action. You'll probably get 0 reward for the first 100 steps. Please be patient!\n", "\n", "| Step | Training Loss | reward | reward_std | completion_length | kl |\n", "|------|---------------|-----------|------------|-------------------|----------|\n", "| 1 | 0.000000 | 0.125000 | 0.000000 | 200.000000 | 0.000000 |\n", "| 2 | 0.000000 | 0.072375 | 0.248112 | 200.000000 | 0.000000 |\n", "| 3 | 0.000000 | -0.079000 | 0.163776 | 182.500000 | 0.000005 |\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "vzOuSVCL_GA9", "outputId": "c6fdbbe8-7062-45a2-a48c-649f2b878ac4" }, "outputs": [], "source": [ "trainer = GRPOTrainer(\n", " model=model,\n", " processing_class=tokenizer,\n", " reward_funcs=[\n", " xmlcount_reward_func,\n", " soft_format_reward_func,\n", " strict_format_reward_func,\n", " int_reward_func,\n", " correctness_reward_func,\n", " ],\n", " args=training_args,\n", " train_dataset=dataset,\n", ")\n", "trainer.train()" ] }, { "cell_type": "markdown", "metadata": { "id": "yUbluAAhD0Lg" }, "source": [ "\n", "### Inference\n", "Now let's try the model we just trained! First, let's first try the model without any GRPO trained:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 72 }, "id": "IqzsdZzeDM_m", "outputId": "a7637565-80fe-4ee9-cdc9-33b9affb65dc" }, "outputs": [], "source": [ "text = tokenizer.apply_chat_template(\n", " [\n", " {\"role\": \"user\", \"content\": \"How many r's are in strawberry?\"},\n", " ],\n", " tokenize=False,\n", " add_generation_prompt=True,\n", ")\n", "\n", "from vllm import SamplingParams\n", "\n", "sampling_params = SamplingParams(\n", " temperature=0.8,\n", " top_p=0.95,\n", " max_tokens=1024,\n", ")\n", "output = (\n", " model.fast_generate(\n", " [text],\n", " sampling_params=sampling_params,\n", " lora_request=None,\n", " )[0]\n", " .outputs[0]\n", " .text\n", ")\n", "\n", "output" ] }, { "cell_type": "markdown", "metadata": { "id": "G4lzJD7REFjs" }, "source": [ "And now with the LoRA we just trained with GRPO - we first save the LoRA first!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YC9BBT0RESln" }, "outputs": [], "source": [ "model.save_lora(\"grpo_saved_lora\")" ] }, { "cell_type": "markdown", "metadata": { "id": "LherO2vzEbMt" }, "source": [ "Now we load the LoRA and test:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 125 }, "id": "SDKIhhvN6lAF", "outputId": "8bc37894-678f-40d7-874b-365300685836" }, "outputs": [], "source": [ "text = tokenizer.apply_chat_template(\n", " [\n", " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n", " {\"role\": \"user\", \"content\": \"How many r's are in strawberry?\"},\n", " ],\n", " tokenize=False,\n", " add_generation_prompt=True,\n", ")\n", "\n", "from vllm import SamplingParams\n", "\n", "sampling_params = SamplingParams(\n", " temperature=0.8,\n", " top_p=0.95,\n", " max_tokens=1024,\n", ")\n", "output = (\n", " model.fast_generate(\n", " text,\n", " sampling_params=sampling_params,\n", " lora_request=model.load_lora(\"grpo_saved_lora\"),\n", " )[0]\n", " .outputs[0]\n", " .text\n", ")\n", "\n", "output" ] }, { "cell_type": "markdown", "metadata": { "id": "_ZBnvg2f9Nlg" }, "source": [ "Our reasoning model is much better - it's not always correct, since we only trained it for an hour or so - it'll be better if we extend the sequence length and train for longer!" ] }, { "cell_type": "markdown", "metadata": { "id": "RphEZRSfFhru" }, "source": [ "\n", "### Saving to float16 for VLLM\n", "\n", "We also support saving to `float16` directly. Select `merged_16bit` for float16 or `merged_4bit` for int4. We also allow `lora` adapters as a fallback. Use `push_to_hub_merged` to upload to your Hugging Face account! You can go to https://huggingface.co/settings/tokens for your personal tokens." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GwNY9_PrFiXZ" }, "outputs": [], "source": [ "# Merge to 16bit\n", "if False:\n", " model.save_pretrained_merged(\n", " \"model\",\n", " tokenizer,\n", " save_method=\"merged_16bit\",\n", " )\n", "if False:\n", " model.push_to_hub_merged(\n", " \"hf/model\", tokenizer, save_method=\"merged_16bit\", token=\"\"\n", " )\n", "\n", "# Merge to 4bit\n", "if False:\n", " model.save_pretrained_merged(\n", " \"model\",\n", " tokenizer,\n", " save_method=\"merged_4bit\",\n", " )\n", "if False:\n", " model.push_to_hub_merged(\"hf/model\", tokenizer, save_method=\"merged_4bit\", token=\"\")\n", "\n", "# Just LoRA adapters\n", "if False:\n", " model.save_pretrained_merged(\n", " \"model\",\n", " tokenizer,\n", " save_method=\"lora\",\n", " )\n", "if False:\n", " model.push_to_hub_merged(\"hf/model\", tokenizer, save_method=\"lora\", token=\"\")" ] }, { "cell_type": "markdown", "metadata": { "id": "BDUGPiL3Fkkq" }, "source": [ "### GGUF / llama.cpp Conversion\n", "To save to `GGUF` / `llama.cpp`, we support it natively now! We clone `llama.cpp` and we default save it to `q8_0`. We allow all methods like `q4_k_m`. Use `save_pretrained_gguf` for local saving and `push_to_hub_gguf` for uploading to HF.\n", "\n", "Some supported quant methods (full list on our [Wiki page](https://github.com/unslothai/unsloth/wiki#gguf-quantization-options)):\n", "* `q8_0` - Fast conversion. High resource use, but generally acceptable.\n", "* `q4_k_m` - Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q4_K.\n", "* `q5_k_m` - Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q5_K.\n", "\n", "[**NEW**] To finetune and auto export to Ollama, try our [Ollama notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "AGo4dbWvFk4M" }, "outputs": [], "source": [ "# Save to 8bit Q8_0\n", "if False:\n", " model.save_pretrained_gguf(\n", " \"model\",\n", " tokenizer,\n", " )\n", "# Remember to go to https://huggingface.co/settings/tokens for a token!\n", "# And change hf to your username!\n", "if False:\n", " model.push_to_hub_gguf(\"hf/model\", tokenizer, token=\"\")\n", "\n", "# Save to 16bit GGUF\n", "if False:\n", " model.save_pretrained_gguf(\"model\", tokenizer, quantization_method=\"f16\")\n", "if False:\n", " model.push_to_hub_gguf(\"hf/model\", tokenizer, quantization_method=\"f16\", token=\"\")\n", "\n", "# Save to q4_k_m GGUF\n", "if False:\n", " model.save_pretrained_gguf(\"model\", tokenizer, quantization_method=\"q4_k_m\")\n", "if False:\n", " model.push_to_hub_gguf(\n", " \"hf/model\", tokenizer, quantization_method=\"q4_k_m\", token=\"\"\n", " )\n", "\n", "# Save to multiple GGUF options - much faster if you want multiple!\n", "if False:\n", " model.push_to_hub_gguf(\n", " \"hf/model\", # Change hf to your username!\n", " tokenizer,\n", " quantization_method=[\n", " \"q4_k_m\",\n", " \"q8_0\",\n", " \"q5_k_m\",\n", " ],\n", " token=\"\",\n", " )" ] }, { "cell_type": "markdown", "metadata": { "id": "bhYfRM8PxWAu" }, "source": [ "Now, use the `model-unsloth.gguf` file or `model-unsloth-Q4_K_M.gguf` file in llama.cpp or a UI based system like Jan or Open WebUI. You can install Jan [here](https://github.com/janhq/jan) and Open WebUI [here](https://github.com/open-webui/open-webui)\n", "\n", "And we're done! If you have any questions on Unsloth, we have a [Discord](https://discord.gg/unsloth) channel! If you find any bugs or want to keep updated with the latest LLM stuff, or need help, join projects etc, feel free to join our Discord!\n", "\n", "Some other links:\n", "1. Train your own reasoning model - Llama GRPO notebook [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-GRPO.ipynb)\n", "2. Saving finetunes to Ollama. [Free notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)\n", "3. Llama 3.2 Vision finetuning - Radiography use case. [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(11B)-Vision.ipynb)\n", "6. See notebooks for DPO, ORPO, Continued pretraining, conversational finetuning and more on our [documentation](https://docs.unsloth.ai/get-started/unsloth-notebooks)!\n", "\n", "
\n", " \n", " \n", " \n", "\n", " Join Discord if you need help + ⭐️ Star us on Github ⭐️\n", "
\n" ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "T4", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }