You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

333 lines
8.8 KiB

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Train AutoDidact\n",
"- Taken from [AutoDidact](https://github.com/menloresearch/DeepSearch/blob/main/notebooks/train_autodidact.ipynb)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "59DIs5BMcvjN",
"outputId": "a4b3de70-c99c-4e76-ee06-dab6a6505a8b"
},
"outputs": [],
"source": [
"from unsloth import FastLanguageModel"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 700,
"referenced_widgets": [
"d8d0dca36cfc47f0919924da07c231e8",
"5f3d96b613e94e9984d4599ca9ca7b17",
"66c3271554b1455eb56be55c9241e45e",
"d36b61cf796c429080e93ea838a3759e",
"94873c3c077e483790b34f95c421f484",
"ea549fffa8c2469888d1668158bc105c",
"98b432b98839428f85d91580c21e80e2",
"fee4f852c9744a07b909e586e3615604",
"3febcf8a8eca40c28aafc697f3ec8776",
"b4e1eb8eeb064c88a2142e474fb8327f",
"da10502506f9448c9de94f1ddd84d3b1",
"e6cc388e78c14abfaa49d2be6fa1b5d9",
"769bde36e2ba4434bddd78e7d5911be4",
"3c522d78b1834068bd4b155d0f87a4d7",
"a23afba19c2a4d3a90d771fc55f8d490",
"6221f0be3b8d48e797c873565a216680",
"1ac03aff5c314b00ac938c80eb7b2f8a",
"88c63d94a05a42c49d5f8958a27987a6",
"0ca67b0c4ca64eb788358a51308f6b97",
"83c3c811923a4642aba156d1215b39d2",
"e863bf099e064da7b482c21fe7b77de7",
"697faad6643a43aca98015da4faef186"
]
},
"id": "DkIvEkIIkEyB",
"outputId": "514dea04-804e-47a8-b891-ed3f4a6fb530"
},
"outputs": [],
"source": [
"from unsloth import is_bfloat16_supported\n",
"\n",
"max_seq_length = 4096 * 2 # Can increase for longer reasoning traces\n",
"lora_rank = 64 # Larger rank = smarter, but slower\n",
"\n",
"model, tokenizer = FastLanguageModel.from_pretrained(\n",
" model_name=\"meta-llama/meta-Llama-3.1-8B-Instruct\",\n",
" max_seq_length=max_seq_length,\n",
" load_in_4bit=True, # False for LoRA 16bit\n",
" fast_inference=True, # Enable vLLM fast inference\n",
" max_lora_rank=lora_rank,\n",
" gpu_memory_utilization=0.6, # Reduce if out of memory\n",
")\n",
"\n",
"model = FastLanguageModel.get_peft_model(\n",
" model,\n",
" r=lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128\n",
" target_modules=[\n",
" \"q_proj\",\n",
" \"k_proj\",\n",
" \"v_proj\",\n",
" \"o_proj\",\n",
" \"gate_proj\",\n",
" \"up_proj\",\n",
" \"down_proj\",\n",
" ], # Remove QKVO if out of memory\n",
" lora_alpha=lora_rank,\n",
" use_gradient_checkpointing=\"unsloth\", # Enable long context finetuning\n",
" random_state=3407,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cXk993X6C2ZZ"
},
"outputs": [],
"source": [
"from rl_helpers import get_qa_dataset\n",
"\n",
"train_dataset, test_dataset = get_qa_dataset()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ux6iqP7z5YOo"
},
"source": [
"<a name=\"Train\"></a>\n",
"### Train the model\n",
"\n",
"Now set up GRPO Trainer and all configurations!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.environ[\"WANDB_PROJECT\"] = \"bootstrap-search-rl\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ptqkXK2D4d6p",
"outputId": "9d5551f4-0276-47ca-e4ca-e96c846cc976"
},
"outputs": [],
"source": [
"# from UnslothGRPOTrainerTemp import UnslothGRPOConfig, _UnslothGRPOTrainer\n",
"import UnslothGRPOTrainerTemp\n",
"\n",
"training_args = UnslothGRPOTrainerTemp.UnslothGRPOConfig(\n",
" use_vllm=True, # use vLLM for fast inference!\n",
" use_agentic_generate=True, # use agentic generation\n",
" learning_rate=5e-6,\n",
" adam_beta1=0.9,\n",
" adam_beta2=0.99,\n",
" weight_decay=0.1,\n",
" warmup_ratio=0.1,\n",
" lr_scheduler_type=\"cosine\",\n",
" optim=\"paged_adamw_8bit\",\n",
" logging_steps=1,\n",
" bf16=is_bfloat16_supported(),\n",
" fp16=not is_bfloat16_supported(),\n",
" per_device_train_batch_size=8,\n",
" gradient_accumulation_steps=1, # Increase to 4 for smoother training\n",
" num_generations=8, # Decrease if out of memory\n",
" max_prompt_length=1024,\n",
" max_completion_length=1024,\n",
" # num_train_epochs = 1, # Set to 1 for a full training run\n",
" max_steps=101,\n",
" save_steps=50,\n",
" max_grad_norm=0.1,\n",
" report_to=\"none\", # Can use Weights & Biases\n",
" output_dir=\"full_local_training\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import rl_helpers\n",
"# importlib.reload(rl_helpers)\n",
"\n",
"\n",
"def agentic_generate(\n",
" prompts: list[str],\n",
" generate_fn,\n",
" max_generations: int = 10,\n",
"):\n",
" return run_agent(generate_fn, tokenizer, prompts, max_generations)\n",
"\n",
"\n",
"model.agentic_generate = agentic_generate\n",
"\n",
"\n",
"from vllm import SamplingParams\n",
"\n",
"verifier_sampling_params = SamplingParams(\n",
" temperature=0.1,\n",
" top_p=0.95,\n",
" max_tokens=4096,\n",
")\n",
"\n",
"\n",
"def verifier_generate_fn(inputs):\n",
" return model.fast_generate(\n",
" inputs,\n",
" sampling_params=verifier_sampling_params,\n",
" )\n",
"\n",
"\n",
"run_agent = rl_helpers.run_agent\n",
"reward_correctness = rl_helpers.build_reward_correctness_fn(\n",
" verifier_generate_fn,\n",
" tokenizer,\n",
")\n",
"reward_formatting = rl_helpers.reward_formatting\n",
"\n",
"import UnslothGRPOTrainerTemp\n",
"\n",
"trainer = UnslothGRPOTrainerTemp.UnslothGRPOTrainer(\n",
" model=model,\n",
" processing_class=tokenizer,\n",
" reward_funcs=[\n",
" reward_correctness,\n",
" reward_formatting,\n",
" ],\n",
" args=training_args,\n",
" train_dataset=train_dataset,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"trainer.train()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tlaUdxC_VHpz"
},
"source": [
"<a name=\"Inference\"></a>\n",
"### Inference\n",
"Now let's try benchmark the model we trained!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from vllm import SamplingParams\n",
"import rl_helpers\n",
"\n",
"sampling_params = SamplingParams(\n",
" temperature=0.5,\n",
" top_p=0.95,\n",
" max_tokens=4096,\n",
")\n",
"\n",
"\n",
"def eval_generate_fn(inputs):\n",
" return model.fast_generate(\n",
" inputs,\n",
" sampling_params=sampling_params,\n",
" lora_request=model.load_lora(\n",
" \"full_local_training/checkpoint-101\"\n",
" ), # load the trained LoRA\n",
" )\n",
"\n",
"\n",
"rl_helpers.run_eval(\n",
" generate_fn=eval_generate_fn,\n",
" verify_fn=reward_correctness,\n",
" tokenizer=tokenizer,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# eval w/o lora\n",
"def eval_generate_fn(inputs):\n",
" return model.fast_generate(\n",
" inputs,\n",
" sampling_params=sampling_params,\n",
" )\n",
"\n",
"\n",
"rl_helpers.run_eval(\n",
" generate_fn=eval_generate_fn,\n",
" verify_fn=reward_correctness,\n",
" tokenizer=tokenizer,\n",
")"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "T4",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 0
}