{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train R1 Distil\n",
    "This notebook is for caching the model loading so that It wouldn't take so long to reload every time I change the trainer source code"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "sys.path.append(\"..\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "from unsloth import FastLanguageModel, is_bfloat16_supported\n",
    "\n",
    "import src.UnslothGRPOTrainerTemp as UnslothGRPOTrainerTemp\n",
    "from src.config import (\n",
    "    MODEL_CONFIG,\n",
    "    MODEL_NAME,\n",
    "    OUTPUT_DIR,\n",
    "    TRAINING_CONFIG,\n",
    "    get_sampling_params,\n",
    "    init_training_dirs,\n",
    "    logger,\n",
    "    update_log_path,\n",
    ")\n",
    "\n",
    "# Import reward functions\n",
    "from src.rl_helpers import (\n",
    "    build_reward_correctness_fn,\n",
    "    get_qa_dataset,\n",
    "    reward_exact_match_chunk_query,\n",
    "    reward_formatting,\n",
    "    reward_retry_behavior,\n",
    "    run_agent,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize training directories\n",
    "paths = init_training_dirs()\n",
    "\n",
    "# Update logger to use the training directory\n",
    "update_log_path(paths[\"log_dir\"])\n",
    "logger.info(f\"Training output directory: {paths['output_dir']}\")\n",
    "logger.info(f\"Logs are being saved to both ./logs and {paths['log_dir']}\")\n",
    "\n",
    "\n",
    "# Initialize model and tokenizer\n",
    "logger.info(f\"Initializing model {MODEL_NAME}\")\n",
    "model, tokenizer = FastLanguageModel.from_pretrained(\n",
    "    model_name=MODEL_NAME,\n",
    "    max_seq_length=MODEL_CONFIG[\"max_seq_length\"],\n",
    "    load_in_4bit=True,  # False for LoRA 16bit\n",
    "    fast_inference=True,  # Enable vLLM fast inference\n",
    "    max_lora_rank=MODEL_CONFIG[\"lora_rank\"],\n",
    "    gpu_memory_utilization=MODEL_CONFIG[\"gpu_memory_utilization\"],\n",
    ")\n",
    "\n",
    "# Setup LoRA\n",
    "logger.info(\"Setting up LoRA adapter\")\n",
    "model = FastLanguageModel.get_peft_model(\n",
    "    model,\n",
    "    r=MODEL_CONFIG[\"lora_rank\"],\n",
    "    target_modules=MODEL_CONFIG[\"target_modules\"],\n",
    "    lora_alpha=MODEL_CONFIG[\"lora_rank\"],\n",
    "    use_gradient_checkpointing=True,  # Enable long context finetuning\n",
    "    random_state=3407,\n",
    ")\n",
    "\n",
    "# Load datasets\n",
    "logger.info(\"Loading datasets\")\n",
    "train_dataset, test_dataset = get_qa_dataset()\n",
    "logger.info(\n",
    "    f\"Loaded {len(train_dataset)} training examples and {len(test_dataset)} test examples\"\n",
    ")\n",
    "\n",
    "# Setup training arguments\n",
    "logger.info(\"Setting up training arguments\")\n",
    "training_args = UnslothGRPOTrainerTemp.UnslothGRPOConfig(\n",
    "    use_vllm=True,  # use vLLM for fast inference!\n",
    "    use_agentic_generate=True,  # use agentic generation\n",
    "    **TRAINING_CONFIG,\n",
    "    bf16=is_bfloat16_supported(),\n",
    "    fp16=not is_bfloat16_supported(),\n",
    "    output_dir=OUTPUT_DIR,\n",
    "    # report_to=\"tensorboard\",  # ❓ Does't have billions of tensorboard files if set report to right here\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Setup model generation functions\n",
    "def agentic_generate(\n",
    "    prompts: list,\n",
    "    generate_fn,\n",
    "    max_generations: int = 10,\n",
    "):\n",
    "    return run_agent(generate_fn, tokenizer, prompts, max_generations)\n",
    "\n",
    "\n",
    "model.agentic_generate = agentic_generate\n",
    "\n",
    "# Setup verifier\n",
    "logger.info(\"Setting up verifier\")\n",
    "verifier_sampling_params = get_sampling_params(temperature=0.1)\n",
    "\n",
    "\n",
    "def verifier_generate_fn(inputs):\n",
    "    return model.fast_generate(\n",
    "        inputs,\n",
    "        sampling_params=verifier_sampling_params,\n",
    "    )\n",
    "\n",
    "\n",
    "# Setup trainer\n",
    "logger.info(\"Initializing trainer\")\n",
    "trainer = UnslothGRPOTrainerTemp.UnslothGRPOTrainer(\n",
    "    model=model,\n",
    "    processing_class=tokenizer,\n",
    "    reward_funcs=[\n",
    "        build_reward_correctness_fn(\n",
    "            verifier_generate_fn,\n",
    "            tokenizer,\n",
    "            log_file=os.path.join(paths[\"log_dir\"], \"qa_log.txt\"),\n",
    "        ),\n",
    "        reward_formatting,\n",
    "        reward_retry_behavior,\n",
    "        reward_exact_match_chunk_query,\n",
    "    ],\n",
    "    args=training_args,\n",
    "    train_dataset=train_dataset,\n",
    ")\n",
    "\n",
    "print(\"Trainer initialized successfully! Starting training...\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train the model\n",
    "if __name__ == \"__main__\":\n",
    "    logger.info(\"Starting training\")\n",
    "    trainer.train()\n",
    "    logger.info(\"Training completed\")\n",
    "    logger.info(f\"Model saved to {OUTPUT_DIR}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "deepsearch-py311",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}