feat: update model configuration (longer context) and dataset loading logic for improved performance and flexibility

main
thinhlpg 4 weeks ago
parent 4a1d45271d
commit 2df9f39fda

@ -28,7 +28,9 @@ GENERATOR_SERVER_PORT = 8002
# Model configuration
# MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
# MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct"
# MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
# MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
# MODEL_NAME = "unsloth/Qwen2-1.5B" # Smoke test first
device_id = 1 if os.environ.get("CUDA_VISIBLE_DEVICES") == "1" else torch.cuda.current_device()
@ -38,7 +40,7 @@ OUTPUT_DIR = PROJ_ROOT / f"trainer_output_{MODEL_NAME.replace('/', '_')}_gpu{dev
# Model parameters
MODEL_CONFIG = {
"max_seq_length": 4096 * 2, # Can increase for longer reasoning traces
"max_seq_length": 4096 * 6, # 24k tokens -> just try to utiliiz
"lora_rank": 64, # Larger rank = smarter, but slower
"gpu_memory_utilization": 0.6, # Reduce if out of memory
"model_name": MODEL_NAME,
@ -66,9 +68,9 @@ TRAINING_CONFIG = {
"per_device_train_batch_size": 8,
"gradient_accumulation_steps": 1, # Increase to 4 for smoother training
"num_generations": 6, # Decrease if out of memory
"max_prompt_length": 1024,
"max_completion_length": 1024,
"max_steps": 101,
"max_prompt_length": 4096 * 4 - 2048,
"max_completion_length": 2048,
"max_steps": 1000,
"save_steps": 50,
"max_grad_norm": 0.1,
"report_to": "tensorboard",
@ -81,7 +83,7 @@ def get_sampling_params(temperature: float = 0.1) -> SamplingParams:
return SamplingParams(
temperature=temperature,
top_p=0.95,
max_tokens=4096,
max_tokens=4096 * 6,
)

@ -571,7 +571,7 @@ class UnslothGRPOConfig(GRPOConfig):
include_inputs_for_metrics=False,
eval_do_concat_batches=True,
fp16_backend="auto",
evaluation_strategy=None,
# evaluation_strategy=None,
push_to_hub_model_id=None,
push_to_hub_organization=None,
push_to_hub_token=None,
@ -744,7 +744,7 @@ class UnslothGRPOConfig(GRPOConfig):
include_inputs_for_metrics=include_inputs_for_metrics,
eval_do_concat_batches=eval_do_concat_batches,
fp16_backend=fp16_backend,
evaluation_strategy=evaluation_strategy,
# evaluation_strategy=evaluation_strategy,
push_to_hub_model_id=push_to_hub_model_id,
push_to_hub_organization=push_to_hub_organization,
push_to_hub_token=push_to_hub_token,
@ -1337,6 +1337,8 @@ class _UnslothGRPOTrainer(Trainer):
self._metrics["reward"].append(rewards.mean().item())
self._metrics["reward_std"].append(std_grouped_rewards.mean().item())
self._metrics["advantages_mean"].append(advantages.mean().item())
self._metrics["advantages_std"].append(advantages.std().item())
if (
self.log_completions
@ -1357,6 +1359,10 @@ class _UnslothGRPOTrainer(Trainer):
if wandb.run is not None and self.accelerator.is_main_process:
wandb.log({"completions": wandb.Table(dataframe=df)})
# Log prompt length
prompt_length = prompt_mask.sum(dim=1).float().mean().item()
self._metrics["prompt_length"].append(prompt_length)
return {
"prompt_ids": prompt_ids,
"prompt_mask": prompt_mask,

@ -11,6 +11,7 @@ from trl.trainer.grpo_trainer import apply_chat_template
from config import logger
from src.prompts import build_user_prompt, get_system_prompt
# TODO: refactor this, it's terrible
from src.search_module import search
from src.tokenizer_adapter import TokenizerAdapter
@ -119,7 +120,7 @@ class Agent:
search_query = extract_search_query(assistant_response)
if search_query:
logger.info(f"🔍 Search Query: {search_query}")
results = self.search_fn(search_query, return_type=str, results=2)
results = self.search_fn(search_query)
formatted_results = f"<information>{results}</information>"
logger.info(f" Information: {formatted_results}")

@ -12,6 +12,8 @@ from langchain_community.vectorstores import FAISS
from config import DATA_DIR, logger
from src.embeddings import CustomHuggingFaceEmbeddings
PROCESSED_DATA_DIR = DATA_DIR / "processed"
# Load pre-saved vectorstore
def load_vectorstore():
@ -19,8 +21,8 @@ def load_vectorstore():
try:
embeddings = CustomHuggingFaceEmbeddings()
# Load the FAISS index from the data directory
logger.info(f"Loading FAISS index from: {DATA_DIR}")
vectorstore = FAISS.load_local(str(DATA_DIR), embeddings, allow_dangerous_deserialization=True)
logger.info(f"Loading FAISS index from: {PROCESSED_DATA_DIR}")
vectorstore = FAISS.load_local(str(PROCESSED_DATA_DIR), embeddings, allow_dangerous_deserialization=True)
logger.info("Successfully loaded FAISS index")
return vectorstore
except Exception as e:
@ -75,12 +77,12 @@ def search(query: str, return_type=str, results: int = 5):
def load_qa_data():
"""Load the pre-generated questions"""
try:
questions_path = DATA_DIR / "questions.json"
questions_path = PROCESSED_DATA_DIR / "questions.jsonl"
logger.info(f"Loading questions from: {questions_path}")
# Load the questions
with open(questions_path, "r") as f:
questions = json.load(f)
questions = [json.loads(line) for line in f]
logger.info(f"Successfully loaded {len(questions)} questions")
return questions
@ -142,7 +144,7 @@ def get_question_count() -> int:
return len(questions)
def get_qa_dataset(randomize: bool = False) -> tuple:
def get_qa_dataset(randomize: bool = False, test_size: float = 0.1, seed: int = 42) -> tuple:
"""
Return a HuggingFace Dataset containing question and answer pairs.
@ -150,19 +152,44 @@ def get_qa_dataset(randomize: bool = False) -> tuple:
Each element in the dataset is a dictionary that includes at least:
- "question": The question text.
- "answer": The corresponding answer text.
- "supporting_paragraphs": The supporting paragraphs for the question.
Additional keys present in the original questions data will also be included.
Args:
randomize: Whether to shuffle the dataset
test_size: Proportion of the dataset to include in the test split (0 for train-only)
seed: Random seed for reproducibility
Returns:
A HuggingFace Dataset object.
A tuple of (train_dataset, test_dataset) HuggingFace Dataset objects.
If test_size=0, test_dataset will be empty. If test_size=1, train_dataset will be empty.
"""
if questions is None:
raise ValueError("Questions not loaded. Please ensure questions.json exists.")
qa_dataset = Dataset.from_list(questions)
if randomize:
qa_dataset = qa_dataset.shuffle(seed=42)
train_dataset = qa_dataset.train_test_split(test_size=0.1, seed=42)["train"]
test_dataset = qa_dataset.train_test_split(test_size=0.1, seed=42)["test"]
qa_dataset = qa_dataset.shuffle(seed=seed)
# Create empty dataset for when train or test size is 0
empty_dataset = Dataset.from_list([])
if test_size <= 0:
# Only train dataset, empty test dataset
train_dataset = qa_dataset
train_dataset = train_dataset.rename_column("question", "prompt")
return train_dataset, empty_dataset
elif test_size >= 1:
# Only test dataset, empty train dataset
test_dataset = qa_dataset
test_dataset = test_dataset.rename_column("question", "prompt")
return empty_dataset, test_dataset
else:
# Both train and test datasets
split = qa_dataset.train_test_split(test_size=test_size, seed=seed)
train_dataset = split["train"]
test_dataset = split["test"]
# rename the column of the dataset from "question" to "input"
train_dataset = train_dataset.rename_column("question", "prompt")
test_dataset = test_dataset.rename_column("question", "prompt")

@ -7,10 +7,6 @@ import os
from unsloth import FastLanguageModel, is_bfloat16_supported
import src.UnslothGRPOTrainerTemp as UnslothGRPOTrainerTemp
# Import reward functions
from src import build_reward_correctness_fn, get_qa_dataset, reward_em_chunk, reward_format, reward_retry
from src.agent import Agent
from config import (
MODEL_CONFIG,
MODEL_NAME,
@ -21,6 +17,10 @@ from config import (
logger,
update_log_path,
)
# Import reward functions
from src import build_reward_correctness_fn, get_qa_dataset, reward_em_chunk, reward_format, reward_retry
from src.agent import Agent
from src.rewards import (
build_reward_correctness_fn,
reward_em_chunk,
@ -64,7 +64,7 @@ model = FastLanguageModel.get_peft_model(
# Load datasets
logger.info("Loading datasets")
train_dataset, test_dataset = get_qa_dataset()
train_dataset, test_dataset = get_qa_dataset(randomize=True, test_size=0, seed=42)
logger.info(f"Loaded {len(train_dataset)} training examples and {len(test_dataset)} test examples")
# Setup training arguments
@ -76,8 +76,7 @@ training_args = UnslothGRPOTrainerTemp.UnslothGRPOConfig(
bf16=is_bfloat16_supported(),
fp16=not is_bfloat16_supported(),
output_dir=OUTPUT_DIR,
reward_weights=[4.0, 2.0, 1.0, 1.0, 1.0, 1.0],
# report_to="tensorboard", # ❓ Does't have billions of tensorboard files if set report to right here
reward_weights=[2.0, 1.0, 1.0, 1.0],
)
@ -85,7 +84,7 @@ training_args = UnslothGRPOTrainerTemp.UnslothGRPOConfig(
def agentic_generate(
prompts: list,
generate_fn,
max_generations: int = 20,
max_generations: int = 32,
):
# Create agent with appropriate adapter based on tokenizer
tokenizer_name = tokenizer.name_or_path.lower()
@ -129,8 +128,8 @@ trainer = UnslothGRPOTrainerTemp.UnslothGRPOTrainer(
reward_format,
reward_retry,
reward_em_chunk,
reward_search_strategy,
reward_search_diversity,
# reward_search_strategy,
# reward_search_diversity,
],
args=training_args,
train_dataset=train_dataset,
@ -142,13 +141,3 @@ if __name__ == "__main__":
trainer.train()
logger.info("Training completed")
logger.info(f"Model saved to {OUTPUT_DIR}")
# Save model to FP16 format
logger.info("Saving model to FP16 format")
model_merged_dir = os.path.join(OUTPUT_DIR, "model_merged_16bit")
model.save_pretrained_merged(
model_merged_dir,
tokenizer,
save_method="merged_16bit",
)
logger.info(f"FP16 model saved to {model_merged_dir}")

Loading…
Cancel
Save