style: change line length to 119, organize imports

main
thinhlpg 1 month ago
parent abb18b10d8
commit 04593fa8fd

@ -3,7 +3,7 @@
# make sure to test the local checkout in scripts and not the pre-installed one # make sure to test the local checkout in scripts and not the pre-installed one
export PYTHONPATH = src export PYTHONPATH = src
check_dirs := src check_dirs := . src notebooks scripts
# Development dependencies # Development dependencies
install: install:

@ -170,9 +170,7 @@ def test_lora_functionality(model, tokenizer, lora_path):
# Sample with base model # Sample with base model
logger.info("Generating with base model...") logger.info("Generating with base model...")
sampling_params = get_sampling_params( sampling_params = get_sampling_params(temperature=0.7) # Higher temp to make differences more obvious
temperature=0.7
) # Higher temp to make differences more obvious
base_response = model.fast_generate( base_response = model.fast_generate(
[formatted_prompt], [formatted_prompt],
sampling_params=sampling_params, sampling_params=sampling_params,
@ -217,12 +215,8 @@ def test_lora_functionality(model, tokenizer, lora_path):
logger.info("-" * 40) logger.info("-" * 40)
if are_identical: if are_identical:
logger.warning( logger.warning("\nWARNING: LoRA adapter does not seem to change the model's output")
"\nWARNING: LoRA adapter does not seem to change the model's output" logger.warning("This could indicate that the LoRA adapter is not being properly applied")
)
logger.warning(
"This could indicate that the LoRA adapter is not being properly applied"
)
else: else:
logger.info("\nLoRA adapter is working as expected (outputs are different)") logger.info("\nLoRA adapter is working as expected (outputs are different)")
@ -257,9 +251,7 @@ def evaluate_model(
# Prioritize the directory passed from the shell script if available # Prioritize the directory passed from the shell script if available
if trainer_dir and os.path.isdir(trainer_dir): if trainer_dir and os.path.isdir(trainer_dir):
trainer_output_dir = os.path.abspath(trainer_dir) trainer_output_dir = os.path.abspath(trainer_dir)
logger.info( logger.info(f"Using trainer directory passed from arguments: {trainer_output_dir}")
f"Using trainer directory passed from arguments: {trainer_output_dir}"
)
else: else:
logger.warning( logger.warning(
f"Trainer directory not provided or invalid: {trainer_dir}. Attempting to determine automatically." f"Trainer directory not provided or invalid: {trainer_dir}. Attempting to determine automatically."
@ -274,9 +266,7 @@ def evaluate_model(
# If a LoRA path exists (provided or found), get its parent's parent # If a LoRA path exists (provided or found), get its parent's parent
checkpoint_dir = os.path.dirname(os.path.abspath(temp_lora_path)) checkpoint_dir = os.path.dirname(os.path.abspath(temp_lora_path))
trainer_output_dir = os.path.dirname(checkpoint_dir) trainer_output_dir = os.path.dirname(checkpoint_dir)
logger.info( logger.info(f"Determined trainer directory from LoRA path ({temp_lora_path}): {trainer_output_dir}")
f"Determined trainer directory from LoRA path ({temp_lora_path}): {trainer_output_dir}"
)
else: else:
# If no LoRA path, default to current directory (should ideally not happen if called from eval.sh) # If no LoRA path, default to current directory (should ideally not happen if called from eval.sh)
trainer_output_dir = os.path.abspath(".") trainer_output_dir = os.path.abspath(".")
@ -290,22 +280,16 @@ def evaluate_model(
detected_checkpoint = find_latest_checkpoint(search_dir=trainer_output_dir) detected_checkpoint = find_latest_checkpoint(search_dir=trainer_output_dir)
if detected_checkpoint: if detected_checkpoint:
lora_path = detected_checkpoint lora_path = detected_checkpoint
logger.info( logger.info(f"Auto-detected latest checkpoint in {trainer_output_dir}: {lora_path}")
f"Auto-detected latest checkpoint in {trainer_output_dir}: {lora_path}"
)
else: else:
logger.warning( logger.warning(f"No checkpoint found in {trainer_output_dir} for auto-detection. Evaluating base model.")
f"No checkpoint found in {trainer_output_dir} for auto-detection. Evaluating base model."
)
lora_path = None lora_path = None
model_type = "LoRA" if lora_path else "Base" model_type = "LoRA" if lora_path else "Base"
logger.info(f"\n{'=' * 50}") logger.info(f"\n{'=' * 50}")
logger.info(f"Starting evaluation of {model_type} model") logger.info(f"Starting evaluation of {model_type} model")
logger.info( logger.info(f"Trainer Output Directory: {trainer_output_dir}") # Log the final directory
f"Trainer Output Directory: {trainer_output_dir}"
) # Log the final directory
logger.info(f"{'=' * 50}") logger.info(f"{'=' * 50}")
# --- Create eval_logs directory --- # --- Create eval_logs directory ---
@ -319,18 +303,14 @@ def evaluate_model(
# Fallback to current directory if creation fails # Fallback to current directory if creation fails
eval_log_dir = os.path.abspath("./eval_logs") eval_log_dir = os.path.abspath("./eval_logs")
os.makedirs(eval_log_dir, exist_ok=True) os.makedirs(eval_log_dir, exist_ok=True)
logger.warning( logger.warning(f"Fell back to creating eval_logs in current directory: {eval_log_dir}")
f"Fell back to creating eval_logs in current directory: {eval_log_dir}"
)
# Create file names based on model type # Create file names based on model type
model_prefix = "lora" if lora_path else "base" model_prefix = "lora" if lora_path else "base"
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# Define all output file paths # Define all output file paths
eval_log_file = os.path.join( eval_log_file = os.path.join(eval_log_dir, f"{model_prefix}_model_eval_{timestamp}.log")
eval_log_dir, f"{model_prefix}_model_eval_{timestamp}.log"
)
output_file = os.path.join(eval_log_dir, f"{model_prefix}_model_results.txt") output_file = os.path.join(eval_log_dir, f"{model_prefix}_model_results.txt")
debug_file = os.path.join(eval_log_dir, f"{model_prefix}_model_results_debug.json") debug_file = os.path.join(eval_log_dir, f"{model_prefix}_model_results_debug.json")
@ -343,12 +323,8 @@ def evaluate_model(
if lora_path: if lora_path:
lora_request = model.load_lora(lora_path) lora_request = model.load_lora(lora_path)
load_time = time.time() - start_time load_time = time.time() - start_time
logger.info( logger.info(f"LoRA adapter loaded in {load_time:.2f} seconds: {lora_request}")
f"LoRA adapter loaded in {load_time:.2f} seconds: {lora_request}" responses = model.fast_generate(inputs, sampling_params=sampling_params, lora_request=lora_request)
)
responses = model.fast_generate(
inputs, sampling_params=sampling_params, lora_request=lora_request
)
else: else:
# For base model, add additional logging # For base model, add additional logging
logger.info("Generating with base model (no LoRA)") logger.info("Generating with base model (no LoRA)")
@ -373,9 +349,7 @@ def evaluate_model(
return model.fast_generate(inputs, sampling_params=verifier_params) return model.fast_generate(inputs, sampling_params=verifier_params)
# Prepare the verification function # Prepare the verification function
verify_fn = rl_helpers.build_reward_correctness_fn( verify_fn = rl_helpers.build_reward_correctness_fn(verifier_generate_fn, tokenizer, log_file=eval_log_file)
verifier_generate_fn, tokenizer, log_file=eval_log_file
)
# Get the dataset and prepare questions and answers # Get the dataset and prepare questions and answers
train_dataset, test_dataset = rl_helpers.get_qa_dataset() train_dataset, test_dataset = rl_helpers.get_qa_dataset()
@ -438,9 +412,7 @@ def evaluate_model(
f.write(f"Results saved to: {output_file}\n") f.write(f"Results saved to: {output_file}\n")
f.write(f"Debug data saved to: {debug_file}\n\n") f.write(f"Debug data saved to: {debug_file}\n\n")
logger.info( logger.info(f"Evaluation completed. Results saved to {output_file} and {debug_file}")
f"Evaluation completed. Results saved to {output_file} and {debug_file}"
)
return results return results
@ -463,9 +435,7 @@ def compare_models(lora_path, temperature=0.5, output_file=None, trainer_dir=Non
lora_path = detected_checkpoint lora_path = detected_checkpoint
logger.info(f"Auto-detected latest checkpoint: {lora_path}") logger.info(f"Auto-detected latest checkpoint: {lora_path}")
else: else:
logger.warning( logger.warning("No checkpoint found for auto-detection. Skipping comparison.")
"No checkpoint found for auto-detection. Skipping comparison."
)
return return
# Set up output directory in the checkpoint directory # Set up output directory in the checkpoint directory
@ -560,9 +530,7 @@ if __name__ == "__main__":
default="auto", default="auto",
help="Path to LoRA weights (use 'auto' for auto-detection)", help="Path to LoRA weights (use 'auto' for auto-detection)",
) )
parser.add_argument( parser.add_argument("--temperature", type=float, default=0.5, help="Sampling temperature")
"--temperature", type=float, default=0.5, help="Sampling temperature"
)
parser.add_argument( parser.add_argument(
"--output_file", "--output_file",
type=str, type=str,
@ -609,24 +577,16 @@ if __name__ == "__main__":
update_log_path(eval_log_dir) update_log_path(eval_log_dir)
logger.info(f"Logs will be saved to both ./logs and {eval_log_dir}") logger.info(f"Logs will be saved to both ./logs and {eval_log_dir}")
except ImportError: except ImportError:
logger.info( logger.info("Config's update_log_path not available, using default logging")
"Config's update_log_path not available, using default logging"
)
if trainer_dir: if trainer_dir:
logger.info(f"Using trainer directory: {trainer_dir}") logger.info(f"Using trainer directory: {trainer_dir}")
logger.info( logger.info(f"All evaluation files will be stored in: {os.path.join(trainer_dir, 'eval_logs')}")
f"All evaluation files will be stored in: {os.path.join(trainer_dir, 'eval_logs')}"
)
else: else:
logger.warning( logger.warning("No trainer directory found, will attempt to determine during evaluation")
"No trainer directory found, will attempt to determine during evaluation"
)
logger.info(f"Starting model evaluation with temperature {args.temperature}") logger.info(f"Starting model evaluation with temperature {args.temperature}")
results = compare_models( results = compare_models(args.lora_path, args.temperature, args.output_file, trainer_dir=trainer_dir)
args.lora_path, args.temperature, args.output_file, trainer_dir=trainer_dir
)
if results: if results:
logger.info("Evaluation completed successfully") logger.info("Evaluation completed successfully")
logger.info(f"Final improvement: {results['improvement']:.4f}") logger.info(f"Final improvement: {results['improvement']:.4f}")

@ -498,7 +498,7 @@ You are a helpful assistant with tool calling capabilities."""
try: try:
with open(filepath, "w", encoding="utf-8") as f: with open(filepath, "w", encoding="utf-8") as f:
f.write(f"{'=' * 80}\n") f.write(f"{'=' * 80}\n")
f.write(f"DEEPSEARCH CHAT HISTORY\n") f.write("DEEPSEARCH CHAT HISTORY\n")
f.write(f"Model: {MODEL_NAME}\n") f.write(f"Model: {MODEL_NAME}\n")
f.write(f"LoRA Path: {self.lora_path if self.lora_path else 'None'}\n") f.write(f"LoRA Path: {self.lora_path if self.lora_path else 'None'}\n")
f.write(f"Temperature: {self.temperature}\n") f.write(f"Temperature: {self.temperature}\n")

@ -70,15 +70,11 @@ torch_compile_options = {
) )
def selective_log_softmax(logits, index): def selective_log_softmax(logits, index):
logits = logits.to(torch.float32) logits = logits.to(torch.float32)
selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze( selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1)
-1
)
# loop to reduce peak mem consumption # loop to reduce peak mem consumption
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits]) # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
logsumexp_values = torch.logsumexp(logits, dim=-1) logsumexp_values = torch.logsumexp(logits, dim=-1)
per_token_logps = ( per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
selected_logits - logsumexp_values
) # log_softmax(x_i) = x_i - logsumexp(x)
return per_token_logps return per_token_logps
@ -139,17 +135,11 @@ class UnslothEfficientGRPO(torch.autograd.Function):
scaler=None, scaler=None,
n_chunks=1, n_chunks=1,
): ):
def compute_loss( def compute_loss(new_hidden_states, old_hidden_states, input_ids, mask, advantages, scaling):
new_hidden_states, old_hidden_states, input_ids, mask, advantages, scaling
):
new_logits = torch.matmul(new_hidden_states, lm_head.t()) new_logits = torch.matmul(new_hidden_states, lm_head.t())
new_logits = new_logits[ new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
:, :-1, :
] # exclude the last logit: it corresponds to the next token pred
old_logits = torch.matmul(old_hidden_states, lm_head.t()) old_logits = torch.matmul(old_hidden_states, lm_head.t())
old_logits = old_logits[ old_logits = old_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
:, :-1, :
] # exclude the last logit: it corresponds to the next token pred
loss, completion_length, mean_kl = grpo_compute_loss( loss, completion_length, mean_kl = grpo_compute_loss(
old_logits, old_logits,
new_logits, new_logits,
@ -311,11 +301,7 @@ def grpo_accumulated_loss(
n_chunks = bsz n_chunks = bsz
n_chunks = factors[min(np.searchsorted(factors, n_chunks), len(factors) - 1)] n_chunks = factors[min(np.searchsorted(factors, n_chunks), len(factors) - 1)]
mixed_dtype = ( mixed_dtype = torch.float16 if os.environ.get("ACCELERATE_MIXED_PRECISION", "fp16") == "fp16" else torch.bfloat16
torch.float16
if os.environ.get("ACCELERATE_MIXED_PRECISION", "fp16") == "fp16"
else torch.bfloat16
)
os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1" os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1"
completion_input_ids = input_ids[:, -logits_to_keep:] completion_input_ids = input_ids[:, -logits_to_keep:]
@ -324,18 +310,12 @@ def grpo_accumulated_loss(
with torch.amp.autocast(device_type="cuda", dtype=mixed_dtype): with torch.amp.autocast(device_type="cuda", dtype=mixed_dtype):
with ( with (
torch.inference_mode(), torch.inference_mode(),
trainer.accelerator.unwrap_model( trainer.accelerator.unwrap_model(trainer.model, keep_fp32_wrapper=False).disable_adapter(),
trainer.model, keep_fp32_wrapper=False
).disable_adapter(),
): ):
old_hidden_states = trainer.model( old_hidden_states = trainer.model(input_ids=input_ids, logits_to_keep=logits_to_keep + 1).logits
input_ids=input_ids, logits_to_keep=logits_to_keep + 1
).logits
pass pass
new_hidden_states = trainer.model( new_hidden_states = trainer.model(input_ids=input_ids, logits_to_keep=logits_to_keep + 1).logits
input_ids=input_ids, logits_to_keep=logits_to_keep + 1
).logits
loss, completion_length, mean_kl = UnslothEfficientGRPO.apply( loss, completion_length, mean_kl = UnslothEfficientGRPO.apply(
new_hidden_states, new_hidden_states,
@ -352,13 +332,9 @@ def grpo_accumulated_loss(
# Old non efficient code path # Old non efficient code path
new_logits = torch.matmul(new_hidden_states, lm_head.t()) new_logits = torch.matmul(new_hidden_states, lm_head.t())
new_logits = new_logits[ new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
:, :-1, :
] # exclude the last logit: it corresponds to the next token pred
old_logits = torch.matmul(old_hidden_states, lm_head.t()) old_logits = torch.matmul(old_hidden_states, lm_head.t())
old_logits = old_logits[ old_logits = old_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
:, :-1, :
] # exclude the last logit: it corresponds to the next token pred
loss, completion_length, mean_kl = grpo_compute_loss( loss, completion_length, mean_kl = grpo_compute_loss(
old_logits, old_logits,
new_logits, new_logits,
@ -824,24 +800,14 @@ class _UnslothGRPOTrainer(Trainer):
reward_funcs: Union[RewardFunc, list[RewardFunc]], reward_funcs: Union[RewardFunc, list[RewardFunc]],
args: GRPOConfig = None, args: GRPOConfig = None,
train_dataset: Optional[Union[Dataset, IterableDataset]] = None, train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
eval_dataset: Optional[ eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]
] = None,
processing_class: Optional[PreTrainedTokenizerBase] = None, processing_class: Optional[PreTrainedTokenizerBase] = None,
reward_processing_classes: Optional[ reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]
] = None,
callbacks: Optional[list[TrainerCallback]] = None, callbacks: Optional[list[TrainerCallback]] = None,
optimizers: tuple[ optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]
] = (None, None),
peft_config: Optional["PeftConfig"] = None, peft_config: Optional["PeftConfig"] = None,
): ):
if ( if hasattr(model, "vllm_engine") and hasattr(args, "use_vllm") and (getattr(args, "use_vllm", False) == False):
hasattr(model, "vllm_engine")
and hasattr(args, "use_vllm")
and (getattr(args, "use_vllm", False) == False)
):
args.use_vllm = True args.use_vllm = True
# Args # Args
if args is None: if args is None:
@ -855,11 +821,7 @@ class _UnslothGRPOTrainer(Trainer):
if isinstance(model, str): if isinstance(model, str):
model_id = model model_id = model
torch_dtype = model_init_kwargs.get("torch_dtype") torch_dtype = model_init_kwargs.get("torch_dtype")
if ( if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
isinstance(torch_dtype, torch.dtype)
or torch_dtype == "auto"
or torch_dtype is None
):
pass # torch_dtype is already a torch.dtype or "auto" or None pass # torch_dtype is already a torch.dtype or "auto" or None
elif isinstance(torch_dtype, str): # it's a str, but not "auto" elif isinstance(torch_dtype, str): # it's a str, but not "auto"
torch_dtype = getattr(torch, torch_dtype) torch_dtype = getattr(torch, torch_dtype)
@ -871,9 +833,7 @@ class _UnslothGRPOTrainer(Trainer):
) )
# Disable caching if gradient checkpointing is enabled (not supported) # Disable caching if gradient checkpointing is enabled (not supported)
model_init_kwargs["use_cache"] = ( model_init_kwargs["use_cache"] = (
False False if args.gradient_checkpointing else model_init_kwargs.get("use_cache")
if args.gradient_checkpointing
else model_init_kwargs.get("use_cache")
) )
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
else: else:
@ -889,9 +849,7 @@ class _UnslothGRPOTrainer(Trainer):
# Reference model # Reference model
if is_deepspeed_zero3_enabled(): if is_deepspeed_zero3_enabled():
self.ref_model = AutoModelForCausalLM.from_pretrained( self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
model_id, **model_init_kwargs
)
elif not is_peft_model(model): elif not is_peft_model(model):
# If PEFT configuration is not provided, create a reference model based on the initial model. # If PEFT configuration is not provided, create a reference model based on the initial model.
self.ref_model = create_reference_model(model) self.ref_model = create_reference_model(model)
@ -902,9 +860,7 @@ class _UnslothGRPOTrainer(Trainer):
# Processing class # Processing class
if processing_class is None: if processing_class is None:
processing_class = AutoTokenizer.from_pretrained( processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path, padding_side="left")
model.config._name_or_path, padding_side="left"
)
# Reward functions # Reward functions
if not isinstance(reward_funcs, list): if not isinstance(reward_funcs, list):
@ -934,22 +890,14 @@ class _UnslothGRPOTrainer(Trainer):
reward_processing_classes = [reward_processing_classes] reward_processing_classes = [reward_processing_classes]
else: else:
if len(reward_processing_classes) != len(reward_funcs): if len(reward_processing_classes) != len(reward_funcs):
raise ValueError( raise ValueError("The number of reward processing classes must match the number of reward functions.")
"The number of reward processing classes must match the number of reward functions."
)
for i, (reward_processing_class, reward_func) in enumerate( for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
zip(reward_processing_classes, reward_funcs)
):
if isinstance(reward_func, PreTrainedModel): if isinstance(reward_func, PreTrainedModel):
if reward_processing_class is None: if reward_processing_class is None:
reward_processing_class = AutoTokenizer.from_pretrained( reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
reward_func.config._name_or_path
)
if reward_processing_class.pad_token_id is None: if reward_processing_class.pad_token_id is None:
reward_processing_class.pad_token = ( reward_processing_class.pad_token = reward_processing_class.eos_token
reward_processing_class.eos_token
)
# The reward model computes the reward for the latest non-padded token in the input sequence. # The reward model computes the reward for the latest non-padded token in the input sequence.
# So it's important to set the pad token ID to the padding token ID of the processing class. # So it's important to set the pad token ID to the padding token ID of the processing class.
reward_func.config.pad_token_id = reward_processing_class.pad_token_id reward_func.config.pad_token_id = reward_processing_class.pad_token_id
@ -962,9 +910,7 @@ class _UnslothGRPOTrainer(Trainer):
# Training arguments # Training arguments
self.max_prompt_length = args.max_prompt_length self.max_prompt_length = args.max_prompt_length
self.max_completion_length = ( self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
args.max_completion_length
) # = |o_i| in the GRPO paper
self.num_generations = args.num_generations # = G in the GRPO paper self.num_generations = args.num_generations # = G in the GRPO paper
self.use_vllm = args.use_vllm self.use_vllm = args.use_vllm
self.use_agentic_generate = args.use_agentic_generate self.use_agentic_generate = args.use_agentic_generate
@ -997,11 +943,7 @@ class _UnslothGRPOTrainer(Trainer):
# Check if the per_device_train/eval_batch_size * num processes can be divided by the number of generations # Check if the per_device_train/eval_batch_size * num processes can be divided by the number of generations
num_processes = self.accelerator.num_processes num_processes = self.accelerator.num_processes
global_batch_size = args.per_device_train_batch_size * num_processes global_batch_size = args.per_device_train_batch_size * num_processes
possible_values = [ possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
n_gen
for n_gen in range(2, global_batch_size + 1)
if (global_batch_size) % n_gen == 0
]
if self.num_generations not in possible_values: if self.num_generations not in possible_values:
raise ValueError( raise ValueError(
f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly " f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly "
@ -1010,11 +952,7 @@ class _UnslothGRPOTrainer(Trainer):
) )
if self.args.eval_strategy != "no": if self.args.eval_strategy != "no":
global_batch_size = args.per_device_eval_batch_size * num_processes global_batch_size = args.per_device_eval_batch_size * num_processes
possible_values = [ possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
n_gen
for n_gen in range(2, global_batch_size + 1)
if (global_batch_size) % n_gen == 0
]
if self.num_generations not in possible_values: if self.num_generations not in possible_values:
raise ValueError( raise ValueError(
f"The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly " f"The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly "
@ -1059,22 +997,14 @@ class _UnslothGRPOTrainer(Trainer):
if self.is_deepspeed_enabled: if self.is_deepspeed_enabled:
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
else: else:
self.ref_model = self.accelerator.prepare_model( self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
self.ref_model, evaluation_mode=True
)
if args.sync_ref_model: if args.sync_ref_model:
self.add_callback( self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator))
SyncRefModelCallback(
ref_model=self.ref_model, accelerator=self.accelerator
)
)
for i, reward_func in enumerate(self.reward_funcs): for i, reward_func in enumerate(self.reward_funcs):
if isinstance(reward_func, PreTrainedModel): if isinstance(reward_func, PreTrainedModel):
self.reward_funcs[i] = self.accelerator.prepare_model( self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
reward_func, evaluation_mode=True
)
def _set_signature_columns_if_needed(self): def _set_signature_columns_if_needed(self):
# If `self.args.remove_unused_columns` is True, non-signature columns are removed. # If `self.args.remove_unused_columns` is True, non-signature columns are removed.
@ -1089,27 +1019,21 @@ class _UnslothGRPOTrainer(Trainer):
# identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly # identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly
# within each prompt group. Using the same seed across processes ensures consistent prompt assignment, # within each prompt group. Using the same seed across processes ensures consistent prompt assignment,
# preventing discrepancies in group formation. # preventing discrepancies in group formation.
return RepeatRandomSampler( return RepeatRandomSampler(self.train_dataset, self.num_generations, seed=self.args.seed)
self.train_dataset, self.num_generations, seed=self.args.seed
)
def _get_eval_sampler(self, eval_dataset) -> Sampler: def _get_eval_sampler(self, eval_dataset) -> Sampler:
# Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that # Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that
# identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly # identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly
# within each prompt group. Using the same seed across processes ensures consistent prompt assignment, # within each prompt group. Using the same seed across processes ensures consistent prompt assignment,
# preventing discrepancies in group formation. # preventing discrepancies in group formation.
return RepeatRandomSampler( return RepeatRandomSampler(eval_dataset, self.num_generations, seed=self.args.seed)
eval_dataset, self.num_generations, seed=self.args.seed
)
# Get the per-token log probabilities for the completions for the model and the reference model # Get the per-token log probabilities for the completions for the model and the reference model
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
return None # Unsloth efficient GRPO return None # Unsloth efficient GRPO
if not hasattr(self, "_autocast_dtype"): if not hasattr(self, "_autocast_dtype"):
self._autocast_dtype = ( self._autocast_dtype = (
torch.float16 torch.float16 if os.environ.get("ACCELERATE_MIXED_PRECISION", "fp16") == "fp16" else torch.bfloat16
if os.environ.get("ACCELERATE_MIXED_PRECISION", "fp16") == "fp16"
else torch.bfloat16
) )
with torch.amp.autocast(device_type="cuda", dtype=self._autocast_dtype): with torch.amp.autocast(device_type="cuda", dtype=self._autocast_dtype):
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
@ -1118,9 +1042,7 @@ class _UnslothGRPOTrainer(Trainer):
attention_mask=attention_mask, attention_mask=attention_mask,
logits_to_keep=logits_to_keep + 1, logits_to_keep=logits_to_keep + 1,
).logits ).logits
logits = logits[ logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
:, :-1, :
] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
input_ids = input_ids[:, -logits_to_keep:] input_ids = input_ids[:, -logits_to_keep:]
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
@ -1133,15 +1055,10 @@ class _UnslothGRPOTrainer(Trainer):
def _move_model_to_vllm(self, *args, **kwargs): def _move_model_to_vllm(self, *args, **kwargs):
return None return None
def _prepare_inputs( def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
self, inputs: dict[str, Union[torch.Tensor, Any]]
) -> dict[str, Union[torch.Tensor, Any]]:
device = self.accelerator.device device = self.accelerator.device
prompts = [x["prompt"] for x in inputs] prompts = [x["prompt"] for x in inputs]
prompts_text = [ prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
maybe_apply_chat_template(example, self.processing_class)["prompt"]
for example in inputs
]
prompt_inputs = self.processing_class( prompt_inputs = self.processing_class(
prompts_text, prompts_text,
return_tensors="pt", return_tensors="pt",
@ -1174,9 +1091,7 @@ class _UnslothGRPOTrainer(Trainer):
prompts_text, prompts_text,
sampling_params=self.sampling_params, sampling_params=self.sampling_params,
use_tqdm=False, use_tqdm=False,
lora_request=self.model.load_lora( lora_request=self.model.load_lora("grpo_trainer_lora_model", load_tensors=True),
"grpo_trainer_lora_model", load_tensors=True
),
) )
if self.use_agentic_generate: if self.use_agentic_generate:
agentic_outputs = self.model.agentic_generate( agentic_outputs = self.model.agentic_generate(
@ -1201,11 +1116,7 @@ class _UnslothGRPOTrainer(Trainer):
).to(device) ).to(device)
else: else:
outputs = generate_fn(all_prompts_text) outputs = generate_fn(all_prompts_text)
completion_ids = [ completion_ids = [out.token_ids for completions in outputs for out in completions.outputs]
out.token_ids
for completions in outputs
for out in completions.outputs
]
else: else:
completion_ids = [None] * len(all_prompts_text) completion_ids = [None] * len(all_prompts_text)
# Broadcast the completions from the main process to all processes, ensuring each process receives its # Broadcast the completions from the main process to all processes, ensuring each process receives its
@ -1218,18 +1129,12 @@ class _UnslothGRPOTrainer(Trainer):
completion_ids = completion_ids[process_slice] completion_ids = completion_ids[process_slice]
# Pad the completions, and concatenate them with the prompts # Pad the completions, and concatenate them with the prompts
completion_ids = [ completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
torch.tensor(ids, device=device) for ids in completion_ids completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
]
completion_ids = pad(
completion_ids, padding_value=self.processing_class.pad_token_id
)
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
else: else:
# Regular generation path # Regular generation path
with unwrap_model_for_generation( with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
self.model, self.accelerator
) as unwrapped_model:
prompt_completion_ids = unwrapped_model.generate( prompt_completion_ids = unwrapped_model.generate(
prompt_ids, prompt_ids,
attention_mask=prompt_mask, attention_mask=prompt_mask,
@ -1244,21 +1149,15 @@ class _UnslothGRPOTrainer(Trainer):
if not self.use_agentic_generate: if not self.use_agentic_generate:
# Mask everything after the first EOS token # Mask everything after the first EOS token
is_eos = completion_ids == self.processing_class.eos_token_id is_eos = completion_ids == self.processing_class.eos_token_id
eos_idx = torch.full( eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
(is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device
)
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
sequence_indices = torch.arange(is_eos.size(1), device=device).expand( sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
is_eos.size(0), -1
)
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
# Concatenate prompt_mask with completion_mask for logit computation # Concatenate prompt_mask with completion_mask for logit computation
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C) attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C)
logits_to_keep = completion_ids.size( logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
1
) # we only need to compute the logits for the completion tokens
# this does nothing # this does nothing
with ( with (
@ -1280,9 +1179,7 @@ class _UnslothGRPOTrainer(Trainer):
logits_to_keep, logits_to_keep,
) )
else: else:
with self.accelerator.unwrap_model( with self.accelerator.unwrap_model(self.model, keep_fp32_wrapper=False).disable_adapter():
self.model, keep_fp32_wrapper=False
).disable_adapter():
ref_per_token_logps = self._get_per_token_logps( ref_per_token_logps = self._get_per_token_logps(
self.model, self.model,
prompt_completion_ids, prompt_completion_ids,
@ -1292,42 +1189,25 @@ class _UnslothGRPOTrainer(Trainer):
# Decode the generated completions # Decode the generated completions
if not self.use_agentic_generate: if not self.use_agentic_generate:
completions_text = self.processing_class.batch_decode( completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
completion_ids, skip_special_tokens=True
)
if is_conversational(inputs[0]): if is_conversational(inputs[0]):
completions = [] completions = []
for prompt, completion in zip(prompts, completions_text): for prompt, completion in zip(prompts, completions_text):
bootstrap = ( bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
prompt.pop()["content"] completions.append([{"role": "assistant", "content": bootstrap + completion}])
if prompt[-1]["role"] == "assistant"
else ""
)
completions.append(
[{"role": "assistant", "content": bootstrap + completion}]
)
else: else:
completions = completions_text completions = completions_text
else: else:
completions = full_chats completions = full_chats
rewards_per_func = torch.zeros( rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
len(prompts), len(self.reward_funcs), device=device
)
for i, (reward_func, reward_processing_class) in enumerate( for i, (reward_func, reward_processing_class) in enumerate(
zip(self.reward_funcs, self.reward_processing_classes) zip(self.reward_funcs, self.reward_processing_classes)
): ):
if isinstance( if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
reward_func, nn.Module
): # Module instead of PretrainedModel for compat with compiled models
if is_conversational(inputs[0]): if is_conversational(inputs[0]):
messages = [ messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
{"messages": p + c} for p, c in zip(prompts, completions) texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
]
texts = [
apply_chat_template(x, reward_processing_class)["text"]
for x in messages
]
else: else:
texts = [p + c for p, c in zip(prompts, completions)] texts = [p + c for p, c in zip(prompts, completions)]
reward_inputs = reward_processing_class( reward_inputs = reward_processing_class(
@ -1343,36 +1223,25 @@ class _UnslothGRPOTrainer(Trainer):
torch.amp.autocast( torch.amp.autocast(
device_type="cuda", device_type="cuda",
dtype=torch.float16 dtype=torch.float16
if os.environ.get("ACCELERATE_MIXED_PRECISION", "fp16") if os.environ.get("ACCELERATE_MIXED_PRECISION", "fp16") == "fp16"
== "fp16"
else torch.bfloat16, else torch.bfloat16,
) )
if not torch.is_autocast_enabled("cuda") if not torch.is_autocast_enabled("cuda")
else nullcontext(), else nullcontext(),
): ):
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[ rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
:, 0
] # Shape (B*G,)
else: else:
# Repeat all input columns (but "prompt" and "completion") to match the number of generations # Repeat all input columns (but "prompt" and "completion") to match the number of generations
keys = [key for key in inputs[0] if key not in ["prompt", "completion"]] keys = [key for key in inputs[0] if key not in ["prompt", "completion"]]
reward_kwargs = { reward_kwargs = {key: [example[key] for example in inputs] for key in keys}
key: [example[key] for example in inputs] for key in keys output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
} rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
output_reward_func = reward_func(
prompts=prompts, completions=completions, **reward_kwargs
)
rewards_per_func[:, i] = torch.tensor(
output_reward_func, dtype=torch.float32, device=device
)
# Gather the reward per function: this part is crucial, because the rewards are normalized per group and the # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the
# completions may be distributed across processes # completions may be distributed across processes
rewards_per_func = gather(rewards_per_func) rewards_per_func = gather(rewards_per_func)
# Apply weights to each reward function's output and sum # Apply weights to each reward function's output and sum
rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).sum( rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).sum(dim=1)
dim=1
)
# else: # else:
# reward_fn = self.reward_funcs[0] # reward_fn = self.reward_funcs[0]
@ -1391,12 +1260,8 @@ class _UnslothGRPOTrainer(Trainer):
std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1) std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
# Normalize the rewards to compute the advantages # Normalize the rewards to compute the advantages
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave( mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
self.num_generations, dim=0 std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
)
std_grouped_rewards = std_grouped_rewards.repeat_interleave(
self.num_generations, dim=0
)
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4) advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
# Slice to keep only the local part of the data # Slice to keep only the local part of the data
@ -1410,15 +1275,11 @@ class _UnslothGRPOTrainer(Trainer):
reward_per_func = rewards_per_func.mean(0) reward_per_func = rewards_per_func.mean(0)
print("rewards_per_func:", reward_per_func) print("rewards_per_func:", reward_per_func)
for i, reward_func in enumerate(self.reward_funcs): for i, reward_func in enumerate(self.reward_funcs):
if isinstance( if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
reward_func, nn.Module
): # Module instead of PretrainedModel for compat with compiled models
reward_func_name = reward_func.config._name_or_path.split("/")[-1] reward_func_name = reward_func.config._name_or_path.split("/")[-1]
else: else:
reward_func_name = reward_func.__name__ reward_func_name = reward_func.__name__
self._metrics[f"rewards/{reward_func_name}"].append( self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
reward_per_func[i].item()
)
self._metrics["reward"].append(rewards.mean().item()) self._metrics["reward"].append(rewards.mean().item())
self._metrics["reward_std"].append(std_grouped_rewards.mean().item()) self._metrics["reward_std"].append(std_grouped_rewards.mean().item())
@ -1451,9 +1312,7 @@ class _UnslothGRPOTrainer(Trainer):
"advantages": advantages, "advantages": advantages,
} }
def compute_loss( def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
self, model, inputs, return_outputs=False, num_items_in_batch=None
):
if return_outputs: if return_outputs:
raise ValueError("The GRPOTrainer does not support returning outputs") raise ValueError("The GRPOTrainer does not support returning outputs")
# Compute the per-token log probabilities for the model # Compute the per-token log probabilities for the model
@ -1467,14 +1326,10 @@ class _UnslothGRPOTrainer(Trainer):
bsz, qlen = input_ids.shape bsz, qlen = input_ids.shape
# attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
attention_mask = None attention_mask = None
logits_to_keep = completion_ids.size( logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
1
) # we only need to compute the logits for the completion tokens
_input_ids = input_ids _input_ids = input_ids
_logits_to_keep = logits_to_keep _logits_to_keep = logits_to_keep
per_token_logps = self._get_per_token_logps( per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
model, input_ids, attention_mask, logits_to_keep
)
# Compute the KL divergence between the model and the reference model # Compute the KL divergence between the model and the reference model
ref_per_token_logps = inputs["ref_per_token_logps"] ref_per_token_logps = inputs["ref_per_token_logps"]
@ -1529,9 +1384,7 @@ class _UnslothGRPOTrainer(Trainer):
return loss, None, None return loss, None, None
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None: def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
metrics = { metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
key: sum(val) / len(val) for key, val in self._metrics.items()
} # average the metrics
# This method can be called both in training and evaluation. When called in evaluation, the keys in `logs` # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
# start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format. # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
@ -1565,9 +1418,7 @@ class _UnslothGRPOTrainer(Trainer):
if not self.is_world_process_zero(): if not self.is_world_process_zero():
return return
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir( if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
self.model.config._name_or_path
):
base_model = self.model.config._name_or_path base_model = self.model.config._name_or_path
else: else:
base_model = None base_model = None
@ -1596,9 +1447,7 @@ class _UnslothGRPOTrainer(Trainer):
hub_model_id=self.hub_model_id, hub_model_id=self.hub_model_id,
dataset_name=dataset_name, dataset_name=dataset_name,
tags=tags, tags=tags,
wandb_url=wandb.run.get_url() wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
if is_wandb_available() and wandb.run is not None
else None,
comet_url=get_comet_experiment_url(), comet_url=get_comet_experiment_url(),
trainer_name="GRPO", trainer_name="GRPO",
trainer_citation=citation, trainer_citation=citation,
@ -1735,10 +1584,7 @@ class UnslothGRPOTrainer(_UnslothGRPOTrainer):
args.fp16 = float16 args.fp16 = float16
args.bf16 = not float16 args.bf16 = not float16
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16" if float16 else "bf16" os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16" if float16 else "bf16"
if ( if getattr(args, "eval_dataset", None) is not None and getattr(args, "eval_strategy", "no") == "no":
getattr(args, "eval_dataset", None) is not None
and getattr(args, "eval_strategy", "no") == "no"
):
args.eval_strategy = "steps" args.eval_strategy = "steps"
if getattr(args, "eval_steps", None) is None: if getattr(args, "eval_steps", None) is None:
args.eval_steps = 0.1 args.eval_steps = 0.1
@ -1755,10 +1601,7 @@ class UnslothGRPOTrainer(_UnslothGRPOTrainer):
eval_bsz = getattr(args, "per_device_eval_batch_size", 8) eval_bsz = getattr(args, "per_device_eval_batch_size", 8)
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz:
args.per_device_eval_batch_size = args.per_device_train_batch_size args.per_device_eval_batch_size = args.per_device_train_batch_size
if ( if getattr(args, "eval_accumulation_steps", None) is None and ga_steps is not None:
getattr(args, "eval_accumulation_steps", None) is None
and ga_steps is not None
):
args.eval_accumulation_steps = ga_steps args.eval_accumulation_steps = ga_steps
fp16_full_eval = getattr(args, "fp16_full_eval", False) fp16_full_eval = getattr(args, "fp16_full_eval", False)
bf16_full_eval = getattr(args, "bf16_full_eval", False) bf16_full_eval = getattr(args, "bf16_full_eval", False)
@ -1787,9 +1630,7 @@ class UnslothGRPOTrainer(_UnslothGRPOTrainer):
if "processing_class" in locals(): if "processing_class" in locals():
if hasattr(processing_class, "padding_side"): if hasattr(processing_class, "padding_side"):
processing_class.padding_side = "right" processing_class.padding_side = "right"
if hasattr(processing_class, "tokenizer") and hasattr( if hasattr(processing_class, "tokenizer") and hasattr(processing_class.tokenizer, "padding_side"):
processing_class.tokenizer, "padding_side"
):
processing_class.tokenizer.padding_side = "right" processing_class.tokenizer.padding_side = "right"
other_metrics = [] other_metrics = []
if not isinstance(reward_funcs, list): if not isinstance(reward_funcs, list):

@ -19,15 +19,10 @@ LOG_FOLDER = PROJ_ROOT / "logs"
# Model configuration # Model configuration
# MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" # MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
device_id = ( device_id = 1 if os.environ.get("CUDA_VISIBLE_DEVICES") == "1" else torch.cuda.current_device()
1 if os.environ.get("CUDA_VISIBLE_DEVICES") == "1" else torch.cuda.current_device()
)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
OUTPUT_DIR = ( OUTPUT_DIR = PROJ_ROOT / f"trainer_output_{MODEL_NAME.replace('/', '_')}_gpu{device_id}_{timestamp}"
PROJ_ROOT
/ f"trainer_output_{MODEL_NAME.replace('/', '_')}_gpu{device_id}_{timestamp}"
)
# Model parameters # Model parameters
MODEL_CONFIG = { MODEL_CONFIG = {
@ -103,12 +98,7 @@ def _init_logging(env: str = "development") -> None:
"- <level>{message}</level>" "- <level>{message}</level>"
) )
file_format = ( file_format = "{time:YYYY-MM-DD at HH:mm:ss} | {level} | {name}:{function}:{line} - {message}"
"{time:YYYY-MM-DD at HH:mm:ss} "
"| {level} "
"| {name}:{function}:{line} "
"- {message}"
)
# Add console logging # Add console logging
logger.add( logger.add(
@ -139,9 +129,7 @@ def _init_logging(env: str = "development") -> None:
if issubclass(exc_type, KeyboardInterrupt): if issubclass(exc_type, KeyboardInterrupt):
sys.__excepthook__(exc_type, exc_value, exc_traceback) sys.__excepthook__(exc_type, exc_value, exc_traceback)
return return
logger.opt(exception=(exc_type, exc_value, exc_traceback)).critical( logger.opt(exception=(exc_type, exc_value, exc_traceback)).critical("Unhandled exception")
"Unhandled exception"
)
sys.excepthook = exception_handler sys.excepthook = exception_handler
@ -163,12 +151,7 @@ def update_log_path(log_dir=None):
log_dir = Path(log_dir) log_dir = Path(log_dir)
log_dir.mkdir(exist_ok=True, parents=True) log_dir.mkdir(exist_ok=True, parents=True)
file_format = ( file_format = "{time:YYYY-MM-DD at HH:mm:ss} | {level} | {name}:{function}:{line} - {message}"
"{time:YYYY-MM-DD at HH:mm:ss} "
"| {level} "
"| {name}:{function}:{line} "
"- {message}"
)
# Add additional file handler pointing to training directory # Add additional file handler pointing to training directory
# No need to remove existing handlers as we want to keep those # No need to remove existing handlers as we want to keep those
@ -248,9 +231,7 @@ def setup_logger(module_name=None, create_dirs: bool = False):
Returns: Returns:
Configured logger instance Configured logger instance
""" """
logger.warning( logger.warning("setup_logger is deprecated. Import logger directly from config instead.")
"setup_logger is deprecated. Import logger directly from config instead."
)
return logger return logger

@ -186,12 +186,8 @@ def run_agent_generations(generate_fn, tokenizer, chat_states):
full_response = response.outputs[0].text full_response = response.outputs[0].text
else: else:
full_response = response full_response = response
assistant_response = full_response.split( assistant_response = full_response.split("<|start_header_id|>assistant<|end_header_id|>")[-1]
"<|start_header_id|>assistant<|end_header_id|>" chat_state["messages"].append({"role": "assistant", "content": assistant_response})
)[-1]
chat_state["messages"].append(
{"role": "assistant", "content": assistant_response}
)
logger.debug(f"Added assistant response to chat state {idx}") logger.debug(f"Added assistant response to chat state {idx}")
else: else:
logger.debug("No prompts to generate responses for") logger.debug("No prompts to generate responses for")
@ -211,9 +207,7 @@ def check_finished_chats(chat_states):
for chat_state in chat_states: for chat_state in chat_states:
if chat_state.get("finished"): if chat_state.get("finished"):
continue continue
assert ( assert chat_state["messages"][-1]["role"] == "assistant", "Expected the last role to be assistant"
chat_state["messages"][-1]["role"] == "assistant"
), "Expected the last role to be assistant"
assistant_response = chat_state["messages"][-1]["content"] assistant_response = chat_state["messages"][-1]["content"]
function_calls = extract_json_objects(assistant_response) function_calls = extract_json_objects(assistant_response)
if len(function_calls) == 0: if len(function_calls) == 0:
@ -232,17 +226,15 @@ def run_tool_calls(chat_states):
if chat_state.get("finished"): if chat_state.get("finished"):
logger.debug("Chat state already finished, skipping tool calls") logger.debug("Chat state already finished, skipping tool calls")
continue continue
assert ( assert chat_state["messages"][-1]["role"] == "assistant", (
chat_state["messages"][-1]["role"] == "assistant" "Expected the last role to be assistant to run tool calls"
), "Expected the last role to be assistant to run tool calls" )
try: try:
assistant_response = chat_state["messages"][-1]["content"] assistant_response = chat_state["messages"][-1]["content"]
function_calls = extract_json_objects(assistant_response) function_calls = extract_json_objects(assistant_response)
if len(function_calls) > 1: if len(function_calls) > 1:
logger.warning("Multiple function calls found in assistant response") logger.warning("Multiple function calls found in assistant response")
raise ValueError( raise ValueError("Expected only one function call in assistant response")
"Expected only one function call in assistant response"
)
elif len(function_calls) == 1: elif len(function_calls) == 1:
function_call = function_calls[0] function_call = function_calls[0]
query = function_call["function"]["parameters"]["query"] query = function_call["function"]["parameters"]["query"]
@ -257,9 +249,7 @@ def run_tool_calls(chat_states):
logger.debug("Added search results to chat state") logger.debug("Added search results to chat state")
except Exception as e: except Exception as e:
logger.error(f"Error during tool call: {str(e)}") logger.error(f"Error during tool call: {str(e)}")
chat_state["messages"].append( chat_state["messages"].append({"role": "system", "content": f"Error during post-processing: {str(e)}"})
{"role": "system", "content": f"Error during post-processing: {str(e)}"}
)
chat_state["finished"] = True chat_state["finished"] = True
return chat_states return chat_states
@ -273,14 +263,9 @@ def get_mask(text, tokenizer):
assistant_ranges = [] assistant_ranges = []
i = 0 i = 0
while i < len(encoding.input_ids) - 1: while i < len(encoding.input_ids) - 1:
if ( if encoding.input_ids[i] == start_header_id and encoding.input_ids[i + 1] == assistant_token:
encoding.input_ids[i] == start_header_id
and encoding.input_ids[i + 1] == assistant_token
):
i += 2 i += 2
while ( while i < len(encoding.input_ids) and encoding.input_ids[i] != end_header_id:
i < len(encoding.input_ids) and encoding.input_ids[i] != end_header_id
):
i += 1 i += 1
i += 2 i += 2
start_idx = i start_idx = i
@ -319,11 +304,7 @@ class AgenticOutputs:
def get_chat_num_tokens(chat_state, tokenizer): def get_chat_num_tokens(chat_state, tokenizer):
chat_text = apply_chat_template(chat_state, tokenizer=tokenizer)["text"] chat_text = apply_chat_template(chat_state, tokenizer=tokenizer)["text"]
return ( return tokenizer(chat_text, add_special_tokens=False, return_tensors="pt")["input_ids"].squeeze().shape[0]
tokenizer(chat_text, add_special_tokens=False, return_tensors="pt")["input_ids"]
.squeeze()
.shape[0]
)
def run_agent( def run_agent(
@ -338,9 +319,7 @@ def run_agent(
Run the agent to completion for a batch of questions. Run the agent to completion for a batch of questions.
""" """
logger.info(f"Starting agent run with {len(questions)} questions") logger.info(f"Starting agent run with {len(questions)} questions")
logger.debug( logger.debug(f"Max generations: {max_generations}, Max new tokens: {max_new_tokens}")
f"Max generations: {max_generations}, Max new tokens: {max_new_tokens}"
)
chat_states = [get_initial_chat(q) for q in questions] chat_states = [get_initial_chat(q) for q in questions]
# Add correct content to chat states if provided # Add correct content to chat states if provided
@ -359,13 +338,9 @@ def run_agent(
chat_states = run_agent_generations(generate_fn, tokenizer, chat_states) chat_states = run_agent_generations(generate_fn, tokenizer, chat_states)
chat_states = check_finished_chats(chat_states) chat_states = check_finished_chats(chat_states)
chat_states = run_tool_calls(chat_states) chat_states = run_tool_calls(chat_states)
chat_states = check_exceeded_max_new_tokens( chat_states = check_exceeded_max_new_tokens(chat_states, max_new_tokens, tokenizer)
chat_states, max_new_tokens, tokenizer
)
finished_count = sum(1 for state in chat_states if state.get("finished")) finished_count = sum(1 for state in chat_states if state.get("finished"))
logger.info( logger.info(f"Finished {finished_count}/{len(chat_states)} chat states after step {i + 1}")
f"Finished {finished_count}/{len(chat_states)} chat states after step {i + 1}"
)
logger.info("Agent run completed") logger.info("Agent run completed")
@ -387,23 +362,15 @@ def run_agent(
assistant_response = convo_text[idx + len(marker) :] assistant_response = convo_text[idx + len(marker) :]
return prompt, assistant_response return prompt, assistant_response
str_chats = [ str_chats = [apply_chat_template(chat, tokenizer=tokenizer)["text"] for chat in chat_states]
apply_chat_template(chat, tokenizer=tokenizer)["text"] for chat in chat_states
]
prompt_toks, response_toks, response_masks = [], [], [] prompt_toks, response_toks, response_masks = [], [], []
logger.debug("Processing tokenization") logger.debug("Processing tokenization")
for i, str_chat in enumerate(str_chats): for i, str_chat in enumerate(str_chats):
prompt, response = split_prompt_assistant(str_chat) prompt, response = split_prompt_assistant(str_chat)
prompt_toks.append( prompt_toks.append(tokenizer(prompt, add_special_tokens=False, return_tensors="pt")["input_ids"].squeeze())
tokenizer(prompt, add_special_tokens=False, return_tensors="pt")[
"input_ids"
].squeeze()
)
response_toks.append( response_toks.append(
tokenizer(response, add_special_tokens=False, return_tensors="pt")[ tokenizer(response, add_special_tokens=False, return_tensors="pt")["input_ids"].squeeze()[:max_new_tokens]
"input_ids"
].squeeze()[:max_new_tokens]
) )
mask = get_mask(str_chat, tokenizer)[len(prompt_toks[-1]) :][:max_new_tokens] mask = get_mask(str_chat, tokenizer)[len(prompt_toks[-1]) :][:max_new_tokens]
response_masks.append(mask) response_masks.append(mask)
@ -469,12 +436,8 @@ def check_student_answers(
logger.info(f"Checking {len(questions)} student answers") logger.info(f"Checking {len(questions)} student answers")
if not (len(questions) == len(answers) == len(student_answers)): if not (len(questions) == len(answers) == len(student_answers)):
logger.error( logger.error("Mismatched lengths between questions, answers, and student answers")
"Mismatched lengths between questions, answers, and student answers" raise ValueError("The number of questions, answers, and student answers must be equal.")
)
raise ValueError(
"The number of questions, answers, and student answers must be equal."
)
prompts = [] prompts = []
for question, answer, student_ans in zip(questions, answers, student_answers): for question, answer, student_ans in zip(questions, answers, student_answers):
@ -548,11 +511,7 @@ def check_student_answers(
if isinstance(student_ans, dict) and "messages" in student_ans: if isinstance(student_ans, dict) and "messages" in student_ans:
# Get messages from dict # Get messages from dict
messages = student_ans.get("messages", []) messages = student_ans.get("messages", [])
search_results = [ search_results = [msg.get("content", "") for msg in messages if msg.get("role") == "ipython"]
msg.get("content", "")
for msg in messages
if msg.get("role") == "ipython"
]
if search_results: if search_results:
file.write("\n🔎 Search Results:\n") file.write("\n🔎 Search Results:\n")
for j, result in enumerate(search_results, 1): for j, result in enumerate(search_results, 1):
@ -572,18 +531,12 @@ def check_student_answers(
def build_reward_correctness_fn(generate_fn, tokenizer, log_file=None): def build_reward_correctness_fn(generate_fn, tokenizer, log_file=None):
def reward_correctness(prompts, completions, **reward_kwargs): def reward_correctness(prompts, completions, **reward_kwargs):
teacher_answers = reward_kwargs["answer"] teacher_answers = reward_kwargs["answer"]
student_answers = [ student_answers = [completion["messages"][-1]["content"] for completion in completions]
completion["messages"][-1]["content"] for completion in completions
]
# Log non-exact matches # Log non-exact matches
for i, (student, teacher) in enumerate(zip(student_answers, teacher_answers)): for i, (student, teacher) in enumerate(zip(student_answers, teacher_answers)):
if student.strip().lower() != teacher.strip().lower(): if student.strip().lower() != teacher.strip().lower():
logger.warning( logger.warning(f"Non-exact match at index {i}:\nStudent: {student}\nTeacher: {teacher}")
f"Non-exact match at index {i}:\n"
f"Student: {student}\n"
f"Teacher: {teacher}"
)
correct = check_student_answers( correct = check_student_answers(
prompts, prompts,
@ -595,12 +548,8 @@ def build_reward_correctness_fn(generate_fn, tokenizer, log_file=None):
) )
# Log correctness metrics with length info # Log correctness metrics with length info
log_metric( log_metric("rewards/correctness", np.mean(correct), reward_kwargs.get("step", 0))
"rewards/correctness", np.mean(correct), reward_kwargs.get("step", 0) log_metric("rewards/correctness_std", np.std(correct), reward_kwargs.get("step", 0))
)
log_metric(
"rewards/correctness_std", np.std(correct), reward_kwargs.get("step", 0)
)
# Log length metrics # Log length metrics
student_lengths = [len(ans.strip()) for ans in student_answers] student_lengths = [len(ans.strip()) for ans in student_answers]
@ -676,9 +625,7 @@ def reward_retry_behavior(completions: list[dict], **reward_kwargs) -> list[floa
if json_count > 1: if json_count > 1:
has_multiple_json = True has_multiple_json = True
logger.warning( logger.warning(f"Message contains {json_count} JSON objects, which exceeds the limit of 1")
f"Message contains {json_count} JSON objects, which exceeds the limit of 1"
)
break break
# Only reward if no message has multiple JSON objects # Only reward if no message has multiple JSON objects
@ -692,17 +639,13 @@ def reward_retry_behavior(completions: list[dict], **reward_kwargs) -> list[floa
if total_json_objects > 4: if total_json_objects > 4:
penalty = 0.1 * (total_json_objects - 4) penalty = 0.1 * (total_json_objects - 4)
base_reward = max(0.2, base_reward - penalty) base_reward = max(0.2, base_reward - penalty)
logger.debug( logger.debug(f"Applied penalty for {total_json_objects} total JSON objects: {penalty}")
f"Applied penalty for {total_json_objects} total JSON objects: {penalty}"
)
rewards.append(base_reward) rewards.append(base_reward)
# Log retry behavior metrics # Log retry behavior metrics
log_metric("rewards/retry_behavior", np.mean(rewards), reward_kwargs.get("step", 0)) log_metric("rewards/retry_behavior", np.mean(rewards), reward_kwargs.get("step", 0))
log_metric( log_metric("rewards/retry_behavior_std", np.std(rewards), reward_kwargs.get("step", 0))
"rewards/retry_behavior_std", np.std(rewards), reward_kwargs.get("step", 0)
)
log_metric( log_metric(
"metrics/avg_json_per_msg", "metrics/avg_json_per_msg",
np.mean( np.mean(
@ -737,13 +680,9 @@ def reward_exact_match_chunk_query(prompts, completions, **reward_kwargs):
raise ValueError("chunk_content must be provided in reward_kwargs") raise ValueError("chunk_content must be provided in reward_kwargs")
rewards = [] rewards = []
for i, (chat_state, correct_content) in enumerate( for i, (chat_state, correct_content) in enumerate(zip(completions, correct_contents)):
zip(completions, correct_contents)
):
# Get all ipython messages (search results) from the chat # Get all ipython messages (search results) from the chat
search_results = [ search_results = [msg["content"] for msg in chat_state["messages"] if msg["role"] == "ipython"]
msg["content"] for msg in chat_state["messages"] if msg["role"] == "ipython"
]
logger.debug(f"Found {len(search_results)} search results for prompt {i}") logger.debug(f"Found {len(search_results)} search results for prompt {i}")
# Log ground truth chunk and searched chunks # Log ground truth chunk and searched chunks
@ -756,9 +695,7 @@ def reward_exact_match_chunk_query(prompts, completions, **reward_kwargs):
for result in search_results: for result in search_results:
if correct_content in result: if correct_content in result:
found_correct_chunk = True found_correct_chunk = True
logger.debug( logger.debug(f"Found correct chunk content in search results for prompt {i}")
f"Found correct chunk content in search results for prompt {i}"
)
break break
if not found_correct_chunk: if not found_correct_chunk:
@ -796,21 +733,13 @@ def reward_exact_match_chunk_query(prompts, completions, **reward_kwargs):
"metrics/avg_search_results", "metrics/avg_search_results",
np.mean( np.mean(
[ [
len( len([msg["content"] for msg in chat_state["messages"] if msg["role"] == "ipython"])
[
msg["content"]
for msg in chat_state["messages"]
if msg["role"] == "ipython"
]
)
for chat_state in completions for chat_state in completions
] ]
), ),
reward_kwargs.get("step", 0), reward_kwargs.get("step", 0),
) )
log_metric( log_metric("metrics/chunk_match_rate", np.mean(rewards), reward_kwargs.get("step", 0))
"metrics/chunk_match_rate", np.mean(rewards), reward_kwargs.get("step", 0)
)
# Log detailed debugging info # Log detailed debugging info
logger.info("Chunk Query Rewards Summary:") logger.info("Chunk Query Rewards Summary:")
@ -862,9 +791,7 @@ def run_eval(generate_fn, verify_fn, tokenizer, output_file=None, debug_file=Non
f.write(f"Percentage correct: {percent_correct:.2f}%\n\n") f.write(f"Percentage correct: {percent_correct:.2f}%\n\n")
f.write("Individual results:\n") f.write("Individual results:\n")
for i, (q, r, resp) in enumerate( for i, (q, r, resp) in enumerate(zip(questions, rewards, final_responses)):
zip(questions, rewards, final_responses)
):
f.write(f"\nQ{i + 1}: {q[:100]}...\n") f.write(f"\nQ{i + 1}: {q[:100]}...\n")
f.write(f"Correct: {'' if r else ''}\n") f.write(f"Correct: {'' if r else ''}\n")
f.write(f"Response: {resp[:150]}...\n") f.write(f"Response: {resp[:150]}...\n")
@ -879,9 +806,7 @@ def run_eval(generate_fn, verify_fn, tokenizer, output_file=None, debug_file=Non
import json import json
debug_data = [] debug_data = []
for i, (q, r, resp, chat) in enumerate( for i, (q, r, resp, chat) in enumerate(zip(questions, rewards, final_responses, full_chat_states)):
zip(questions, rewards, final_responses, full_chat_states)
):
debug_data.append( debug_data.append(
{ {
"question_id": i, "question_id": i,

@ -20,9 +20,7 @@ def load_vectorstore():
embeddings = CustomHuggingFaceEmbeddings() embeddings = CustomHuggingFaceEmbeddings()
# Load the FAISS index from the data directory # Load the FAISS index from the data directory
logger.info(f"Loading FAISS index from: {DATA_DIR}") logger.info(f"Loading FAISS index from: {DATA_DIR}")
vectorstore = FAISS.load_local( vectorstore = FAISS.load_local(str(DATA_DIR), embeddings, allow_dangerous_deserialization=True)
str(DATA_DIR), embeddings, allow_dangerous_deserialization=True
)
logger.info("Successfully loaded FAISS index") logger.info("Successfully loaded FAISS index")
return vectorstore return vectorstore
except Exception as e: except Exception as e:
@ -125,9 +123,7 @@ def get_question_answer(idx=None, return_both: bool = True) -> dict:
# Select question by index # Select question by index
qa_pair = questions[idx] qa_pair = questions[idx]
else: else:
raise ValueError( raise ValueError(f"Index out of range. Must be between 0 and {len(questions) - 1}")
f"Index out of range. Must be between 0 and {len(questions) - 1}"
)
question = qa_pair["question"] question = qa_pair["question"]
answer = qa_pair["answer"] answer = qa_pair["answer"]

@ -3,16 +3,6 @@ import os
from unsloth import FastLanguageModel, is_bfloat16_supported from unsloth import FastLanguageModel, is_bfloat16_supported
import src.UnslothGRPOTrainerTemp as UnslothGRPOTrainerTemp import src.UnslothGRPOTrainerTemp as UnslothGRPOTrainerTemp
# Import reward functions
from src.rl_helpers import (
build_reward_correctness_fn,
get_qa_dataset,
reward_exact_match_chunk_query,
reward_formatting,
reward_retry_behavior,
run_agent,
)
from src.config import ( from src.config import (
MODEL_CONFIG, MODEL_CONFIG,
MODEL_NAME, MODEL_NAME,
@ -24,6 +14,16 @@ from src.config import (
update_log_path, update_log_path,
) )
# Import reward functions
from src.rl_helpers import (
build_reward_correctness_fn,
get_qa_dataset,
reward_exact_match_chunk_query,
reward_formatting,
reward_retry_behavior,
run_agent,
)
# Initialize training directories # Initialize training directories
paths = init_training_dirs() paths = init_training_dirs()
@ -57,9 +57,7 @@ model = FastLanguageModel.get_peft_model(
# Load datasets # Load datasets
logger.info("Loading datasets") logger.info("Loading datasets")
train_dataset, test_dataset = get_qa_dataset() train_dataset, test_dataset = get_qa_dataset()
logger.info( logger.info(f"Loaded {len(train_dataset)} training examples and {len(test_dataset)} test examples")
f"Loaded {len(train_dataset)} training examples and {len(test_dataset)} test examples"
)
# Setup training arguments # Setup training arguments
logger.info("Setting up training arguments") logger.info("Setting up training arguments")

Loading…
Cancel
Save