# %%
import torch
# %%
from unsloth import FastLanguageModel, is_bfloat16_supported
max_seq_length = 4096 * 2 # Can increase for longer reasoning traces
lora_rank = 64 # Larger rank = smarter, but slower
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="meta-llama/meta-Llama-3.1-8B-Instruct",
max_seq_length=max_seq_length,
load_in_4bit=True, # False for LoRA 16bit
fast_inference=True, # Enable vLLM fast inference
max_lora_rank=lora_rank,
gpu_memory_utilization=0.6, # Reduce if out of memory
)
model = FastLanguageModel.get_peft_model(
model,
r=lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
], # Remove QKVO if out of memory
lora_alpha=lora_rank,
use_gradient_checkpointing="unsloth", # Enable long context finetuning
random_state=3407,
)
# %%
import re
from datasets import Dataset, load_dataset
from rl_helpers import get_qa_dataset
from search_module import get_question_answer, get_question_count, search
train_dataset, test_dataset = get_qa_dataset()
# %% [markdown]
#
# ### Train the model
#
# Now set up GRPO Trainer and all configurations!
# %%
import os
os.environ["WANDB_PROJECT"] = "bootstrap-search-rl"
# %%
# from UnslothGRPOTrainerTemp import UnslothGRPOConfig, _UnslothGRPOTrainer
import UnslothGRPOTrainerTemp
training_args = UnslothGRPOTrainerTemp.UnslothGRPOConfig(
use_vllm=True, # use vLLM for fast inference!
use_agentic_generate=True, # use agentic generation
learning_rate=5e-6,
adam_beta1=0.9,
adam_beta2=0.99,
weight_decay=0.1,
warmup_ratio=0.1,
lr_scheduler_type="cosine",
optim="paged_adamw_8bit",
logging_steps=1,
bf16=is_bfloat16_supported(),
fp16=not is_bfloat16_supported(),
per_device_train_batch_size=8,
gradient_accumulation_steps=1, # Increase to 4 for smoother training
num_generations=8, # Decrease if out of memory
max_prompt_length=1024,
max_completion_length=1024,
# num_train_epochs = 1, # Set to 1 for a full training run
max_steps=101,
save_steps=50,
max_grad_norm=0.1,
report_to="none", # Can use Weights & Biases
output_dir="full_local_training",
)
# %%
import rl_helpers
# importlib.reload(rl_helpers)
def agentic_generate(
prompts: list[str],
generate_fn,
max_generations: int = 6,
):
return run_agent(generate_fn, tokenizer, prompts, max_generations)
model.agentic_generate = agentic_generate
from vllm import SamplingParams
verifier_sampling_params = SamplingParams(
temperature=0.1,
top_p=0.95,
max_tokens=4096,
)
def verifier_generate_fn(inputs):
return model.fast_generate(
inputs,
sampling_params=verifier_sampling_params,
)
run_agent = rl_helpers.run_agent
reward_correctness = rl_helpers.build_reward_correctness_fn(
verifier_generate_fn,
tokenizer,
)
reward_formatting = rl_helpers.reward_formatting
import UnslothGRPOTrainerTemp
trainer = UnslothGRPOTrainerTemp.UnslothGRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[
reward_correctness,
reward_formatting,
],
args=training_args,
train_dataset=train_dataset,
)
# %%
trainer.train()
# %% [markdown]
#
# ### Inference
# Now let's try benchmark the model we trained!
# %%
from vllm import SamplingParams
import rl_helpers
sampling_params = SamplingParams(
temperature=0.5,
top_p=0.95,
max_tokens=4096,
)
def eval_generate_fn(inputs):
return model.fast_generate(
inputs,
sampling_params=sampling_params,
lora_request=model.load_lora(
"full_local_training/checkpoint-101"
), # load the trained LoRA
)
rl_helpers.run_eval(
generate_fn=eval_generate_fn,
verify_fn=reward_correctness,
tokenizer=tokenizer,
)
# %%
# eval w/o lora
def eval_generate_fn(inputs):
return model.fast_generate(
inputs,
sampling_params=sampling_params,
)
rl_helpers.run_eval(
generate_fn=eval_generate_fn,
verify_fn=reward_correctness,
tokenizer=tokenizer,
)