You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

836 lines
28 KiB

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "wAwm-SoKNjoX"
},
"source": [
"# Sample GRPO Training Notebook\n",
"- Stolen from unsloth :3 thanks!\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rv8V30lINjob"
},
"source": [
"### Installation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1iVpHRPXNjoc"
},
"outputs": [],
"source": [
"# @title Colab Extra Install { display-mode: \"form\" }\n",
"import os\n",
"\n",
"if \"COLAB_\" not in \"\".join(os.environ.keys()):\n",
" !pip install -q unsloth vllm\n",
"else:\n",
" !pip install -q --no-deps unsloth vllm\n",
" # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]\n",
" # Skip restarting message in Colab\n",
" import sys\n",
" import re\n",
" import requests\n",
"\n",
" modules = list(sys.modules.keys())\n",
" for x in modules:\n",
" sys.modules.pop(x) if \"PIL\" in x or \"google\" in x else None\n",
" !pip install -q --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo\n",
" !pip install -q sentencepiece protobuf datasets huggingface_hub hf_transfer\n",
"\n",
" # vLLM requirements - vLLM breaks Colab due to reinstalling numpy\n",
" f = requests.get(\n",
" \"https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/requirements/common.txt\"\n",
" ).content\n",
" with open(\"vllm_requirements.txt\", \"wb\") as file:\n",
" file.write(re.sub(rb\"(transformers|numpy|xformers)[^\\n]{1,}\\n\", b\"\", f))\n",
" !pip install -q -r vllm_requirements.txt"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HnqT9aF5Njoc"
},
"source": [
"### Unsloth"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IpAO-idyNjoc"
},
"source": [
"Load up `Llama 3.1 8B Instruct`, and set parameters"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000,
"referenced_widgets": [
"09fca7e3b90f4b17a2344b67100e26bf",
"6eee6e73473440a39c2a14d11b19e7b9",
"f64c0803b3c34610a34f9c5f2b5ce6ba",
"4bae66db2fc041278f13ae1def5d1ddd",
"ce588ed5880240d7bfd91c150ef299ed",
"bc92bf99bcf345b59699eb8f0423f7cd",
"f589e115f63d4fdb95816dccb34c3e3e",
"fb45312477f64384912ba3e843efb0b3",
"4e541334eb26485ca5764b9bf0bbbd34",
"fe9dbcf4e85646cb8c684d82252b93e2",
"d36f0dda703c47cfae6f544e7e655806",
"cc9a242d278d4ea1b0f0f44d611439a2",
"9ec23dd10394466eb6db5336912f09de",
"e72af3cf2aac44d3a10b692b5fd54afe",
"7a0041d1e5c349b4b70e96554ea18113",
"f3fe4eb2a2fc4b2fb255632272dc0415",
"b095d3eac38847c0b087ffcc5876e0b6",
"9ebd8c869bc144cf8142375e7d7f04dd",
"48403d21996b4ebfa931cf260af6c283",
"b019cf9965424856b3fd781bd9e73118",
"ff032a23470845ba95a12227baf0e0ce",
"18d2fd154ffc4671b76f00760eb35571",
"bc8f8134f144435681cbc50f9f5e5274",
"d7a1980bf7cc4895aea14df80bf92302",
"6be20c9739de44ffb7ef2465ad17424d",
"39411144c4c44eae935262bc6312d5b9",
"32f3e3f08eb1441ca70191eaa6191805",
"2df204a0982b466eb3c25bf80645b6c1",
"9ec592e894b945cea7c2a756a96eb959",
"7ff47ae65e944dc9bbd50f01edb7fa02",
"2f5bc26300174430aa5b0610b5ef4c72",
"79da5ea5195c4c63b0ec1ee354511a87",
"a5f3704359df48e4b4c5b418a5f72021",
"2cae8f5f92304278bcfa4a07fb5aafdc",
"63681e5808d1443c9f04654d58952448",
"91200f911ca64162b006235d0ec0c5f1",
"40c4dfad109643529a16b3de307e2e1a",
"491466071e6742c6ab30419597d78fde",
"a67ca6e6c91244978b6b468257ffefd7",
"781f0652012b424aa43cd39028e3a99a",
"9e4ae8a9e1b247ffbd232d97be270cd8",
"c558b6c8e51046c6b3e0809ccf749348",
"0ea62e14a2d44b8ca7c6f15bca8e1a2b",
"27e78ef30ff64be5af79b6431816f569",
"b973da5783dc49d1802831298be3ab7e",
"d7e3f2c2441743089ee7ed10eb0319dd",
"7f490169dd8a4dbeb367af5259b0949e",
"2a8b54de6a974cc3a3f0b47daa10af89",
"70780d08fb00445da32b28b40d33412f",
"2c83de81fc91497db859d1be6860d6f4",
"397a49fc365b4c3a95db709b71fc32e3",
"1e75c59df9524dec928b21d65dd43c08",
"747b264f119746c08bbd6685c134ec4f",
"364d9414fd5f4280be4ad621874e15cb",
"7b3ad76804bd47a68fba0bf9f967cb04",
"2c34cb165ce447dca52c25eff7cc91ce",
"e8f729e29f7c4639a10a02a9f4978385",
"f60723e017f64a2183af5f05ca827340",
"165ed67ab04240fca77e70e6a905ad48",
"e229f6a91dcf4b5182673cc6c7002063",
"fe4ecfe3522c432b964608f137cfb936",
"e1e8c601a2bb4c0b8f0cdf3bf0dd0b2e",
"db89c098dd414fa08440071935414b63",
"efb3aff9a97340fb8acf5f98402e7c1d",
"d4eeb17c3a124c919058c0ffe72a60b6",
"8b6bd37f64454464b44691ae3bef0b9f",
"239c67254f4044259fcbb35cb78d72b3",
"7bdc3ab924804e6badd7f403d0ce33f1",
"2b1ff4602dee4511a450588c501a7174",
"bfae7687d2a440189c29dc53e47b71ab",
"c72dbf9684ed4cf0a55655ebaf645b96",
"7083c21ed3dd4ed6b442afdf4a18a05d",
"6e4ce08c86f24beeb63d38d436143dc9",
"86498c04c2a347968e65867eb4ed2102",
"dc90aea1a6e445ebaf68bb305586bf47",
"4c69d53cf33c4865af8feead04b810ad",
"9dbc67661c614d8a82d0790d9b694e6d",
"97f09dcebaf94be3b115a69b1440392b",
"a568aff772814128b0ad329594e67610",
"92ac502bff6e48d98cbb07a65d84dca6",
"0b9dd1c421f14c56986cf30abe0a9455",
"50a3cb7fc0ce45bb8c1baabc2a54fff2",
"fac3506146244766b886c715ea57ec7b",
"d068e96c067e43ffa1777a666b0df147",
"9e9b1a12028443c9abe6d88d0ed39d15",
"638c59bf00ba4536ad157975faa26984",
"73177ee1d64c4f56aa393c45c42e7499",
"98aabe0e4167414d943ff3d7759093ad",
"0f040a5765854b699b32bf096bb0563c",
"813e8e40df6447c18e47414ab38f69a2",
"533f84f90e7247e8a60da8b6ce60fe89",
"80c5f4171f2844749b5f3e67a46178b9",
"0722b569fa2e45f9af7d4fb8134e632b",
"115a8f3c523c4cd3aaa472af87fe0e4a",
"26cde13f97db4ae4a0900fb5184a51a2",
"5c4e90490762447c9712372a18f48e20",
"690c64fbee9f4c038a29030086c73368",
"cbe9c4e2f8bc49dfb1d8a4cddc36d724",
"f33aef89be7641a694fb11e3101c7d85",
"6798efcf0f784fca91a2e84c91553296",
"190504fd7fec4f08ba43b972ed28d23b",
"142703403e5e4ba1ac1531d7c82c752a",
"0a45b8dd65764272869fd1424415e6af",
"d536479db8034c99be7e745a440baf7d",
"594a583be494419c845fd873904d4d6d",
"ba21e0c745714031aae12b1665847e54",
"3491a152831245f7b03d06fe94f42116",
"b80d4e38b5c54dd585d1da848ee9c4c8",
"fa155f662d1c40f496daef290ff7a5ff",
"9b2a88f9712d409bbf74bf0c2e14df16"
]
},
"id": "DkIvEkIIkEyB",
"outputId": "e1aaf3b5-6ad0-45c7-be08-2017e74ebba9"
},
"outputs": [],
"source": [
"from unsloth import FastLanguageModel\n",
"\n",
"max_seq_length = 1024 # Can increase for longer reasoning traces\n",
"lora_rank = 32 # Larger rank = smarter, but slower\n",
"\n",
"model, tokenizer = FastLanguageModel.from_pretrained(\n",
" model_name=\"meta-llama/meta-Llama-3.1-8B-Instruct\",\n",
" max_seq_length=max_seq_length,\n",
" load_in_4bit=True, # False for LoRA 16bit\n",
" fast_inference=True, # Enable vLLM fast inference\n",
" max_lora_rank=lora_rank,\n",
" gpu_memory_utilization=0.6, # Reduce if out of memory\n",
")\n",
"\n",
"model = FastLanguageModel.get_peft_model(\n",
" model,\n",
" r=lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128\n",
" target_modules=[\n",
" \"q_proj\",\n",
" \"k_proj\",\n",
" \"v_proj\",\n",
" \"o_proj\",\n",
" \"gate_proj\",\n",
" \"up_proj\",\n",
" \"down_proj\",\n",
" ], # Remove QKVO if out of memory\n",
" lora_alpha=lora_rank,\n",
" use_gradient_checkpointing=\"unsloth\", # Enable long context finetuning\n",
" random_state=3407,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7KGgPgk_5S8r"
},
"source": [
"### Data Prep\n",
"<a name=\"Data\"></a>\n",
"\n",
"We directly leverage [@willccbb](https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb) for data prep and all reward functions. You are free to create your own!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 321,
"referenced_widgets": [
"5954f25ab3334f8c8e0576ee75922c0c",
"5db92327420a48f4a1ad966ecd7ac8c5",
"1a1651d5ca7d4ccd887235343e49ffb3",
"22462b8181964257a53c85be6265eef6",
"1c09f4dc86124b1296ed96fc48b8d856",
"5e9d2ad532a8447a904688629a91f171",
"2df38622a5ab46979ee53efeb8c31b91",
"a75b75f7dec54759816845c3fda31f22",
"b14a63b4cac142c5adaaebbfef8739fe",
"1a090b75a50b459a8d2f822efc692dfa",
"0da6c9fa4eeb4b4d832f8bbc614bc8d9",
"3e1954a710bd43d1907b702c7f70c6de",
"483db1483d1d4a97a59c1e6cdad04c6a",
"cf88de7717504af691ca8db0fb68a6b6",
"d4de07c4656c40329bc2c1450bd94bb5",
"160f401dc9a74266bb610e747f8f711e",
"269f762b5af041eab750226febe70e07",
"f38c687b78224c8aa2227914998b4627",
"eca5ebef74f341f79dc96964d0189862",
"009ea1226a194cb2923ec9bc3a6e1843",
"2aade13ff2ca4250ae02b7a12b1ef16d",
"23531535f2fc4b99a5ecd1b63fb93d80",
"a137a1bc2b4f437a8477e079f003ad8d",
"05836f874c9a49d390721a05ff055ac9",
"9032130a4a9646db92805eeb4221c83d",
"ac58943ffd184cdca0b69cce35f54ec7",
"8dc3c53fac5644a6a2841aca62226300",
"fda3ed1250d34b12aae8401287322548",
"e85f549d21dd4c3293f8239a053b0ad4",
"a8d0ef0ef4f747e294394b04d07cbe34",
"15b83f9309a642d2bb1bc79da539b923",
"436282df7bcc45c4b951de60f2e4ec53",
"911d088df0894243b69fd4f69c46e7f3",
"eba35bc404f640bcb087de7a7ce74942",
"6b39fce0206f4ca894b943a42eb9dc04",
"c3ec8a75007846519d40f6726d2a01ee",
"de3486da03594b0091b77cbf702401a5",
"2e1b5b6b3ae84190a7f50993d78fb6ea",
"81f43f522faa4dcf9e557793a2027e90",
"ae46c8f17c0644fdbd1aa6d412c322b7",
"5de480828775465bb76b674ecd713bd0",
"a159d55b4de84020bc127d6c7758d2a9",
"cc6097b25ec74decb51a397ccd9fcf27",
"d2af3c19e6ec46a3915f6badf7062b7b",
"8d2463ee53394dc097843bc733191c97",
"f681a0ff692c4cd3b14b1e064547c52f",
"0ecef6aade754029ba939e30c91538be",
"87e25abbd016411199a2e976cd7ab550",
"93f708b722194750a9f66a2597987033",
"76dc9524ee764741a1615de6bc43492e",
"78201e7953b647ee9656fc7e8f4abbcb",
"0bc45e643a9b4a0e8db9f85180c108b2",
"cfb5247cf65c4bdfbcf8207c5d8c72c8",
"32fe42f8a11d4131a3a529b554f028fe",
"39c33d513fb74712be8fcf73093a0c9e",
"908333eedde74ec3a0885e52454b3627",
"934617e655d04f85a87d6444b7c1ff0e",
"63d0c36c4bb8469f8cf5ac410d2c11d6",
"092d58ba2c444878a1cbfd19bdc14027",
"c8ecf533f45f4ae29bdcfdaeca7daedf",
"d89eae07da0e4344909900e24e3d0d09",
"b1e00fc06b664fc59d4921388f187269",
"b834a1ed2fc14d209c405969def96965",
"dc681483ceeb4e878124858ffb467338",
"ef2a701c9c594d4d9ffc8379fa9b5899",
"22c31da7c0564f09bbdb04efdefad11a"
]
},
"id": "cXk993X6C2ZZ",
"outputId": "d6b161df-4d15-4cae-a02c-22fb901b91f1"
},
"outputs": [],
"source": [
"import re\n",
"from datasets import load_dataset, Dataset\n",
"\n",
"# Load and prep dataset\n",
"SYSTEM_PROMPT = \"\"\"\n",
"Respond in the following format:\n",
"<reasoning>\n",
"...\n",
"</reasoning>\n",
"<answer>\n",
"...\n",
"</answer>\n",
"\"\"\"\n",
"\n",
"XML_COT_FORMAT = \"\"\"\\\n",
"<reasoning>\n",
"{reasoning}\n",
"</reasoning>\n",
"<answer>\n",
"{answer}\n",
"</answer>\n",
"\"\"\"\n",
"\n",
"\n",
"def extract_xml_answer(text: str) -> str:\n",
" answer = text.split(\"<answer>\")[-1]\n",
" answer = answer.split(\"</answer>\")[0]\n",
" return answer.strip()\n",
"\n",
"\n",
"def extract_hash_answer(text: str) -> str | None:\n",
" if \"####\" not in text:\n",
" return None\n",
" return text.split(\"####\")[1].strip()\n",
"\n",
"\n",
"# uncomment middle messages for 1-shot prompting\n",
"def get_gsm8k_questions(split=\"train\") -> Dataset:\n",
" data = load_dataset(\"openai/gsm8k\", \"main\")[split] # type: ignore\n",
" data = data.map(\n",
" lambda x: { # type: ignore\n",
" \"prompt\": [\n",
" {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
" {\"role\": \"user\", \"content\": x[\"question\"]},\n",
" ],\n",
" \"answer\": extract_hash_answer(x[\"answer\"]),\n",
" }\n",
" ) # type: ignore\n",
" return data # type: ignore\n",
"\n",
"\n",
"dataset = get_gsm8k_questions()\n",
"\n",
"\n",
"# Reward functions\n",
"def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:\n",
" responses = [completion[0][\"content\"] for completion in completions]\n",
" q = prompts[0][-1][\"content\"]\n",
" extracted_responses = [extract_xml_answer(r) for r in responses]\n",
" print(\n",
" \"-\" * 20,\n",
" f\"Question:\\n{q}\",\n",
" f\"\\nAnswer:\\n{answer[0]}\",\n",
" f\"\\nResponse:\\n{responses[0]}\",\n",
" f\"\\nExtracted:\\n{extracted_responses[0]}\",\n",
" )\n",
" return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]\n",
"\n",
"\n",
"def int_reward_func(completions, **kwargs) -> list[float]:\n",
" responses = [completion[0][\"content\"] for completion in completions]\n",
" extracted_responses = [extract_xml_answer(r) for r in responses]\n",
" return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]\n",
"\n",
"\n",
"def strict_format_reward_func(completions, **kwargs) -> list[float]:\n",
" \"\"\"Reward function that checks if the completion has a specific format.\"\"\"\n",
" pattern = r\"^<reasoning>\\n.*?\\n</reasoning>\\n<answer>\\n.*?\\n</answer>\\n$\"\n",
" responses = [completion[0][\"content\"] for completion in completions]\n",
" matches = [re.match(pattern, r) for r in responses]\n",
" return [0.5 if match else 0.0 for match in matches]\n",
"\n",
"\n",
"def soft_format_reward_func(completions, **kwargs) -> list[float]:\n",
" \"\"\"Reward function that checks if the completion has a specific format.\"\"\"\n",
" pattern = r\"<reasoning>.*?</reasoning>\\s*<answer>.*?</answer>\"\n",
" responses = [completion[0][\"content\"] for completion in completions]\n",
" matches = [re.match(pattern, r) for r in responses]\n",
" return [0.5 if match else 0.0 for match in matches]\n",
"\n",
"\n",
"def count_xml(text) -> float:\n",
" count = 0.0\n",
" if text.count(\"<reasoning>\\n\") == 1:\n",
" count += 0.125\n",
" if text.count(\"\\n</reasoning>\\n\") == 1:\n",
" count += 0.125\n",
" if text.count(\"\\n<answer>\\n\") == 1:\n",
" count += 0.125\n",
" count -= len(text.split(\"\\n</answer>\\n\")[-1]) * 0.001\n",
" if text.count(\"\\n</answer>\") == 1:\n",
" count += 0.125\n",
" count -= (len(text.split(\"\\n</answer>\")[-1]) - 1) * 0.001\n",
" return count\n",
"\n",
"\n",
"def xmlcount_reward_func(completions, **kwargs) -> list[float]:\n",
" contents = [completion[0][\"content\"] for completion in completions]\n",
" return [count_xml(c) for c in contents]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ux6iqP7z5YOo"
},
"source": [
"<a name=\"Train\"></a>\n",
"### Train the model\n",
"\n",
"Now set up GRPO Trainer and all configurations!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ptqkXK2D4d6p",
"outputId": "1e7b980b-c662-49c9-fb3a-d929edbafa09"
},
"outputs": [],
"source": [
"max_prompt_length = 256\n",
"\n",
"from trl import GRPOConfig, GRPOTrainer\n",
"\n",
"training_args = GRPOConfig(\n",
" learning_rate=5e-6,\n",
" adam_beta1=0.9,\n",
" adam_beta2=0.99,\n",
" weight_decay=0.1,\n",
" warmup_ratio=0.1,\n",
" lr_scheduler_type=\"cosine\",\n",
" optim=\"paged_adamw_8bit\",\n",
" logging_steps=1,\n",
" per_device_train_batch_size=1,\n",
" gradient_accumulation_steps=1, # Increase to 4 for smoother training\n",
" num_generations=6, # Decrease if out of memory\n",
" max_prompt_length=max_prompt_length,\n",
" max_completion_length=max_seq_length - max_prompt_length,\n",
" # num_train_epochs = 1, # Set to 1 for a full training run\n",
" max_steps=250,\n",
" save_steps=250,\n",
" max_grad_norm=0.1,\n",
" report_to=\"none\", # Can use Weights & Biases\n",
" output_dir=\"outputs\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "r9Mv8UZO5hz-"
},
"source": [
"And let's run the trainer! If you scroll up, you'll see a table of rewards. The goal is to see the `reward` column increase!\n",
"\n",
"You might have to wait 150 to 200 steps for any action. You'll probably get 0 reward for the first 100 steps. Please be patient!\n",
"\n",
"| Step | Training Loss | reward | reward_std | completion_length | kl |\n",
"|------|---------------|-----------|------------|-------------------|----------|\n",
"| 1 | 0.000000 | 0.125000 | 0.000000 | 200.000000 | 0.000000 |\n",
"| 2 | 0.000000 | 0.072375 | 0.248112 | 200.000000 | 0.000000 |\n",
"| 3 | 0.000000 | -0.079000 | 0.163776 | 182.500000 | 0.000005 |\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "vzOuSVCL_GA9",
"outputId": "09405367-ccea-4e1c-aabf-2dbb9318c4fa"
},
"outputs": [],
"source": [
"trainer = GRPOTrainer(\n",
" model=model,\n",
" processing_class=tokenizer,\n",
" reward_funcs=[\n",
" xmlcount_reward_func,\n",
" soft_format_reward_func,\n",
" strict_format_reward_func,\n",
" int_reward_func,\n",
" correctness_reward_func,\n",
" ],\n",
" args=training_args,\n",
" train_dataset=dataset,\n",
")\n",
"trainer.train()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tlaUdxC_VHpz"
},
"source": [
"<a name=\"Inference\"></a>\n",
"### Inference\n",
"Now let's try the model we just trained! First, let's first try the model without any GRPO trained:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 249
},
"id": "qtcz_lpbVC92",
"outputId": "70e4f329-acac-4d31-a8cd-47f7c6088747"
},
"outputs": [],
"source": [
"text = tokenizer.apply_chat_template(\n",
" [\n",
" {\"role\": \"user\", \"content\": \"Calculate pi.\"},\n",
" ],\n",
" tokenize=False,\n",
" add_generation_prompt=True,\n",
")\n",
"\n",
"from vllm import SamplingParams\n",
"\n",
"sampling_params = SamplingParams(\n",
" temperature=0.8,\n",
" top_p=0.95,\n",
" max_tokens=1024,\n",
")\n",
"output = (\n",
" model.fast_generate(\n",
" [text],\n",
" sampling_params=sampling_params,\n",
" lora_request=None,\n",
" )[0]\n",
" .outputs[0]\n",
" .text\n",
")\n",
"\n",
"output"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Colxz9TAVMsi"
},
"source": [
"And now with the LoRA we just trained with GRPO - we first save the LoRA first!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "AL-BcuB1VLIv"
},
"outputs": [],
"source": [
"model.save_lora(\"grpo_saved_lora\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CwpbwnDBVRLg"
},
"source": [
"Now we load the LoRA and test:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 249
},
"id": "zf_OY5WMVOxF",
"outputId": "22373f16-a3bc-4c99-8a1d-1994522a5f0f"
},
"outputs": [],
"source": [
"text = tokenizer.apply_chat_template(\n",
" [\n",
" {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
" {\"role\": \"user\", \"content\": \"Calculate pi.\"},\n",
" ],\n",
" tokenize=False,\n",
" add_generation_prompt=True,\n",
")\n",
"\n",
"from vllm import SamplingParams\n",
"\n",
"sampling_params = SamplingParams(\n",
" temperature=0.8,\n",
" top_p=0.95,\n",
" max_tokens=1024,\n",
")\n",
"output = (\n",
" model.fast_generate(\n",
" text,\n",
" sampling_params=sampling_params,\n",
" lora_request=model.load_lora(\"grpo_saved_lora\"),\n",
" )[0]\n",
" .outputs[0]\n",
" .text\n",
")\n",
"\n",
"output"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6aDgFfhFYIAS"
},
"source": [
"Our reasoning model is much better - it's not always correct, since we only trained it for an hour or so - it'll be better if we extend the sequence length and train for longer!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-NUEmHFSYNTp"
},
"source": [
"<a name=\"Save\"></a>\n",
"### Saving to float16 for VLLM\n",
"\n",
"We also support saving to `float16` directly. Select `merged_16bit` for float16 or `merged_4bit` for int4. We also allow `lora` adapters as a fallback. Use `push_to_hub_merged` to upload to your Hugging Face account! You can go to https://huggingface.co/settings/tokens for your personal tokens."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NjXGTkp7YNtB"
},
"outputs": [],
"source": [
"# Merge to 16bit\n",
"if False:\n",
" model.save_pretrained_merged(\n",
" \"model\",\n",
" tokenizer,\n",
" save_method=\"merged_16bit\",\n",
" )\n",
"if False:\n",
" model.push_to_hub_merged(\n",
" \"hf/model\", tokenizer, save_method=\"merged_16bit\", token=\"\"\n",
" )\n",
"\n",
"# Merge to 4bit\n",
"if False:\n",
" model.save_pretrained_merged(\n",
" \"model\",\n",
" tokenizer,\n",
" save_method=\"merged_4bit\",\n",
" )\n",
"if False:\n",
" model.push_to_hub_merged(\"hf/model\", tokenizer, save_method=\"merged_4bit\", token=\"\")\n",
"\n",
"# Just LoRA adapters\n",
"if False:\n",
" model.save_pretrained_merged(\n",
" \"model\",\n",
" tokenizer,\n",
" save_method=\"lora\",\n",
" )\n",
"if False:\n",
" model.push_to_hub_merged(\"hf/model\", tokenizer, save_method=\"lora\", token=\"\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "52WMb3k_YPt8"
},
"source": [
"### GGUF / llama.cpp Conversion\n",
"To save to `GGUF` / `llama.cpp`, we support it natively now! We clone `llama.cpp` and we default save it to `q8_0`. We allow all methods like `q4_k_m`. Use `save_pretrained_gguf` for local saving and `push_to_hub_gguf` for uploading to HF.\n",
"\n",
"Some supported quant methods (full list on our [Wiki page](https://github.com/unslothai/unsloth/wiki#gguf-quantization-options)):\n",
"* `q8_0` - Fast conversion. High resource use, but generally acceptable.\n",
"* `q4_k_m` - Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q4_K.\n",
"* `q5_k_m` - Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q5_K.\n",
"\n",
"[**NEW**] To finetune and auto export to Ollama, try our [Ollama notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QyEjW-WuYQIm"
},
"outputs": [],
"source": [
"# Save to 8bit Q8_0\n",
"if False:\n",
" model.save_pretrained_gguf(\n",
" \"model\",\n",
" tokenizer,\n",
" )\n",
"# Remember to go to https://huggingface.co/settings/tokens for a token!\n",
"# And change hf to your username!\n",
"if False:\n",
" model.push_to_hub_gguf(\"hf/model\", tokenizer, token=\"\")\n",
"\n",
"# Save to 16bit GGUF\n",
"if False:\n",
" model.save_pretrained_gguf(\"model\", tokenizer, quantization_method=\"f16\")\n",
"if False:\n",
" model.push_to_hub_gguf(\"hf/model\", tokenizer, quantization_method=\"f16\", token=\"\")\n",
"\n",
"# Save to q4_k_m GGUF\n",
"if False:\n",
" model.save_pretrained_gguf(\"model\", tokenizer, quantization_method=\"q4_k_m\")\n",
"if False:\n",
" model.push_to_hub_gguf(\n",
" \"hf/model\", tokenizer, quantization_method=\"q4_k_m\", token=\"\"\n",
" )\n",
"\n",
"# Save to multiple GGUF options - much faster if you want multiple!\n",
"if False:\n",
" model.push_to_hub_gguf(\n",
" \"hf/model\", # Change hf to your username!\n",
" tokenizer,\n",
" quantization_method=[\n",
" \"q4_k_m\",\n",
" \"q8_0\",\n",
" \"q5_k_m\",\n",
" ],\n",
" token=\"\",\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "upXTZ_0aNjof"
},
"source": [
"Now, use the `model-unsloth.gguf` file or `model-unsloth-Q4_K_M.gguf` file in llama.cpp or a UI based system like Jan or Open WebUI. You can install Jan [here](https://github.com/janhq/jan) and Open WebUI [here](https://github.com/open-webui/open-webui)\n",
"\n",
"And we're done! If you have any questions on Unsloth, we have a [Discord](https://discord.gg/unsloth) channel! If you find any bugs or want to keep updated with the latest LLM stuff, or need help, join projects etc, feel free to join our Discord!\n",
"\n",
"Some other links:\n",
"1. Train your own reasoning model - Llama GRPO notebook [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-GRPO.ipynb)\n",
"2. Saving finetunes to Ollama. [Free notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)\n",
"3. Llama 3.2 Vision finetuning - Radiography use case. [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(11B)-Vision.ipynb)\n",
"6. See notebooks for DPO, ORPO, Continued pretraining, conversational finetuning and more on our [documentation](https://docs.unsloth.ai/get-started/unsloth-notebooks)!\n",
"\n",
"<div class=\"align-center\">\n",
" <a href=\"https://unsloth.ai\"><img src=\"https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png\" width=\"115\"></a>\n",
" <a href=\"https://discord.gg/unsloth\"><img src=\"https://github.com/unslothai/unsloth/raw/main/images/Discord.png\" width=\"145\"></a>\n",
" <a href=\"https://docs.unsloth.ai/\"><img src=\"https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true\" width=\"125\"></a>\n",
"\n",
" Join Discord if you need help + ⭐️ <i>Star us on <a href=\"https://github.com/unslothai/unsloth\">Github</a> </i> ⭐️\n",
"</div>\n"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "T4",
"provenance": []
},
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 0
}