You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
244 lines
6.7 KiB
244 lines
6.7 KiB
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Load finetuned model ans say hello"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## ✅ Utils"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from unsloth import FastLanguageModel\n",
|
|
"from vllm import SamplingParams\n",
|
|
"\n",
|
|
"\n",
|
|
"def load_model(\n",
|
|
" # model_name=\"meta-llama/Llama-3.2-1B-Instruct\",\n",
|
|
" model_name=\"meta-llama/meta-Llama-3.1-8B-Instruct\",\n",
|
|
" lora_path=\"../trainer_output_meta-llama_Llama-3.1-8B-Instruct_gpu1_20250326_134236/checkpoint-101\",\n",
|
|
" max_seq_length=8192,\n",
|
|
"):\n",
|
|
" \"\"\"Load model and tokenizer with optional LoRA weights.\"\"\"\n",
|
|
" # Load base model\n",
|
|
" model, tokenizer = FastLanguageModel.from_pretrained(\n",
|
|
" model_name=model_name,\n",
|
|
" max_seq_length=max_seq_length,\n",
|
|
" load_in_4bit=True,\n",
|
|
" fast_inference=True,\n",
|
|
" max_lora_rank=64,\n",
|
|
" gpu_memory_utilization=0.6,\n",
|
|
" )\n",
|
|
"\n",
|
|
" # Setup LoRA if path provided\n",
|
|
" if lora_path:\n",
|
|
" model = FastLanguageModel.get_peft_model(\n",
|
|
" model,\n",
|
|
" r=64,\n",
|
|
" target_modules=[\n",
|
|
" \"q_proj\",\n",
|
|
" \"k_proj\",\n",
|
|
" \"v_proj\",\n",
|
|
" \"o_proj\",\n",
|
|
" \"gate_proj\",\n",
|
|
" \"up_proj\",\n",
|
|
" \"down_proj\",\n",
|
|
" ],\n",
|
|
" lora_alpha=64,\n",
|
|
" use_gradient_checkpointing=True,\n",
|
|
" random_state=3407,\n",
|
|
" )\n",
|
|
" model.load_lora(lora_path)\n",
|
|
"\n",
|
|
" return model, tokenizer\n",
|
|
"\n",
|
|
"\n",
|
|
"def get_sampling_params(\n",
|
|
" temperature=0.7,\n",
|
|
" top_p=0.95,\n",
|
|
" max_tokens=4096,\n",
|
|
"):\n",
|
|
" \"\"\"Get sampling parameters for text generation.\"\"\"\n",
|
|
" return SamplingParams(\n",
|
|
" temperature=temperature,\n",
|
|
" top_p=top_p,\n",
|
|
" max_tokens=max_tokens,\n",
|
|
" )\n",
|
|
"\n",
|
|
"\n",
|
|
"def generate_response(prompt, model, tokenizer, sampling_params):\n",
|
|
" \"\"\"Generate a response from the model.\"\"\"\n",
|
|
" inputs = tokenizer.apply_chat_template(\n",
|
|
" [{\"role\": \"user\", \"content\": prompt}],\n",
|
|
" tokenize=False,\n",
|
|
" add_generation_prompt=True,\n",
|
|
" )\n",
|
|
"\n",
|
|
" outputs = model.fast_generate([inputs], sampling_params=sampling_params)\n",
|
|
"\n",
|
|
" if hasattr(outputs[0], \"outputs\"):\n",
|
|
" response_text = outputs[0].outputs[0].text\n",
|
|
" else:\n",
|
|
" response_text = outputs[0]\n",
|
|
"\n",
|
|
" return response_text"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"model, tokenizer = load_model() # Using default hardcoded path\n",
|
|
"sampling_params = get_sampling_params()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"response = generate_response(\"Hi! How are you?\", model, tokenizer, sampling_params)\n",
|
|
"print(\"\\nModel response:\")\n",
|
|
"print(response)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Merge lora and save model to 16 bit, then test load inference\n",
|
|
"if False:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"model.save_pretrained_merged(\n",
|
|
" \"model\",\n",
|
|
" tokenizer,\n",
|
|
" save_method=\"merged_16bit\",\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## ✅ Test load merged model 16bit"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Test load merged model\n",
|
|
"model, tokenizer = load_model(model_name=\"model\", lora_path=None)\n",
|
|
"sampling_params = get_sampling_params()\n",
|
|
"response = generate_response(\"Hi! How are you?\", model, tokenizer, sampling_params)\n",
|
|
"print(\"\\nModel response:\")\n",
|
|
"print(response)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# ❌ Save model to Ollama format\n",
|
|
"- bug no lllama-quantize on wsl, fix later\n",
|
|
"https://docs.unsloth.ai/basics/runninand-saving-models/saving-to-gguf\n",
|
|
"```bash\n",
|
|
"git clone https://github.com/ggml-org/llama.cpp\n",
|
|
"cd llama.cpp\n",
|
|
"cmake -B build\n",
|
|
"cmake --build build --config Release\n",
|
|
"```"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Save to 8bit Q8_0\n",
|
|
"if True:\n",
|
|
" model.save_pretrained_gguf(\n",
|
|
" \"model-gguf\",\n",
|
|
" tokenizer,\n",
|
|
" )\n",
|
|
"# Remember to go to https://huggingface.co/settings/tokens for a token!\n",
|
|
"# And change hf to your username!\n",
|
|
"if False:\n",
|
|
" model.push_to_hub_gguf(\"hf/model\", tokenizer, token=\"\")\n",
|
|
"\n",
|
|
"# Save to 16bit GGUF\n",
|
|
"if False:\n",
|
|
" model.save_pretrained_gguf(\"model\", tokenizer, quantization_method=\"f16\")\n",
|
|
"if False:\n",
|
|
" model.push_to_hub_gguf(\"hf/model\", tokenizer, quantization_method=\"f16\", token=\"\")\n",
|
|
"\n",
|
|
"# Save to q4_k_m GGUF\n",
|
|
"if False:\n",
|
|
" model.save_pretrained_gguf(\"model\", tokenizer, quantization_method=\"q4_k_m\")\n",
|
|
"if False:\n",
|
|
" model.push_to_hub_gguf(\n",
|
|
" \"hf/model\", tokenizer, quantization_method=\"q4_k_m\", token=\"\"\n",
|
|
" )\n",
|
|
"\n",
|
|
"# Save to multiple GGUF options - much faster if you want multiple!\n",
|
|
"if False:\n",
|
|
" model.push_to_hub_gguf(\n",
|
|
" \"hf/model\", # Change hf to your username!\n",
|
|
" tokenizer,\n",
|
|
" quantization_method=[\n",
|
|
" \"q4_k_m\",\n",
|
|
" \"q8_0\",\n",
|
|
" \"q5_k_m\",\n",
|
|
" ],\n",
|
|
" token=\"\",\n",
|
|
" )"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "deepsearch-py311",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.11"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|