From 900944066375cb67623b5cef658e0b725753a81b Mon Sep 17 00:00:00 2001 From: thinhlpg Date: Thu, 3 Apr 2025 13:32:19 +0700 Subject: [PATCH] chore: disable logging, enable torch complie --- src/UnslothGRPOTrainerTemp.py | 174 +++++++++++++++++----------------- 1 file changed, 86 insertions(+), 88 deletions(-) diff --git a/src/UnslothGRPOTrainerTemp.py b/src/UnslothGRPOTrainerTemp.py index 780ddd0..18ca2bf 100644 --- a/src/UnslothGRPOTrainerTemp.py +++ b/src/UnslothGRPOTrainerTemp.py @@ -65,11 +65,11 @@ torch_compile_options = { } -# @torch.compile( -# dynamic=True, -# fullgraph=True, -# options=torch_compile_options, -# ) +@torch.compile( + dynamic=True, + fullgraph=True, + options=torch_compile_options, +) def selective_log_softmax(logits, index): logits = logits.to(torch.float32) selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1) @@ -86,26 +86,26 @@ def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages) new_logits = new_logits.to(torch.float32) # Print FULL tensor contents - logger.debug("\n🔍 DETAILED TENSOR ANALYSIS:") - logger.debug("\n1️⃣ Input IDs:") - logger.debug(f"Shape: {input_ids.shape}") + # logger.debug("\n🔍 DETAILED TENSOR ANALYSIS:") + # logger.debug("\n1️⃣ Input IDs:") + # logger.debug(f"Shape: {input_ids.shape}") # Use tensor.cpu().numpy() to safely print content - logger.debug(f"Type: {mask.dtype}") - logger.debug(f"Sum: {mask.sum().item()}") # Use .item() to get Python scalar + # logger.debug(f"Type: {mask.dtype}") + # logger.debug(f"Sum: {mask.sum().item()}") # Use .item() to get Python scalar - logger.debug("\n3️⃣ Old Logits:") - logger.debug(f"Shape: {old_logits.shape}") - logger.debug(f"Type: {old_logits.dtype}") - logger.debug(f"Mean: {old_logits.mean().item():.4f}") + # logger.debug("\n3️⃣ Old Logits:") + # logger.debug(f"Shape: {old_logits.shape}") + # logger.debug(f"Type: {old_logits.dtype}") + # logger.debug(f"Mean: {old_logits.mean().item():.4f}") - logger.debug("\n4️⃣ New Logits:") - logger.debug(f"Shape: {new_logits.shape}") - logger.debug(f"Type: {new_logits.dtype}") - logger.debug(f"Mean: {new_logits.mean().item():.4f}") + # logger.debug("\n4️⃣ New Logits:") + # logger.debug(f"Shape: {new_logits.shape}") + # logger.debug(f"Type: {new_logits.dtype}") + # logger.debug(f"Mean: {new_logits.mean().item():.4f}") - logger.debug("\n5️⃣ Advantages:") - logger.debug(f"Shape: {advantages.shape}") + # logger.debug("\n5️⃣ Advantages:") + # logger.debug(f"Shape: {advantages.shape}") input_ids = input_ids.unsqueeze(-1) @@ -115,9 +115,9 @@ def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages) old = old_x - torch.logsumexp(old_logits, dim=-1) new = new_x - torch.logsumexp(new_logits, dim=-1) - logger.debug("\n6️⃣ After Gather & LogSumExp:") - logger.debug(f"old_x shape: {old_x.shape}, new_x shape: {new_x.shape}") - logger.debug(f"old shape: {old.shape}, new shape: {new.shape}") + # logger.debug("\n6️⃣ After Gather & LogSumExp:") + # logger.debug(f"old_x shape: {old_x.shape}, new_x shape: {new_x.shape}") + # logger.debug(f"old shape: {old.shape}, new shape: {new.shape}") # Reverse KL kl_i = torch.exp(old - new) - (old - new) - 1.0 @@ -140,10 +140,8 @@ def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages) # loss = loss_per_reward.mean() # Add print statements here for debugging - logger.debug(f"🚨 Debug: loss_i shape: {loss_i.shape}") - logger.debug( - f"🚨 Debug: mask shape: {mask.shape}" - ) # Note: Mask shape might change slightly due to float conversion + # logger.debug(f"🚨 Debug: loss_i shape: {loss_i.shape}") + # logger.debug(f"🚨 Debug: mask shape: {mask.shape}") loss = (loss_i * mask).sum() / mask.sum() @@ -224,11 +222,11 @@ class UnslothEfficientGRPO(torch.autograd.Function): pass - # accumulate_chunk = torch.compile( - # accumulate_chunk, - # fullgraph=True, - # options=torch_compile_options, - # ) + accumulate_chunk = torch.compile( + accumulate_chunk, + fullgraph=True, + options=torch_compile_options, + ) grad_inputs_chunks = torch.chunk(grad_inputs, chunks=n_chunks, dim=0) new_hidden_states = torch.chunk(_new_hidden_states, chunks=n_chunks, dim=0) @@ -1058,13 +1056,13 @@ class _UnslothGRPOTrainer(Trainer): return None def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]: - logger.debug("\n🔍 DEBUG: Starting _prepare_inputs") + # logger.debug("\n🔍 DEBUG: Starting _prepare_inputs") device = self.accelerator.device prompts = [x["prompt"] for x in inputs] prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs] - logger.debug("\n1️⃣ Before tokenization:") - logger.debug(f"Number of prompts: {len(prompts)}") - logger.debug(f"Sample prompt text length: {len(prompts_text[0]) if prompts_text else 0}") + # logger.debug("\n1️⃣ Before tokenization:") + # logger.debug(f"Number of prompts: {len(prompts)}") + # logger.debug(f"Sample prompt text length: {len(prompts_text[0]) if prompts_text else 0}") prompt_inputs = self.processing_class( prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False @@ -1072,17 +1070,17 @@ class _UnslothGRPOTrainer(Trainer): prompt_inputs = super()._prepare_inputs(prompt_inputs) prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] - logger.debug("\n2️⃣ After initial tokenization:") - logger.debug(f"prompt_ids shape: {prompt_ids.shape}") - logger.debug(f"prompt_mask shape: {prompt_mask.shape}") - logger.debug(f"prompt_mask sum: {prompt_mask.sum().item()}") + # logger.debug("\n2️⃣ After initial tokenization:") + # logger.debug(f"prompt_ids shape: {prompt_ids.shape}") + # logger.debug(f"prompt_mask shape: {prompt_mask.shape}") + # logger.debug(f"prompt_mask sum: {prompt_mask.sum().item()}") if self.max_prompt_length is not None: prompt_ids = prompt_ids[:, -self.max_prompt_length :] prompt_mask = prompt_mask[:, -self.max_prompt_length :] - logger.debug("\n3️⃣ After prompt length truncation:") - logger.debug(f"prompt_ids shape: {prompt_ids.shape}") - logger.debug(f"prompt_mask shape: {prompt_mask.shape}") + # logger.debug("\n3️⃣ After prompt length truncation:") + # logger.debug(f"prompt_ids shape: {prompt_ids.shape}") + # logger.debug(f"prompt_mask shape: {prompt_mask.shape}") # Generate completions using either vLLM or regular generation if self.args.use_vllm: @@ -1094,7 +1092,7 @@ class _UnslothGRPOTrainer(Trainer): # Generate completions using vLLM: gather all prompts and use them in a single call in the main process all_prompts_text = gather_object(prompts_text) if self.accelerator.is_main_process: - logger.debug(all_prompts_text) + # logger.debug(all_prompts_text) generate_fn = lambda prompts_text: self.llm.generate( prompts_text, sampling_params=self.sampling_params, @@ -1112,10 +1110,10 @@ class _UnslothGRPOTrainer(Trainer): prompt_inputs = agentic_outputs.prompt_tokens completion_ids = agentic_outputs.response_tokens completion_mask = agentic_outputs.response_masks - for i in range(len(completion_ids)): - logger.debug(f"prompt_inputs {i} len before padding: {len(prompt_inputs[i])}") - logger.debug(f"completion_ids {i} len before padding: {len(completion_ids[i])}") - logger.debug(f"completion_mask {i} len before padding: {len(completion_mask[i])}") + # for i in range(len(completion_ids)): + # logger.debug(f"prompt_inputs {i} len before padding: {len(prompt_inputs[i])}") + # logger.debug(f"completion_ids {i} len before padding: {len(completion_ids[i])}") + # logger.debug(f"completion_mask {i} len before padding: {len(completion_mask[i])}") prompt_ids = pad( prompt_inputs, @@ -1128,10 +1126,10 @@ class _UnslothGRPOTrainer(Trainer): padding_side="right", ).to(device) - for i in range(len(completion_ids)): - logger.debug(f"prompt_inputs {i} len after padding: {len(prompt_inputs[i])}") - logger.debug(f"prompt_ids {i} len after padding: {len(prompt_ids[i])}") - logger.debug(f"completion_mask {i} len after padding: {len(completion_mask[i])}") + # for i in range(len(completion_ids)): + # logger.debug(f"prompt_inputs {i} len after padding: {len(prompt_inputs[i])}") + # logger.debug(f"prompt_ids {i} len after padding: {len(prompt_ids[i])}") + # logger.debug(f"completion_mask {i} len after padding: {len(completion_mask[i])}") else: outputs = generate_fn(all_prompts_text) @@ -1149,16 +1147,16 @@ class _UnslothGRPOTrainer(Trainer): # Pad the completions, and concatenate them with the prompts completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] - logger.debug("\n4️⃣ Before completion padding:") - logger.debug(f"completion_ids shapes: {[ids.shape for ids in completion_ids]}") + # logger.debug("\n4️⃣ Before completion padding:") + # logger.debug(f"completion_ids shapes: {[ids.shape for ids in completion_ids]}") completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id) - logger.debug("\n5️⃣ After completion padding:") - logger.debug(f"completion_ids shape: {completion_ids.shape}") + # logger.debug("\n5️⃣ After completion padding:") + # logger.debug(f"completion_ids shape: {completion_ids.shape}") prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) - logger.debug("\n6️⃣ After concatenation:") - logger.debug(f"prompt_completion_ids shape: {prompt_completion_ids.shape}") + # logger.debug("\n6️⃣ After concatenation:") + # logger.debug(f"prompt_completion_ids shape: {prompt_completion_ids.shape}") else: # Regular generation path with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model: @@ -1173,52 +1171,52 @@ class _UnslothGRPOTrainer(Trainer): if not self.use_agentic_generate: # Mask everything after the first EOS token - logger.debug("\n🔍 Starting EOS token detection and masking:") - logger.debug(f"completion_ids shape: {completion_ids.shape}") - logger.debug(f"eos_token_id: {self.processing_class.eos_token_id}") + # logger.debug("\n🔍 Starting EOS token detection and masking:") + # logger.debug(f"completion_ids shape: {completion_ids.shape}") + # logger.debug(f"eos_token_id: {self.processing_class.eos_token_id}") # Debug EOS detection is_eos = completion_ids == self.processing_class.eos_token_id - logger.debug("\n7️⃣ EOS Detection Details:") - logger.debug(f"is_eos shape: {is_eos.shape}") - logger.debug(f"Sample is_eos values (first sequence):\n{is_eos[0]}") - logger.debug(f"Any EOS tokens found: {is_eos.any().item()}") - logger.debug(f"EOS positions: {is_eos.nonzero()}") + # logger.debug("\n7️⃣ EOS Detection Details:") + # logger.debug(f"is_eos shape: {is_eos.shape}") + # logger.debug(f"Sample is_eos values (first sequence):\n{is_eos[0]}") + # logger.debug(f"Any EOS tokens found: {is_eos.any().item()}") + # logger.debug(f"EOS positions: {is_eos.nonzero()}") # Debug EOS index tensor creation eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) - logger.debug("\n8️⃣ EOS Index Creation:") - logger.debug(f"eos_idx initial shape: {eos_idx.shape}") - logger.debug(f"eos_idx initial values: {eos_idx}") + # logger.debug("\n8️⃣ EOS Index Creation:") + # logger.debug(f"eos_idx initial shape: {eos_idx.shape}") + # logger.debug(f"eos_idx initial values: {eos_idx}") # Debug the complex indexing operation - logger.debug("\n9️⃣ EOS Position Analysis:") - logger.debug(f"Sequences with EOS: {is_eos.any(dim=1).sum().item()}") - logger.debug(f"First EOS positions: {is_eos.int().argmax(dim=1)}") + # logger.debug("\n9️⃣ EOS Position Analysis:") + # logger.debug(f"Sequences with EOS: {is_eos.any(dim=1).sum().item()}") + # logger.debug(f"First EOS positions: {is_eos.int().argmax(dim=1)}") eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] - logger.debug("\n🔟 After EOS Index Update:") - logger.debug(f"Updated eos_idx values: {eos_idx}") + # logger.debug("\n🔟 After EOS Index Update:") + # logger.debug(f"Updated eos_idx values: {eos_idx}") # Debug sequence indices creation sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) - logger.debug("\n1️⃣1️⃣ Sequence Indices:") - logger.debug(f"sequence_indices shape: {sequence_indices.shape}") - logger.debug(f"Sample sequence_indices (first row):\n{sequence_indices[0]}") + # logger.debug("\n1️⃣1️⃣ Sequence Indices:") + # logger.debug(f"sequence_indices shape: {sequence_indices.shape}") + # logger.debug(f"Sample sequence_indices (first row):\n{sequence_indices[0]}") # Debug final mask creation completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() - logger.debug("\n1️⃣2️⃣ Final Completion Mask:") - logger.debug(f"completion_mask shape: {completion_mask.shape}") - logger.debug(f"Sample mask (first sequence):\n{completion_mask[0]}") - logger.debug("Mask statistics:") - logger.debug(f"- Total 1s: {completion_mask.sum().item()}") - logger.debug(f"- Average sequence length: {completion_mask.sum(dim=1).float().mean().item():.2f}") + # logger.debug("\n1️⃣2️⃣ Final Completion Mask:") + # logger.debug(f"completion_mask shape: {completion_mask.shape}") + # logger.debug(f"Sample mask (first sequence):\n{completion_mask[0]}") + # logger.debug("Mask statistics:") + # logger.debug(f"- Total 1s: {completion_mask.sum().item()}") + # logger.debug(f"- Average sequence length: {completion_mask.sum(dim=1).float().mean().item():.2f}") # Add a final validation check - logger.debug("\n7️⃣ Final Validation:") - logger.debug(f"Input shape: {completion_ids.shape}") - logger.debug(f"Mask shape: {completion_mask.shape}") + # logger.debug("\n7️⃣ Final Validation:") + # logger.debug(f"Input shape: {completion_ids.shape}") + # logger.debug(f"Mask shape: {completion_mask.shape}") # Concatenate prompt_mask with completion_mask for logit computation attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C) @@ -1329,7 +1327,7 @@ class _UnslothGRPOTrainer(Trainer): # Log the metrics reward_per_func = rewards_per_func.mean(0) - logger.debug("rewards_per_func:", reward_per_func) + # logger.debug("rewards_per_func:", reward_per_func) for i, reward_func in enumerate(self.reward_funcs): if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models reward_func_name = reward_func.config._name_or_path.split("/")[-1] @@ -1640,7 +1638,7 @@ class UnslothGRPOTrainer(_UnslothGRPOTrainer): from transformers import __version__ as transformers_version if Version(transformers_version) <= Version("4.45.2"): - logger.debug( + print( "**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n" "`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`" )