|
|
@ -65,11 +65,11 @@ torch_compile_options = {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# @torch.compile(
|
|
|
|
@torch.compile(
|
|
|
|
# dynamic=True,
|
|
|
|
dynamic=True,
|
|
|
|
# fullgraph=True,
|
|
|
|
fullgraph=True,
|
|
|
|
# options=torch_compile_options,
|
|
|
|
options=torch_compile_options,
|
|
|
|
# )
|
|
|
|
)
|
|
|
|
def selective_log_softmax(logits, index):
|
|
|
|
def selective_log_softmax(logits, index):
|
|
|
|
logits = logits.to(torch.float32)
|
|
|
|
logits = logits.to(torch.float32)
|
|
|
|
selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1)
|
|
|
|
selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1)
|
|
|
@ -86,26 +86,26 @@ def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages)
|
|
|
|
new_logits = new_logits.to(torch.float32)
|
|
|
|
new_logits = new_logits.to(torch.float32)
|
|
|
|
|
|
|
|
|
|
|
|
# Print FULL tensor contents
|
|
|
|
# Print FULL tensor contents
|
|
|
|
logger.debug("\n🔍 DETAILED TENSOR ANALYSIS:")
|
|
|
|
# logger.debug("\n🔍 DETAILED TENSOR ANALYSIS:")
|
|
|
|
logger.debug("\n1️⃣ Input IDs:")
|
|
|
|
# logger.debug("\n1️⃣ Input IDs:")
|
|
|
|
logger.debug(f"Shape: {input_ids.shape}")
|
|
|
|
# logger.debug(f"Shape: {input_ids.shape}")
|
|
|
|
# Use tensor.cpu().numpy() to safely print content
|
|
|
|
# Use tensor.cpu().numpy() to safely print content
|
|
|
|
|
|
|
|
|
|
|
|
logger.debug(f"Type: {mask.dtype}")
|
|
|
|
# logger.debug(f"Type: {mask.dtype}")
|
|
|
|
logger.debug(f"Sum: {mask.sum().item()}") # Use .item() to get Python scalar
|
|
|
|
# logger.debug(f"Sum: {mask.sum().item()}") # Use .item() to get Python scalar
|
|
|
|
|
|
|
|
|
|
|
|
logger.debug("\n3️⃣ Old Logits:")
|
|
|
|
# logger.debug("\n3️⃣ Old Logits:")
|
|
|
|
logger.debug(f"Shape: {old_logits.shape}")
|
|
|
|
# logger.debug(f"Shape: {old_logits.shape}")
|
|
|
|
logger.debug(f"Type: {old_logits.dtype}")
|
|
|
|
# logger.debug(f"Type: {old_logits.dtype}")
|
|
|
|
logger.debug(f"Mean: {old_logits.mean().item():.4f}")
|
|
|
|
# logger.debug(f"Mean: {old_logits.mean().item():.4f}")
|
|
|
|
|
|
|
|
|
|
|
|
logger.debug("\n4️⃣ New Logits:")
|
|
|
|
# logger.debug("\n4️⃣ New Logits:")
|
|
|
|
logger.debug(f"Shape: {new_logits.shape}")
|
|
|
|
# logger.debug(f"Shape: {new_logits.shape}")
|
|
|
|
logger.debug(f"Type: {new_logits.dtype}")
|
|
|
|
# logger.debug(f"Type: {new_logits.dtype}")
|
|
|
|
logger.debug(f"Mean: {new_logits.mean().item():.4f}")
|
|
|
|
# logger.debug(f"Mean: {new_logits.mean().item():.4f}")
|
|
|
|
|
|
|
|
|
|
|
|
logger.debug("\n5️⃣ Advantages:")
|
|
|
|
# logger.debug("\n5️⃣ Advantages:")
|
|
|
|
logger.debug(f"Shape: {advantages.shape}")
|
|
|
|
# logger.debug(f"Shape: {advantages.shape}")
|
|
|
|
|
|
|
|
|
|
|
|
input_ids = input_ids.unsqueeze(-1)
|
|
|
|
input_ids = input_ids.unsqueeze(-1)
|
|
|
|
|
|
|
|
|
|
|
@ -115,9 +115,9 @@ def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages)
|
|
|
|
old = old_x - torch.logsumexp(old_logits, dim=-1)
|
|
|
|
old = old_x - torch.logsumexp(old_logits, dim=-1)
|
|
|
|
new = new_x - torch.logsumexp(new_logits, dim=-1)
|
|
|
|
new = new_x - torch.logsumexp(new_logits, dim=-1)
|
|
|
|
|
|
|
|
|
|
|
|
logger.debug("\n6️⃣ After Gather & LogSumExp:")
|
|
|
|
# logger.debug("\n6️⃣ After Gather & LogSumExp:")
|
|
|
|
logger.debug(f"old_x shape: {old_x.shape}, new_x shape: {new_x.shape}")
|
|
|
|
# logger.debug(f"old_x shape: {old_x.shape}, new_x shape: {new_x.shape}")
|
|
|
|
logger.debug(f"old shape: {old.shape}, new shape: {new.shape}")
|
|
|
|
# logger.debug(f"old shape: {old.shape}, new shape: {new.shape}")
|
|
|
|
|
|
|
|
|
|
|
|
# Reverse KL
|
|
|
|
# Reverse KL
|
|
|
|
kl_i = torch.exp(old - new) - (old - new) - 1.0
|
|
|
|
kl_i = torch.exp(old - new) - (old - new) - 1.0
|
|
|
@ -140,10 +140,8 @@ def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages)
|
|
|
|
# loss = loss_per_reward.mean()
|
|
|
|
# loss = loss_per_reward.mean()
|
|
|
|
|
|
|
|
|
|
|
|
# Add print statements here for debugging
|
|
|
|
# Add print statements here for debugging
|
|
|
|
logger.debug(f"🚨 Debug: loss_i shape: {loss_i.shape}")
|
|
|
|
# logger.debug(f"🚨 Debug: loss_i shape: {loss_i.shape}")
|
|
|
|
logger.debug(
|
|
|
|
# logger.debug(f"🚨 Debug: mask shape: {mask.shape}")
|
|
|
|
f"🚨 Debug: mask shape: {mask.shape}"
|
|
|
|
|
|
|
|
) # Note: Mask shape might change slightly due to float conversion
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loss = (loss_i * mask).sum() / mask.sum()
|
|
|
|
loss = (loss_i * mask).sum() / mask.sum()
|
|
|
|
|
|
|
|
|
|
|
@ -224,11 +222,11 @@ class UnslothEfficientGRPO(torch.autograd.Function):
|
|
|
|
|
|
|
|
|
|
|
|
pass
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
# accumulate_chunk = torch.compile(
|
|
|
|
accumulate_chunk = torch.compile(
|
|
|
|
# accumulate_chunk,
|
|
|
|
accumulate_chunk,
|
|
|
|
# fullgraph=True,
|
|
|
|
fullgraph=True,
|
|
|
|
# options=torch_compile_options,
|
|
|
|
options=torch_compile_options,
|
|
|
|
# )
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
grad_inputs_chunks = torch.chunk(grad_inputs, chunks=n_chunks, dim=0)
|
|
|
|
grad_inputs_chunks = torch.chunk(grad_inputs, chunks=n_chunks, dim=0)
|
|
|
|
new_hidden_states = torch.chunk(_new_hidden_states, chunks=n_chunks, dim=0)
|
|
|
|
new_hidden_states = torch.chunk(_new_hidden_states, chunks=n_chunks, dim=0)
|
|
|
@ -1058,13 +1056,13 @@ class _UnslothGRPOTrainer(Trainer):
|
|
|
|
return None
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
|
|
|
|
def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
|
|
|
|
logger.debug("\n🔍 DEBUG: Starting _prepare_inputs")
|
|
|
|
# logger.debug("\n🔍 DEBUG: Starting _prepare_inputs")
|
|
|
|
device = self.accelerator.device
|
|
|
|
device = self.accelerator.device
|
|
|
|
prompts = [x["prompt"] for x in inputs]
|
|
|
|
prompts = [x["prompt"] for x in inputs]
|
|
|
|
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
|
|
|
|
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
|
|
|
|
logger.debug("\n1️⃣ Before tokenization:")
|
|
|
|
# logger.debug("\n1️⃣ Before tokenization:")
|
|
|
|
logger.debug(f"Number of prompts: {len(prompts)}")
|
|
|
|
# logger.debug(f"Number of prompts: {len(prompts)}")
|
|
|
|
logger.debug(f"Sample prompt text length: {len(prompts_text[0]) if prompts_text else 0}")
|
|
|
|
# logger.debug(f"Sample prompt text length: {len(prompts_text[0]) if prompts_text else 0}")
|
|
|
|
|
|
|
|
|
|
|
|
prompt_inputs = self.processing_class(
|
|
|
|
prompt_inputs = self.processing_class(
|
|
|
|
prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
|
|
|
|
prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
|
|
|
@ -1072,17 +1070,17 @@ class _UnslothGRPOTrainer(Trainer):
|
|
|
|
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
|
|
|
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
|
|
|
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
|
|
|
|
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
|
|
|
|
|
|
|
|
|
|
|
|
logger.debug("\n2️⃣ After initial tokenization:")
|
|
|
|
# logger.debug("\n2️⃣ After initial tokenization:")
|
|
|
|
logger.debug(f"prompt_ids shape: {prompt_ids.shape}")
|
|
|
|
# logger.debug(f"prompt_ids shape: {prompt_ids.shape}")
|
|
|
|
logger.debug(f"prompt_mask shape: {prompt_mask.shape}")
|
|
|
|
# logger.debug(f"prompt_mask shape: {prompt_mask.shape}")
|
|
|
|
logger.debug(f"prompt_mask sum: {prompt_mask.sum().item()}")
|
|
|
|
# logger.debug(f"prompt_mask sum: {prompt_mask.sum().item()}")
|
|
|
|
|
|
|
|
|
|
|
|
if self.max_prompt_length is not None:
|
|
|
|
if self.max_prompt_length is not None:
|
|
|
|
prompt_ids = prompt_ids[:, -self.max_prompt_length :]
|
|
|
|
prompt_ids = prompt_ids[:, -self.max_prompt_length :]
|
|
|
|
prompt_mask = prompt_mask[:, -self.max_prompt_length :]
|
|
|
|
prompt_mask = prompt_mask[:, -self.max_prompt_length :]
|
|
|
|
logger.debug("\n3️⃣ After prompt length truncation:")
|
|
|
|
# logger.debug("\n3️⃣ After prompt length truncation:")
|
|
|
|
logger.debug(f"prompt_ids shape: {prompt_ids.shape}")
|
|
|
|
# logger.debug(f"prompt_ids shape: {prompt_ids.shape}")
|
|
|
|
logger.debug(f"prompt_mask shape: {prompt_mask.shape}")
|
|
|
|
# logger.debug(f"prompt_mask shape: {prompt_mask.shape}")
|
|
|
|
|
|
|
|
|
|
|
|
# Generate completions using either vLLM or regular generation
|
|
|
|
# Generate completions using either vLLM or regular generation
|
|
|
|
if self.args.use_vllm:
|
|
|
|
if self.args.use_vllm:
|
|
|
@ -1094,7 +1092,7 @@ class _UnslothGRPOTrainer(Trainer):
|
|
|
|
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
|
|
|
|
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
|
|
|
|
all_prompts_text = gather_object(prompts_text)
|
|
|
|
all_prompts_text = gather_object(prompts_text)
|
|
|
|
if self.accelerator.is_main_process:
|
|
|
|
if self.accelerator.is_main_process:
|
|
|
|
logger.debug(all_prompts_text)
|
|
|
|
# logger.debug(all_prompts_text)
|
|
|
|
generate_fn = lambda prompts_text: self.llm.generate(
|
|
|
|
generate_fn = lambda prompts_text: self.llm.generate(
|
|
|
|
prompts_text,
|
|
|
|
prompts_text,
|
|
|
|
sampling_params=self.sampling_params,
|
|
|
|
sampling_params=self.sampling_params,
|
|
|
@ -1112,10 +1110,10 @@ class _UnslothGRPOTrainer(Trainer):
|
|
|
|
prompt_inputs = agentic_outputs.prompt_tokens
|
|
|
|
prompt_inputs = agentic_outputs.prompt_tokens
|
|
|
|
completion_ids = agentic_outputs.response_tokens
|
|
|
|
completion_ids = agentic_outputs.response_tokens
|
|
|
|
completion_mask = agentic_outputs.response_masks
|
|
|
|
completion_mask = agentic_outputs.response_masks
|
|
|
|
for i in range(len(completion_ids)):
|
|
|
|
# for i in range(len(completion_ids)):
|
|
|
|
logger.debug(f"prompt_inputs {i} len before padding: {len(prompt_inputs[i])}")
|
|
|
|
# logger.debug(f"prompt_inputs {i} len before padding: {len(prompt_inputs[i])}")
|
|
|
|
logger.debug(f"completion_ids {i} len before padding: {len(completion_ids[i])}")
|
|
|
|
# logger.debug(f"completion_ids {i} len before padding: {len(completion_ids[i])}")
|
|
|
|
logger.debug(f"completion_mask {i} len before padding: {len(completion_mask[i])}")
|
|
|
|
# logger.debug(f"completion_mask {i} len before padding: {len(completion_mask[i])}")
|
|
|
|
|
|
|
|
|
|
|
|
prompt_ids = pad(
|
|
|
|
prompt_ids = pad(
|
|
|
|
prompt_inputs,
|
|
|
|
prompt_inputs,
|
|
|
@ -1128,10 +1126,10 @@ class _UnslothGRPOTrainer(Trainer):
|
|
|
|
padding_side="right",
|
|
|
|
padding_side="right",
|
|
|
|
).to(device)
|
|
|
|
).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
for i in range(len(completion_ids)):
|
|
|
|
# for i in range(len(completion_ids)):
|
|
|
|
logger.debug(f"prompt_inputs {i} len after padding: {len(prompt_inputs[i])}")
|
|
|
|
# logger.debug(f"prompt_inputs {i} len after padding: {len(prompt_inputs[i])}")
|
|
|
|
logger.debug(f"prompt_ids {i} len after padding: {len(prompt_ids[i])}")
|
|
|
|
# logger.debug(f"prompt_ids {i} len after padding: {len(prompt_ids[i])}")
|
|
|
|
logger.debug(f"completion_mask {i} len after padding: {len(completion_mask[i])}")
|
|
|
|
# logger.debug(f"completion_mask {i} len after padding: {len(completion_mask[i])}")
|
|
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
outputs = generate_fn(all_prompts_text)
|
|
|
|
outputs = generate_fn(all_prompts_text)
|
|
|
@ -1149,16 +1147,16 @@ class _UnslothGRPOTrainer(Trainer):
|
|
|
|
|
|
|
|
|
|
|
|
# Pad the completions, and concatenate them with the prompts
|
|
|
|
# Pad the completions, and concatenate them with the prompts
|
|
|
|
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
|
|
|
|
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
|
|
|
|
logger.debug("\n4️⃣ Before completion padding:")
|
|
|
|
# logger.debug("\n4️⃣ Before completion padding:")
|
|
|
|
logger.debug(f"completion_ids shapes: {[ids.shape for ids in completion_ids]}")
|
|
|
|
# logger.debug(f"completion_ids shapes: {[ids.shape for ids in completion_ids]}")
|
|
|
|
|
|
|
|
|
|
|
|
completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
|
|
|
|
completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
|
|
|
|
logger.debug("\n5️⃣ After completion padding:")
|
|
|
|
# logger.debug("\n5️⃣ After completion padding:")
|
|
|
|
logger.debug(f"completion_ids shape: {completion_ids.shape}")
|
|
|
|
# logger.debug(f"completion_ids shape: {completion_ids.shape}")
|
|
|
|
|
|
|
|
|
|
|
|
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
|
|
|
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
|
|
|
logger.debug("\n6️⃣ After concatenation:")
|
|
|
|
# logger.debug("\n6️⃣ After concatenation:")
|
|
|
|
logger.debug(f"prompt_completion_ids shape: {prompt_completion_ids.shape}")
|
|
|
|
# logger.debug(f"prompt_completion_ids shape: {prompt_completion_ids.shape}")
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
# Regular generation path
|
|
|
|
# Regular generation path
|
|
|
|
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
|
|
|
|
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
|
|
|
@ -1173,52 +1171,52 @@ class _UnslothGRPOTrainer(Trainer):
|
|
|
|
|
|
|
|
|
|
|
|
if not self.use_agentic_generate:
|
|
|
|
if not self.use_agentic_generate:
|
|
|
|
# Mask everything after the first EOS token
|
|
|
|
# Mask everything after the first EOS token
|
|
|
|
logger.debug("\n🔍 Starting EOS token detection and masking:")
|
|
|
|
# logger.debug("\n🔍 Starting EOS token detection and masking:")
|
|
|
|
logger.debug(f"completion_ids shape: {completion_ids.shape}")
|
|
|
|
# logger.debug(f"completion_ids shape: {completion_ids.shape}")
|
|
|
|
logger.debug(f"eos_token_id: {self.processing_class.eos_token_id}")
|
|
|
|
# logger.debug(f"eos_token_id: {self.processing_class.eos_token_id}")
|
|
|
|
|
|
|
|
|
|
|
|
# Debug EOS detection
|
|
|
|
# Debug EOS detection
|
|
|
|
is_eos = completion_ids == self.processing_class.eos_token_id
|
|
|
|
is_eos = completion_ids == self.processing_class.eos_token_id
|
|
|
|
logger.debug("\n7️⃣ EOS Detection Details:")
|
|
|
|
# logger.debug("\n7️⃣ EOS Detection Details:")
|
|
|
|
logger.debug(f"is_eos shape: {is_eos.shape}")
|
|
|
|
# logger.debug(f"is_eos shape: {is_eos.shape}")
|
|
|
|
logger.debug(f"Sample is_eos values (first sequence):\n{is_eos[0]}")
|
|
|
|
# logger.debug(f"Sample is_eos values (first sequence):\n{is_eos[0]}")
|
|
|
|
logger.debug(f"Any EOS tokens found: {is_eos.any().item()}")
|
|
|
|
# logger.debug(f"Any EOS tokens found: {is_eos.any().item()}")
|
|
|
|
logger.debug(f"EOS positions: {is_eos.nonzero()}")
|
|
|
|
# logger.debug(f"EOS positions: {is_eos.nonzero()}")
|
|
|
|
|
|
|
|
|
|
|
|
# Debug EOS index tensor creation
|
|
|
|
# Debug EOS index tensor creation
|
|
|
|
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
|
|
|
|
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
|
|
|
|
logger.debug("\n8️⃣ EOS Index Creation:")
|
|
|
|
# logger.debug("\n8️⃣ EOS Index Creation:")
|
|
|
|
logger.debug(f"eos_idx initial shape: {eos_idx.shape}")
|
|
|
|
# logger.debug(f"eos_idx initial shape: {eos_idx.shape}")
|
|
|
|
logger.debug(f"eos_idx initial values: {eos_idx}")
|
|
|
|
# logger.debug(f"eos_idx initial values: {eos_idx}")
|
|
|
|
|
|
|
|
|
|
|
|
# Debug the complex indexing operation
|
|
|
|
# Debug the complex indexing operation
|
|
|
|
logger.debug("\n9️⃣ EOS Position Analysis:")
|
|
|
|
# logger.debug("\n9️⃣ EOS Position Analysis:")
|
|
|
|
logger.debug(f"Sequences with EOS: {is_eos.any(dim=1).sum().item()}")
|
|
|
|
# logger.debug(f"Sequences with EOS: {is_eos.any(dim=1).sum().item()}")
|
|
|
|
logger.debug(f"First EOS positions: {is_eos.int().argmax(dim=1)}")
|
|
|
|
# logger.debug(f"First EOS positions: {is_eos.int().argmax(dim=1)}")
|
|
|
|
|
|
|
|
|
|
|
|
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
|
|
|
|
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
|
|
|
|
logger.debug("\n🔟 After EOS Index Update:")
|
|
|
|
# logger.debug("\n🔟 After EOS Index Update:")
|
|
|
|
logger.debug(f"Updated eos_idx values: {eos_idx}")
|
|
|
|
# logger.debug(f"Updated eos_idx values: {eos_idx}")
|
|
|
|
|
|
|
|
|
|
|
|
# Debug sequence indices creation
|
|
|
|
# Debug sequence indices creation
|
|
|
|
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
|
|
|
|
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
|
|
|
|
logger.debug("\n1️⃣1️⃣ Sequence Indices:")
|
|
|
|
# logger.debug("\n1️⃣1️⃣ Sequence Indices:")
|
|
|
|
logger.debug(f"sequence_indices shape: {sequence_indices.shape}")
|
|
|
|
# logger.debug(f"sequence_indices shape: {sequence_indices.shape}")
|
|
|
|
logger.debug(f"Sample sequence_indices (first row):\n{sequence_indices[0]}")
|
|
|
|
# logger.debug(f"Sample sequence_indices (first row):\n{sequence_indices[0]}")
|
|
|
|
|
|
|
|
|
|
|
|
# Debug final mask creation
|
|
|
|
# Debug final mask creation
|
|
|
|
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
|
|
|
|
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
|
|
|
|
logger.debug("\n1️⃣2️⃣ Final Completion Mask:")
|
|
|
|
# logger.debug("\n1️⃣2️⃣ Final Completion Mask:")
|
|
|
|
logger.debug(f"completion_mask shape: {completion_mask.shape}")
|
|
|
|
# logger.debug(f"completion_mask shape: {completion_mask.shape}")
|
|
|
|
logger.debug(f"Sample mask (first sequence):\n{completion_mask[0]}")
|
|
|
|
# logger.debug(f"Sample mask (first sequence):\n{completion_mask[0]}")
|
|
|
|
logger.debug("Mask statistics:")
|
|
|
|
# logger.debug("Mask statistics:")
|
|
|
|
logger.debug(f"- Total 1s: {completion_mask.sum().item()}")
|
|
|
|
# logger.debug(f"- Total 1s: {completion_mask.sum().item()}")
|
|
|
|
logger.debug(f"- Average sequence length: {completion_mask.sum(dim=1).float().mean().item():.2f}")
|
|
|
|
# logger.debug(f"- Average sequence length: {completion_mask.sum(dim=1).float().mean().item():.2f}")
|
|
|
|
|
|
|
|
|
|
|
|
# Add a final validation check
|
|
|
|
# Add a final validation check
|
|
|
|
logger.debug("\n7️⃣ Final Validation:")
|
|
|
|
# logger.debug("\n7️⃣ Final Validation:")
|
|
|
|
logger.debug(f"Input shape: {completion_ids.shape}")
|
|
|
|
# logger.debug(f"Input shape: {completion_ids.shape}")
|
|
|
|
logger.debug(f"Mask shape: {completion_mask.shape}")
|
|
|
|
# logger.debug(f"Mask shape: {completion_mask.shape}")
|
|
|
|
|
|
|
|
|
|
|
|
# Concatenate prompt_mask with completion_mask for logit computation
|
|
|
|
# Concatenate prompt_mask with completion_mask for logit computation
|
|
|
|
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C)
|
|
|
|
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C)
|
|
|
@ -1329,7 +1327,7 @@ class _UnslothGRPOTrainer(Trainer):
|
|
|
|
|
|
|
|
|
|
|
|
# Log the metrics
|
|
|
|
# Log the metrics
|
|
|
|
reward_per_func = rewards_per_func.mean(0)
|
|
|
|
reward_per_func = rewards_per_func.mean(0)
|
|
|
|
logger.debug("rewards_per_func:", reward_per_func)
|
|
|
|
# logger.debug("rewards_per_func:", reward_per_func)
|
|
|
|
for i, reward_func in enumerate(self.reward_funcs):
|
|
|
|
for i, reward_func in enumerate(self.reward_funcs):
|
|
|
|
if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
|
|
|
|
if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
|
|
|
|
reward_func_name = reward_func.config._name_or_path.split("/")[-1]
|
|
|
|
reward_func_name = reward_func.config._name_or_path.split("/")[-1]
|
|
|
@ -1640,7 +1638,7 @@ class UnslothGRPOTrainer(_UnslothGRPOTrainer):
|
|
|
|
from transformers import __version__ as transformers_version
|
|
|
|
from transformers import __version__ as transformers_version
|
|
|
|
|
|
|
|
|
|
|
|
if Version(transformers_version) <= Version("4.45.2"):
|
|
|
|
if Version(transformers_version) <= Version("4.45.2"):
|
|
|
|
logger.debug(
|
|
|
|
print(
|
|
|
|
"**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n"
|
|
|
|
"**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n"
|
|
|
|
"`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`"
|
|
|
|
"`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`"
|
|
|
|
)
|
|
|
|
)
|
|
|
|