You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
940 lines
33 KiB
940 lines
33 KiB
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "Oh3eu2ULxWAn"
|
|
},
|
|
"source": [
|
|
"To run this, press \"*Runtime*\" and press \"*Run all*\" on a **free** Tesla T4 Google Colab instance!\n",
|
|
"<div class=\"align-center\">\n",
|
|
"<a href=\"https://unsloth.ai/\"><img src=\"https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png\" width=\"115\"></a>\n",
|
|
"<a href=\"https://discord.gg/unsloth\"><img src=\"https://github.com/unslothai/unsloth/raw/main/images/Discord button.png\" width=\"145\"></a>\n",
|
|
"<a href=\"https://docs.unsloth.ai/\"><img src=\"https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true\" width=\"125\"></a></a> Join Discord if you need help + ⭐ <i>Star us on <a href=\"https://github.com/unslothai/unsloth\">Github</a> </i> ⭐\n",
|
|
"</div>\n",
|
|
"\n",
|
|
"To install Unsloth on your own computer, follow the installation instructions on our Github page [here](https://docs.unsloth.ai/get-started/installing-+-updating).\n",
|
|
"\n",
|
|
"You will learn how to do [data prep](#Data), how to [train](#Train), how to [run the model](#Inference), & [how to save it](#Save)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "rNxpti3ExWAp"
|
|
},
|
|
"source": [
|
|
"### News"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "yo-qTlvNxWAp"
|
|
},
|
|
"source": [
|
|
"**Read our [Gemma 3 blog](https://unsloth.ai/blog/gemma3) for what's new in Unsloth and our [Reasoning blog](https://unsloth.ai/blog/r1-reasoning) on how to train reasoning models.**\n",
|
|
"\n",
|
|
"Visit our docs for all our [model uploads](https://docs.unsloth.ai/get-started/all-our-models) and [notebooks](https://docs.unsloth.ai/get-started/unsloth-notebooks).\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "MrkTn5S9xWAq"
|
|
},
|
|
"source": [
|
|
"### Installation"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "NV0BlxygxWAq"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"%%capture\n",
|
|
"import os\n",
|
|
"\n",
|
|
"if \"COLAB_\" not in \"\".join(os.environ.keys()):\n",
|
|
" !pip install unsloth vllm\n",
|
|
"else:\n",
|
|
" # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]\n",
|
|
" !pip install --no-deps unsloth vllm"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "8lIsllEIxWAr"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# @title Colab Extra Install { display-mode: \"form\" }\n",
|
|
"%%capture\n",
|
|
"import os\n",
|
|
"\n",
|
|
"if \"COLAB_\" not in \"\".join(os.environ.keys()):\n",
|
|
" !pip install unsloth vllm\n",
|
|
"else:\n",
|
|
" !pip install --no-deps unsloth vllm\n",
|
|
" # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]\n",
|
|
" # Skip restarting message in Colab\n",
|
|
" import sys\n",
|
|
" import re\n",
|
|
" import requests\n",
|
|
"\n",
|
|
" modules = list(sys.modules.keys())\n",
|
|
" for x in modules:\n",
|
|
" sys.modules.pop(x) if \"PIL\" in x or \"google\" in x else None\n",
|
|
" !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo\n",
|
|
" !pip install sentencepiece protobuf datasets huggingface_hub hf_transfer\n",
|
|
"\n",
|
|
" # vLLM requirements - vLLM breaks Colab due to reinstalling numpy\n",
|
|
" f = requests.get(\n",
|
|
" \"https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/requirements/common.txt\"\n",
|
|
" ).content\n",
|
|
" with open(\"vllm_requirements.txt\", \"wb\") as file:\n",
|
|
" file.write(re.sub(rb\"(transformers|numpy|xformers)[^\\n]{1,}\\n\", b\"\", f))\n",
|
|
" !pip install -r vllm_requirements.txt"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "a6uTCUXFxWAr"
|
|
},
|
|
"source": [
|
|
"### Unsloth"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "Joje4qPsyxM9"
|
|
},
|
|
"source": [
|
|
"Load up `Qwen 2.5 3B Instruct`, and set parameters"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 1000,
|
|
"referenced_widgets": [
|
|
"b46d7118f71043109c2dd00fb0f40557",
|
|
"32defd049d8a473c8ad8fc954c0e3d13",
|
|
"424078e3ad6c405c925bedfb76d18b3f",
|
|
"8d937714a44d422aa9851c02c2345e5c",
|
|
"25b26c398e314cb79a2f8bbb7e32e158",
|
|
"694494752ab2458ead83b0331af20122",
|
|
"1cc89fa4d7c744c69c3be1a3d0b3ba2d",
|
|
"15116ef21fae463ba9013c8b9dbd40b5",
|
|
"43fdf1fe592147538ec43410a1bf8e09",
|
|
"52ba857bdfc742a2aae9e9a949054e5c",
|
|
"589a3857411f49d4ae019465f536da45",
|
|
"5dafeefcf436462e81a9affe7f6c66d9",
|
|
"98d1697fb2054fd88432f027e592ecc0",
|
|
"b1adc20d1d8740f99a6ebd589f240688",
|
|
"f45a696d73054cd6af181a7ce35a62e4",
|
|
"372b687ac774400088d7252467f9cef2",
|
|
"a422e26291bb47d08e36b7d94fffcc9b",
|
|
"799dc8fe6f7d4e4db68ebd6b1d8fdbd5",
|
|
"e61a718f06a045b6b028fbd33382d37a",
|
|
"7c712c0011954872834e90cc908b8cf6",
|
|
"9929aaffe7614506ae934a14d2b5c3ca",
|
|
"1060550e9dec4fa59e2b9b488af61b48",
|
|
"ebd36fe7f8dd47b99e69176ace9b985d",
|
|
"42a3c5c410cd49a38400b788980ad428",
|
|
"e5373b31cbfc4e87af314bf4898d2b72",
|
|
"4379f29ff4a74df784bb2fb3989c64cc",
|
|
"87b42e656263458cbb33ee3ca29128a8",
|
|
"b245d8e4c115438583ef8b433d4d16be",
|
|
"df7f9d244f70440aaf401c6a856a6ed0",
|
|
"5d39d9e8cb724af59d7af29ec9c41cf4",
|
|
"6292f7a4de1948d09bebaf5064fa0d42",
|
|
"62cbedb2d67d4a2c8888b46e56e39f65",
|
|
"57f2c072b0574f74a7dbe2a9d303cbd4",
|
|
"78345adf6d8143e3a649a05c5d476115",
|
|
"f89bfff6526c43bcb75c93f987bfba72",
|
|
"a7244280fc934b8484b3dfd7eecca161",
|
|
"ec7ae2abc92f49f98ab0c52953bf82af",
|
|
"092355509a67474e9cc13a6b4718a192",
|
|
"c69730dfeec14cedaf8917ee9ad27dcb",
|
|
"00177cac24374c7d9144e3856594ddaf",
|
|
"1b1e49195c5b48a1b6e4364db01c5c87",
|
|
"97af130f6bd54d618e0af88ed9e85819",
|
|
"1f88a409f3c64ea291861837e830ce39",
|
|
"e9881edb2c7846fea2070c8e724b25cb",
|
|
"dfc8cc5ee6704f1fb3b4df0507ea471b",
|
|
"16bb453bd4934bd9840a9d365be604cd",
|
|
"1ee1a56d04fa4e85a20e4cde46c164e5",
|
|
"9872da92026e4debb96b731c892408e2",
|
|
"451d181b9ef24b0e9fc3914a4466b9f1",
|
|
"6a49aa5a2b3646d4b772e2398eacf72d",
|
|
"edeb0f283b0b495eab6e06a15fa38af0",
|
|
"0ed98c973a9d4d30b31ab751a7ecf3f6",
|
|
"a0b7dab4f65e448f816e00d12dc3b469",
|
|
"3b6d4b41992b43acac5daa31f119cd30",
|
|
"c5b4960fc3ba4459a6eaff43772050a3",
|
|
"71f9e574addd491cbbe65ddbe0eb23ad",
|
|
"442e58ceec4e44a3a4a2ebadd145ee73",
|
|
"4382289897b54e63b22541d0b45c9bf3",
|
|
"ba6998ac52544199b4e828d1a7f99d0d",
|
|
"29399d40d44d46c8b3648218e8378e89",
|
|
"c6b2d407256549a5bc7543cb6d7475b4",
|
|
"545251749b10408d82c9a60b914dce34",
|
|
"1bbadef2fca5427fa1a4975153968f56",
|
|
"d267ec5dd1954b9a88e566dff4a9af9d",
|
|
"4ddc8fa0c9834cd3b36274a9e0e05504",
|
|
"ec13219478284bbdbe69cbcd3dff5c34",
|
|
"562637f216114f42aa352cabd59cab57",
|
|
"e37bcfc1604b4a788f53825c5a7e04b7",
|
|
"16e05d45084c44b68697316bc8791b88",
|
|
"9b80621f94a0484791b33a13b3c6bade",
|
|
"fc87fbccba304554be90701b39429918",
|
|
"aaf79572906c4443b5542bef5dee8055",
|
|
"bc81dfa869f94f059545a7d59b7eb70a",
|
|
"b3c84c45c0804e4cb703f9c512c6f90d",
|
|
"6520db76749d4a8aaa6d26005f07c43c",
|
|
"29a3f867c8bd4769b8e649b8912776a2",
|
|
"2317c9a94dfe4a62a8605038ebc2210a",
|
|
"bd27949a1e3c43e5912a01b04b6621a4",
|
|
"94ed75e3a9fe4cbdb373d4038af0c32b",
|
|
"adfd8a554ada48f9a3dc2fab69467b79",
|
|
"0f9dcd9b698b4af699cfafdf2874612c",
|
|
"f8700d095ee04a40a15c6a88d7daf1ae",
|
|
"1ccb12282661417383a7f96bd5ea4eac",
|
|
"32a16b0cbbe548f4869cf665d737b9ba",
|
|
"3f546bf0aa124e45b39daa0328a1d23d",
|
|
"61f6966f8b6d4070a9ea3604373cb96a",
|
|
"f8198301b2db406685189589e59efd9a",
|
|
"e813e26451e3476a82a9b75668b638a1",
|
|
"8aa00f7a2de64af8b0a5b8cd5fefb14e",
|
|
"3404fb0c4c4b4ca0ae6d9eeec6722e02",
|
|
"f268eda021e040e6b25889e72ee8dd1e",
|
|
"b61f7faf19594e3c81e1429045f341de",
|
|
"a4ecf424e6ed44c597f25acee7b43df2",
|
|
"c314f7a0c6084790a2620ec231297525",
|
|
"950c1b85a8bd456ab928df4cc3710b21",
|
|
"806b17b8ce074563941dbba9ca2c9422",
|
|
"0b42f72fac844c19adea7b6689e79432",
|
|
"37e2f38471d4481eb93031ca52bdb785",
|
|
"067aaf24d87a400b9f636762051598f3",
|
|
"a03f0a986af84eacbb5c333f49726105",
|
|
"abf0e1d60e9f4fadbcda876dd28f6d14",
|
|
"81a7650cc4124b58b4e191093a5433e1",
|
|
"74f24411d35e4c918c608a1afd799156",
|
|
"a8501b8551d4431fbb708d6556e7dd85",
|
|
"24757877452d48f0afe92bb66e2c48ac",
|
|
"b64be6a0d6df44899ba6fe0a47f676e1",
|
|
"c1588f6fe9ca48a5a7512734f26ae9a3",
|
|
"374765208bf345089f58e925d2a70984",
|
|
"78774135f2d1468d9172fa7588afef20",
|
|
"9b3233c52ecc4b259deed2d5d6f332eb",
|
|
"2567e718b28f417e834a10c5310f92b2",
|
|
"587ea317bd0a4831a9417a3980207a73",
|
|
"9e857c0cb85d4b50ad73c670436b48a5",
|
|
"f8705a1680534caa801dc6e14c3601e3",
|
|
"7e982757835048ae85abf6cba5d27ecb",
|
|
"e14a085ca7e6410ca78fe0f27533be6d",
|
|
"6d204a683e2d49deb5f504f91ee72044",
|
|
"6659b3fef8644c6e96ed9d43adb744d5",
|
|
"4c130bb196db422c8b9c309cbe59bac1",
|
|
"32f42609c389450bbaad10f0a884c527",
|
|
"a9e874a076234974a102b02551ec8666",
|
|
"e505a51d928542b883743a6d48c5cbd3",
|
|
"32ae8645c975473b81788c17c376367a",
|
|
"7cb758be0be64a48930379b2cfa9fb44",
|
|
"fb4bf57c071d4ef68046988a4965b312",
|
|
"f1af6e23e42c4051a1c7fd4864450eaa",
|
|
"5cf88ae3c2fc4419a71eb45f28f67384",
|
|
"327fa24e70ee4cae903fe5673bcf3bfc",
|
|
"d1d5afc36de043afb8638dec0001288e",
|
|
"fffe74f9a1184de5b1297c659c7891b6",
|
|
"878702a019e34789b5b2f68eb321f2c8",
|
|
"91998be216844c079144c18278f660d7",
|
|
"5a1c360e82204ad99238901d366828cf",
|
|
"fc98ab1d58d84f69b6c36dbc19f5edd7",
|
|
"e2e86147b7d44f66a2cdfe02d63ee141",
|
|
"2f5178d399ec48648b435bac95faf188",
|
|
"c13ff304d7bc4c67a1f51da5d2303f96",
|
|
"0e8c4086734c4296884db9c5b67ddd33",
|
|
"c3a97f3130bc42c1be2f8c7dc81ca6cc",
|
|
"0043184800574bfdbfb3c82f4842c3ed",
|
|
"32719fe47c8543f69bcddc9b9bc0f719",
|
|
"784a5c0d201d46459a160e0dc8857fd9",
|
|
"d753c1ecc6034237ba4401a007451bfb",
|
|
"55fd60765d5e46c0baf1ac075820fc87",
|
|
"771af8b28988459ba4f2dbaadd0d5ce1",
|
|
"0cc22299405f4361b1b5f6234f35bae3",
|
|
"0f22a4136dce46f89e9efbf0581057a2",
|
|
"5cda6bd66e3248c4be477f1c1b40198d",
|
|
"269b131121554544a21d47596a24aa58",
|
|
"73758a3ca9c14d7eaa6504d8b9b6e333",
|
|
"5a16728cee0f4b75be08696261d8d8cc",
|
|
"da77a7b205b1446a9dacab0550aafebe",
|
|
"d6bd8a9af6e84500956bb565f2fe2c4e",
|
|
"6643bd32eb984a929159c8d441530aaf",
|
|
"8358f98f790d47fdb783eb5ca210c1fc",
|
|
"0c94630c2693424d9edd8b9921333aac",
|
|
"edf185cd11544b2c8bce478fb2f0b565",
|
|
"a7b4b7d3f1a34b089d72eb2a4ab169b0",
|
|
"09a607655a054ed5ad7f36716f8ee4c6",
|
|
"9258ed2d94924de3b944105e280fbf5f",
|
|
"1c68fbbc833d4c3fbbb686f15b9ba939",
|
|
"fa996ce2f8f24a5d80914c4ef5b6e788",
|
|
"f77e9e785cc04c00822fc8b7ef2959da",
|
|
"5d190801b10b49a698c3a69c7cd1f253",
|
|
"0112cfe0a27e43fd9509ef8da050e188",
|
|
"4c6b2159bbc04b8d811fadff3a22f4a6",
|
|
"c705d699568543829441df2258f2e615",
|
|
"eb8f921c8d67442b9bc2a6a67930d94c",
|
|
"8a89212807f243d4b135a3188f4fb6d0",
|
|
"cc018dfe29144d0182058816134e1fd9",
|
|
"c404db8298364742b8ce076cddcb7641",
|
|
"03aca6d92597461481419eac9bdb6635",
|
|
"59b90e05822f480793685902b9c6d055",
|
|
"d13cc9053e7a4982b6fc7f00474f25c7",
|
|
"c9b781db7ee2422b8dc4eb6aa2575d87",
|
|
"5100eb13df1c453980942ec78d0e810b"
|
|
]
|
|
},
|
|
"id": "DkIvEkIIkEyB",
|
|
"outputId": "8bf13df5-67ab-4ea2-db01-5719298b2dc9"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from unsloth import FastLanguageModel, is_bfloat16_supported\n",
|
|
"\n",
|
|
"max_seq_length = 1024 # Can increase for longer reasoning traces\n",
|
|
"lora_rank = 64 # Larger rank = smarter, but slower\n",
|
|
"\n",
|
|
"model, tokenizer = FastLanguageModel.from_pretrained(\n",
|
|
" model_name=\"Qwen/Qwen2.5-3B-Instruct\",\n",
|
|
" max_seq_length=max_seq_length,\n",
|
|
" load_in_4bit=True, # False for LoRA 16bit\n",
|
|
" fast_inference=True, # Enable vLLM fast inference\n",
|
|
" max_lora_rank=lora_rank,\n",
|
|
" gpu_memory_utilization=0.5, # Reduce if out of memory\n",
|
|
")\n",
|
|
"\n",
|
|
"model = FastLanguageModel.get_peft_model(\n",
|
|
" model,\n",
|
|
" r=lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128\n",
|
|
" target_modules=[\n",
|
|
" \"q_proj\",\n",
|
|
" \"k_proj\",\n",
|
|
" \"v_proj\",\n",
|
|
" \"o_proj\",\n",
|
|
" \"gate_proj\",\n",
|
|
" \"up_proj\",\n",
|
|
" \"down_proj\",\n",
|
|
" ], # Remove QKVO if out of memory\n",
|
|
" lora_alpha=lora_rank,\n",
|
|
" use_gradient_checkpointing=\"unsloth\", # Enable long context finetuning\n",
|
|
" random_state=3407,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "0Y56ln_izS9E"
|
|
},
|
|
"source": [
|
|
"### Data Prep\n",
|
|
"<a name=\"Data\"></a>\n",
|
|
"\n",
|
|
"We directly leverage [@willccbb](https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb) for data prep and all reward functions. You are free to create your own!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 209,
|
|
"referenced_widgets": [
|
|
"6b2cf7a462884bc4a3c2f8670f8982ed",
|
|
"8a358b880c194e6baab24f711937f695",
|
|
"0ec5dfa2fa6b4be2bfb45692ac20c053",
|
|
"1a619a0a53c149c3b9b0cbd53f1d2034",
|
|
"fdcbb976f35e4aa383d9ebcd7304849f",
|
|
"b131891a909743f5826429a84a74b4f9",
|
|
"3745da58146d445abe863a248e86d9be",
|
|
"22d29d39ef8149dba932ed9af68def67",
|
|
"413b4818b6de4bd995944df2f73bc3d5",
|
|
"c626fa7f5c624ec5a414e6f58d71ea96",
|
|
"2a8fb4b980ab45be81ae39d5c6b7bad4",
|
|
"ec1837ba407c4ecb86335d9c0c239ad9",
|
|
"1793da35014840b08d776b4abc9a1212",
|
|
"a06bde82466c43a49b58c0d942071be0",
|
|
"a1a3ee663a154319ba186954816e6932",
|
|
"97a9db845a9c45caa88ea1cd0ab6c40d",
|
|
"231be59acb2a4921ba53ba97bdf5888a",
|
|
"5aa0e699560c4fe89250b4ccac6401cb",
|
|
"7529f4d84b9f46709ec5f0e22b75baf4",
|
|
"8b8fd84911e94b5d8d627911fa9a64f7",
|
|
"420154f0bbe344a7979769a8ef066df7",
|
|
"1ee34b9171604805887aee3baa4ff099",
|
|
"3c95b86a26bc49b4a65589df0e7865e2",
|
|
"ef6d7a6a7fb344fab5befe656b98c7b7",
|
|
"c3d523c23ebf4b769c59eda950a3fc1f",
|
|
"bffc24c5bcdb47988f2e309c7fbc21b8",
|
|
"11052ab3d1f441248458cd6ab8b810db",
|
|
"0c2e39b20d5941969101b793b94fdbfa",
|
|
"f0c74de390814c2b9f38f7cc02458426",
|
|
"6c9502ae4cbc4248b14590fc6efe832f",
|
|
"8972cb9e823342bcb5c493e487dc64b6",
|
|
"71898e427b1245d682178a90c6752e8e",
|
|
"90295f1308054c66b557da231d25d3d5",
|
|
"54f444e6a25349fcb8fb35b9efb33e68",
|
|
"2f9a64ad318248bf99fccdedafaddfc0",
|
|
"d3ce3f3bee5646d893242758444acd6a",
|
|
"5b4f99c20a8a4ecf873b0a8ef8b150cd",
|
|
"70e141512e91492bae4713400879d019",
|
|
"7c67cf4d83e34cd6a475405be45f345e",
|
|
"c184e417ad134e94995a55d4bc7c94b5",
|
|
"f1c8d7356b584fb8965d2a62ccec52c8",
|
|
"9bcb0000da6e434d9b241a09d9b46244",
|
|
"11bd7b89fa794a899db9c83cb2a99df5",
|
|
"0e9d3d2244a54eb781311c01965c33e4",
|
|
"00d1bd9a327241e8b434ef7b3bde30df",
|
|
"cd17f032bbb44575977797caaa92bd54",
|
|
"04b74b7851334892b0db68bea1c8fd7e",
|
|
"44b6ecc9b9fc436e92701ec2409e0dd1",
|
|
"46998bc2b1ee4858a8ee1108ce36b9b7",
|
|
"95fa6a54e7e543c7b2dc5d5cf04daf58",
|
|
"8ee15a3d02954d31899d0b6a8dbf7988",
|
|
"1d5a58d63c51427bbbfb729216ea8468",
|
|
"e2790d1e5e714b3e9081b4edcef24535",
|
|
"96519710bf5047239f6800100471ee6e",
|
|
"e9b020ce2dda442b871b41e417210b52",
|
|
"4d18ef3f7c8046ddb0b03314861c3414",
|
|
"176ccdb40b994ee3b381f7ea0f379c3c",
|
|
"3519fdf3abac4d7690c45710f617e5d0",
|
|
"dc320ae9e2f7426588d49848b0871ca4",
|
|
"6b5a1287c457454fb367e5dc03d8b157",
|
|
"989519f3e61b4132b3f88fbdf29703a9",
|
|
"e0f87f6fd7c84261818624defe575a1c",
|
|
"e6d2b6aaa9794971b01c778b1a0596db",
|
|
"334c4833051e4ccea727dc2cdf3871c3",
|
|
"d36027362b41413fa3ee4670dd820a88",
|
|
"d94e3315b99747c39190fee9b79f00ac"
|
|
]
|
|
},
|
|
"id": "cXk993X6C2ZZ",
|
|
"outputId": "2d545127-0313-405e-d8f9-a43d7dba8d11"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import re\n",
|
|
"from datasets import load_dataset, Dataset\n",
|
|
"\n",
|
|
"# Load and prep dataset\n",
|
|
"SYSTEM_PROMPT = \"\"\"\n",
|
|
"Respond in the following format:\n",
|
|
"<reasoning>\n",
|
|
"...\n",
|
|
"</reasoning>\n",
|
|
"<answer>\n",
|
|
"...\n",
|
|
"</answer>\n",
|
|
"\"\"\"\n",
|
|
"\n",
|
|
"XML_COT_FORMAT = \"\"\"\\\n",
|
|
"<reasoning>\n",
|
|
"{reasoning}\n",
|
|
"</reasoning>\n",
|
|
"<answer>\n",
|
|
"{answer}\n",
|
|
"</answer>\n",
|
|
"\"\"\"\n",
|
|
"\n",
|
|
"\n",
|
|
"def extract_xml_answer(text: str) -> str:\n",
|
|
" answer = text.split(\"<answer>\")[-1]\n",
|
|
" answer = answer.split(\"</answer>\")[0]\n",
|
|
" return answer.strip()\n",
|
|
"\n",
|
|
"\n",
|
|
"def extract_hash_answer(text: str) -> str | None:\n",
|
|
" if \"####\" not in text:\n",
|
|
" return None\n",
|
|
" return text.split(\"####\")[1].strip()\n",
|
|
"\n",
|
|
"\n",
|
|
"# uncomment middle messages for 1-shot prompting\n",
|
|
"def get_gsm8k_questions(split=\"train\") -> Dataset:\n",
|
|
" data = load_dataset(\"openai/gsm8k\", \"main\")[split] # type: ignore\n",
|
|
" data = data.map(\n",
|
|
" lambda x: { # type: ignore\n",
|
|
" \"prompt\": [\n",
|
|
" {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
|
|
" {\"role\": \"user\", \"content\": x[\"question\"]},\n",
|
|
" ],\n",
|
|
" \"answer\": extract_hash_answer(x[\"answer\"]),\n",
|
|
" }\n",
|
|
" ) # type: ignore\n",
|
|
" return data # type: ignore\n",
|
|
"\n",
|
|
"\n",
|
|
"dataset = get_gsm8k_questions()\n",
|
|
"\n",
|
|
"\n",
|
|
"# Reward functions\n",
|
|
"def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:\n",
|
|
" responses = [completion[0][\"content\"] for completion in completions]\n",
|
|
" q = prompts[0][-1][\"content\"]\n",
|
|
" extracted_responses = [extract_xml_answer(r) for r in responses]\n",
|
|
" print(\n",
|
|
" \"-\" * 20,\n",
|
|
" f\"Question:\\n{q}\",\n",
|
|
" f\"\\nAnswer:\\n{answer[0]}\",\n",
|
|
" f\"\\nResponse:\\n{responses[0]}\",\n",
|
|
" f\"\\nExtracted:\\n{extracted_responses[0]}\",\n",
|
|
" )\n",
|
|
" return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]\n",
|
|
"\n",
|
|
"\n",
|
|
"def int_reward_func(completions, **kwargs) -> list[float]:\n",
|
|
" responses = [completion[0][\"content\"] for completion in completions]\n",
|
|
" extracted_responses = [extract_xml_answer(r) for r in responses]\n",
|
|
" return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]\n",
|
|
"\n",
|
|
"\n",
|
|
"def strict_format_reward_func(completions, **kwargs) -> list[float]:\n",
|
|
" \"\"\"Reward function that checks if the completion has a specific format.\"\"\"\n",
|
|
" pattern = r\"^<reasoning>\\n.*?\\n</reasoning>\\n<answer>\\n.*?\\n</answer>\\n$\"\n",
|
|
" responses = [completion[0][\"content\"] for completion in completions]\n",
|
|
" matches = [re.match(pattern, r) for r in responses]\n",
|
|
" return [0.5 if match else 0.0 for match in matches]\n",
|
|
"\n",
|
|
"\n",
|
|
"def soft_format_reward_func(completions, **kwargs) -> list[float]:\n",
|
|
" \"\"\"Reward function that checks if the completion has a specific format.\"\"\"\n",
|
|
" pattern = r\"<reasoning>.*?</reasoning>\\s*<answer>.*?</answer>\"\n",
|
|
" responses = [completion[0][\"content\"] for completion in completions]\n",
|
|
" matches = [re.match(pattern, r) for r in responses]\n",
|
|
" return [0.5 if match else 0.0 for match in matches]\n",
|
|
"\n",
|
|
"\n",
|
|
"def count_xml(text) -> float:\n",
|
|
" count = 0.0\n",
|
|
" if text.count(\"<reasoning>\\n\") == 1:\n",
|
|
" count += 0.125\n",
|
|
" if text.count(\"\\n</reasoning>\\n\") == 1:\n",
|
|
" count += 0.125\n",
|
|
" if text.count(\"\\n<answer>\\n\") == 1:\n",
|
|
" count += 0.125\n",
|
|
" count -= len(text.split(\"\\n</answer>\\n\")[-1]) * 0.001\n",
|
|
" if text.count(\"\\n</answer>\") == 1:\n",
|
|
" count += 0.125\n",
|
|
" count -= (len(text.split(\"\\n</answer>\")[-1]) - 1) * 0.001\n",
|
|
" return count\n",
|
|
"\n",
|
|
"\n",
|
|
"def xmlcount_reward_func(completions, **kwargs) -> list[float]:\n",
|
|
" contents = [completion[0][\"content\"] for completion in completions]\n",
|
|
" return [count_xml(c) for c in contents]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "bTnL_tJnzh2L"
|
|
},
|
|
"source": [
|
|
"<a name=\"Train\"></a>\n",
|
|
"### Train the model\n",
|
|
"\n",
|
|
"Now set up GRPO Trainer and all configurations!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "ptqkXK2D4d6p",
|
|
"outputId": "7d36a642-4cc8-48ee-c8cc-1470978d4503"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from trl import GRPOConfig, GRPOTrainer\n",
|
|
"\n",
|
|
"training_args = GRPOConfig(\n",
|
|
" use_vllm=True, # use vLLM for fast inference!\n",
|
|
" learning_rate=5e-6,\n",
|
|
" adam_beta1=0.9,\n",
|
|
" adam_beta2=0.99,\n",
|
|
" weight_decay=0.1,\n",
|
|
" warmup_ratio=0.1,\n",
|
|
" lr_scheduler_type=\"cosine\",\n",
|
|
" optim=\"adamw_8bit\",\n",
|
|
" logging_steps=1,\n",
|
|
" bf16=is_bfloat16_supported(),\n",
|
|
" fp16=not is_bfloat16_supported(),\n",
|
|
" per_device_train_batch_size=1,\n",
|
|
" gradient_accumulation_steps=1, # Increase to 4 for smoother training\n",
|
|
" num_generations=8, # Decrease if out of memory\n",
|
|
" max_prompt_length=256,\n",
|
|
" max_completion_length=200,\n",
|
|
" # num_train_epochs = 1, # Set to 1 for a full training run\n",
|
|
" max_steps=250,\n",
|
|
" save_steps=250,\n",
|
|
" max_grad_norm=0.1,\n",
|
|
" report_to=\"none\", # Can use Weights & Biases\n",
|
|
" output_dir=\"outputs\",\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "X_71Y0eKz5yE"
|
|
},
|
|
"source": [
|
|
"And let's run the trainer! If you scroll up, you'll see a table of rewards. The goal is to see the `reward` column increase!\n",
|
|
"\n",
|
|
"You might have to wait 150 to 200 steps for any action. You'll probably get 0 reward for the first 100 steps. Please be patient!\n",
|
|
"\n",
|
|
"| Step | Training Loss | reward | reward_std | completion_length | kl |\n",
|
|
"|------|---------------|-----------|------------|-------------------|----------|\n",
|
|
"| 1 | 0.000000 | 0.125000 | 0.000000 | 200.000000 | 0.000000 |\n",
|
|
"| 2 | 0.000000 | 0.072375 | 0.248112 | 200.000000 | 0.000000 |\n",
|
|
"| 3 | 0.000000 | -0.079000 | 0.163776 | 182.500000 | 0.000005 |\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 1000
|
|
},
|
|
"id": "vzOuSVCL_GA9",
|
|
"outputId": "c6fdbbe8-7062-45a2-a48c-649f2b878ac4"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"trainer = GRPOTrainer(\n",
|
|
" model=model,\n",
|
|
" processing_class=tokenizer,\n",
|
|
" reward_funcs=[\n",
|
|
" xmlcount_reward_func,\n",
|
|
" soft_format_reward_func,\n",
|
|
" strict_format_reward_func,\n",
|
|
" int_reward_func,\n",
|
|
" correctness_reward_func,\n",
|
|
" ],\n",
|
|
" args=training_args,\n",
|
|
" train_dataset=dataset,\n",
|
|
")\n",
|
|
"trainer.train()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "yUbluAAhD0Lg"
|
|
},
|
|
"source": [
|
|
"<a name=\"Inference\"></a>\n",
|
|
"### Inference\n",
|
|
"Now let's try the model we just trained! First, let's first try the model without any GRPO trained:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 72
|
|
},
|
|
"id": "IqzsdZzeDM_m",
|
|
"outputId": "a7637565-80fe-4ee9-cdc9-33b9affb65dc"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"text = tokenizer.apply_chat_template(\n",
|
|
" [\n",
|
|
" {\"role\": \"user\", \"content\": \"How many r's are in strawberry?\"},\n",
|
|
" ],\n",
|
|
" tokenize=False,\n",
|
|
" add_generation_prompt=True,\n",
|
|
")\n",
|
|
"\n",
|
|
"from vllm import SamplingParams\n",
|
|
"\n",
|
|
"sampling_params = SamplingParams(\n",
|
|
" temperature=0.8,\n",
|
|
" top_p=0.95,\n",
|
|
" max_tokens=1024,\n",
|
|
")\n",
|
|
"output = (\n",
|
|
" model.fast_generate(\n",
|
|
" [text],\n",
|
|
" sampling_params=sampling_params,\n",
|
|
" lora_request=None,\n",
|
|
" )[0]\n",
|
|
" .outputs[0]\n",
|
|
" .text\n",
|
|
")\n",
|
|
"\n",
|
|
"output"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "G4lzJD7REFjs"
|
|
},
|
|
"source": [
|
|
"And now with the LoRA we just trained with GRPO - we first save the LoRA first!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "YC9BBT0RESln"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"model.save_lora(\"grpo_saved_lora\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "LherO2vzEbMt"
|
|
},
|
|
"source": [
|
|
"Now we load the LoRA and test:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 125
|
|
},
|
|
"id": "SDKIhhvN6lAF",
|
|
"outputId": "8bc37894-678f-40d7-874b-365300685836"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"text = tokenizer.apply_chat_template(\n",
|
|
" [\n",
|
|
" {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
|
|
" {\"role\": \"user\", \"content\": \"How many r's are in strawberry?\"},\n",
|
|
" ],\n",
|
|
" tokenize=False,\n",
|
|
" add_generation_prompt=True,\n",
|
|
")\n",
|
|
"\n",
|
|
"from vllm import SamplingParams\n",
|
|
"\n",
|
|
"sampling_params = SamplingParams(\n",
|
|
" temperature=0.8,\n",
|
|
" top_p=0.95,\n",
|
|
" max_tokens=1024,\n",
|
|
")\n",
|
|
"output = (\n",
|
|
" model.fast_generate(\n",
|
|
" text,\n",
|
|
" sampling_params=sampling_params,\n",
|
|
" lora_request=model.load_lora(\"grpo_saved_lora\"),\n",
|
|
" )[0]\n",
|
|
" .outputs[0]\n",
|
|
" .text\n",
|
|
")\n",
|
|
"\n",
|
|
"output"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "_ZBnvg2f9Nlg"
|
|
},
|
|
"source": [
|
|
"Our reasoning model is much better - it's not always correct, since we only trained it for an hour or so - it'll be better if we extend the sequence length and train for longer!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "RphEZRSfFhru"
|
|
},
|
|
"source": [
|
|
"<a name=\"Save\"></a>\n",
|
|
"### Saving to float16 for VLLM\n",
|
|
"\n",
|
|
"We also support saving to `float16` directly. Select `merged_16bit` for float16 or `merged_4bit` for int4. We also allow `lora` adapters as a fallback. Use `push_to_hub_merged` to upload to your Hugging Face account! You can go to https://huggingface.co/settings/tokens for your personal tokens."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "GwNY9_PrFiXZ"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Merge to 16bit\n",
|
|
"if False:\n",
|
|
" model.save_pretrained_merged(\n",
|
|
" \"model\",\n",
|
|
" tokenizer,\n",
|
|
" save_method=\"merged_16bit\",\n",
|
|
" )\n",
|
|
"if False:\n",
|
|
" model.push_to_hub_merged(\n",
|
|
" \"hf/model\", tokenizer, save_method=\"merged_16bit\", token=\"\"\n",
|
|
" )\n",
|
|
"\n",
|
|
"# Merge to 4bit\n",
|
|
"if False:\n",
|
|
" model.save_pretrained_merged(\n",
|
|
" \"model\",\n",
|
|
" tokenizer,\n",
|
|
" save_method=\"merged_4bit\",\n",
|
|
" )\n",
|
|
"if False:\n",
|
|
" model.push_to_hub_merged(\"hf/model\", tokenizer, save_method=\"merged_4bit\", token=\"\")\n",
|
|
"\n",
|
|
"# Just LoRA adapters\n",
|
|
"if False:\n",
|
|
" model.save_pretrained_merged(\n",
|
|
" \"model\",\n",
|
|
" tokenizer,\n",
|
|
" save_method=\"lora\",\n",
|
|
" )\n",
|
|
"if False:\n",
|
|
" model.push_to_hub_merged(\"hf/model\", tokenizer, save_method=\"lora\", token=\"\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "BDUGPiL3Fkkq"
|
|
},
|
|
"source": [
|
|
"### GGUF / llama.cpp Conversion\n",
|
|
"To save to `GGUF` / `llama.cpp`, we support it natively now! We clone `llama.cpp` and we default save it to `q8_0`. We allow all methods like `q4_k_m`. Use `save_pretrained_gguf` for local saving and `push_to_hub_gguf` for uploading to HF.\n",
|
|
"\n",
|
|
"Some supported quant methods (full list on our [Wiki page](https://github.com/unslothai/unsloth/wiki#gguf-quantization-options)):\n",
|
|
"* `q8_0` - Fast conversion. High resource use, but generally acceptable.\n",
|
|
"* `q4_k_m` - Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q4_K.\n",
|
|
"* `q5_k_m` - Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q5_K.\n",
|
|
"\n",
|
|
"[**NEW**] To finetune and auto export to Ollama, try our [Ollama notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "AGo4dbWvFk4M"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Save to 8bit Q8_0\n",
|
|
"if False:\n",
|
|
" model.save_pretrained_gguf(\n",
|
|
" \"model\",\n",
|
|
" tokenizer,\n",
|
|
" )\n",
|
|
"# Remember to go to https://huggingface.co/settings/tokens for a token!\n",
|
|
"# And change hf to your username!\n",
|
|
"if False:\n",
|
|
" model.push_to_hub_gguf(\"hf/model\", tokenizer, token=\"\")\n",
|
|
"\n",
|
|
"# Save to 16bit GGUF\n",
|
|
"if False:\n",
|
|
" model.save_pretrained_gguf(\"model\", tokenizer, quantization_method=\"f16\")\n",
|
|
"if False:\n",
|
|
" model.push_to_hub_gguf(\"hf/model\", tokenizer, quantization_method=\"f16\", token=\"\")\n",
|
|
"\n",
|
|
"# Save to q4_k_m GGUF\n",
|
|
"if False:\n",
|
|
" model.save_pretrained_gguf(\"model\", tokenizer, quantization_method=\"q4_k_m\")\n",
|
|
"if False:\n",
|
|
" model.push_to_hub_gguf(\n",
|
|
" \"hf/model\", tokenizer, quantization_method=\"q4_k_m\", token=\"\"\n",
|
|
" )\n",
|
|
"\n",
|
|
"# Save to multiple GGUF options - much faster if you want multiple!\n",
|
|
"if False:\n",
|
|
" model.push_to_hub_gguf(\n",
|
|
" \"hf/model\", # Change hf to your username!\n",
|
|
" tokenizer,\n",
|
|
" quantization_method=[\n",
|
|
" \"q4_k_m\",\n",
|
|
" \"q8_0\",\n",
|
|
" \"q5_k_m\",\n",
|
|
" ],\n",
|
|
" token=\"\",\n",
|
|
" )"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "bhYfRM8PxWAu"
|
|
},
|
|
"source": [
|
|
"Now, use the `model-unsloth.gguf` file or `model-unsloth-Q4_K_M.gguf` file in llama.cpp or a UI based system like Jan or Open WebUI. You can install Jan [here](https://github.com/janhq/jan) and Open WebUI [here](https://github.com/open-webui/open-webui)\n",
|
|
"\n",
|
|
"And we're done! If you have any questions on Unsloth, we have a [Discord](https://discord.gg/unsloth) channel! If you find any bugs or want to keep updated with the latest LLM stuff, or need help, join projects etc, feel free to join our Discord!\n",
|
|
"\n",
|
|
"Some other links:\n",
|
|
"1. Train your own reasoning model - Llama GRPO notebook [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-GRPO.ipynb)\n",
|
|
"2. Saving finetunes to Ollama. [Free notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)\n",
|
|
"3. Llama 3.2 Vision finetuning - Radiography use case. [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(11B)-Vision.ipynb)\n",
|
|
"6. See notebooks for DPO, ORPO, Continued pretraining, conversational finetuning and more on our [documentation](https://docs.unsloth.ai/get-started/unsloth-notebooks)!\n",
|
|
"\n",
|
|
"<div class=\"align-center\">\n",
|
|
" <a href=\"https://unsloth.ai\"><img src=\"https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png\" width=\"115\"></a>\n",
|
|
" <a href=\"https://discord.gg/unsloth\"><img src=\"https://github.com/unslothai/unsloth/raw/main/images/Discord.png\" width=\"145\"></a>\n",
|
|
" <a href=\"https://docs.unsloth.ai/\"><img src=\"https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true\" width=\"125\"></a>\n",
|
|
"\n",
|
|
" Join Discord if you need help + ⭐️ <i>Star us on <a href=\"https://github.com/unslothai/unsloth\">Github</a> </i> ⭐️\n",
|
|
"</div>\n"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"accelerator": "GPU",
|
|
"colab": {
|
|
"gpuType": "T4",
|
|
"provenance": []
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"name": "python"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0
|
|
}
|