chore: disable logging, enable torch complie

main
thinhlpg 1 month ago
parent d2f03b96ab
commit 9009440663

@ -65,11 +65,11 @@ torch_compile_options = {
}
# @torch.compile(
# dynamic=True,
# fullgraph=True,
# options=torch_compile_options,
# )
@torch.compile(
dynamic=True,
fullgraph=True,
options=torch_compile_options,
)
def selective_log_softmax(logits, index):
logits = logits.to(torch.float32)
selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1)
@ -86,26 +86,26 @@ def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages)
new_logits = new_logits.to(torch.float32)
# Print FULL tensor contents
logger.debug("\n🔍 DETAILED TENSOR ANALYSIS:")
logger.debug("\n1⃣ Input IDs:")
logger.debug(f"Shape: {input_ids.shape}")
# logger.debug("\n🔍 DETAILED TENSOR ANALYSIS:")
# logger.debug("\n1⃣ Input IDs:")
# logger.debug(f"Shape: {input_ids.shape}")
# Use tensor.cpu().numpy() to safely print content
logger.debug(f"Type: {mask.dtype}")
logger.debug(f"Sum: {mask.sum().item()}") # Use .item() to get Python scalar
# logger.debug(f"Type: {mask.dtype}")
# logger.debug(f"Sum: {mask.sum().item()}") # Use .item() to get Python scalar
logger.debug("\n3⃣ Old Logits:")
logger.debug(f"Shape: {old_logits.shape}")
logger.debug(f"Type: {old_logits.dtype}")
logger.debug(f"Mean: {old_logits.mean().item():.4f}")
# logger.debug("\n3⃣ Old Logits:")
# logger.debug(f"Shape: {old_logits.shape}")
# logger.debug(f"Type: {old_logits.dtype}")
# logger.debug(f"Mean: {old_logits.mean().item():.4f}")
logger.debug("\n4⃣ New Logits:")
logger.debug(f"Shape: {new_logits.shape}")
logger.debug(f"Type: {new_logits.dtype}")
logger.debug(f"Mean: {new_logits.mean().item():.4f}")
# logger.debug("\n4⃣ New Logits:")
# logger.debug(f"Shape: {new_logits.shape}")
# logger.debug(f"Type: {new_logits.dtype}")
# logger.debug(f"Mean: {new_logits.mean().item():.4f}")
logger.debug("\n5⃣ Advantages:")
logger.debug(f"Shape: {advantages.shape}")
# logger.debug("\n5⃣ Advantages:")
# logger.debug(f"Shape: {advantages.shape}")
input_ids = input_ids.unsqueeze(-1)
@ -115,9 +115,9 @@ def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages)
old = old_x - torch.logsumexp(old_logits, dim=-1)
new = new_x - torch.logsumexp(new_logits, dim=-1)
logger.debug("\n6⃣ After Gather & LogSumExp:")
logger.debug(f"old_x shape: {old_x.shape}, new_x shape: {new_x.shape}")
logger.debug(f"old shape: {old.shape}, new shape: {new.shape}")
# logger.debug("\n6⃣ After Gather & LogSumExp:")
# logger.debug(f"old_x shape: {old_x.shape}, new_x shape: {new_x.shape}")
# logger.debug(f"old shape: {old.shape}, new shape: {new.shape}")
# Reverse KL
kl_i = torch.exp(old - new) - (old - new) - 1.0
@ -140,10 +140,8 @@ def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages)
# loss = loss_per_reward.mean()
# Add print statements here for debugging
logger.debug(f"🚨 Debug: loss_i shape: {loss_i.shape}")
logger.debug(
f"🚨 Debug: mask shape: {mask.shape}"
) # Note: Mask shape might change slightly due to float conversion
# logger.debug(f"🚨 Debug: loss_i shape: {loss_i.shape}")
# logger.debug(f"🚨 Debug: mask shape: {mask.shape}")
loss = (loss_i * mask).sum() / mask.sum()
@ -224,11 +222,11 @@ class UnslothEfficientGRPO(torch.autograd.Function):
pass
# accumulate_chunk = torch.compile(
# accumulate_chunk,
# fullgraph=True,
# options=torch_compile_options,
# )
accumulate_chunk = torch.compile(
accumulate_chunk,
fullgraph=True,
options=torch_compile_options,
)
grad_inputs_chunks = torch.chunk(grad_inputs, chunks=n_chunks, dim=0)
new_hidden_states = torch.chunk(_new_hidden_states, chunks=n_chunks, dim=0)
@ -1058,13 +1056,13 @@ class _UnslothGRPOTrainer(Trainer):
return None
def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
logger.debug("\n🔍 DEBUG: Starting _prepare_inputs")
# logger.debug("\n🔍 DEBUG: Starting _prepare_inputs")
device = self.accelerator.device
prompts = [x["prompt"] for x in inputs]
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
logger.debug("\n1⃣ Before tokenization:")
logger.debug(f"Number of prompts: {len(prompts)}")
logger.debug(f"Sample prompt text length: {len(prompts_text[0]) if prompts_text else 0}")
# logger.debug("\n1⃣ Before tokenization:")
# logger.debug(f"Number of prompts: {len(prompts)}")
# logger.debug(f"Sample prompt text length: {len(prompts_text[0]) if prompts_text else 0}")
prompt_inputs = self.processing_class(
prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
@ -1072,17 +1070,17 @@ class _UnslothGRPOTrainer(Trainer):
prompt_inputs = super()._prepare_inputs(prompt_inputs)
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
logger.debug("\n2⃣ After initial tokenization:")
logger.debug(f"prompt_ids shape: {prompt_ids.shape}")
logger.debug(f"prompt_mask shape: {prompt_mask.shape}")
logger.debug(f"prompt_mask sum: {prompt_mask.sum().item()}")
# logger.debug("\n2⃣ After initial tokenization:")
# logger.debug(f"prompt_ids shape: {prompt_ids.shape}")
# logger.debug(f"prompt_mask shape: {prompt_mask.shape}")
# logger.debug(f"prompt_mask sum: {prompt_mask.sum().item()}")
if self.max_prompt_length is not None:
prompt_ids = prompt_ids[:, -self.max_prompt_length :]
prompt_mask = prompt_mask[:, -self.max_prompt_length :]
logger.debug("\n3⃣ After prompt length truncation:")
logger.debug(f"prompt_ids shape: {prompt_ids.shape}")
logger.debug(f"prompt_mask shape: {prompt_mask.shape}")
# logger.debug("\n3⃣ After prompt length truncation:")
# logger.debug(f"prompt_ids shape: {prompt_ids.shape}")
# logger.debug(f"prompt_mask shape: {prompt_mask.shape}")
# Generate completions using either vLLM or regular generation
if self.args.use_vllm:
@ -1094,7 +1092,7 @@ class _UnslothGRPOTrainer(Trainer):
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
all_prompts_text = gather_object(prompts_text)
if self.accelerator.is_main_process:
logger.debug(all_prompts_text)
# logger.debug(all_prompts_text)
generate_fn = lambda prompts_text: self.llm.generate(
prompts_text,
sampling_params=self.sampling_params,
@ -1112,10 +1110,10 @@ class _UnslothGRPOTrainer(Trainer):
prompt_inputs = agentic_outputs.prompt_tokens
completion_ids = agentic_outputs.response_tokens
completion_mask = agentic_outputs.response_masks
for i in range(len(completion_ids)):
logger.debug(f"prompt_inputs {i} len before padding: {len(prompt_inputs[i])}")
logger.debug(f"completion_ids {i} len before padding: {len(completion_ids[i])}")
logger.debug(f"completion_mask {i} len before padding: {len(completion_mask[i])}")
# for i in range(len(completion_ids)):
# logger.debug(f"prompt_inputs {i} len before padding: {len(prompt_inputs[i])}")
# logger.debug(f"completion_ids {i} len before padding: {len(completion_ids[i])}")
# logger.debug(f"completion_mask {i} len before padding: {len(completion_mask[i])}")
prompt_ids = pad(
prompt_inputs,
@ -1128,10 +1126,10 @@ class _UnslothGRPOTrainer(Trainer):
padding_side="right",
).to(device)
for i in range(len(completion_ids)):
logger.debug(f"prompt_inputs {i} len after padding: {len(prompt_inputs[i])}")
logger.debug(f"prompt_ids {i} len after padding: {len(prompt_ids[i])}")
logger.debug(f"completion_mask {i} len after padding: {len(completion_mask[i])}")
# for i in range(len(completion_ids)):
# logger.debug(f"prompt_inputs {i} len after padding: {len(prompt_inputs[i])}")
# logger.debug(f"prompt_ids {i} len after padding: {len(prompt_ids[i])}")
# logger.debug(f"completion_mask {i} len after padding: {len(completion_mask[i])}")
else:
outputs = generate_fn(all_prompts_text)
@ -1149,16 +1147,16 @@ class _UnslothGRPOTrainer(Trainer):
# Pad the completions, and concatenate them with the prompts
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
logger.debug("\n4⃣ Before completion padding:")
logger.debug(f"completion_ids shapes: {[ids.shape for ids in completion_ids]}")
# logger.debug("\n4⃣ Before completion padding:")
# logger.debug(f"completion_ids shapes: {[ids.shape for ids in completion_ids]}")
completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
logger.debug("\n5⃣ After completion padding:")
logger.debug(f"completion_ids shape: {completion_ids.shape}")
# logger.debug("\n5⃣ After completion padding:")
# logger.debug(f"completion_ids shape: {completion_ids.shape}")
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
logger.debug("\n6⃣ After concatenation:")
logger.debug(f"prompt_completion_ids shape: {prompt_completion_ids.shape}")
# logger.debug("\n6⃣ After concatenation:")
# logger.debug(f"prompt_completion_ids shape: {prompt_completion_ids.shape}")
else:
# Regular generation path
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
@ -1173,52 +1171,52 @@ class _UnslothGRPOTrainer(Trainer):
if not self.use_agentic_generate:
# Mask everything after the first EOS token
logger.debug("\n🔍 Starting EOS token detection and masking:")
logger.debug(f"completion_ids shape: {completion_ids.shape}")
logger.debug(f"eos_token_id: {self.processing_class.eos_token_id}")
# logger.debug("\n🔍 Starting EOS token detection and masking:")
# logger.debug(f"completion_ids shape: {completion_ids.shape}")
# logger.debug(f"eos_token_id: {self.processing_class.eos_token_id}")
# Debug EOS detection
is_eos = completion_ids == self.processing_class.eos_token_id
logger.debug("\n7⃣ EOS Detection Details:")
logger.debug(f"is_eos shape: {is_eos.shape}")
logger.debug(f"Sample is_eos values (first sequence):\n{is_eos[0]}")
logger.debug(f"Any EOS tokens found: {is_eos.any().item()}")
logger.debug(f"EOS positions: {is_eos.nonzero()}")
# logger.debug("\n7⃣ EOS Detection Details:")
# logger.debug(f"is_eos shape: {is_eos.shape}")
# logger.debug(f"Sample is_eos values (first sequence):\n{is_eos[0]}")
# logger.debug(f"Any EOS tokens found: {is_eos.any().item()}")
# logger.debug(f"EOS positions: {is_eos.nonzero()}")
# Debug EOS index tensor creation
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
logger.debug("\n8⃣ EOS Index Creation:")
logger.debug(f"eos_idx initial shape: {eos_idx.shape}")
logger.debug(f"eos_idx initial values: {eos_idx}")
# logger.debug("\n8⃣ EOS Index Creation:")
# logger.debug(f"eos_idx initial shape: {eos_idx.shape}")
# logger.debug(f"eos_idx initial values: {eos_idx}")
# Debug the complex indexing operation
logger.debug("\n9⃣ EOS Position Analysis:")
logger.debug(f"Sequences with EOS: {is_eos.any(dim=1).sum().item()}")
logger.debug(f"First EOS positions: {is_eos.int().argmax(dim=1)}")
# logger.debug("\n9⃣ EOS Position Analysis:")
# logger.debug(f"Sequences with EOS: {is_eos.any(dim=1).sum().item()}")
# logger.debug(f"First EOS positions: {is_eos.int().argmax(dim=1)}")
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
logger.debug("\n🔟 After EOS Index Update:")
logger.debug(f"Updated eos_idx values: {eos_idx}")
# logger.debug("\n🔟 After EOS Index Update:")
# logger.debug(f"Updated eos_idx values: {eos_idx}")
# Debug sequence indices creation
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
logger.debug("\n1⃣1⃣ Sequence Indices:")
logger.debug(f"sequence_indices shape: {sequence_indices.shape}")
logger.debug(f"Sample sequence_indices (first row):\n{sequence_indices[0]}")
# logger.debug("\n1⃣1⃣ Sequence Indices:")
# logger.debug(f"sequence_indices shape: {sequence_indices.shape}")
# logger.debug(f"Sample sequence_indices (first row):\n{sequence_indices[0]}")
# Debug final mask creation
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
logger.debug("\n1⃣2⃣ Final Completion Mask:")
logger.debug(f"completion_mask shape: {completion_mask.shape}")
logger.debug(f"Sample mask (first sequence):\n{completion_mask[0]}")
logger.debug("Mask statistics:")
logger.debug(f"- Total 1s: {completion_mask.sum().item()}")
logger.debug(f"- Average sequence length: {completion_mask.sum(dim=1).float().mean().item():.2f}")
# logger.debug("\n1⃣2⃣ Final Completion Mask:")
# logger.debug(f"completion_mask shape: {completion_mask.shape}")
# logger.debug(f"Sample mask (first sequence):\n{completion_mask[0]}")
# logger.debug("Mask statistics:")
# logger.debug(f"- Total 1s: {completion_mask.sum().item()}")
# logger.debug(f"- Average sequence length: {completion_mask.sum(dim=1).float().mean().item():.2f}")
# Add a final validation check
logger.debug("\n7⃣ Final Validation:")
logger.debug(f"Input shape: {completion_ids.shape}")
logger.debug(f"Mask shape: {completion_mask.shape}")
# logger.debug("\n7⃣ Final Validation:")
# logger.debug(f"Input shape: {completion_ids.shape}")
# logger.debug(f"Mask shape: {completion_mask.shape}")
# Concatenate prompt_mask with completion_mask for logit computation
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C)
@ -1329,7 +1327,7 @@ class _UnslothGRPOTrainer(Trainer):
# Log the metrics
reward_per_func = rewards_per_func.mean(0)
logger.debug("rewards_per_func:", reward_per_func)
# logger.debug("rewards_per_func:", reward_per_func)
for i, reward_func in enumerate(self.reward_funcs):
if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
reward_func_name = reward_func.config._name_or_path.split("/")[-1]
@ -1640,7 +1638,7 @@ class UnslothGRPOTrainer(_UnslothGRPOTrainer):
from transformers import __version__ as transformers_version
if Version(transformers_version) <= Version("4.45.2"):
logger.debug(
print(
"**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n"
"`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`"
)

Loading…
Cancel
Save