You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
836 lines
28 KiB
836 lines
28 KiB
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "wAwm-SoKNjoX"
|
|
},
|
|
"source": [
|
|
"# Sample GRPO Training Notebook\n",
|
|
"- Stolen from unsloth :3 thanks!\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "rv8V30lINjob"
|
|
},
|
|
"source": [
|
|
"### Installation"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "1iVpHRPXNjoc"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# @title Colab Extra Install { display-mode: \"form\" }\n",
|
|
"import os\n",
|
|
"\n",
|
|
"if \"COLAB_\" not in \"\".join(os.environ.keys()):\n",
|
|
" !pip install -q unsloth vllm\n",
|
|
"else:\n",
|
|
" !pip install -q --no-deps unsloth vllm\n",
|
|
" # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]\n",
|
|
" # Skip restarting message in Colab\n",
|
|
" import sys\n",
|
|
" import re\n",
|
|
" import requests\n",
|
|
"\n",
|
|
" modules = list(sys.modules.keys())\n",
|
|
" for x in modules:\n",
|
|
" sys.modules.pop(x) if \"PIL\" in x or \"google\" in x else None\n",
|
|
" !pip install -q --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo\n",
|
|
" !pip install -q sentencepiece protobuf datasets huggingface_hub hf_transfer\n",
|
|
"\n",
|
|
" # vLLM requirements - vLLM breaks Colab due to reinstalling numpy\n",
|
|
" f = requests.get(\n",
|
|
" \"https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/requirements/common.txt\"\n",
|
|
" ).content\n",
|
|
" with open(\"vllm_requirements.txt\", \"wb\") as file:\n",
|
|
" file.write(re.sub(rb\"(transformers|numpy|xformers)[^\\n]{1,}\\n\", b\"\", f))\n",
|
|
" !pip install -q -r vllm_requirements.txt"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "HnqT9aF5Njoc"
|
|
},
|
|
"source": [
|
|
"### Unsloth"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "IpAO-idyNjoc"
|
|
},
|
|
"source": [
|
|
"Load up `Llama 3.1 8B Instruct`, and set parameters"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 1000,
|
|
"referenced_widgets": [
|
|
"09fca7e3b90f4b17a2344b67100e26bf",
|
|
"6eee6e73473440a39c2a14d11b19e7b9",
|
|
"f64c0803b3c34610a34f9c5f2b5ce6ba",
|
|
"4bae66db2fc041278f13ae1def5d1ddd",
|
|
"ce588ed5880240d7bfd91c150ef299ed",
|
|
"bc92bf99bcf345b59699eb8f0423f7cd",
|
|
"f589e115f63d4fdb95816dccb34c3e3e",
|
|
"fb45312477f64384912ba3e843efb0b3",
|
|
"4e541334eb26485ca5764b9bf0bbbd34",
|
|
"fe9dbcf4e85646cb8c684d82252b93e2",
|
|
"d36f0dda703c47cfae6f544e7e655806",
|
|
"cc9a242d278d4ea1b0f0f44d611439a2",
|
|
"9ec23dd10394466eb6db5336912f09de",
|
|
"e72af3cf2aac44d3a10b692b5fd54afe",
|
|
"7a0041d1e5c349b4b70e96554ea18113",
|
|
"f3fe4eb2a2fc4b2fb255632272dc0415",
|
|
"b095d3eac38847c0b087ffcc5876e0b6",
|
|
"9ebd8c869bc144cf8142375e7d7f04dd",
|
|
"48403d21996b4ebfa931cf260af6c283",
|
|
"b019cf9965424856b3fd781bd9e73118",
|
|
"ff032a23470845ba95a12227baf0e0ce",
|
|
"18d2fd154ffc4671b76f00760eb35571",
|
|
"bc8f8134f144435681cbc50f9f5e5274",
|
|
"d7a1980bf7cc4895aea14df80bf92302",
|
|
"6be20c9739de44ffb7ef2465ad17424d",
|
|
"39411144c4c44eae935262bc6312d5b9",
|
|
"32f3e3f08eb1441ca70191eaa6191805",
|
|
"2df204a0982b466eb3c25bf80645b6c1",
|
|
"9ec592e894b945cea7c2a756a96eb959",
|
|
"7ff47ae65e944dc9bbd50f01edb7fa02",
|
|
"2f5bc26300174430aa5b0610b5ef4c72",
|
|
"79da5ea5195c4c63b0ec1ee354511a87",
|
|
"a5f3704359df48e4b4c5b418a5f72021",
|
|
"2cae8f5f92304278bcfa4a07fb5aafdc",
|
|
"63681e5808d1443c9f04654d58952448",
|
|
"91200f911ca64162b006235d0ec0c5f1",
|
|
"40c4dfad109643529a16b3de307e2e1a",
|
|
"491466071e6742c6ab30419597d78fde",
|
|
"a67ca6e6c91244978b6b468257ffefd7",
|
|
"781f0652012b424aa43cd39028e3a99a",
|
|
"9e4ae8a9e1b247ffbd232d97be270cd8",
|
|
"c558b6c8e51046c6b3e0809ccf749348",
|
|
"0ea62e14a2d44b8ca7c6f15bca8e1a2b",
|
|
"27e78ef30ff64be5af79b6431816f569",
|
|
"b973da5783dc49d1802831298be3ab7e",
|
|
"d7e3f2c2441743089ee7ed10eb0319dd",
|
|
"7f490169dd8a4dbeb367af5259b0949e",
|
|
"2a8b54de6a974cc3a3f0b47daa10af89",
|
|
"70780d08fb00445da32b28b40d33412f",
|
|
"2c83de81fc91497db859d1be6860d6f4",
|
|
"397a49fc365b4c3a95db709b71fc32e3",
|
|
"1e75c59df9524dec928b21d65dd43c08",
|
|
"747b264f119746c08bbd6685c134ec4f",
|
|
"364d9414fd5f4280be4ad621874e15cb",
|
|
"7b3ad76804bd47a68fba0bf9f967cb04",
|
|
"2c34cb165ce447dca52c25eff7cc91ce",
|
|
"e8f729e29f7c4639a10a02a9f4978385",
|
|
"f60723e017f64a2183af5f05ca827340",
|
|
"165ed67ab04240fca77e70e6a905ad48",
|
|
"e229f6a91dcf4b5182673cc6c7002063",
|
|
"fe4ecfe3522c432b964608f137cfb936",
|
|
"e1e8c601a2bb4c0b8f0cdf3bf0dd0b2e",
|
|
"db89c098dd414fa08440071935414b63",
|
|
"efb3aff9a97340fb8acf5f98402e7c1d",
|
|
"d4eeb17c3a124c919058c0ffe72a60b6",
|
|
"8b6bd37f64454464b44691ae3bef0b9f",
|
|
"239c67254f4044259fcbb35cb78d72b3",
|
|
"7bdc3ab924804e6badd7f403d0ce33f1",
|
|
"2b1ff4602dee4511a450588c501a7174",
|
|
"bfae7687d2a440189c29dc53e47b71ab",
|
|
"c72dbf9684ed4cf0a55655ebaf645b96",
|
|
"7083c21ed3dd4ed6b442afdf4a18a05d",
|
|
"6e4ce08c86f24beeb63d38d436143dc9",
|
|
"86498c04c2a347968e65867eb4ed2102",
|
|
"dc90aea1a6e445ebaf68bb305586bf47",
|
|
"4c69d53cf33c4865af8feead04b810ad",
|
|
"9dbc67661c614d8a82d0790d9b694e6d",
|
|
"97f09dcebaf94be3b115a69b1440392b",
|
|
"a568aff772814128b0ad329594e67610",
|
|
"92ac502bff6e48d98cbb07a65d84dca6",
|
|
"0b9dd1c421f14c56986cf30abe0a9455",
|
|
"50a3cb7fc0ce45bb8c1baabc2a54fff2",
|
|
"fac3506146244766b886c715ea57ec7b",
|
|
"d068e96c067e43ffa1777a666b0df147",
|
|
"9e9b1a12028443c9abe6d88d0ed39d15",
|
|
"638c59bf00ba4536ad157975faa26984",
|
|
"73177ee1d64c4f56aa393c45c42e7499",
|
|
"98aabe0e4167414d943ff3d7759093ad",
|
|
"0f040a5765854b699b32bf096bb0563c",
|
|
"813e8e40df6447c18e47414ab38f69a2",
|
|
"533f84f90e7247e8a60da8b6ce60fe89",
|
|
"80c5f4171f2844749b5f3e67a46178b9",
|
|
"0722b569fa2e45f9af7d4fb8134e632b",
|
|
"115a8f3c523c4cd3aaa472af87fe0e4a",
|
|
"26cde13f97db4ae4a0900fb5184a51a2",
|
|
"5c4e90490762447c9712372a18f48e20",
|
|
"690c64fbee9f4c038a29030086c73368",
|
|
"cbe9c4e2f8bc49dfb1d8a4cddc36d724",
|
|
"f33aef89be7641a694fb11e3101c7d85",
|
|
"6798efcf0f784fca91a2e84c91553296",
|
|
"190504fd7fec4f08ba43b972ed28d23b",
|
|
"142703403e5e4ba1ac1531d7c82c752a",
|
|
"0a45b8dd65764272869fd1424415e6af",
|
|
"d536479db8034c99be7e745a440baf7d",
|
|
"594a583be494419c845fd873904d4d6d",
|
|
"ba21e0c745714031aae12b1665847e54",
|
|
"3491a152831245f7b03d06fe94f42116",
|
|
"b80d4e38b5c54dd585d1da848ee9c4c8",
|
|
"fa155f662d1c40f496daef290ff7a5ff",
|
|
"9b2a88f9712d409bbf74bf0c2e14df16"
|
|
]
|
|
},
|
|
"id": "DkIvEkIIkEyB",
|
|
"outputId": "e1aaf3b5-6ad0-45c7-be08-2017e74ebba9"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from unsloth import FastLanguageModel\n",
|
|
"\n",
|
|
"max_seq_length = 1024 # Can increase for longer reasoning traces\n",
|
|
"lora_rank = 32 # Larger rank = smarter, but slower\n",
|
|
"\n",
|
|
"model, tokenizer = FastLanguageModel.from_pretrained(\n",
|
|
" model_name=\"meta-llama/meta-Llama-3.1-8B-Instruct\",\n",
|
|
" max_seq_length=max_seq_length,\n",
|
|
" load_in_4bit=True, # False for LoRA 16bit\n",
|
|
" fast_inference=True, # Enable vLLM fast inference\n",
|
|
" max_lora_rank=lora_rank,\n",
|
|
" gpu_memory_utilization=0.6, # Reduce if out of memory\n",
|
|
")\n",
|
|
"\n",
|
|
"model = FastLanguageModel.get_peft_model(\n",
|
|
" model,\n",
|
|
" r=lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128\n",
|
|
" target_modules=[\n",
|
|
" \"q_proj\",\n",
|
|
" \"k_proj\",\n",
|
|
" \"v_proj\",\n",
|
|
" \"o_proj\",\n",
|
|
" \"gate_proj\",\n",
|
|
" \"up_proj\",\n",
|
|
" \"down_proj\",\n",
|
|
" ], # Remove QKVO if out of memory\n",
|
|
" lora_alpha=lora_rank,\n",
|
|
" use_gradient_checkpointing=\"unsloth\", # Enable long context finetuning\n",
|
|
" random_state=3407,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "7KGgPgk_5S8r"
|
|
},
|
|
"source": [
|
|
"### Data Prep\n",
|
|
"<a name=\"Data\"></a>\n",
|
|
"\n",
|
|
"We directly leverage [@willccbb](https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb) for data prep and all reward functions. You are free to create your own!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 321,
|
|
"referenced_widgets": [
|
|
"5954f25ab3334f8c8e0576ee75922c0c",
|
|
"5db92327420a48f4a1ad966ecd7ac8c5",
|
|
"1a1651d5ca7d4ccd887235343e49ffb3",
|
|
"22462b8181964257a53c85be6265eef6",
|
|
"1c09f4dc86124b1296ed96fc48b8d856",
|
|
"5e9d2ad532a8447a904688629a91f171",
|
|
"2df38622a5ab46979ee53efeb8c31b91",
|
|
"a75b75f7dec54759816845c3fda31f22",
|
|
"b14a63b4cac142c5adaaebbfef8739fe",
|
|
"1a090b75a50b459a8d2f822efc692dfa",
|
|
"0da6c9fa4eeb4b4d832f8bbc614bc8d9",
|
|
"3e1954a710bd43d1907b702c7f70c6de",
|
|
"483db1483d1d4a97a59c1e6cdad04c6a",
|
|
"cf88de7717504af691ca8db0fb68a6b6",
|
|
"d4de07c4656c40329bc2c1450bd94bb5",
|
|
"160f401dc9a74266bb610e747f8f711e",
|
|
"269f762b5af041eab750226febe70e07",
|
|
"f38c687b78224c8aa2227914998b4627",
|
|
"eca5ebef74f341f79dc96964d0189862",
|
|
"009ea1226a194cb2923ec9bc3a6e1843",
|
|
"2aade13ff2ca4250ae02b7a12b1ef16d",
|
|
"23531535f2fc4b99a5ecd1b63fb93d80",
|
|
"a137a1bc2b4f437a8477e079f003ad8d",
|
|
"05836f874c9a49d390721a05ff055ac9",
|
|
"9032130a4a9646db92805eeb4221c83d",
|
|
"ac58943ffd184cdca0b69cce35f54ec7",
|
|
"8dc3c53fac5644a6a2841aca62226300",
|
|
"fda3ed1250d34b12aae8401287322548",
|
|
"e85f549d21dd4c3293f8239a053b0ad4",
|
|
"a8d0ef0ef4f747e294394b04d07cbe34",
|
|
"15b83f9309a642d2bb1bc79da539b923",
|
|
"436282df7bcc45c4b951de60f2e4ec53",
|
|
"911d088df0894243b69fd4f69c46e7f3",
|
|
"eba35bc404f640bcb087de7a7ce74942",
|
|
"6b39fce0206f4ca894b943a42eb9dc04",
|
|
"c3ec8a75007846519d40f6726d2a01ee",
|
|
"de3486da03594b0091b77cbf702401a5",
|
|
"2e1b5b6b3ae84190a7f50993d78fb6ea",
|
|
"81f43f522faa4dcf9e557793a2027e90",
|
|
"ae46c8f17c0644fdbd1aa6d412c322b7",
|
|
"5de480828775465bb76b674ecd713bd0",
|
|
"a159d55b4de84020bc127d6c7758d2a9",
|
|
"cc6097b25ec74decb51a397ccd9fcf27",
|
|
"d2af3c19e6ec46a3915f6badf7062b7b",
|
|
"8d2463ee53394dc097843bc733191c97",
|
|
"f681a0ff692c4cd3b14b1e064547c52f",
|
|
"0ecef6aade754029ba939e30c91538be",
|
|
"87e25abbd016411199a2e976cd7ab550",
|
|
"93f708b722194750a9f66a2597987033",
|
|
"76dc9524ee764741a1615de6bc43492e",
|
|
"78201e7953b647ee9656fc7e8f4abbcb",
|
|
"0bc45e643a9b4a0e8db9f85180c108b2",
|
|
"cfb5247cf65c4bdfbcf8207c5d8c72c8",
|
|
"32fe42f8a11d4131a3a529b554f028fe",
|
|
"39c33d513fb74712be8fcf73093a0c9e",
|
|
"908333eedde74ec3a0885e52454b3627",
|
|
"934617e655d04f85a87d6444b7c1ff0e",
|
|
"63d0c36c4bb8469f8cf5ac410d2c11d6",
|
|
"092d58ba2c444878a1cbfd19bdc14027",
|
|
"c8ecf533f45f4ae29bdcfdaeca7daedf",
|
|
"d89eae07da0e4344909900e24e3d0d09",
|
|
"b1e00fc06b664fc59d4921388f187269",
|
|
"b834a1ed2fc14d209c405969def96965",
|
|
"dc681483ceeb4e878124858ffb467338",
|
|
"ef2a701c9c594d4d9ffc8379fa9b5899",
|
|
"22c31da7c0564f09bbdb04efdefad11a"
|
|
]
|
|
},
|
|
"id": "cXk993X6C2ZZ",
|
|
"outputId": "d6b161df-4d15-4cae-a02c-22fb901b91f1"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import re\n",
|
|
"from datasets import load_dataset, Dataset\n",
|
|
"\n",
|
|
"# Load and prep dataset\n",
|
|
"SYSTEM_PROMPT = \"\"\"\n",
|
|
"Respond in the following format:\n",
|
|
"<reasoning>\n",
|
|
"...\n",
|
|
"</reasoning>\n",
|
|
"<answer>\n",
|
|
"...\n",
|
|
"</answer>\n",
|
|
"\"\"\"\n",
|
|
"\n",
|
|
"XML_COT_FORMAT = \"\"\"\\\n",
|
|
"<reasoning>\n",
|
|
"{reasoning}\n",
|
|
"</reasoning>\n",
|
|
"<answer>\n",
|
|
"{answer}\n",
|
|
"</answer>\n",
|
|
"\"\"\"\n",
|
|
"\n",
|
|
"\n",
|
|
"def extract_xml_answer(text: str) -> str:\n",
|
|
" answer = text.split(\"<answer>\")[-1]\n",
|
|
" answer = answer.split(\"</answer>\")[0]\n",
|
|
" return answer.strip()\n",
|
|
"\n",
|
|
"\n",
|
|
"def extract_hash_answer(text: str) -> str | None:\n",
|
|
" if \"####\" not in text:\n",
|
|
" return None\n",
|
|
" return text.split(\"####\")[1].strip()\n",
|
|
"\n",
|
|
"\n",
|
|
"# uncomment middle messages for 1-shot prompting\n",
|
|
"def get_gsm8k_questions(split=\"train\") -> Dataset:\n",
|
|
" data = load_dataset(\"openai/gsm8k\", \"main\")[split] # type: ignore\n",
|
|
" data = data.map(\n",
|
|
" lambda x: { # type: ignore\n",
|
|
" \"prompt\": [\n",
|
|
" {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
|
|
" {\"role\": \"user\", \"content\": x[\"question\"]},\n",
|
|
" ],\n",
|
|
" \"answer\": extract_hash_answer(x[\"answer\"]),\n",
|
|
" }\n",
|
|
" ) # type: ignore\n",
|
|
" return data # type: ignore\n",
|
|
"\n",
|
|
"\n",
|
|
"dataset = get_gsm8k_questions()\n",
|
|
"\n",
|
|
"\n",
|
|
"# Reward functions\n",
|
|
"def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:\n",
|
|
" responses = [completion[0][\"content\"] for completion in completions]\n",
|
|
" q = prompts[0][-1][\"content\"]\n",
|
|
" extracted_responses = [extract_xml_answer(r) for r in responses]\n",
|
|
" print(\n",
|
|
" \"-\" * 20,\n",
|
|
" f\"Question:\\n{q}\",\n",
|
|
" f\"\\nAnswer:\\n{answer[0]}\",\n",
|
|
" f\"\\nResponse:\\n{responses[0]}\",\n",
|
|
" f\"\\nExtracted:\\n{extracted_responses[0]}\",\n",
|
|
" )\n",
|
|
" return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]\n",
|
|
"\n",
|
|
"\n",
|
|
"def int_reward_func(completions, **kwargs) -> list[float]:\n",
|
|
" responses = [completion[0][\"content\"] for completion in completions]\n",
|
|
" extracted_responses = [extract_xml_answer(r) for r in responses]\n",
|
|
" return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]\n",
|
|
"\n",
|
|
"\n",
|
|
"def strict_format_reward_func(completions, **kwargs) -> list[float]:\n",
|
|
" \"\"\"Reward function that checks if the completion has a specific format.\"\"\"\n",
|
|
" pattern = r\"^<reasoning>\\n.*?\\n</reasoning>\\n<answer>\\n.*?\\n</answer>\\n$\"\n",
|
|
" responses = [completion[0][\"content\"] for completion in completions]\n",
|
|
" matches = [re.match(pattern, r) for r in responses]\n",
|
|
" return [0.5 if match else 0.0 for match in matches]\n",
|
|
"\n",
|
|
"\n",
|
|
"def soft_format_reward_func(completions, **kwargs) -> list[float]:\n",
|
|
" \"\"\"Reward function that checks if the completion has a specific format.\"\"\"\n",
|
|
" pattern = r\"<reasoning>.*?</reasoning>\\s*<answer>.*?</answer>\"\n",
|
|
" responses = [completion[0][\"content\"] for completion in completions]\n",
|
|
" matches = [re.match(pattern, r) for r in responses]\n",
|
|
" return [0.5 if match else 0.0 for match in matches]\n",
|
|
"\n",
|
|
"\n",
|
|
"def count_xml(text) -> float:\n",
|
|
" count = 0.0\n",
|
|
" if text.count(\"<reasoning>\\n\") == 1:\n",
|
|
" count += 0.125\n",
|
|
" if text.count(\"\\n</reasoning>\\n\") == 1:\n",
|
|
" count += 0.125\n",
|
|
" if text.count(\"\\n<answer>\\n\") == 1:\n",
|
|
" count += 0.125\n",
|
|
" count -= len(text.split(\"\\n</answer>\\n\")[-1]) * 0.001\n",
|
|
" if text.count(\"\\n</answer>\") == 1:\n",
|
|
" count += 0.125\n",
|
|
" count -= (len(text.split(\"\\n</answer>\")[-1]) - 1) * 0.001\n",
|
|
" return count\n",
|
|
"\n",
|
|
"\n",
|
|
"def xmlcount_reward_func(completions, **kwargs) -> list[float]:\n",
|
|
" contents = [completion[0][\"content\"] for completion in completions]\n",
|
|
" return [count_xml(c) for c in contents]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "Ux6iqP7z5YOo"
|
|
},
|
|
"source": [
|
|
"<a name=\"Train\"></a>\n",
|
|
"### Train the model\n",
|
|
"\n",
|
|
"Now set up GRPO Trainer and all configurations!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "ptqkXK2D4d6p",
|
|
"outputId": "1e7b980b-c662-49c9-fb3a-d929edbafa09"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"max_prompt_length = 256\n",
|
|
"\n",
|
|
"from trl import GRPOConfig, GRPOTrainer\n",
|
|
"\n",
|
|
"training_args = GRPOConfig(\n",
|
|
" learning_rate=5e-6,\n",
|
|
" adam_beta1=0.9,\n",
|
|
" adam_beta2=0.99,\n",
|
|
" weight_decay=0.1,\n",
|
|
" warmup_ratio=0.1,\n",
|
|
" lr_scheduler_type=\"cosine\",\n",
|
|
" optim=\"paged_adamw_8bit\",\n",
|
|
" logging_steps=1,\n",
|
|
" per_device_train_batch_size=1,\n",
|
|
" gradient_accumulation_steps=1, # Increase to 4 for smoother training\n",
|
|
" num_generations=6, # Decrease if out of memory\n",
|
|
" max_prompt_length=max_prompt_length,\n",
|
|
" max_completion_length=max_seq_length - max_prompt_length,\n",
|
|
" # num_train_epochs = 1, # Set to 1 for a full training run\n",
|
|
" max_steps=250,\n",
|
|
" save_steps=250,\n",
|
|
" max_grad_norm=0.1,\n",
|
|
" report_to=\"none\", # Can use Weights & Biases\n",
|
|
" output_dir=\"outputs\",\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "r9Mv8UZO5hz-"
|
|
},
|
|
"source": [
|
|
"And let's run the trainer! If you scroll up, you'll see a table of rewards. The goal is to see the `reward` column increase!\n",
|
|
"\n",
|
|
"You might have to wait 150 to 200 steps for any action. You'll probably get 0 reward for the first 100 steps. Please be patient!\n",
|
|
"\n",
|
|
"| Step | Training Loss | reward | reward_std | completion_length | kl |\n",
|
|
"|------|---------------|-----------|------------|-------------------|----------|\n",
|
|
"| 1 | 0.000000 | 0.125000 | 0.000000 | 200.000000 | 0.000000 |\n",
|
|
"| 2 | 0.000000 | 0.072375 | 0.248112 | 200.000000 | 0.000000 |\n",
|
|
"| 3 | 0.000000 | -0.079000 | 0.163776 | 182.500000 | 0.000005 |\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 1000
|
|
},
|
|
"id": "vzOuSVCL_GA9",
|
|
"outputId": "09405367-ccea-4e1c-aabf-2dbb9318c4fa"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"trainer = GRPOTrainer(\n",
|
|
" model=model,\n",
|
|
" processing_class=tokenizer,\n",
|
|
" reward_funcs=[\n",
|
|
" xmlcount_reward_func,\n",
|
|
" soft_format_reward_func,\n",
|
|
" strict_format_reward_func,\n",
|
|
" int_reward_func,\n",
|
|
" correctness_reward_func,\n",
|
|
" ],\n",
|
|
" args=training_args,\n",
|
|
" train_dataset=dataset,\n",
|
|
")\n",
|
|
"trainer.train()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "tlaUdxC_VHpz"
|
|
},
|
|
"source": [
|
|
"<a name=\"Inference\"></a>\n",
|
|
"### Inference\n",
|
|
"Now let's try the model we just trained! First, let's first try the model without any GRPO trained:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 249
|
|
},
|
|
"id": "qtcz_lpbVC92",
|
|
"outputId": "70e4f329-acac-4d31-a8cd-47f7c6088747"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"text = tokenizer.apply_chat_template(\n",
|
|
" [\n",
|
|
" {\"role\": \"user\", \"content\": \"Calculate pi.\"},\n",
|
|
" ],\n",
|
|
" tokenize=False,\n",
|
|
" add_generation_prompt=True,\n",
|
|
")\n",
|
|
"\n",
|
|
"from vllm import SamplingParams\n",
|
|
"\n",
|
|
"sampling_params = SamplingParams(\n",
|
|
" temperature=0.8,\n",
|
|
" top_p=0.95,\n",
|
|
" max_tokens=1024,\n",
|
|
")\n",
|
|
"output = (\n",
|
|
" model.fast_generate(\n",
|
|
" [text],\n",
|
|
" sampling_params=sampling_params,\n",
|
|
" lora_request=None,\n",
|
|
" )[0]\n",
|
|
" .outputs[0]\n",
|
|
" .text\n",
|
|
")\n",
|
|
"\n",
|
|
"output"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "Colxz9TAVMsi"
|
|
},
|
|
"source": [
|
|
"And now with the LoRA we just trained with GRPO - we first save the LoRA first!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "AL-BcuB1VLIv"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"model.save_lora(\"grpo_saved_lora\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "CwpbwnDBVRLg"
|
|
},
|
|
"source": [
|
|
"Now we load the LoRA and test:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 249
|
|
},
|
|
"id": "zf_OY5WMVOxF",
|
|
"outputId": "22373f16-a3bc-4c99-8a1d-1994522a5f0f"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"text = tokenizer.apply_chat_template(\n",
|
|
" [\n",
|
|
" {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
|
|
" {\"role\": \"user\", \"content\": \"Calculate pi.\"},\n",
|
|
" ],\n",
|
|
" tokenize=False,\n",
|
|
" add_generation_prompt=True,\n",
|
|
")\n",
|
|
"\n",
|
|
"from vllm import SamplingParams\n",
|
|
"\n",
|
|
"sampling_params = SamplingParams(\n",
|
|
" temperature=0.8,\n",
|
|
" top_p=0.95,\n",
|
|
" max_tokens=1024,\n",
|
|
")\n",
|
|
"output = (\n",
|
|
" model.fast_generate(\n",
|
|
" text,\n",
|
|
" sampling_params=sampling_params,\n",
|
|
" lora_request=model.load_lora(\"grpo_saved_lora\"),\n",
|
|
" )[0]\n",
|
|
" .outputs[0]\n",
|
|
" .text\n",
|
|
")\n",
|
|
"\n",
|
|
"output"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "6aDgFfhFYIAS"
|
|
},
|
|
"source": [
|
|
"Our reasoning model is much better - it's not always correct, since we only trained it for an hour or so - it'll be better if we extend the sequence length and train for longer!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "-NUEmHFSYNTp"
|
|
},
|
|
"source": [
|
|
"<a name=\"Save\"></a>\n",
|
|
"### Saving to float16 for VLLM\n",
|
|
"\n",
|
|
"We also support saving to `float16` directly. Select `merged_16bit` for float16 or `merged_4bit` for int4. We also allow `lora` adapters as a fallback. Use `push_to_hub_merged` to upload to your Hugging Face account! You can go to https://huggingface.co/settings/tokens for your personal tokens."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "NjXGTkp7YNtB"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Merge to 16bit\n",
|
|
"if False:\n",
|
|
" model.save_pretrained_merged(\n",
|
|
" \"model\",\n",
|
|
" tokenizer,\n",
|
|
" save_method=\"merged_16bit\",\n",
|
|
" )\n",
|
|
"if False:\n",
|
|
" model.push_to_hub_merged(\n",
|
|
" \"hf/model\", tokenizer, save_method=\"merged_16bit\", token=\"\"\n",
|
|
" )\n",
|
|
"\n",
|
|
"# Merge to 4bit\n",
|
|
"if False:\n",
|
|
" model.save_pretrained_merged(\n",
|
|
" \"model\",\n",
|
|
" tokenizer,\n",
|
|
" save_method=\"merged_4bit\",\n",
|
|
" )\n",
|
|
"if False:\n",
|
|
" model.push_to_hub_merged(\"hf/model\", tokenizer, save_method=\"merged_4bit\", token=\"\")\n",
|
|
"\n",
|
|
"# Just LoRA adapters\n",
|
|
"if False:\n",
|
|
" model.save_pretrained_merged(\n",
|
|
" \"model\",\n",
|
|
" tokenizer,\n",
|
|
" save_method=\"lora\",\n",
|
|
" )\n",
|
|
"if False:\n",
|
|
" model.push_to_hub_merged(\"hf/model\", tokenizer, save_method=\"lora\", token=\"\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "52WMb3k_YPt8"
|
|
},
|
|
"source": [
|
|
"### GGUF / llama.cpp Conversion\n",
|
|
"To save to `GGUF` / `llama.cpp`, we support it natively now! We clone `llama.cpp` and we default save it to `q8_0`. We allow all methods like `q4_k_m`. Use `save_pretrained_gguf` for local saving and `push_to_hub_gguf` for uploading to HF.\n",
|
|
"\n",
|
|
"Some supported quant methods (full list on our [Wiki page](https://github.com/unslothai/unsloth/wiki#gguf-quantization-options)):\n",
|
|
"* `q8_0` - Fast conversion. High resource use, but generally acceptable.\n",
|
|
"* `q4_k_m` - Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q4_K.\n",
|
|
"* `q5_k_m` - Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q5_K.\n",
|
|
"\n",
|
|
"[**NEW**] To finetune and auto export to Ollama, try our [Ollama notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "QyEjW-WuYQIm"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Save to 8bit Q8_0\n",
|
|
"if False:\n",
|
|
" model.save_pretrained_gguf(\n",
|
|
" \"model\",\n",
|
|
" tokenizer,\n",
|
|
" )\n",
|
|
"# Remember to go to https://huggingface.co/settings/tokens for a token!\n",
|
|
"# And change hf to your username!\n",
|
|
"if False:\n",
|
|
" model.push_to_hub_gguf(\"hf/model\", tokenizer, token=\"\")\n",
|
|
"\n",
|
|
"# Save to 16bit GGUF\n",
|
|
"if False:\n",
|
|
" model.save_pretrained_gguf(\"model\", tokenizer, quantization_method=\"f16\")\n",
|
|
"if False:\n",
|
|
" model.push_to_hub_gguf(\"hf/model\", tokenizer, quantization_method=\"f16\", token=\"\")\n",
|
|
"\n",
|
|
"# Save to q4_k_m GGUF\n",
|
|
"if False:\n",
|
|
" model.save_pretrained_gguf(\"model\", tokenizer, quantization_method=\"q4_k_m\")\n",
|
|
"if False:\n",
|
|
" model.push_to_hub_gguf(\n",
|
|
" \"hf/model\", tokenizer, quantization_method=\"q4_k_m\", token=\"\"\n",
|
|
" )\n",
|
|
"\n",
|
|
"# Save to multiple GGUF options - much faster if you want multiple!\n",
|
|
"if False:\n",
|
|
" model.push_to_hub_gguf(\n",
|
|
" \"hf/model\", # Change hf to your username!\n",
|
|
" tokenizer,\n",
|
|
" quantization_method=[\n",
|
|
" \"q4_k_m\",\n",
|
|
" \"q8_0\",\n",
|
|
" \"q5_k_m\",\n",
|
|
" ],\n",
|
|
" token=\"\",\n",
|
|
" )"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "upXTZ_0aNjof"
|
|
},
|
|
"source": [
|
|
"Now, use the `model-unsloth.gguf` file or `model-unsloth-Q4_K_M.gguf` file in llama.cpp or a UI based system like Jan or Open WebUI. You can install Jan [here](https://github.com/janhq/jan) and Open WebUI [here](https://github.com/open-webui/open-webui)\n",
|
|
"\n",
|
|
"And we're done! If you have any questions on Unsloth, we have a [Discord](https://discord.gg/unsloth) channel! If you find any bugs or want to keep updated with the latest LLM stuff, or need help, join projects etc, feel free to join our Discord!\n",
|
|
"\n",
|
|
"Some other links:\n",
|
|
"1. Train your own reasoning model - Llama GRPO notebook [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-GRPO.ipynb)\n",
|
|
"2. Saving finetunes to Ollama. [Free notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)\n",
|
|
"3. Llama 3.2 Vision finetuning - Radiography use case. [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(11B)-Vision.ipynb)\n",
|
|
"6. See notebooks for DPO, ORPO, Continued pretraining, conversational finetuning and more on our [documentation](https://docs.unsloth.ai/get-started/unsloth-notebooks)!\n",
|
|
"\n",
|
|
"<div class=\"align-center\">\n",
|
|
" <a href=\"https://unsloth.ai\"><img src=\"https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png\" width=\"115\"></a>\n",
|
|
" <a href=\"https://discord.gg/unsloth\"><img src=\"https://github.com/unslothai/unsloth/raw/main/images/Discord.png\" width=\"145\"></a>\n",
|
|
" <a href=\"https://docs.unsloth.ai/\"><img src=\"https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true\" width=\"125\"></a>\n",
|
|
"\n",
|
|
" Join Discord if you need help + ⭐️ <i>Star us on <a href=\"https://github.com/unslothai/unsloth\">Github</a> </i> ⭐️\n",
|
|
"</div>\n"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"accelerator": "GPU",
|
|
"colab": {
|
|
"gpuType": "T4",
|
|
"provenance": []
|
|
},
|
|
"kernelspec": {
|
|
"display_name": ".venv",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.11"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0
|
|
}
|