You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

854 lines
32 KiB

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "Oh3eu2ULxWAn"
},
"source": [
"To run this, press \"*Runtime*\" and press \"*Run all*\" on a **free** Tesla T4 Google Colab instance!\n",
"<div class=\"align-center\">\n",
"<a href=\"https://unsloth.ai/\"><img src=\"https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png\" width=\"115\"></a>\n",
"<a href=\"https://discord.gg/unsloth\"><img src=\"https://github.com/unslothai/unsloth/raw/main/images/Discord button.png\" width=\"145\"></a>\n",
"<a href=\"https://docs.unsloth.ai/\"><img src=\"https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true\" width=\"125\"></a></a> Join Discord if you need help + ⭐ <i>Star us on <a href=\"https://github.com/unslothai/unsloth\">Github</a> </i> ⭐\n",
"</div>\n",
"\n",
"To install Unsloth on your own computer, follow the installation instructions on our Github page [here](https://docs.unsloth.ai/get-started/installing-+-updating).\n",
"\n",
"You will learn how to do [data prep](#Data), how to [train](#Train), how to [run the model](#Inference), & [how to save it](#Save)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rNxpti3ExWAp"
},
"source": [
"### News"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yo-qTlvNxWAp"
},
"source": [
"**Read our [Gemma 3 blog](https://unsloth.ai/blog/gemma3) for what's new in Unsloth and our [Reasoning blog](https://unsloth.ai/blog/r1-reasoning) on how to train reasoning models.**\n",
"\n",
"Visit our docs for all our [model uploads](https://docs.unsloth.ai/get-started/all-our-models) and [notebooks](https://docs.unsloth.ai/get-started/unsloth-notebooks).\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MrkTn5S9xWAq"
},
"source": [
"### Installation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NV0BlxygxWAq"
},
"outputs": [],
"source": [
"%%capture\n",
"import os\n",
"if \"COLAB_\" not in \"\".join(os.environ.keys()):\n",
" !pip install unsloth vllm\n",
"else:\n",
" # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]\n",
" !pip install --no-deps unsloth vllm"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8lIsllEIxWAr"
},
"outputs": [],
"source": [
"#@title Colab Extra Install { display-mode: \"form\" }\n",
"%%capture\n",
"import os\n",
"if \"COLAB_\" not in \"\".join(os.environ.keys()):\n",
" !pip install unsloth vllm\n",
"else:\n",
" !pip install --no-deps unsloth vllm\n",
" # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]\n",
" # Skip restarting message in Colab\n",
" import sys, re, requests; modules = list(sys.modules.keys())\n",
" for x in modules: sys.modules.pop(x) if \"PIL\" in x or \"google\" in x else None\n",
" !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo\n",
" !pip install sentencepiece protobuf datasets huggingface_hub hf_transfer\n",
"\n",
" # vLLM requirements - vLLM breaks Colab due to reinstalling numpy\n",
" f = requests.get(\"https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/requirements/common.txt\").content\n",
" with open(\"vllm_requirements.txt\", \"wb\") as file:\n",
" file.write(re.sub(rb\"(transformers|numpy|xformers)[^\\n]{1,}\\n\", b\"\", f))\n",
" !pip install -r vllm_requirements.txt"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a6uTCUXFxWAr"
},
"source": [
"### Unsloth"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Joje4qPsyxM9"
},
"source": [
"Load up `Qwen 2.5 3B Instruct`, and set parameters"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000,
"referenced_widgets": [
"b46d7118f71043109c2dd00fb0f40557",
"32defd049d8a473c8ad8fc954c0e3d13",
"424078e3ad6c405c925bedfb76d18b3f",
"8d937714a44d422aa9851c02c2345e5c",
"25b26c398e314cb79a2f8bbb7e32e158",
"694494752ab2458ead83b0331af20122",
"1cc89fa4d7c744c69c3be1a3d0b3ba2d",
"15116ef21fae463ba9013c8b9dbd40b5",
"43fdf1fe592147538ec43410a1bf8e09",
"52ba857bdfc742a2aae9e9a949054e5c",
"589a3857411f49d4ae019465f536da45",
"5dafeefcf436462e81a9affe7f6c66d9",
"98d1697fb2054fd88432f027e592ecc0",
"b1adc20d1d8740f99a6ebd589f240688",
"f45a696d73054cd6af181a7ce35a62e4",
"372b687ac774400088d7252467f9cef2",
"a422e26291bb47d08e36b7d94fffcc9b",
"799dc8fe6f7d4e4db68ebd6b1d8fdbd5",
"e61a718f06a045b6b028fbd33382d37a",
"7c712c0011954872834e90cc908b8cf6",
"9929aaffe7614506ae934a14d2b5c3ca",
"1060550e9dec4fa59e2b9b488af61b48",
"ebd36fe7f8dd47b99e69176ace9b985d",
"42a3c5c410cd49a38400b788980ad428",
"e5373b31cbfc4e87af314bf4898d2b72",
"4379f29ff4a74df784bb2fb3989c64cc",
"87b42e656263458cbb33ee3ca29128a8",
"b245d8e4c115438583ef8b433d4d16be",
"df7f9d244f70440aaf401c6a856a6ed0",
"5d39d9e8cb724af59d7af29ec9c41cf4",
"6292f7a4de1948d09bebaf5064fa0d42",
"62cbedb2d67d4a2c8888b46e56e39f65",
"57f2c072b0574f74a7dbe2a9d303cbd4",
"78345adf6d8143e3a649a05c5d476115",
"f89bfff6526c43bcb75c93f987bfba72",
"a7244280fc934b8484b3dfd7eecca161",
"ec7ae2abc92f49f98ab0c52953bf82af",
"092355509a67474e9cc13a6b4718a192",
"c69730dfeec14cedaf8917ee9ad27dcb",
"00177cac24374c7d9144e3856594ddaf",
"1b1e49195c5b48a1b6e4364db01c5c87",
"97af130f6bd54d618e0af88ed9e85819",
"1f88a409f3c64ea291861837e830ce39",
"e9881edb2c7846fea2070c8e724b25cb",
"dfc8cc5ee6704f1fb3b4df0507ea471b",
"16bb453bd4934bd9840a9d365be604cd",
"1ee1a56d04fa4e85a20e4cde46c164e5",
"9872da92026e4debb96b731c892408e2",
"451d181b9ef24b0e9fc3914a4466b9f1",
"6a49aa5a2b3646d4b772e2398eacf72d",
"edeb0f283b0b495eab6e06a15fa38af0",
"0ed98c973a9d4d30b31ab751a7ecf3f6",
"a0b7dab4f65e448f816e00d12dc3b469",
"3b6d4b41992b43acac5daa31f119cd30",
"c5b4960fc3ba4459a6eaff43772050a3",
"71f9e574addd491cbbe65ddbe0eb23ad",
"442e58ceec4e44a3a4a2ebadd145ee73",
"4382289897b54e63b22541d0b45c9bf3",
"ba6998ac52544199b4e828d1a7f99d0d",
"29399d40d44d46c8b3648218e8378e89",
"c6b2d407256549a5bc7543cb6d7475b4",
"545251749b10408d82c9a60b914dce34",
"1bbadef2fca5427fa1a4975153968f56",
"d267ec5dd1954b9a88e566dff4a9af9d",
"4ddc8fa0c9834cd3b36274a9e0e05504",
"ec13219478284bbdbe69cbcd3dff5c34",
"562637f216114f42aa352cabd59cab57",
"e37bcfc1604b4a788f53825c5a7e04b7",
"16e05d45084c44b68697316bc8791b88",
"9b80621f94a0484791b33a13b3c6bade",
"fc87fbccba304554be90701b39429918",
"aaf79572906c4443b5542bef5dee8055",
"bc81dfa869f94f059545a7d59b7eb70a",
"b3c84c45c0804e4cb703f9c512c6f90d",
"6520db76749d4a8aaa6d26005f07c43c",
"29a3f867c8bd4769b8e649b8912776a2",
"2317c9a94dfe4a62a8605038ebc2210a",
"bd27949a1e3c43e5912a01b04b6621a4",
"94ed75e3a9fe4cbdb373d4038af0c32b",
"adfd8a554ada48f9a3dc2fab69467b79",
"0f9dcd9b698b4af699cfafdf2874612c",
"f8700d095ee04a40a15c6a88d7daf1ae",
"1ccb12282661417383a7f96bd5ea4eac",
"32a16b0cbbe548f4869cf665d737b9ba",
"3f546bf0aa124e45b39daa0328a1d23d",
"61f6966f8b6d4070a9ea3604373cb96a",
"f8198301b2db406685189589e59efd9a",
"e813e26451e3476a82a9b75668b638a1",
"8aa00f7a2de64af8b0a5b8cd5fefb14e",
"3404fb0c4c4b4ca0ae6d9eeec6722e02",
"f268eda021e040e6b25889e72ee8dd1e",
"b61f7faf19594e3c81e1429045f341de",
"a4ecf424e6ed44c597f25acee7b43df2",
"c314f7a0c6084790a2620ec231297525",
"950c1b85a8bd456ab928df4cc3710b21",
"806b17b8ce074563941dbba9ca2c9422",
"0b42f72fac844c19adea7b6689e79432",
"37e2f38471d4481eb93031ca52bdb785",
"067aaf24d87a400b9f636762051598f3",
"a03f0a986af84eacbb5c333f49726105",
"abf0e1d60e9f4fadbcda876dd28f6d14",
"81a7650cc4124b58b4e191093a5433e1",
"74f24411d35e4c918c608a1afd799156",
"a8501b8551d4431fbb708d6556e7dd85",
"24757877452d48f0afe92bb66e2c48ac",
"b64be6a0d6df44899ba6fe0a47f676e1",
"c1588f6fe9ca48a5a7512734f26ae9a3",
"374765208bf345089f58e925d2a70984",
"78774135f2d1468d9172fa7588afef20",
"9b3233c52ecc4b259deed2d5d6f332eb",
"2567e718b28f417e834a10c5310f92b2",
"587ea317bd0a4831a9417a3980207a73",
"9e857c0cb85d4b50ad73c670436b48a5",
"f8705a1680534caa801dc6e14c3601e3",
"7e982757835048ae85abf6cba5d27ecb",
"e14a085ca7e6410ca78fe0f27533be6d",
"6d204a683e2d49deb5f504f91ee72044",
"6659b3fef8644c6e96ed9d43adb744d5",
"4c130bb196db422c8b9c309cbe59bac1",
"32f42609c389450bbaad10f0a884c527",
"a9e874a076234974a102b02551ec8666",
"e505a51d928542b883743a6d48c5cbd3",
"32ae8645c975473b81788c17c376367a",
"7cb758be0be64a48930379b2cfa9fb44",
"fb4bf57c071d4ef68046988a4965b312",
"f1af6e23e42c4051a1c7fd4864450eaa",
"5cf88ae3c2fc4419a71eb45f28f67384",
"327fa24e70ee4cae903fe5673bcf3bfc",
"d1d5afc36de043afb8638dec0001288e",
"fffe74f9a1184de5b1297c659c7891b6",
"878702a019e34789b5b2f68eb321f2c8",
"91998be216844c079144c18278f660d7",
"5a1c360e82204ad99238901d366828cf",
"fc98ab1d58d84f69b6c36dbc19f5edd7",
"e2e86147b7d44f66a2cdfe02d63ee141",
"2f5178d399ec48648b435bac95faf188",
"c13ff304d7bc4c67a1f51da5d2303f96",
"0e8c4086734c4296884db9c5b67ddd33",
"c3a97f3130bc42c1be2f8c7dc81ca6cc",
"0043184800574bfdbfb3c82f4842c3ed",
"32719fe47c8543f69bcddc9b9bc0f719",
"784a5c0d201d46459a160e0dc8857fd9",
"d753c1ecc6034237ba4401a007451bfb",
"55fd60765d5e46c0baf1ac075820fc87",
"771af8b28988459ba4f2dbaadd0d5ce1",
"0cc22299405f4361b1b5f6234f35bae3",
"0f22a4136dce46f89e9efbf0581057a2",
"5cda6bd66e3248c4be477f1c1b40198d",
"269b131121554544a21d47596a24aa58",
"73758a3ca9c14d7eaa6504d8b9b6e333",
"5a16728cee0f4b75be08696261d8d8cc",
"da77a7b205b1446a9dacab0550aafebe",
"d6bd8a9af6e84500956bb565f2fe2c4e",
"6643bd32eb984a929159c8d441530aaf",
"8358f98f790d47fdb783eb5ca210c1fc",
"0c94630c2693424d9edd8b9921333aac",
"edf185cd11544b2c8bce478fb2f0b565",
"a7b4b7d3f1a34b089d72eb2a4ab169b0",
"09a607655a054ed5ad7f36716f8ee4c6",
"9258ed2d94924de3b944105e280fbf5f",
"1c68fbbc833d4c3fbbb686f15b9ba939",
"fa996ce2f8f24a5d80914c4ef5b6e788",
"f77e9e785cc04c00822fc8b7ef2959da",
"5d190801b10b49a698c3a69c7cd1f253",
"0112cfe0a27e43fd9509ef8da050e188",
"4c6b2159bbc04b8d811fadff3a22f4a6",
"c705d699568543829441df2258f2e615",
"eb8f921c8d67442b9bc2a6a67930d94c",
"8a89212807f243d4b135a3188f4fb6d0",
"cc018dfe29144d0182058816134e1fd9",
"c404db8298364742b8ce076cddcb7641",
"03aca6d92597461481419eac9bdb6635",
"59b90e05822f480793685902b9c6d055",
"d13cc9053e7a4982b6fc7f00474f25c7",
"c9b781db7ee2422b8dc4eb6aa2575d87",
"5100eb13df1c453980942ec78d0e810b"
]
},
"id": "DkIvEkIIkEyB",
"outputId": "8bf13df5-67ab-4ea2-db01-5719298b2dc9"
},
"outputs": [],
"source": [
"from unsloth import FastLanguageModel, is_bfloat16_supported\n",
"import torch\n",
"max_seq_length = 1024 # Can increase for longer reasoning traces\n",
"lora_rank = 64 # Larger rank = smarter, but slower\n",
"\n",
"model, tokenizer = FastLanguageModel.from_pretrained(\n",
" model_name = \"Qwen/Qwen2.5-3B-Instruct\",\n",
" max_seq_length = max_seq_length,\n",
" load_in_4bit = True, # False for LoRA 16bit\n",
" fast_inference = True, # Enable vLLM fast inference\n",
" max_lora_rank = lora_rank,\n",
" gpu_memory_utilization = 0.5, # Reduce if out of memory\n",
")\n",
"\n",
"model = FastLanguageModel.get_peft_model(\n",
" model,\n",
" r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128\n",
" target_modules = [\n",
" \"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
" \"gate_proj\", \"up_proj\", \"down_proj\",\n",
" ], # Remove QKVO if out of memory\n",
" lora_alpha = lora_rank,\n",
" use_gradient_checkpointing = \"unsloth\", # Enable long context finetuning\n",
" random_state = 3407,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0Y56ln_izS9E"
},
"source": [
"### Data Prep\n",
"<a name=\"Data\"></a>\n",
"\n",
"We directly leverage [@willccbb](https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb) for data prep and all reward functions. You are free to create your own!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 209,
"referenced_widgets": [
"6b2cf7a462884bc4a3c2f8670f8982ed",
"8a358b880c194e6baab24f711937f695",
"0ec5dfa2fa6b4be2bfb45692ac20c053",
"1a619a0a53c149c3b9b0cbd53f1d2034",
"fdcbb976f35e4aa383d9ebcd7304849f",
"b131891a909743f5826429a84a74b4f9",
"3745da58146d445abe863a248e86d9be",
"22d29d39ef8149dba932ed9af68def67",
"413b4818b6de4bd995944df2f73bc3d5",
"c626fa7f5c624ec5a414e6f58d71ea96",
"2a8fb4b980ab45be81ae39d5c6b7bad4",
"ec1837ba407c4ecb86335d9c0c239ad9",
"1793da35014840b08d776b4abc9a1212",
"a06bde82466c43a49b58c0d942071be0",
"a1a3ee663a154319ba186954816e6932",
"97a9db845a9c45caa88ea1cd0ab6c40d",
"231be59acb2a4921ba53ba97bdf5888a",
"5aa0e699560c4fe89250b4ccac6401cb",
"7529f4d84b9f46709ec5f0e22b75baf4",
"8b8fd84911e94b5d8d627911fa9a64f7",
"420154f0bbe344a7979769a8ef066df7",
"1ee34b9171604805887aee3baa4ff099",
"3c95b86a26bc49b4a65589df0e7865e2",
"ef6d7a6a7fb344fab5befe656b98c7b7",
"c3d523c23ebf4b769c59eda950a3fc1f",
"bffc24c5bcdb47988f2e309c7fbc21b8",
"11052ab3d1f441248458cd6ab8b810db",
"0c2e39b20d5941969101b793b94fdbfa",
"f0c74de390814c2b9f38f7cc02458426",
"6c9502ae4cbc4248b14590fc6efe832f",
"8972cb9e823342bcb5c493e487dc64b6",
"71898e427b1245d682178a90c6752e8e",
"90295f1308054c66b557da231d25d3d5",
"54f444e6a25349fcb8fb35b9efb33e68",
"2f9a64ad318248bf99fccdedafaddfc0",
"d3ce3f3bee5646d893242758444acd6a",
"5b4f99c20a8a4ecf873b0a8ef8b150cd",
"70e141512e91492bae4713400879d019",
"7c67cf4d83e34cd6a475405be45f345e",
"c184e417ad134e94995a55d4bc7c94b5",
"f1c8d7356b584fb8965d2a62ccec52c8",
"9bcb0000da6e434d9b241a09d9b46244",
"11bd7b89fa794a899db9c83cb2a99df5",
"0e9d3d2244a54eb781311c01965c33e4",
"00d1bd9a327241e8b434ef7b3bde30df",
"cd17f032bbb44575977797caaa92bd54",
"04b74b7851334892b0db68bea1c8fd7e",
"44b6ecc9b9fc436e92701ec2409e0dd1",
"46998bc2b1ee4858a8ee1108ce36b9b7",
"95fa6a54e7e543c7b2dc5d5cf04daf58",
"8ee15a3d02954d31899d0b6a8dbf7988",
"1d5a58d63c51427bbbfb729216ea8468",
"e2790d1e5e714b3e9081b4edcef24535",
"96519710bf5047239f6800100471ee6e",
"e9b020ce2dda442b871b41e417210b52",
"4d18ef3f7c8046ddb0b03314861c3414",
"176ccdb40b994ee3b381f7ea0f379c3c",
"3519fdf3abac4d7690c45710f617e5d0",
"dc320ae9e2f7426588d49848b0871ca4",
"6b5a1287c457454fb367e5dc03d8b157",
"989519f3e61b4132b3f88fbdf29703a9",
"e0f87f6fd7c84261818624defe575a1c",
"e6d2b6aaa9794971b01c778b1a0596db",
"334c4833051e4ccea727dc2cdf3871c3",
"d36027362b41413fa3ee4670dd820a88",
"d94e3315b99747c39190fee9b79f00ac"
]
},
"id": "cXk993X6C2ZZ",
"outputId": "2d545127-0313-405e-d8f9-a43d7dba8d11"
},
"outputs": [],
"source": [
"import re\n",
"from datasets import load_dataset, Dataset\n",
"\n",
"# Load and prep dataset\n",
"SYSTEM_PROMPT = \"\"\"\n",
"Respond in the following format:\n",
"<reasoning>\n",
"...\n",
"</reasoning>\n",
"<answer>\n",
"...\n",
"</answer>\n",
"\"\"\"\n",
"\n",
"XML_COT_FORMAT = \"\"\"\\\n",
"<reasoning>\n",
"{reasoning}\n",
"</reasoning>\n",
"<answer>\n",
"{answer}\n",
"</answer>\n",
"\"\"\"\n",
"\n",
"def extract_xml_answer(text: str) -> str:\n",
" answer = text.split(\"<answer>\")[-1]\n",
" answer = answer.split(\"</answer>\")[0]\n",
" return answer.strip()\n",
"\n",
"def extract_hash_answer(text: str) -> str | None:\n",
" if \"####\" not in text:\n",
" return None\n",
" return text.split(\"####\")[1].strip()\n",
"\n",
"# uncomment middle messages for 1-shot prompting\n",
"def get_gsm8k_questions(split = \"train\") -> Dataset:\n",
" data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore\n",
" data = data.map(lambda x: { # type: ignore\n",
" 'prompt': [\n",
" {'role': 'system', 'content': SYSTEM_PROMPT},\n",
" {'role': 'user', 'content': x['question']}\n",
" ],\n",
" 'answer': extract_hash_answer(x['answer'])\n",
" }) # type: ignore\n",
" return data # type: ignore\n",
"\n",
"dataset = get_gsm8k_questions()\n",
"\n",
"# Reward functions\n",
"def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:\n",
" responses = [completion[0]['content'] for completion in completions]\n",
" q = prompts[0][-1]['content']\n",
" extracted_responses = [extract_xml_answer(r) for r in responses]\n",
" print('-'*20, f\"Question:\\n{q}\", f\"\\nAnswer:\\n{answer[0]}\", f\"\\nResponse:\\n{responses[0]}\", f\"\\nExtracted:\\n{extracted_responses[0]}\")\n",
" return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]\n",
"\n",
"def int_reward_func(completions, **kwargs) -> list[float]:\n",
" responses = [completion[0]['content'] for completion in completions]\n",
" extracted_responses = [extract_xml_answer(r) for r in responses]\n",
" return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]\n",
"\n",
"def strict_format_reward_func(completions, **kwargs) -> list[float]:\n",
" \"\"\"Reward function that checks if the completion has a specific format.\"\"\"\n",
" pattern = r\"^<reasoning>\\n.*?\\n</reasoning>\\n<answer>\\n.*?\\n</answer>\\n$\"\n",
" responses = [completion[0][\"content\"] for completion in completions]\n",
" matches = [re.match(pattern, r) for r in responses]\n",
" return [0.5 if match else 0.0 for match in matches]\n",
"\n",
"def soft_format_reward_func(completions, **kwargs) -> list[float]:\n",
" \"\"\"Reward function that checks if the completion has a specific format.\"\"\"\n",
" pattern = r\"<reasoning>.*?</reasoning>\\s*<answer>.*?</answer>\"\n",
" responses = [completion[0][\"content\"] for completion in completions]\n",
" matches = [re.match(pattern, r) for r in responses]\n",
" return [0.5 if match else 0.0 for match in matches]\n",
"\n",
"def count_xml(text) -> float:\n",
" count = 0.0\n",
" if text.count(\"<reasoning>\\n\") == 1:\n",
" count += 0.125\n",
" if text.count(\"\\n</reasoning>\\n\") == 1:\n",
" count += 0.125\n",
" if text.count(\"\\n<answer>\\n\") == 1:\n",
" count += 0.125\n",
" count -= len(text.split(\"\\n</answer>\\n\")[-1])*0.001\n",
" if text.count(\"\\n</answer>\") == 1:\n",
" count += 0.125\n",
" count -= (len(text.split(\"\\n</answer>\")[-1]) - 1)*0.001\n",
" return count\n",
"\n",
"def xmlcount_reward_func(completions, **kwargs) -> list[float]:\n",
" contents = [completion[0][\"content\"] for completion in completions]\n",
" return [count_xml(c) for c in contents]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bTnL_tJnzh2L"
},
"source": [
"<a name=\"Train\"></a>\n",
"### Train the model\n",
"\n",
"Now set up GRPO Trainer and all configurations!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ptqkXK2D4d6p",
"outputId": "7d36a642-4cc8-48ee-c8cc-1470978d4503"
},
"outputs": [],
"source": [
"from trl import GRPOConfig, GRPOTrainer\n",
"training_args = GRPOConfig(\n",
" use_vllm = True, # use vLLM for fast inference!\n",
" learning_rate = 5e-6,\n",
" adam_beta1 = 0.9,\n",
" adam_beta2 = 0.99,\n",
" weight_decay = 0.1,\n",
" warmup_ratio = 0.1,\n",
" lr_scheduler_type = \"cosine\",\n",
" optim = \"adamw_8bit\",\n",
" logging_steps = 1,\n",
" bf16 = is_bfloat16_supported(),\n",
" fp16 = not is_bfloat16_supported(),\n",
" per_device_train_batch_size = 1,\n",
" gradient_accumulation_steps = 1, # Increase to 4 for smoother training\n",
" num_generations = 8, # Decrease if out of memory\n",
" max_prompt_length = 256,\n",
" max_completion_length = 200,\n",
" # num_train_epochs = 1, # Set to 1 for a full training run\n",
" max_steps = 250,\n",
" save_steps = 250,\n",
" max_grad_norm = 0.1,\n",
" report_to = \"none\", # Can use Weights & Biases\n",
" output_dir = \"outputs\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "X_71Y0eKz5yE"
},
"source": [
"And let's run the trainer! If you scroll up, you'll see a table of rewards. The goal is to see the `reward` column increase!\n",
"\n",
"You might have to wait 150 to 200 steps for any action. You'll probably get 0 reward for the first 100 steps. Please be patient!\n",
"\n",
"| Step | Training Loss | reward | reward_std | completion_length | kl |\n",
"|------|---------------|-----------|------------|-------------------|----------|\n",
"| 1 | 0.000000 | 0.125000 | 0.000000 | 200.000000 | 0.000000 |\n",
"| 2 | 0.000000 | 0.072375 | 0.248112 | 200.000000 | 0.000000 |\n",
"| 3 | 0.000000 | -0.079000 | 0.163776 | 182.500000 | 0.000005 |\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "vzOuSVCL_GA9",
"outputId": "c6fdbbe8-7062-45a2-a48c-649f2b878ac4"
},
"outputs": [],
"source": [
"trainer = GRPOTrainer(\n",
" model = model,\n",
" processing_class = tokenizer,\n",
" reward_funcs = [\n",
" xmlcount_reward_func,\n",
" soft_format_reward_func,\n",
" strict_format_reward_func,\n",
" int_reward_func,\n",
" correctness_reward_func,\n",
" ],\n",
" args = training_args,\n",
" train_dataset = dataset,\n",
")\n",
"trainer.train()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yUbluAAhD0Lg"
},
"source": [
"<a name=\"Inference\"></a>\n",
"### Inference\n",
"Now let's try the model we just trained! First, let's first try the model without any GRPO trained:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 72
},
"id": "IqzsdZzeDM_m",
"outputId": "a7637565-80fe-4ee9-cdc9-33b9affb65dc"
},
"outputs": [],
"source": [
"text = tokenizer.apply_chat_template([\n",
" {\"role\" : \"user\", \"content\" : \"How many r's are in strawberry?\"},\n",
"], tokenize = False, add_generation_prompt = True)\n",
"\n",
"from vllm import SamplingParams\n",
"sampling_params = SamplingParams(\n",
" temperature = 0.8,\n",
" top_p = 0.95,\n",
" max_tokens = 1024,\n",
")\n",
"output = model.fast_generate(\n",
" [text],\n",
" sampling_params = sampling_params,\n",
" lora_request = None,\n",
")[0].outputs[0].text\n",
"\n",
"output"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "G4lzJD7REFjs"
},
"source": [
"And now with the LoRA we just trained with GRPO - we first save the LoRA first!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YC9BBT0RESln"
},
"outputs": [],
"source": [
"model.save_lora(\"grpo_saved_lora\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LherO2vzEbMt"
},
"source": [
"Now we load the LoRA and test:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 125
},
"id": "SDKIhhvN6lAF",
"outputId": "8bc37894-678f-40d7-874b-365300685836"
},
"outputs": [],
"source": [
"text = tokenizer.apply_chat_template([\n",
" {\"role\" : \"system\", \"content\" : SYSTEM_PROMPT},\n",
" {\"role\" : \"user\", \"content\" : \"How many r's are in strawberry?\"},\n",
"], tokenize = False, add_generation_prompt = True)\n",
"\n",
"from vllm import SamplingParams\n",
"sampling_params = SamplingParams(\n",
" temperature = 0.8,\n",
" top_p = 0.95,\n",
" max_tokens = 1024,\n",
")\n",
"output = model.fast_generate(\n",
" text,\n",
" sampling_params = sampling_params,\n",
" lora_request = model.load_lora(\"grpo_saved_lora\"),\n",
")[0].outputs[0].text\n",
"\n",
"output"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_ZBnvg2f9Nlg"
},
"source": [
"Our reasoning model is much better - it's not always correct, since we only trained it for an hour or so - it'll be better if we extend the sequence length and train for longer!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RphEZRSfFhru"
},
"source": [
"<a name=\"Save\"></a>\n",
"### Saving to float16 for VLLM\n",
"\n",
"We also support saving to `float16` directly. Select `merged_16bit` for float16 or `merged_4bit` for int4. We also allow `lora` adapters as a fallback. Use `push_to_hub_merged` to upload to your Hugging Face account! You can go to https://huggingface.co/settings/tokens for your personal tokens."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GwNY9_PrFiXZ"
},
"outputs": [],
"source": [
"# Merge to 16bit\n",
"if False: model.save_pretrained_merged(\"model\", tokenizer, save_method = \"merged_16bit\",)\n",
"if False: model.push_to_hub_merged(\"hf/model\", tokenizer, save_method = \"merged_16bit\", token = \"\")\n",
"\n",
"# Merge to 4bit\n",
"if False: model.save_pretrained_merged(\"model\", tokenizer, save_method = \"merged_4bit\",)\n",
"if False: model.push_to_hub_merged(\"hf/model\", tokenizer, save_method = \"merged_4bit\", token = \"\")\n",
"\n",
"# Just LoRA adapters\n",
"if False: model.save_pretrained_merged(\"model\", tokenizer, save_method = \"lora\",)\n",
"if False: model.push_to_hub_merged(\"hf/model\", tokenizer, save_method = \"lora\", token = \"\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BDUGPiL3Fkkq"
},
"source": [
"### GGUF / llama.cpp Conversion\n",
"To save to `GGUF` / `llama.cpp`, we support it natively now! We clone `llama.cpp` and we default save it to `q8_0`. We allow all methods like `q4_k_m`. Use `save_pretrained_gguf` for local saving and `push_to_hub_gguf` for uploading to HF.\n",
"\n",
"Some supported quant methods (full list on our [Wiki page](https://github.com/unslothai/unsloth/wiki#gguf-quantization-options)):\n",
"* `q8_0` - Fast conversion. High resource use, but generally acceptable.\n",
"* `q4_k_m` - Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q4_K.\n",
"* `q5_k_m` - Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q5_K.\n",
"\n",
"[**NEW**] To finetune and auto export to Ollama, try our [Ollama notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "AGo4dbWvFk4M"
},
"outputs": [],
"source": [
"# Save to 8bit Q8_0\n",
"if False: model.save_pretrained_gguf(\"model\", tokenizer,)\n",
"# Remember to go to https://huggingface.co/settings/tokens for a token!\n",
"# And change hf to your username!\n",
"if False: model.push_to_hub_gguf(\"hf/model\", tokenizer, token = \"\")\n",
"\n",
"# Save to 16bit GGUF\n",
"if False: model.save_pretrained_gguf(\"model\", tokenizer, quantization_method = \"f16\")\n",
"if False: model.push_to_hub_gguf(\"hf/model\", tokenizer, quantization_method = \"f16\", token = \"\")\n",
"\n",
"# Save to q4_k_m GGUF\n",
"if False: model.save_pretrained_gguf(\"model\", tokenizer, quantization_method = \"q4_k_m\")\n",
"if False: model.push_to_hub_gguf(\"hf/model\", tokenizer, quantization_method = \"q4_k_m\", token = \"\")\n",
"\n",
"# Save to multiple GGUF options - much faster if you want multiple!\n",
"if False:\n",
" model.push_to_hub_gguf(\n",
" \"hf/model\", # Change hf to your username!\n",
" tokenizer,\n",
" quantization_method = [\"q4_k_m\", \"q8_0\", \"q5_k_m\",],\n",
" token = \"\",\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bhYfRM8PxWAu"
},
"source": [
"Now, use the `model-unsloth.gguf` file or `model-unsloth-Q4_K_M.gguf` file in llama.cpp or a UI based system like Jan or Open WebUI. You can install Jan [here](https://github.com/janhq/jan) and Open WebUI [here](https://github.com/open-webui/open-webui)\n",
"\n",
"And we're done! If you have any questions on Unsloth, we have a [Discord](https://discord.gg/unsloth) channel! If you find any bugs or want to keep updated with the latest LLM stuff, or need help, join projects etc, feel free to join our Discord!\n",
"\n",
"Some other links:\n",
"1. Train your own reasoning model - Llama GRPO notebook [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-GRPO.ipynb)\n",
"2. Saving finetunes to Ollama. [Free notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)\n",
"3. Llama 3.2 Vision finetuning - Radiography use case. [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(11B)-Vision.ipynb)\n",
"6. See notebooks for DPO, ORPO, Continued pretraining, conversational finetuning and more on our [documentation](https://docs.unsloth.ai/get-started/unsloth-notebooks)!\n",
"\n",
"<div class=\"align-center\">\n",
" <a href=\"https://unsloth.ai\"><img src=\"https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png\" width=\"115\"></a>\n",
" <a href=\"https://discord.gg/unsloth\"><img src=\"https://github.com/unslothai/unsloth/raw/main/images/Discord.png\" width=\"145\"></a>\n",
" <a href=\"https://docs.unsloth.ai/\"><img src=\"https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true\" width=\"125\"></a>\n",
"\n",
" Join Discord if you need help + ⭐️ <i>Star us on <a href=\"https://github.com/unslothai/unsloth\">Github</a> </i> ⭐️\n",
"</div>\n"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "T4",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}