From 04593fa8fdcbd8b74034ac1cbec1b8e1dd196a84 Mon Sep 17 00:00:00 2001 From: thinhlpg Date: Tue, 1 Apr 2025 04:08:31 +0700 Subject: [PATCH] style: change line length to 119, organize imports --- Makefile | 2 +- eval.py | 80 +++------ inference.py | 2 +- src/UnslothGRPOTrainerTemp.py | 299 ++++++++-------------------------- src/config.py | 31 +--- src/rl_helpers.py | 151 +++++------------ src/search_module.py | 8 +- train_grpo.py | 24 ++- 8 files changed, 149 insertions(+), 448 deletions(-) diff --git a/Makefile b/Makefile index cfa32a7..181eaa0 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,7 @@ # make sure to test the local checkout in scripts and not the pre-installed one export PYTHONPATH = src -check_dirs := src +check_dirs := . src notebooks scripts # Development dependencies install: diff --git a/eval.py b/eval.py index d017868..36b7679 100644 --- a/eval.py +++ b/eval.py @@ -170,9 +170,7 @@ def test_lora_functionality(model, tokenizer, lora_path): # Sample with base model logger.info("Generating with base model...") - sampling_params = get_sampling_params( - temperature=0.7 - ) # Higher temp to make differences more obvious + sampling_params = get_sampling_params(temperature=0.7) # Higher temp to make differences more obvious base_response = model.fast_generate( [formatted_prompt], sampling_params=sampling_params, @@ -217,12 +215,8 @@ def test_lora_functionality(model, tokenizer, lora_path): logger.info("-" * 40) if are_identical: - logger.warning( - "\nWARNING: LoRA adapter does not seem to change the model's output" - ) - logger.warning( - "This could indicate that the LoRA adapter is not being properly applied" - ) + logger.warning("\nWARNING: LoRA adapter does not seem to change the model's output") + logger.warning("This could indicate that the LoRA adapter is not being properly applied") else: logger.info("\nLoRA adapter is working as expected (outputs are different)") @@ -257,9 +251,7 @@ def evaluate_model( # Prioritize the directory passed from the shell script if available if trainer_dir and os.path.isdir(trainer_dir): trainer_output_dir = os.path.abspath(trainer_dir) - logger.info( - f"Using trainer directory passed from arguments: {trainer_output_dir}" - ) + logger.info(f"Using trainer directory passed from arguments: {trainer_output_dir}") else: logger.warning( f"Trainer directory not provided or invalid: {trainer_dir}. Attempting to determine automatically." @@ -274,9 +266,7 @@ def evaluate_model( # If a LoRA path exists (provided or found), get its parent's parent checkpoint_dir = os.path.dirname(os.path.abspath(temp_lora_path)) trainer_output_dir = os.path.dirname(checkpoint_dir) - logger.info( - f"Determined trainer directory from LoRA path ({temp_lora_path}): {trainer_output_dir}" - ) + logger.info(f"Determined trainer directory from LoRA path ({temp_lora_path}): {trainer_output_dir}") else: # If no LoRA path, default to current directory (should ideally not happen if called from eval.sh) trainer_output_dir = os.path.abspath(".") @@ -290,22 +280,16 @@ def evaluate_model( detected_checkpoint = find_latest_checkpoint(search_dir=trainer_output_dir) if detected_checkpoint: lora_path = detected_checkpoint - logger.info( - f"Auto-detected latest checkpoint in {trainer_output_dir}: {lora_path}" - ) + logger.info(f"Auto-detected latest checkpoint in {trainer_output_dir}: {lora_path}") else: - logger.warning( - f"No checkpoint found in {trainer_output_dir} for auto-detection. Evaluating base model." - ) + logger.warning(f"No checkpoint found in {trainer_output_dir} for auto-detection. Evaluating base model.") lora_path = None model_type = "LoRA" if lora_path else "Base" logger.info(f"\n{'=' * 50}") logger.info(f"Starting evaluation of {model_type} model") - logger.info( - f"Trainer Output Directory: {trainer_output_dir}" - ) # Log the final directory + logger.info(f"Trainer Output Directory: {trainer_output_dir}") # Log the final directory logger.info(f"{'=' * 50}") # --- Create eval_logs directory --- @@ -319,18 +303,14 @@ def evaluate_model( # Fallback to current directory if creation fails eval_log_dir = os.path.abspath("./eval_logs") os.makedirs(eval_log_dir, exist_ok=True) - logger.warning( - f"Fell back to creating eval_logs in current directory: {eval_log_dir}" - ) + logger.warning(f"Fell back to creating eval_logs in current directory: {eval_log_dir}") # Create file names based on model type model_prefix = "lora" if lora_path else "base" timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") # Define all output file paths - eval_log_file = os.path.join( - eval_log_dir, f"{model_prefix}_model_eval_{timestamp}.log" - ) + eval_log_file = os.path.join(eval_log_dir, f"{model_prefix}_model_eval_{timestamp}.log") output_file = os.path.join(eval_log_dir, f"{model_prefix}_model_results.txt") debug_file = os.path.join(eval_log_dir, f"{model_prefix}_model_results_debug.json") @@ -343,12 +323,8 @@ def evaluate_model( if lora_path: lora_request = model.load_lora(lora_path) load_time = time.time() - start_time - logger.info( - f"LoRA adapter loaded in {load_time:.2f} seconds: {lora_request}" - ) - responses = model.fast_generate( - inputs, sampling_params=sampling_params, lora_request=lora_request - ) + logger.info(f"LoRA adapter loaded in {load_time:.2f} seconds: {lora_request}") + responses = model.fast_generate(inputs, sampling_params=sampling_params, lora_request=lora_request) else: # For base model, add additional logging logger.info("Generating with base model (no LoRA)") @@ -373,9 +349,7 @@ def evaluate_model( return model.fast_generate(inputs, sampling_params=verifier_params) # Prepare the verification function - verify_fn = rl_helpers.build_reward_correctness_fn( - verifier_generate_fn, tokenizer, log_file=eval_log_file - ) + verify_fn = rl_helpers.build_reward_correctness_fn(verifier_generate_fn, tokenizer, log_file=eval_log_file) # Get the dataset and prepare questions and answers train_dataset, test_dataset = rl_helpers.get_qa_dataset() @@ -438,9 +412,7 @@ def evaluate_model( f.write(f"Results saved to: {output_file}\n") f.write(f"Debug data saved to: {debug_file}\n\n") - logger.info( - f"Evaluation completed. Results saved to {output_file} and {debug_file}" - ) + logger.info(f"Evaluation completed. Results saved to {output_file} and {debug_file}") return results @@ -463,9 +435,7 @@ def compare_models(lora_path, temperature=0.5, output_file=None, trainer_dir=Non lora_path = detected_checkpoint logger.info(f"Auto-detected latest checkpoint: {lora_path}") else: - logger.warning( - "No checkpoint found for auto-detection. Skipping comparison." - ) + logger.warning("No checkpoint found for auto-detection. Skipping comparison.") return # Set up output directory in the checkpoint directory @@ -560,9 +530,7 @@ if __name__ == "__main__": default="auto", help="Path to LoRA weights (use 'auto' for auto-detection)", ) - parser.add_argument( - "--temperature", type=float, default=0.5, help="Sampling temperature" - ) + parser.add_argument("--temperature", type=float, default=0.5, help="Sampling temperature") parser.add_argument( "--output_file", type=str, @@ -609,24 +577,16 @@ if __name__ == "__main__": update_log_path(eval_log_dir) logger.info(f"Logs will be saved to both ./logs and {eval_log_dir}") except ImportError: - logger.info( - "Config's update_log_path not available, using default logging" - ) + logger.info("Config's update_log_path not available, using default logging") if trainer_dir: logger.info(f"Using trainer directory: {trainer_dir}") - logger.info( - f"All evaluation files will be stored in: {os.path.join(trainer_dir, 'eval_logs')}" - ) + logger.info(f"All evaluation files will be stored in: {os.path.join(trainer_dir, 'eval_logs')}") else: - logger.warning( - "No trainer directory found, will attempt to determine during evaluation" - ) + logger.warning("No trainer directory found, will attempt to determine during evaluation") logger.info(f"Starting model evaluation with temperature {args.temperature}") - results = compare_models( - args.lora_path, args.temperature, args.output_file, trainer_dir=trainer_dir - ) + results = compare_models(args.lora_path, args.temperature, args.output_file, trainer_dir=trainer_dir) if results: logger.info("Evaluation completed successfully") logger.info(f"Final improvement: {results['improvement']:.4f}") diff --git a/inference.py b/inference.py index 6f97746..a2aab1a 100644 --- a/inference.py +++ b/inference.py @@ -498,7 +498,7 @@ You are a helpful assistant with tool calling capabilities.""" try: with open(filepath, "w", encoding="utf-8") as f: f.write(f"{'=' * 80}\n") - f.write(f"DEEPSEARCH CHAT HISTORY\n") + f.write("DEEPSEARCH CHAT HISTORY\n") f.write(f"Model: {MODEL_NAME}\n") f.write(f"LoRA Path: {self.lora_path if self.lora_path else 'None'}\n") f.write(f"Temperature: {self.temperature}\n") diff --git a/src/UnslothGRPOTrainerTemp.py b/src/UnslothGRPOTrainerTemp.py index 8ae417f..67f375f 100644 --- a/src/UnslothGRPOTrainerTemp.py +++ b/src/UnslothGRPOTrainerTemp.py @@ -70,15 +70,11 @@ torch_compile_options = { ) def selective_log_softmax(logits, index): logits = logits.to(torch.float32) - selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze( - -1 - ) + selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1) # loop to reduce peak mem consumption # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits]) logsumexp_values = torch.logsumexp(logits, dim=-1) - per_token_logps = ( - selected_logits - logsumexp_values - ) # log_softmax(x_i) = x_i - logsumexp(x) + per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x) return per_token_logps @@ -139,17 +135,11 @@ class UnslothEfficientGRPO(torch.autograd.Function): scaler=None, n_chunks=1, ): - def compute_loss( - new_hidden_states, old_hidden_states, input_ids, mask, advantages, scaling - ): + def compute_loss(new_hidden_states, old_hidden_states, input_ids, mask, advantages, scaling): new_logits = torch.matmul(new_hidden_states, lm_head.t()) - new_logits = new_logits[ - :, :-1, : - ] # exclude the last logit: it corresponds to the next token pred + new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred old_logits = torch.matmul(old_hidden_states, lm_head.t()) - old_logits = old_logits[ - :, :-1, : - ] # exclude the last logit: it corresponds to the next token pred + old_logits = old_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred loss, completion_length, mean_kl = grpo_compute_loss( old_logits, new_logits, @@ -311,11 +301,7 @@ def grpo_accumulated_loss( n_chunks = bsz n_chunks = factors[min(np.searchsorted(factors, n_chunks), len(factors) - 1)] - mixed_dtype = ( - torch.float16 - if os.environ.get("ACCELERATE_MIXED_PRECISION", "fp16") == "fp16" - else torch.bfloat16 - ) + mixed_dtype = torch.float16 if os.environ.get("ACCELERATE_MIXED_PRECISION", "fp16") == "fp16" else torch.bfloat16 os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1" completion_input_ids = input_ids[:, -logits_to_keep:] @@ -324,18 +310,12 @@ def grpo_accumulated_loss( with torch.amp.autocast(device_type="cuda", dtype=mixed_dtype): with ( torch.inference_mode(), - trainer.accelerator.unwrap_model( - trainer.model, keep_fp32_wrapper=False - ).disable_adapter(), + trainer.accelerator.unwrap_model(trainer.model, keep_fp32_wrapper=False).disable_adapter(), ): - old_hidden_states = trainer.model( - input_ids=input_ids, logits_to_keep=logits_to_keep + 1 - ).logits + old_hidden_states = trainer.model(input_ids=input_ids, logits_to_keep=logits_to_keep + 1).logits pass - new_hidden_states = trainer.model( - input_ids=input_ids, logits_to_keep=logits_to_keep + 1 - ).logits + new_hidden_states = trainer.model(input_ids=input_ids, logits_to_keep=logits_to_keep + 1).logits loss, completion_length, mean_kl = UnslothEfficientGRPO.apply( new_hidden_states, @@ -352,13 +332,9 @@ def grpo_accumulated_loss( # Old non efficient code path new_logits = torch.matmul(new_hidden_states, lm_head.t()) - new_logits = new_logits[ - :, :-1, : - ] # exclude the last logit: it corresponds to the next token pred + new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred old_logits = torch.matmul(old_hidden_states, lm_head.t()) - old_logits = old_logits[ - :, :-1, : - ] # exclude the last logit: it corresponds to the next token pred + old_logits = old_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred loss, completion_length, mean_kl = grpo_compute_loss( old_logits, new_logits, @@ -824,24 +800,14 @@ class _UnslothGRPOTrainer(Trainer): reward_funcs: Union[RewardFunc, list[RewardFunc]], args: GRPOConfig = None, train_dataset: Optional[Union[Dataset, IterableDataset]] = None, - eval_dataset: Optional[ - Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]] - ] = None, + eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None, processing_class: Optional[PreTrainedTokenizerBase] = None, - reward_processing_classes: Optional[ - Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]] - ] = None, + reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None, callbacks: Optional[list[TrainerCallback]] = None, - optimizers: tuple[ - Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR] - ] = (None, None), + optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), peft_config: Optional["PeftConfig"] = None, ): - if ( - hasattr(model, "vllm_engine") - and hasattr(args, "use_vllm") - and (getattr(args, "use_vllm", False) == False) - ): + if hasattr(model, "vllm_engine") and hasattr(args, "use_vllm") and (getattr(args, "use_vllm", False) == False): args.use_vllm = True # Args if args is None: @@ -855,11 +821,7 @@ class _UnslothGRPOTrainer(Trainer): if isinstance(model, str): model_id = model torch_dtype = model_init_kwargs.get("torch_dtype") - if ( - isinstance(torch_dtype, torch.dtype) - or torch_dtype == "auto" - or torch_dtype is None - ): + if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None: pass # torch_dtype is already a torch.dtype or "auto" or None elif isinstance(torch_dtype, str): # it's a str, but not "auto" torch_dtype = getattr(torch, torch_dtype) @@ -871,9 +833,7 @@ class _UnslothGRPOTrainer(Trainer): ) # Disable caching if gradient checkpointing is enabled (not supported) model_init_kwargs["use_cache"] = ( - False - if args.gradient_checkpointing - else model_init_kwargs.get("use_cache") + False if args.gradient_checkpointing else model_init_kwargs.get("use_cache") ) model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) else: @@ -889,9 +849,7 @@ class _UnslothGRPOTrainer(Trainer): # Reference model if is_deepspeed_zero3_enabled(): - self.ref_model = AutoModelForCausalLM.from_pretrained( - model_id, **model_init_kwargs - ) + self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs) elif not is_peft_model(model): # If PEFT configuration is not provided, create a reference model based on the initial model. self.ref_model = create_reference_model(model) @@ -902,9 +860,7 @@ class _UnslothGRPOTrainer(Trainer): # Processing class if processing_class is None: - processing_class = AutoTokenizer.from_pretrained( - model.config._name_or_path, padding_side="left" - ) + processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path, padding_side="left") # Reward functions if not isinstance(reward_funcs, list): @@ -934,22 +890,14 @@ class _UnslothGRPOTrainer(Trainer): reward_processing_classes = [reward_processing_classes] else: if len(reward_processing_classes) != len(reward_funcs): - raise ValueError( - "The number of reward processing classes must match the number of reward functions." - ) + raise ValueError("The number of reward processing classes must match the number of reward functions.") - for i, (reward_processing_class, reward_func) in enumerate( - zip(reward_processing_classes, reward_funcs) - ): + for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)): if isinstance(reward_func, PreTrainedModel): if reward_processing_class is None: - reward_processing_class = AutoTokenizer.from_pretrained( - reward_func.config._name_or_path - ) + reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path) if reward_processing_class.pad_token_id is None: - reward_processing_class.pad_token = ( - reward_processing_class.eos_token - ) + reward_processing_class.pad_token = reward_processing_class.eos_token # The reward model computes the reward for the latest non-padded token in the input sequence. # So it's important to set the pad token ID to the padding token ID of the processing class. reward_func.config.pad_token_id = reward_processing_class.pad_token_id @@ -962,9 +910,7 @@ class _UnslothGRPOTrainer(Trainer): # Training arguments self.max_prompt_length = args.max_prompt_length - self.max_completion_length = ( - args.max_completion_length - ) # = |o_i| in the GRPO paper + self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper self.num_generations = args.num_generations # = G in the GRPO paper self.use_vllm = args.use_vllm self.use_agentic_generate = args.use_agentic_generate @@ -997,11 +943,7 @@ class _UnslothGRPOTrainer(Trainer): # Check if the per_device_train/eval_batch_size * num processes can be divided by the number of generations num_processes = self.accelerator.num_processes global_batch_size = args.per_device_train_batch_size * num_processes - possible_values = [ - n_gen - for n_gen in range(2, global_batch_size + 1) - if (global_batch_size) % n_gen == 0 - ] + possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0] if self.num_generations not in possible_values: raise ValueError( f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly " @@ -1010,11 +952,7 @@ class _UnslothGRPOTrainer(Trainer): ) if self.args.eval_strategy != "no": global_batch_size = args.per_device_eval_batch_size * num_processes - possible_values = [ - n_gen - for n_gen in range(2, global_batch_size + 1) - if (global_batch_size) % n_gen == 0 - ] + possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0] if self.num_generations not in possible_values: raise ValueError( f"The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly " @@ -1059,22 +997,14 @@ class _UnslothGRPOTrainer(Trainer): if self.is_deepspeed_enabled: self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) else: - self.ref_model = self.accelerator.prepare_model( - self.ref_model, evaluation_mode=True - ) + self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) if args.sync_ref_model: - self.add_callback( - SyncRefModelCallback( - ref_model=self.ref_model, accelerator=self.accelerator - ) - ) + self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator)) for i, reward_func in enumerate(self.reward_funcs): if isinstance(reward_func, PreTrainedModel): - self.reward_funcs[i] = self.accelerator.prepare_model( - reward_func, evaluation_mode=True - ) + self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True) def _set_signature_columns_if_needed(self): # If `self.args.remove_unused_columns` is True, non-signature columns are removed. @@ -1089,27 +1019,21 @@ class _UnslothGRPOTrainer(Trainer): # identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly # within each prompt group. Using the same seed across processes ensures consistent prompt assignment, # preventing discrepancies in group formation. - return RepeatRandomSampler( - self.train_dataset, self.num_generations, seed=self.args.seed - ) + return RepeatRandomSampler(self.train_dataset, self.num_generations, seed=self.args.seed) def _get_eval_sampler(self, eval_dataset) -> Sampler: # Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that # identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly # within each prompt group. Using the same seed across processes ensures consistent prompt assignment, # preventing discrepancies in group formation. - return RepeatRandomSampler( - eval_dataset, self.num_generations, seed=self.args.seed - ) + return RepeatRandomSampler(eval_dataset, self.num_generations, seed=self.args.seed) # Get the per-token log probabilities for the completions for the model and the reference model def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): return None # Unsloth efficient GRPO if not hasattr(self, "_autocast_dtype"): self._autocast_dtype = ( - torch.float16 - if os.environ.get("ACCELERATE_MIXED_PRECISION", "fp16") == "fp16" - else torch.bfloat16 + torch.float16 if os.environ.get("ACCELERATE_MIXED_PRECISION", "fp16") == "fp16" else torch.bfloat16 ) with torch.amp.autocast(device_type="cuda", dtype=self._autocast_dtype): # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded @@ -1118,9 +1042,7 @@ class _UnslothGRPOTrainer(Trainer): attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1, ).logits - logits = logits[ - :, :-1, : - ] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred + logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred input_ids = input_ids[:, -logits_to_keep:] # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. @@ -1133,15 +1055,10 @@ class _UnslothGRPOTrainer(Trainer): def _move_model_to_vllm(self, *args, **kwargs): return None - def _prepare_inputs( - self, inputs: dict[str, Union[torch.Tensor, Any]] - ) -> dict[str, Union[torch.Tensor, Any]]: + def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]: device = self.accelerator.device prompts = [x["prompt"] for x in inputs] - prompts_text = [ - maybe_apply_chat_template(example, self.processing_class)["prompt"] - for example in inputs - ] + prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs] prompt_inputs = self.processing_class( prompts_text, return_tensors="pt", @@ -1174,9 +1091,7 @@ class _UnslothGRPOTrainer(Trainer): prompts_text, sampling_params=self.sampling_params, use_tqdm=False, - lora_request=self.model.load_lora( - "grpo_trainer_lora_model", load_tensors=True - ), + lora_request=self.model.load_lora("grpo_trainer_lora_model", load_tensors=True), ) if self.use_agentic_generate: agentic_outputs = self.model.agentic_generate( @@ -1201,11 +1116,7 @@ class _UnslothGRPOTrainer(Trainer): ).to(device) else: outputs = generate_fn(all_prompts_text) - completion_ids = [ - out.token_ids - for completions in outputs - for out in completions.outputs - ] + completion_ids = [out.token_ids for completions in outputs for out in completions.outputs] else: completion_ids = [None] * len(all_prompts_text) # Broadcast the completions from the main process to all processes, ensuring each process receives its @@ -1218,18 +1129,12 @@ class _UnslothGRPOTrainer(Trainer): completion_ids = completion_ids[process_slice] # Pad the completions, and concatenate them with the prompts - completion_ids = [ - torch.tensor(ids, device=device) for ids in completion_ids - ] - completion_ids = pad( - completion_ids, padding_value=self.processing_class.pad_token_id - ) + completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] + completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id) prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) else: # Regular generation path - with unwrap_model_for_generation( - self.model, self.accelerator - ) as unwrapped_model: + with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model: prompt_completion_ids = unwrapped_model.generate( prompt_ids, attention_mask=prompt_mask, @@ -1244,21 +1149,15 @@ class _UnslothGRPOTrainer(Trainer): if not self.use_agentic_generate: # Mask everything after the first EOS token is_eos = completion_ids == self.processing_class.eos_token_id - eos_idx = torch.full( - (is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device - ) + eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] - sequence_indices = torch.arange(is_eos.size(1), device=device).expand( - is_eos.size(0), -1 - ) + sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() # Concatenate prompt_mask with completion_mask for logit computation attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C) - logits_to_keep = completion_ids.size( - 1 - ) # we only need to compute the logits for the completion tokens + logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens # this does nothing with ( @@ -1280,9 +1179,7 @@ class _UnslothGRPOTrainer(Trainer): logits_to_keep, ) else: - with self.accelerator.unwrap_model( - self.model, keep_fp32_wrapper=False - ).disable_adapter(): + with self.accelerator.unwrap_model(self.model, keep_fp32_wrapper=False).disable_adapter(): ref_per_token_logps = self._get_per_token_logps( self.model, prompt_completion_ids, @@ -1292,42 +1189,25 @@ class _UnslothGRPOTrainer(Trainer): # Decode the generated completions if not self.use_agentic_generate: - completions_text = self.processing_class.batch_decode( - completion_ids, skip_special_tokens=True - ) + completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) if is_conversational(inputs[0]): completions = [] for prompt, completion in zip(prompts, completions_text): - bootstrap = ( - prompt.pop()["content"] - if prompt[-1]["role"] == "assistant" - else "" - ) - completions.append( - [{"role": "assistant", "content": bootstrap + completion}] - ) + bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else "" + completions.append([{"role": "assistant", "content": bootstrap + completion}]) else: completions = completions_text else: completions = full_chats - rewards_per_func = torch.zeros( - len(prompts), len(self.reward_funcs), device=device - ) + rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device) for i, (reward_func, reward_processing_class) in enumerate( zip(self.reward_funcs, self.reward_processing_classes) ): - if isinstance( - reward_func, nn.Module - ): # Module instead of PretrainedModel for compat with compiled models + if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models if is_conversational(inputs[0]): - messages = [ - {"messages": p + c} for p, c in zip(prompts, completions) - ] - texts = [ - apply_chat_template(x, reward_processing_class)["text"] - for x in messages - ] + messages = [{"messages": p + c} for p, c in zip(prompts, completions)] + texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages] else: texts = [p + c for p, c in zip(prompts, completions)] reward_inputs = reward_processing_class( @@ -1343,36 +1223,25 @@ class _UnslothGRPOTrainer(Trainer): torch.amp.autocast( device_type="cuda", dtype=torch.float16 - if os.environ.get("ACCELERATE_MIXED_PRECISION", "fp16") - == "fp16" + if os.environ.get("ACCELERATE_MIXED_PRECISION", "fp16") == "fp16" else torch.bfloat16, ) if not torch.is_autocast_enabled("cuda") else nullcontext(), ): - rewards_per_func[:, i] = reward_func(**reward_inputs).logits[ - :, 0 - ] # Shape (B*G,) + rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,) else: # Repeat all input columns (but "prompt" and "completion") to match the number of generations keys = [key for key in inputs[0] if key not in ["prompt", "completion"]] - reward_kwargs = { - key: [example[key] for example in inputs] for key in keys - } - output_reward_func = reward_func( - prompts=prompts, completions=completions, **reward_kwargs - ) - rewards_per_func[:, i] = torch.tensor( - output_reward_func, dtype=torch.float32, device=device - ) + reward_kwargs = {key: [example[key] for example in inputs] for key in keys} + output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs) + rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the # completions may be distributed across processes rewards_per_func = gather(rewards_per_func) # Apply weights to each reward function's output and sum - rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).sum( - dim=1 - ) + rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).sum(dim=1) # else: # reward_fn = self.reward_funcs[0] @@ -1391,12 +1260,8 @@ class _UnslothGRPOTrainer(Trainer): std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1) # Normalize the rewards to compute the advantages - mean_grouped_rewards = mean_grouped_rewards.repeat_interleave( - self.num_generations, dim=0 - ) - std_grouped_rewards = std_grouped_rewards.repeat_interleave( - self.num_generations, dim=0 - ) + mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0) + std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0) advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4) # Slice to keep only the local part of the data @@ -1410,15 +1275,11 @@ class _UnslothGRPOTrainer(Trainer): reward_per_func = rewards_per_func.mean(0) print("rewards_per_func:", reward_per_func) for i, reward_func in enumerate(self.reward_funcs): - if isinstance( - reward_func, nn.Module - ): # Module instead of PretrainedModel for compat with compiled models + if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models reward_func_name = reward_func.config._name_or_path.split("/")[-1] else: reward_func_name = reward_func.__name__ - self._metrics[f"rewards/{reward_func_name}"].append( - reward_per_func[i].item() - ) + self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item()) self._metrics["reward"].append(rewards.mean().item()) self._metrics["reward_std"].append(std_grouped_rewards.mean().item()) @@ -1451,9 +1312,7 @@ class _UnslothGRPOTrainer(Trainer): "advantages": advantages, } - def compute_loss( - self, model, inputs, return_outputs=False, num_items_in_batch=None - ): + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): if return_outputs: raise ValueError("The GRPOTrainer does not support returning outputs") # Compute the per-token log probabilities for the model @@ -1467,14 +1326,10 @@ class _UnslothGRPOTrainer(Trainer): bsz, qlen = input_ids.shape # attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) attention_mask = None - logits_to_keep = completion_ids.size( - 1 - ) # we only need to compute the logits for the completion tokens + logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens _input_ids = input_ids _logits_to_keep = logits_to_keep - per_token_logps = self._get_per_token_logps( - model, input_ids, attention_mask, logits_to_keep - ) + per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) # Compute the KL divergence between the model and the reference model ref_per_token_logps = inputs["ref_per_token_logps"] @@ -1529,9 +1384,7 @@ class _UnslothGRPOTrainer(Trainer): return loss, None, None def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None: - metrics = { - key: sum(val) / len(val) for key, val in self._metrics.items() - } # average the metrics + metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs` # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format. @@ -1565,9 +1418,7 @@ class _UnslothGRPOTrainer(Trainer): if not self.is_world_process_zero(): return - if hasattr(self.model.config, "_name_or_path") and not os.path.isdir( - self.model.config._name_or_path - ): + if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): base_model = self.model.config._name_or_path else: base_model = None @@ -1596,9 +1447,7 @@ class _UnslothGRPOTrainer(Trainer): hub_model_id=self.hub_model_id, dataset_name=dataset_name, tags=tags, - wandb_url=wandb.run.get_url() - if is_wandb_available() and wandb.run is not None - else None, + wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, comet_url=get_comet_experiment_url(), trainer_name="GRPO", trainer_citation=citation, @@ -1735,10 +1584,7 @@ class UnslothGRPOTrainer(_UnslothGRPOTrainer): args.fp16 = float16 args.bf16 = not float16 os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16" if float16 else "bf16" - if ( - getattr(args, "eval_dataset", None) is not None - and getattr(args, "eval_strategy", "no") == "no" - ): + if getattr(args, "eval_dataset", None) is not None and getattr(args, "eval_strategy", "no") == "no": args.eval_strategy = "steps" if getattr(args, "eval_steps", None) is None: args.eval_steps = 0.1 @@ -1755,10 +1601,7 @@ class UnslothGRPOTrainer(_UnslothGRPOTrainer): eval_bsz = getattr(args, "per_device_eval_batch_size", 8) if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size - if ( - getattr(args, "eval_accumulation_steps", None) is None - and ga_steps is not None - ): + if getattr(args, "eval_accumulation_steps", None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps fp16_full_eval = getattr(args, "fp16_full_eval", False) bf16_full_eval = getattr(args, "bf16_full_eval", False) @@ -1787,9 +1630,7 @@ class UnslothGRPOTrainer(_UnslothGRPOTrainer): if "processing_class" in locals(): if hasattr(processing_class, "padding_side"): processing_class.padding_side = "right" - if hasattr(processing_class, "tokenizer") and hasattr( - processing_class.tokenizer, "padding_side" - ): + if hasattr(processing_class, "tokenizer") and hasattr(processing_class.tokenizer, "padding_side"): processing_class.tokenizer.padding_side = "right" other_metrics = [] if not isinstance(reward_funcs, list): diff --git a/src/config.py b/src/config.py index 838b971..ecc767b 100644 --- a/src/config.py +++ b/src/config.py @@ -19,15 +19,10 @@ LOG_FOLDER = PROJ_ROOT / "logs" # Model configuration # MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" -device_id = ( - 1 if os.environ.get("CUDA_VISIBLE_DEVICES") == "1" else torch.cuda.current_device() -) +device_id = 1 if os.environ.get("CUDA_VISIBLE_DEVICES") == "1" else torch.cuda.current_device() timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") -OUTPUT_DIR = ( - PROJ_ROOT - / f"trainer_output_{MODEL_NAME.replace('/', '_')}_gpu{device_id}_{timestamp}" -) +OUTPUT_DIR = PROJ_ROOT / f"trainer_output_{MODEL_NAME.replace('/', '_')}_gpu{device_id}_{timestamp}" # Model parameters MODEL_CONFIG = { @@ -103,12 +98,7 @@ def _init_logging(env: str = "development") -> None: "- {message}" ) - file_format = ( - "{time:YYYY-MM-DD at HH:mm:ss} " - "| {level} " - "| {name}:{function}:{line} " - "- {message}" - ) + file_format = "{time:YYYY-MM-DD at HH:mm:ss} | {level} | {name}:{function}:{line} - {message}" # Add console logging logger.add( @@ -139,9 +129,7 @@ def _init_logging(env: str = "development") -> None: if issubclass(exc_type, KeyboardInterrupt): sys.__excepthook__(exc_type, exc_value, exc_traceback) return - logger.opt(exception=(exc_type, exc_value, exc_traceback)).critical( - "Unhandled exception" - ) + logger.opt(exception=(exc_type, exc_value, exc_traceback)).critical("Unhandled exception") sys.excepthook = exception_handler @@ -163,12 +151,7 @@ def update_log_path(log_dir=None): log_dir = Path(log_dir) log_dir.mkdir(exist_ok=True, parents=True) - file_format = ( - "{time:YYYY-MM-DD at HH:mm:ss} " - "| {level} " - "| {name}:{function}:{line} " - "- {message}" - ) + file_format = "{time:YYYY-MM-DD at HH:mm:ss} | {level} | {name}:{function}:{line} - {message}" # Add additional file handler pointing to training directory # No need to remove existing handlers as we want to keep those @@ -248,9 +231,7 @@ def setup_logger(module_name=None, create_dirs: bool = False): Returns: Configured logger instance """ - logger.warning( - "setup_logger is deprecated. Import logger directly from config instead." - ) + logger.warning("setup_logger is deprecated. Import logger directly from config instead.") return logger diff --git a/src/rl_helpers.py b/src/rl_helpers.py index d60c887..117b86b 100644 --- a/src/rl_helpers.py +++ b/src/rl_helpers.py @@ -186,12 +186,8 @@ def run_agent_generations(generate_fn, tokenizer, chat_states): full_response = response.outputs[0].text else: full_response = response - assistant_response = full_response.split( - "<|start_header_id|>assistant<|end_header_id|>" - )[-1] - chat_state["messages"].append( - {"role": "assistant", "content": assistant_response} - ) + assistant_response = full_response.split("<|start_header_id|>assistant<|end_header_id|>")[-1] + chat_state["messages"].append({"role": "assistant", "content": assistant_response}) logger.debug(f"Added assistant response to chat state {idx}") else: logger.debug("No prompts to generate responses for") @@ -211,9 +207,7 @@ def check_finished_chats(chat_states): for chat_state in chat_states: if chat_state.get("finished"): continue - assert ( - chat_state["messages"][-1]["role"] == "assistant" - ), "Expected the last role to be assistant" + assert chat_state["messages"][-1]["role"] == "assistant", "Expected the last role to be assistant" assistant_response = chat_state["messages"][-1]["content"] function_calls = extract_json_objects(assistant_response) if len(function_calls) == 0: @@ -232,17 +226,15 @@ def run_tool_calls(chat_states): if chat_state.get("finished"): logger.debug("Chat state already finished, skipping tool calls") continue - assert ( - chat_state["messages"][-1]["role"] == "assistant" - ), "Expected the last role to be assistant to run tool calls" + assert chat_state["messages"][-1]["role"] == "assistant", ( + "Expected the last role to be assistant to run tool calls" + ) try: assistant_response = chat_state["messages"][-1]["content"] function_calls = extract_json_objects(assistant_response) if len(function_calls) > 1: logger.warning("Multiple function calls found in assistant response") - raise ValueError( - "Expected only one function call in assistant response" - ) + raise ValueError("Expected only one function call in assistant response") elif len(function_calls) == 1: function_call = function_calls[0] query = function_call["function"]["parameters"]["query"] @@ -257,9 +249,7 @@ def run_tool_calls(chat_states): logger.debug("Added search results to chat state") except Exception as e: logger.error(f"Error during tool call: {str(e)}") - chat_state["messages"].append( - {"role": "system", "content": f"Error during post-processing: {str(e)}"} - ) + chat_state["messages"].append({"role": "system", "content": f"Error during post-processing: {str(e)}"}) chat_state["finished"] = True return chat_states @@ -273,14 +263,9 @@ def get_mask(text, tokenizer): assistant_ranges = [] i = 0 while i < len(encoding.input_ids) - 1: - if ( - encoding.input_ids[i] == start_header_id - and encoding.input_ids[i + 1] == assistant_token - ): + if encoding.input_ids[i] == start_header_id and encoding.input_ids[i + 1] == assistant_token: i += 2 - while ( - i < len(encoding.input_ids) and encoding.input_ids[i] != end_header_id - ): + while i < len(encoding.input_ids) and encoding.input_ids[i] != end_header_id: i += 1 i += 2 start_idx = i @@ -319,11 +304,7 @@ class AgenticOutputs: def get_chat_num_tokens(chat_state, tokenizer): chat_text = apply_chat_template(chat_state, tokenizer=tokenizer)["text"] - return ( - tokenizer(chat_text, add_special_tokens=False, return_tensors="pt")["input_ids"] - .squeeze() - .shape[0] - ) + return tokenizer(chat_text, add_special_tokens=False, return_tensors="pt")["input_ids"].squeeze().shape[0] def run_agent( @@ -338,9 +319,7 @@ def run_agent( Run the agent to completion for a batch of questions. """ logger.info(f"Starting agent run with {len(questions)} questions") - logger.debug( - f"Max generations: {max_generations}, Max new tokens: {max_new_tokens}" - ) + logger.debug(f"Max generations: {max_generations}, Max new tokens: {max_new_tokens}") chat_states = [get_initial_chat(q) for q in questions] # Add correct content to chat states if provided @@ -359,13 +338,9 @@ def run_agent( chat_states = run_agent_generations(generate_fn, tokenizer, chat_states) chat_states = check_finished_chats(chat_states) chat_states = run_tool_calls(chat_states) - chat_states = check_exceeded_max_new_tokens( - chat_states, max_new_tokens, tokenizer - ) + chat_states = check_exceeded_max_new_tokens(chat_states, max_new_tokens, tokenizer) finished_count = sum(1 for state in chat_states if state.get("finished")) - logger.info( - f"Finished {finished_count}/{len(chat_states)} chat states after step {i + 1}" - ) + logger.info(f"Finished {finished_count}/{len(chat_states)} chat states after step {i + 1}") logger.info("Agent run completed") @@ -387,23 +362,15 @@ def run_agent( assistant_response = convo_text[idx + len(marker) :] return prompt, assistant_response - str_chats = [ - apply_chat_template(chat, tokenizer=tokenizer)["text"] for chat in chat_states - ] + str_chats = [apply_chat_template(chat, tokenizer=tokenizer)["text"] for chat in chat_states] prompt_toks, response_toks, response_masks = [], [], [] logger.debug("Processing tokenization") for i, str_chat in enumerate(str_chats): prompt, response = split_prompt_assistant(str_chat) - prompt_toks.append( - tokenizer(prompt, add_special_tokens=False, return_tensors="pt")[ - "input_ids" - ].squeeze() - ) + prompt_toks.append(tokenizer(prompt, add_special_tokens=False, return_tensors="pt")["input_ids"].squeeze()) response_toks.append( - tokenizer(response, add_special_tokens=False, return_tensors="pt")[ - "input_ids" - ].squeeze()[:max_new_tokens] + tokenizer(response, add_special_tokens=False, return_tensors="pt")["input_ids"].squeeze()[:max_new_tokens] ) mask = get_mask(str_chat, tokenizer)[len(prompt_toks[-1]) :][:max_new_tokens] response_masks.append(mask) @@ -469,12 +436,8 @@ def check_student_answers( logger.info(f"Checking {len(questions)} student answers") if not (len(questions) == len(answers) == len(student_answers)): - logger.error( - "Mismatched lengths between questions, answers, and student answers" - ) - raise ValueError( - "The number of questions, answers, and student answers must be equal." - ) + logger.error("Mismatched lengths between questions, answers, and student answers") + raise ValueError("The number of questions, answers, and student answers must be equal.") prompts = [] for question, answer, student_ans in zip(questions, answers, student_answers): @@ -537,7 +500,7 @@ def check_student_answers( for i, (question, answer, student_ans, verifier_response) in enumerate( zip(questions, answers, student_answers, responses_text) ): - file.write(f"\nā“ Question {i+1}:\n") + file.write(f"\nā“ Question {i + 1}:\n") file.write("-" * 40 + "\n") file.write(f"šŸ“‹ Question: {question}\n") file.write(f"āœ… Correct Answer: {answer}\n") @@ -548,11 +511,7 @@ def check_student_answers( if isinstance(student_ans, dict) and "messages" in student_ans: # Get messages from dict messages = student_ans.get("messages", []) - search_results = [ - msg.get("content", "") - for msg in messages - if msg.get("role") == "ipython" - ] + search_results = [msg.get("content", "") for msg in messages if msg.get("role") == "ipython"] if search_results: file.write("\nšŸ”Ž Search Results:\n") for j, result in enumerate(search_results, 1): @@ -561,7 +520,7 @@ def check_student_answers( file.write("-" * 40 + "\n") file.write( - f"\nšŸ“Š Summary: {sum(results)}/{len(results)} answers correct ({sum(results)/len(results)*100:.2f}%)\n" + f"\nšŸ“Š Summary: {sum(results)}/{len(results)} answers correct ({sum(results) / len(results) * 100:.2f}%)\n" ) file.write("=" * 80 + "\n\n") @@ -572,18 +531,12 @@ def check_student_answers( def build_reward_correctness_fn(generate_fn, tokenizer, log_file=None): def reward_correctness(prompts, completions, **reward_kwargs): teacher_answers = reward_kwargs["answer"] - student_answers = [ - completion["messages"][-1]["content"] for completion in completions - ] + student_answers = [completion["messages"][-1]["content"] for completion in completions] # Log non-exact matches for i, (student, teacher) in enumerate(zip(student_answers, teacher_answers)): if student.strip().lower() != teacher.strip().lower(): - logger.warning( - f"Non-exact match at index {i}:\n" - f"Student: {student}\n" - f"Teacher: {teacher}" - ) + logger.warning(f"Non-exact match at index {i}:\nStudent: {student}\nTeacher: {teacher}") correct = check_student_answers( prompts, @@ -595,12 +548,8 @@ def build_reward_correctness_fn(generate_fn, tokenizer, log_file=None): ) # Log correctness metrics with length info - log_metric( - "rewards/correctness", np.mean(correct), reward_kwargs.get("step", 0) - ) - log_metric( - "rewards/correctness_std", np.std(correct), reward_kwargs.get("step", 0) - ) + log_metric("rewards/correctness", np.mean(correct), reward_kwargs.get("step", 0)) + log_metric("rewards/correctness_std", np.std(correct), reward_kwargs.get("step", 0)) # Log length metrics student_lengths = [len(ans.strip()) for ans in student_answers] @@ -676,9 +625,7 @@ def reward_retry_behavior(completions: list[dict], **reward_kwargs) -> list[floa if json_count > 1: has_multiple_json = True - logger.warning( - f"Message contains {json_count} JSON objects, which exceeds the limit of 1" - ) + logger.warning(f"Message contains {json_count} JSON objects, which exceeds the limit of 1") break # Only reward if no message has multiple JSON objects @@ -692,17 +639,13 @@ def reward_retry_behavior(completions: list[dict], **reward_kwargs) -> list[floa if total_json_objects > 4: penalty = 0.1 * (total_json_objects - 4) base_reward = max(0.2, base_reward - penalty) - logger.debug( - f"Applied penalty for {total_json_objects} total JSON objects: {penalty}" - ) + logger.debug(f"Applied penalty for {total_json_objects} total JSON objects: {penalty}") rewards.append(base_reward) # Log retry behavior metrics log_metric("rewards/retry_behavior", np.mean(rewards), reward_kwargs.get("step", 0)) - log_metric( - "rewards/retry_behavior_std", np.std(rewards), reward_kwargs.get("step", 0) - ) + log_metric("rewards/retry_behavior_std", np.std(rewards), reward_kwargs.get("step", 0)) log_metric( "metrics/avg_json_per_msg", np.mean( @@ -737,28 +680,22 @@ def reward_exact_match_chunk_query(prompts, completions, **reward_kwargs): raise ValueError("chunk_content must be provided in reward_kwargs") rewards = [] - for i, (chat_state, correct_content) in enumerate( - zip(completions, correct_contents) - ): + for i, (chat_state, correct_content) in enumerate(zip(completions, correct_contents)): # Get all ipython messages (search results) from the chat - search_results = [ - msg["content"] for msg in chat_state["messages"] if msg["role"] == "ipython" - ] + search_results = [msg["content"] for msg in chat_state["messages"] if msg["role"] == "ipython"] logger.debug(f"Found {len(search_results)} search results for prompt {i}") # Log ground truth chunk and searched chunks logger.info(f"šŸ“ Ground Truth Chunk: {correct_content}") for j, result in enumerate(search_results): - logger.info(f"šŸ” Searched Chunk {j+1}: {result}") + logger.info(f"šŸ” Searched Chunk {j + 1}: {result}") # Check if any search hit the correct chunk content found_correct_chunk = False for result in search_results: if correct_content in result: found_correct_chunk = True - logger.debug( - f"Found correct chunk content in search results for prompt {i}" - ) + logger.debug(f"Found correct chunk content in search results for prompt {i}") break if not found_correct_chunk: @@ -796,21 +733,13 @@ def reward_exact_match_chunk_query(prompts, completions, **reward_kwargs): "metrics/avg_search_results", np.mean( [ - len( - [ - msg["content"] - for msg in chat_state["messages"] - if msg["role"] == "ipython" - ] - ) + len([msg["content"] for msg in chat_state["messages"] if msg["role"] == "ipython"]) for chat_state in completions ] ), reward_kwargs.get("step", 0), ) - log_metric( - "metrics/chunk_match_rate", np.mean(rewards), reward_kwargs.get("step", 0) - ) + log_metric("metrics/chunk_match_rate", np.mean(rewards), reward_kwargs.get("step", 0)) # Log detailed debugging info logger.info("Chunk Query Rewards Summary:") @@ -862,10 +791,8 @@ def run_eval(generate_fn, verify_fn, tokenizer, output_file=None, debug_file=Non f.write(f"Percentage correct: {percent_correct:.2f}%\n\n") f.write("Individual results:\n") - for i, (q, r, resp) in enumerate( - zip(questions, rewards, final_responses) - ): - f.write(f"\nQ{i+1}: {q[:100]}...\n") + for i, (q, r, resp) in enumerate(zip(questions, rewards, final_responses)): + f.write(f"\nQ{i + 1}: {q[:100]}...\n") f.write(f"Correct: {'āœ“' if r else 'āœ—'}\n") f.write(f"Response: {resp[:150]}...\n") f.write("-" * 40 + "\n") @@ -879,9 +806,7 @@ def run_eval(generate_fn, verify_fn, tokenizer, output_file=None, debug_file=Non import json debug_data = [] - for i, (q, r, resp, chat) in enumerate( - zip(questions, rewards, final_responses, full_chat_states) - ): + for i, (q, r, resp, chat) in enumerate(zip(questions, rewards, final_responses, full_chat_states)): debug_data.append( { "question_id": i, diff --git a/src/search_module.py b/src/search_module.py index 50955b7..fe4991d 100644 --- a/src/search_module.py +++ b/src/search_module.py @@ -20,9 +20,7 @@ def load_vectorstore(): embeddings = CustomHuggingFaceEmbeddings() # Load the FAISS index from the data directory logger.info(f"Loading FAISS index from: {DATA_DIR}") - vectorstore = FAISS.load_local( - str(DATA_DIR), embeddings, allow_dangerous_deserialization=True - ) + vectorstore = FAISS.load_local(str(DATA_DIR), embeddings, allow_dangerous_deserialization=True) logger.info("Successfully loaded FAISS index") return vectorstore except Exception as e: @@ -125,9 +123,7 @@ def get_question_answer(idx=None, return_both: bool = True) -> dict: # Select question by index qa_pair = questions[idx] else: - raise ValueError( - f"Index out of range. Must be between 0 and {len(questions) - 1}" - ) + raise ValueError(f"Index out of range. Must be between 0 and {len(questions) - 1}") question = qa_pair["question"] answer = qa_pair["answer"] diff --git a/train_grpo.py b/train_grpo.py index ff0699d..bd9ee63 100644 --- a/train_grpo.py +++ b/train_grpo.py @@ -3,16 +3,6 @@ import os from unsloth import FastLanguageModel, is_bfloat16_supported import src.UnslothGRPOTrainerTemp as UnslothGRPOTrainerTemp - -# Import reward functions -from src.rl_helpers import ( - build_reward_correctness_fn, - get_qa_dataset, - reward_exact_match_chunk_query, - reward_formatting, - reward_retry_behavior, - run_agent, -) from src.config import ( MODEL_CONFIG, MODEL_NAME, @@ -24,6 +14,16 @@ from src.config import ( update_log_path, ) +# Import reward functions +from src.rl_helpers import ( + build_reward_correctness_fn, + get_qa_dataset, + reward_exact_match_chunk_query, + reward_formatting, + reward_retry_behavior, + run_agent, +) + # Initialize training directories paths = init_training_dirs() @@ -57,9 +57,7 @@ model = FastLanguageModel.get_peft_model( # Load datasets logger.info("Loading datasets") train_dataset, test_dataset = get_qa_dataset() -logger.info( - f"Loaded {len(train_dataset)} training examples and {len(test_dataset)} test examples" -) +logger.info(f"Loaded {len(train_dataset)} training examples and {len(test_dataset)} test examples") # Setup training arguments logger.info("Setting up training arguments")