{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Train AutoDidact\n", "- Taken from [AutoDidact](https://github.com/menloresearch/DeepSearch/blob/main/notebooks/train_autodidact.ipynb)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "59DIs5BMcvjN", "outputId": "a4b3de70-c99c-4e76-ee06-dab6a6505a8b" }, "outputs": [], "source": [ "from unsloth import FastLanguageModel" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 700, "referenced_widgets": [ "d8d0dca36cfc47f0919924da07c231e8", "5f3d96b613e94e9984d4599ca9ca7b17", "66c3271554b1455eb56be55c9241e45e", "d36b61cf796c429080e93ea838a3759e", "94873c3c077e483790b34f95c421f484", "ea549fffa8c2469888d1668158bc105c", "98b432b98839428f85d91580c21e80e2", "fee4f852c9744a07b909e586e3615604", "3febcf8a8eca40c28aafc697f3ec8776", "b4e1eb8eeb064c88a2142e474fb8327f", "da10502506f9448c9de94f1ddd84d3b1", "e6cc388e78c14abfaa49d2be6fa1b5d9", "769bde36e2ba4434bddd78e7d5911be4", "3c522d78b1834068bd4b155d0f87a4d7", "a23afba19c2a4d3a90d771fc55f8d490", "6221f0be3b8d48e797c873565a216680", "1ac03aff5c314b00ac938c80eb7b2f8a", "88c63d94a05a42c49d5f8958a27987a6", "0ca67b0c4ca64eb788358a51308f6b97", "83c3c811923a4642aba156d1215b39d2", "e863bf099e064da7b482c21fe7b77de7", "697faad6643a43aca98015da4faef186" ] }, "id": "DkIvEkIIkEyB", "outputId": "514dea04-804e-47a8-b891-ed3f4a6fb530" }, "outputs": [], "source": [ "from unsloth import is_bfloat16_supported\n", "\n", "max_seq_length = 4096 * 2 # Can increase for longer reasoning traces\n", "lora_rank = 64 # Larger rank = smarter, but slower\n", "\n", "model, tokenizer = FastLanguageModel.from_pretrained(\n", " model_name=\"meta-llama/meta-Llama-3.1-8B-Instruct\",\n", " max_seq_length=max_seq_length,\n", " load_in_4bit=True, # False for LoRA 16bit\n", " fast_inference=True, # Enable vLLM fast inference\n", " max_lora_rank=lora_rank,\n", " gpu_memory_utilization=0.6, # Reduce if out of memory\n", ")\n", "\n", "model = FastLanguageModel.get_peft_model(\n", " model,\n", " r=lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128\n", " target_modules=[\n", " \"q_proj\",\n", " \"k_proj\",\n", " \"v_proj\",\n", " \"o_proj\",\n", " \"gate_proj\",\n", " \"up_proj\",\n", " \"down_proj\",\n", " ], # Remove QKVO if out of memory\n", " lora_alpha=lora_rank,\n", " use_gradient_checkpointing=\"unsloth\", # Enable long context finetuning\n", " random_state=3407,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cXk993X6C2ZZ" }, "outputs": [], "source": [ "from rl_helpers import get_qa_dataset\n", "\n", "train_dataset, test_dataset = get_qa_dataset()" ] }, { "cell_type": "markdown", "metadata": { "id": "Ux6iqP7z5YOo" }, "source": [ "\n", "### Train the model\n", "\n", "Now set up GRPO Trainer and all configurations!" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "os.environ[\"WANDB_PROJECT\"] = \"bootstrap-search-rl\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ptqkXK2D4d6p", "outputId": "9d5551f4-0276-47ca-e4ca-e96c846cc976" }, "outputs": [], "source": [ "# from UnslothGRPOTrainerTemp import UnslothGRPOConfig, _UnslothGRPOTrainer\n", "import UnslothGRPOTrainerTemp\n", "\n", "training_args = UnslothGRPOTrainerTemp.UnslothGRPOConfig(\n", " use_vllm=True, # use vLLM for fast inference!\n", " use_agentic_generate=True, # use agentic generation\n", " learning_rate=5e-6,\n", " adam_beta1=0.9,\n", " adam_beta2=0.99,\n", " weight_decay=0.1,\n", " warmup_ratio=0.1,\n", " lr_scheduler_type=\"cosine\",\n", " optim=\"paged_adamw_8bit\",\n", " logging_steps=1,\n", " bf16=is_bfloat16_supported(),\n", " fp16=not is_bfloat16_supported(),\n", " per_device_train_batch_size=8,\n", " gradient_accumulation_steps=1, # Increase to 4 for smoother training\n", " num_generations=8, # Decrease if out of memory\n", " max_prompt_length=1024,\n", " max_completion_length=1024,\n", " # num_train_epochs = 1, # Set to 1 for a full training run\n", " max_steps=101,\n", " save_steps=50,\n", " max_grad_norm=0.1,\n", " report_to=\"none\", # Can use Weights & Biases\n", " output_dir=\"full_local_training\",\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import rl_helpers\n", "# importlib.reload(rl_helpers)\n", "\n", "\n", "def agentic_generate(\n", " prompts: list[str],\n", " generate_fn,\n", " max_generations: int = 10,\n", "):\n", " return run_agent(generate_fn, tokenizer, prompts, max_generations)\n", "\n", "\n", "model.agentic_generate = agentic_generate\n", "\n", "\n", "from vllm import SamplingParams\n", "\n", "verifier_sampling_params = SamplingParams(\n", " temperature=0.1,\n", " top_p=0.95,\n", " max_tokens=4096,\n", ")\n", "\n", "\n", "def verifier_generate_fn(inputs):\n", " return model.fast_generate(\n", " inputs,\n", " sampling_params=verifier_sampling_params,\n", " )\n", "\n", "\n", "run_agent = rl_helpers.run_agent\n", "reward_correctness = rl_helpers.build_reward_correctness_fn(\n", " verifier_generate_fn,\n", " tokenizer,\n", ")\n", "reward_formatting = rl_helpers.reward_formatting\n", "\n", "import UnslothGRPOTrainerTemp\n", "\n", "trainer = UnslothGRPOTrainerTemp.UnslothGRPOTrainer(\n", " model=model,\n", " processing_class=tokenizer,\n", " reward_funcs=[\n", " reward_correctness,\n", " reward_formatting,\n", " ],\n", " args=training_args,\n", " train_dataset=train_dataset,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "trainer.train()" ] }, { "cell_type": "markdown", "metadata": { "id": "tlaUdxC_VHpz" }, "source": [ "\n", "### Inference\n", "Now let's try benchmark the model we trained!" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from vllm import SamplingParams\n", "import rl_helpers\n", "\n", "sampling_params = SamplingParams(\n", " temperature=0.5,\n", " top_p=0.95,\n", " max_tokens=4096,\n", ")\n", "\n", "\n", "def eval_generate_fn(inputs):\n", " return model.fast_generate(\n", " inputs,\n", " sampling_params=sampling_params,\n", " lora_request=model.load_lora(\n", " \"full_local_training/checkpoint-101\"\n", " ), # load the trained LoRA\n", " )\n", "\n", "\n", "rl_helpers.run_eval(\n", " generate_fn=eval_generate_fn,\n", " verify_fn=reward_correctness,\n", " tokenizer=tokenizer,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# eval w/o lora\n", "def eval_generate_fn(inputs):\n", " return model.fast_generate(\n", " inputs,\n", " sampling_params=sampling_params,\n", " )\n", "\n", "\n", "rl_helpers.run_eval(\n", " generate_fn=eval_generate_fn,\n", " verify_fn=reward_correctness,\n", " tokenizer=tokenizer,\n", ")" ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "T4", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 0 }