diff --git a/config.py b/config.py index 1ff243f..20430b9 100644 --- a/config.py +++ b/config.py @@ -28,7 +28,9 @@ GENERATOR_SERVER_PORT = 8002 # Model configuration # MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" -MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" +# MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" +MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct" +# MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct" # MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" # MODEL_NAME = "unsloth/Qwen2-1.5B" # Smoke test first device_id = 1 if os.environ.get("CUDA_VISIBLE_DEVICES") == "1" else torch.cuda.current_device() @@ -38,7 +40,7 @@ OUTPUT_DIR = PROJ_ROOT / f"trainer_output_{MODEL_NAME.replace('/', '_')}_gpu{dev # Model parameters MODEL_CONFIG = { - "max_seq_length": 4096 * 2, # Can increase for longer reasoning traces + "max_seq_length": 4096 * 6, # 24k tokens -> just try to utiliiz "lora_rank": 64, # Larger rank = smarter, but slower "gpu_memory_utilization": 0.6, # Reduce if out of memory "model_name": MODEL_NAME, @@ -66,9 +68,9 @@ TRAINING_CONFIG = { "per_device_train_batch_size": 8, "gradient_accumulation_steps": 1, # Increase to 4 for smoother training "num_generations": 6, # Decrease if out of memory - "max_prompt_length": 1024, - "max_completion_length": 1024, - "max_steps": 101, + "max_prompt_length": 4096 * 4 - 2048, + "max_completion_length": 2048, + "max_steps": 1000, "save_steps": 50, "max_grad_norm": 0.1, "report_to": "tensorboard", @@ -81,7 +83,7 @@ def get_sampling_params(temperature: float = 0.1) -> SamplingParams: return SamplingParams( temperature=temperature, top_p=0.95, - max_tokens=4096, + max_tokens=4096 * 6, ) diff --git a/src/UnslothGRPOTrainerTemp.py b/src/UnslothGRPOTrainerTemp.py index 731e174..62c3314 100644 --- a/src/UnslothGRPOTrainerTemp.py +++ b/src/UnslothGRPOTrainerTemp.py @@ -571,7 +571,7 @@ class UnslothGRPOConfig(GRPOConfig): include_inputs_for_metrics=False, eval_do_concat_batches=True, fp16_backend="auto", - evaluation_strategy=None, + # evaluation_strategy=None, push_to_hub_model_id=None, push_to_hub_organization=None, push_to_hub_token=None, @@ -744,7 +744,7 @@ class UnslothGRPOConfig(GRPOConfig): include_inputs_for_metrics=include_inputs_for_metrics, eval_do_concat_batches=eval_do_concat_batches, fp16_backend=fp16_backend, - evaluation_strategy=evaluation_strategy, + # evaluation_strategy=evaluation_strategy, push_to_hub_model_id=push_to_hub_model_id, push_to_hub_organization=push_to_hub_organization, push_to_hub_token=push_to_hub_token, @@ -1337,6 +1337,8 @@ class _UnslothGRPOTrainer(Trainer): self._metrics["reward"].append(rewards.mean().item()) self._metrics["reward_std"].append(std_grouped_rewards.mean().item()) + self._metrics["advantages_mean"].append(advantages.mean().item()) + self._metrics["advantages_std"].append(advantages.std().item()) if ( self.log_completions @@ -1357,6 +1359,10 @@ class _UnslothGRPOTrainer(Trainer): if wandb.run is not None and self.accelerator.is_main_process: wandb.log({"completions": wandb.Table(dataframe=df)}) + # Log prompt length + prompt_length = prompt_mask.sum(dim=1).float().mean().item() + self._metrics["prompt_length"].append(prompt_length) + return { "prompt_ids": prompt_ids, "prompt_mask": prompt_mask, diff --git a/src/agent.py b/src/agent.py index d3eb87c..7bca1b9 100644 --- a/src/agent.py +++ b/src/agent.py @@ -11,6 +11,7 @@ from trl.trainer.grpo_trainer import apply_chat_template from config import logger from src.prompts import build_user_prompt, get_system_prompt +# TODO: refactor this, it's terrible from src.search_module import search from src.tokenizer_adapter import TokenizerAdapter @@ -119,7 +120,7 @@ class Agent: search_query = extract_search_query(assistant_response) if search_query: logger.info(f"🔍 Search Query: {search_query}") - results = self.search_fn(search_query, return_type=str, results=2) + results = self.search_fn(search_query) formatted_results = f"{results}" logger.info(f"â„šī¸ Information: {formatted_results}") diff --git a/src/search_module.py b/src/search_module.py index d196543..5182276 100644 --- a/src/search_module.py +++ b/src/search_module.py @@ -12,6 +12,8 @@ from langchain_community.vectorstores import FAISS from config import DATA_DIR, logger from src.embeddings import CustomHuggingFaceEmbeddings +PROCESSED_DATA_DIR = DATA_DIR / "processed" + # Load pre-saved vectorstore def load_vectorstore(): @@ -19,8 +21,8 @@ def load_vectorstore(): try: embeddings = CustomHuggingFaceEmbeddings() # Load the FAISS index from the data directory - logger.info(f"Loading FAISS index from: {DATA_DIR}") - vectorstore = FAISS.load_local(str(DATA_DIR), embeddings, allow_dangerous_deserialization=True) + logger.info(f"Loading FAISS index from: {PROCESSED_DATA_DIR}") + vectorstore = FAISS.load_local(str(PROCESSED_DATA_DIR), embeddings, allow_dangerous_deserialization=True) logger.info("Successfully loaded FAISS index") return vectorstore except Exception as e: @@ -75,12 +77,12 @@ def search(query: str, return_type=str, results: int = 5): def load_qa_data(): """Load the pre-generated questions""" try: - questions_path = DATA_DIR / "questions.json" + questions_path = PROCESSED_DATA_DIR / "questions.jsonl" logger.info(f"Loading questions from: {questions_path}") # Load the questions with open(questions_path, "r") as f: - questions = json.load(f) + questions = [json.loads(line) for line in f] logger.info(f"Successfully loaded {len(questions)} questions") return questions @@ -142,7 +144,7 @@ def get_question_count() -> int: return len(questions) -def get_qa_dataset(randomize: bool = False) -> tuple: +def get_qa_dataset(randomize: bool = False, test_size: float = 0.1, seed: int = 42) -> tuple: """ Return a HuggingFace Dataset containing question and answer pairs. @@ -150,20 +152,45 @@ def get_qa_dataset(randomize: bool = False) -> tuple: Each element in the dataset is a dictionary that includes at least: - "question": The question text. - "answer": The corresponding answer text. + - "supporting_paragraphs": The supporting paragraphs for the question. Additional keys present in the original questions data will also be included. + Args: + randomize: Whether to shuffle the dataset + test_size: Proportion of the dataset to include in the test split (0 for train-only) + seed: Random seed for reproducibility + Returns: - A HuggingFace Dataset object. + A tuple of (train_dataset, test_dataset) HuggingFace Dataset objects. + If test_size=0, test_dataset will be empty. If test_size=1, train_dataset will be empty. """ if questions is None: raise ValueError("Questions not loaded. Please ensure questions.json exists.") qa_dataset = Dataset.from_list(questions) if randomize: - qa_dataset = qa_dataset.shuffle(seed=42) - train_dataset = qa_dataset.train_test_split(test_size=0.1, seed=42)["train"] - test_dataset = qa_dataset.train_test_split(test_size=0.1, seed=42)["test"] - # rename the column of the dataset from "question" to "input" - train_dataset = train_dataset.rename_column("question", "prompt") - test_dataset = test_dataset.rename_column("question", "prompt") - return train_dataset, test_dataset + qa_dataset = qa_dataset.shuffle(seed=seed) + + # Create empty dataset for when train or test size is 0 + empty_dataset = Dataset.from_list([]) + + if test_size <= 0: + # Only train dataset, empty test dataset + train_dataset = qa_dataset + train_dataset = train_dataset.rename_column("question", "prompt") + return train_dataset, empty_dataset + elif test_size >= 1: + # Only test dataset, empty train dataset + test_dataset = qa_dataset + test_dataset = test_dataset.rename_column("question", "prompt") + return empty_dataset, test_dataset + else: + # Both train and test datasets + split = qa_dataset.train_test_split(test_size=test_size, seed=seed) + train_dataset = split["train"] + test_dataset = split["test"] + + # rename the column of the dataset from "question" to "input" + train_dataset = train_dataset.rename_column("question", "prompt") + test_dataset = test_dataset.rename_column("question", "prompt") + return train_dataset, test_dataset diff --git a/train_grpo.py b/train_grpo.py index 9f3dd58..1460638 100644 --- a/train_grpo.py +++ b/train_grpo.py @@ -7,10 +7,6 @@ import os from unsloth import FastLanguageModel, is_bfloat16_supported import src.UnslothGRPOTrainerTemp as UnslothGRPOTrainerTemp - -# Import reward functions -from src import build_reward_correctness_fn, get_qa_dataset, reward_em_chunk, reward_format, reward_retry -from src.agent import Agent from config import ( MODEL_CONFIG, MODEL_NAME, @@ -21,6 +17,10 @@ from config import ( logger, update_log_path, ) + +# Import reward functions +from src import build_reward_correctness_fn, get_qa_dataset, reward_em_chunk, reward_format, reward_retry +from src.agent import Agent from src.rewards import ( build_reward_correctness_fn, reward_em_chunk, @@ -64,7 +64,7 @@ model = FastLanguageModel.get_peft_model( # Load datasets logger.info("Loading datasets") -train_dataset, test_dataset = get_qa_dataset() +train_dataset, test_dataset = get_qa_dataset(randomize=True, test_size=0, seed=42) logger.info(f"Loaded {len(train_dataset)} training examples and {len(test_dataset)} test examples") # Setup training arguments @@ -76,8 +76,7 @@ training_args = UnslothGRPOTrainerTemp.UnslothGRPOConfig( bf16=is_bfloat16_supported(), fp16=not is_bfloat16_supported(), output_dir=OUTPUT_DIR, - reward_weights=[4.0, 2.0, 1.0, 1.0, 1.0, 1.0], - # report_to="tensorboard", # ❓ Does't have billions of tensorboard files if set report to right here + reward_weights=[2.0, 1.0, 1.0, 1.0], ) @@ -85,7 +84,7 @@ training_args = UnslothGRPOTrainerTemp.UnslothGRPOConfig( def agentic_generate( prompts: list, generate_fn, - max_generations: int = 20, + max_generations: int = 32, ): # Create agent with appropriate adapter based on tokenizer tokenizer_name = tokenizer.name_or_path.lower() @@ -129,8 +128,8 @@ trainer = UnslothGRPOTrainerTemp.UnslothGRPOTrainer( reward_format, reward_retry, reward_em_chunk, - reward_search_strategy, - reward_search_diversity, + # reward_search_strategy, + # reward_search_diversity, ], args=training_args, train_dataset=train_dataset, @@ -142,13 +141,3 @@ if __name__ == "__main__": trainer.train() logger.info("Training completed") logger.info(f"Model saved to {OUTPUT_DIR}") - - # Save model to FP16 format - logger.info("Saving model to FP16 format") - model_merged_dir = os.path.join(OUTPUT_DIR, "model_merged_16bit") - model.save_pretrained_merged( - model_merged_dir, - tokenizer, - save_method="merged_16bit", - ) - logger.info(f"FP16 model saved to {model_merged_dir}")