parent
04d56325bb
commit
3c2deaced9
@ -1,2 +1,2 @@
|
||||
HF_TOKEN=
|
||||
OPENROUTER_API_KEY=
|
||||
HF_TOKEN=<your-huggingface-token>
|
||||
OPENROUTER_API_KEY=<your-openrouter-api-key>
|
@ -1,232 +0,0 @@
|
||||
"""
|
||||
This script performs two main tasks:
|
||||
1. It loads a markdown document, splits it into chunks, generates embeddings,
|
||||
and builds a FAISS index (which is saved locally).
|
||||
2. It generates QA pairs from the document using llama.
|
||||
For each chunk (using a sliding window for context), it generates multiple question-answer pairs
|
||||
with different difficulties. The generation is performed in batch with one retry for failed prompts.
|
||||
Successfully generated QA pairs are saved to "saved_data/questions.json".
|
||||
|
||||
Requirements:
|
||||
pip install langchain faiss-cpu unsloth vllm
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import re
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
# ========= Part 1: Document Processing and Embedding Generation =========
|
||||
# Load and split the markdown document using LangChain
|
||||
from langchain_community.document_loaders import UnstructuredMarkdownLoader
|
||||
from langchain_community.vectorstores import FAISS
|
||||
|
||||
from embeddings import CustomHuggingFaceEmbeddings
|
||||
|
||||
# Load your markdown file (adjust the path as needed)
|
||||
loader = UnstructuredMarkdownLoader("./data/mission_report.md")
|
||||
docs = loader.load()
|
||||
|
||||
# Split the document into smaller chunks (each 1000 characters, no overlap)
|
||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
|
||||
chunks = text_splitter.split_documents(docs)
|
||||
|
||||
# Save chunks for later use # TODO: change to csv? easier inspect.
|
||||
os.makedirs("saved_data", exist_ok=True)
|
||||
with open("saved_data/chunks.pkl", "wb") as f:
|
||||
pickle.dump(chunks, f)
|
||||
print(f"Saved {len(chunks)} chunks to saved_data/chunks.pkl")
|
||||
|
||||
embeddings = CustomHuggingFaceEmbeddings()
|
||||
|
||||
# Create a FAISS vector store from the document chunks and save it locally
|
||||
vectorstore = FAISS.from_documents(chunks, embeddings)
|
||||
vectorstore.save_local("faiss_index")
|
||||
print("Saved FAISS index to 'faiss_index'")
|
||||
|
||||
# TODO: add the paraphrased chunks to the vector store
|
||||
|
||||
# ========= Part 2: QA Generation using Llama Backend =========
|
||||
|
||||
# Setup Llama backend via unsloth and vLLM
|
||||
from unsloth import FastLanguageModel
|
||||
from vllm import SamplingParams
|
||||
|
||||
import rl_helpers # Ensure you have this or remove if not used
|
||||
|
||||
# Load the Llama model (adjust parameters as needed)
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="meta-llama/meta-Llama-3.1-8B-Instruct",
|
||||
max_seq_length=4096,
|
||||
load_in_4bit=True, # Use 4-bit quantization if desired
|
||||
fast_inference=True, # Enable fast inference
|
||||
gpu_memory_utilization=0.6, # Adjust based on your GPU memory
|
||||
)
|
||||
|
||||
# Define sampling parameters for generation
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.3,
|
||||
top_p=0.95,
|
||||
max_tokens=4096,
|
||||
)
|
||||
|
||||
|
||||
def batch_generate(prompts: List[str]) -> List[str]:
|
||||
"""
|
||||
Given a list of prompt strings, returns a list of generated outputs.
|
||||
"""
|
||||
|
||||
def format_input(text: str) -> str:
|
||||
return tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": text}],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
|
||||
formatted = [format_input(p) for p in prompts]
|
||||
outputs = model.fast_generate(formatted, sampling_params=sampling_params)
|
||||
return [output.outputs[0].text for output in outputs]
|
||||
|
||||
|
||||
def parse_qa_block(block: str) -> Optional[Tuple[str, str, str]]:
|
||||
"""
|
||||
Parses a QA block that should contain exactly three non-empty lines:
|
||||
- A line starting with "Question:"
|
||||
- A line starting with "Answer:"
|
||||
- A line starting with "Difficulty:"
|
||||
|
||||
If the markers are not present but the block contains exactly three lines,
|
||||
those are used in order.
|
||||
|
||||
Returns a tuple (question, answer, difficulty) or None if parsing fails.
|
||||
"""
|
||||
lines = [line.strip() for line in block.splitlines() if line.strip()]
|
||||
if not lines:
|
||||
return None
|
||||
|
||||
question, answer, difficulty = None, None, None
|
||||
for line in lines:
|
||||
lower = line.lower()
|
||||
if question is None and lower.startswith("question:"):
|
||||
question = line[len("question:") :].strip()
|
||||
elif answer is None and lower.startswith("answer:"):
|
||||
answer = line[len("answer:") :].strip()
|
||||
elif difficulty is None and lower.startswith("difficulty:"):
|
||||
difficulty = line[len("difficulty:") :].strip()
|
||||
|
||||
if question and answer and difficulty:
|
||||
return question, answer, difficulty
|
||||
if len(lines) == 3:
|
||||
return lines[0], lines[1], lines[2]
|
||||
return None
|
||||
|
||||
|
||||
def parse_multiple_qa_output(output: str) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
Splits the output into blocks (separated by one or more blank lines) and
|
||||
attempts to parse each as a QA pair.
|
||||
|
||||
Returns a list of successfully parsed QA tuples.
|
||||
"""
|
||||
blocks = re.split(r"\n\s*\n", output.strip())
|
||||
qa_pairs = []
|
||||
for block in blocks:
|
||||
parsed = parse_qa_block(block)
|
||||
if parsed:
|
||||
qa_pairs.append(parsed)
|
||||
return qa_pairs
|
||||
|
||||
|
||||
def generate_question_batch_for_chunks(
|
||||
chunks: List, num_questions: int = 2, difficulty: str = None
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Generates QA pairs for multiple chunks in batch.
|
||||
|
||||
For each chunk (except the first and last), a sliding window is used for context:
|
||||
- before: previous chunk's content
|
||||
- current: current chunk's content
|
||||
- after: next chunk's content
|
||||
|
||||
Each prompt instructs the model to output exactly three lines per QA pair with markers.
|
||||
Failed prompts are retried once in batch; if still unsuccessful, they are skipped.
|
||||
|
||||
Returns a list of dicts with keys: "chunk_id", "question", "answer", "difficulty".
|
||||
"""
|
||||
prompts = []
|
||||
chunk_ids = []
|
||||
|
||||
# Prepare prompts using a sliding window
|
||||
for i in range(1, len(chunks) - 1):
|
||||
before = chunks[i - 1].page_content
|
||||
current = chunks[i].page_content
|
||||
after = chunks[i + 1].page_content
|
||||
prompt = (
|
||||
f"From the text within ==BEGIN== and ==END==, generate {num_questions} questions with answers.\n"
|
||||
"For each QA pair, output exactly three lines with no extra commentary:\n"
|
||||
"Line 1: Question: <your question>\n"
|
||||
"Line 2: Answer: <the answer>\n"
|
||||
"Line 3: Difficulty: <easy, medium, or hard>\n"
|
||||
"Do not include any additional text.\n\n"
|
||||
"==BEGIN==\n"
|
||||
f"{before}\n{current}\n{after}\n"
|
||||
"==END==\n"
|
||||
)
|
||||
prompts.append(prompt)
|
||||
chunk_ids.append(i)
|
||||
|
||||
# First batch generation
|
||||
outputs = batch_generate(prompts)
|
||||
results = [None] * len(outputs)
|
||||
failed_indices = []
|
||||
|
||||
# Parse each output
|
||||
for idx, output in enumerate(outputs):
|
||||
qa_pairs = parse_multiple_qa_output(output)
|
||||
if qa_pairs is None or len(qa_pairs) < num_questions:
|
||||
failed_indices.append(idx)
|
||||
else:
|
||||
results[idx] = qa_pairs[:num_questions]
|
||||
|
||||
# Retry failed prompts in batch
|
||||
if failed_indices:
|
||||
print(f"Retrying {len(failed_indices)} failed prompt(s)...")
|
||||
retry_prompts = [prompts[i] for i in failed_indices]
|
||||
retry_outputs = batch_generate(retry_prompts)
|
||||
for j, idx in enumerate(failed_indices):
|
||||
qa_pairs = parse_multiple_qa_output(retry_outputs[j])
|
||||
if qa_pairs is not None and len(qa_pairs) >= num_questions:
|
||||
results[idx] = qa_pairs[:num_questions]
|
||||
else:
|
||||
results[idx] = None # Mark as failed
|
||||
|
||||
# Build final output, skipping prompts that failed even after retry
|
||||
final_questions = []
|
||||
for i, qa_list in enumerate(results):
|
||||
if qa_list is not None:
|
||||
for qa in qa_list:
|
||||
final_questions.append(
|
||||
{
|
||||
"chunk_id": chunk_ids[i],
|
||||
"question": qa[0],
|
||||
"answer": qa[1],
|
||||
"difficulty": qa[2],
|
||||
}
|
||||
)
|
||||
return final_questions
|
||||
|
||||
|
||||
# Generate QA pairs in batch (using a sliding window over the chunks)
|
||||
all_questions = generate_question_batch_for_chunks(
|
||||
chunks, num_questions=2, difficulty="medium"
|
||||
)
|
||||
print(f"Generated {len(all_questions)} QA pairs.")
|
||||
|
||||
# Save the QA pairs to a JSON file
|
||||
questions_path = os.path.join("saved_data", "questions.json")
|
||||
with open(questions_path, "w") as f:
|
||||
json.dump(all_questions, f, indent=2)
|
||||
print(f"Saved questions to {questions_path}")
|
@ -1,3 +0,0 @@
|
||||
unsloth_compiled_cache
|
||||
0_*
|
||||
faiss_index*
|
@ -0,0 +1,301 @@
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from vllm import SamplingParams
|
||||
|
||||
# Load environment variables from .env file if it exists
|
||||
load_dotenv(override=True)
|
||||
|
||||
# Project paths
|
||||
PROJ_ROOT = Path(__file__).resolve().parent.parent
|
||||
DATA_DIR = PROJ_ROOT / "data"
|
||||
LOG_FOLDER = PROJ_ROOT / "logs"
|
||||
|
||||
# Model configuration
|
||||
# MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
|
||||
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
device_id = (
|
||||
1 if os.environ.get("CUDA_VISIBLE_DEVICES") == "1" else torch.cuda.current_device()
|
||||
)
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
OUTPUT_DIR = (
|
||||
PROJ_ROOT
|
||||
/ f"trainer_output_{MODEL_NAME.replace('/', '_')}_gpu{device_id}_{timestamp}"
|
||||
)
|
||||
|
||||
# Model parameters
|
||||
MODEL_CONFIG = {
|
||||
"max_seq_length": 4096 * 2, # Can increase for longer reasoning traces
|
||||
"lora_rank": 64, # Larger rank = smarter, but slower
|
||||
"gpu_memory_utilization": 0.6, # Reduce if out of memory
|
||||
"model_name": MODEL_NAME,
|
||||
"target_modules": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
"o_proj",
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
"down_proj",
|
||||
], # Remove QKVO if out of memory
|
||||
}
|
||||
|
||||
# Training parameters
|
||||
TRAINING_CONFIG = {
|
||||
"learning_rate": 5e-6,
|
||||
"adam_beta1": 0.9,
|
||||
"adam_beta2": 0.99,
|
||||
"weight_decay": 0.1,
|
||||
"warmup_ratio": 0.1,
|
||||
"lr_scheduler_type": "cosine",
|
||||
"optim": "paged_adamw_8bit",
|
||||
"logging_steps": 1,
|
||||
"per_device_train_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1, # Increase to 4 for smoother training
|
||||
"num_generations": 8, # Decrease if out of memory
|
||||
"max_prompt_length": 1024,
|
||||
"max_completion_length": 1024,
|
||||
"max_steps": 101,
|
||||
"save_steps": 50,
|
||||
"max_grad_norm": 0.1,
|
||||
"report_to": "tensorboard",
|
||||
}
|
||||
|
||||
|
||||
# Sampling parameters
|
||||
def get_sampling_params(temperature: float = 0.1) -> SamplingParams:
|
||||
"""Get sampling parameters for text generation"""
|
||||
return SamplingParams(
|
||||
temperature=temperature,
|
||||
top_p=0.95,
|
||||
max_tokens=4096,
|
||||
)
|
||||
|
||||
|
||||
# Initialize logging based on environment
|
||||
def _init_logging(env: str = "development") -> None:
|
||||
"""
|
||||
Initialize logging configuration with console logging
|
||||
and default file logging to ./logs directory.
|
||||
Additional file logging will be set up later in update_log_path().
|
||||
|
||||
Args:
|
||||
env: The environment for logging ('development' or 'production')
|
||||
"""
|
||||
# Create default log folder
|
||||
if not LOG_FOLDER.exists():
|
||||
LOG_FOLDER.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Remove any existing handlers
|
||||
logger.remove()
|
||||
|
||||
# Define the logging format
|
||||
console_format = (
|
||||
"<green>{time:YYYY-MM-DD HH:mm:ss}</green> "
|
||||
"| <level>{level: <8}</level> "
|
||||
"| <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> "
|
||||
"- <level>{message}</level>"
|
||||
)
|
||||
|
||||
file_format = (
|
||||
"{time:YYYY-MM-DD at HH:mm:ss} "
|
||||
"| {level} "
|
||||
"| {name}:{function}:{line} "
|
||||
"- {message}"
|
||||
)
|
||||
|
||||
# Add console logging
|
||||
logger.add(
|
||||
sys.stderr,
|
||||
format=console_format,
|
||||
level="DEBUG" if env == "development" else "INFO",
|
||||
colorize=True,
|
||||
backtrace=True,
|
||||
diagnose=env == "development",
|
||||
)
|
||||
|
||||
# Add default file logging to ./logs directory
|
||||
logger.add(
|
||||
LOG_FOLDER / "app.log",
|
||||
format=file_format,
|
||||
level="INFO",
|
||||
rotation="500 MB",
|
||||
retention="7 days",
|
||||
compression="zip",
|
||||
enqueue=True, # Enables asynchronous logging
|
||||
)
|
||||
|
||||
# Add custom level for requests
|
||||
logger.level("REQUEST", no=25, color="<yellow>", icon=" ")
|
||||
|
||||
# Configure exception handling
|
||||
def exception_handler(exc_type, exc_value, exc_traceback):
|
||||
if issubclass(exc_type, KeyboardInterrupt):
|
||||
sys.__excepthook__(exc_type, exc_value, exc_traceback)
|
||||
return
|
||||
logger.opt(exception=(exc_type, exc_value, exc_traceback)).critical(
|
||||
"Unhandled exception"
|
||||
)
|
||||
|
||||
sys.excepthook = exception_handler
|
||||
|
||||
|
||||
# Update the log files to point to the training directory
|
||||
def update_log_path(log_dir=None):
|
||||
"""
|
||||
Add a log file in the training directory while keeping the default ./logs logging.
|
||||
Should be called after the training directory is created.
|
||||
|
||||
Args:
|
||||
log_dir: Path to store additional log files (default: uses get_paths()["log_dir"])
|
||||
"""
|
||||
# Use provided log_dir or get from training paths
|
||||
if log_dir is None:
|
||||
paths = get_paths(create_dirs=True)
|
||||
log_dir = paths["log_dir"]
|
||||
else:
|
||||
log_dir = Path(log_dir)
|
||||
log_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
file_format = (
|
||||
"{time:YYYY-MM-DD at HH:mm:ss} "
|
||||
"| {level} "
|
||||
"| {name}:{function}:{line} "
|
||||
"- {message}"
|
||||
)
|
||||
|
||||
# Add additional file handler pointing to training directory
|
||||
# No need to remove existing handlers as we want to keep those
|
||||
logger.add(
|
||||
log_dir / "app.log",
|
||||
format=file_format,
|
||||
level="INFO",
|
||||
rotation="500 MB",
|
||||
retention="7 days",
|
||||
compression="zip",
|
||||
enqueue=True, # Enables asynchronous logging
|
||||
)
|
||||
|
||||
logger.info(f"Additional logs will be stored in: {log_dir}")
|
||||
|
||||
|
||||
# Paths configuration without creating directories
|
||||
def get_paths(create_dirs: bool = False) -> dict:
|
||||
"""
|
||||
Get common paths for the project
|
||||
|
||||
Args:
|
||||
create_dirs: Whether to create the directories
|
||||
|
||||
Returns:
|
||||
Dictionary with paths
|
||||
"""
|
||||
output_dir = Path(OUTPUT_DIR)
|
||||
log_dir = output_dir / "logs"
|
||||
tensorboard_dir = output_dir / "runs"
|
||||
|
||||
# Only create directories if explicitly requested
|
||||
if create_dirs:
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
log_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Only create tensorboard directory if it's enabled in config
|
||||
if TRAINING_CONFIG.get("report_to") == "tensorboard":
|
||||
tensorboard_dir.mkdir(exist_ok=True)
|
||||
|
||||
return {
|
||||
"output_dir": output_dir,
|
||||
"log_dir": log_dir,
|
||||
"tensorboard_dir": tensorboard_dir,
|
||||
"proj_root": PROJ_ROOT,
|
||||
"data_dir": DATA_DIR,
|
||||
}
|
||||
|
||||
|
||||
# Create training directories
|
||||
def init_training_dirs():
|
||||
"""Initialize all directories needed for training"""
|
||||
paths = get_paths(create_dirs=True)
|
||||
|
||||
# Also ensure our standard project directories exist
|
||||
for directory in [
|
||||
DATA_DIR,
|
||||
LOG_FOLDER,
|
||||
]:
|
||||
directory.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
return paths
|
||||
|
||||
|
||||
# For backward compatibility - will be deprecated
|
||||
def setup_logger(module_name=None, create_dirs: bool = False):
|
||||
"""
|
||||
Setup a logger for a specific module with consistent configuration.
|
||||
|
||||
Note: This function is kept for backward compatibility.
|
||||
Use the global 'logger' instead for new code.
|
||||
|
||||
Args:
|
||||
module_name: Optional name of module for module-specific log file
|
||||
create_dirs: Whether to create log directories
|
||||
|
||||
Returns:
|
||||
Configured logger instance
|
||||
"""
|
||||
logger.warning(
|
||||
"setup_logger is deprecated. Import logger directly from config instead."
|
||||
)
|
||||
return logger
|
||||
|
||||
|
||||
# Tensorboard writer singleton
|
||||
_tensorboard_writer = None
|
||||
|
||||
|
||||
# Safe tensorboard logging function
|
||||
def log_metric(key, value, step=0):
|
||||
"""
|
||||
Log a metric safely to tensorboard if writer is available.
|
||||
|
||||
Args:
|
||||
key: Metric name
|
||||
value: Metric value
|
||||
step: Training step
|
||||
"""
|
||||
global _tensorboard_writer
|
||||
|
||||
# Skip tensorboard logging if disabled in config
|
||||
if TRAINING_CONFIG.get("report_to") != "tensorboard":
|
||||
logger.debug(f"Tensorboard disabled. Metric: {key}={value} (step {step})")
|
||||
return
|
||||
|
||||
# Get paths and initialize writer if needed
|
||||
paths = get_paths(create_dirs=False)
|
||||
if paths["tensorboard_dir"].exists():
|
||||
# Only create writer once
|
||||
if _tensorboard_writer is None:
|
||||
from torch.utils.tensorboard.writer import SummaryWriter
|
||||
|
||||
_tensorboard_writer = SummaryWriter(paths["tensorboard_dir"])
|
||||
logger.debug(f"Created tensorboard writer at {paths['tensorboard_dir']}")
|
||||
|
||||
# Add scalar using existing writer
|
||||
_tensorboard_writer.add_scalar(key, value, step)
|
||||
# No need to close the writer - it will be closed at process exit
|
||||
else:
|
||||
logger.debug(f"Tensorboard metric: {key}={value} (step {step})")
|
||||
|
||||
|
||||
# Initialize logging on module import
|
||||
env = os.getenv("APP_ENV", "development")
|
||||
_init_logging(env=env)
|
||||
|
||||
# Log project root on import
|
||||
logger.info(f"Project root path: {PROJ_ROOT}")
|
||||
logger.debug(f"Running in {env} environment")
|
@ -1,192 +0,0 @@
|
||||
# %%
|
||||
import torch
|
||||
|
||||
# %%
|
||||
from unsloth import FastLanguageModel, is_bfloat16_supported
|
||||
|
||||
max_seq_length = 4096 * 2 # Can increase for longer reasoning traces
|
||||
lora_rank = 64 # Larger rank = smarter, but slower
|
||||
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="meta-llama/meta-Llama-3.1-8B-Instruct",
|
||||
max_seq_length=max_seq_length,
|
||||
load_in_4bit=True, # False for LoRA 16bit
|
||||
fast_inference=True, # Enable vLLM fast inference
|
||||
max_lora_rank=lora_rank,
|
||||
gpu_memory_utilization=0.6, # Reduce if out of memory
|
||||
)
|
||||
|
||||
model = FastLanguageModel.get_peft_model(
|
||||
model,
|
||||
r=lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
|
||||
target_modules=[
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
"o_proj",
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
"down_proj",
|
||||
], # Remove QKVO if out of memory
|
||||
lora_alpha=lora_rank,
|
||||
use_gradient_checkpointing="unsloth", # Enable long context finetuning
|
||||
random_state=3407,
|
||||
)
|
||||
|
||||
# %%
|
||||
import re
|
||||
|
||||
from datasets import Dataset, load_dataset
|
||||
|
||||
from rl_helpers import get_qa_dataset
|
||||
from search_module import get_question_answer, get_question_count, search
|
||||
|
||||
train_dataset, test_dataset = get_qa_dataset()
|
||||
|
||||
# %% [markdown]
|
||||
# <a name="Train"></a>
|
||||
# ### Train the model
|
||||
#
|
||||
# Now set up GRPO Trainer and all configurations!
|
||||
|
||||
# %%
|
||||
import os
|
||||
|
||||
os.environ["WANDB_PROJECT"] = "bootstrap-search-rl"
|
||||
|
||||
# %%
|
||||
# from UnslothGRPOTrainerTemp import UnslothGRPOConfig, _UnslothGRPOTrainer
|
||||
import UnslothGRPOTrainerTemp
|
||||
|
||||
training_args = UnslothGRPOTrainerTemp.UnslothGRPOConfig(
|
||||
use_vllm=True, # use vLLM for fast inference!
|
||||
use_agentic_generate=True, # use agentic generation
|
||||
learning_rate=5e-6,
|
||||
adam_beta1=0.9,
|
||||
adam_beta2=0.99,
|
||||
weight_decay=0.1,
|
||||
warmup_ratio=0.1,
|
||||
lr_scheduler_type="cosine",
|
||||
optim="paged_adamw_8bit",
|
||||
logging_steps=1,
|
||||
bf16=is_bfloat16_supported(),
|
||||
fp16=not is_bfloat16_supported(),
|
||||
per_device_train_batch_size=8,
|
||||
gradient_accumulation_steps=1, # Increase to 4 for smoother training
|
||||
num_generations=8, # Decrease if out of memory
|
||||
max_prompt_length=1024,
|
||||
max_completion_length=1024,
|
||||
# num_train_epochs = 1, # Set to 1 for a full training run
|
||||
max_steps=101,
|
||||
save_steps=50,
|
||||
max_grad_norm=0.1,
|
||||
report_to="none", # Can use Weights & Biases
|
||||
output_dir="full_local_training",
|
||||
)
|
||||
|
||||
# %%
|
||||
|
||||
|
||||
import rl_helpers
|
||||
|
||||
# importlib.reload(rl_helpers)
|
||||
|
||||
|
||||
def agentic_generate(
|
||||
prompts: list[str],
|
||||
generate_fn,
|
||||
max_generations: int = 6,
|
||||
):
|
||||
return run_agent(generate_fn, tokenizer, prompts, max_generations)
|
||||
|
||||
|
||||
model.agentic_generate = agentic_generate
|
||||
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
verifier_sampling_params = SamplingParams(
|
||||
temperature=0.1,
|
||||
top_p=0.95,
|
||||
max_tokens=4096,
|
||||
)
|
||||
|
||||
|
||||
def verifier_generate_fn(inputs):
|
||||
return model.fast_generate(
|
||||
inputs,
|
||||
sampling_params=verifier_sampling_params,
|
||||
)
|
||||
|
||||
|
||||
run_agent = rl_helpers.run_agent
|
||||
reward_correctness = rl_helpers.build_reward_correctness_fn(
|
||||
verifier_generate_fn,
|
||||
tokenizer,
|
||||
)
|
||||
reward_formatting = rl_helpers.reward_formatting
|
||||
|
||||
import UnslothGRPOTrainerTemp
|
||||
|
||||
trainer = UnslothGRPOTrainerTemp.UnslothGRPOTrainer(
|
||||
model=model,
|
||||
processing_class=tokenizer,
|
||||
reward_funcs=[
|
||||
reward_correctness,
|
||||
reward_formatting,
|
||||
],
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
)
|
||||
|
||||
# %%
|
||||
trainer.train()
|
||||
|
||||
# %% [markdown]
|
||||
# <a name="Inference"></a>
|
||||
# ### Inference
|
||||
# Now let's try benchmark the model we trained!
|
||||
|
||||
# %%
|
||||
from vllm import SamplingParams
|
||||
|
||||
import rl_helpers
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.5,
|
||||
top_p=0.95,
|
||||
max_tokens=4096,
|
||||
)
|
||||
|
||||
|
||||
def eval_generate_fn(inputs):
|
||||
return model.fast_generate(
|
||||
inputs,
|
||||
sampling_params=sampling_params,
|
||||
lora_request=model.load_lora(
|
||||
"full_local_training/checkpoint-101"
|
||||
), # load the trained LoRA
|
||||
)
|
||||
|
||||
|
||||
rl_helpers.run_eval(
|
||||
generate_fn=eval_generate_fn,
|
||||
verify_fn=reward_correctness,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
|
||||
# %%
|
||||
# eval w/o lora
|
||||
def eval_generate_fn(inputs):
|
||||
return model.fast_generate(
|
||||
inputs,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
|
||||
|
||||
rl_helpers.run_eval(
|
||||
generate_fn=eval_generate_fn,
|
||||
verify_fn=reward_correctness,
|
||||
tokenizer=tokenizer,
|
||||
)
|
@ -1,196 +0,0 @@
|
||||
# %%
|
||||
import torch
|
||||
|
||||
# %%
|
||||
from unsloth import FastLanguageModel, is_bfloat16_supported
|
||||
|
||||
max_seq_length = 4096 * 2 # Can increase for longer reasoning traces
|
||||
lora_rank = 64 # Larger rank = smarter, but slower
|
||||
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="meta-llama/Llama-3.2-1B-Instruct",
|
||||
max_seq_length=max_seq_length,
|
||||
load_in_4bit=True, # False for LoRA 16bit
|
||||
fast_inference=True, # Enable vLLM fast inference
|
||||
max_lora_rank=lora_rank,
|
||||
gpu_memory_utilization=0.6, # Reduce if out of memory
|
||||
)
|
||||
|
||||
|
||||
print(tokenizer.chat_template) # See what format Qwen expects
|
||||
|
||||
|
||||
model = FastLanguageModel.get_peft_model(
|
||||
model,
|
||||
r=lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
|
||||
target_modules=[
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
"o_proj",
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
"down_proj",
|
||||
], # Remove QKVO if out of memory
|
||||
lora_alpha=lora_rank,
|
||||
use_gradient_checkpointing="unsloth", # Enable long context finetuning
|
||||
random_state=3407,
|
||||
)
|
||||
|
||||
# %%
|
||||
import re
|
||||
|
||||
from datasets import Dataset, load_dataset
|
||||
|
||||
from rl_helpers import get_qa_dataset
|
||||
from search_module import get_question_answer, get_question_count, search
|
||||
|
||||
train_dataset, test_dataset = get_qa_dataset()
|
||||
|
||||
# %% [markdown]
|
||||
# <a name="Train"></a>
|
||||
# ### Train the model
|
||||
#
|
||||
# Now set up GRPO Trainer and all configurations!
|
||||
|
||||
# %%
|
||||
import os
|
||||
|
||||
os.environ["WANDB_PROJECT"] = "bootstrap-search-rl"
|
||||
|
||||
# %%
|
||||
# from UnslothGRPOTrainerTemp import UnslothGRPOConfig, _UnslothGRPOTrainer
|
||||
import UnslothGRPOTrainerTemp
|
||||
|
||||
training_args = UnslothGRPOTrainerTemp.UnslothGRPOConfig(
|
||||
use_vllm=True, # use vLLM for fast inference!
|
||||
use_agentic_generate=True, # use agentic generation
|
||||
learning_rate=5e-6,
|
||||
adam_beta1=0.9,
|
||||
adam_beta2=0.99,
|
||||
weight_decay=0.1,
|
||||
warmup_ratio=0.1,
|
||||
lr_scheduler_type="cosine",
|
||||
optim="paged_adamw_8bit",
|
||||
logging_steps=1,
|
||||
bf16=is_bfloat16_supported(),
|
||||
fp16=not is_bfloat16_supported(),
|
||||
per_device_train_batch_size=8,
|
||||
gradient_accumulation_steps=1, # Increase to 4 for smoother training
|
||||
num_generations=8, # Decrease if out of memory
|
||||
max_prompt_length=1024,
|
||||
max_completion_length=1024,
|
||||
# num_train_epochs = 1, # Set to 1 for a full training run
|
||||
max_steps=101,
|
||||
save_steps=50,
|
||||
max_grad_norm=0.1,
|
||||
report_to="none", # Can use Weights & Biases
|
||||
output_dir="full_local_training",
|
||||
)
|
||||
|
||||
# %%
|
||||
|
||||
|
||||
import rl_helpers
|
||||
|
||||
# importlib.reload(rl_helpers)
|
||||
|
||||
|
||||
def agentic_generate(
|
||||
prompts: list[str],
|
||||
generate_fn,
|
||||
max_generations: int = 6,
|
||||
):
|
||||
return run_agent(generate_fn, tokenizer, prompts, max_generations)
|
||||
|
||||
|
||||
model.agentic_generate = agentic_generate
|
||||
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
verifier_sampling_params = SamplingParams(
|
||||
temperature=0.1,
|
||||
top_p=0.95,
|
||||
max_tokens=4096,
|
||||
)
|
||||
|
||||
|
||||
def verifier_generate_fn(inputs):
|
||||
return model.fast_generate(
|
||||
inputs,
|
||||
sampling_params=verifier_sampling_params,
|
||||
)
|
||||
|
||||
|
||||
run_agent = rl_helpers.run_agent
|
||||
reward_correctness = rl_helpers.build_reward_correctness_fn(
|
||||
verifier_generate_fn,
|
||||
tokenizer,
|
||||
)
|
||||
reward_formatting = rl_helpers.reward_formatting
|
||||
|
||||
import UnslothGRPOTrainerTemp
|
||||
|
||||
trainer = UnslothGRPOTrainerTemp.UnslothGRPOTrainer(
|
||||
model=model,
|
||||
processing_class=tokenizer,
|
||||
reward_funcs=[
|
||||
reward_correctness,
|
||||
reward_formatting,
|
||||
],
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
)
|
||||
|
||||
# %%
|
||||
trainer.train()
|
||||
|
||||
# %% [markdown]
|
||||
# <a name="Inference"></a>
|
||||
# ### Inference
|
||||
# Now let's try benchmark the model we trained!
|
||||
|
||||
# %%
|
||||
from vllm import SamplingParams
|
||||
|
||||
import rl_helpers
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.5,
|
||||
top_p=0.95,
|
||||
max_tokens=4096,
|
||||
)
|
||||
|
||||
|
||||
def eval_generate_fn(inputs):
|
||||
return model.fast_generate(
|
||||
inputs,
|
||||
sampling_params=sampling_params,
|
||||
lora_request=model.load_lora(
|
||||
"full_local_training/checkpoint-101"
|
||||
), # load the trained LoRA
|
||||
)
|
||||
|
||||
|
||||
rl_helpers.run_eval(
|
||||
generate_fn=eval_generate_fn,
|
||||
verify_fn=reward_correctness,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
|
||||
# %%
|
||||
# eval w/o lora
|
||||
def eval_generate_fn(inputs):
|
||||
return model.fast_generate(
|
||||
inputs,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
|
||||
|
||||
rl_helpers.run_eval(
|
||||
generate_fn=eval_generate_fn,
|
||||
verify_fn=reward_correctness,
|
||||
tokenizer=tokenizer,
|
||||
)
|
@ -0,0 +1,124 @@
|
||||
import os
|
||||
|
||||
from unsloth import FastLanguageModel, is_bfloat16_supported
|
||||
|
||||
import src.UnslothGRPOTrainerTemp as UnslothGRPOTrainerTemp
|
||||
|
||||
# Import reward functions
|
||||
from src.rl_helpers import (
|
||||
build_reward_correctness_fn,
|
||||
get_qa_dataset,
|
||||
reward_exact_match_chunk_query,
|
||||
reward_formatting,
|
||||
reward_retry_behavior,
|
||||
run_agent,
|
||||
)
|
||||
from src.config import (
|
||||
MODEL_CONFIG,
|
||||
MODEL_NAME,
|
||||
OUTPUT_DIR,
|
||||
TRAINING_CONFIG,
|
||||
get_sampling_params,
|
||||
init_training_dirs,
|
||||
logger,
|
||||
update_log_path,
|
||||
)
|
||||
|
||||
# Initialize training directories
|
||||
paths = init_training_dirs()
|
||||
|
||||
# Update logger to use the training directory
|
||||
update_log_path(paths["log_dir"])
|
||||
logger.info(f"Training output directory: {paths['output_dir']}")
|
||||
logger.info(f"Logs are being saved to both ./logs and {paths['log_dir']}")
|
||||
|
||||
# Initialize model and tokenizer
|
||||
logger.info(f"Initializing model {MODEL_NAME}")
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name=MODEL_NAME,
|
||||
max_seq_length=MODEL_CONFIG["max_seq_length"],
|
||||
load_in_4bit=True, # False for LoRA 16bit
|
||||
fast_inference=True, # Enable vLLM fast inference
|
||||
max_lora_rank=MODEL_CONFIG["lora_rank"],
|
||||
gpu_memory_utilization=MODEL_CONFIG["gpu_memory_utilization"],
|
||||
)
|
||||
|
||||
# Setup LoRA
|
||||
logger.info("Setting up LoRA adapter")
|
||||
model = FastLanguageModel.get_peft_model(
|
||||
model,
|
||||
r=MODEL_CONFIG["lora_rank"],
|
||||
target_modules=MODEL_CONFIG["target_modules"],
|
||||
lora_alpha=MODEL_CONFIG["lora_rank"],
|
||||
use_gradient_checkpointing=True, # Enable long context finetuning
|
||||
random_state=3407,
|
||||
)
|
||||
|
||||
# Load datasets
|
||||
logger.info("Loading datasets")
|
||||
train_dataset, test_dataset = get_qa_dataset()
|
||||
logger.info(
|
||||
f"Loaded {len(train_dataset)} training examples and {len(test_dataset)} test examples"
|
||||
)
|
||||
|
||||
# Setup training arguments
|
||||
logger.info("Setting up training arguments")
|
||||
training_args = UnslothGRPOTrainerTemp.UnslothGRPOConfig(
|
||||
use_vllm=True, # use vLLM for fast inference!
|
||||
use_agentic_generate=True, # use agentic generation
|
||||
**TRAINING_CONFIG,
|
||||
bf16=is_bfloat16_supported(),
|
||||
fp16=not is_bfloat16_supported(),
|
||||
output_dir=OUTPUT_DIR,
|
||||
# report_to="tensorboard", # ❓ Does't have billions of tensorboard files if set report to right here
|
||||
)
|
||||
|
||||
|
||||
# Setup model generation functions
|
||||
def agentic_generate(
|
||||
prompts: list,
|
||||
generate_fn,
|
||||
max_generations: int = 10,
|
||||
):
|
||||
return run_agent(generate_fn, tokenizer, prompts, max_generations)
|
||||
|
||||
|
||||
model.agentic_generate = agentic_generate
|
||||
|
||||
# Setup verifier
|
||||
logger.info("Setting up verifier")
|
||||
verifier_sampling_params = get_sampling_params(temperature=0.1)
|
||||
|
||||
|
||||
def verifier_generate_fn(inputs):
|
||||
return model.fast_generate(
|
||||
inputs,
|
||||
sampling_params=verifier_sampling_params,
|
||||
)
|
||||
|
||||
|
||||
# Setup trainer
|
||||
logger.info("Initializing trainer")
|
||||
trainer = UnslothGRPOTrainerTemp.UnslothGRPOTrainer(
|
||||
model=model,
|
||||
processing_class=tokenizer,
|
||||
reward_funcs=[
|
||||
build_reward_correctness_fn(
|
||||
verifier_generate_fn,
|
||||
tokenizer,
|
||||
log_file=os.path.join(paths["log_dir"], "qa_log.txt"),
|
||||
),
|
||||
reward_formatting,
|
||||
reward_retry_behavior,
|
||||
reward_exact_match_chunk_query,
|
||||
],
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
)
|
||||
|
||||
# Train the model
|
||||
if __name__ == "__main__":
|
||||
logger.info("Starting training")
|
||||
trainer.train()
|
||||
logger.info("Training completed")
|
||||
logger.info(f"Model saved to {OUTPUT_DIR}")
|
Loading…
Reference in new issue