parent
04d56325bb
commit
3c2deaced9
@ -1,2 +1,2 @@
|
|||||||
HF_TOKEN=
|
HF_TOKEN=<your-huggingface-token>
|
||||||
OPENROUTER_API_KEY=
|
OPENROUTER_API_KEY=<your-openrouter-api-key>
|
@ -1,232 +0,0 @@
|
|||||||
"""
|
|
||||||
This script performs two main tasks:
|
|
||||||
1. It loads a markdown document, splits it into chunks, generates embeddings,
|
|
||||||
and builds a FAISS index (which is saved locally).
|
|
||||||
2. It generates QA pairs from the document using llama.
|
|
||||||
For each chunk (using a sliding window for context), it generates multiple question-answer pairs
|
|
||||||
with different difficulties. The generation is performed in batch with one retry for failed prompts.
|
|
||||||
Successfully generated QA pairs are saved to "saved_data/questions.json".
|
|
||||||
|
|
||||||
Requirements:
|
|
||||||
pip install langchain faiss-cpu unsloth vllm
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import pickle
|
|
||||||
import re
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
||||||
|
|
||||||
# ========= Part 1: Document Processing and Embedding Generation =========
|
|
||||||
# Load and split the markdown document using LangChain
|
|
||||||
from langchain_community.document_loaders import UnstructuredMarkdownLoader
|
|
||||||
from langchain_community.vectorstores import FAISS
|
|
||||||
|
|
||||||
from embeddings import CustomHuggingFaceEmbeddings
|
|
||||||
|
|
||||||
# Load your markdown file (adjust the path as needed)
|
|
||||||
loader = UnstructuredMarkdownLoader("./data/mission_report.md")
|
|
||||||
docs = loader.load()
|
|
||||||
|
|
||||||
# Split the document into smaller chunks (each 1000 characters, no overlap)
|
|
||||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
|
|
||||||
chunks = text_splitter.split_documents(docs)
|
|
||||||
|
|
||||||
# Save chunks for later use # TODO: change to csv? easier inspect.
|
|
||||||
os.makedirs("saved_data", exist_ok=True)
|
|
||||||
with open("saved_data/chunks.pkl", "wb") as f:
|
|
||||||
pickle.dump(chunks, f)
|
|
||||||
print(f"Saved {len(chunks)} chunks to saved_data/chunks.pkl")
|
|
||||||
|
|
||||||
embeddings = CustomHuggingFaceEmbeddings()
|
|
||||||
|
|
||||||
# Create a FAISS vector store from the document chunks and save it locally
|
|
||||||
vectorstore = FAISS.from_documents(chunks, embeddings)
|
|
||||||
vectorstore.save_local("faiss_index")
|
|
||||||
print("Saved FAISS index to 'faiss_index'")
|
|
||||||
|
|
||||||
# TODO: add the paraphrased chunks to the vector store
|
|
||||||
|
|
||||||
# ========= Part 2: QA Generation using Llama Backend =========
|
|
||||||
|
|
||||||
# Setup Llama backend via unsloth and vLLM
|
|
||||||
from unsloth import FastLanguageModel
|
|
||||||
from vllm import SamplingParams
|
|
||||||
|
|
||||||
import rl_helpers # Ensure you have this or remove if not used
|
|
||||||
|
|
||||||
# Load the Llama model (adjust parameters as needed)
|
|
||||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
|
||||||
model_name="meta-llama/meta-Llama-3.1-8B-Instruct",
|
|
||||||
max_seq_length=4096,
|
|
||||||
load_in_4bit=True, # Use 4-bit quantization if desired
|
|
||||||
fast_inference=True, # Enable fast inference
|
|
||||||
gpu_memory_utilization=0.6, # Adjust based on your GPU memory
|
|
||||||
)
|
|
||||||
|
|
||||||
# Define sampling parameters for generation
|
|
||||||
sampling_params = SamplingParams(
|
|
||||||
temperature=0.3,
|
|
||||||
top_p=0.95,
|
|
||||||
max_tokens=4096,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def batch_generate(prompts: List[str]) -> List[str]:
|
|
||||||
"""
|
|
||||||
Given a list of prompt strings, returns a list of generated outputs.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def format_input(text: str) -> str:
|
|
||||||
return tokenizer.apply_chat_template(
|
|
||||||
[{"role": "user", "content": text}],
|
|
||||||
tokenize=False,
|
|
||||||
add_generation_prompt=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
formatted = [format_input(p) for p in prompts]
|
|
||||||
outputs = model.fast_generate(formatted, sampling_params=sampling_params)
|
|
||||||
return [output.outputs[0].text for output in outputs]
|
|
||||||
|
|
||||||
|
|
||||||
def parse_qa_block(block: str) -> Optional[Tuple[str, str, str]]:
|
|
||||||
"""
|
|
||||||
Parses a QA block that should contain exactly three non-empty lines:
|
|
||||||
- A line starting with "Question:"
|
|
||||||
- A line starting with "Answer:"
|
|
||||||
- A line starting with "Difficulty:"
|
|
||||||
|
|
||||||
If the markers are not present but the block contains exactly three lines,
|
|
||||||
those are used in order.
|
|
||||||
|
|
||||||
Returns a tuple (question, answer, difficulty) or None if parsing fails.
|
|
||||||
"""
|
|
||||||
lines = [line.strip() for line in block.splitlines() if line.strip()]
|
|
||||||
if not lines:
|
|
||||||
return None
|
|
||||||
|
|
||||||
question, answer, difficulty = None, None, None
|
|
||||||
for line in lines:
|
|
||||||
lower = line.lower()
|
|
||||||
if question is None and lower.startswith("question:"):
|
|
||||||
question = line[len("question:") :].strip()
|
|
||||||
elif answer is None and lower.startswith("answer:"):
|
|
||||||
answer = line[len("answer:") :].strip()
|
|
||||||
elif difficulty is None and lower.startswith("difficulty:"):
|
|
||||||
difficulty = line[len("difficulty:") :].strip()
|
|
||||||
|
|
||||||
if question and answer and difficulty:
|
|
||||||
return question, answer, difficulty
|
|
||||||
if len(lines) == 3:
|
|
||||||
return lines[0], lines[1], lines[2]
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def parse_multiple_qa_output(output: str) -> List[Tuple[str, str, str]]:
|
|
||||||
"""
|
|
||||||
Splits the output into blocks (separated by one or more blank lines) and
|
|
||||||
attempts to parse each as a QA pair.
|
|
||||||
|
|
||||||
Returns a list of successfully parsed QA tuples.
|
|
||||||
"""
|
|
||||||
blocks = re.split(r"\n\s*\n", output.strip())
|
|
||||||
qa_pairs = []
|
|
||||||
for block in blocks:
|
|
||||||
parsed = parse_qa_block(block)
|
|
||||||
if parsed:
|
|
||||||
qa_pairs.append(parsed)
|
|
||||||
return qa_pairs
|
|
||||||
|
|
||||||
|
|
||||||
def generate_question_batch_for_chunks(
|
|
||||||
chunks: List, num_questions: int = 2, difficulty: str = None
|
|
||||||
) -> List[Dict]:
|
|
||||||
"""
|
|
||||||
Generates QA pairs for multiple chunks in batch.
|
|
||||||
|
|
||||||
For each chunk (except the first and last), a sliding window is used for context:
|
|
||||||
- before: previous chunk's content
|
|
||||||
- current: current chunk's content
|
|
||||||
- after: next chunk's content
|
|
||||||
|
|
||||||
Each prompt instructs the model to output exactly three lines per QA pair with markers.
|
|
||||||
Failed prompts are retried once in batch; if still unsuccessful, they are skipped.
|
|
||||||
|
|
||||||
Returns a list of dicts with keys: "chunk_id", "question", "answer", "difficulty".
|
|
||||||
"""
|
|
||||||
prompts = []
|
|
||||||
chunk_ids = []
|
|
||||||
|
|
||||||
# Prepare prompts using a sliding window
|
|
||||||
for i in range(1, len(chunks) - 1):
|
|
||||||
before = chunks[i - 1].page_content
|
|
||||||
current = chunks[i].page_content
|
|
||||||
after = chunks[i + 1].page_content
|
|
||||||
prompt = (
|
|
||||||
f"From the text within ==BEGIN== and ==END==, generate {num_questions} questions with answers.\n"
|
|
||||||
"For each QA pair, output exactly three lines with no extra commentary:\n"
|
|
||||||
"Line 1: Question: <your question>\n"
|
|
||||||
"Line 2: Answer: <the answer>\n"
|
|
||||||
"Line 3: Difficulty: <easy, medium, or hard>\n"
|
|
||||||
"Do not include any additional text.\n\n"
|
|
||||||
"==BEGIN==\n"
|
|
||||||
f"{before}\n{current}\n{after}\n"
|
|
||||||
"==END==\n"
|
|
||||||
)
|
|
||||||
prompts.append(prompt)
|
|
||||||
chunk_ids.append(i)
|
|
||||||
|
|
||||||
# First batch generation
|
|
||||||
outputs = batch_generate(prompts)
|
|
||||||
results = [None] * len(outputs)
|
|
||||||
failed_indices = []
|
|
||||||
|
|
||||||
# Parse each output
|
|
||||||
for idx, output in enumerate(outputs):
|
|
||||||
qa_pairs = parse_multiple_qa_output(output)
|
|
||||||
if qa_pairs is None or len(qa_pairs) < num_questions:
|
|
||||||
failed_indices.append(idx)
|
|
||||||
else:
|
|
||||||
results[idx] = qa_pairs[:num_questions]
|
|
||||||
|
|
||||||
# Retry failed prompts in batch
|
|
||||||
if failed_indices:
|
|
||||||
print(f"Retrying {len(failed_indices)} failed prompt(s)...")
|
|
||||||
retry_prompts = [prompts[i] for i in failed_indices]
|
|
||||||
retry_outputs = batch_generate(retry_prompts)
|
|
||||||
for j, idx in enumerate(failed_indices):
|
|
||||||
qa_pairs = parse_multiple_qa_output(retry_outputs[j])
|
|
||||||
if qa_pairs is not None and len(qa_pairs) >= num_questions:
|
|
||||||
results[idx] = qa_pairs[:num_questions]
|
|
||||||
else:
|
|
||||||
results[idx] = None # Mark as failed
|
|
||||||
|
|
||||||
# Build final output, skipping prompts that failed even after retry
|
|
||||||
final_questions = []
|
|
||||||
for i, qa_list in enumerate(results):
|
|
||||||
if qa_list is not None:
|
|
||||||
for qa in qa_list:
|
|
||||||
final_questions.append(
|
|
||||||
{
|
|
||||||
"chunk_id": chunk_ids[i],
|
|
||||||
"question": qa[0],
|
|
||||||
"answer": qa[1],
|
|
||||||
"difficulty": qa[2],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return final_questions
|
|
||||||
|
|
||||||
|
|
||||||
# Generate QA pairs in batch (using a sliding window over the chunks)
|
|
||||||
all_questions = generate_question_batch_for_chunks(
|
|
||||||
chunks, num_questions=2, difficulty="medium"
|
|
||||||
)
|
|
||||||
print(f"Generated {len(all_questions)} QA pairs.")
|
|
||||||
|
|
||||||
# Save the QA pairs to a JSON file
|
|
||||||
questions_path = os.path.join("saved_data", "questions.json")
|
|
||||||
with open(questions_path, "w") as f:
|
|
||||||
json.dump(all_questions, f, indent=2)
|
|
||||||
print(f"Saved questions to {questions_path}")
|
|
@ -1,3 +0,0 @@
|
|||||||
unsloth_compiled_cache
|
|
||||||
0_*
|
|
||||||
faiss_index*
|
|
@ -0,0 +1,301 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from loguru import logger
|
||||||
|
from vllm import SamplingParams
|
||||||
|
|
||||||
|
# Load environment variables from .env file if it exists
|
||||||
|
load_dotenv(override=True)
|
||||||
|
|
||||||
|
# Project paths
|
||||||
|
PROJ_ROOT = Path(__file__).resolve().parent.parent
|
||||||
|
DATA_DIR = PROJ_ROOT / "data"
|
||||||
|
LOG_FOLDER = PROJ_ROOT / "logs"
|
||||||
|
|
||||||
|
# Model configuration
|
||||||
|
# MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
|
||||||
|
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
|
||||||
|
device_id = (
|
||||||
|
1 if os.environ.get("CUDA_VISIBLE_DEVICES") == "1" else torch.cuda.current_device()
|
||||||
|
)
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
|
||||||
|
OUTPUT_DIR = (
|
||||||
|
PROJ_ROOT
|
||||||
|
/ f"trainer_output_{MODEL_NAME.replace('/', '_')}_gpu{device_id}_{timestamp}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Model parameters
|
||||||
|
MODEL_CONFIG = {
|
||||||
|
"max_seq_length": 4096 * 2, # Can increase for longer reasoning traces
|
||||||
|
"lora_rank": 64, # Larger rank = smarter, but slower
|
||||||
|
"gpu_memory_utilization": 0.6, # Reduce if out of memory
|
||||||
|
"model_name": MODEL_NAME,
|
||||||
|
"target_modules": [
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
"o_proj",
|
||||||
|
"gate_proj",
|
||||||
|
"up_proj",
|
||||||
|
"down_proj",
|
||||||
|
], # Remove QKVO if out of memory
|
||||||
|
}
|
||||||
|
|
||||||
|
# Training parameters
|
||||||
|
TRAINING_CONFIG = {
|
||||||
|
"learning_rate": 5e-6,
|
||||||
|
"adam_beta1": 0.9,
|
||||||
|
"adam_beta2": 0.99,
|
||||||
|
"weight_decay": 0.1,
|
||||||
|
"warmup_ratio": 0.1,
|
||||||
|
"lr_scheduler_type": "cosine",
|
||||||
|
"optim": "paged_adamw_8bit",
|
||||||
|
"logging_steps": 1,
|
||||||
|
"per_device_train_batch_size": 8,
|
||||||
|
"gradient_accumulation_steps": 1, # Increase to 4 for smoother training
|
||||||
|
"num_generations": 8, # Decrease if out of memory
|
||||||
|
"max_prompt_length": 1024,
|
||||||
|
"max_completion_length": 1024,
|
||||||
|
"max_steps": 101,
|
||||||
|
"save_steps": 50,
|
||||||
|
"max_grad_norm": 0.1,
|
||||||
|
"report_to": "tensorboard",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Sampling parameters
|
||||||
|
def get_sampling_params(temperature: float = 0.1) -> SamplingParams:
|
||||||
|
"""Get sampling parameters for text generation"""
|
||||||
|
return SamplingParams(
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=0.95,
|
||||||
|
max_tokens=4096,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Initialize logging based on environment
|
||||||
|
def _init_logging(env: str = "development") -> None:
|
||||||
|
"""
|
||||||
|
Initialize logging configuration with console logging
|
||||||
|
and default file logging to ./logs directory.
|
||||||
|
Additional file logging will be set up later in update_log_path().
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: The environment for logging ('development' or 'production')
|
||||||
|
"""
|
||||||
|
# Create default log folder
|
||||||
|
if not LOG_FOLDER.exists():
|
||||||
|
LOG_FOLDER.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Remove any existing handlers
|
||||||
|
logger.remove()
|
||||||
|
|
||||||
|
# Define the logging format
|
||||||
|
console_format = (
|
||||||
|
"<green>{time:YYYY-MM-DD HH:mm:ss}</green> "
|
||||||
|
"| <level>{level: <8}</level> "
|
||||||
|
"| <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> "
|
||||||
|
"- <level>{message}</level>"
|
||||||
|
)
|
||||||
|
|
||||||
|
file_format = (
|
||||||
|
"{time:YYYY-MM-DD at HH:mm:ss} "
|
||||||
|
"| {level} "
|
||||||
|
"| {name}:{function}:{line} "
|
||||||
|
"- {message}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add console logging
|
||||||
|
logger.add(
|
||||||
|
sys.stderr,
|
||||||
|
format=console_format,
|
||||||
|
level="DEBUG" if env == "development" else "INFO",
|
||||||
|
colorize=True,
|
||||||
|
backtrace=True,
|
||||||
|
diagnose=env == "development",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add default file logging to ./logs directory
|
||||||
|
logger.add(
|
||||||
|
LOG_FOLDER / "app.log",
|
||||||
|
format=file_format,
|
||||||
|
level="INFO",
|
||||||
|
rotation="500 MB",
|
||||||
|
retention="7 days",
|
||||||
|
compression="zip",
|
||||||
|
enqueue=True, # Enables asynchronous logging
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add custom level for requests
|
||||||
|
logger.level("REQUEST", no=25, color="<yellow>", icon=" ")
|
||||||
|
|
||||||
|
# Configure exception handling
|
||||||
|
def exception_handler(exc_type, exc_value, exc_traceback):
|
||||||
|
if issubclass(exc_type, KeyboardInterrupt):
|
||||||
|
sys.__excepthook__(exc_type, exc_value, exc_traceback)
|
||||||
|
return
|
||||||
|
logger.opt(exception=(exc_type, exc_value, exc_traceback)).critical(
|
||||||
|
"Unhandled exception"
|
||||||
|
)
|
||||||
|
|
||||||
|
sys.excepthook = exception_handler
|
||||||
|
|
||||||
|
|
||||||
|
# Update the log files to point to the training directory
|
||||||
|
def update_log_path(log_dir=None):
|
||||||
|
"""
|
||||||
|
Add a log file in the training directory while keeping the default ./logs logging.
|
||||||
|
Should be called after the training directory is created.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
log_dir: Path to store additional log files (default: uses get_paths()["log_dir"])
|
||||||
|
"""
|
||||||
|
# Use provided log_dir or get from training paths
|
||||||
|
if log_dir is None:
|
||||||
|
paths = get_paths(create_dirs=True)
|
||||||
|
log_dir = paths["log_dir"]
|
||||||
|
else:
|
||||||
|
log_dir = Path(log_dir)
|
||||||
|
log_dir.mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
|
file_format = (
|
||||||
|
"{time:YYYY-MM-DD at HH:mm:ss} "
|
||||||
|
"| {level} "
|
||||||
|
"| {name}:{function}:{line} "
|
||||||
|
"- {message}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add additional file handler pointing to training directory
|
||||||
|
# No need to remove existing handlers as we want to keep those
|
||||||
|
logger.add(
|
||||||
|
log_dir / "app.log",
|
||||||
|
format=file_format,
|
||||||
|
level="INFO",
|
||||||
|
rotation="500 MB",
|
||||||
|
retention="7 days",
|
||||||
|
compression="zip",
|
||||||
|
enqueue=True, # Enables asynchronous logging
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Additional logs will be stored in: {log_dir}")
|
||||||
|
|
||||||
|
|
||||||
|
# Paths configuration without creating directories
|
||||||
|
def get_paths(create_dirs: bool = False) -> dict:
|
||||||
|
"""
|
||||||
|
Get common paths for the project
|
||||||
|
|
||||||
|
Args:
|
||||||
|
create_dirs: Whether to create the directories
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with paths
|
||||||
|
"""
|
||||||
|
output_dir = Path(OUTPUT_DIR)
|
||||||
|
log_dir = output_dir / "logs"
|
||||||
|
tensorboard_dir = output_dir / "runs"
|
||||||
|
|
||||||
|
# Only create directories if explicitly requested
|
||||||
|
if create_dirs:
|
||||||
|
output_dir.mkdir(exist_ok=True)
|
||||||
|
log_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
# Only create tensorboard directory if it's enabled in config
|
||||||
|
if TRAINING_CONFIG.get("report_to") == "tensorboard":
|
||||||
|
tensorboard_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"output_dir": output_dir,
|
||||||
|
"log_dir": log_dir,
|
||||||
|
"tensorboard_dir": tensorboard_dir,
|
||||||
|
"proj_root": PROJ_ROOT,
|
||||||
|
"data_dir": DATA_DIR,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Create training directories
|
||||||
|
def init_training_dirs():
|
||||||
|
"""Initialize all directories needed for training"""
|
||||||
|
paths = get_paths(create_dirs=True)
|
||||||
|
|
||||||
|
# Also ensure our standard project directories exist
|
||||||
|
for directory in [
|
||||||
|
DATA_DIR,
|
||||||
|
LOG_FOLDER,
|
||||||
|
]:
|
||||||
|
directory.mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
|
return paths
|
||||||
|
|
||||||
|
|
||||||
|
# For backward compatibility - will be deprecated
|
||||||
|
def setup_logger(module_name=None, create_dirs: bool = False):
|
||||||
|
"""
|
||||||
|
Setup a logger for a specific module with consistent configuration.
|
||||||
|
|
||||||
|
Note: This function is kept for backward compatibility.
|
||||||
|
Use the global 'logger' instead for new code.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module_name: Optional name of module for module-specific log file
|
||||||
|
create_dirs: Whether to create log directories
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured logger instance
|
||||||
|
"""
|
||||||
|
logger.warning(
|
||||||
|
"setup_logger is deprecated. Import logger directly from config instead."
|
||||||
|
)
|
||||||
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
# Tensorboard writer singleton
|
||||||
|
_tensorboard_writer = None
|
||||||
|
|
||||||
|
|
||||||
|
# Safe tensorboard logging function
|
||||||
|
def log_metric(key, value, step=0):
|
||||||
|
"""
|
||||||
|
Log a metric safely to tensorboard if writer is available.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Metric name
|
||||||
|
value: Metric value
|
||||||
|
step: Training step
|
||||||
|
"""
|
||||||
|
global _tensorboard_writer
|
||||||
|
|
||||||
|
# Skip tensorboard logging if disabled in config
|
||||||
|
if TRAINING_CONFIG.get("report_to") != "tensorboard":
|
||||||
|
logger.debug(f"Tensorboard disabled. Metric: {key}={value} (step {step})")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get paths and initialize writer if needed
|
||||||
|
paths = get_paths(create_dirs=False)
|
||||||
|
if paths["tensorboard_dir"].exists():
|
||||||
|
# Only create writer once
|
||||||
|
if _tensorboard_writer is None:
|
||||||
|
from torch.utils.tensorboard.writer import SummaryWriter
|
||||||
|
|
||||||
|
_tensorboard_writer = SummaryWriter(paths["tensorboard_dir"])
|
||||||
|
logger.debug(f"Created tensorboard writer at {paths['tensorboard_dir']}")
|
||||||
|
|
||||||
|
# Add scalar using existing writer
|
||||||
|
_tensorboard_writer.add_scalar(key, value, step)
|
||||||
|
# No need to close the writer - it will be closed at process exit
|
||||||
|
else:
|
||||||
|
logger.debug(f"Tensorboard metric: {key}={value} (step {step})")
|
||||||
|
|
||||||
|
|
||||||
|
# Initialize logging on module import
|
||||||
|
env = os.getenv("APP_ENV", "development")
|
||||||
|
_init_logging(env=env)
|
||||||
|
|
||||||
|
# Log project root on import
|
||||||
|
logger.info(f"Project root path: {PROJ_ROOT}")
|
||||||
|
logger.debug(f"Running in {env} environment")
|
@ -1,192 +0,0 @@
|
|||||||
# %%
|
|
||||||
import torch
|
|
||||||
|
|
||||||
# %%
|
|
||||||
from unsloth import FastLanguageModel, is_bfloat16_supported
|
|
||||||
|
|
||||||
max_seq_length = 4096 * 2 # Can increase for longer reasoning traces
|
|
||||||
lora_rank = 64 # Larger rank = smarter, but slower
|
|
||||||
|
|
||||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
|
||||||
model_name="meta-llama/meta-Llama-3.1-8B-Instruct",
|
|
||||||
max_seq_length=max_seq_length,
|
|
||||||
load_in_4bit=True, # False for LoRA 16bit
|
|
||||||
fast_inference=True, # Enable vLLM fast inference
|
|
||||||
max_lora_rank=lora_rank,
|
|
||||||
gpu_memory_utilization=0.6, # Reduce if out of memory
|
|
||||||
)
|
|
||||||
|
|
||||||
model = FastLanguageModel.get_peft_model(
|
|
||||||
model,
|
|
||||||
r=lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
|
|
||||||
target_modules=[
|
|
||||||
"q_proj",
|
|
||||||
"k_proj",
|
|
||||||
"v_proj",
|
|
||||||
"o_proj",
|
|
||||||
"gate_proj",
|
|
||||||
"up_proj",
|
|
||||||
"down_proj",
|
|
||||||
], # Remove QKVO if out of memory
|
|
||||||
lora_alpha=lora_rank,
|
|
||||||
use_gradient_checkpointing="unsloth", # Enable long context finetuning
|
|
||||||
random_state=3407,
|
|
||||||
)
|
|
||||||
|
|
||||||
# %%
|
|
||||||
import re
|
|
||||||
|
|
||||||
from datasets import Dataset, load_dataset
|
|
||||||
|
|
||||||
from rl_helpers import get_qa_dataset
|
|
||||||
from search_module import get_question_answer, get_question_count, search
|
|
||||||
|
|
||||||
train_dataset, test_dataset = get_qa_dataset()
|
|
||||||
|
|
||||||
# %% [markdown]
|
|
||||||
# <a name="Train"></a>
|
|
||||||
# ### Train the model
|
|
||||||
#
|
|
||||||
# Now set up GRPO Trainer and all configurations!
|
|
||||||
|
|
||||||
# %%
|
|
||||||
import os
|
|
||||||
|
|
||||||
os.environ["WANDB_PROJECT"] = "bootstrap-search-rl"
|
|
||||||
|
|
||||||
# %%
|
|
||||||
# from UnslothGRPOTrainerTemp import UnslothGRPOConfig, _UnslothGRPOTrainer
|
|
||||||
import UnslothGRPOTrainerTemp
|
|
||||||
|
|
||||||
training_args = UnslothGRPOTrainerTemp.UnslothGRPOConfig(
|
|
||||||
use_vllm=True, # use vLLM for fast inference!
|
|
||||||
use_agentic_generate=True, # use agentic generation
|
|
||||||
learning_rate=5e-6,
|
|
||||||
adam_beta1=0.9,
|
|
||||||
adam_beta2=0.99,
|
|
||||||
weight_decay=0.1,
|
|
||||||
warmup_ratio=0.1,
|
|
||||||
lr_scheduler_type="cosine",
|
|
||||||
optim="paged_adamw_8bit",
|
|
||||||
logging_steps=1,
|
|
||||||
bf16=is_bfloat16_supported(),
|
|
||||||
fp16=not is_bfloat16_supported(),
|
|
||||||
per_device_train_batch_size=8,
|
|
||||||
gradient_accumulation_steps=1, # Increase to 4 for smoother training
|
|
||||||
num_generations=8, # Decrease if out of memory
|
|
||||||
max_prompt_length=1024,
|
|
||||||
max_completion_length=1024,
|
|
||||||
# num_train_epochs = 1, # Set to 1 for a full training run
|
|
||||||
max_steps=101,
|
|
||||||
save_steps=50,
|
|
||||||
max_grad_norm=0.1,
|
|
||||||
report_to="none", # Can use Weights & Biases
|
|
||||||
output_dir="full_local_training",
|
|
||||||
)
|
|
||||||
|
|
||||||
# %%
|
|
||||||
|
|
||||||
|
|
||||||
import rl_helpers
|
|
||||||
|
|
||||||
# importlib.reload(rl_helpers)
|
|
||||||
|
|
||||||
|
|
||||||
def agentic_generate(
|
|
||||||
prompts: list[str],
|
|
||||||
generate_fn,
|
|
||||||
max_generations: int = 6,
|
|
||||||
):
|
|
||||||
return run_agent(generate_fn, tokenizer, prompts, max_generations)
|
|
||||||
|
|
||||||
|
|
||||||
model.agentic_generate = agentic_generate
|
|
||||||
|
|
||||||
|
|
||||||
from vllm import SamplingParams
|
|
||||||
|
|
||||||
verifier_sampling_params = SamplingParams(
|
|
||||||
temperature=0.1,
|
|
||||||
top_p=0.95,
|
|
||||||
max_tokens=4096,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def verifier_generate_fn(inputs):
|
|
||||||
return model.fast_generate(
|
|
||||||
inputs,
|
|
||||||
sampling_params=verifier_sampling_params,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
run_agent = rl_helpers.run_agent
|
|
||||||
reward_correctness = rl_helpers.build_reward_correctness_fn(
|
|
||||||
verifier_generate_fn,
|
|
||||||
tokenizer,
|
|
||||||
)
|
|
||||||
reward_formatting = rl_helpers.reward_formatting
|
|
||||||
|
|
||||||
import UnslothGRPOTrainerTemp
|
|
||||||
|
|
||||||
trainer = UnslothGRPOTrainerTemp.UnslothGRPOTrainer(
|
|
||||||
model=model,
|
|
||||||
processing_class=tokenizer,
|
|
||||||
reward_funcs=[
|
|
||||||
reward_correctness,
|
|
||||||
reward_formatting,
|
|
||||||
],
|
|
||||||
args=training_args,
|
|
||||||
train_dataset=train_dataset,
|
|
||||||
)
|
|
||||||
|
|
||||||
# %%
|
|
||||||
trainer.train()
|
|
||||||
|
|
||||||
# %% [markdown]
|
|
||||||
# <a name="Inference"></a>
|
|
||||||
# ### Inference
|
|
||||||
# Now let's try benchmark the model we trained!
|
|
||||||
|
|
||||||
# %%
|
|
||||||
from vllm import SamplingParams
|
|
||||||
|
|
||||||
import rl_helpers
|
|
||||||
|
|
||||||
sampling_params = SamplingParams(
|
|
||||||
temperature=0.5,
|
|
||||||
top_p=0.95,
|
|
||||||
max_tokens=4096,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def eval_generate_fn(inputs):
|
|
||||||
return model.fast_generate(
|
|
||||||
inputs,
|
|
||||||
sampling_params=sampling_params,
|
|
||||||
lora_request=model.load_lora(
|
|
||||||
"full_local_training/checkpoint-101"
|
|
||||||
), # load the trained LoRA
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
rl_helpers.run_eval(
|
|
||||||
generate_fn=eval_generate_fn,
|
|
||||||
verify_fn=reward_correctness,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# %%
|
|
||||||
# eval w/o lora
|
|
||||||
def eval_generate_fn(inputs):
|
|
||||||
return model.fast_generate(
|
|
||||||
inputs,
|
|
||||||
sampling_params=sampling_params,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
rl_helpers.run_eval(
|
|
||||||
generate_fn=eval_generate_fn,
|
|
||||||
verify_fn=reward_correctness,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
)
|
|
@ -1,196 +0,0 @@
|
|||||||
# %%
|
|
||||||
import torch
|
|
||||||
|
|
||||||
# %%
|
|
||||||
from unsloth import FastLanguageModel, is_bfloat16_supported
|
|
||||||
|
|
||||||
max_seq_length = 4096 * 2 # Can increase for longer reasoning traces
|
|
||||||
lora_rank = 64 # Larger rank = smarter, but slower
|
|
||||||
|
|
||||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
|
||||||
model_name="meta-llama/Llama-3.2-1B-Instruct",
|
|
||||||
max_seq_length=max_seq_length,
|
|
||||||
load_in_4bit=True, # False for LoRA 16bit
|
|
||||||
fast_inference=True, # Enable vLLM fast inference
|
|
||||||
max_lora_rank=lora_rank,
|
|
||||||
gpu_memory_utilization=0.6, # Reduce if out of memory
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
print(tokenizer.chat_template) # See what format Qwen expects
|
|
||||||
|
|
||||||
|
|
||||||
model = FastLanguageModel.get_peft_model(
|
|
||||||
model,
|
|
||||||
r=lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
|
|
||||||
target_modules=[
|
|
||||||
"q_proj",
|
|
||||||
"k_proj",
|
|
||||||
"v_proj",
|
|
||||||
"o_proj",
|
|
||||||
"gate_proj",
|
|
||||||
"up_proj",
|
|
||||||
"down_proj",
|
|
||||||
], # Remove QKVO if out of memory
|
|
||||||
lora_alpha=lora_rank,
|
|
||||||
use_gradient_checkpointing="unsloth", # Enable long context finetuning
|
|
||||||
random_state=3407,
|
|
||||||
)
|
|
||||||
|
|
||||||
# %%
|
|
||||||
import re
|
|
||||||
|
|
||||||
from datasets import Dataset, load_dataset
|
|
||||||
|
|
||||||
from rl_helpers import get_qa_dataset
|
|
||||||
from search_module import get_question_answer, get_question_count, search
|
|
||||||
|
|
||||||
train_dataset, test_dataset = get_qa_dataset()
|
|
||||||
|
|
||||||
# %% [markdown]
|
|
||||||
# <a name="Train"></a>
|
|
||||||
# ### Train the model
|
|
||||||
#
|
|
||||||
# Now set up GRPO Trainer and all configurations!
|
|
||||||
|
|
||||||
# %%
|
|
||||||
import os
|
|
||||||
|
|
||||||
os.environ["WANDB_PROJECT"] = "bootstrap-search-rl"
|
|
||||||
|
|
||||||
# %%
|
|
||||||
# from UnslothGRPOTrainerTemp import UnslothGRPOConfig, _UnslothGRPOTrainer
|
|
||||||
import UnslothGRPOTrainerTemp
|
|
||||||
|
|
||||||
training_args = UnslothGRPOTrainerTemp.UnslothGRPOConfig(
|
|
||||||
use_vllm=True, # use vLLM for fast inference!
|
|
||||||
use_agentic_generate=True, # use agentic generation
|
|
||||||
learning_rate=5e-6,
|
|
||||||
adam_beta1=0.9,
|
|
||||||
adam_beta2=0.99,
|
|
||||||
weight_decay=0.1,
|
|
||||||
warmup_ratio=0.1,
|
|
||||||
lr_scheduler_type="cosine",
|
|
||||||
optim="paged_adamw_8bit",
|
|
||||||
logging_steps=1,
|
|
||||||
bf16=is_bfloat16_supported(),
|
|
||||||
fp16=not is_bfloat16_supported(),
|
|
||||||
per_device_train_batch_size=8,
|
|
||||||
gradient_accumulation_steps=1, # Increase to 4 for smoother training
|
|
||||||
num_generations=8, # Decrease if out of memory
|
|
||||||
max_prompt_length=1024,
|
|
||||||
max_completion_length=1024,
|
|
||||||
# num_train_epochs = 1, # Set to 1 for a full training run
|
|
||||||
max_steps=101,
|
|
||||||
save_steps=50,
|
|
||||||
max_grad_norm=0.1,
|
|
||||||
report_to="none", # Can use Weights & Biases
|
|
||||||
output_dir="full_local_training",
|
|
||||||
)
|
|
||||||
|
|
||||||
# %%
|
|
||||||
|
|
||||||
|
|
||||||
import rl_helpers
|
|
||||||
|
|
||||||
# importlib.reload(rl_helpers)
|
|
||||||
|
|
||||||
|
|
||||||
def agentic_generate(
|
|
||||||
prompts: list[str],
|
|
||||||
generate_fn,
|
|
||||||
max_generations: int = 6,
|
|
||||||
):
|
|
||||||
return run_agent(generate_fn, tokenizer, prompts, max_generations)
|
|
||||||
|
|
||||||
|
|
||||||
model.agentic_generate = agentic_generate
|
|
||||||
|
|
||||||
|
|
||||||
from vllm import SamplingParams
|
|
||||||
|
|
||||||
verifier_sampling_params = SamplingParams(
|
|
||||||
temperature=0.1,
|
|
||||||
top_p=0.95,
|
|
||||||
max_tokens=4096,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def verifier_generate_fn(inputs):
|
|
||||||
return model.fast_generate(
|
|
||||||
inputs,
|
|
||||||
sampling_params=verifier_sampling_params,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
run_agent = rl_helpers.run_agent
|
|
||||||
reward_correctness = rl_helpers.build_reward_correctness_fn(
|
|
||||||
verifier_generate_fn,
|
|
||||||
tokenizer,
|
|
||||||
)
|
|
||||||
reward_formatting = rl_helpers.reward_formatting
|
|
||||||
|
|
||||||
import UnslothGRPOTrainerTemp
|
|
||||||
|
|
||||||
trainer = UnslothGRPOTrainerTemp.UnslothGRPOTrainer(
|
|
||||||
model=model,
|
|
||||||
processing_class=tokenizer,
|
|
||||||
reward_funcs=[
|
|
||||||
reward_correctness,
|
|
||||||
reward_formatting,
|
|
||||||
],
|
|
||||||
args=training_args,
|
|
||||||
train_dataset=train_dataset,
|
|
||||||
)
|
|
||||||
|
|
||||||
# %%
|
|
||||||
trainer.train()
|
|
||||||
|
|
||||||
# %% [markdown]
|
|
||||||
# <a name="Inference"></a>
|
|
||||||
# ### Inference
|
|
||||||
# Now let's try benchmark the model we trained!
|
|
||||||
|
|
||||||
# %%
|
|
||||||
from vllm import SamplingParams
|
|
||||||
|
|
||||||
import rl_helpers
|
|
||||||
|
|
||||||
sampling_params = SamplingParams(
|
|
||||||
temperature=0.5,
|
|
||||||
top_p=0.95,
|
|
||||||
max_tokens=4096,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def eval_generate_fn(inputs):
|
|
||||||
return model.fast_generate(
|
|
||||||
inputs,
|
|
||||||
sampling_params=sampling_params,
|
|
||||||
lora_request=model.load_lora(
|
|
||||||
"full_local_training/checkpoint-101"
|
|
||||||
), # load the trained LoRA
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
rl_helpers.run_eval(
|
|
||||||
generate_fn=eval_generate_fn,
|
|
||||||
verify_fn=reward_correctness,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# %%
|
|
||||||
# eval w/o lora
|
|
||||||
def eval_generate_fn(inputs):
|
|
||||||
return model.fast_generate(
|
|
||||||
inputs,
|
|
||||||
sampling_params=sampling_params,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
rl_helpers.run_eval(
|
|
||||||
generate_fn=eval_generate_fn,
|
|
||||||
verify_fn=reward_correctness,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
)
|
|
@ -0,0 +1,124 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from unsloth import FastLanguageModel, is_bfloat16_supported
|
||||||
|
|
||||||
|
import src.UnslothGRPOTrainerTemp as UnslothGRPOTrainerTemp
|
||||||
|
|
||||||
|
# Import reward functions
|
||||||
|
from src.rl_helpers import (
|
||||||
|
build_reward_correctness_fn,
|
||||||
|
get_qa_dataset,
|
||||||
|
reward_exact_match_chunk_query,
|
||||||
|
reward_formatting,
|
||||||
|
reward_retry_behavior,
|
||||||
|
run_agent,
|
||||||
|
)
|
||||||
|
from src.config import (
|
||||||
|
MODEL_CONFIG,
|
||||||
|
MODEL_NAME,
|
||||||
|
OUTPUT_DIR,
|
||||||
|
TRAINING_CONFIG,
|
||||||
|
get_sampling_params,
|
||||||
|
init_training_dirs,
|
||||||
|
logger,
|
||||||
|
update_log_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize training directories
|
||||||
|
paths = init_training_dirs()
|
||||||
|
|
||||||
|
# Update logger to use the training directory
|
||||||
|
update_log_path(paths["log_dir"])
|
||||||
|
logger.info(f"Training output directory: {paths['output_dir']}")
|
||||||
|
logger.info(f"Logs are being saved to both ./logs and {paths['log_dir']}")
|
||||||
|
|
||||||
|
# Initialize model and tokenizer
|
||||||
|
logger.info(f"Initializing model {MODEL_NAME}")
|
||||||
|
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||||
|
model_name=MODEL_NAME,
|
||||||
|
max_seq_length=MODEL_CONFIG["max_seq_length"],
|
||||||
|
load_in_4bit=True, # False for LoRA 16bit
|
||||||
|
fast_inference=True, # Enable vLLM fast inference
|
||||||
|
max_lora_rank=MODEL_CONFIG["lora_rank"],
|
||||||
|
gpu_memory_utilization=MODEL_CONFIG["gpu_memory_utilization"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Setup LoRA
|
||||||
|
logger.info("Setting up LoRA adapter")
|
||||||
|
model = FastLanguageModel.get_peft_model(
|
||||||
|
model,
|
||||||
|
r=MODEL_CONFIG["lora_rank"],
|
||||||
|
target_modules=MODEL_CONFIG["target_modules"],
|
||||||
|
lora_alpha=MODEL_CONFIG["lora_rank"],
|
||||||
|
use_gradient_checkpointing=True, # Enable long context finetuning
|
||||||
|
random_state=3407,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load datasets
|
||||||
|
logger.info("Loading datasets")
|
||||||
|
train_dataset, test_dataset = get_qa_dataset()
|
||||||
|
logger.info(
|
||||||
|
f"Loaded {len(train_dataset)} training examples and {len(test_dataset)} test examples"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Setup training arguments
|
||||||
|
logger.info("Setting up training arguments")
|
||||||
|
training_args = UnslothGRPOTrainerTemp.UnslothGRPOConfig(
|
||||||
|
use_vllm=True, # use vLLM for fast inference!
|
||||||
|
use_agentic_generate=True, # use agentic generation
|
||||||
|
**TRAINING_CONFIG,
|
||||||
|
bf16=is_bfloat16_supported(),
|
||||||
|
fp16=not is_bfloat16_supported(),
|
||||||
|
output_dir=OUTPUT_DIR,
|
||||||
|
# report_to="tensorboard", # ❓ Does't have billions of tensorboard files if set report to right here
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Setup model generation functions
|
||||||
|
def agentic_generate(
|
||||||
|
prompts: list,
|
||||||
|
generate_fn,
|
||||||
|
max_generations: int = 10,
|
||||||
|
):
|
||||||
|
return run_agent(generate_fn, tokenizer, prompts, max_generations)
|
||||||
|
|
||||||
|
|
||||||
|
model.agentic_generate = agentic_generate
|
||||||
|
|
||||||
|
# Setup verifier
|
||||||
|
logger.info("Setting up verifier")
|
||||||
|
verifier_sampling_params = get_sampling_params(temperature=0.1)
|
||||||
|
|
||||||
|
|
||||||
|
def verifier_generate_fn(inputs):
|
||||||
|
return model.fast_generate(
|
||||||
|
inputs,
|
||||||
|
sampling_params=verifier_sampling_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Setup trainer
|
||||||
|
logger.info("Initializing trainer")
|
||||||
|
trainer = UnslothGRPOTrainerTemp.UnslothGRPOTrainer(
|
||||||
|
model=model,
|
||||||
|
processing_class=tokenizer,
|
||||||
|
reward_funcs=[
|
||||||
|
build_reward_correctness_fn(
|
||||||
|
verifier_generate_fn,
|
||||||
|
tokenizer,
|
||||||
|
log_file=os.path.join(paths["log_dir"], "qa_log.txt"),
|
||||||
|
),
|
||||||
|
reward_formatting,
|
||||||
|
reward_retry_behavior,
|
||||||
|
reward_exact_match_chunk_query,
|
||||||
|
],
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Train the model
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logger.info("Starting training")
|
||||||
|
trainer.train()
|
||||||
|
logger.info("Training completed")
|
||||||
|
logger.info(f"Model saved to {OUTPUT_DIR}")
|
Loading…
Reference in new issue