chore: disable logging, enable torch complie

main
thinhlpg 1 month ago
parent d2f03b96ab
commit 9009440663

@ -65,11 +65,11 @@ torch_compile_options = {
} }
# @torch.compile( @torch.compile(
# dynamic=True, dynamic=True,
# fullgraph=True, fullgraph=True,
# options=torch_compile_options, options=torch_compile_options,
# ) )
def selective_log_softmax(logits, index): def selective_log_softmax(logits, index):
logits = logits.to(torch.float32) logits = logits.to(torch.float32)
selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1) selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1)
@ -86,26 +86,26 @@ def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages)
new_logits = new_logits.to(torch.float32) new_logits = new_logits.to(torch.float32)
# Print FULL tensor contents # Print FULL tensor contents
logger.debug("\n🔍 DETAILED TENSOR ANALYSIS:") # logger.debug("\n🔍 DETAILED TENSOR ANALYSIS:")
logger.debug("\n1⃣ Input IDs:") # logger.debug("\n1⃣ Input IDs:")
logger.debug(f"Shape: {input_ids.shape}") # logger.debug(f"Shape: {input_ids.shape}")
# Use tensor.cpu().numpy() to safely print content # Use tensor.cpu().numpy() to safely print content
logger.debug(f"Type: {mask.dtype}") # logger.debug(f"Type: {mask.dtype}")
logger.debug(f"Sum: {mask.sum().item()}") # Use .item() to get Python scalar # logger.debug(f"Sum: {mask.sum().item()}") # Use .item() to get Python scalar
logger.debug("\n3⃣ Old Logits:") # logger.debug("\n3⃣ Old Logits:")
logger.debug(f"Shape: {old_logits.shape}") # logger.debug(f"Shape: {old_logits.shape}")
logger.debug(f"Type: {old_logits.dtype}") # logger.debug(f"Type: {old_logits.dtype}")
logger.debug(f"Mean: {old_logits.mean().item():.4f}") # logger.debug(f"Mean: {old_logits.mean().item():.4f}")
logger.debug("\n4⃣ New Logits:") # logger.debug("\n4⃣ New Logits:")
logger.debug(f"Shape: {new_logits.shape}") # logger.debug(f"Shape: {new_logits.shape}")
logger.debug(f"Type: {new_logits.dtype}") # logger.debug(f"Type: {new_logits.dtype}")
logger.debug(f"Mean: {new_logits.mean().item():.4f}") # logger.debug(f"Mean: {new_logits.mean().item():.4f}")
logger.debug("\n5⃣ Advantages:") # logger.debug("\n5⃣ Advantages:")
logger.debug(f"Shape: {advantages.shape}") # logger.debug(f"Shape: {advantages.shape}")
input_ids = input_ids.unsqueeze(-1) input_ids = input_ids.unsqueeze(-1)
@ -115,9 +115,9 @@ def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages)
old = old_x - torch.logsumexp(old_logits, dim=-1) old = old_x - torch.logsumexp(old_logits, dim=-1)
new = new_x - torch.logsumexp(new_logits, dim=-1) new = new_x - torch.logsumexp(new_logits, dim=-1)
logger.debug("\n6⃣ After Gather & LogSumExp:") # logger.debug("\n6⃣ After Gather & LogSumExp:")
logger.debug(f"old_x shape: {old_x.shape}, new_x shape: {new_x.shape}") # logger.debug(f"old_x shape: {old_x.shape}, new_x shape: {new_x.shape}")
logger.debug(f"old shape: {old.shape}, new shape: {new.shape}") # logger.debug(f"old shape: {old.shape}, new shape: {new.shape}")
# Reverse KL # Reverse KL
kl_i = torch.exp(old - new) - (old - new) - 1.0 kl_i = torch.exp(old - new) - (old - new) - 1.0
@ -140,10 +140,8 @@ def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages)
# loss = loss_per_reward.mean() # loss = loss_per_reward.mean()
# Add print statements here for debugging # Add print statements here for debugging
logger.debug(f"🚨 Debug: loss_i shape: {loss_i.shape}") # logger.debug(f"🚨 Debug: loss_i shape: {loss_i.shape}")
logger.debug( # logger.debug(f"🚨 Debug: mask shape: {mask.shape}")
f"🚨 Debug: mask shape: {mask.shape}"
) # Note: Mask shape might change slightly due to float conversion
loss = (loss_i * mask).sum() / mask.sum() loss = (loss_i * mask).sum() / mask.sum()
@ -224,11 +222,11 @@ class UnslothEfficientGRPO(torch.autograd.Function):
pass pass
# accumulate_chunk = torch.compile( accumulate_chunk = torch.compile(
# accumulate_chunk, accumulate_chunk,
# fullgraph=True, fullgraph=True,
# options=torch_compile_options, options=torch_compile_options,
# ) )
grad_inputs_chunks = torch.chunk(grad_inputs, chunks=n_chunks, dim=0) grad_inputs_chunks = torch.chunk(grad_inputs, chunks=n_chunks, dim=0)
new_hidden_states = torch.chunk(_new_hidden_states, chunks=n_chunks, dim=0) new_hidden_states = torch.chunk(_new_hidden_states, chunks=n_chunks, dim=0)
@ -1058,13 +1056,13 @@ class _UnslothGRPOTrainer(Trainer):
return None return None
def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]: def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
logger.debug("\n🔍 DEBUG: Starting _prepare_inputs") # logger.debug("\n🔍 DEBUG: Starting _prepare_inputs")
device = self.accelerator.device device = self.accelerator.device
prompts = [x["prompt"] for x in inputs] prompts = [x["prompt"] for x in inputs]
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs] prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
logger.debug("\n1⃣ Before tokenization:") # logger.debug("\n1⃣ Before tokenization:")
logger.debug(f"Number of prompts: {len(prompts)}") # logger.debug(f"Number of prompts: {len(prompts)}")
logger.debug(f"Sample prompt text length: {len(prompts_text[0]) if prompts_text else 0}") # logger.debug(f"Sample prompt text length: {len(prompts_text[0]) if prompts_text else 0}")
prompt_inputs = self.processing_class( prompt_inputs = self.processing_class(
prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
@ -1072,17 +1070,17 @@ class _UnslothGRPOTrainer(Trainer):
prompt_inputs = super()._prepare_inputs(prompt_inputs) prompt_inputs = super()._prepare_inputs(prompt_inputs)
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
logger.debug("\n2⃣ After initial tokenization:") # logger.debug("\n2⃣ After initial tokenization:")
logger.debug(f"prompt_ids shape: {prompt_ids.shape}") # logger.debug(f"prompt_ids shape: {prompt_ids.shape}")
logger.debug(f"prompt_mask shape: {prompt_mask.shape}") # logger.debug(f"prompt_mask shape: {prompt_mask.shape}")
logger.debug(f"prompt_mask sum: {prompt_mask.sum().item()}") # logger.debug(f"prompt_mask sum: {prompt_mask.sum().item()}")
if self.max_prompt_length is not None: if self.max_prompt_length is not None:
prompt_ids = prompt_ids[:, -self.max_prompt_length :] prompt_ids = prompt_ids[:, -self.max_prompt_length :]
prompt_mask = prompt_mask[:, -self.max_prompt_length :] prompt_mask = prompt_mask[:, -self.max_prompt_length :]
logger.debug("\n3⃣ After prompt length truncation:") # logger.debug("\n3⃣ After prompt length truncation:")
logger.debug(f"prompt_ids shape: {prompt_ids.shape}") # logger.debug(f"prompt_ids shape: {prompt_ids.shape}")
logger.debug(f"prompt_mask shape: {prompt_mask.shape}") # logger.debug(f"prompt_mask shape: {prompt_mask.shape}")
# Generate completions using either vLLM or regular generation # Generate completions using either vLLM or regular generation
if self.args.use_vllm: if self.args.use_vllm:
@ -1094,7 +1092,7 @@ class _UnslothGRPOTrainer(Trainer):
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process # Generate completions using vLLM: gather all prompts and use them in a single call in the main process
all_prompts_text = gather_object(prompts_text) all_prompts_text = gather_object(prompts_text)
if self.accelerator.is_main_process: if self.accelerator.is_main_process:
logger.debug(all_prompts_text) # logger.debug(all_prompts_text)
generate_fn = lambda prompts_text: self.llm.generate( generate_fn = lambda prompts_text: self.llm.generate(
prompts_text, prompts_text,
sampling_params=self.sampling_params, sampling_params=self.sampling_params,
@ -1112,10 +1110,10 @@ class _UnslothGRPOTrainer(Trainer):
prompt_inputs = agentic_outputs.prompt_tokens prompt_inputs = agentic_outputs.prompt_tokens
completion_ids = agentic_outputs.response_tokens completion_ids = agentic_outputs.response_tokens
completion_mask = agentic_outputs.response_masks completion_mask = agentic_outputs.response_masks
for i in range(len(completion_ids)): # for i in range(len(completion_ids)):
logger.debug(f"prompt_inputs {i} len before padding: {len(prompt_inputs[i])}") # logger.debug(f"prompt_inputs {i} len before padding: {len(prompt_inputs[i])}")
logger.debug(f"completion_ids {i} len before padding: {len(completion_ids[i])}") # logger.debug(f"completion_ids {i} len before padding: {len(completion_ids[i])}")
logger.debug(f"completion_mask {i} len before padding: {len(completion_mask[i])}") # logger.debug(f"completion_mask {i} len before padding: {len(completion_mask[i])}")
prompt_ids = pad( prompt_ids = pad(
prompt_inputs, prompt_inputs,
@ -1128,10 +1126,10 @@ class _UnslothGRPOTrainer(Trainer):
padding_side="right", padding_side="right",
).to(device) ).to(device)
for i in range(len(completion_ids)): # for i in range(len(completion_ids)):
logger.debug(f"prompt_inputs {i} len after padding: {len(prompt_inputs[i])}") # logger.debug(f"prompt_inputs {i} len after padding: {len(prompt_inputs[i])}")
logger.debug(f"prompt_ids {i} len after padding: {len(prompt_ids[i])}") # logger.debug(f"prompt_ids {i} len after padding: {len(prompt_ids[i])}")
logger.debug(f"completion_mask {i} len after padding: {len(completion_mask[i])}") # logger.debug(f"completion_mask {i} len after padding: {len(completion_mask[i])}")
else: else:
outputs = generate_fn(all_prompts_text) outputs = generate_fn(all_prompts_text)
@ -1149,16 +1147,16 @@ class _UnslothGRPOTrainer(Trainer):
# Pad the completions, and concatenate them with the prompts # Pad the completions, and concatenate them with the prompts
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
logger.debug("\n4⃣ Before completion padding:") # logger.debug("\n4⃣ Before completion padding:")
logger.debug(f"completion_ids shapes: {[ids.shape for ids in completion_ids]}") # logger.debug(f"completion_ids shapes: {[ids.shape for ids in completion_ids]}")
completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id) completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
logger.debug("\n5⃣ After completion padding:") # logger.debug("\n5⃣ After completion padding:")
logger.debug(f"completion_ids shape: {completion_ids.shape}") # logger.debug(f"completion_ids shape: {completion_ids.shape}")
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
logger.debug("\n6⃣ After concatenation:") # logger.debug("\n6⃣ After concatenation:")
logger.debug(f"prompt_completion_ids shape: {prompt_completion_ids.shape}") # logger.debug(f"prompt_completion_ids shape: {prompt_completion_ids.shape}")
else: else:
# Regular generation path # Regular generation path
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model: with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
@ -1173,52 +1171,52 @@ class _UnslothGRPOTrainer(Trainer):
if not self.use_agentic_generate: if not self.use_agentic_generate:
# Mask everything after the first EOS token # Mask everything after the first EOS token
logger.debug("\n🔍 Starting EOS token detection and masking:") # logger.debug("\n🔍 Starting EOS token detection and masking:")
logger.debug(f"completion_ids shape: {completion_ids.shape}") # logger.debug(f"completion_ids shape: {completion_ids.shape}")
logger.debug(f"eos_token_id: {self.processing_class.eos_token_id}") # logger.debug(f"eos_token_id: {self.processing_class.eos_token_id}")
# Debug EOS detection # Debug EOS detection
is_eos = completion_ids == self.processing_class.eos_token_id is_eos = completion_ids == self.processing_class.eos_token_id
logger.debug("\n7⃣ EOS Detection Details:") # logger.debug("\n7⃣ EOS Detection Details:")
logger.debug(f"is_eos shape: {is_eos.shape}") # logger.debug(f"is_eos shape: {is_eos.shape}")
logger.debug(f"Sample is_eos values (first sequence):\n{is_eos[0]}") # logger.debug(f"Sample is_eos values (first sequence):\n{is_eos[0]}")
logger.debug(f"Any EOS tokens found: {is_eos.any().item()}") # logger.debug(f"Any EOS tokens found: {is_eos.any().item()}")
logger.debug(f"EOS positions: {is_eos.nonzero()}") # logger.debug(f"EOS positions: {is_eos.nonzero()}")
# Debug EOS index tensor creation # Debug EOS index tensor creation
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
logger.debug("\n8⃣ EOS Index Creation:") # logger.debug("\n8⃣ EOS Index Creation:")
logger.debug(f"eos_idx initial shape: {eos_idx.shape}") # logger.debug(f"eos_idx initial shape: {eos_idx.shape}")
logger.debug(f"eos_idx initial values: {eos_idx}") # logger.debug(f"eos_idx initial values: {eos_idx}")
# Debug the complex indexing operation # Debug the complex indexing operation
logger.debug("\n9⃣ EOS Position Analysis:") # logger.debug("\n9⃣ EOS Position Analysis:")
logger.debug(f"Sequences with EOS: {is_eos.any(dim=1).sum().item()}") # logger.debug(f"Sequences with EOS: {is_eos.any(dim=1).sum().item()}")
logger.debug(f"First EOS positions: {is_eos.int().argmax(dim=1)}") # logger.debug(f"First EOS positions: {is_eos.int().argmax(dim=1)}")
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
logger.debug("\n🔟 After EOS Index Update:") # logger.debug("\n🔟 After EOS Index Update:")
logger.debug(f"Updated eos_idx values: {eos_idx}") # logger.debug(f"Updated eos_idx values: {eos_idx}")
# Debug sequence indices creation # Debug sequence indices creation
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
logger.debug("\n1⃣1⃣ Sequence Indices:") # logger.debug("\n1⃣1⃣ Sequence Indices:")
logger.debug(f"sequence_indices shape: {sequence_indices.shape}") # logger.debug(f"sequence_indices shape: {sequence_indices.shape}")
logger.debug(f"Sample sequence_indices (first row):\n{sequence_indices[0]}") # logger.debug(f"Sample sequence_indices (first row):\n{sequence_indices[0]}")
# Debug final mask creation # Debug final mask creation
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
logger.debug("\n1⃣2⃣ Final Completion Mask:") # logger.debug("\n1⃣2⃣ Final Completion Mask:")
logger.debug(f"completion_mask shape: {completion_mask.shape}") # logger.debug(f"completion_mask shape: {completion_mask.shape}")
logger.debug(f"Sample mask (first sequence):\n{completion_mask[0]}") # logger.debug(f"Sample mask (first sequence):\n{completion_mask[0]}")
logger.debug("Mask statistics:") # logger.debug("Mask statistics:")
logger.debug(f"- Total 1s: {completion_mask.sum().item()}") # logger.debug(f"- Total 1s: {completion_mask.sum().item()}")
logger.debug(f"- Average sequence length: {completion_mask.sum(dim=1).float().mean().item():.2f}") # logger.debug(f"- Average sequence length: {completion_mask.sum(dim=1).float().mean().item():.2f}")
# Add a final validation check # Add a final validation check
logger.debug("\n7⃣ Final Validation:") # logger.debug("\n7⃣ Final Validation:")
logger.debug(f"Input shape: {completion_ids.shape}") # logger.debug(f"Input shape: {completion_ids.shape}")
logger.debug(f"Mask shape: {completion_mask.shape}") # logger.debug(f"Mask shape: {completion_mask.shape}")
# Concatenate prompt_mask with completion_mask for logit computation # Concatenate prompt_mask with completion_mask for logit computation
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C) attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C)
@ -1329,7 +1327,7 @@ class _UnslothGRPOTrainer(Trainer):
# Log the metrics # Log the metrics
reward_per_func = rewards_per_func.mean(0) reward_per_func = rewards_per_func.mean(0)
logger.debug("rewards_per_func:", reward_per_func) # logger.debug("rewards_per_func:", reward_per_func)
for i, reward_func in enumerate(self.reward_funcs): for i, reward_func in enumerate(self.reward_funcs):
if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
reward_func_name = reward_func.config._name_or_path.split("/")[-1] reward_func_name = reward_func.config._name_or_path.split("/")[-1]
@ -1640,7 +1638,7 @@ class UnslothGRPOTrainer(_UnslothGRPOTrainer):
from transformers import __version__ as transformers_version from transformers import __version__ as transformers_version
if Version(transformers_version) <= Version("4.45.2"): if Version(transformers_version) <= Version("4.45.2"):
logger.debug( print(
"**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n" "**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n"
"`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`" "`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`"
) )

Loading…
Cancel
Save