|
|
@ -70,15 +70,11 @@ torch_compile_options = {
|
|
|
|
)
|
|
|
|
)
|
|
|
|
def selective_log_softmax(logits, index):
|
|
|
|
def selective_log_softmax(logits, index):
|
|
|
|
logits = logits.to(torch.float32)
|
|
|
|
logits = logits.to(torch.float32)
|
|
|
|
selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(
|
|
|
|
selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1)
|
|
|
|
-1
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
# loop to reduce peak mem consumption
|
|
|
|
# loop to reduce peak mem consumption
|
|
|
|
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
|
|
|
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
|
|
|
logsumexp_values = torch.logsumexp(logits, dim=-1)
|
|
|
|
logsumexp_values = torch.logsumexp(logits, dim=-1)
|
|
|
|
per_token_logps = (
|
|
|
|
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
|
|
|
selected_logits - logsumexp_values
|
|
|
|
|
|
|
|
) # log_softmax(x_i) = x_i - logsumexp(x)
|
|
|
|
|
|
|
|
return per_token_logps
|
|
|
|
return per_token_logps
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -139,17 +135,11 @@ class UnslothEfficientGRPO(torch.autograd.Function):
|
|
|
|
scaler=None,
|
|
|
|
scaler=None,
|
|
|
|
n_chunks=1,
|
|
|
|
n_chunks=1,
|
|
|
|
):
|
|
|
|
):
|
|
|
|
def compute_loss(
|
|
|
|
def compute_loss(new_hidden_states, old_hidden_states, input_ids, mask, advantages, scaling):
|
|
|
|
new_hidden_states, old_hidden_states, input_ids, mask, advantages, scaling
|
|
|
|
|
|
|
|
):
|
|
|
|
|
|
|
|
new_logits = torch.matmul(new_hidden_states, lm_head.t())
|
|
|
|
new_logits = torch.matmul(new_hidden_states, lm_head.t())
|
|
|
|
new_logits = new_logits[
|
|
|
|
new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
|
|
|
|
:, :-1, :
|
|
|
|
|
|
|
|
] # exclude the last logit: it corresponds to the next token pred
|
|
|
|
|
|
|
|
old_logits = torch.matmul(old_hidden_states, lm_head.t())
|
|
|
|
old_logits = torch.matmul(old_hidden_states, lm_head.t())
|
|
|
|
old_logits = old_logits[
|
|
|
|
old_logits = old_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
|
|
|
|
:, :-1, :
|
|
|
|
|
|
|
|
] # exclude the last logit: it corresponds to the next token pred
|
|
|
|
|
|
|
|
loss, completion_length, mean_kl = grpo_compute_loss(
|
|
|
|
loss, completion_length, mean_kl = grpo_compute_loss(
|
|
|
|
old_logits,
|
|
|
|
old_logits,
|
|
|
|
new_logits,
|
|
|
|
new_logits,
|
|
|
@ -311,11 +301,7 @@ def grpo_accumulated_loss(
|
|
|
|
n_chunks = bsz
|
|
|
|
n_chunks = bsz
|
|
|
|
n_chunks = factors[min(np.searchsorted(factors, n_chunks), len(factors) - 1)]
|
|
|
|
n_chunks = factors[min(np.searchsorted(factors, n_chunks), len(factors) - 1)]
|
|
|
|
|
|
|
|
|
|
|
|
mixed_dtype = (
|
|
|
|
mixed_dtype = torch.float16 if os.environ.get("ACCELERATE_MIXED_PRECISION", "fp16") == "fp16" else torch.bfloat16
|
|
|
|
torch.float16
|
|
|
|
|
|
|
|
if os.environ.get("ACCELERATE_MIXED_PRECISION", "fp16") == "fp16"
|
|
|
|
|
|
|
|
else torch.bfloat16
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1"
|
|
|
|
os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1"
|
|
|
|
|
|
|
|
|
|
|
|
completion_input_ids = input_ids[:, -logits_to_keep:]
|
|
|
|
completion_input_ids = input_ids[:, -logits_to_keep:]
|
|
|
@ -324,18 +310,12 @@ def grpo_accumulated_loss(
|
|
|
|
with torch.amp.autocast(device_type="cuda", dtype=mixed_dtype):
|
|
|
|
with torch.amp.autocast(device_type="cuda", dtype=mixed_dtype):
|
|
|
|
with (
|
|
|
|
with (
|
|
|
|
torch.inference_mode(),
|
|
|
|
torch.inference_mode(),
|
|
|
|
trainer.accelerator.unwrap_model(
|
|
|
|
trainer.accelerator.unwrap_model(trainer.model, keep_fp32_wrapper=False).disable_adapter(),
|
|
|
|
trainer.model, keep_fp32_wrapper=False
|
|
|
|
|
|
|
|
).disable_adapter(),
|
|
|
|
|
|
|
|
):
|
|
|
|
):
|
|
|
|
old_hidden_states = trainer.model(
|
|
|
|
old_hidden_states = trainer.model(input_ids=input_ids, logits_to_keep=logits_to_keep + 1).logits
|
|
|
|
input_ids=input_ids, logits_to_keep=logits_to_keep + 1
|
|
|
|
|
|
|
|
).logits
|
|
|
|
|
|
|
|
pass
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
new_hidden_states = trainer.model(
|
|
|
|
new_hidden_states = trainer.model(input_ids=input_ids, logits_to_keep=logits_to_keep + 1).logits
|
|
|
|
input_ids=input_ids, logits_to_keep=logits_to_keep + 1
|
|
|
|
|
|
|
|
).logits
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loss, completion_length, mean_kl = UnslothEfficientGRPO.apply(
|
|
|
|
loss, completion_length, mean_kl = UnslothEfficientGRPO.apply(
|
|
|
|
new_hidden_states,
|
|
|
|
new_hidden_states,
|
|
|
@ -352,13 +332,9 @@ def grpo_accumulated_loss(
|
|
|
|
|
|
|
|
|
|
|
|
# Old non efficient code path
|
|
|
|
# Old non efficient code path
|
|
|
|
new_logits = torch.matmul(new_hidden_states, lm_head.t())
|
|
|
|
new_logits = torch.matmul(new_hidden_states, lm_head.t())
|
|
|
|
new_logits = new_logits[
|
|
|
|
new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
|
|
|
|
:, :-1, :
|
|
|
|
|
|
|
|
] # exclude the last logit: it corresponds to the next token pred
|
|
|
|
|
|
|
|
old_logits = torch.matmul(old_hidden_states, lm_head.t())
|
|
|
|
old_logits = torch.matmul(old_hidden_states, lm_head.t())
|
|
|
|
old_logits = old_logits[
|
|
|
|
old_logits = old_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
|
|
|
|
:, :-1, :
|
|
|
|
|
|
|
|
] # exclude the last logit: it corresponds to the next token pred
|
|
|
|
|
|
|
|
loss, completion_length, mean_kl = grpo_compute_loss(
|
|
|
|
loss, completion_length, mean_kl = grpo_compute_loss(
|
|
|
|
old_logits,
|
|
|
|
old_logits,
|
|
|
|
new_logits,
|
|
|
|
new_logits,
|
|
|
@ -824,24 +800,14 @@ class _UnslothGRPOTrainer(Trainer):
|
|
|
|
reward_funcs: Union[RewardFunc, list[RewardFunc]],
|
|
|
|
reward_funcs: Union[RewardFunc, list[RewardFunc]],
|
|
|
|
args: GRPOConfig = None,
|
|
|
|
args: GRPOConfig = None,
|
|
|
|
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
|
|
|
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
|
|
|
eval_dataset: Optional[
|
|
|
|
eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
|
|
|
|
Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]
|
|
|
|
|
|
|
|
] = None,
|
|
|
|
|
|
|
|
processing_class: Optional[PreTrainedTokenizerBase] = None,
|
|
|
|
processing_class: Optional[PreTrainedTokenizerBase] = None,
|
|
|
|
reward_processing_classes: Optional[
|
|
|
|
reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
|
|
|
|
Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]
|
|
|
|
|
|
|
|
] = None,
|
|
|
|
|
|
|
|
callbacks: Optional[list[TrainerCallback]] = None,
|
|
|
|
callbacks: Optional[list[TrainerCallback]] = None,
|
|
|
|
optimizers: tuple[
|
|
|
|
optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
|
|
|
|
Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]
|
|
|
|
|
|
|
|
] = (None, None),
|
|
|
|
|
|
|
|
peft_config: Optional["PeftConfig"] = None,
|
|
|
|
peft_config: Optional["PeftConfig"] = None,
|
|
|
|
):
|
|
|
|
):
|
|
|
|
if (
|
|
|
|
if hasattr(model, "vllm_engine") and hasattr(args, "use_vllm") and (getattr(args, "use_vllm", False) == False):
|
|
|
|
hasattr(model, "vllm_engine")
|
|
|
|
|
|
|
|
and hasattr(args, "use_vllm")
|
|
|
|
|
|
|
|
and (getattr(args, "use_vllm", False) == False)
|
|
|
|
|
|
|
|
):
|
|
|
|
|
|
|
|
args.use_vllm = True
|
|
|
|
args.use_vllm = True
|
|
|
|
# Args
|
|
|
|
# Args
|
|
|
|
if args is None:
|
|
|
|
if args is None:
|
|
|
@ -855,11 +821,7 @@ class _UnslothGRPOTrainer(Trainer):
|
|
|
|
if isinstance(model, str):
|
|
|
|
if isinstance(model, str):
|
|
|
|
model_id = model
|
|
|
|
model_id = model
|
|
|
|
torch_dtype = model_init_kwargs.get("torch_dtype")
|
|
|
|
torch_dtype = model_init_kwargs.get("torch_dtype")
|
|
|
|
if (
|
|
|
|
if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
|
|
|
|
isinstance(torch_dtype, torch.dtype)
|
|
|
|
|
|
|
|
or torch_dtype == "auto"
|
|
|
|
|
|
|
|
or torch_dtype is None
|
|
|
|
|
|
|
|
):
|
|
|
|
|
|
|
|
pass # torch_dtype is already a torch.dtype or "auto" or None
|
|
|
|
pass # torch_dtype is already a torch.dtype or "auto" or None
|
|
|
|
elif isinstance(torch_dtype, str): # it's a str, but not "auto"
|
|
|
|
elif isinstance(torch_dtype, str): # it's a str, but not "auto"
|
|
|
|
torch_dtype = getattr(torch, torch_dtype)
|
|
|
|
torch_dtype = getattr(torch, torch_dtype)
|
|
|
@ -871,9 +833,7 @@ class _UnslothGRPOTrainer(Trainer):
|
|
|
|
)
|
|
|
|
)
|
|
|
|
# Disable caching if gradient checkpointing is enabled (not supported)
|
|
|
|
# Disable caching if gradient checkpointing is enabled (not supported)
|
|
|
|
model_init_kwargs["use_cache"] = (
|
|
|
|
model_init_kwargs["use_cache"] = (
|
|
|
|
False
|
|
|
|
False if args.gradient_checkpointing else model_init_kwargs.get("use_cache")
|
|
|
|
if args.gradient_checkpointing
|
|
|
|
|
|
|
|
else model_init_kwargs.get("use_cache")
|
|
|
|
|
|
|
|
)
|
|
|
|
)
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
@ -889,9 +849,7 @@ class _UnslothGRPOTrainer(Trainer):
|
|
|
|
|
|
|
|
|
|
|
|
# Reference model
|
|
|
|
# Reference model
|
|
|
|
if is_deepspeed_zero3_enabled():
|
|
|
|
if is_deepspeed_zero3_enabled():
|
|
|
|
self.ref_model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
|
|
|
|
model_id, **model_init_kwargs
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
elif not is_peft_model(model):
|
|
|
|
elif not is_peft_model(model):
|
|
|
|
# If PEFT configuration is not provided, create a reference model based on the initial model.
|
|
|
|
# If PEFT configuration is not provided, create a reference model based on the initial model.
|
|
|
|
self.ref_model = create_reference_model(model)
|
|
|
|
self.ref_model = create_reference_model(model)
|
|
|
@ -902,9 +860,7 @@ class _UnslothGRPOTrainer(Trainer):
|
|
|
|
|
|
|
|
|
|
|
|
# Processing class
|
|
|
|
# Processing class
|
|
|
|
if processing_class is None:
|
|
|
|
if processing_class is None:
|
|
|
|
processing_class = AutoTokenizer.from_pretrained(
|
|
|
|
processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path, padding_side="left")
|
|
|
|
model.config._name_or_path, padding_side="left"
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Reward functions
|
|
|
|
# Reward functions
|
|
|
|
if not isinstance(reward_funcs, list):
|
|
|
|
if not isinstance(reward_funcs, list):
|
|
|
@ -934,22 +890,14 @@ class _UnslothGRPOTrainer(Trainer):
|
|
|
|
reward_processing_classes = [reward_processing_classes]
|
|
|
|
reward_processing_classes = [reward_processing_classes]
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
if len(reward_processing_classes) != len(reward_funcs):
|
|
|
|
if len(reward_processing_classes) != len(reward_funcs):
|
|
|
|
raise ValueError(
|
|
|
|
raise ValueError("The number of reward processing classes must match the number of reward functions.")
|
|
|
|
"The number of reward processing classes must match the number of reward functions."
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for i, (reward_processing_class, reward_func) in enumerate(
|
|
|
|
for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
|
|
|
|
zip(reward_processing_classes, reward_funcs)
|
|
|
|
|
|
|
|
):
|
|
|
|
|
|
|
|
if isinstance(reward_func, PreTrainedModel):
|
|
|
|
if isinstance(reward_func, PreTrainedModel):
|
|
|
|
if reward_processing_class is None:
|
|
|
|
if reward_processing_class is None:
|
|
|
|
reward_processing_class = AutoTokenizer.from_pretrained(
|
|
|
|
reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
|
|
|
|
reward_func.config._name_or_path
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
if reward_processing_class.pad_token_id is None:
|
|
|
|
if reward_processing_class.pad_token_id is None:
|
|
|
|
reward_processing_class.pad_token = (
|
|
|
|
reward_processing_class.pad_token = reward_processing_class.eos_token
|
|
|
|
reward_processing_class.eos_token
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
# The reward model computes the reward for the latest non-padded token in the input sequence.
|
|
|
|
# The reward model computes the reward for the latest non-padded token in the input sequence.
|
|
|
|
# So it's important to set the pad token ID to the padding token ID of the processing class.
|
|
|
|
# So it's important to set the pad token ID to the padding token ID of the processing class.
|
|
|
|
reward_func.config.pad_token_id = reward_processing_class.pad_token_id
|
|
|
|
reward_func.config.pad_token_id = reward_processing_class.pad_token_id
|
|
|
@ -962,9 +910,7 @@ class _UnslothGRPOTrainer(Trainer):
|
|
|
|
|
|
|
|
|
|
|
|
# Training arguments
|
|
|
|
# Training arguments
|
|
|
|
self.max_prompt_length = args.max_prompt_length
|
|
|
|
self.max_prompt_length = args.max_prompt_length
|
|
|
|
self.max_completion_length = (
|
|
|
|
self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
|
|
|
|
args.max_completion_length
|
|
|
|
|
|
|
|
) # = |o_i| in the GRPO paper
|
|
|
|
|
|
|
|
self.num_generations = args.num_generations # = G in the GRPO paper
|
|
|
|
self.num_generations = args.num_generations # = G in the GRPO paper
|
|
|
|
self.use_vllm = args.use_vllm
|
|
|
|
self.use_vllm = args.use_vllm
|
|
|
|
self.use_agentic_generate = args.use_agentic_generate
|
|
|
|
self.use_agentic_generate = args.use_agentic_generate
|
|
|
@ -997,11 +943,7 @@ class _UnslothGRPOTrainer(Trainer):
|
|
|
|
# Check if the per_device_train/eval_batch_size * num processes can be divided by the number of generations
|
|
|
|
# Check if the per_device_train/eval_batch_size * num processes can be divided by the number of generations
|
|
|
|
num_processes = self.accelerator.num_processes
|
|
|
|
num_processes = self.accelerator.num_processes
|
|
|
|
global_batch_size = args.per_device_train_batch_size * num_processes
|
|
|
|
global_batch_size = args.per_device_train_batch_size * num_processes
|
|
|
|
possible_values = [
|
|
|
|
possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
|
|
|
|
n_gen
|
|
|
|
|
|
|
|
for n_gen in range(2, global_batch_size + 1)
|
|
|
|
|
|
|
|
if (global_batch_size) % n_gen == 0
|
|
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
if self.num_generations not in possible_values:
|
|
|
|
if self.num_generations not in possible_values:
|
|
|
|
raise ValueError(
|
|
|
|
raise ValueError(
|
|
|
|
f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly "
|
|
|
|
f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly "
|
|
|
@ -1010,11 +952,7 @@ class _UnslothGRPOTrainer(Trainer):
|
|
|
|
)
|
|
|
|
)
|
|
|
|
if self.args.eval_strategy != "no":
|
|
|
|
if self.args.eval_strategy != "no":
|
|
|
|
global_batch_size = args.per_device_eval_batch_size * num_processes
|
|
|
|
global_batch_size = args.per_device_eval_batch_size * num_processes
|
|
|
|
possible_values = [
|
|
|
|
possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
|
|
|
|
n_gen
|
|
|
|
|
|
|
|
for n_gen in range(2, global_batch_size + 1)
|
|
|
|
|
|
|
|
if (global_batch_size) % n_gen == 0
|
|
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
if self.num_generations not in possible_values:
|
|
|
|
if self.num_generations not in possible_values:
|
|
|
|
raise ValueError(
|
|
|
|
raise ValueError(
|
|
|
|
f"The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly "
|
|
|
|
f"The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly "
|
|
|
@ -1059,22 +997,14 @@ class _UnslothGRPOTrainer(Trainer):
|
|
|
|
if self.is_deepspeed_enabled:
|
|
|
|
if self.is_deepspeed_enabled:
|
|
|
|
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
|
|
|
|
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
self.ref_model = self.accelerator.prepare_model(
|
|
|
|
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
|
|
|
self.ref_model, evaluation_mode=True
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if args.sync_ref_model:
|
|
|
|
if args.sync_ref_model:
|
|
|
|
self.add_callback(
|
|
|
|
self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator))
|
|
|
|
SyncRefModelCallback(
|
|
|
|
|
|
|
|
ref_model=self.ref_model, accelerator=self.accelerator
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for i, reward_func in enumerate(self.reward_funcs):
|
|
|
|
for i, reward_func in enumerate(self.reward_funcs):
|
|
|
|
if isinstance(reward_func, PreTrainedModel):
|
|
|
|
if isinstance(reward_func, PreTrainedModel):
|
|
|
|
self.reward_funcs[i] = self.accelerator.prepare_model(
|
|
|
|
self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
|
|
|
|
reward_func, evaluation_mode=True
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _set_signature_columns_if_needed(self):
|
|
|
|
def _set_signature_columns_if_needed(self):
|
|
|
|
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
|
|
|
|
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
|
|
|
@ -1089,27 +1019,21 @@ class _UnslothGRPOTrainer(Trainer):
|
|
|
|
# identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly
|
|
|
|
# identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly
|
|
|
|
# within each prompt group. Using the same seed across processes ensures consistent prompt assignment,
|
|
|
|
# within each prompt group. Using the same seed across processes ensures consistent prompt assignment,
|
|
|
|
# preventing discrepancies in group formation.
|
|
|
|
# preventing discrepancies in group formation.
|
|
|
|
return RepeatRandomSampler(
|
|
|
|
return RepeatRandomSampler(self.train_dataset, self.num_generations, seed=self.args.seed)
|
|
|
|
self.train_dataset, self.num_generations, seed=self.args.seed
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_eval_sampler(self, eval_dataset) -> Sampler:
|
|
|
|
def _get_eval_sampler(self, eval_dataset) -> Sampler:
|
|
|
|
# Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that
|
|
|
|
# Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that
|
|
|
|
# identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly
|
|
|
|
# identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly
|
|
|
|
# within each prompt group. Using the same seed across processes ensures consistent prompt assignment,
|
|
|
|
# within each prompt group. Using the same seed across processes ensures consistent prompt assignment,
|
|
|
|
# preventing discrepancies in group formation.
|
|
|
|
# preventing discrepancies in group formation.
|
|
|
|
return RepeatRandomSampler(
|
|
|
|
return RepeatRandomSampler(eval_dataset, self.num_generations, seed=self.args.seed)
|
|
|
|
eval_dataset, self.num_generations, seed=self.args.seed
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Get the per-token log probabilities for the completions for the model and the reference model
|
|
|
|
# Get the per-token log probabilities for the completions for the model and the reference model
|
|
|
|
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
|
|
|
|
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
|
|
|
|
return None # Unsloth efficient GRPO
|
|
|
|
return None # Unsloth efficient GRPO
|
|
|
|
if not hasattr(self, "_autocast_dtype"):
|
|
|
|
if not hasattr(self, "_autocast_dtype"):
|
|
|
|
self._autocast_dtype = (
|
|
|
|
self._autocast_dtype = (
|
|
|
|
torch.float16
|
|
|
|
torch.float16 if os.environ.get("ACCELERATE_MIXED_PRECISION", "fp16") == "fp16" else torch.bfloat16
|
|
|
|
if os.environ.get("ACCELERATE_MIXED_PRECISION", "fp16") == "fp16"
|
|
|
|
|
|
|
|
else torch.bfloat16
|
|
|
|
|
|
|
|
)
|
|
|
|
)
|
|
|
|
with torch.amp.autocast(device_type="cuda", dtype=self._autocast_dtype):
|
|
|
|
with torch.amp.autocast(device_type="cuda", dtype=self._autocast_dtype):
|
|
|
|
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
|
|
|
|
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
|
|
|
@ -1118,9 +1042,7 @@ class _UnslothGRPOTrainer(Trainer):
|
|
|
|
attention_mask=attention_mask,
|
|
|
|
attention_mask=attention_mask,
|
|
|
|
logits_to_keep=logits_to_keep + 1,
|
|
|
|
logits_to_keep=logits_to_keep + 1,
|
|
|
|
).logits
|
|
|
|
).logits
|
|
|
|
logits = logits[
|
|
|
|
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
|
|
|
|
:, :-1, :
|
|
|
|
|
|
|
|
] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_ids = input_ids[:, -logits_to_keep:]
|
|
|
|
input_ids = input_ids[:, -logits_to_keep:]
|
|
|
|
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
|
|
|
|
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
|
|
|
@ -1133,15 +1055,10 @@ class _UnslothGRPOTrainer(Trainer):
|
|
|
|
def _move_model_to_vllm(self, *args, **kwargs):
|
|
|
|
def _move_model_to_vllm(self, *args, **kwargs):
|
|
|
|
return None
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
def _prepare_inputs(
|
|
|
|
def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
|
|
|
|
self, inputs: dict[str, Union[torch.Tensor, Any]]
|
|
|
|
|
|
|
|
) -> dict[str, Union[torch.Tensor, Any]]:
|
|
|
|
|
|
|
|
device = self.accelerator.device
|
|
|
|
device = self.accelerator.device
|
|
|
|
prompts = [x["prompt"] for x in inputs]
|
|
|
|
prompts = [x["prompt"] for x in inputs]
|
|
|
|
prompts_text = [
|
|
|
|
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
|
|
|
|
maybe_apply_chat_template(example, self.processing_class)["prompt"]
|
|
|
|
|
|
|
|
for example in inputs
|
|
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
prompt_inputs = self.processing_class(
|
|
|
|
prompt_inputs = self.processing_class(
|
|
|
|
prompts_text,
|
|
|
|
prompts_text,
|
|
|
|
return_tensors="pt",
|
|
|
|
return_tensors="pt",
|
|
|
@ -1174,9 +1091,7 @@ class _UnslothGRPOTrainer(Trainer):
|
|
|
|
prompts_text,
|
|
|
|
prompts_text,
|
|
|
|
sampling_params=self.sampling_params,
|
|
|
|
sampling_params=self.sampling_params,
|
|
|
|
use_tqdm=False,
|
|
|
|
use_tqdm=False,
|
|
|
|
lora_request=self.model.load_lora(
|
|
|
|
lora_request=self.model.load_lora("grpo_trainer_lora_model", load_tensors=True),
|
|
|
|
"grpo_trainer_lora_model", load_tensors=True
|
|
|
|
|
|
|
|
),
|
|
|
|
|
|
|
|
)
|
|
|
|
)
|
|
|
|
if self.use_agentic_generate:
|
|
|
|
if self.use_agentic_generate:
|
|
|
|
agentic_outputs = self.model.agentic_generate(
|
|
|
|
agentic_outputs = self.model.agentic_generate(
|
|
|
@ -1201,11 +1116,7 @@ class _UnslothGRPOTrainer(Trainer):
|
|
|
|
).to(device)
|
|
|
|
).to(device)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
outputs = generate_fn(all_prompts_text)
|
|
|
|
outputs = generate_fn(all_prompts_text)
|
|
|
|
completion_ids = [
|
|
|
|
completion_ids = [out.token_ids for completions in outputs for out in completions.outputs]
|
|
|
|
out.token_ids
|
|
|
|
|
|
|
|
for completions in outputs
|
|
|
|
|
|
|
|
for out in completions.outputs
|
|
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
completion_ids = [None] * len(all_prompts_text)
|
|
|
|
completion_ids = [None] * len(all_prompts_text)
|
|
|
|
# Broadcast the completions from the main process to all processes, ensuring each process receives its
|
|
|
|
# Broadcast the completions from the main process to all processes, ensuring each process receives its
|
|
|
@ -1218,18 +1129,12 @@ class _UnslothGRPOTrainer(Trainer):
|
|
|
|
completion_ids = completion_ids[process_slice]
|
|
|
|
completion_ids = completion_ids[process_slice]
|
|
|
|
|
|
|
|
|
|
|
|
# Pad the completions, and concatenate them with the prompts
|
|
|
|
# Pad the completions, and concatenate them with the prompts
|
|
|
|
completion_ids = [
|
|
|
|
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
|
|
|
|
torch.tensor(ids, device=device) for ids in completion_ids
|
|
|
|
completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
|
|
|
|
]
|
|
|
|
|
|
|
|
completion_ids = pad(
|
|
|
|
|
|
|
|
completion_ids, padding_value=self.processing_class.pad_token_id
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
|
|
|
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
# Regular generation path
|
|
|
|
# Regular generation path
|
|
|
|
with unwrap_model_for_generation(
|
|
|
|
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
|
|
|
|
self.model, self.accelerator
|
|
|
|
|
|
|
|
) as unwrapped_model:
|
|
|
|
|
|
|
|
prompt_completion_ids = unwrapped_model.generate(
|
|
|
|
prompt_completion_ids = unwrapped_model.generate(
|
|
|
|
prompt_ids,
|
|
|
|
prompt_ids,
|
|
|
|
attention_mask=prompt_mask,
|
|
|
|
attention_mask=prompt_mask,
|
|
|
@ -1244,21 +1149,15 @@ class _UnslothGRPOTrainer(Trainer):
|
|
|
|
if not self.use_agentic_generate:
|
|
|
|
if not self.use_agentic_generate:
|
|
|
|
# Mask everything after the first EOS token
|
|
|
|
# Mask everything after the first EOS token
|
|
|
|
is_eos = completion_ids == self.processing_class.eos_token_id
|
|
|
|
is_eos = completion_ids == self.processing_class.eos_token_id
|
|
|
|
eos_idx = torch.full(
|
|
|
|
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
|
|
|
|
(is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
|
|
|
|
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
|
|
|
|
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(
|
|
|
|
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
|
|
|
|
is_eos.size(0), -1
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
|
|
|
|
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
|
|
|
|
|
|
|
|
|
|
|
|
# Concatenate prompt_mask with completion_mask for logit computation
|
|
|
|
# Concatenate prompt_mask with completion_mask for logit computation
|
|
|
|
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C)
|
|
|
|
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C)
|
|
|
|
|
|
|
|
|
|
|
|
logits_to_keep = completion_ids.size(
|
|
|
|
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
|
|
|
|
1
|
|
|
|
|
|
|
|
) # we only need to compute the logits for the completion tokens
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# this does nothing
|
|
|
|
# this does nothing
|
|
|
|
with (
|
|
|
|
with (
|
|
|
@ -1280,9 +1179,7 @@ class _UnslothGRPOTrainer(Trainer):
|
|
|
|
logits_to_keep,
|
|
|
|
logits_to_keep,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
with self.accelerator.unwrap_model(
|
|
|
|
with self.accelerator.unwrap_model(self.model, keep_fp32_wrapper=False).disable_adapter():
|
|
|
|
self.model, keep_fp32_wrapper=False
|
|
|
|
|
|
|
|
).disable_adapter():
|
|
|
|
|
|
|
|
ref_per_token_logps = self._get_per_token_logps(
|
|
|
|
ref_per_token_logps = self._get_per_token_logps(
|
|
|
|
self.model,
|
|
|
|
self.model,
|
|
|
|
prompt_completion_ids,
|
|
|
|
prompt_completion_ids,
|
|
|
@ -1292,42 +1189,25 @@ class _UnslothGRPOTrainer(Trainer):
|
|
|
|
|
|
|
|
|
|
|
|
# Decode the generated completions
|
|
|
|
# Decode the generated completions
|
|
|
|
if not self.use_agentic_generate:
|
|
|
|
if not self.use_agentic_generate:
|
|
|
|
completions_text = self.processing_class.batch_decode(
|
|
|
|
completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
|
|
|
|
completion_ids, skip_special_tokens=True
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
if is_conversational(inputs[0]):
|
|
|
|
if is_conversational(inputs[0]):
|
|
|
|
completions = []
|
|
|
|
completions = []
|
|
|
|
for prompt, completion in zip(prompts, completions_text):
|
|
|
|
for prompt, completion in zip(prompts, completions_text):
|
|
|
|
bootstrap = (
|
|
|
|
bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
|
|
|
|
prompt.pop()["content"]
|
|
|
|
completions.append([{"role": "assistant", "content": bootstrap + completion}])
|
|
|
|
if prompt[-1]["role"] == "assistant"
|
|
|
|
|
|
|
|
else ""
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
completions.append(
|
|
|
|
|
|
|
|
[{"role": "assistant", "content": bootstrap + completion}]
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
completions = completions_text
|
|
|
|
completions = completions_text
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
completions = full_chats
|
|
|
|
completions = full_chats
|
|
|
|
|
|
|
|
|
|
|
|
rewards_per_func = torch.zeros(
|
|
|
|
rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
|
|
|
|
len(prompts), len(self.reward_funcs), device=device
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
for i, (reward_func, reward_processing_class) in enumerate(
|
|
|
|
for i, (reward_func, reward_processing_class) in enumerate(
|
|
|
|
zip(self.reward_funcs, self.reward_processing_classes)
|
|
|
|
zip(self.reward_funcs, self.reward_processing_classes)
|
|
|
|
):
|
|
|
|
):
|
|
|
|
if isinstance(
|
|
|
|
if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
|
|
|
|
reward_func, nn.Module
|
|
|
|
|
|
|
|
): # Module instead of PretrainedModel for compat with compiled models
|
|
|
|
|
|
|
|
if is_conversational(inputs[0]):
|
|
|
|
if is_conversational(inputs[0]):
|
|
|
|
messages = [
|
|
|
|
messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
|
|
|
|
{"messages": p + c} for p, c in zip(prompts, completions)
|
|
|
|
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
|
|
|
|
]
|
|
|
|
|
|
|
|
texts = [
|
|
|
|
|
|
|
|
apply_chat_template(x, reward_processing_class)["text"]
|
|
|
|
|
|
|
|
for x in messages
|
|
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
texts = [p + c for p, c in zip(prompts, completions)]
|
|
|
|
texts = [p + c for p, c in zip(prompts, completions)]
|
|
|
|
reward_inputs = reward_processing_class(
|
|
|
|
reward_inputs = reward_processing_class(
|
|
|
@ -1343,36 +1223,25 @@ class _UnslothGRPOTrainer(Trainer):
|
|
|
|
torch.amp.autocast(
|
|
|
|
torch.amp.autocast(
|
|
|
|
device_type="cuda",
|
|
|
|
device_type="cuda",
|
|
|
|
dtype=torch.float16
|
|
|
|
dtype=torch.float16
|
|
|
|
if os.environ.get("ACCELERATE_MIXED_PRECISION", "fp16")
|
|
|
|
if os.environ.get("ACCELERATE_MIXED_PRECISION", "fp16") == "fp16"
|
|
|
|
== "fp16"
|
|
|
|
|
|
|
|
else torch.bfloat16,
|
|
|
|
else torch.bfloat16,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
if not torch.is_autocast_enabled("cuda")
|
|
|
|
if not torch.is_autocast_enabled("cuda")
|
|
|
|
else nullcontext(),
|
|
|
|
else nullcontext(),
|
|
|
|
):
|
|
|
|
):
|
|
|
|
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[
|
|
|
|
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
|
|
|
|
:, 0
|
|
|
|
|
|
|
|
] # Shape (B*G,)
|
|
|
|
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
|
|
|
|
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
|
|
|
|
keys = [key for key in inputs[0] if key not in ["prompt", "completion"]]
|
|
|
|
keys = [key for key in inputs[0] if key not in ["prompt", "completion"]]
|
|
|
|
reward_kwargs = {
|
|
|
|
reward_kwargs = {key: [example[key] for example in inputs] for key in keys}
|
|
|
|
key: [example[key] for example in inputs] for key in keys
|
|
|
|
output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
|
|
|
|
}
|
|
|
|
rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
|
|
|
|
output_reward_func = reward_func(
|
|
|
|
|
|
|
|
prompts=prompts, completions=completions, **reward_kwargs
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
rewards_per_func[:, i] = torch.tensor(
|
|
|
|
|
|
|
|
output_reward_func, dtype=torch.float32, device=device
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Gather the reward per function: this part is crucial, because the rewards are normalized per group and the
|
|
|
|
# Gather the reward per function: this part is crucial, because the rewards are normalized per group and the
|
|
|
|
# completions may be distributed across processes
|
|
|
|
# completions may be distributed across processes
|
|
|
|
rewards_per_func = gather(rewards_per_func)
|
|
|
|
rewards_per_func = gather(rewards_per_func)
|
|
|
|
# Apply weights to each reward function's output and sum
|
|
|
|
# Apply weights to each reward function's output and sum
|
|
|
|
rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).sum(
|
|
|
|
rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).sum(dim=1)
|
|
|
|
dim=1
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# else:
|
|
|
|
# else:
|
|
|
|
# reward_fn = self.reward_funcs[0]
|
|
|
|
# reward_fn = self.reward_funcs[0]
|
|
|
@ -1391,12 +1260,8 @@ class _UnslothGRPOTrainer(Trainer):
|
|
|
|
std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
|
|
|
|
std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
# Normalize the rewards to compute the advantages
|
|
|
|
# Normalize the rewards to compute the advantages
|
|
|
|
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(
|
|
|
|
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
|
|
|
|
self.num_generations, dim=0
|
|
|
|
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
|
|
|
|
)
|
|
|
|
|
|
|
|
std_grouped_rewards = std_grouped_rewards.repeat_interleave(
|
|
|
|
|
|
|
|
self.num_generations, dim=0
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
|
|
|
|
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
|
|
|
|
|
|
|
|
|
|
|
|
# Slice to keep only the local part of the data
|
|
|
|
# Slice to keep only the local part of the data
|
|
|
@ -1410,15 +1275,11 @@ class _UnslothGRPOTrainer(Trainer):
|
|
|
|
reward_per_func = rewards_per_func.mean(0)
|
|
|
|
reward_per_func = rewards_per_func.mean(0)
|
|
|
|
print("rewards_per_func:", reward_per_func)
|
|
|
|
print("rewards_per_func:", reward_per_func)
|
|
|
|
for i, reward_func in enumerate(self.reward_funcs):
|
|
|
|
for i, reward_func in enumerate(self.reward_funcs):
|
|
|
|
if isinstance(
|
|
|
|
if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
|
|
|
|
reward_func, nn.Module
|
|
|
|
|
|
|
|
): # Module instead of PretrainedModel for compat with compiled models
|
|
|
|
|
|
|
|
reward_func_name = reward_func.config._name_or_path.split("/")[-1]
|
|
|
|
reward_func_name = reward_func.config._name_or_path.split("/")[-1]
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
reward_func_name = reward_func.__name__
|
|
|
|
reward_func_name = reward_func.__name__
|
|
|
|
self._metrics[f"rewards/{reward_func_name}"].append(
|
|
|
|
self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
|
|
|
|
reward_per_func[i].item()
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self._metrics["reward"].append(rewards.mean().item())
|
|
|
|
self._metrics["reward"].append(rewards.mean().item())
|
|
|
|
self._metrics["reward_std"].append(std_grouped_rewards.mean().item())
|
|
|
|
self._metrics["reward_std"].append(std_grouped_rewards.mean().item())
|
|
|
@ -1451,9 +1312,7 @@ class _UnslothGRPOTrainer(Trainer):
|
|
|
|
"advantages": advantages,
|
|
|
|
"advantages": advantages,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
def compute_loss(
|
|
|
|
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
|
|
|
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
|
|
|
|
|
|
|
):
|
|
|
|
|
|
|
|
if return_outputs:
|
|
|
|
if return_outputs:
|
|
|
|
raise ValueError("The GRPOTrainer does not support returning outputs")
|
|
|
|
raise ValueError("The GRPOTrainer does not support returning outputs")
|
|
|
|
# Compute the per-token log probabilities for the model
|
|
|
|
# Compute the per-token log probabilities for the model
|
|
|
@ -1467,14 +1326,10 @@ class _UnslothGRPOTrainer(Trainer):
|
|
|
|
bsz, qlen = input_ids.shape
|
|
|
|
bsz, qlen = input_ids.shape
|
|
|
|
# attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
|
|
|
|
# attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
|
|
|
|
attention_mask = None
|
|
|
|
attention_mask = None
|
|
|
|
logits_to_keep = completion_ids.size(
|
|
|
|
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
|
|
|
|
1
|
|
|
|
|
|
|
|
) # we only need to compute the logits for the completion tokens
|
|
|
|
|
|
|
|
_input_ids = input_ids
|
|
|
|
_input_ids = input_ids
|
|
|
|
_logits_to_keep = logits_to_keep
|
|
|
|
_logits_to_keep = logits_to_keep
|
|
|
|
per_token_logps = self._get_per_token_logps(
|
|
|
|
per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
|
|
|
|
model, input_ids, attention_mask, logits_to_keep
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Compute the KL divergence between the model and the reference model
|
|
|
|
# Compute the KL divergence between the model and the reference model
|
|
|
|
ref_per_token_logps = inputs["ref_per_token_logps"]
|
|
|
|
ref_per_token_logps = inputs["ref_per_token_logps"]
|
|
|
@ -1529,9 +1384,7 @@ class _UnslothGRPOTrainer(Trainer):
|
|
|
|
return loss, None, None
|
|
|
|
return loss, None, None
|
|
|
|
|
|
|
|
|
|
|
|
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
|
|
|
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
|
|
|
metrics = {
|
|
|
|
metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
|
|
|
|
key: sum(val) / len(val) for key, val in self._metrics.items()
|
|
|
|
|
|
|
|
} # average the metrics
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
|
|
|
|
# This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
|
|
|
|
# start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
|
|
|
|
# start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
|
|
|
@ -1565,9 +1418,7 @@ class _UnslothGRPOTrainer(Trainer):
|
|
|
|
if not self.is_world_process_zero():
|
|
|
|
if not self.is_world_process_zero():
|
|
|
|
return
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(
|
|
|
|
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
|
|
|
self.model.config._name_or_path
|
|
|
|
|
|
|
|
):
|
|
|
|
|
|
|
|
base_model = self.model.config._name_or_path
|
|
|
|
base_model = self.model.config._name_or_path
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
base_model = None
|
|
|
|
base_model = None
|
|
|
@ -1596,9 +1447,7 @@ class _UnslothGRPOTrainer(Trainer):
|
|
|
|
hub_model_id=self.hub_model_id,
|
|
|
|
hub_model_id=self.hub_model_id,
|
|
|
|
dataset_name=dataset_name,
|
|
|
|
dataset_name=dataset_name,
|
|
|
|
tags=tags,
|
|
|
|
tags=tags,
|
|
|
|
wandb_url=wandb.run.get_url()
|
|
|
|
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
|
|
|
if is_wandb_available() and wandb.run is not None
|
|
|
|
|
|
|
|
else None,
|
|
|
|
|
|
|
|
comet_url=get_comet_experiment_url(),
|
|
|
|
comet_url=get_comet_experiment_url(),
|
|
|
|
trainer_name="GRPO",
|
|
|
|
trainer_name="GRPO",
|
|
|
|
trainer_citation=citation,
|
|
|
|
trainer_citation=citation,
|
|
|
@ -1735,10 +1584,7 @@ class UnslothGRPOTrainer(_UnslothGRPOTrainer):
|
|
|
|
args.fp16 = float16
|
|
|
|
args.fp16 = float16
|
|
|
|
args.bf16 = not float16
|
|
|
|
args.bf16 = not float16
|
|
|
|
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16" if float16 else "bf16"
|
|
|
|
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16" if float16 else "bf16"
|
|
|
|
if (
|
|
|
|
if getattr(args, "eval_dataset", None) is not None and getattr(args, "eval_strategy", "no") == "no":
|
|
|
|
getattr(args, "eval_dataset", None) is not None
|
|
|
|
|
|
|
|
and getattr(args, "eval_strategy", "no") == "no"
|
|
|
|
|
|
|
|
):
|
|
|
|
|
|
|
|
args.eval_strategy = "steps"
|
|
|
|
args.eval_strategy = "steps"
|
|
|
|
if getattr(args, "eval_steps", None) is None:
|
|
|
|
if getattr(args, "eval_steps", None) is None:
|
|
|
|
args.eval_steps = 0.1
|
|
|
|
args.eval_steps = 0.1
|
|
|
@ -1755,10 +1601,7 @@ class UnslothGRPOTrainer(_UnslothGRPOTrainer):
|
|
|
|
eval_bsz = getattr(args, "per_device_eval_batch_size", 8)
|
|
|
|
eval_bsz = getattr(args, "per_device_eval_batch_size", 8)
|
|
|
|
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz:
|
|
|
|
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz:
|
|
|
|
args.per_device_eval_batch_size = args.per_device_train_batch_size
|
|
|
|
args.per_device_eval_batch_size = args.per_device_train_batch_size
|
|
|
|
if (
|
|
|
|
if getattr(args, "eval_accumulation_steps", None) is None and ga_steps is not None:
|
|
|
|
getattr(args, "eval_accumulation_steps", None) is None
|
|
|
|
|
|
|
|
and ga_steps is not None
|
|
|
|
|
|
|
|
):
|
|
|
|
|
|
|
|
args.eval_accumulation_steps = ga_steps
|
|
|
|
args.eval_accumulation_steps = ga_steps
|
|
|
|
fp16_full_eval = getattr(args, "fp16_full_eval", False)
|
|
|
|
fp16_full_eval = getattr(args, "fp16_full_eval", False)
|
|
|
|
bf16_full_eval = getattr(args, "bf16_full_eval", False)
|
|
|
|
bf16_full_eval = getattr(args, "bf16_full_eval", False)
|
|
|
@ -1787,9 +1630,7 @@ class UnslothGRPOTrainer(_UnslothGRPOTrainer):
|
|
|
|
if "processing_class" in locals():
|
|
|
|
if "processing_class" in locals():
|
|
|
|
if hasattr(processing_class, "padding_side"):
|
|
|
|
if hasattr(processing_class, "padding_side"):
|
|
|
|
processing_class.padding_side = "right"
|
|
|
|
processing_class.padding_side = "right"
|
|
|
|
if hasattr(processing_class, "tokenizer") and hasattr(
|
|
|
|
if hasattr(processing_class, "tokenizer") and hasattr(processing_class.tokenizer, "padding_side"):
|
|
|
|
processing_class.tokenizer, "padding_side"
|
|
|
|
|
|
|
|
):
|
|
|
|
|
|
|
|
processing_class.tokenizer.padding_side = "right"
|
|
|
|
processing_class.tokenizer.padding_side = "right"
|
|
|
|
other_metrics = []
|
|
|
|
other_metrics = []
|
|
|
|
if not isinstance(reward_funcs, list):
|
|
|
|
if not isinstance(reward_funcs, list):
|
|
|
|