{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "wAwm-SoKNjoX"
},
"source": [
"# Sample GRPO Training Notebook\n",
"- Stolen from unsloth :3 thanks!\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rv8V30lINjob"
},
"source": [
"### Installation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1iVpHRPXNjoc"
},
"outputs": [],
"source": [
"# @title Colab Extra Install { display-mode: \"form\" }\n",
"import os\n",
"\n",
"if \"COLAB_\" not in \"\".join(os.environ.keys()):\n",
" !pip install -q unsloth vllm\n",
"else:\n",
" !pip install -q --no-deps unsloth vllm\n",
" # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]\n",
" # Skip restarting message in Colab\n",
" import sys, re, requests\n",
"\n",
" modules = list(sys.modules.keys())\n",
" for x in modules:\n",
" sys.modules.pop(x) if \"PIL\" in x or \"google\" in x else None\n",
" !pip install -q --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo\n",
" !pip install -q sentencepiece protobuf datasets huggingface_hub hf_transfer\n",
"\n",
" # vLLM requirements - vLLM breaks Colab due to reinstalling numpy\n",
" f = requests.get(\n",
" \"https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/requirements/common.txt\"\n",
" ).content\n",
" with open(\"vllm_requirements.txt\", \"wb\") as file:\n",
" file.write(re.sub(rb\"(transformers|numpy|xformers)[^\\n]{1,}\\n\", b\"\", f))\n",
" !pip install -q -r vllm_requirements.txt"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HnqT9aF5Njoc"
},
"source": [
"### Unsloth"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IpAO-idyNjoc"
},
"source": [
"Load up `Llama 3.1 8B Instruct`, and set parameters"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000,
"referenced_widgets": [
"09fca7e3b90f4b17a2344b67100e26bf",
"6eee6e73473440a39c2a14d11b19e7b9",
"f64c0803b3c34610a34f9c5f2b5ce6ba",
"4bae66db2fc041278f13ae1def5d1ddd",
"ce588ed5880240d7bfd91c150ef299ed",
"bc92bf99bcf345b59699eb8f0423f7cd",
"f589e115f63d4fdb95816dccb34c3e3e",
"fb45312477f64384912ba3e843efb0b3",
"4e541334eb26485ca5764b9bf0bbbd34",
"fe9dbcf4e85646cb8c684d82252b93e2",
"d36f0dda703c47cfae6f544e7e655806",
"cc9a242d278d4ea1b0f0f44d611439a2",
"9ec23dd10394466eb6db5336912f09de",
"e72af3cf2aac44d3a10b692b5fd54afe",
"7a0041d1e5c349b4b70e96554ea18113",
"f3fe4eb2a2fc4b2fb255632272dc0415",
"b095d3eac38847c0b087ffcc5876e0b6",
"9ebd8c869bc144cf8142375e7d7f04dd",
"48403d21996b4ebfa931cf260af6c283",
"b019cf9965424856b3fd781bd9e73118",
"ff032a23470845ba95a12227baf0e0ce",
"18d2fd154ffc4671b76f00760eb35571",
"bc8f8134f144435681cbc50f9f5e5274",
"d7a1980bf7cc4895aea14df80bf92302",
"6be20c9739de44ffb7ef2465ad17424d",
"39411144c4c44eae935262bc6312d5b9",
"32f3e3f08eb1441ca70191eaa6191805",
"2df204a0982b466eb3c25bf80645b6c1",
"9ec592e894b945cea7c2a756a96eb959",
"7ff47ae65e944dc9bbd50f01edb7fa02",
"2f5bc26300174430aa5b0610b5ef4c72",
"79da5ea5195c4c63b0ec1ee354511a87",
"a5f3704359df48e4b4c5b418a5f72021",
"2cae8f5f92304278bcfa4a07fb5aafdc",
"63681e5808d1443c9f04654d58952448",
"91200f911ca64162b006235d0ec0c5f1",
"40c4dfad109643529a16b3de307e2e1a",
"491466071e6742c6ab30419597d78fde",
"a67ca6e6c91244978b6b468257ffefd7",
"781f0652012b424aa43cd39028e3a99a",
"9e4ae8a9e1b247ffbd232d97be270cd8",
"c558b6c8e51046c6b3e0809ccf749348",
"0ea62e14a2d44b8ca7c6f15bca8e1a2b",
"27e78ef30ff64be5af79b6431816f569",
"b973da5783dc49d1802831298be3ab7e",
"d7e3f2c2441743089ee7ed10eb0319dd",
"7f490169dd8a4dbeb367af5259b0949e",
"2a8b54de6a974cc3a3f0b47daa10af89",
"70780d08fb00445da32b28b40d33412f",
"2c83de81fc91497db859d1be6860d6f4",
"397a49fc365b4c3a95db709b71fc32e3",
"1e75c59df9524dec928b21d65dd43c08",
"747b264f119746c08bbd6685c134ec4f",
"364d9414fd5f4280be4ad621874e15cb",
"7b3ad76804bd47a68fba0bf9f967cb04",
"2c34cb165ce447dca52c25eff7cc91ce",
"e8f729e29f7c4639a10a02a9f4978385",
"f60723e017f64a2183af5f05ca827340",
"165ed67ab04240fca77e70e6a905ad48",
"e229f6a91dcf4b5182673cc6c7002063",
"fe4ecfe3522c432b964608f137cfb936",
"e1e8c601a2bb4c0b8f0cdf3bf0dd0b2e",
"db89c098dd414fa08440071935414b63",
"efb3aff9a97340fb8acf5f98402e7c1d",
"d4eeb17c3a124c919058c0ffe72a60b6",
"8b6bd37f64454464b44691ae3bef0b9f",
"239c67254f4044259fcbb35cb78d72b3",
"7bdc3ab924804e6badd7f403d0ce33f1",
"2b1ff4602dee4511a450588c501a7174",
"bfae7687d2a440189c29dc53e47b71ab",
"c72dbf9684ed4cf0a55655ebaf645b96",
"7083c21ed3dd4ed6b442afdf4a18a05d",
"6e4ce08c86f24beeb63d38d436143dc9",
"86498c04c2a347968e65867eb4ed2102",
"dc90aea1a6e445ebaf68bb305586bf47",
"4c69d53cf33c4865af8feead04b810ad",
"9dbc67661c614d8a82d0790d9b694e6d",
"97f09dcebaf94be3b115a69b1440392b",
"a568aff772814128b0ad329594e67610",
"92ac502bff6e48d98cbb07a65d84dca6",
"0b9dd1c421f14c56986cf30abe0a9455",
"50a3cb7fc0ce45bb8c1baabc2a54fff2",
"fac3506146244766b886c715ea57ec7b",
"d068e96c067e43ffa1777a666b0df147",
"9e9b1a12028443c9abe6d88d0ed39d15",
"638c59bf00ba4536ad157975faa26984",
"73177ee1d64c4f56aa393c45c42e7499",
"98aabe0e4167414d943ff3d7759093ad",
"0f040a5765854b699b32bf096bb0563c",
"813e8e40df6447c18e47414ab38f69a2",
"533f84f90e7247e8a60da8b6ce60fe89",
"80c5f4171f2844749b5f3e67a46178b9",
"0722b569fa2e45f9af7d4fb8134e632b",
"115a8f3c523c4cd3aaa472af87fe0e4a",
"26cde13f97db4ae4a0900fb5184a51a2",
"5c4e90490762447c9712372a18f48e20",
"690c64fbee9f4c038a29030086c73368",
"cbe9c4e2f8bc49dfb1d8a4cddc36d724",
"f33aef89be7641a694fb11e3101c7d85",
"6798efcf0f784fca91a2e84c91553296",
"190504fd7fec4f08ba43b972ed28d23b",
"142703403e5e4ba1ac1531d7c82c752a",
"0a45b8dd65764272869fd1424415e6af",
"d536479db8034c99be7e745a440baf7d",
"594a583be494419c845fd873904d4d6d",
"ba21e0c745714031aae12b1665847e54",
"3491a152831245f7b03d06fe94f42116",
"b80d4e38b5c54dd585d1da848ee9c4c8",
"fa155f662d1c40f496daef290ff7a5ff",
"9b2a88f9712d409bbf74bf0c2e14df16"
]
},
"id": "DkIvEkIIkEyB",
"outputId": "e1aaf3b5-6ad0-45c7-be08-2017e74ebba9"
},
"outputs": [],
"source": [
"from unsloth import FastLanguageModel\n",
"import torch\n",
"\n",
"max_seq_length = 1024 # Can increase for longer reasoning traces\n",
"lora_rank = 32 # Larger rank = smarter, but slower\n",
"\n",
"model, tokenizer = FastLanguageModel.from_pretrained(\n",
" model_name=\"meta-llama/meta-Llama-3.1-8B-Instruct\",\n",
" max_seq_length=max_seq_length,\n",
" load_in_4bit=True, # False for LoRA 16bit\n",
" fast_inference=True, # Enable vLLM fast inference\n",
" max_lora_rank=lora_rank,\n",
" gpu_memory_utilization=0.6, # Reduce if out of memory\n",
")\n",
"\n",
"model = FastLanguageModel.get_peft_model(\n",
" model,\n",
" r=lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128\n",
" target_modules=[\n",
" \"q_proj\",\n",
" \"k_proj\",\n",
" \"v_proj\",\n",
" \"o_proj\",\n",
" \"gate_proj\",\n",
" \"up_proj\",\n",
" \"down_proj\",\n",
" ], # Remove QKVO if out of memory\n",
" lora_alpha=lora_rank,\n",
" use_gradient_checkpointing=\"unsloth\", # Enable long context finetuning\n",
" random_state=3407,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7KGgPgk_5S8r"
},
"source": [
"### Data Prep\n",
"\n",
"\n",
"We directly leverage [@willccbb](https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb) for data prep and all reward functions. You are free to create your own!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 321,
"referenced_widgets": [
"5954f25ab3334f8c8e0576ee75922c0c",
"5db92327420a48f4a1ad966ecd7ac8c5",
"1a1651d5ca7d4ccd887235343e49ffb3",
"22462b8181964257a53c85be6265eef6",
"1c09f4dc86124b1296ed96fc48b8d856",
"5e9d2ad532a8447a904688629a91f171",
"2df38622a5ab46979ee53efeb8c31b91",
"a75b75f7dec54759816845c3fda31f22",
"b14a63b4cac142c5adaaebbfef8739fe",
"1a090b75a50b459a8d2f822efc692dfa",
"0da6c9fa4eeb4b4d832f8bbc614bc8d9",
"3e1954a710bd43d1907b702c7f70c6de",
"483db1483d1d4a97a59c1e6cdad04c6a",
"cf88de7717504af691ca8db0fb68a6b6",
"d4de07c4656c40329bc2c1450bd94bb5",
"160f401dc9a74266bb610e747f8f711e",
"269f762b5af041eab750226febe70e07",
"f38c687b78224c8aa2227914998b4627",
"eca5ebef74f341f79dc96964d0189862",
"009ea1226a194cb2923ec9bc3a6e1843",
"2aade13ff2ca4250ae02b7a12b1ef16d",
"23531535f2fc4b99a5ecd1b63fb93d80",
"a137a1bc2b4f437a8477e079f003ad8d",
"05836f874c9a49d390721a05ff055ac9",
"9032130a4a9646db92805eeb4221c83d",
"ac58943ffd184cdca0b69cce35f54ec7",
"8dc3c53fac5644a6a2841aca62226300",
"fda3ed1250d34b12aae8401287322548",
"e85f549d21dd4c3293f8239a053b0ad4",
"a8d0ef0ef4f747e294394b04d07cbe34",
"15b83f9309a642d2bb1bc79da539b923",
"436282df7bcc45c4b951de60f2e4ec53",
"911d088df0894243b69fd4f69c46e7f3",
"eba35bc404f640bcb087de7a7ce74942",
"6b39fce0206f4ca894b943a42eb9dc04",
"c3ec8a75007846519d40f6726d2a01ee",
"de3486da03594b0091b77cbf702401a5",
"2e1b5b6b3ae84190a7f50993d78fb6ea",
"81f43f522faa4dcf9e557793a2027e90",
"ae46c8f17c0644fdbd1aa6d412c322b7",
"5de480828775465bb76b674ecd713bd0",
"a159d55b4de84020bc127d6c7758d2a9",
"cc6097b25ec74decb51a397ccd9fcf27",
"d2af3c19e6ec46a3915f6badf7062b7b",
"8d2463ee53394dc097843bc733191c97",
"f681a0ff692c4cd3b14b1e064547c52f",
"0ecef6aade754029ba939e30c91538be",
"87e25abbd016411199a2e976cd7ab550",
"93f708b722194750a9f66a2597987033",
"76dc9524ee764741a1615de6bc43492e",
"78201e7953b647ee9656fc7e8f4abbcb",
"0bc45e643a9b4a0e8db9f85180c108b2",
"cfb5247cf65c4bdfbcf8207c5d8c72c8",
"32fe42f8a11d4131a3a529b554f028fe",
"39c33d513fb74712be8fcf73093a0c9e",
"908333eedde74ec3a0885e52454b3627",
"934617e655d04f85a87d6444b7c1ff0e",
"63d0c36c4bb8469f8cf5ac410d2c11d6",
"092d58ba2c444878a1cbfd19bdc14027",
"c8ecf533f45f4ae29bdcfdaeca7daedf",
"d89eae07da0e4344909900e24e3d0d09",
"b1e00fc06b664fc59d4921388f187269",
"b834a1ed2fc14d209c405969def96965",
"dc681483ceeb4e878124858ffb467338",
"ef2a701c9c594d4d9ffc8379fa9b5899",
"22c31da7c0564f09bbdb04efdefad11a"
]
},
"id": "cXk993X6C2ZZ",
"outputId": "d6b161df-4d15-4cae-a02c-22fb901b91f1"
},
"outputs": [],
"source": [
"import re\n",
"from datasets import load_dataset, Dataset\n",
"\n",
"# Load and prep dataset\n",
"SYSTEM_PROMPT = \"\"\"\n",
"Respond in the following format:\n",
"