@ -57,6 +57,7 @@
"source": [
"source": [
"%%capture\n",
"%%capture\n",
"import os\n",
"import os\n",
"\n",
"if \"COLAB_\" not in \"\".join(os.environ.keys()):\n",
"if \"COLAB_\" not in \"\".join(os.environ.keys()):\n",
" !pip install unsloth vllm\n",
" !pip install unsloth vllm\n",
"else:\n",
"else:\n",
@ -75,19 +76,27 @@
"# @title Colab Extra Install { display-mode: \"form\" }\n",
"# @title Colab Extra Install { display-mode: \"form\" }\n",
"%%capture\n",
"%%capture\n",
"import os\n",
"import os\n",
"\n",
"if \"COLAB_\" not in \"\".join(os.environ.keys()):\n",
"if \"COLAB_\" not in \"\".join(os.environ.keys()):\n",
" !pip install unsloth vllm\n",
" !pip install unsloth vllm\n",
"else:\n",
"else:\n",
" !pip install --no-deps unsloth vllm\n",
" !pip install --no-deps unsloth vllm\n",
" # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]\n",
" # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]\n",
" # Skip restarting message in Colab\n",
" # Skip restarting message in Colab\n",
" import sys, re, requests; modules = list(sys.modules.keys())\n",
" import sys\n",
" for x in modules: sys.modules.pop(x) if \"PIL\" in x or \"google\" in x else None\n",
" import re\n",
" import requests\n",
"\n",
" modules = list(sys.modules.keys())\n",
" for x in modules:\n",
" sys.modules.pop(x) if \"PIL\" in x or \"google\" in x else None\n",
" !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo\n",
" !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo\n",
" !pip install sentencepiece protobuf datasets huggingface_hub hf_transfer\n",
" !pip install sentencepiece protobuf datasets huggingface_hub hf_transfer\n",
"\n",
"\n",
" # vLLM requirements - vLLM breaks Colab due to reinstalling numpy\n",
" # vLLM requirements - vLLM breaks Colab due to reinstalling numpy\n",
" f = requests.get(\"https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/requirements/common.txt\").content\n",
" f = requests.get(\n",
" \"https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/requirements/common.txt\"\n",
" ).content\n",
" with open(\"vllm_requirements.txt\", \"wb\") as file:\n",
" with open(\"vllm_requirements.txt\", \"wb\") as file:\n",
" file.write(re.sub(rb\"(transformers|numpy|xformers)[^\\n]{1,}\\n\", b\"\", f))\n",
" file.write(re.sub(rb\"(transformers|numpy|xformers)[^\\n]{1,}\\n\", b\"\", f))\n",
" !pip install -r vllm_requirements.txt"
" !pip install -r vllm_requirements.txt"
@ -303,7 +312,7 @@
"outputs": [],
"outputs": [],
"source": [
"source": [
"from unsloth import FastLanguageModel, is_bfloat16_supported\n",
"from unsloth import FastLanguageModel, is_bfloat16_supported\n",
"import torch \n",
"\n",
"max_seq_length = 1024 # Can increase for longer reasoning traces\n",
"max_seq_length = 1024 # Can increase for longer reasoning traces\n",
"lora_rank = 64 # Larger rank = smarter, but slower\n",
"lora_rank = 64 # Larger rank = smarter, but slower\n",
"\n",
"\n",
@ -320,8 +329,13 @@
" model,\n",
" model,\n",
" r=lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128\n",
" r=lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128\n",
" target_modules=[\n",
" target_modules=[\n",
" \"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
" \"q_proj\",\n",
" \"gate_proj\", \"up_proj\", \"down_proj\",\n",
" \"k_proj\",\n",
" \"v_proj\",\n",
" \"o_proj\",\n",
" \"gate_proj\",\n",
" \"up_proj\",\n",
" \"down_proj\",\n",
" ], # Remove QKVO if out of memory\n",
" ], # Remove QKVO if out of memory\n",
" lora_alpha=lora_rank,\n",
" lora_alpha=lora_rank,\n",
" use_gradient_checkpointing=\"unsloth\", # Enable long context finetuning\n",
" use_gradient_checkpointing=\"unsloth\", # Enable long context finetuning\n",
@ -445,43 +459,58 @@
"</answer>\n",
"</answer>\n",
"\"\"\"\n",
"\"\"\"\n",
"\n",
"\n",
"\n",
"def extract_xml_answer(text: str) -> str:\n",
"def extract_xml_answer(text: str) -> str:\n",
" answer = text.split(\"<answer>\")[-1]\n",
" answer = text.split(\"<answer>\")[-1]\n",
" answer = answer.split(\"</answer>\")[0]\n",
" answer = answer.split(\"</answer>\")[0]\n",
" return answer.strip()\n",
" return answer.strip()\n",
"\n",
"\n",
"\n",
"def extract_hash_answer(text: str) -> str | None:\n",
"def extract_hash_answer(text: str) -> str | None:\n",
" if \"####\" not in text:\n",
" if \"####\" not in text:\n",
" return None\n",
" return None\n",
" return text.split(\"####\")[1].strip()\n",
" return text.split(\"####\")[1].strip()\n",
"\n",
"\n",
"\n",
"# uncomment middle messages for 1-shot prompting\n",
"# uncomment middle messages for 1-shot prompting\n",
"def get_gsm8k_questions(split=\"train\") -> Dataset:\n",
"def get_gsm8k_questions(split=\"train\") -> Dataset:\n",
" data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore\n",
" data = load_dataset(\"openai/gsm8k\", \"main\")[split] # type: ignore\n",
" data = data.map(lambda x: { # type: ignore\n",
" data = data.map(\n",
" 'prompt': [\n",
" lambda x: { # type: ignore\n",
" {'role': 'system', 'content': SYSTEM_PROMPT},\n",
" \"prompt\": [\n",
" {'role': 'user', 'content': x['question']}\n",
" {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
" {\"role\": \"user\", \"content\": x[\"question\"]},\n",
" ],\n",
" ],\n",
" 'answer': extract_hash_answer(x['answer'])\n",
" \"answer\": extract_hash_answer(x[\"answer\"]),\n",
" }) # type: ignore\n",
" }\n",
" ) # type: ignore\n",
" return data # type: ignore\n",
" return data # type: ignore\n",
"\n",
"\n",
"\n",
"dataset = get_gsm8k_questions()\n",
"dataset = get_gsm8k_questions()\n",
"\n",
"\n",
"\n",
"# Reward functions\n",
"# Reward functions\n",
"def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:\n",
"def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:\n",
" responses = [completion[0]['content' ] for completion in completions]\n",
" responses = [completion[0][\"content\" ] for completion in completions]\n",
" q = prompts[0][-1]['content' ]\n",
" q = prompts[0][-1][\"content\" ]\n",
" extracted_responses = [extract_xml_answer(r) for r in responses]\n",
" extracted_responses = [extract_xml_answer(r) for r in responses]\n",
" print('-'*20, f\"Question:\\n{q}\", f\"\\nAnswer:\\n{answer[0]}\", f\"\\nResponse:\\n{responses[0]}\", f\"\\nExtracted:\\n{extracted_responses[0]}\")\n",
" print(\n",
" \"-\" * 20,\n",
" f\"Question:\\n{q}\",\n",
" f\"\\nAnswer:\\n{answer[0]}\",\n",
" f\"\\nResponse:\\n{responses[0]}\",\n",
" f\"\\nExtracted:\\n{extracted_responses[0]}\",\n",
" )\n",
" return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]\n",
" return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]\n",
"\n",
"\n",
"\n",
"def int_reward_func(completions, **kwargs) -> list[float]:\n",
"def int_reward_func(completions, **kwargs) -> list[float]:\n",
" responses = [completion[0]['content'] for completion in completions]\n",
" responses = [completion[0][\"content\" ] for completion in completions]\n",
" extracted_responses = [extract_xml_answer(r) for r in responses]\n",
" extracted_responses = [extract_xml_answer(r) for r in responses]\n",
" return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]\n",
" return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]\n",
"\n",
"\n",
"\n",
"def strict_format_reward_func(completions, **kwargs) -> list[float]:\n",
"def strict_format_reward_func(completions, **kwargs) -> list[float]:\n",
" \"\"\"Reward function that checks if the completion has a specific format.\"\"\"\n",
" \"\"\"Reward function that checks if the completion has a specific format.\"\"\"\n",
" pattern = r\"^<reasoning>\\n.*?\\n</reasoning>\\n<answer>\\n.*?\\n</answer>\\n$\"\n",
" pattern = r\"^<reasoning>\\n.*?\\n</reasoning>\\n<answer>\\n.*?\\n</answer>\\n$\"\n",
@ -489,6 +518,7 @@
" matches = [re.match(pattern, r) for r in responses]\n",
" matches = [re.match(pattern, r) for r in responses]\n",
" return [0.5 if match else 0.0 for match in matches]\n",
" return [0.5 if match else 0.0 for match in matches]\n",
"\n",
"\n",
"\n",
"def soft_format_reward_func(completions, **kwargs) -> list[float]:\n",
"def soft_format_reward_func(completions, **kwargs) -> list[float]:\n",
" \"\"\"Reward function that checks if the completion has a specific format.\"\"\"\n",
" \"\"\"Reward function that checks if the completion has a specific format.\"\"\"\n",
" pattern = r\"<reasoning>.*?</reasoning>\\s*<answer>.*?</answer>\"\n",
" pattern = r\"<reasoning>.*?</reasoning>\\s*<answer>.*?</answer>\"\n",
@ -496,6 +526,7 @@
" matches = [re.match(pattern, r) for r in responses]\n",
" matches = [re.match(pattern, r) for r in responses]\n",
" return [0.5 if match else 0.0 for match in matches]\n",
" return [0.5 if match else 0.0 for match in matches]\n",
"\n",
"\n",
"\n",
"def count_xml(text) -> float:\n",
"def count_xml(text) -> float:\n",
" count = 0.0\n",
" count = 0.0\n",
" if text.count(\"<reasoning>\\n\") == 1:\n",
" if text.count(\"<reasoning>\\n\") == 1:\n",
@ -510,6 +541,7 @@
" count -= (len(text.split(\"\\n</answer>\")[-1]) - 1) * 0.001\n",
" count -= (len(text.split(\"\\n</answer>\")[-1]) - 1) * 0.001\n",
" return count\n",
" return count\n",
"\n",
"\n",
"\n",
"def xmlcount_reward_func(completions, **kwargs) -> list[float]:\n",
"def xmlcount_reward_func(completions, **kwargs) -> list[float]:\n",
" contents = [completion[0][\"content\"] for completion in completions]\n",
" contents = [completion[0][\"content\"] for completion in completions]\n",
" return [count_xml(c) for c in contents]"
" return [count_xml(c) for c in contents]"
@ -540,6 +572,7 @@
"outputs": [],
"outputs": [],
"source": [
"source": [
"from trl import GRPOConfig, GRPOTrainer\n",
"from trl import GRPOConfig, GRPOTrainer\n",
"\n",
"training_args = GRPOConfig(\n",
"training_args = GRPOConfig(\n",
" use_vllm=True, # use vLLM for fast inference!\n",
" use_vllm=True, # use vLLM for fast inference!\n",
" learning_rate=5e-6,\n",
" learning_rate=5e-6,\n",
@ -636,21 +669,30 @@
},
},
"outputs": [],
"outputs": [],
"source": [
"source": [
"text = tokenizer.apply_chat_template([\n",
"text = tokenizer.apply_chat_template(\n",
" [\n",
" {\"role\": \"user\", \"content\": \"How many r's are in strawberry?\"},\n",
" {\"role\": \"user\", \"content\": \"How many r's are in strawberry?\"},\n",
"], tokenize = False, add_generation_prompt = True)\n",
" ],\n",
" tokenize=False,\n",
" add_generation_prompt=True,\n",
")\n",
"\n",
"\n",
"from vllm import SamplingParams\n",
"from vllm import SamplingParams\n",
"\n",
"sampling_params = SamplingParams(\n",
"sampling_params = SamplingParams(\n",
" temperature=0.8,\n",
" temperature=0.8,\n",
" top_p=0.95,\n",
" top_p=0.95,\n",
" max_tokens=1024,\n",
" max_tokens=1024,\n",
")\n",
")\n",
"output = model.fast_generate(\n",
"output = (\n",
" model.fast_generate(\n",
" [text],\n",
" [text],\n",
" sampling_params=sampling_params,\n",
" sampling_params=sampling_params,\n",
" lora_request=None,\n",
" lora_request=None,\n",
")[0].outputs[0].text\n",
" )[0]\n",
" .outputs[0]\n",
" .text\n",
")\n",
"\n",
"\n",
"output"
"output"
]
]
@ -697,22 +739,31 @@
},
},
"outputs": [],
"outputs": [],
"source": [
"source": [
"text = tokenizer.apply_chat_template([\n",
"text = tokenizer.apply_chat_template(\n",
" [\n",
" {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
" {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
" {\"role\": \"user\", \"content\": \"How many r's are in strawberry?\"},\n",
" {\"role\": \"user\", \"content\": \"How many r's are in strawberry?\"},\n",
"], tokenize = False, add_generation_prompt = True)\n",
" ],\n",
" tokenize=False,\n",
" add_generation_prompt=True,\n",
")\n",
"\n",
"\n",
"from vllm import SamplingParams\n",
"from vllm import SamplingParams\n",
"\n",
"sampling_params = SamplingParams(\n",
"sampling_params = SamplingParams(\n",
" temperature=0.8,\n",
" temperature=0.8,\n",
" top_p=0.95,\n",
" top_p=0.95,\n",
" max_tokens=1024,\n",
" max_tokens=1024,\n",
")\n",
")\n",
"output = model.fast_generate(\n",
"output = (\n",
" model.fast_generate(\n",
" text,\n",
" text,\n",
" sampling_params=sampling_params,\n",
" sampling_params=sampling_params,\n",
" lora_request=model.load_lora(\"grpo_saved_lora\"),\n",
" lora_request=model.load_lora(\"grpo_saved_lora\"),\n",
")[0].outputs[0].text\n",
" )[0]\n",
" .outputs[0]\n",
" .text\n",
")\n",
"\n",
"\n",
"output"
"output"
]
]
@ -747,16 +798,36 @@
"outputs": [],
"outputs": [],
"source": [
"source": [
"# Merge to 16bit\n",
"# Merge to 16bit\n",
"if False: model.save_pretrained_merged(\"model\", tokenizer, save_method = \"merged_16bit\",)\n",
"if False:\n",
"if False: model.push_to_hub_merged(\"hf/model\", tokenizer, save_method = \"merged_16bit\", token = \"\")\n",
" model.save_pretrained_merged(\n",
" \"model\",\n",
" tokenizer,\n",
" save_method=\"merged_16bit\",\n",
" )\n",
"if False:\n",
" model.push_to_hub_merged(\n",
" \"hf/model\", tokenizer, save_method=\"merged_16bit\", token=\"\"\n",
" )\n",
"\n",
"\n",
"# Merge to 4bit\n",
"# Merge to 4bit\n",
"if False: model.save_pretrained_merged(\"model\", tokenizer, save_method = \"merged_4bit\",)\n",
"if False:\n",
"if False: model.push_to_hub_merged(\"hf/model\", tokenizer, save_method = \"merged_4bit\", token = \"\")\n",
" model.save_pretrained_merged(\n",
" \"model\",\n",
" tokenizer,\n",
" save_method=\"merged_4bit\",\n",
" )\n",
"if False:\n",
" model.push_to_hub_merged(\"hf/model\", tokenizer, save_method=\"merged_4bit\", token=\"\")\n",
"\n",
"\n",
"# Just LoRA adapters\n",
"# Just LoRA adapters\n",
"if False: model.save_pretrained_merged(\"model\", tokenizer, save_method = \"lora\",)\n",
"if False:\n",
"if False: model.push_to_hub_merged(\"hf/model\", tokenizer, save_method = \"lora\", token = \"\")"
" model.save_pretrained_merged(\n",
" \"model\",\n",
" tokenizer,\n",
" save_method=\"lora\",\n",
" )\n",
"if False:\n",
" model.push_to_hub_merged(\"hf/model\", tokenizer, save_method=\"lora\", token=\"\")"
]
]
},
},
{
{
@ -785,25 +856,40 @@
"outputs": [],
"outputs": [],
"source": [
"source": [
"# Save to 8bit Q8_0\n",
"# Save to 8bit Q8_0\n",
"if False: model.save_pretrained_gguf(\"model\", tokenizer,)\n",
"if False:\n",
" model.save_pretrained_gguf(\n",
" \"model\",\n",
" tokenizer,\n",
" )\n",
"# Remember to go to https://huggingface.co/settings/tokens for a token!\n",
"# Remember to go to https://huggingface.co/settings/tokens for a token!\n",
"# And change hf to your username!\n",
"# And change hf to your username!\n",
"if False: model.push_to_hub_gguf(\"hf/model\", tokenizer, token = \"\")\n",
"if False:\n",
" model.push_to_hub_gguf(\"hf/model\", tokenizer, token=\"\")\n",
"\n",
"\n",
"# Save to 16bit GGUF\n",
"# Save to 16bit GGUF\n",
"if False: model.save_pretrained_gguf(\"model\", tokenizer, quantization_method = \"f16\")\n",
"if False:\n",
"if False: model.push_to_hub_gguf(\"hf/model\", tokenizer, quantization_method = \"f16\", token = \"\")\n",
" model.save_pretrained_gguf(\"model\", tokenizer, quantization_method=\"f16\")\n",
"if False:\n",
" model.push_to_hub_gguf(\"hf/model\", tokenizer, quantization_method=\"f16\", token=\"\")\n",
"\n",
"\n",
"# Save to q4_k_m GGUF\n",
"# Save to q4_k_m GGUF\n",
"if False: model.save_pretrained_gguf(\"model\", tokenizer, quantization_method = \"q4_k_m\")\n",
"if False:\n",
"if False: model.push_to_hub_gguf(\"hf/model\", tokenizer, quantization_method = \"q4_k_m\", token = \"\")\n",
" model.save_pretrained_gguf(\"model\", tokenizer, quantization_method=\"q4_k_m\")\n",
"if False:\n",
" model.push_to_hub_gguf(\n",
" \"hf/model\", tokenizer, quantization_method=\"q4_k_m\", token=\"\"\n",
" )\n",
"\n",
"\n",
"# Save to multiple GGUF options - much faster if you want multiple!\n",
"# Save to multiple GGUF options - much faster if you want multiple!\n",
"if False:\n",
"if False:\n",
" model.push_to_hub_gguf(\n",
" model.push_to_hub_gguf(\n",
" \"hf/model\", # Change hf to your username!\n",
" \"hf/model\", # Change hf to your username!\n",
" tokenizer,\n",
" tokenizer,\n",
" quantization_method = [\"q4_k_m\", \"q8_0\", \"q5_k_m\",],\n",
" quantization_method=[\n",
" \"q4_k_m\",\n",
" \"q8_0\",\n",
" \"q5_k_m\",\n",
" ],\n",
" token=\"\",\n",
" token=\"\",\n",
" )"
" )"
]
]