{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train AutoDidact\n",
    "- Taken from [AutoDidact](https://github.com/menloresearch/DeepSearch/blob/main/notebooks/train_autodidact.ipynb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "59DIs5BMcvjN",
    "outputId": "a4b3de70-c99c-4e76-ee06-dab6a6505a8b"
   },
   "outputs": [],
   "source": [
    "from unsloth import FastLanguageModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 700,
     "referenced_widgets": [
      "d8d0dca36cfc47f0919924da07c231e8",
      "5f3d96b613e94e9984d4599ca9ca7b17",
      "66c3271554b1455eb56be55c9241e45e",
      "d36b61cf796c429080e93ea838a3759e",
      "94873c3c077e483790b34f95c421f484",
      "ea549fffa8c2469888d1668158bc105c",
      "98b432b98839428f85d91580c21e80e2",
      "fee4f852c9744a07b909e586e3615604",
      "3febcf8a8eca40c28aafc697f3ec8776",
      "b4e1eb8eeb064c88a2142e474fb8327f",
      "da10502506f9448c9de94f1ddd84d3b1",
      "e6cc388e78c14abfaa49d2be6fa1b5d9",
      "769bde36e2ba4434bddd78e7d5911be4",
      "3c522d78b1834068bd4b155d0f87a4d7",
      "a23afba19c2a4d3a90d771fc55f8d490",
      "6221f0be3b8d48e797c873565a216680",
      "1ac03aff5c314b00ac938c80eb7b2f8a",
      "88c63d94a05a42c49d5f8958a27987a6",
      "0ca67b0c4ca64eb788358a51308f6b97",
      "83c3c811923a4642aba156d1215b39d2",
      "e863bf099e064da7b482c21fe7b77de7",
      "697faad6643a43aca98015da4faef186"
     ]
    },
    "id": "DkIvEkIIkEyB",
    "outputId": "514dea04-804e-47a8-b891-ed3f4a6fb530"
   },
   "outputs": [],
   "source": [
    "from unsloth import is_bfloat16_supported\n",
    "\n",
    "max_seq_length = 4096 * 2  # Can increase for longer reasoning traces\n",
    "lora_rank = 64  # Larger rank = smarter, but slower\n",
    "\n",
    "model, tokenizer = FastLanguageModel.from_pretrained(\n",
    "    model_name=\"meta-llama/meta-Llama-3.1-8B-Instruct\",\n",
    "    max_seq_length=max_seq_length,\n",
    "    load_in_4bit=True,  # False for LoRA 16bit\n",
    "    fast_inference=True,  # Enable vLLM fast inference\n",
    "    max_lora_rank=lora_rank,\n",
    "    gpu_memory_utilization=0.6,  # Reduce if out of memory\n",
    ")\n",
    "\n",
    "model = FastLanguageModel.get_peft_model(\n",
    "    model,\n",
    "    r=lora_rank,  # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128\n",
    "    target_modules=[\n",
    "        \"q_proj\",\n",
    "        \"k_proj\",\n",
    "        \"v_proj\",\n",
    "        \"o_proj\",\n",
    "        \"gate_proj\",\n",
    "        \"up_proj\",\n",
    "        \"down_proj\",\n",
    "    ],  # Remove QKVO if out of memory\n",
    "    lora_alpha=lora_rank,\n",
    "    use_gradient_checkpointing=\"unsloth\",  # Enable long context finetuning\n",
    "    random_state=3407,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "cXk993X6C2ZZ"
   },
   "outputs": [],
   "source": [
    "from rl_helpers import get_qa_dataset\n",
    "\n",
    "train_dataset, test_dataset = get_qa_dataset()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Ux6iqP7z5YOo"
   },
   "source": [
    "<a name=\"Train\"></a>\n",
    "### Train the model\n",
    "\n",
    "Now set up GRPO Trainer and all configurations!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ[\"WANDB_PROJECT\"] = \"bootstrap-search-rl\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "ptqkXK2D4d6p",
    "outputId": "9d5551f4-0276-47ca-e4ca-e96c846cc976"
   },
   "outputs": [],
   "source": [
    "# from UnslothGRPOTrainerTemp import UnslothGRPOConfig, _UnslothGRPOTrainer\n",
    "import UnslothGRPOTrainerTemp\n",
    "\n",
    "training_args = UnslothGRPOTrainerTemp.UnslothGRPOConfig(\n",
    "    use_vllm=True,  # use vLLM for fast inference!\n",
    "    use_agentic_generate=True,  # use agentic generation\n",
    "    learning_rate=5e-6,\n",
    "    adam_beta1=0.9,\n",
    "    adam_beta2=0.99,\n",
    "    weight_decay=0.1,\n",
    "    warmup_ratio=0.1,\n",
    "    lr_scheduler_type=\"cosine\",\n",
    "    optim=\"paged_adamw_8bit\",\n",
    "    logging_steps=1,\n",
    "    bf16=is_bfloat16_supported(),\n",
    "    fp16=not is_bfloat16_supported(),\n",
    "    per_device_train_batch_size=8,\n",
    "    gradient_accumulation_steps=1,  # Increase to 4 for smoother training\n",
    "    num_generations=8,  # Decrease if out of memory\n",
    "    max_prompt_length=1024,\n",
    "    max_completion_length=1024,\n",
    "    # num_train_epochs = 1, # Set to 1 for a full training run\n",
    "    max_steps=101,\n",
    "    save_steps=50,\n",
    "    max_grad_norm=0.1,\n",
    "    report_to=\"none\",  # Can use Weights & Biases\n",
    "    output_dir=\"full_local_training\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import rl_helpers\n",
    "# importlib.reload(rl_helpers)\n",
    "\n",
    "\n",
    "def agentic_generate(\n",
    "    prompts: list[str],\n",
    "    generate_fn,\n",
    "    max_generations: int = 10,\n",
    "):\n",
    "    return run_agent(generate_fn, tokenizer, prompts, max_generations)\n",
    "\n",
    "\n",
    "model.agentic_generate = agentic_generate\n",
    "\n",
    "\n",
    "from vllm import SamplingParams\n",
    "\n",
    "verifier_sampling_params = SamplingParams(\n",
    "    temperature=0.1,\n",
    "    top_p=0.95,\n",
    "    max_tokens=4096,\n",
    ")\n",
    "\n",
    "\n",
    "def verifier_generate_fn(inputs):\n",
    "    return model.fast_generate(\n",
    "        inputs,\n",
    "        sampling_params=verifier_sampling_params,\n",
    "    )\n",
    "\n",
    "\n",
    "run_agent = rl_helpers.run_agent\n",
    "reward_correctness = rl_helpers.build_reward_correctness_fn(\n",
    "    verifier_generate_fn,\n",
    "    tokenizer,\n",
    ")\n",
    "reward_formatting = rl_helpers.reward_formatting\n",
    "\n",
    "import UnslothGRPOTrainerTemp\n",
    "\n",
    "trainer = UnslothGRPOTrainerTemp.UnslothGRPOTrainer(\n",
    "    model=model,\n",
    "    processing_class=tokenizer,\n",
    "    reward_funcs=[\n",
    "        reward_correctness,\n",
    "        reward_formatting,\n",
    "    ],\n",
    "    args=training_args,\n",
    "    train_dataset=train_dataset,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "tlaUdxC_VHpz"
   },
   "source": [
    "<a name=\"Inference\"></a>\n",
    "### Inference\n",
    "Now let's try benchmark the model we trained!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from vllm import SamplingParams\n",
    "import rl_helpers\n",
    "\n",
    "sampling_params = SamplingParams(\n",
    "    temperature=0.5,\n",
    "    top_p=0.95,\n",
    "    max_tokens=4096,\n",
    ")\n",
    "\n",
    "\n",
    "def eval_generate_fn(inputs):\n",
    "    return model.fast_generate(\n",
    "        inputs,\n",
    "        sampling_params=sampling_params,\n",
    "        lora_request=model.load_lora(\n",
    "            \"full_local_training/checkpoint-101\"\n",
    "        ),  # load the trained LoRA\n",
    "    )\n",
    "\n",
    "\n",
    "rl_helpers.run_eval(\n",
    "    generate_fn=eval_generate_fn,\n",
    "    verify_fn=reward_correctness,\n",
    "    tokenizer=tokenizer,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# eval w/o lora\n",
    "def eval_generate_fn(inputs):\n",
    "    return model.fast_generate(\n",
    "        inputs,\n",
    "        sampling_params=sampling_params,\n",
    "    )\n",
    "\n",
    "\n",
    "rl_helpers.run_eval(\n",
    "    generate_fn=eval_generate_fn,\n",
    "    verify_fn=reward_correctness,\n",
    "    tokenizer=tokenizer,\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "T4",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}