Big thanks to author(s) for the great reference code!main
parent
04593fa8fd
commit
f6b6cca2ce
@ -0,0 +1,243 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Load finetuned model ans say hello"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## ✅ Utils"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from unsloth import FastLanguageModel\n",
|
||||
"from vllm import SamplingParams\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def load_model(\n",
|
||||
" # model_name=\"meta-llama/Llama-3.2-1B-Instruct\",\n",
|
||||
" model_name=\"meta-llama/meta-Llama-3.1-8B-Instruct\",\n",
|
||||
" lora_path=\"../trainer_output_meta-llama_Llama-3.1-8B-Instruct_gpu1_20250326_134236/checkpoint-101\",\n",
|
||||
" max_seq_length=8192,\n",
|
||||
"):\n",
|
||||
" \"\"\"Load model and tokenizer with optional LoRA weights.\"\"\"\n",
|
||||
" # Load base model\n",
|
||||
" model, tokenizer = FastLanguageModel.from_pretrained(\n",
|
||||
" model_name=model_name,\n",
|
||||
" max_seq_length=max_seq_length,\n",
|
||||
" load_in_4bit=True,\n",
|
||||
" fast_inference=True,\n",
|
||||
" max_lora_rank=64,\n",
|
||||
" gpu_memory_utilization=0.6,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # Setup LoRA if path provided\n",
|
||||
" if lora_path:\n",
|
||||
" model = FastLanguageModel.get_peft_model(\n",
|
||||
" model,\n",
|
||||
" r=64,\n",
|
||||
" target_modules=[\n",
|
||||
" \"q_proj\",\n",
|
||||
" \"k_proj\",\n",
|
||||
" \"v_proj\",\n",
|
||||
" \"o_proj\",\n",
|
||||
" \"gate_proj\",\n",
|
||||
" \"up_proj\",\n",
|
||||
" \"down_proj\",\n",
|
||||
" ],\n",
|
||||
" lora_alpha=64,\n",
|
||||
" use_gradient_checkpointing=True,\n",
|
||||
" random_state=3407,\n",
|
||||
" )\n",
|
||||
" model.load_lora(lora_path)\n",
|
||||
"\n",
|
||||
" return model, tokenizer\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def get_sampling_params(\n",
|
||||
" temperature=0.7,\n",
|
||||
" top_p=0.95,\n",
|
||||
" max_tokens=4096,\n",
|
||||
"):\n",
|
||||
" \"\"\"Get sampling parameters for text generation.\"\"\"\n",
|
||||
" return SamplingParams(\n",
|
||||
" temperature=temperature,\n",
|
||||
" top_p=top_p,\n",
|
||||
" max_tokens=max_tokens,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def generate_response(prompt, model, tokenizer, sampling_params):\n",
|
||||
" \"\"\"Generate a response from the model.\"\"\"\n",
|
||||
" inputs = tokenizer.apply_chat_template(\n",
|
||||
" [{\"role\": \"user\", \"content\": prompt}],\n",
|
||||
" tokenize=False,\n",
|
||||
" add_generation_prompt=True,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" outputs = model.fast_generate([inputs], sampling_params=sampling_params)\n",
|
||||
"\n",
|
||||
" if hasattr(outputs[0], \"outputs\"):\n",
|
||||
" response_text = outputs[0].outputs[0].text\n",
|
||||
" else:\n",
|
||||
" response_text = outputs[0]\n",
|
||||
"\n",
|
||||
" return response_text"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model, tokenizer = load_model() # Using default hardcoded path\n",
|
||||
"sampling_params = get_sampling_params()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"response = generate_response(\"Hi! How are you?\", model, tokenizer, sampling_params)\n",
|
||||
"print(\"\\nModel response:\")\n",
|
||||
"print(response)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Merge lora and save model to 16 bit, then test load inference\n",
|
||||
"if False:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model.save_pretrained_merged(\n",
|
||||
" \"model\",\n",
|
||||
" tokenizer,\n",
|
||||
" save_method=\"merged_16bit\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## ✅ Test load merged model 16bit"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Test load merged model\n",
|
||||
"model, tokenizer = load_model(model_name=\"model\", lora_path=None)\n",
|
||||
"sampling_params = get_sampling_params()\n",
|
||||
"response = generate_response(\"Hi! How are you?\", model, tokenizer, sampling_params)\n",
|
||||
"print(\"\\nModel response:\")\n",
|
||||
"print(response)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# ❌ Save model to Ollama format\n",
|
||||
"- bug no lllama-quantize on wsl, fix later\n",
|
||||
"https://docs.unsloth.ai/basics/runninand-saving-models/saving-to-gguf\n",
|
||||
"```bash\n",
|
||||
"git clone https://github.com/ggml-org/llama.cpp\n",
|
||||
"cd llama.cpp\n",
|
||||
"cmake -B build\n",
|
||||
"cmake --build build --config Release\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Save to 8bit Q8_0\n",
|
||||
"if True:\n",
|
||||
" model.save_pretrained_gguf(\n",
|
||||
" \"model-gguf\",\n",
|
||||
" tokenizer,\n",
|
||||
" )\n",
|
||||
"# Remember to go to https://huggingface.co/settings/tokens for a token!\n",
|
||||
"# And change hf to your username!\n",
|
||||
"if False:\n",
|
||||
" model.push_to_hub_gguf(\"hf/model\", tokenizer, token=\"\")\n",
|
||||
"\n",
|
||||
"# Save to 16bit GGUF\n",
|
||||
"if False:\n",
|
||||
" model.save_pretrained_gguf(\"model\", tokenizer, quantization_method=\"f16\")\n",
|
||||
"if False:\n",
|
||||
" model.push_to_hub_gguf(\"hf/model\", tokenizer, quantization_method=\"f16\", token=\"\")\n",
|
||||
"\n",
|
||||
"# Save to q4_k_m GGUF\n",
|
||||
"if False:\n",
|
||||
" model.save_pretrained_gguf(\"model\", tokenizer, quantization_method=\"q4_k_m\")\n",
|
||||
"if False:\n",
|
||||
" model.push_to_hub_gguf(\n",
|
||||
" \"hf/model\", tokenizer, quantization_method=\"q4_k_m\", token=\"\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"# Save to multiple GGUF options - much faster if you want multiple!\n",
|
||||
"if False:\n",
|
||||
" model.push_to_hub_gguf(\n",
|
||||
" \"hf/model\", # Change hf to your username!\n",
|
||||
" tokenizer,\n",
|
||||
" quantization_method=[\n",
|
||||
" \"q4_k_m\",\n",
|
||||
" \"q8_0\",\n",
|
||||
" \"q5_k_m\",\n",
|
||||
" ],\n",
|
||||
" token=\"\",\n",
|
||||
" )"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "deepsearch-py311",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -0,0 +1,213 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Train R1 Distil\n",
|
||||
"This notebook is for caching the model loading so that It wouldn't take so long to reload every time I change the trainer source code"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Utils"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
"\n",
|
||||
"sys.path.append(\"..\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"from unsloth import FastLanguageModel, is_bfloat16_supported\n",
|
||||
"\n",
|
||||
"import src.UnslothGRPOTrainerTemp as UnslothGRPOTrainerTemp\n",
|
||||
"from src.config import (\n",
|
||||
" MODEL_CONFIG,\n",
|
||||
" MODEL_NAME,\n",
|
||||
" OUTPUT_DIR,\n",
|
||||
" TRAINING_CONFIG,\n",
|
||||
" get_sampling_params,\n",
|
||||
" init_training_dirs,\n",
|
||||
" logger,\n",
|
||||
" update_log_path,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Import reward functions\n",
|
||||
"from src.rl_helpers_r1_distil import (\n",
|
||||
" build_reward_correctness_fn,\n",
|
||||
" get_qa_dataset,\n",
|
||||
" reward_exact_match_chunk_query,\n",
|
||||
" reward_formatting,\n",
|
||||
" reward_retry_behavior,\n",
|
||||
" run_agent,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Initialize training directories\n",
|
||||
"paths = init_training_dirs()\n",
|
||||
"\n",
|
||||
"# Update logger to use the training directory\n",
|
||||
"update_log_path(paths[\"log_dir\"])\n",
|
||||
"logger.info(f\"Training output directory: {paths['output_dir']}\")\n",
|
||||
"logger.info(f\"Logs are being saved to both ./logs and {paths['log_dir']}\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Initialize model and tokenizer\n",
|
||||
"logger.info(f\"Initializing model {MODEL_NAME}\")\n",
|
||||
"model, tokenizer = FastLanguageModel.from_pretrained(\n",
|
||||
" model_name=MODEL_NAME,\n",
|
||||
" max_seq_length=MODEL_CONFIG[\"max_seq_length\"],\n",
|
||||
" load_in_4bit=True, # False for LoRA 16bit\n",
|
||||
" fast_inference=True, # Enable vLLM fast inference\n",
|
||||
" max_lora_rank=MODEL_CONFIG[\"lora_rank\"],\n",
|
||||
" gpu_memory_utilization=MODEL_CONFIG[\"gpu_memory_utilization\"],\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Setup LoRA\n",
|
||||
"logger.info(\"Setting up LoRA adapter\")\n",
|
||||
"model = FastLanguageModel.get_peft_model(\n",
|
||||
" model,\n",
|
||||
" r=MODEL_CONFIG[\"lora_rank\"],\n",
|
||||
" target_modules=MODEL_CONFIG[\"target_modules\"],\n",
|
||||
" lora_alpha=MODEL_CONFIG[\"lora_rank\"],\n",
|
||||
" use_gradient_checkpointing=True, # Enable long context finetuning\n",
|
||||
" random_state=3407,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Load datasets\n",
|
||||
"logger.info(\"Loading datasets\")\n",
|
||||
"train_dataset, test_dataset = get_qa_dataset()\n",
|
||||
"logger.info(\n",
|
||||
" f\"Loaded {len(train_dataset)} training examples and {len(test_dataset)} test examples\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Setup training arguments\n",
|
||||
"logger.info(\"Setting up training arguments\")\n",
|
||||
"training_args = UnslothGRPOTrainerTemp.UnslothGRPOConfig(\n",
|
||||
" use_vllm=True, # use vLLM for fast inference!\n",
|
||||
" use_agentic_generate=True, # use agentic generation\n",
|
||||
" **TRAINING_CONFIG,\n",
|
||||
" bf16=is_bfloat16_supported(),\n",
|
||||
" fp16=not is_bfloat16_supported(),\n",
|
||||
" output_dir=OUTPUT_DIR,\n",
|
||||
" # report_to=\"tensorboard\", # ❓ Does't have billions of tensorboard files if set report to right here\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Setup model generation functions\n",
|
||||
"def agentic_generate(\n",
|
||||
" prompts: list,\n",
|
||||
" generate_fn,\n",
|
||||
" max_generations: int = 10,\n",
|
||||
"):\n",
|
||||
" return run_agent(generate_fn, tokenizer, prompts, max_generations)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"model.agentic_generate = agentic_generate\n",
|
||||
"\n",
|
||||
"# Setup verifier\n",
|
||||
"logger.info(\"Setting up verifier\")\n",
|
||||
"verifier_sampling_params = get_sampling_params(temperature=0.1)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def verifier_generate_fn(inputs):\n",
|
||||
" return model.fast_generate(\n",
|
||||
" inputs,\n",
|
||||
" sampling_params=verifier_sampling_params,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Setup trainer\n",
|
||||
"logger.info(\"Initializing trainer\")\n",
|
||||
"trainer = UnslothGRPOTrainerTemp.UnslothGRPOTrainer(\n",
|
||||
" model=model,\n",
|
||||
" processing_class=tokenizer,\n",
|
||||
" reward_funcs=[\n",
|
||||
" build_reward_correctness_fn(\n",
|
||||
" verifier_generate_fn,\n",
|
||||
" tokenizer,\n",
|
||||
" log_file=os.path.join(paths[\"log_dir\"], \"qa_log.txt\"),\n",
|
||||
" ),\n",
|
||||
" reward_formatting,\n",
|
||||
" reward_retry_behavior,\n",
|
||||
" reward_exact_match_chunk_query,\n",
|
||||
" ],\n",
|
||||
" args=training_args,\n",
|
||||
" train_dataset=train_dataset,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(\"Trainer initialized successfully! Starting training...\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Load model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Train the model\n",
|
||||
"if __name__ == \"__main__\":\n",
|
||||
" logger.info(\"Starting training\")\n",
|
||||
" trainer.train()\n",
|
||||
" logger.info(\"Training completed\")\n",
|
||||
" logger.info(f\"Model saved to {OUTPUT_DIR}\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "deepsearch-py311",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -0,0 +1,213 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Train R1 Distil\n",
|
||||
"This notebook is for caching the model loading so that It wouldn't take so long to reload every time I change the trainer source code"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Utils"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
"\n",
|
||||
"sys.path.append(\"..\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"from unsloth import FastLanguageModel, is_bfloat16_supported\n",
|
||||
"\n",
|
||||
"import src.UnslothGRPOTrainerTemp as UnslothGRPOTrainerTemp\n",
|
||||
"from src.config import (\n",
|
||||
" MODEL_CONFIG,\n",
|
||||
" MODEL_NAME,\n",
|
||||
" OUTPUT_DIR,\n",
|
||||
" TRAINING_CONFIG,\n",
|
||||
" get_sampling_params,\n",
|
||||
" init_training_dirs,\n",
|
||||
" logger,\n",
|
||||
" update_log_path,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Import reward functions\n",
|
||||
"from src.rl_helpers import (\n",
|
||||
" build_reward_correctness_fn,\n",
|
||||
" get_qa_dataset,\n",
|
||||
" reward_exact_match_chunk_query,\n",
|
||||
" reward_formatting,\n",
|
||||
" reward_retry_behavior,\n",
|
||||
" run_agent,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Initialize training directories\n",
|
||||
"paths = init_training_dirs()\n",
|
||||
"\n",
|
||||
"# Update logger to use the training directory\n",
|
||||
"update_log_path(paths[\"log_dir\"])\n",
|
||||
"logger.info(f\"Training output directory: {paths['output_dir']}\")\n",
|
||||
"logger.info(f\"Logs are being saved to both ./logs and {paths['log_dir']}\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Initialize model and tokenizer\n",
|
||||
"logger.info(f\"Initializing model {MODEL_NAME}\")\n",
|
||||
"model, tokenizer = FastLanguageModel.from_pretrained(\n",
|
||||
" model_name=MODEL_NAME,\n",
|
||||
" max_seq_length=MODEL_CONFIG[\"max_seq_length\"],\n",
|
||||
" load_in_4bit=True, # False for LoRA 16bit\n",
|
||||
" fast_inference=True, # Enable vLLM fast inference\n",
|
||||
" max_lora_rank=MODEL_CONFIG[\"lora_rank\"],\n",
|
||||
" gpu_memory_utilization=MODEL_CONFIG[\"gpu_memory_utilization\"],\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Setup LoRA\n",
|
||||
"logger.info(\"Setting up LoRA adapter\")\n",
|
||||
"model = FastLanguageModel.get_peft_model(\n",
|
||||
" model,\n",
|
||||
" r=MODEL_CONFIG[\"lora_rank\"],\n",
|
||||
" target_modules=MODEL_CONFIG[\"target_modules\"],\n",
|
||||
" lora_alpha=MODEL_CONFIG[\"lora_rank\"],\n",
|
||||
" use_gradient_checkpointing=True, # Enable long context finetuning\n",
|
||||
" random_state=3407,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Load datasets\n",
|
||||
"logger.info(\"Loading datasets\")\n",
|
||||
"train_dataset, test_dataset = get_qa_dataset()\n",
|
||||
"logger.info(\n",
|
||||
" f\"Loaded {len(train_dataset)} training examples and {len(test_dataset)} test examples\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Setup training arguments\n",
|
||||
"logger.info(\"Setting up training arguments\")\n",
|
||||
"training_args = UnslothGRPOTrainerTemp.UnslothGRPOConfig(\n",
|
||||
" use_vllm=True, # use vLLM for fast inference!\n",
|
||||
" use_agentic_generate=True, # use agentic generation\n",
|
||||
" **TRAINING_CONFIG,\n",
|
||||
" bf16=is_bfloat16_supported(),\n",
|
||||
" fp16=not is_bfloat16_supported(),\n",
|
||||
" output_dir=OUTPUT_DIR,\n",
|
||||
" # report_to=\"tensorboard\", # ❓ Does't have billions of tensorboard files if set report to right here\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Setup model generation functions\n",
|
||||
"def agentic_generate(\n",
|
||||
" prompts: list,\n",
|
||||
" generate_fn,\n",
|
||||
" max_generations: int = 10,\n",
|
||||
"):\n",
|
||||
" return run_agent(generate_fn, tokenizer, prompts, max_generations)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"model.agentic_generate = agentic_generate\n",
|
||||
"\n",
|
||||
"# Setup verifier\n",
|
||||
"logger.info(\"Setting up verifier\")\n",
|
||||
"verifier_sampling_params = get_sampling_params(temperature=0.1)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def verifier_generate_fn(inputs):\n",
|
||||
" return model.fast_generate(\n",
|
||||
" inputs,\n",
|
||||
" sampling_params=verifier_sampling_params,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Setup trainer\n",
|
||||
"logger.info(\"Initializing trainer\")\n",
|
||||
"trainer = UnslothGRPOTrainerTemp.UnslothGRPOTrainer(\n",
|
||||
" model=model,\n",
|
||||
" processing_class=tokenizer,\n",
|
||||
" reward_funcs=[\n",
|
||||
" build_reward_correctness_fn(\n",
|
||||
" verifier_generate_fn,\n",
|
||||
" tokenizer,\n",
|
||||
" log_file=os.path.join(paths[\"log_dir\"], \"qa_log.txt\"),\n",
|
||||
" ),\n",
|
||||
" reward_formatting,\n",
|
||||
" reward_retry_behavior,\n",
|
||||
" reward_exact_match_chunk_query,\n",
|
||||
" ],\n",
|
||||
" args=training_args,\n",
|
||||
" train_dataset=train_dataset,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(\"Trainer initialized successfully! Starting training...\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Load model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Train the model\n",
|
||||
"if __name__ == \"__main__\":\n",
|
||||
" logger.info(\"Starting training\")\n",
|
||||
" trainer.train()\n",
|
||||
" logger.info(\"Training completed\")\n",
|
||||
" logger.info(f\"Model saved to {OUTPUT_DIR}\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "deepsearch-py311",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -0,0 +1,277 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "IqM-T1RTzY6C"
|
||||
},
|
||||
"source": [
|
||||
"To run this, press \"*Runtime*\" and press \"*Run all*\" on a **free** Tesla T4 Google Colab instance!\n",
|
||||
"<div class=\"align-center\">\n",
|
||||
" <a href=\"https://github.com/unslothai/unsloth\"><img src=\"https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png\" width=\"115\"></a>\n",
|
||||
" <a href=\"https://discord.gg/u54VK8m8tk\"><img src=\"https://github.com/unslothai/unsloth/raw/main/images/Discord button.png\" width=\"145\"></a>\n",
|
||||
" <a href=\"https://ko-fi.com/unsloth\"><img src=\"https://github.com/unslothai/unsloth/raw/main/images/Kofi button.png\" width=\"145\"></a></a> Join Discord if you need help + support us if you can!\n",
|
||||
"</div>\n",
|
||||
"\n",
|
||||
"To install Unsloth on your own computer, follow the installation instructions on our Github page [here](https://github.com/unslothai/unsloth#installation-instructions---conda)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "2eSvM9zX_2d3"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%%capture\n",
|
||||
"# Installs Unsloth, Xformers (Flash Attention) and all other packages!\n",
|
||||
"!pip install unsloth\n",
|
||||
"# Get latest Unsloth\n",
|
||||
"!pip install --upgrade --no-deps \"unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "r2v_X2fA0Df5"
|
||||
},
|
||||
"source": [
|
||||
"If you want to finetune Llama-3 2x faster and use 70% less VRAM, go to our [finetuning notebook](https://colab.research.google.com/drive/135ced7oHytdxu3N2DNe1Z0kqjyYIkDXp?usp=sharing)!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 371,
|
||||
"referenced_widgets": [
|
||||
"7ca2facea2414549ab2a0a3fd07a0723",
|
||||
"78676eb3a75f45ed860b488a42fb9bc5",
|
||||
"b69b75dd3b5745f4a7f869f7ba6593d4",
|
||||
"ea10da0d677849c28bfdfc9206ac1b73",
|
||||
"493f6a1af5e94b5eb88a4935361a88c3",
|
||||
"1a4ef35565f54388afdf1839055eb5c8",
|
||||
"07ae7e9f6861403ea8064a98f4063754",
|
||||
"fc2ce36bb32f4c648ccb06d65a4e2aea",
|
||||
"bf44ec9413b24dda932f575ebe32b222",
|
||||
"2ad28d8e556b4e17b0859aa3683d31c6",
|
||||
"e32220b1d3f14386834bcd3736ac9888",
|
||||
"d9c5466214ed4f06876a89c64c29048f",
|
||||
"1e0c170f075b43ad9e2d39fcdbaa0a06",
|
||||
"00abd0bd091f4258b7f61d80ecb3118e",
|
||||
"0f2c8641a38043d2a218ec0809e7a7d3",
|
||||
"948972fa1a2e400097b590ad47ff6f3d",
|
||||
"47a13a9b02ad4181b74cc1f5ddeb3f3f",
|
||||
"7c01111484b14a5999d48393085b0264",
|
||||
"fb1a288281fc446788fdd0f432b4732c",
|
||||
"2f026e4c1ff34aafbf7c5deff6c1e66f",
|
||||
"acc52213c033476885ee30fb92f55e23",
|
||||
"9caec4d2cb6e4a72b7155912787d3139",
|
||||
"2abdab3fdd4f4a70a8faaa7b0e39f811",
|
||||
"4d432e4349964027bae76c34f5734fcf",
|
||||
"630c170c04494811b21992d6886a3015",
|
||||
"cb9a92b76a1d4b189e7ed783bde524c9",
|
||||
"ccc49354213c4c1394855986a0e07483",
|
||||
"7bf239e457f04e15b4b0e458f0dd1e19",
|
||||
"f5c6558ef929438eb117159de0285a58",
|
||||
"447aa145c09c4c6f8bd9ce3ea8b5cd44",
|
||||
"9c82690bd07346d1ab875bdf5ed80154",
|
||||
"5200dae3ba2346c1bca27530b82b38b4",
|
||||
"c7a3b0cba7ee49ba929e0038d9c03e98",
|
||||
"36250ac65fcb4756a4368f98902a1617",
|
||||
"705a1673de934cce8418f7893fbf780a",
|
||||
"1f7ed3fc3a824e119cf559af8bc7c370",
|
||||
"cafc86be3bd94bc79cb07bc94461ec5d",
|
||||
"0c190d903a38402ab452de0a0100a8eb",
|
||||
"b3362bda7942432692049558d0cdf14e",
|
||||
"4afb54f2995447e68ae3a8129325ffdc",
|
||||
"31c7337bfd504409992b7e04b713c243",
|
||||
"9b5a45647bd74e8c878cf6ba43b27190",
|
||||
"b9c9010af82b4678af43a20546f3ad49",
|
||||
"9b818c8e8bac40edb97abf8a266ab28e",
|
||||
"ebaa8d22074d495ea675f51dc2d5a4d6",
|
||||
"8fecfa968d55466ba8f42730d32bd9b4",
|
||||
"dd0e8e4cb46945df8dbec4fefd720358",
|
||||
"60a80b3316c14ce79945e06cb190be39",
|
||||
"4f0e06a2e2204dad9f29a1697edd4218",
|
||||
"66527058da914c73b48221011a2b605a",
|
||||
"58413234f68b484693f0719166942214",
|
||||
"84b3814386ee4d04b9817f0da2108725",
|
||||
"01ede422dc5f4d4da7f913880db4984d",
|
||||
"2e9f87f895344785855cca39bb64d5e2",
|
||||
"5512e17f3864445a875e640937f296ef",
|
||||
"932920c0ee834e3e8ffe53ebf7833fa7",
|
||||
"cb65c93ecef84c2c947c67481326691c",
|
||||
"73c3f04df8044f34ab45b92c174e121e",
|
||||
"13b8be4a5af14bfdbbce24d185187906",
|
||||
"21d8ebf7cbfd4e2b9db111d163aad8b8",
|
||||
"3122a97298124d0ab0e6297960c132e2",
|
||||
"cc829a40832849f586b077f6d8e5795b",
|
||||
"3f5a8b6d049244959c1558c56483f643",
|
||||
"11fea1d7e4a8426d98165fa5ca7fc650",
|
||||
"fc3715b03c5249438965886892ba07bb",
|
||||
"05d6f6c3efed4ecf9893c8aa86fe4176"
|
||||
]
|
||||
},
|
||||
"id": "QmUBVEnvCDJv",
|
||||
"outputId": "14ca09b9-e1ff-4f91-b98c-7a1ed22eef3a"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from unsloth import FastLanguageModel\n",
|
||||
"\n",
|
||||
"# 4bit pre quantized models we support for 4x faster downloading + no OOMs.\n",
|
||||
"fourbit_models = [\n",
|
||||
" \"unsloth/mistral-7b-instruct-v0.2-bnb-4bit\",\n",
|
||||
" \"unsloth/gemma-7b-it-bnb-4bit\",\n",
|
||||
"] # More models at https://huggingface.co/unsloth\n",
|
||||
"\n",
|
||||
"model, tokenizer = FastLanguageModel.from_pretrained(\n",
|
||||
" model_name = \"unsloth/Meta-Llama-3.1-8B-Instruct\",\n",
|
||||
" max_seq_length = 8192,\n",
|
||||
" load_in_4bit = True,\n",
|
||||
" # token = \"hf_...\", # use one if using gated models like meta-llama/Llama-2-7b-hf\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "Yvwf0OEtchS_"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import TextStreamer\n",
|
||||
"from unsloth.chat_templates import get_chat_template\n",
|
||||
"tokenizer = get_chat_template(\n",
|
||||
" tokenizer,\n",
|
||||
" chat_template = \"llama-3.1\",\n",
|
||||
" mapping = {\"role\" : \"from\", \"content\" : \"value\", \"user\" : \"human\", \"assistant\" : \"gpt\"}, # ShareGPT style\n",
|
||||
")\n",
|
||||
"FastLanguageModel.for_inference(model) # Enable native 2x faster inference"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "qTRSvafleB2Q"
|
||||
},
|
||||
"source": [
|
||||
"Change the \"value\" part to call the model!\n",
|
||||
"\n",
|
||||
"Unsloth makes inference natively 2x faster!! No need to change or do anything!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "e2pEuRb1r2Vg",
|
||||
"outputId": "5cfe649f-cd72-469d-b73e-88748df49690"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"messages = [\n",
|
||||
" # EDIT HERE!\n",
|
||||
" {\"from\": \"human\", \"value\": \"Continue the fibonnaci sequence: 1, 1, 2, 3, 5, 8,\"},\n",
|
||||
"]\n",
|
||||
"inputs = tokenizer.apply_chat_template(messages, tokenize = True, add_generation_prompt = True, return_tensors = \"pt\").to(\"cuda\")\n",
|
||||
"\n",
|
||||
"text_streamer = TextStreamer(tokenizer)\n",
|
||||
"_ = model.generate(input_ids = inputs, streamer = text_streamer, max_new_tokens = 1024, use_cache = True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "yo7qGcu-cqCW",
|
||||
"outputId": "9959df0e-e04b-4af0-c624-d640f896f441"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"messages = [\n",
|
||||
" {\"from\": \"human\", \"value\": \"Describe the tallest tower in the world.\"},\n",
|
||||
"]\n",
|
||||
"inputs = tokenizer.apply_chat_template(messages, tokenize = True, add_generation_prompt = True, return_tensors = \"pt\").to(\"cuda\")\n",
|
||||
"\n",
|
||||
"text_streamer = TextStreamer(tokenizer)\n",
|
||||
"_ = model.generate(input_ids = inputs, streamer = text_streamer, max_new_tokens = 1024, use_cache = True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "xmQcutr7cxmc",
|
||||
"outputId": "2cb3f9af-7db6-40cf-9b21-e67a3f841adb"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"messages = [\n",
|
||||
" {\"from\": \"human\", \"value\": \"What is Unsloth?\"},\n",
|
||||
"]\n",
|
||||
"inputs = tokenizer.apply_chat_template(messages, tokenize = True, add_generation_prompt = True, return_tensors = \"pt\").to(\"cuda\")\n",
|
||||
"\n",
|
||||
"text_streamer = TextStreamer(tokenizer)\n",
|
||||
"_ = model.generate(input_ids = inputs, streamer = text_streamer, max_new_tokens = 1024, use_cache = True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "Zt9CHJqO6p30"
|
||||
},
|
||||
"source": [
|
||||
"And we're done! If you have any questions on Unsloth, we have a [Discord](https://discord.gg/u54VK8m8tk) channel! If you find any bugs or want to keep updated with the latest LLM stuff, or need help, join projects etc, feel free to join our Discord!\n",
|
||||
"\n",
|
||||
"If you want to finetune Llama-3 2x faster and use 70% less VRAM, go to our [finetuning notebook](https://colab.research.google.com/drive/135ced7oHytdxu3N2DNe1Z0kqjyYIkDXp?usp=sharing)!\n",
|
||||
"\n",
|
||||
"Some other links:\n",
|
||||
"1. Zephyr DPO 2x faster [free Colab](https://colab.research.google.com/drive/15vttTpzzVXv_tJwEk-hIcQ0S9FcEWvwP?usp=sharing)\n",
|
||||
"2. Llama 7b 2x faster [free Colab](https://colab.research.google.com/drive/1lBzz5KeZJKXjvivbYvmGarix9Ao6Wxe5?usp=sharing)\n",
|
||||
"3. TinyLlama 4x faster full Alpaca 52K in 1 hour [free Colab](https://colab.research.google.com/drive/1AZghoNBQaMDgWJpi4RbffGM1h6raLUj9?usp=sharing)\n",
|
||||
"4. CodeLlama 34b 2x faster [A100 on Colab](https://colab.research.google.com/drive/1y7A0AxE3y8gdj4AVkl2aZX47Xu3P1wJT?usp=sharing)\n",
|
||||
"5. Mistral 7b [free Kaggle version](https://www.kaggle.com/code/danielhanchen/kaggle-mistral-7b-unsloth-notebook)\n",
|
||||
"6. We also did a [blog](https://huggingface.co/blog/unsloth-trl) with 🤗 HuggingFace, and we're in the TRL [docs](https://huggingface.co/docs/trl/main/en/sft_trainer#accelerate-fine-tuning-2x-using-unsloth)!\n",
|
||||
"7. Text completions like novel writing [notebook](https://colab.research.google.com/drive/1ef-tab5bhkvWmBOObepl1WgJvfvSzn5Q?usp=sharing)\n",
|
||||
"9. Gemma 6 trillion tokens is 2.5x faster! [free Colab](https://colab.research.google.com/drive/10NbwlsRChbma1v55m8LAPYG15uQv6HLo?usp=sharing)\n",
|
||||
"\n",
|
||||
"<div class=\"align-center\">\n",
|
||||
" <a href=\"https://github.com/unslothai/unsloth\"><img src=\"https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png\" width=\"115\"></a>\n",
|
||||
" <a href=\"https://discord.gg/u54VK8m8tk\"><img src=\"https://github.com/unslothai/unsloth/raw/main/images/Discord.png\" width=\"145\"></a>\n",
|
||||
" <a href=\"https://ko-fi.com/unsloth\"><img src=\"https://github.com/unslothai/unsloth/raw/main/images/Kofi button.png\" width=\"145\"></a></a> Support our work if you can! Thanks!\n",
|
||||
"</div>"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"accelerator": "GPU",
|
||||
"colab": {
|
||||
"gpuType": "T4",
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
Loading…
Reference in new issue