You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
337 lines
9.0 KiB
337 lines
9.0 KiB
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Train AutoDidact\n",
|
|
"- Taken from [AutoDidact](https://github.com/menloresearch/DeepSearch/blob/main/notebooks/train_autodidact.ipynb)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "59DIs5BMcvjN",
|
|
"outputId": "a4b3de70-c99c-4e76-ee06-dab6a6505a8b"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from unsloth import FastLanguageModel"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 700,
|
|
"referenced_widgets": [
|
|
"d8d0dca36cfc47f0919924da07c231e8",
|
|
"5f3d96b613e94e9984d4599ca9ca7b17",
|
|
"66c3271554b1455eb56be55c9241e45e",
|
|
"d36b61cf796c429080e93ea838a3759e",
|
|
"94873c3c077e483790b34f95c421f484",
|
|
"ea549fffa8c2469888d1668158bc105c",
|
|
"98b432b98839428f85d91580c21e80e2",
|
|
"fee4f852c9744a07b909e586e3615604",
|
|
"3febcf8a8eca40c28aafc697f3ec8776",
|
|
"b4e1eb8eeb064c88a2142e474fb8327f",
|
|
"da10502506f9448c9de94f1ddd84d3b1",
|
|
"e6cc388e78c14abfaa49d2be6fa1b5d9",
|
|
"769bde36e2ba4434bddd78e7d5911be4",
|
|
"3c522d78b1834068bd4b155d0f87a4d7",
|
|
"a23afba19c2a4d3a90d771fc55f8d490",
|
|
"6221f0be3b8d48e797c873565a216680",
|
|
"1ac03aff5c314b00ac938c80eb7b2f8a",
|
|
"88c63d94a05a42c49d5f8958a27987a6",
|
|
"0ca67b0c4ca64eb788358a51308f6b97",
|
|
"83c3c811923a4642aba156d1215b39d2",
|
|
"e863bf099e064da7b482c21fe7b77de7",
|
|
"697faad6643a43aca98015da4faef186"
|
|
]
|
|
},
|
|
"id": "DkIvEkIIkEyB",
|
|
"outputId": "514dea04-804e-47a8-b891-ed3f4a6fb530"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from unsloth import is_bfloat16_supported\n",
|
|
"import torch\n",
|
|
"\n",
|
|
"max_seq_length = 4096 * 2 # Can increase for longer reasoning traces\n",
|
|
"lora_rank = 64 # Larger rank = smarter, but slower\n",
|
|
"\n",
|
|
"model, tokenizer = FastLanguageModel.from_pretrained(\n",
|
|
" model_name=\"meta-llama/meta-Llama-3.1-8B-Instruct\",\n",
|
|
" max_seq_length=max_seq_length,\n",
|
|
" load_in_4bit=True, # False for LoRA 16bit\n",
|
|
" fast_inference=True, # Enable vLLM fast inference\n",
|
|
" max_lora_rank=lora_rank,\n",
|
|
" gpu_memory_utilization=0.6, # Reduce if out of memory\n",
|
|
")\n",
|
|
"\n",
|
|
"model = FastLanguageModel.get_peft_model(\n",
|
|
" model,\n",
|
|
" r=lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128\n",
|
|
" target_modules=[\n",
|
|
" \"q_proj\",\n",
|
|
" \"k_proj\",\n",
|
|
" \"v_proj\",\n",
|
|
" \"o_proj\",\n",
|
|
" \"gate_proj\",\n",
|
|
" \"up_proj\",\n",
|
|
" \"down_proj\",\n",
|
|
" ], # Remove QKVO if out of memory\n",
|
|
" lora_alpha=lora_rank,\n",
|
|
" use_gradient_checkpointing=\"unsloth\", # Enable long context finetuning\n",
|
|
" random_state=3407,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "cXk993X6C2ZZ"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import re\n",
|
|
"from datasets import load_dataset, Dataset\n",
|
|
"from search_module import search, get_question_answer, get_question_count\n",
|
|
"from rl_helpers import get_qa_dataset\n",
|
|
"\n",
|
|
"train_dataset, test_dataset = get_qa_dataset()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "Ux6iqP7z5YOo"
|
|
},
|
|
"source": [
|
|
"<a name=\"Train\"></a>\n",
|
|
"### Train the model\n",
|
|
"\n",
|
|
"Now set up GRPO Trainer and all configurations!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"\n",
|
|
"os.environ[\"WANDB_PROJECT\"] = \"bootstrap-search-rl\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "ptqkXK2D4d6p",
|
|
"outputId": "9d5551f4-0276-47ca-e4ca-e96c846cc976"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# from UnslothGRPOTrainerTemp import UnslothGRPOConfig, _UnslothGRPOTrainer\n",
|
|
"import UnslothGRPOTrainerTemp\n",
|
|
"\n",
|
|
"training_args = UnslothGRPOTrainerTemp.UnslothGRPOConfig(\n",
|
|
" use_vllm=True, # use vLLM for fast inference!\n",
|
|
" use_agentic_generate=True, # use agentic generation\n",
|
|
" learning_rate=5e-6,\n",
|
|
" adam_beta1=0.9,\n",
|
|
" adam_beta2=0.99,\n",
|
|
" weight_decay=0.1,\n",
|
|
" warmup_ratio=0.1,\n",
|
|
" lr_scheduler_type=\"cosine\",\n",
|
|
" optim=\"paged_adamw_8bit\",\n",
|
|
" logging_steps=1,\n",
|
|
" bf16=is_bfloat16_supported(),\n",
|
|
" fp16=not is_bfloat16_supported(),\n",
|
|
" per_device_train_batch_size=8,\n",
|
|
" gradient_accumulation_steps=1, # Increase to 4 for smoother training\n",
|
|
" num_generations=8, # Decrease if out of memory\n",
|
|
" max_prompt_length=1024,\n",
|
|
" max_completion_length=1024,\n",
|
|
" # num_train_epochs = 1, # Set to 1 for a full training run\n",
|
|
" max_steps=101,\n",
|
|
" save_steps=50,\n",
|
|
" max_grad_norm=0.1,\n",
|
|
" report_to=\"none\", # Can use Weights & Biases\n",
|
|
" output_dir=\"full_local_training\",\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import rl_helpers\n",
|
|
"# importlib.reload(rl_helpers)\n",
|
|
"\n",
|
|
"\n",
|
|
"def agentic_generate(\n",
|
|
" prompts: list[str],\n",
|
|
" generate_fn,\n",
|
|
" max_generations: int = 6,\n",
|
|
"):\n",
|
|
" return run_agent(generate_fn, tokenizer, prompts, max_generations)\n",
|
|
"\n",
|
|
"\n",
|
|
"model.agentic_generate = agentic_generate\n",
|
|
"\n",
|
|
"\n",
|
|
"from vllm import SamplingParams\n",
|
|
"\n",
|
|
"verifier_sampling_params = SamplingParams(\n",
|
|
" temperature=0.1,\n",
|
|
" top_p=0.95,\n",
|
|
" max_tokens=4096,\n",
|
|
")\n",
|
|
"\n",
|
|
"\n",
|
|
"def verifier_generate_fn(inputs):\n",
|
|
" return model.fast_generate(\n",
|
|
" inputs,\n",
|
|
" sampling_params=verifier_sampling_params,\n",
|
|
" )\n",
|
|
"\n",
|
|
"\n",
|
|
"run_agent = rl_helpers.run_agent\n",
|
|
"reward_correctness = rl_helpers.build_reward_correctness_fn(\n",
|
|
" verifier_generate_fn,\n",
|
|
" tokenizer,\n",
|
|
")\n",
|
|
"reward_formatting = rl_helpers.reward_formatting\n",
|
|
"\n",
|
|
"import UnslothGRPOTrainerTemp\n",
|
|
"\n",
|
|
"trainer = UnslothGRPOTrainerTemp.UnslothGRPOTrainer(\n",
|
|
" model=model,\n",
|
|
" processing_class=tokenizer,\n",
|
|
" reward_funcs=[\n",
|
|
" reward_correctness,\n",
|
|
" reward_formatting,\n",
|
|
" ],\n",
|
|
" args=training_args,\n",
|
|
" train_dataset=train_dataset,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"trainer.train()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "tlaUdxC_VHpz"
|
|
},
|
|
"source": [
|
|
"<a name=\"Inference\"></a>\n",
|
|
"### Inference\n",
|
|
"Now let's try benchmark the model we trained!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from vllm import SamplingParams\n",
|
|
"import rl_helpers\n",
|
|
"\n",
|
|
"sampling_params = SamplingParams(\n",
|
|
" temperature=0.5,\n",
|
|
" top_p=0.95,\n",
|
|
" max_tokens=4096,\n",
|
|
")\n",
|
|
"\n",
|
|
"\n",
|
|
"def eval_generate_fn(inputs):\n",
|
|
" return model.fast_generate(\n",
|
|
" inputs,\n",
|
|
" sampling_params=sampling_params,\n",
|
|
" lora_request=model.load_lora(\n",
|
|
" \"full_local_training/checkpoint-101\"\n",
|
|
" ), # load the trained LoRA\n",
|
|
" )\n",
|
|
"\n",
|
|
"\n",
|
|
"rl_helpers.run_eval(\n",
|
|
" generate_fn=eval_generate_fn,\n",
|
|
" verify_fn=reward_correctness,\n",
|
|
" tokenizer=tokenizer,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# eval w/o lora\n",
|
|
"def eval_generate_fn(inputs):\n",
|
|
" return model.fast_generate(\n",
|
|
" inputs,\n",
|
|
" sampling_params=sampling_params,\n",
|
|
" )\n",
|
|
"\n",
|
|
"\n",
|
|
"rl_helpers.run_eval(\n",
|
|
" generate_fn=eval_generate_fn,\n",
|
|
" verify_fn=reward_correctness,\n",
|
|
" tokenizer=tokenizer,\n",
|
|
")"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"accelerator": "GPU",
|
|
"colab": {
|
|
"gpuType": "T4",
|
|
"provenance": []
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.12"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0
|
|
}
|