You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
214 lines
6.0 KiB
214 lines
6.0 KiB
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Train R1 Distil\n",
|
|
"This notebook is for caching the model loading so that It wouldn't take so long to reload every time I change the trainer source code"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Utils"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import sys\n",
|
|
"\n",
|
|
"sys.path.append(\"..\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"\n",
|
|
"from unsloth import FastLanguageModel, is_bfloat16_supported\n",
|
|
"\n",
|
|
"import src.UnslothGRPOTrainerTemp as UnslothGRPOTrainerTemp\n",
|
|
"from src.config import (\n",
|
|
" MODEL_CONFIG,\n",
|
|
" MODEL_NAME,\n",
|
|
" OUTPUT_DIR,\n",
|
|
" TRAINING_CONFIG,\n",
|
|
" get_sampling_params,\n",
|
|
" init_training_dirs,\n",
|
|
" logger,\n",
|
|
" update_log_path,\n",
|
|
")\n",
|
|
"\n",
|
|
"# Import reward functions\n",
|
|
"from src.rl_helpers import (\n",
|
|
" build_reward_correctness_fn,\n",
|
|
" get_qa_dataset,\n",
|
|
" reward_exact_match_chunk_query,\n",
|
|
" reward_formatting,\n",
|
|
" reward_retry_behavior,\n",
|
|
" run_agent,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Initialize training directories\n",
|
|
"paths = init_training_dirs()\n",
|
|
"\n",
|
|
"# Update logger to use the training directory\n",
|
|
"update_log_path(paths[\"log_dir\"])\n",
|
|
"logger.info(f\"Training output directory: {paths['output_dir']}\")\n",
|
|
"logger.info(f\"Logs are being saved to both ./logs and {paths['log_dir']}\")\n",
|
|
"\n",
|
|
"\n",
|
|
"# Initialize model and tokenizer\n",
|
|
"logger.info(f\"Initializing model {MODEL_NAME}\")\n",
|
|
"model, tokenizer = FastLanguageModel.from_pretrained(\n",
|
|
" model_name=MODEL_NAME,\n",
|
|
" max_seq_length=MODEL_CONFIG[\"max_seq_length\"],\n",
|
|
" load_in_4bit=True, # False for LoRA 16bit\n",
|
|
" fast_inference=True, # Enable vLLM fast inference\n",
|
|
" max_lora_rank=MODEL_CONFIG[\"lora_rank\"],\n",
|
|
" gpu_memory_utilization=MODEL_CONFIG[\"gpu_memory_utilization\"],\n",
|
|
")\n",
|
|
"\n",
|
|
"# Setup LoRA\n",
|
|
"logger.info(\"Setting up LoRA adapter\")\n",
|
|
"model = FastLanguageModel.get_peft_model(\n",
|
|
" model,\n",
|
|
" r=MODEL_CONFIG[\"lora_rank\"],\n",
|
|
" target_modules=MODEL_CONFIG[\"target_modules\"],\n",
|
|
" lora_alpha=MODEL_CONFIG[\"lora_rank\"],\n",
|
|
" use_gradient_checkpointing=True, # Enable long context finetuning\n",
|
|
" random_state=3407,\n",
|
|
")\n",
|
|
"\n",
|
|
"# Load datasets\n",
|
|
"logger.info(\"Loading datasets\")\n",
|
|
"train_dataset, test_dataset = get_qa_dataset()\n",
|
|
"logger.info(\n",
|
|
" f\"Loaded {len(train_dataset)} training examples and {len(test_dataset)} test examples\"\n",
|
|
")\n",
|
|
"\n",
|
|
"# Setup training arguments\n",
|
|
"logger.info(\"Setting up training arguments\")\n",
|
|
"training_args = UnslothGRPOTrainerTemp.UnslothGRPOConfig(\n",
|
|
" use_vllm=True, # use vLLM for fast inference!\n",
|
|
" use_agentic_generate=True, # use agentic generation\n",
|
|
" **TRAINING_CONFIG,\n",
|
|
" bf16=is_bfloat16_supported(),\n",
|
|
" fp16=not is_bfloat16_supported(),\n",
|
|
" output_dir=OUTPUT_DIR,\n",
|
|
" # report_to=\"tensorboard\", # ❓ Does't have billions of tensorboard files if set report to right here\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Setup model generation functions\n",
|
|
"def agentic_generate(\n",
|
|
" prompts: list,\n",
|
|
" generate_fn,\n",
|
|
" max_generations: int = 10,\n",
|
|
"):\n",
|
|
" return run_agent(generate_fn, tokenizer, prompts, max_generations)\n",
|
|
"\n",
|
|
"\n",
|
|
"model.agentic_generate = agentic_generate\n",
|
|
"\n",
|
|
"# Setup verifier\n",
|
|
"logger.info(\"Setting up verifier\")\n",
|
|
"verifier_sampling_params = get_sampling_params(temperature=0.1)\n",
|
|
"\n",
|
|
"\n",
|
|
"def verifier_generate_fn(inputs):\n",
|
|
" return model.fast_generate(\n",
|
|
" inputs,\n",
|
|
" sampling_params=verifier_sampling_params,\n",
|
|
" )\n",
|
|
"\n",
|
|
"\n",
|
|
"# Setup trainer\n",
|
|
"logger.info(\"Initializing trainer\")\n",
|
|
"trainer = UnslothGRPOTrainerTemp.UnslothGRPOTrainer(\n",
|
|
" model=model,\n",
|
|
" processing_class=tokenizer,\n",
|
|
" reward_funcs=[\n",
|
|
" build_reward_correctness_fn(\n",
|
|
" verifier_generate_fn,\n",
|
|
" tokenizer,\n",
|
|
" log_file=os.path.join(paths[\"log_dir\"], \"qa_log.txt\"),\n",
|
|
" ),\n",
|
|
" reward_formatting,\n",
|
|
" reward_retry_behavior,\n",
|
|
" reward_exact_match_chunk_query,\n",
|
|
" ],\n",
|
|
" args=training_args,\n",
|
|
" train_dataset=train_dataset,\n",
|
|
")\n",
|
|
"\n",
|
|
"print(\"Trainer initialized successfully! Starting training...\")\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Load model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Train the model\n",
|
|
"if __name__ == \"__main__\":\n",
|
|
" logger.info(\"Starting training\")\n",
|
|
" trainer.train()\n",
|
|
" logger.info(\"Training completed\")\n",
|
|
" logger.info(f\"Model saved to {OUTPUT_DIR}\")"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "deepsearch-py311",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.11"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|