style: change line length to 119, organize imports

main
thinhlpg 1 month ago
parent abb18b10d8
commit 04593fa8fd

@ -3,7 +3,7 @@
# make sure to test the local checkout in scripts and not the pre-installed one
export PYTHONPATH = src
check_dirs := src
check_dirs := . src notebooks scripts
# Development dependencies
install:

@ -170,9 +170,7 @@ def test_lora_functionality(model, tokenizer, lora_path):
# Sample with base model
logger.info("Generating with base model...")
sampling_params = get_sampling_params(
temperature=0.7
) # Higher temp to make differences more obvious
sampling_params = get_sampling_params(temperature=0.7) # Higher temp to make differences more obvious
base_response = model.fast_generate(
[formatted_prompt],
sampling_params=sampling_params,
@ -217,12 +215,8 @@ def test_lora_functionality(model, tokenizer, lora_path):
logger.info("-" * 40)
if are_identical:
logger.warning(
"\nWARNING: LoRA adapter does not seem to change the model's output"
)
logger.warning(
"This could indicate that the LoRA adapter is not being properly applied"
)
logger.warning("\nWARNING: LoRA adapter does not seem to change the model's output")
logger.warning("This could indicate that the LoRA adapter is not being properly applied")
else:
logger.info("\nLoRA adapter is working as expected (outputs are different)")
@ -257,9 +251,7 @@ def evaluate_model(
# Prioritize the directory passed from the shell script if available
if trainer_dir and os.path.isdir(trainer_dir):
trainer_output_dir = os.path.abspath(trainer_dir)
logger.info(
f"Using trainer directory passed from arguments: {trainer_output_dir}"
)
logger.info(f"Using trainer directory passed from arguments: {trainer_output_dir}")
else:
logger.warning(
f"Trainer directory not provided or invalid: {trainer_dir}. Attempting to determine automatically."
@ -274,9 +266,7 @@ def evaluate_model(
# If a LoRA path exists (provided or found), get its parent's parent
checkpoint_dir = os.path.dirname(os.path.abspath(temp_lora_path))
trainer_output_dir = os.path.dirname(checkpoint_dir)
logger.info(
f"Determined trainer directory from LoRA path ({temp_lora_path}): {trainer_output_dir}"
)
logger.info(f"Determined trainer directory from LoRA path ({temp_lora_path}): {trainer_output_dir}")
else:
# If no LoRA path, default to current directory (should ideally not happen if called from eval.sh)
trainer_output_dir = os.path.abspath(".")
@ -290,22 +280,16 @@ def evaluate_model(
detected_checkpoint = find_latest_checkpoint(search_dir=trainer_output_dir)
if detected_checkpoint:
lora_path = detected_checkpoint
logger.info(
f"Auto-detected latest checkpoint in {trainer_output_dir}: {lora_path}"
)
logger.info(f"Auto-detected latest checkpoint in {trainer_output_dir}: {lora_path}")
else:
logger.warning(
f"No checkpoint found in {trainer_output_dir} for auto-detection. Evaluating base model."
)
logger.warning(f"No checkpoint found in {trainer_output_dir} for auto-detection. Evaluating base model.")
lora_path = None
model_type = "LoRA" if lora_path else "Base"
logger.info(f"\n{'=' * 50}")
logger.info(f"Starting evaluation of {model_type} model")
logger.info(
f"Trainer Output Directory: {trainer_output_dir}"
) # Log the final directory
logger.info(f"Trainer Output Directory: {trainer_output_dir}") # Log the final directory
logger.info(f"{'=' * 50}")
# --- Create eval_logs directory ---
@ -319,18 +303,14 @@ def evaluate_model(
# Fallback to current directory if creation fails
eval_log_dir = os.path.abspath("./eval_logs")
os.makedirs(eval_log_dir, exist_ok=True)
logger.warning(
f"Fell back to creating eval_logs in current directory: {eval_log_dir}"
)
logger.warning(f"Fell back to creating eval_logs in current directory: {eval_log_dir}")
# Create file names based on model type
model_prefix = "lora" if lora_path else "base"
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# Define all output file paths
eval_log_file = os.path.join(
eval_log_dir, f"{model_prefix}_model_eval_{timestamp}.log"
)
eval_log_file = os.path.join(eval_log_dir, f"{model_prefix}_model_eval_{timestamp}.log")
output_file = os.path.join(eval_log_dir, f"{model_prefix}_model_results.txt")
debug_file = os.path.join(eval_log_dir, f"{model_prefix}_model_results_debug.json")
@ -343,12 +323,8 @@ def evaluate_model(
if lora_path:
lora_request = model.load_lora(lora_path)
load_time = time.time() - start_time
logger.info(
f"LoRA adapter loaded in {load_time:.2f} seconds: {lora_request}"
)
responses = model.fast_generate(
inputs, sampling_params=sampling_params, lora_request=lora_request
)
logger.info(f"LoRA adapter loaded in {load_time:.2f} seconds: {lora_request}")
responses = model.fast_generate(inputs, sampling_params=sampling_params, lora_request=lora_request)
else:
# For base model, add additional logging
logger.info("Generating with base model (no LoRA)")
@ -373,9 +349,7 @@ def evaluate_model(
return model.fast_generate(inputs, sampling_params=verifier_params)
# Prepare the verification function
verify_fn = rl_helpers.build_reward_correctness_fn(
verifier_generate_fn, tokenizer, log_file=eval_log_file
)
verify_fn = rl_helpers.build_reward_correctness_fn(verifier_generate_fn, tokenizer, log_file=eval_log_file)
# Get the dataset and prepare questions and answers
train_dataset, test_dataset = rl_helpers.get_qa_dataset()
@ -438,9 +412,7 @@ def evaluate_model(
f.write(f"Results saved to: {output_file}\n")
f.write(f"Debug data saved to: {debug_file}\n\n")
logger.info(
f"Evaluation completed. Results saved to {output_file} and {debug_file}"
)
logger.info(f"Evaluation completed. Results saved to {output_file} and {debug_file}")
return results
@ -463,9 +435,7 @@ def compare_models(lora_path, temperature=0.5, output_file=None, trainer_dir=Non
lora_path = detected_checkpoint
logger.info(f"Auto-detected latest checkpoint: {lora_path}")
else:
logger.warning(
"No checkpoint found for auto-detection. Skipping comparison."
)
logger.warning("No checkpoint found for auto-detection. Skipping comparison.")
return
# Set up output directory in the checkpoint directory
@ -560,9 +530,7 @@ if __name__ == "__main__":
default="auto",
help="Path to LoRA weights (use 'auto' for auto-detection)",
)
parser.add_argument(
"--temperature", type=float, default=0.5, help="Sampling temperature"
)
parser.add_argument("--temperature", type=float, default=0.5, help="Sampling temperature")
parser.add_argument(
"--output_file",
type=str,
@ -609,24 +577,16 @@ if __name__ == "__main__":
update_log_path(eval_log_dir)
logger.info(f"Logs will be saved to both ./logs and {eval_log_dir}")
except ImportError:
logger.info(
"Config's update_log_path not available, using default logging"
)
logger.info("Config's update_log_path not available, using default logging")
if trainer_dir:
logger.info(f"Using trainer directory: {trainer_dir}")
logger.info(
f"All evaluation files will be stored in: {os.path.join(trainer_dir, 'eval_logs')}"
)
logger.info(f"All evaluation files will be stored in: {os.path.join(trainer_dir, 'eval_logs')}")
else:
logger.warning(
"No trainer directory found, will attempt to determine during evaluation"
)
logger.warning("No trainer directory found, will attempt to determine during evaluation")
logger.info(f"Starting model evaluation with temperature {args.temperature}")
results = compare_models(
args.lora_path, args.temperature, args.output_file, trainer_dir=trainer_dir
)
results = compare_models(args.lora_path, args.temperature, args.output_file, trainer_dir=trainer_dir)
if results:
logger.info("Evaluation completed successfully")
logger.info(f"Final improvement: {results['improvement']:.4f}")

@ -498,7 +498,7 @@ You are a helpful assistant with tool calling capabilities."""
try:
with open(filepath, "w", encoding="utf-8") as f:
f.write(f"{'=' * 80}\n")
f.write(f"DEEPSEARCH CHAT HISTORY\n")
f.write("DEEPSEARCH CHAT HISTORY\n")
f.write(f"Model: {MODEL_NAME}\n")
f.write(f"LoRA Path: {self.lora_path if self.lora_path else 'None'}\n")
f.write(f"Temperature: {self.temperature}\n")

@ -70,15 +70,11 @@ torch_compile_options = {
)
def selective_log_softmax(logits, index):
logits = logits.to(torch.float32)
selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(
-1
)
selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1)
# loop to reduce peak mem consumption
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
logsumexp_values = torch.logsumexp(logits, dim=-1)
per_token_logps = (
selected_logits - logsumexp_values
) # log_softmax(x_i) = x_i - logsumexp(x)
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
return per_token_logps
@ -139,17 +135,11 @@ class UnslothEfficientGRPO(torch.autograd.Function):
scaler=None,
n_chunks=1,
):
def compute_loss(
new_hidden_states, old_hidden_states, input_ids, mask, advantages, scaling
):
def compute_loss(new_hidden_states, old_hidden_states, input_ids, mask, advantages, scaling):
new_logits = torch.matmul(new_hidden_states, lm_head.t())
new_logits = new_logits[
:, :-1, :
] # exclude the last logit: it corresponds to the next token pred
new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
old_logits = torch.matmul(old_hidden_states, lm_head.t())
old_logits = old_logits[
:, :-1, :
] # exclude the last logit: it corresponds to the next token pred
old_logits = old_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
loss, completion_length, mean_kl = grpo_compute_loss(
old_logits,
new_logits,
@ -311,11 +301,7 @@ def grpo_accumulated_loss(
n_chunks = bsz
n_chunks = factors[min(np.searchsorted(factors, n_chunks), len(factors) - 1)]
mixed_dtype = (
torch.float16
if os.environ.get("ACCELERATE_MIXED_PRECISION", "fp16") == "fp16"
else torch.bfloat16
)
mixed_dtype = torch.float16 if os.environ.get("ACCELERATE_MIXED_PRECISION", "fp16") == "fp16" else torch.bfloat16
os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1"
completion_input_ids = input_ids[:, -logits_to_keep:]
@ -324,18 +310,12 @@ def grpo_accumulated_loss(
with torch.amp.autocast(device_type="cuda", dtype=mixed_dtype):
with (
torch.inference_mode(),
trainer.accelerator.unwrap_model(
trainer.model, keep_fp32_wrapper=False
).disable_adapter(),
trainer.accelerator.unwrap_model(trainer.model, keep_fp32_wrapper=False).disable_adapter(),
):
old_hidden_states = trainer.model(
input_ids=input_ids, logits_to_keep=logits_to_keep + 1
).logits
old_hidden_states = trainer.model(input_ids=input_ids, logits_to_keep=logits_to_keep + 1).logits
pass
new_hidden_states = trainer.model(
input_ids=input_ids, logits_to_keep=logits_to_keep + 1
).logits
new_hidden_states = trainer.model(input_ids=input_ids, logits_to_keep=logits_to_keep + 1).logits
loss, completion_length, mean_kl = UnslothEfficientGRPO.apply(
new_hidden_states,
@ -352,13 +332,9 @@ def grpo_accumulated_loss(
# Old non efficient code path
new_logits = torch.matmul(new_hidden_states, lm_head.t())
new_logits = new_logits[
:, :-1, :
] # exclude the last logit: it corresponds to the next token pred
new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
old_logits = torch.matmul(old_hidden_states, lm_head.t())
old_logits = old_logits[
:, :-1, :
] # exclude the last logit: it corresponds to the next token pred
old_logits = old_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
loss, completion_length, mean_kl = grpo_compute_loss(
old_logits,
new_logits,
@ -824,24 +800,14 @@ class _UnslothGRPOTrainer(Trainer):
reward_funcs: Union[RewardFunc, list[RewardFunc]],
args: GRPOConfig = None,
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
eval_dataset: Optional[
Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]
] = None,
eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
processing_class: Optional[PreTrainedTokenizerBase] = None,
reward_processing_classes: Optional[
Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]
] = None,
reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
callbacks: Optional[list[TrainerCallback]] = None,
optimizers: tuple[
Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]
] = (None, None),
optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
peft_config: Optional["PeftConfig"] = None,
):
if (
hasattr(model, "vllm_engine")
and hasattr(args, "use_vllm")
and (getattr(args, "use_vllm", False) == False)
):
if hasattr(model, "vllm_engine") and hasattr(args, "use_vllm") and (getattr(args, "use_vllm", False) == False):
args.use_vllm = True
# Args
if args is None:
@ -855,11 +821,7 @@ class _UnslothGRPOTrainer(Trainer):
if isinstance(model, str):
model_id = model
torch_dtype = model_init_kwargs.get("torch_dtype")
if (
isinstance(torch_dtype, torch.dtype)
or torch_dtype == "auto"
or torch_dtype is None
):
if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
pass # torch_dtype is already a torch.dtype or "auto" or None
elif isinstance(torch_dtype, str): # it's a str, but not "auto"
torch_dtype = getattr(torch, torch_dtype)
@ -871,9 +833,7 @@ class _UnslothGRPOTrainer(Trainer):
)
# Disable caching if gradient checkpointing is enabled (not supported)
model_init_kwargs["use_cache"] = (
False
if args.gradient_checkpointing
else model_init_kwargs.get("use_cache")
False if args.gradient_checkpointing else model_init_kwargs.get("use_cache")
)
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
else:
@ -889,9 +849,7 @@ class _UnslothGRPOTrainer(Trainer):
# Reference model
if is_deepspeed_zero3_enabled():
self.ref_model = AutoModelForCausalLM.from_pretrained(
model_id, **model_init_kwargs
)
self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
elif not is_peft_model(model):
# If PEFT configuration is not provided, create a reference model based on the initial model.
self.ref_model = create_reference_model(model)
@ -902,9 +860,7 @@ class _UnslothGRPOTrainer(Trainer):
# Processing class
if processing_class is None:
processing_class = AutoTokenizer.from_pretrained(
model.config._name_or_path, padding_side="left"
)
processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path, padding_side="left")
# Reward functions
if not isinstance(reward_funcs, list):
@ -934,22 +890,14 @@ class _UnslothGRPOTrainer(Trainer):
reward_processing_classes = [reward_processing_classes]
else:
if len(reward_processing_classes) != len(reward_funcs):
raise ValueError(
"The number of reward processing classes must match the number of reward functions."
)
raise ValueError("The number of reward processing classes must match the number of reward functions.")
for i, (reward_processing_class, reward_func) in enumerate(
zip(reward_processing_classes, reward_funcs)
):
for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
if isinstance(reward_func, PreTrainedModel):
if reward_processing_class is None:
reward_processing_class = AutoTokenizer.from_pretrained(
reward_func.config._name_or_path
)
reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
if reward_processing_class.pad_token_id is None:
reward_processing_class.pad_token = (
reward_processing_class.eos_token
)
reward_processing_class.pad_token = reward_processing_class.eos_token
# The reward model computes the reward for the latest non-padded token in the input sequence.
# So it's important to set the pad token ID to the padding token ID of the processing class.
reward_func.config.pad_token_id = reward_processing_class.pad_token_id
@ -962,9 +910,7 @@ class _UnslothGRPOTrainer(Trainer):
# Training arguments
self.max_prompt_length = args.max_prompt_length
self.max_completion_length = (
args.max_completion_length
) # = |o_i| in the GRPO paper
self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
self.num_generations = args.num_generations # = G in the GRPO paper
self.use_vllm = args.use_vllm
self.use_agentic_generate = args.use_agentic_generate
@ -997,11 +943,7 @@ class _UnslothGRPOTrainer(Trainer):
# Check if the per_device_train/eval_batch_size * num processes can be divided by the number of generations
num_processes = self.accelerator.num_processes
global_batch_size = args.per_device_train_batch_size * num_processes
possible_values = [
n_gen
for n_gen in range(2, global_batch_size + 1)
if (global_batch_size) % n_gen == 0
]
possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
if self.num_generations not in possible_values:
raise ValueError(
f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly "
@ -1010,11 +952,7 @@ class _UnslothGRPOTrainer(Trainer):
)
if self.args.eval_strategy != "no":
global_batch_size = args.per_device_eval_batch_size * num_processes
possible_values = [
n_gen
for n_gen in range(2, global_batch_size + 1)
if (global_batch_size) % n_gen == 0
]
possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
if self.num_generations not in possible_values:
raise ValueError(
f"The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly "
@ -1059,22 +997,14 @@ class _UnslothGRPOTrainer(Trainer):
if self.is_deepspeed_enabled:
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
else:
self.ref_model = self.accelerator.prepare_model(
self.ref_model, evaluation_mode=True
)
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
if args.sync_ref_model:
self.add_callback(
SyncRefModelCallback(
ref_model=self.ref_model, accelerator=self.accelerator
)
)
self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator))
for i, reward_func in enumerate(self.reward_funcs):
if isinstance(reward_func, PreTrainedModel):
self.reward_funcs[i] = self.accelerator.prepare_model(
reward_func, evaluation_mode=True
)
self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
def _set_signature_columns_if_needed(self):
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
@ -1089,27 +1019,21 @@ class _UnslothGRPOTrainer(Trainer):
# identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly
# within each prompt group. Using the same seed across processes ensures consistent prompt assignment,
# preventing discrepancies in group formation.
return RepeatRandomSampler(
self.train_dataset, self.num_generations, seed=self.args.seed
)
return RepeatRandomSampler(self.train_dataset, self.num_generations, seed=self.args.seed)
def _get_eval_sampler(self, eval_dataset) -> Sampler:
# Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that
# identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly
# within each prompt group. Using the same seed across processes ensures consistent prompt assignment,
# preventing discrepancies in group formation.
return RepeatRandomSampler(
eval_dataset, self.num_generations, seed=self.args.seed
)
return RepeatRandomSampler(eval_dataset, self.num_generations, seed=self.args.seed)
# Get the per-token log probabilities for the completions for the model and the reference model
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
return None # Unsloth efficient GRPO
if not hasattr(self, "_autocast_dtype"):
self._autocast_dtype = (
torch.float16
if os.environ.get("ACCELERATE_MIXED_PRECISION", "fp16") == "fp16"
else torch.bfloat16
torch.float16 if os.environ.get("ACCELERATE_MIXED_PRECISION", "fp16") == "fp16" else torch.bfloat16
)
with torch.amp.autocast(device_type="cuda", dtype=self._autocast_dtype):
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
@ -1118,9 +1042,7 @@ class _UnslothGRPOTrainer(Trainer):
attention_mask=attention_mask,
logits_to_keep=logits_to_keep + 1,
).logits
logits = logits[
:, :-1, :
] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
input_ids = input_ids[:, -logits_to_keep:]
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
@ -1133,15 +1055,10 @@ class _UnslothGRPOTrainer(Trainer):
def _move_model_to_vllm(self, *args, **kwargs):
return None
def _prepare_inputs(
self, inputs: dict[str, Union[torch.Tensor, Any]]
) -> dict[str, Union[torch.Tensor, Any]]:
def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
device = self.accelerator.device
prompts = [x["prompt"] for x in inputs]
prompts_text = [
maybe_apply_chat_template(example, self.processing_class)["prompt"]
for example in inputs
]
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
prompt_inputs = self.processing_class(
prompts_text,
return_tensors="pt",
@ -1174,9 +1091,7 @@ class _UnslothGRPOTrainer(Trainer):
prompts_text,
sampling_params=self.sampling_params,
use_tqdm=False,
lora_request=self.model.load_lora(
"grpo_trainer_lora_model", load_tensors=True
),
lora_request=self.model.load_lora("grpo_trainer_lora_model", load_tensors=True),
)
if self.use_agentic_generate:
agentic_outputs = self.model.agentic_generate(
@ -1201,11 +1116,7 @@ class _UnslothGRPOTrainer(Trainer):
).to(device)
else:
outputs = generate_fn(all_prompts_text)
completion_ids = [
out.token_ids
for completions in outputs
for out in completions.outputs
]
completion_ids = [out.token_ids for completions in outputs for out in completions.outputs]
else:
completion_ids = [None] * len(all_prompts_text)
# Broadcast the completions from the main process to all processes, ensuring each process receives its
@ -1218,18 +1129,12 @@ class _UnslothGRPOTrainer(Trainer):
completion_ids = completion_ids[process_slice]
# Pad the completions, and concatenate them with the prompts
completion_ids = [
torch.tensor(ids, device=device) for ids in completion_ids
]
completion_ids = pad(
completion_ids, padding_value=self.processing_class.pad_token_id
)
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
else:
# Regular generation path
with unwrap_model_for_generation(
self.model, self.accelerator
) as unwrapped_model:
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
prompt_completion_ids = unwrapped_model.generate(
prompt_ids,
attention_mask=prompt_mask,
@ -1244,21 +1149,15 @@ class _UnslothGRPOTrainer(Trainer):
if not self.use_agentic_generate:
# Mask everything after the first EOS token
is_eos = completion_ids == self.processing_class.eos_token_id
eos_idx = torch.full(
(is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device
)
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(
is_eos.size(0), -1
)
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
# Concatenate prompt_mask with completion_mask for logit computation
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C)
logits_to_keep = completion_ids.size(
1
) # we only need to compute the logits for the completion tokens
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
# this does nothing
with (
@ -1280,9 +1179,7 @@ class _UnslothGRPOTrainer(Trainer):
logits_to_keep,
)
else:
with self.accelerator.unwrap_model(
self.model, keep_fp32_wrapper=False
).disable_adapter():
with self.accelerator.unwrap_model(self.model, keep_fp32_wrapper=False).disable_adapter():
ref_per_token_logps = self._get_per_token_logps(
self.model,
prompt_completion_ids,
@ -1292,42 +1189,25 @@ class _UnslothGRPOTrainer(Trainer):
# Decode the generated completions
if not self.use_agentic_generate:
completions_text = self.processing_class.batch_decode(
completion_ids, skip_special_tokens=True
)
completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
if is_conversational(inputs[0]):
completions = []
for prompt, completion in zip(prompts, completions_text):
bootstrap = (
prompt.pop()["content"]
if prompt[-1]["role"] == "assistant"
else ""
)
completions.append(
[{"role": "assistant", "content": bootstrap + completion}]
)
bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
completions.append([{"role": "assistant", "content": bootstrap + completion}])
else:
completions = completions_text
else:
completions = full_chats
rewards_per_func = torch.zeros(
len(prompts), len(self.reward_funcs), device=device
)
rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
for i, (reward_func, reward_processing_class) in enumerate(
zip(self.reward_funcs, self.reward_processing_classes)
):
if isinstance(
reward_func, nn.Module
): # Module instead of PretrainedModel for compat with compiled models
if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
if is_conversational(inputs[0]):
messages = [
{"messages": p + c} for p, c in zip(prompts, completions)
]
texts = [
apply_chat_template(x, reward_processing_class)["text"]
for x in messages
]
messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
else:
texts = [p + c for p, c in zip(prompts, completions)]
reward_inputs = reward_processing_class(
@ -1343,36 +1223,25 @@ class _UnslothGRPOTrainer(Trainer):
torch.amp.autocast(
device_type="cuda",
dtype=torch.float16
if os.environ.get("ACCELERATE_MIXED_PRECISION", "fp16")
== "fp16"
if os.environ.get("ACCELERATE_MIXED_PRECISION", "fp16") == "fp16"
else torch.bfloat16,
)
if not torch.is_autocast_enabled("cuda")
else nullcontext(),
):
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[
:, 0
] # Shape (B*G,)
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
else:
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
keys = [key for key in inputs[0] if key not in ["prompt", "completion"]]
reward_kwargs = {
key: [example[key] for example in inputs] for key in keys
}
output_reward_func = reward_func(
prompts=prompts, completions=completions, **reward_kwargs
)
rewards_per_func[:, i] = torch.tensor(
output_reward_func, dtype=torch.float32, device=device
)
reward_kwargs = {key: [example[key] for example in inputs] for key in keys}
output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
# Gather the reward per function: this part is crucial, because the rewards are normalized per group and the
# completions may be distributed across processes
rewards_per_func = gather(rewards_per_func)
# Apply weights to each reward function's output and sum
rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).sum(
dim=1
)
rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).sum(dim=1)
# else:
# reward_fn = self.reward_funcs[0]
@ -1391,12 +1260,8 @@ class _UnslothGRPOTrainer(Trainer):
std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
# Normalize the rewards to compute the advantages
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(
self.num_generations, dim=0
)
std_grouped_rewards = std_grouped_rewards.repeat_interleave(
self.num_generations, dim=0
)
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
# Slice to keep only the local part of the data
@ -1410,15 +1275,11 @@ class _UnslothGRPOTrainer(Trainer):
reward_per_func = rewards_per_func.mean(0)
print("rewards_per_func:", reward_per_func)
for i, reward_func in enumerate(self.reward_funcs):
if isinstance(
reward_func, nn.Module
): # Module instead of PretrainedModel for compat with compiled models
if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
reward_func_name = reward_func.config._name_or_path.split("/")[-1]
else:
reward_func_name = reward_func.__name__
self._metrics[f"rewards/{reward_func_name}"].append(
reward_per_func[i].item()
)
self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
self._metrics["reward"].append(rewards.mean().item())
self._metrics["reward_std"].append(std_grouped_rewards.mean().item())
@ -1451,9 +1312,7 @@ class _UnslothGRPOTrainer(Trainer):
"advantages": advantages,
}
def compute_loss(
self, model, inputs, return_outputs=False, num_items_in_batch=None
):
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
if return_outputs:
raise ValueError("The GRPOTrainer does not support returning outputs")
# Compute the per-token log probabilities for the model
@ -1467,14 +1326,10 @@ class _UnslothGRPOTrainer(Trainer):
bsz, qlen = input_ids.shape
# attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
attention_mask = None
logits_to_keep = completion_ids.size(
1
) # we only need to compute the logits for the completion tokens
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
_input_ids = input_ids
_logits_to_keep = logits_to_keep
per_token_logps = self._get_per_token_logps(
model, input_ids, attention_mask, logits_to_keep
)
per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
# Compute the KL divergence between the model and the reference model
ref_per_token_logps = inputs["ref_per_token_logps"]
@ -1529,9 +1384,7 @@ class _UnslothGRPOTrainer(Trainer):
return loss, None, None
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
metrics = {
key: sum(val) / len(val) for key, val in self._metrics.items()
} # average the metrics
metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
# This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
# start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
@ -1565,9 +1418,7 @@ class _UnslothGRPOTrainer(Trainer):
if not self.is_world_process_zero():
return
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(
self.model.config._name_or_path
):
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
base_model = self.model.config._name_or_path
else:
base_model = None
@ -1596,9 +1447,7 @@ class _UnslothGRPOTrainer(Trainer):
hub_model_id=self.hub_model_id,
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.get_url()
if is_wandb_available() and wandb.run is not None
else None,
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
comet_url=get_comet_experiment_url(),
trainer_name="GRPO",
trainer_citation=citation,
@ -1735,10 +1584,7 @@ class UnslothGRPOTrainer(_UnslothGRPOTrainer):
args.fp16 = float16
args.bf16 = not float16
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16" if float16 else "bf16"
if (
getattr(args, "eval_dataset", None) is not None
and getattr(args, "eval_strategy", "no") == "no"
):
if getattr(args, "eval_dataset", None) is not None and getattr(args, "eval_strategy", "no") == "no":
args.eval_strategy = "steps"
if getattr(args, "eval_steps", None) is None:
args.eval_steps = 0.1
@ -1755,10 +1601,7 @@ class UnslothGRPOTrainer(_UnslothGRPOTrainer):
eval_bsz = getattr(args, "per_device_eval_batch_size", 8)
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz:
args.per_device_eval_batch_size = args.per_device_train_batch_size
if (
getattr(args, "eval_accumulation_steps", None) is None
and ga_steps is not None
):
if getattr(args, "eval_accumulation_steps", None) is None and ga_steps is not None:
args.eval_accumulation_steps = ga_steps
fp16_full_eval = getattr(args, "fp16_full_eval", False)
bf16_full_eval = getattr(args, "bf16_full_eval", False)
@ -1787,9 +1630,7 @@ class UnslothGRPOTrainer(_UnslothGRPOTrainer):
if "processing_class" in locals():
if hasattr(processing_class, "padding_side"):
processing_class.padding_side = "right"
if hasattr(processing_class, "tokenizer") and hasattr(
processing_class.tokenizer, "padding_side"
):
if hasattr(processing_class, "tokenizer") and hasattr(processing_class.tokenizer, "padding_side"):
processing_class.tokenizer.padding_side = "right"
other_metrics = []
if not isinstance(reward_funcs, list):

@ -19,15 +19,10 @@ LOG_FOLDER = PROJ_ROOT / "logs"
# Model configuration
# MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
device_id = (
1 if os.environ.get("CUDA_VISIBLE_DEVICES") == "1" else torch.cuda.current_device()
)
device_id = 1 if os.environ.get("CUDA_VISIBLE_DEVICES") == "1" else torch.cuda.current_device()
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
OUTPUT_DIR = (
PROJ_ROOT
/ f"trainer_output_{MODEL_NAME.replace('/', '_')}_gpu{device_id}_{timestamp}"
)
OUTPUT_DIR = PROJ_ROOT / f"trainer_output_{MODEL_NAME.replace('/', '_')}_gpu{device_id}_{timestamp}"
# Model parameters
MODEL_CONFIG = {
@ -103,12 +98,7 @@ def _init_logging(env: str = "development") -> None:
"- <level>{message}</level>"
)
file_format = (
"{time:YYYY-MM-DD at HH:mm:ss} "
"| {level} "
"| {name}:{function}:{line} "
"- {message}"
)
file_format = "{time:YYYY-MM-DD at HH:mm:ss} | {level} | {name}:{function}:{line} - {message}"
# Add console logging
logger.add(
@ -139,9 +129,7 @@ def _init_logging(env: str = "development") -> None:
if issubclass(exc_type, KeyboardInterrupt):
sys.__excepthook__(exc_type, exc_value, exc_traceback)
return
logger.opt(exception=(exc_type, exc_value, exc_traceback)).critical(
"Unhandled exception"
)
logger.opt(exception=(exc_type, exc_value, exc_traceback)).critical("Unhandled exception")
sys.excepthook = exception_handler
@ -163,12 +151,7 @@ def update_log_path(log_dir=None):
log_dir = Path(log_dir)
log_dir.mkdir(exist_ok=True, parents=True)
file_format = (
"{time:YYYY-MM-DD at HH:mm:ss} "
"| {level} "
"| {name}:{function}:{line} "
"- {message}"
)
file_format = "{time:YYYY-MM-DD at HH:mm:ss} | {level} | {name}:{function}:{line} - {message}"
# Add additional file handler pointing to training directory
# No need to remove existing handlers as we want to keep those
@ -248,9 +231,7 @@ def setup_logger(module_name=None, create_dirs: bool = False):
Returns:
Configured logger instance
"""
logger.warning(
"setup_logger is deprecated. Import logger directly from config instead."
)
logger.warning("setup_logger is deprecated. Import logger directly from config instead.")
return logger

@ -186,12 +186,8 @@ def run_agent_generations(generate_fn, tokenizer, chat_states):
full_response = response.outputs[0].text
else:
full_response = response
assistant_response = full_response.split(
"<|start_header_id|>assistant<|end_header_id|>"
)[-1]
chat_state["messages"].append(
{"role": "assistant", "content": assistant_response}
)
assistant_response = full_response.split("<|start_header_id|>assistant<|end_header_id|>")[-1]
chat_state["messages"].append({"role": "assistant", "content": assistant_response})
logger.debug(f"Added assistant response to chat state {idx}")
else:
logger.debug("No prompts to generate responses for")
@ -211,9 +207,7 @@ def check_finished_chats(chat_states):
for chat_state in chat_states:
if chat_state.get("finished"):
continue
assert (
chat_state["messages"][-1]["role"] == "assistant"
), "Expected the last role to be assistant"
assert chat_state["messages"][-1]["role"] == "assistant", "Expected the last role to be assistant"
assistant_response = chat_state["messages"][-1]["content"]
function_calls = extract_json_objects(assistant_response)
if len(function_calls) == 0:
@ -232,17 +226,15 @@ def run_tool_calls(chat_states):
if chat_state.get("finished"):
logger.debug("Chat state already finished, skipping tool calls")
continue
assert (
chat_state["messages"][-1]["role"] == "assistant"
), "Expected the last role to be assistant to run tool calls"
assert chat_state["messages"][-1]["role"] == "assistant", (
"Expected the last role to be assistant to run tool calls"
)
try:
assistant_response = chat_state["messages"][-1]["content"]
function_calls = extract_json_objects(assistant_response)
if len(function_calls) > 1:
logger.warning("Multiple function calls found in assistant response")
raise ValueError(
"Expected only one function call in assistant response"
)
raise ValueError("Expected only one function call in assistant response")
elif len(function_calls) == 1:
function_call = function_calls[0]
query = function_call["function"]["parameters"]["query"]
@ -257,9 +249,7 @@ def run_tool_calls(chat_states):
logger.debug("Added search results to chat state")
except Exception as e:
logger.error(f"Error during tool call: {str(e)}")
chat_state["messages"].append(
{"role": "system", "content": f"Error during post-processing: {str(e)}"}
)
chat_state["messages"].append({"role": "system", "content": f"Error during post-processing: {str(e)}"})
chat_state["finished"] = True
return chat_states
@ -273,14 +263,9 @@ def get_mask(text, tokenizer):
assistant_ranges = []
i = 0
while i < len(encoding.input_ids) - 1:
if (
encoding.input_ids[i] == start_header_id
and encoding.input_ids[i + 1] == assistant_token
):
if encoding.input_ids[i] == start_header_id and encoding.input_ids[i + 1] == assistant_token:
i += 2
while (
i < len(encoding.input_ids) and encoding.input_ids[i] != end_header_id
):
while i < len(encoding.input_ids) and encoding.input_ids[i] != end_header_id:
i += 1
i += 2
start_idx = i
@ -319,11 +304,7 @@ class AgenticOutputs:
def get_chat_num_tokens(chat_state, tokenizer):
chat_text = apply_chat_template(chat_state, tokenizer=tokenizer)["text"]
return (
tokenizer(chat_text, add_special_tokens=False, return_tensors="pt")["input_ids"]
.squeeze()
.shape[0]
)
return tokenizer(chat_text, add_special_tokens=False, return_tensors="pt")["input_ids"].squeeze().shape[0]
def run_agent(
@ -338,9 +319,7 @@ def run_agent(
Run the agent to completion for a batch of questions.
"""
logger.info(f"Starting agent run with {len(questions)} questions")
logger.debug(
f"Max generations: {max_generations}, Max new tokens: {max_new_tokens}"
)
logger.debug(f"Max generations: {max_generations}, Max new tokens: {max_new_tokens}")
chat_states = [get_initial_chat(q) for q in questions]
# Add correct content to chat states if provided
@ -359,13 +338,9 @@ def run_agent(
chat_states = run_agent_generations(generate_fn, tokenizer, chat_states)
chat_states = check_finished_chats(chat_states)
chat_states = run_tool_calls(chat_states)
chat_states = check_exceeded_max_new_tokens(
chat_states, max_new_tokens, tokenizer
)
chat_states = check_exceeded_max_new_tokens(chat_states, max_new_tokens, tokenizer)
finished_count = sum(1 for state in chat_states if state.get("finished"))
logger.info(
f"Finished {finished_count}/{len(chat_states)} chat states after step {i + 1}"
)
logger.info(f"Finished {finished_count}/{len(chat_states)} chat states after step {i + 1}")
logger.info("Agent run completed")
@ -387,23 +362,15 @@ def run_agent(
assistant_response = convo_text[idx + len(marker) :]
return prompt, assistant_response
str_chats = [
apply_chat_template(chat, tokenizer=tokenizer)["text"] for chat in chat_states
]
str_chats = [apply_chat_template(chat, tokenizer=tokenizer)["text"] for chat in chat_states]
prompt_toks, response_toks, response_masks = [], [], []
logger.debug("Processing tokenization")
for i, str_chat in enumerate(str_chats):
prompt, response = split_prompt_assistant(str_chat)
prompt_toks.append(
tokenizer(prompt, add_special_tokens=False, return_tensors="pt")[
"input_ids"
].squeeze()
)
prompt_toks.append(tokenizer(prompt, add_special_tokens=False, return_tensors="pt")["input_ids"].squeeze())
response_toks.append(
tokenizer(response, add_special_tokens=False, return_tensors="pt")[
"input_ids"
].squeeze()[:max_new_tokens]
tokenizer(response, add_special_tokens=False, return_tensors="pt")["input_ids"].squeeze()[:max_new_tokens]
)
mask = get_mask(str_chat, tokenizer)[len(prompt_toks[-1]) :][:max_new_tokens]
response_masks.append(mask)
@ -469,12 +436,8 @@ def check_student_answers(
logger.info(f"Checking {len(questions)} student answers")
if not (len(questions) == len(answers) == len(student_answers)):
logger.error(
"Mismatched lengths between questions, answers, and student answers"
)
raise ValueError(
"The number of questions, answers, and student answers must be equal."
)
logger.error("Mismatched lengths between questions, answers, and student answers")
raise ValueError("The number of questions, answers, and student answers must be equal.")
prompts = []
for question, answer, student_ans in zip(questions, answers, student_answers):
@ -548,11 +511,7 @@ def check_student_answers(
if isinstance(student_ans, dict) and "messages" in student_ans:
# Get messages from dict
messages = student_ans.get("messages", [])
search_results = [
msg.get("content", "")
for msg in messages
if msg.get("role") == "ipython"
]
search_results = [msg.get("content", "") for msg in messages if msg.get("role") == "ipython"]
if search_results:
file.write("\n🔎 Search Results:\n")
for j, result in enumerate(search_results, 1):
@ -572,18 +531,12 @@ def check_student_answers(
def build_reward_correctness_fn(generate_fn, tokenizer, log_file=None):
def reward_correctness(prompts, completions, **reward_kwargs):
teacher_answers = reward_kwargs["answer"]
student_answers = [
completion["messages"][-1]["content"] for completion in completions
]
student_answers = [completion["messages"][-1]["content"] for completion in completions]
# Log non-exact matches
for i, (student, teacher) in enumerate(zip(student_answers, teacher_answers)):
if student.strip().lower() != teacher.strip().lower():
logger.warning(
f"Non-exact match at index {i}:\n"
f"Student: {student}\n"
f"Teacher: {teacher}"
)
logger.warning(f"Non-exact match at index {i}:\nStudent: {student}\nTeacher: {teacher}")
correct = check_student_answers(
prompts,
@ -595,12 +548,8 @@ def build_reward_correctness_fn(generate_fn, tokenizer, log_file=None):
)
# Log correctness metrics with length info
log_metric(
"rewards/correctness", np.mean(correct), reward_kwargs.get("step", 0)
)
log_metric(
"rewards/correctness_std", np.std(correct), reward_kwargs.get("step", 0)
)
log_metric("rewards/correctness", np.mean(correct), reward_kwargs.get("step", 0))
log_metric("rewards/correctness_std", np.std(correct), reward_kwargs.get("step", 0))
# Log length metrics
student_lengths = [len(ans.strip()) for ans in student_answers]
@ -676,9 +625,7 @@ def reward_retry_behavior(completions: list[dict], **reward_kwargs) -> list[floa
if json_count > 1:
has_multiple_json = True
logger.warning(
f"Message contains {json_count} JSON objects, which exceeds the limit of 1"
)
logger.warning(f"Message contains {json_count} JSON objects, which exceeds the limit of 1")
break
# Only reward if no message has multiple JSON objects
@ -692,17 +639,13 @@ def reward_retry_behavior(completions: list[dict], **reward_kwargs) -> list[floa
if total_json_objects > 4:
penalty = 0.1 * (total_json_objects - 4)
base_reward = max(0.2, base_reward - penalty)
logger.debug(
f"Applied penalty for {total_json_objects} total JSON objects: {penalty}"
)
logger.debug(f"Applied penalty for {total_json_objects} total JSON objects: {penalty}")
rewards.append(base_reward)
# Log retry behavior metrics
log_metric("rewards/retry_behavior", np.mean(rewards), reward_kwargs.get("step", 0))
log_metric(
"rewards/retry_behavior_std", np.std(rewards), reward_kwargs.get("step", 0)
)
log_metric("rewards/retry_behavior_std", np.std(rewards), reward_kwargs.get("step", 0))
log_metric(
"metrics/avg_json_per_msg",
np.mean(
@ -737,13 +680,9 @@ def reward_exact_match_chunk_query(prompts, completions, **reward_kwargs):
raise ValueError("chunk_content must be provided in reward_kwargs")
rewards = []
for i, (chat_state, correct_content) in enumerate(
zip(completions, correct_contents)
):
for i, (chat_state, correct_content) in enumerate(zip(completions, correct_contents)):
# Get all ipython messages (search results) from the chat
search_results = [
msg["content"] for msg in chat_state["messages"] if msg["role"] == "ipython"
]
search_results = [msg["content"] for msg in chat_state["messages"] if msg["role"] == "ipython"]
logger.debug(f"Found {len(search_results)} search results for prompt {i}")
# Log ground truth chunk and searched chunks
@ -756,9 +695,7 @@ def reward_exact_match_chunk_query(prompts, completions, **reward_kwargs):
for result in search_results:
if correct_content in result:
found_correct_chunk = True
logger.debug(
f"Found correct chunk content in search results for prompt {i}"
)
logger.debug(f"Found correct chunk content in search results for prompt {i}")
break
if not found_correct_chunk:
@ -796,21 +733,13 @@ def reward_exact_match_chunk_query(prompts, completions, **reward_kwargs):
"metrics/avg_search_results",
np.mean(
[
len(
[
msg["content"]
for msg in chat_state["messages"]
if msg["role"] == "ipython"
]
)
len([msg["content"] for msg in chat_state["messages"] if msg["role"] == "ipython"])
for chat_state in completions
]
),
reward_kwargs.get("step", 0),
)
log_metric(
"metrics/chunk_match_rate", np.mean(rewards), reward_kwargs.get("step", 0)
)
log_metric("metrics/chunk_match_rate", np.mean(rewards), reward_kwargs.get("step", 0))
# Log detailed debugging info
logger.info("Chunk Query Rewards Summary:")
@ -862,9 +791,7 @@ def run_eval(generate_fn, verify_fn, tokenizer, output_file=None, debug_file=Non
f.write(f"Percentage correct: {percent_correct:.2f}%\n\n")
f.write("Individual results:\n")
for i, (q, r, resp) in enumerate(
zip(questions, rewards, final_responses)
):
for i, (q, r, resp) in enumerate(zip(questions, rewards, final_responses)):
f.write(f"\nQ{i + 1}: {q[:100]}...\n")
f.write(f"Correct: {'' if r else ''}\n")
f.write(f"Response: {resp[:150]}...\n")
@ -879,9 +806,7 @@ def run_eval(generate_fn, verify_fn, tokenizer, output_file=None, debug_file=Non
import json
debug_data = []
for i, (q, r, resp, chat) in enumerate(
zip(questions, rewards, final_responses, full_chat_states)
):
for i, (q, r, resp, chat) in enumerate(zip(questions, rewards, final_responses, full_chat_states)):
debug_data.append(
{
"question_id": i,

@ -20,9 +20,7 @@ def load_vectorstore():
embeddings = CustomHuggingFaceEmbeddings()
# Load the FAISS index from the data directory
logger.info(f"Loading FAISS index from: {DATA_DIR}")
vectorstore = FAISS.load_local(
str(DATA_DIR), embeddings, allow_dangerous_deserialization=True
)
vectorstore = FAISS.load_local(str(DATA_DIR), embeddings, allow_dangerous_deserialization=True)
logger.info("Successfully loaded FAISS index")
return vectorstore
except Exception as e:
@ -125,9 +123,7 @@ def get_question_answer(idx=None, return_both: bool = True) -> dict:
# Select question by index
qa_pair = questions[idx]
else:
raise ValueError(
f"Index out of range. Must be between 0 and {len(questions) - 1}"
)
raise ValueError(f"Index out of range. Must be between 0 and {len(questions) - 1}")
question = qa_pair["question"]
answer = qa_pair["answer"]

@ -3,16 +3,6 @@ import os
from unsloth import FastLanguageModel, is_bfloat16_supported
import src.UnslothGRPOTrainerTemp as UnslothGRPOTrainerTemp
# Import reward functions
from src.rl_helpers import (
build_reward_correctness_fn,
get_qa_dataset,
reward_exact_match_chunk_query,
reward_formatting,
reward_retry_behavior,
run_agent,
)
from src.config import (
MODEL_CONFIG,
MODEL_NAME,
@ -24,6 +14,16 @@ from src.config import (
update_log_path,
)
# Import reward functions
from src.rl_helpers import (
build_reward_correctness_fn,
get_qa_dataset,
reward_exact_match_chunk_query,
reward_formatting,
reward_retry_behavior,
run_agent,
)
# Initialize training directories
paths = init_training_dirs()
@ -57,9 +57,7 @@ model = FastLanguageModel.get_peft_model(
# Load datasets
logger.info("Loading datasets")
train_dataset, test_dataset = get_qa_dataset()
logger.info(
f"Loaded {len(train_dataset)} training examples and {len(test_dataset)} test examples"
)
logger.info(f"Loaded {len(train_dataset)} training examples and {len(test_dataset)} test examples")
# Setup training arguments
logger.info("Setting up training arguments")

Loading…
Cancel
Save