feat: refactor whole code base, add logic for training R1 distil base models, change some template and reward logics

- Break down rl_helpers into smaller modules
- Removed deprecated rl_helpers module to streamline the codebase.
- Enhance initial user prompt template inspired by Search-R1
main
thinhlpg 1 month ago
parent c90c03267e
commit 31dcbf5d8a

@ -54,6 +54,8 @@ from trl.trainer.grpo_trainer import (
wandb, wandb,
) )
from src.config import logger
torch_compile_options = { torch_compile_options = {
"epilogue_fusion": True, "epilogue_fusion": True,
"max_autotune": False, "max_autotune": False,
@ -63,11 +65,11 @@ torch_compile_options = {
} }
@torch.compile( # @torch.compile(
dynamic=True, # dynamic=True,
fullgraph=True, # fullgraph=True,
options=torch_compile_options, # options=torch_compile_options,
) # )
def selective_log_softmax(logits, index): def selective_log_softmax(logits, index):
logits = logits.to(torch.float32) logits = logits.to(torch.float32)
selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1) selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1)
@ -82,6 +84,29 @@ def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages)
# All Unsloth Zoo code licensed under LGPLv3 # All Unsloth Zoo code licensed under LGPLv3
old_logits = old_logits.to(torch.float32) old_logits = old_logits.to(torch.float32)
new_logits = new_logits.to(torch.float32) new_logits = new_logits.to(torch.float32)
# Print FULL tensor contents
logger.debug("\n🔍 DETAILED TENSOR ANALYSIS:")
logger.debug("\n1⃣ Input IDs:")
logger.debug(f"Shape: {input_ids.shape}")
# Use tensor.cpu().numpy() to safely print content
logger.debug(f"Type: {mask.dtype}")
logger.debug(f"Sum: {mask.sum().item()}") # Use .item() to get Python scalar
logger.debug("\n3⃣ Old Logits:")
logger.debug(f"Shape: {old_logits.shape}")
logger.debug(f"Type: {old_logits.dtype}")
logger.debug(f"Mean: {old_logits.mean().item():.4f}")
logger.debug("\n4⃣ New Logits:")
logger.debug(f"Shape: {new_logits.shape}")
logger.debug(f"Type: {new_logits.dtype}")
logger.debug(f"Mean: {new_logits.mean().item():.4f}")
logger.debug("\n5⃣ Advantages:")
logger.debug(f"Shape: {advantages.shape}")
input_ids = input_ids.unsqueeze(-1) input_ids = input_ids.unsqueeze(-1)
# x_i - logsumexp(x_i) # x_i - logsumexp(x_i)
@ -90,6 +115,10 @@ def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages)
old = old_x - torch.logsumexp(old_logits, dim=-1) old = old_x - torch.logsumexp(old_logits, dim=-1)
new = new_x - torch.logsumexp(new_logits, dim=-1) new = new_x - torch.logsumexp(new_logits, dim=-1)
logger.debug("\n6⃣ After Gather & LogSumExp:")
logger.debug(f"old_x shape: {old_x.shape}, new_x shape: {new_x.shape}")
logger.debug(f"old shape: {old.shape}, new shape: {new.shape}")
# Reverse KL # Reverse KL
kl_i = torch.exp(old - new) - (old - new) - 1.0 kl_i = torch.exp(old - new) - (old - new) - 1.0
# Full correct reverse KL divergence?? Missing term maybe? # Full correct reverse KL divergence?? Missing term maybe?
@ -109,6 +138,13 @@ def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages)
# See https://github.com/huggingface/trl/pull/2881 # See https://github.com/huggingface/trl/pull/2881
# loss_per_reward = (loss_i * mask).sum(1) / n_mask_per_reward # loss_per_reward = (loss_i * mask).sum(1) / n_mask_per_reward
# loss = loss_per_reward.mean() # loss = loss_per_reward.mean()
# Add print statements here for debugging
logger.debug(f"🚨 Debug: loss_i shape: {loss_i.shape}")
logger.debug(
f"🚨 Debug: mask shape: {mask.shape}"
) # Note: Mask shape might change slightly due to float conversion
loss = (loss_i * mask).sum() / mask.sum() loss = (loss_i * mask).sum() / mask.sum()
# Get metrics as well which are folded # Get metrics as well which are folded
@ -165,14 +201,7 @@ class UnslothEfficientGRPO(torch.autograd.Function):
accumulated_completion_length = torch.zeros(1, device=device) accumulated_completion_length = torch.zeros(1, device=device)
accumulated_mean_kl = torch.zeros(1, device=device) accumulated_mean_kl = torch.zeros(1, device=device)
def accumulate_chunk( def accumulate_chunk(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j, scaling):
new_hidden_states_j,
old_hidden_states_j,
input_ids_j,
mask_j,
advantages_j,
scaling,
):
( (
(chunk_grad_input,), (chunk_grad_input,),
( (
@ -187,14 +216,7 @@ class UnslothEfficientGRPO(torch.autograd.Function):
compute_loss, compute_loss,
argnums=(0,), argnums=(0,),
has_aux=True, has_aux=True,
)( )(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j, scaling)
new_hidden_states_j,
old_hidden_states_j,
input_ids_j,
mask_j,
advantages_j,
scaling,
)
accumulated_loss.add_(unscaled_loss) accumulated_loss.add_(unscaled_loss)
accumulated_completion_length.add_(chunk_completion_length) accumulated_completion_length.add_(chunk_completion_length)
accumulated_mean_kl.add_(chunk_mean_kl) accumulated_mean_kl.add_(chunk_mean_kl)
@ -202,11 +224,11 @@ class UnslothEfficientGRPO(torch.autograd.Function):
pass pass
accumulate_chunk = torch.compile( # accumulate_chunk = torch.compile(
accumulate_chunk, # accumulate_chunk,
fullgraph=True, # fullgraph=True,
options=torch_compile_options, # options=torch_compile_options,
) # )
grad_inputs_chunks = torch.chunk(grad_inputs, chunks=n_chunks, dim=0) grad_inputs_chunks = torch.chunk(grad_inputs, chunks=n_chunks, dim=0)
new_hidden_states = torch.chunk(_new_hidden_states, chunks=n_chunks, dim=0) new_hidden_states = torch.chunk(_new_hidden_states, chunks=n_chunks, dim=0)
@ -228,28 +250,14 @@ class UnslothEfficientGRPO(torch.autograd.Function):
input_ids_j, input_ids_j,
mask_j, mask_j,
advantages_j, advantages_j,
) in zip( ) in zip(grad_inputs_chunks, new_hidden_states, old_hidden_states, input_ids, mask, advantages):
grad_inputs_chunks,
new_hidden_states,
old_hidden_states,
input_ids,
mask,
advantages,
):
mark_dynamic(new_hidden_states_j) mark_dynamic(new_hidden_states_j)
mark_dynamic(old_hidden_states_j) mark_dynamic(old_hidden_states_j)
mark_dynamic(input_ids_j) mark_dynamic(input_ids_j)
mark_dynamic(mask_j) mark_dynamic(mask_j)
grad_inputs_j.copy_( grad_inputs_j.copy_(
accumulate_chunk( accumulate_chunk(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j, scaling)
new_hidden_states_j,
old_hidden_states_j,
input_ids_j,
mask_j,
advantages_j,
scaling,
)
) )
pass pass
@ -624,7 +632,7 @@ class UnslothGRPOConfig(GRPOConfig):
save_strategy = "no" save_strategy = "no"
div = per_device_train_batch_size // num_generations div = per_device_train_batch_size // num_generations
if div * num_generations != per_device_train_batch_size: if div * num_generations != per_device_train_batch_size:
print( logger.debug(
"Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.\nWe will change the batch size of " "Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.\nWe will change the batch size of "
+ str(per_device_train_batch_size) + str(per_device_train_batch_size)
+ " to the `num_generations` of " + " to the `num_generations` of "
@ -971,11 +979,7 @@ class _UnslothGRPOTrainer(Trainer):
self.sampling_params = SamplingParams( self.sampling_params = SamplingParams(
temperature=args.temperature, temperature=args.temperature,
max_tokens=self.max_completion_length, max_tokens=self.max_completion_length,
**getattr( **getattr(getattr(args, "vllm_sampling_params", vLLMSamplingParams()), "_set_kwargs", {}),
getattr(args, "vllm_sampling_params", vLLMSamplingParams()),
"_set_kwargs",
{},
),
) )
else: else:
self.generation_config = GenerationConfig( self.generation_config = GenerationConfig(
@ -1038,9 +1042,7 @@ class _UnslothGRPOTrainer(Trainer):
with torch.amp.autocast(device_type="cuda", dtype=self._autocast_dtype): with torch.amp.autocast(device_type="cuda", dtype=self._autocast_dtype):
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
logits = model( logits = model(
input_ids=input_ids, input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1
attention_mask=attention_mask,
logits_to_keep=logits_to_keep + 1,
).logits ).logits
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
@ -1056,25 +1058,31 @@ class _UnslothGRPOTrainer(Trainer):
return None return None
def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]: def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
logger.debug("\n🔍 DEBUG: Starting _prepare_inputs")
device = self.accelerator.device device = self.accelerator.device
prompts = [x["prompt"] for x in inputs] prompts = [x["prompt"] for x in inputs]
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs] prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
logger.debug("\n1⃣ Before tokenization:")
logger.debug(f"Number of prompts: {len(prompts)}")
logger.debug(f"Sample prompt text length: {len(prompts_text[0]) if prompts_text else 0}")
prompt_inputs = self.processing_class( prompt_inputs = self.processing_class(
prompts_text, prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
return_tensors="pt",
padding=True,
padding_side="left",
add_special_tokens=False,
) )
prompt_inputs = super()._prepare_inputs(prompt_inputs) prompt_inputs = super()._prepare_inputs(prompt_inputs)
prompt_ids, prompt_mask = ( prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
prompt_inputs["input_ids"],
prompt_inputs["attention_mask"], logger.debug("\n2⃣ After initial tokenization:")
) logger.debug(f"prompt_ids shape: {prompt_ids.shape}")
logger.debug(f"prompt_mask shape: {prompt_mask.shape}")
logger.debug(f"prompt_mask sum: {prompt_mask.sum().item()}")
if self.max_prompt_length is not None: if self.max_prompt_length is not None:
prompt_ids = prompt_ids[:, -self.max_prompt_length :] prompt_ids = prompt_ids[:, -self.max_prompt_length :]
prompt_mask = prompt_mask[:, -self.max_prompt_length :] prompt_mask = prompt_mask[:, -self.max_prompt_length :]
logger.debug("\n3⃣ After prompt length truncation:")
logger.debug(f"prompt_ids shape: {prompt_ids.shape}")
logger.debug(f"prompt_mask shape: {prompt_mask.shape}")
# Generate completions using either vLLM or regular generation # Generate completions using either vLLM or regular generation
if self.args.use_vllm: if self.args.use_vllm:
@ -1086,7 +1094,7 @@ class _UnslothGRPOTrainer(Trainer):
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process # Generate completions using vLLM: gather all prompts and use them in a single call in the main process
all_prompts_text = gather_object(prompts_text) all_prompts_text = gather_object(prompts_text)
if self.accelerator.is_main_process: if self.accelerator.is_main_process:
print(all_prompts_text) logger.debug(all_prompts_text)
generate_fn = lambda prompts_text: self.llm.generate( generate_fn = lambda prompts_text: self.llm.generate(
prompts_text, prompts_text,
sampling_params=self.sampling_params, sampling_params=self.sampling_params,
@ -1104,6 +1112,11 @@ class _UnslothGRPOTrainer(Trainer):
prompt_inputs = agentic_outputs.prompt_tokens prompt_inputs = agentic_outputs.prompt_tokens
completion_ids = agentic_outputs.response_tokens completion_ids = agentic_outputs.response_tokens
completion_mask = agentic_outputs.response_masks completion_mask = agentic_outputs.response_masks
for i in range(len(completion_ids)):
logger.debug(f"prompt_inputs {i} len before padding: {len(prompt_inputs[i])}")
logger.debug(f"completion_ids {i} len before padding: {len(completion_ids[i])}")
logger.debug(f"completion_mask {i} len before padding: {len(completion_mask[i])}")
prompt_ids = pad( prompt_ids = pad(
prompt_inputs, prompt_inputs,
padding_value=self.processing_class.pad_token_id, padding_value=self.processing_class.pad_token_id,
@ -1114,6 +1127,12 @@ class _UnslothGRPOTrainer(Trainer):
padding_value=0, padding_value=0,
padding_side="right", padding_side="right",
).to(device) ).to(device)
for i in range(len(completion_ids)):
logger.debug(f"prompt_inputs {i} len after padding: {len(prompt_inputs[i])}")
logger.debug(f"prompt_ids {i} len after padding: {len(prompt_ids[i])}")
logger.debug(f"completion_mask {i} len after padding: {len(completion_mask[i])}")
else: else:
outputs = generate_fn(all_prompts_text) outputs = generate_fn(all_prompts_text)
completion_ids = [out.token_ids for completions in outputs for out in completions.outputs] completion_ids = [out.token_ids for completions in outputs for out in completions.outputs]
@ -1130,15 +1149,21 @@ class _UnslothGRPOTrainer(Trainer):
# Pad the completions, and concatenate them with the prompts # Pad the completions, and concatenate them with the prompts
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
logger.debug("\n4⃣ Before completion padding:")
logger.debug(f"completion_ids shapes: {[ids.shape for ids in completion_ids]}")
completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id) completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
logger.debug("\n5⃣ After completion padding:")
logger.debug(f"completion_ids shape: {completion_ids.shape}")
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
logger.debug("\n6⃣ After concatenation:")
logger.debug(f"prompt_completion_ids shape: {prompt_completion_ids.shape}")
else: else:
# Regular generation path # Regular generation path
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model: with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
prompt_completion_ids = unwrapped_model.generate( prompt_completion_ids = unwrapped_model.generate(
prompt_ids, prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config
attention_mask=prompt_mask,
generation_config=self.generation_config,
) )
# Compute prompt length and extract completion ids # Compute prompt length and extract completion ids
@ -1148,11 +1173,52 @@ class _UnslothGRPOTrainer(Trainer):
if not self.use_agentic_generate: if not self.use_agentic_generate:
# Mask everything after the first EOS token # Mask everything after the first EOS token
logger.debug("\n🔍 Starting EOS token detection and masking:")
logger.debug(f"completion_ids shape: {completion_ids.shape}")
logger.debug(f"eos_token_id: {self.processing_class.eos_token_id}")
# Debug EOS detection
is_eos = completion_ids == self.processing_class.eos_token_id is_eos = completion_ids == self.processing_class.eos_token_id
logger.debug("\n7⃣ EOS Detection Details:")
logger.debug(f"is_eos shape: {is_eos.shape}")
logger.debug(f"Sample is_eos values (first sequence):\n{is_eos[0]}")
logger.debug(f"Any EOS tokens found: {is_eos.any().item()}")
logger.debug(f"EOS positions: {is_eos.nonzero()}")
# Debug EOS index tensor creation
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
logger.debug("\n8⃣ EOS Index Creation:")
logger.debug(f"eos_idx initial shape: {eos_idx.shape}")
logger.debug(f"eos_idx initial values: {eos_idx}")
# Debug the complex indexing operation
logger.debug("\n9⃣ EOS Position Analysis:")
logger.debug(f"Sequences with EOS: {is_eos.any(dim=1).sum().item()}")
logger.debug(f"First EOS positions: {is_eos.int().argmax(dim=1)}")
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
logger.debug("\n🔟 After EOS Index Update:")
logger.debug(f"Updated eos_idx values: {eos_idx}")
# Debug sequence indices creation
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
logger.debug("\n1⃣1⃣ Sequence Indices:")
logger.debug(f"sequence_indices shape: {sequence_indices.shape}")
logger.debug(f"Sample sequence_indices (first row):\n{sequence_indices[0]}")
# Debug final mask creation
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
logger.debug("\n1⃣2⃣ Final Completion Mask:")
logger.debug(f"completion_mask shape: {completion_mask.shape}")
logger.debug(f"Sample mask (first sequence):\n{completion_mask[0]}")
logger.debug("Mask statistics:")
logger.debug(f"- Total 1s: {completion_mask.sum().item()}")
logger.debug(f"- Average sequence length: {completion_mask.sum(dim=1).float().mean().item():.2f}")
# Add a final validation check
logger.debug("\n7⃣ Final Validation:")
logger.debug(f"Input shape: {completion_ids.shape}")
logger.debug(f"Mask shape: {completion_mask.shape}")
# Concatenate prompt_mask with completion_mask for logit computation # Concatenate prompt_mask with completion_mask for logit computation
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C) attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C)
@ -1173,18 +1239,12 @@ class _UnslothGRPOTrainer(Trainer):
): ):
if self.ref_model is not None: if self.ref_model is not None:
ref_per_token_logps = self._get_per_token_logps( ref_per_token_logps = self._get_per_token_logps(
self.ref_model, self.ref_model, prompt_completion_ids, attention_mask, logits_to_keep
prompt_completion_ids,
attention_mask,
logits_to_keep,
) )
else: else:
with self.accelerator.unwrap_model(self.model, keep_fp32_wrapper=False).disable_adapter(): with self.accelerator.unwrap_model(self.model, keep_fp32_wrapper=False).disable_adapter():
ref_per_token_logps = self._get_per_token_logps( ref_per_token_logps = self._get_per_token_logps(
self.model, self.model, prompt_completion_ids, attention_mask, logits_to_keep
prompt_completion_ids,
attention_mask,
logits_to_keep,
) )
# Decode the generated completions # Decode the generated completions
@ -1211,11 +1271,7 @@ class _UnslothGRPOTrainer(Trainer):
else: else:
texts = [p + c for p, c in zip(prompts, completions)] texts = [p + c for p, c in zip(prompts, completions)]
reward_inputs = reward_processing_class( reward_inputs = reward_processing_class(
texts, texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
return_tensors="pt",
padding=True,
padding_side="right",
add_special_tokens=False,
) )
reward_inputs = super()._prepare_inputs(reward_inputs) reward_inputs = super()._prepare_inputs(reward_inputs)
with ( with (
@ -1273,7 +1329,7 @@ class _UnslothGRPOTrainer(Trainer):
# Log the metrics # Log the metrics
reward_per_func = rewards_per_func.mean(0) reward_per_func = rewards_per_func.mean(0)
print("rewards_per_func:", reward_per_func) logger.debug("rewards_per_func:", reward_per_func)
for i, reward_func in enumerate(self.reward_funcs): for i, reward_func in enumerate(self.reward_funcs):
if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
reward_func_name = reward_func.config._name_or_path.split("/")[-1] reward_func_name = reward_func.config._name_or_path.split("/")[-1]
@ -1318,10 +1374,7 @@ class _UnslothGRPOTrainer(Trainer):
# Compute the per-token log probabilities for the model # Compute the per-token log probabilities for the model
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
completion_ids, completion_mask = ( completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
inputs["completion_ids"],
inputs["completion_mask"],
)
input_ids = torch.cat([prompt_ids, completion_ids], dim=1) input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
bsz, qlen = input_ids.shape bsz, qlen = input_ids.shape
# attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
@ -1369,13 +1422,7 @@ class _UnslothGRPOTrainer(Trainer):
self._metrics["kl"].append(mean_kl.item()) self._metrics["kl"].append(mean_kl.item())
return loss return loss
def prediction_step( def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: Optional[list[str]] = None):
self,
model,
inputs,
prediction_loss_only,
ignore_keys: Optional[list[str]] = None,
):
inputs = self._prepare_inputs(inputs) inputs = self._prepare_inputs(inputs)
with torch.no_grad(): with torch.no_grad():
with self.compute_loss_context_manager(): with self.compute_loss_context_manager():
@ -1593,7 +1640,7 @@ class UnslothGRPOTrainer(_UnslothGRPOTrainer):
from transformers import __version__ as transformers_version from transformers import __version__ as transformers_version
if Version(transformers_version) <= Version("4.45.2"): if Version(transformers_version) <= Version("4.45.2"):
print( logger.debug(
"**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n" "**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n"
"`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`" "`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`"
) )

@ -0,0 +1,43 @@
"""
Main package exports for RL helpers.
"""
from trl.trainer.grpo_trainer import apply_chat_template
from src.agent import Agent
from src.config import logger
from src.evaluation import check_student_answers, run_eval, verify
from src.prompts import build_user_prompt, format_search_results, get_system_prompt
from src.rewards import (
build_reward_correctness_fn,
reward_em_chunk,
reward_format,
reward_retry,
)
from src.search_module import get_qa_dataset, search
from src.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter
__all__ = [
# Prompts
"get_system_prompt",
"build_user_prompt",
"format_search_results",
"apply_chat_template",
# Agent
"Agent",
"LlamaTokenizerAdapter",
"R1DistilTokenizerAdapter",
# Rewards
"build_reward_correctness_fn",
"reward_format",
"reward_retry",
"reward_em_chunk",
# Evaluation
"run_eval",
"check_student_answers",
"verify",
# Search
"get_qa_dataset",
"search",
"logger",
]

@ -0,0 +1,237 @@
"""
Core agent functionality for handling tool-based conversations.
This module provides a base agent class for handling tool-based conversations.
"""
import re
from dataclasses import dataclass
import torch
from trl.trainer.grpo_trainer import apply_chat_template
from src.config import logger
from src.prompts import build_user_prompt, get_system_prompt
from src.search_module import search
from src.tokenizer_adapter import TokenizerAdapter
@dataclass
class AgenticOutputs:
"""Outputs from running the agent on a batch of questions."""
prompt_tokens: list[torch.Tensor]
response_tokens: list[torch.Tensor]
response_masks: list[torch.Tensor]
final_response_str: list[str]
full_chat_states: list[dict]
class Agent:
"""Base agent class for handling tool-based conversations."""
def __init__(self, tokenizer_adapter: TokenizerAdapter):
"""Initialize the agent with a tokenizer adapter."""
self.tokenizer_adapter = tokenizer_adapter
def get_initial_chat(self, question: str) -> dict:
"""Initialize a chat state with the question."""
return {
"messages": [
{"role": "system", "content": get_system_prompt()},
{"role": "user", "content": build_user_prompt(question)},
]
}
def extract_search_query(self, text: str) -> str | None:
"""Extract search query from text between <search> tags."""
pattern = re.compile(r"<search>(.*?)</search>", re.DOTALL)
matches = pattern.findall(text)
return matches[-1] if matches else None
def run_agent_generations(self, generate_fn, tokenizer, chat_states: list[dict]) -> list[dict]:
"""Run generation for chat states requiring assistant responses."""
logger.debug(f"Starting generation for {len(chat_states)} chat states")
prompts = []
batch_indices = []
for idx, chat_state in enumerate(chat_states):
if chat_state.get("finished"):
logger.debug(f"Chat state {idx} already finished, skipping")
continue
if chat_state["messages"][-1]["role"] in ["ipython", "user"]:
prompt = apply_chat_template(chat_state, tokenizer=tokenizer)["text"]
prompts.append(prompt)
batch_indices.append(idx)
logger.debug(f"Added prompt for chat state {idx}")
if prompts:
logger.info(f"Generating responses for {len(prompts)} prompts")
responses = generate_fn(prompts)
for i, idx in enumerate(batch_indices):
chat_state = chat_states[idx]
response = responses[i]
if hasattr(response, "outputs"):
full_response = response.outputs[0].text
else:
full_response = response
assistant_response = full_response.split(self.tokenizer_adapter.get_assistant_marker())[-1]
chat_state["messages"].append({"role": "assistant", "content": assistant_response})
logger.debug(f"Added assistant response to chat state {idx}")
else:
logger.debug("No prompts to generate responses for")
return chat_states
def check_finished_chats(self, chat_states: list[dict]) -> list[dict]:
"""Check which chat states are finished (no more search queries)."""
for chat_state in chat_states:
if chat_state.get("finished"):
continue
assert chat_state["messages"][-1]["role"] == "assistant", "Expected the last role to be assistant"
assistant_response = chat_state["messages"][-1]["content"]
if not re.search(r"<search>.*?</search>", assistant_response, re.DOTALL):
chat_state["finished"] = True
return chat_states
def run_tool_calls(self, chat_states: list[dict]) -> list[dict]:
"""Execute tool calls found in chat states."""
logger.debug(f"Running tool calls for {len(chat_states)} chat states")
for chat_state in chat_states:
if chat_state.get("finished"):
logger.debug("Chat state already finished, skipping tool calls")
continue
assert chat_state["messages"][-1]["role"] == "assistant", (
"Expected the last role to be assistant to run tool calls"
)
try:
assistant_response = chat_state["messages"][-1]["content"]
search_query = self.extract_search_query(assistant_response)
if search_query:
logger.info(f"🔍 Search Query: {search_query}")
results = search(search_query, return_type=str, results=2)
formatted_results = f"<information>{results}</information>"
logger.info(f" Information: {formatted_results}")
# chat_state["messages"].append({"role": "ipython", "content": formatted_results})
chat_state["messages"].append({"role": "user", "content": formatted_results})
logger.debug("Added search results to chat state")
except Exception as e:
logger.error(f"Error during tool call: {str(e)}")
chat_state["messages"].append({"role": "system", "content": f"Error during post-processing: {str(e)}"})
chat_state["finished"] = True
return chat_states
def get_chat_num_tokens(self, chat_state: dict, tokenizer) -> int:
"""Get number of tokens in chat state."""
chat_text = apply_chat_template(chat_state, tokenizer=tokenizer)["text"]
return tokenizer(chat_text, add_special_tokens=False, return_tensors="pt")["input_ids"].squeeze().shape[0]
def check_exceeded_max_new_tokens(self, chat_states: list[dict], max_new_tokens: int, tokenizer) -> list[dict]:
"""Check if any chat state has exceeded max new tokens."""
for chat_state in chat_states:
if chat_state.get("finished"):
continue
initial_length = chat_state["initial_length"]
new_length = self.get_chat_num_tokens(chat_state, tokenizer)
if new_length - initial_length > max_new_tokens:
chat_state["finished"] = True
return chat_states
def run_agent(
self,
generate_fn,
tokenizer,
questions: list[str],
max_generations: int = 5,
max_new_tokens: int = 4096,
correct_contents: list[str] | None = None,
) -> AgenticOutputs:
"""Run the agent to completion for a batch of questions.
This method follows the same flow as rl_helpers.py:
1. Initialize chat states with questions
2. Run agent loop (generations, check finished, tool calls, check tokens)
3. Process final outputs (split prompt/response, get masks)
The key difference from our previous implementation is in how we handle
the final tokenization and masking, which now matches rl_helpers.py exactly.
"""
# Step 1: Initialize chat states with questions
chat_states = [self.get_initial_chat(q) for q in questions]
if correct_contents:
for chat_state, correct_content in zip(chat_states, correct_contents):
chat_state["correct_content"] = correct_content
# Set initial token lengths for each chat state
for chat_state in chat_states:
chat_state["initial_length"] = self.get_chat_num_tokens(chat_state, tokenizer)
# Step 2: Run agent loop
for i in range(max_generations):
chat_states = self.run_agent_generations(generate_fn, tokenizer, chat_states)
chat_states = self.check_finished_chats(chat_states)
chat_states = self.run_tool_calls(chat_states)
chat_states = self.check_exceeded_max_new_tokens(chat_states, max_new_tokens, tokenizer)
# Step 3: Process final outputs
# Get the final answers from each chat state
answers = [chat["messages"][-1]["content"] for chat in chat_states]
# Convert chat states to text format for tokenization
str_chats = [apply_chat_template(chat, tokenizer=tokenizer)["text"] for chat in chat_states]
prompt_toks, response_toks, response_masks = [], [], []
# Process each chat state to get tokens and masks
for str_chat in str_chats:
try:
# Split into prompt and response parts
# Note: If assistant marker is missing, split_prompt_assistant will return (full_text, "")
prompt_text, response_text = self.tokenizer_adapter.split_prompt_assistant(str_chat)
# Get prompt tokens
prompt_toks.append(
tokenizer(prompt_text, add_special_tokens=False, return_tensors="pt")["input_ids"].squeeze()
)
# Get response tokens (truncated to max_new_tokens)
response_toks.append(
tokenizer(response_text, add_special_tokens=False, return_tensors="pt")["input_ids"].squeeze()[
:max_new_tokens
]
)
# Get full mask and slice it properly
# This matches rl_helpers.py exactly:
# 1. Get full mask for entire text
# 2. Slice from prompt length to end
# 3. Truncate to max_new_tokens
full_mask = self.tokenizer_adapter.get_mask(str_chat, tokenizer)
prompt_len = prompt_toks[-1].shape[0]
mask = full_mask[prompt_len:][:max_new_tokens]
response_masks.append(mask)
# debug if the tokens and masks are of same length by logger info
logger.debug(f"Prompt tokens length: {len(prompt_toks[-1])}")
logger.debug(f"Mask length: {len(mask)}")
logger.debug(f"Response tokens length: {len(response_toks[-1])}")
except Exception:
# If anything fails, add empty tensors
# This matches rl_helpers.py's behavior of not handling errors explicitly
prompt_toks.append(torch.tensor([], dtype=torch.long))
response_toks.append(torch.tensor([], dtype=torch.long))
response_masks.append(torch.tensor([], dtype=torch.long))
# Return final outputs
return AgenticOutputs(
prompt_tokens=prompt_toks,
response_tokens=response_toks,
response_masks=response_masks,
final_response_str=answers,
full_chat_states=chat_states,
)

@ -19,6 +19,8 @@ LOG_FOLDER = PROJ_ROOT / "logs"
# Model configuration # Model configuration
# MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" # MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
# MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
# MODEL_NAME = "unsloth/Qwen2-1.5B" # Smoke test first
device_id = 1 if os.environ.get("CUDA_VISIBLE_DEVICES") == "1" else torch.cuda.current_device() device_id = 1 if os.environ.get("CUDA_VISIBLE_DEVICES") == "1" else torch.cuda.current_device()
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
@ -100,21 +102,21 @@ def _init_logging(env: str = "development") -> None:
file_format = "{time:YYYY-MM-DD at HH:mm:ss} | {level} | {name}:{function}:{line} - {message}" file_format = "{time:YYYY-MM-DD at HH:mm:ss} | {level} | {name}:{function}:{line} - {message}"
# Add console logging with DEBUG level # Add console logging with INFO level (minimal terminal output)
logger.add( logger.add(
sys.stderr, sys.stderr,
format=console_format, format=console_format,
level="DEBUG", # Always use DEBUG level level="INFO", # "INFO", # Changed from DEBUG to INFO for minimal terminal output
colorize=True, colorize=True,
backtrace=True, backtrace=True,
diagnose=True, # Always enable diagnostics diagnose=True,
) )
# Add default file logging to ./logs directory with DEBUG level # Add default file logging to ./logs directory with DEBUG level (full details)
logger.add( logger.add(
LOG_FOLDER / "app.log", LOG_FOLDER / "app.log",
format=file_format, format=file_format,
level="DEBUG", # Always use DEBUG level level="DEBUG", # Keep DEBUG level for full file logging
rotation="500 MB", rotation="500 MB",
retention="7 days", retention="7 days",
compression="zip", compression="zip",
@ -235,44 +237,6 @@ def setup_logger(module_name=None, create_dirs: bool = False):
return logger return logger
# Tensorboard writer singleton
_tensorboard_writer = None
# Safe tensorboard logging function
def log_metric(key, value, step=0):
"""
Log a metric safely to tensorboard if writer is available.
Args:
key: Metric name
value: Metric value
step: Training step
"""
global _tensorboard_writer
# Skip tensorboard logging if disabled in config
if TRAINING_CONFIG.get("report_to") != "tensorboard":
logger.debug(f"Tensorboard disabled. Metric: {key}={value} (step {step})")
return
# Get paths and initialize writer if needed
paths = get_paths(create_dirs=False)
if paths["tensorboard_dir"].exists():
# Only create writer once
if _tensorboard_writer is None:
from torch.utils.tensorboard.writer import SummaryWriter
_tensorboard_writer = SummaryWriter(paths["tensorboard_dir"])
logger.debug(f"Created tensorboard writer at {paths['tensorboard_dir']}")
# Add scalar using existing writer
_tensorboard_writer.add_scalar(key, value, step)
# No need to close the writer - it will be closed at process exit
else:
logger.debug(f"Tensorboard metric: {key}={value} (step {step})")
# Initialize logging on module import # Initialize logging on module import
env = os.getenv("APP_ENV", "development") env = os.getenv("APP_ENV", "development")
_init_logging(env=env) _init_logging(env=env)

@ -0,0 +1,238 @@
"""
Evaluation utilities for RL training.
"""
import inspect
from datetime import datetime
from src.agent import Agent
from src.config import logger
from src.search_module import get_qa_dataset
from src.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter
async def verify(student_answer: str, question: str, answer: str) -> bool:
"""
Verify if student's answer matches the correct answer.
Args:
student_answer: The model's answer
question: The original question
answer: The ground truth answer
Returns:
bool: True if answer is correct, False otherwise
"""
logger.debug(f"Verifying answer for question: {question}")
logger.debug(f"Student answer: {student_answer}")
logger.debug(f"Correct answer: {answer}")
# Simple string matching for now
# TODO: Implement more sophisticated matching
return student_answer.strip().lower() == answer.strip().lower()
def check_student_answers(
questions: list[str],
answers: list[str],
student_answers: list, # Can be strings or dicts
vllm_generate_func,
tokenizer,
log_file=None,
) -> list[bool]:
"""
Evaluates a list of student answers against the true answers using a vLLM generate function.
Args:
questions: List of questions
answers: List of correct answers
student_answers: List of student answers to evaluate
vllm_generate_func: Function to generate verification responses
tokenizer: Tokenizer for formatting prompts
log_file: Optional path to write detailed results
Returns:
List of boolean results (True for correct answers)
"""
logger.info(f"Checking {len(questions)} student answers")
if not (len(questions) == len(answers) == len(student_answers)):
logger.error("Mismatched lengths between questions, answers, and student answers")
raise ValueError("The number of questions, answers, and student answers must be equal.")
prompts = []
for question, answer, student_ans in zip(questions, answers, student_answers):
prompt_text = (
"You are grading a student's answer to a question. For the following question, "
"compare the student's answer to the correct answer. Reply with 'Yes' if the student's answer contains the correct information, "
"even if it's not an exact match. If the student's answer doesn't contain the right information or is completely incorrect, reply with 'No'.\n\n"
f"Question: {question}\n"
f"Correct Answer: {answer}\n"
f"Student Answer: {student_ans}\n\n"
"Your response should be just 'Yes' or 'No'."
)
formatted_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt_text}],
tokenize=False,
add_generation_prompt=True,
)
prompts.append(formatted_prompt)
logger.debug(f"Created verification prompt for question: {question[:50]}...")
logger.info("Generating verification responses")
responses = vllm_generate_func(prompts)
responses_text = []
for response in responses:
# Handle different response formats
if hasattr(response, "outputs"):
try:
responses_text.append(response.outputs[0].text)
except (AttributeError, IndexError):
# Fallback for simple string responses
responses_text.append(str(response))
else:
responses_text.append(str(response))
logger.debug(f"Got {len(responses_text)} verification responses")
results = []
for response in responses_text:
results.append("yes" in response.lower())
logger.debug(f"Verification result: {'yes' in response.lower()}")
logger.info(f"Verification complete. {sum(results)}/{len(results)} answers correct")
# Append the QA details and verifier's response to the specified log file
if log_file:
with open(log_file, "a") as file:
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
file.write(f"\n📝 === QA Evaluation at {timestamp} ===\n")
file.write(f"📂 File: {__file__}\n")
# Get current frame info safely
frame = inspect.currentframe()
if frame:
file.write(f"📍 Line: {frame.f_lineno}\n")
# Don't forget to delete the frame to avoid reference cycles
del frame
file.write("=" * 80 + "\n")
for i, (question, answer, student_ans, verifier_response) in enumerate(
zip(questions, answers, student_answers, responses_text)
):
file.write(f"\n❓ Question {i + 1}:\n")
file.write("-" * 40 + "\n")
file.write(f"📋 Question: {question}\n")
file.write(f"✅ Correct Answer: {answer}\n")
file.write(f"👨‍🎓 Student Answer: {student_ans}\n")
file.write(f"🔍 Verifier said: {verifier_response}\n")
# Add search results if available in the chat state
if isinstance(student_ans, dict) and "messages" in student_ans:
# Get messages from dict
messages = student_ans.get("messages", [])
search_results = [msg.get("content", "") for msg in messages if msg.get("role") == "ipython"]
if search_results:
file.write("\n🔎 Search Results:\n")
for j, result in enumerate(search_results, 1):
file.write(f"\nSearch {j}:\n{result}\n")
file.write("-" * 40 + "\n")
file.write(
f"\n📊 Summary: {sum(results)}/{len(results)} answers correct ({sum(results) / len(results) * 100:.2f}%)\n"
)
file.write("=" * 80 + "\n\n")
return results
def run_eval(generate_fn, verify_fn, tokenizer, output_file=None, debug_file=None):
"""
Run evaluation on the test dataset and return results.
Args:
generate_fn: Function to generate completions
verify_fn: Function to verify results
tokenizer: Tokenizer for processing text
output_file: Path to save evaluation results summary
debug_file: Path to save detailed debug information
Returns:
full_chat_states: The chat states from evaluation
"""
train_dataset, test_dataset = get_qa_dataset()
questions = test_dataset["prompt"]
# Create agent with appropriate adapter based on tokenizer
tokenizer_name = tokenizer.name_or_path.lower()
if "deepseek-r1-distill" in tokenizer_name:
adapter = R1DistilTokenizerAdapter()
elif "llama" in tokenizer_name:
adapter = LlamaTokenizerAdapter()
else:
adapter = R1DistilTokenizerAdapter()
agent = Agent(adapter)
agentic_outputs = agent.run_agent(generate_fn, tokenizer, questions)
full_chat_states = agentic_outputs.full_chat_states
final_responses = agentic_outputs.final_response_str
rewards = verify_fn(questions, full_chat_states, answer=test_dataset["answer"])
# Calculate results
percent_correct = sum(rewards) / len(rewards) * 100
# Log results to console
logger.info("RESULTS:")
logger.info(f"percentage of correct answers: {percent_correct:.2f}%")
logger.info("=" * 30)
# Save results to file if specified
if output_file:
try:
with open(output_file, "w") as f:
f.write("EVALUATION RESULTS\n")
f.write("=================\n\n")
f.write(f"Total questions: {len(questions)}\n")
f.write(f"Correct answers: {sum(rewards)}\n")
f.write(f"Percentage correct: {percent_correct:.2f}%\n\n")
f.write("Individual results:\n")
for i, (q, r, resp) in enumerate(zip(questions, rewards, final_responses)):
f.write(f"\nQ{i + 1}: {q[:100]}...\n")
f.write(f"Correct: {'' if r else ''}\n")
f.write(f"Response: {resp[:150]}...\n")
f.write("-" * 40 + "\n")
logger.info(f"Saved evaluation results to {output_file}")
except Exception as e:
logger.error(f"Error saving results file: {e}")
# Save debug information if specified
if debug_file:
try:
import json
debug_data = []
for i, (q, r, resp, chat) in enumerate(zip(questions, rewards, final_responses, full_chat_states)):
debug_data.append(
{
"question_id": i,
"question": q,
"is_correct": bool(r),
"final_response": resp,
"chat_state": {
k: str(v) if isinstance(v, (list, dict)) else v
for k, v in chat.items()
if k != "tokenizer"
},
}
)
with open(debug_file, "w") as f:
json.dump(debug_data, f, indent=2)
logger.info(f"Saved debug information to {debug_file}")
except Exception as e:
logger.error(f"Error saving debug file: {e}")
return full_chat_states

@ -0,0 +1,67 @@
"""
Prompt-related functions for handling system and user prompts.
"""
from datetime import datetime
def get_system_prompt():
"""Get the system prompt with current date."""
current_date = datetime.now().strftime("%d %b %Y")
return f"""Cutting Knowledge Date: December 2023
Today Date: {current_date}
You are a helpful assistant with search capabilities.
"""
def build_user_prompt(q):
"""
Build a user prompt with the question using the new template format.
Args:
q (str): The question to ask
Returns:
str: Formatted user prompt
"""
user_prompt = f"""Answer the given question. \
You must conduct reasoning inside <think> and </think> first every time you get new information. \
After reasoning, if you find you lack some knowledge, you can call a search engine by <search> query </search>. \
Based on the user's core intent, formulate the most effective search query using specific, descriptive keywords that differentiate the topic clearly. \
Aim for queries that resemble how an expert searcher might phrase it, like using "compare lithium-ion vs solid-state battery efficiency" rather than just "batteries". \
The document will be provided inside <information> and </information> tags to you later. \
You can search as many turns as you want, but only one search query per turn. \
If you find no further external knowledge needed, you can directly provide the answer inside <answer> and </answer>, without detailed illustrations. \
Only answer when you have 100% confidence in the search results, else continue searching. \
Question: {q}\n"""
return user_prompt
def format_search_results(results: str | list[str]) -> str:
"""
Format search results for display, matching the format from infer.py.
Each result should be in the format: "Doc X(Title: Y) content"
Args:
results: Search results as string or list of strings
Returns:
Formatted search results with document titles
"""
if isinstance(results, list):
# If results are already in the correct format, just join them
if any("Doc" in r and "Title:" in r for r in results):
content = "\n".join(results)
else:
# If results are raw content, format them with default titles
content = "\n".join([f"Doc {i + 1}(Title: Document {i + 1})\n{r}" for i, r in enumerate(results)])
else:
# If single result is already formatted, use it as is
if "Doc" in results and "Title:" in results:
content = results
else:
# If single result is raw content, format it with default title
content = f"Doc 1(Title: Document 1)\n{results}"
return content

@ -0,0 +1,312 @@
"""
Reward functions for RL training.
"""
import re
import numpy as np
from src.config import logger
from src.evaluation import check_student_answers
def build_reward_correctness_fn(
vllm_generate_func,
tokenizer,
):
"""Build a reward function that checks correctness of student answers.
Args:
vllm_generate_func: Function to generate answers using vLLM
tokenizer: Tokenizer for the model
Returns:
A reward function that takes prompts and completions and returns correctness scores
"""
def reward_correctness(prompts: list, completions: list, **reward_kwargs) -> list:
"""Calculate reward based on correctness of student answers.
Args:
prompts: List of input prompts
completions: List of model completions
**reward_kwargs: Additional arguments for reward calculation
Returns:
List of correctness scores between 0 and 1
"""
teacher_answers = reward_kwargs["answer"]
student_answers = [completion["messages"][-1]["content"] for completion in completions]
# Log non-exact matches
for i, (student, teacher) in enumerate(zip(student_answers, teacher_answers)):
if student.strip().lower() != teacher.strip().lower():
logger.debug(f"Non-exact match at index {i}:\nStudent: {student}\nTeacher: {teacher}")
correct = check_student_answers(
prompts,
teacher_answers,
student_answers,
vllm_generate_func=vllm_generate_func,
tokenizer=tokenizer,
)
# Log correctness metrics with length info
logger.info(f"Correctness metrics: {correct}")
logger.info(f"Average correctness: {np.mean(correct):.2f}")
logger.info(f"Standard deviation: {np.std(correct):.2f}")
# Log length metrics
student_lengths = [len(ans.strip()) for ans in student_answers]
teacher_lengths = [len(ans.strip()) for ans in teacher_answers]
logger.info(f"Student lengths: {student_lengths}")
logger.info(f"Teacher lengths: {teacher_lengths}")
logger.info(f"Average student length: {np.mean(student_lengths):.2f}")
logger.info(f"Average teacher length: {np.mean(teacher_lengths):.2f}")
logger.info(f"Length ratio: {np.mean(student_lengths) / np.mean(teacher_lengths):.2f}")
return correct
return reward_correctness
def reward_format(prompts: list, completions: list, **reward_kwargs) -> list:
"""Reward function that checks if the completion follows the required format with proper tags.
Args:
prompts: List of input prompts
completions: List of completion dictionaries containing messages
**reward_kwargs: Additional reward parameters
Returns:
list: List of rewards (1.0 for valid format, 0.0 for invalid)
"""
# Regex patterns for each tag type - no markdown allowed
think_pattern = r"<think>[\s\S]*?</think>"
search_pattern = r"<search>[\s\S]*?</search>"
answer_pattern = r"<answer>[\s\S]*?</answer>"
# Information tag patterns - handle multiple variants
info_patterns = [
r"<information>[\s\S]*?</information>", # Standard
r"<info>[\s\S]*?</info>", # Shortened
r"<Info[\w]*>[\s\S]*?</Info[\w]*>", # Capitalized variants
r"<INFORMATION>[\s\S]*?</INFORMATION>", # Uppercase
r"<INFO>[\s\S]*?</INFO>", # Uppercase shortened
]
# Invalid patterns (bold/italic tags)
invalid_patterns = [
r"\*\*<\/?(?:think|search|answer|information|info)>\*\*", # Bold tags
r"\*<\/?(?:think|search|answer|information|info)>\*", # Italic tags
r"_<\/?(?:think|search|answer|information|info)>_", # Underscore italic
]
rewards = []
for completion in completions:
messages = completion.get("messages", [])
assistant_msgs = [msg["content"] for msg in messages if msg["role"] == "assistant"]
if not assistant_msgs:
rewards.append(0.0)
continue
content = assistant_msgs[-1] # Get the last assistant message
# Check for invalid markdown formatting
has_invalid_tags = any(re.search(pattern, content) for pattern in invalid_patterns)
if has_invalid_tags:
logger.debug("Found markdown-formatted tags in response")
rewards.append(0.0)
continue
# Check for any information tag variants (should not exist in assistant messages)
has_info_tags = False
for pattern in info_patterns:
info_matches = re.findall(pattern, content, re.IGNORECASE)
if info_matches:
logger.debug(f"Found {len(info_matches)} information tag(s) of type '{pattern}' in assistant message")
has_info_tags = True
break
if has_info_tags:
rewards.append(0.0)
continue
# Find all tag matches
think_matches = re.findall(think_pattern, content)
search_matches = re.findall(search_pattern, content)
answer_matches = re.findall(answer_pattern, content)
# Verify tag presence and count
has_think = len(think_matches) >= 1
has_answer = len(answer_matches) == 1 # Must have exactly one answer
has_search = len(search_matches) >= 1 # One or more search tags
# Check for search and answer in the same message (not allowed)
if has_search and has_answer:
logger.debug("Found both search and answer tags in the same message")
rewards.append(0.0)
continue
# Award reward - must have think tag and either answer or search (but not both)
reward = 1.0 if has_think and (has_answer or has_search) else 0.0
rewards.append(reward)
# Log issues for debugging
if not reward:
logger.debug(f"Format issues - think: {has_think}, answer: {has_answer}, search: {has_search}")
if search_matches:
logger.debug(f"Number of search tags: {len(search_matches)}")
# Log overall metrics
logger.info(f"Format reward metrics - Mean: {np.mean(rewards):.3f}, Valid formats: {sum(rewards)}/{len(rewards)}")
return rewards
# TODO: Implement this reward function if the project survives
def reward_long_query(completions, **kwargs):
"""Reward function that checks if the query is long."""
pass
def reward_retry(prompts: list, completions: list, **reward_kwargs) -> list:
"""
Reward function that encourages optimal retry behavior.
Rewards increase with more search attempts but caps at optimal_search_count.
Penalizes having multiple searches in a single message.
Args:
prompts: List of input prompts
completions: List of completion dictionaries with messages
**reward_kwargs: Additional reward parameters (chunk_id, answer, etc.)
Returns:
List of rewards for each completion, rounded to 3 decimal places
"""
rewards = []
search_queries = []
violations = []
# Config for retry rewards
optimal_search_count = 5 # Cap rewards at this many searches
base_reward = 0.2 # Base reward for having at least one search
increment = 0.15 # Reward increment per search attempt (0.2 + 5*0.15 = 0.95 max)
violation_penalty = 0.5 # Penalty for having multiple searches in one message
# Regex pattern for search tags
search_pattern = r"<search>[\s\S]*?</search>"
for completion in completions:
# Get assistant messages
assistant_messages = [msg["content"] for msg in completion["messages"] if msg["role"] == "assistant"]
# Count search tags in assistant messages
message_searches = []
for msg in assistant_messages:
# Find all search tags in each message
search_matches = re.findall(search_pattern, msg)
message_searches.append(len(search_matches))
# Record total search queries
total_searches = sum(message_searches)
search_queries.append(total_searches)
# Check for violations (more than one search query per message)
violation = any(count > 1 for count in message_searches)
violations.append(violation)
# Calculate reward
if total_searches == 0:
reward = 0.0 # No searches = no reward
else:
# Base reward for having at least one search
reward = base_reward
# Add incremental reward for each search up to optimal_search_count
search_bonus = min(total_searches, optimal_search_count) * increment
reward += search_bonus
# Cap reward at 1.0
reward = min(1.0, reward)
# Apply penalty if there's a violation
if violation:
reward *= 1 - violation_penalty
# Round to 3 decimal places to avoid floating point precision issues
reward = round(reward, 3)
rewards.append(reward)
# Log metrics with search distribution info
logger.info(f"Retry behavior rewards: {np.mean(rewards):.3f} ± {np.std(rewards):.3f}")
logger.info(f"Search tags per completion: {np.mean(search_queries):.2f} ± {np.std(search_queries):.2f}")
logger.info(f"Violations (>1 search per message): {sum(violations)}/{len(violations)}")
logger.info(f"Search counts distribution: {search_queries}")
return rewards
def reward_em_chunk(prompts: list, completions: list, **reward_kwargs) -> list:
"""Reward function that checks if model's search queries hit the correct chunk content.
Args:
prompts: List of input prompts
completions: List of completion dictionaries with messages
**reward_kwargs: Additional reward parameters including:
- chunk_content: List of correct chunk contents to match against
- step: Optional step number for logging metrics
Returns:
list: List of rewards (1.0 for exact match, 0.0 otherwise)
Raises:
ValueError: If chunk_content is not provided in reward_kwargs
"""
logger.debug(f"Calculating rewards for {len(prompts)} prompts")
# Get correct chunk contents from reward kwargs
correct_contents = reward_kwargs.get("chunk_content", [])
if not correct_contents:
logger.error("No chunk_content provided in reward_kwargs")
raise ValueError("chunk_content must be provided in reward_kwargs")
rewards = []
for i, (completion, correct_content) in enumerate(zip(completions, correct_contents)):
# Get all messages from ipython or user roles that start with <information>
search_results = [
msg["content"]
for msg in completion["messages"]
if msg["role"] in ("ipython", "user") and msg["content"].strip().startswith("<information>")
]
logger.debug(f"Found {len(search_results)} search results for prompt {i}")
# Log ground truth and searched chunks for debugging
logger.info(f"📝 Ground Truth Chunk: {correct_content}")
for j, result in enumerate(search_results):
logger.info(f"🔍 Searched Chunk {j + 1}: {result}")
# Check if any search hit the correct chunk content
found_correct_chunk = any(correct_content in result for result in search_results)
if not found_correct_chunk:
logger.warning(
f"Failed to find correct chunk for prompt {i}:\n"
f"Search results: {[r[:100] + '...' for r in search_results]}"
)
reward = 1.0 if found_correct_chunk else 0.0
rewards.append(reward)
logger.debug(f"Reward for prompt {i}: {reward}")
# Log summary metrics
logger.info("Chunk Query Rewards Summary:")
logger.info(f"Total prompts: {len(prompts)}")
logger.info(f"Correct matches: {sum(rewards)}")
logger.info(f"Average reward: {np.mean(rewards):.3f}")
logger.info(f"Reward std: {np.std(rewards):.3f}")
return rewards

@ -1,806 +0,0 @@
"""
RL helpers module for handling tool-based conversations.
This module provides utility functions for handling chat-based tool interactions
and calculating rewards based on the quality of responses.
"""
import inspect
import re
from dataclasses import dataclass
from datetime import datetime
import nest_asyncio
import numpy as np
import torch
from src.config import log_metric, logger
from src.search_module import get_qa_dataset, search
# Apply nest_asyncio for supporting async operations in notebooks
nest_asyncio.apply()
from trl.trainer.grpo_trainer import apply_chat_template
# Constants for prompts and tool definitions
def get_system_prompt():
"""Get the system prompt with current date."""
current_date = datetime.now().strftime("%d %b %Y")
return f"""Cutting Knowledge Date: December 2023
Today Date: {current_date}
When you receive a tool call response, use the output to format an answer to the original user question.
You are a helpful assistant with tool calling capabilities.
"""
def build_user_prompt(q):
"""
Build a user prompt with the question using the new template format.
Args:
q (str): The question to ask
Returns:
str: Formatted user prompt
"""
user_prompt = f"""Answer the given question. \
You must conduct reasoning inside <think> and </think> first every time you get new information. \
After reasoning, if you find you lack some knowledge, you can call a search engine by <search> query </search>. \
You can search as many times as your want. \
If you find no further external knowledge needed, you can directly provide the answer inside <answer> and </answer>, without detailed illustrations. For example, <answer> Beijing </answer>.
IMPORTANT INSTRUCTIONS:
1. PLEASE CONSIDER CHAT HISTORY WHEN ANSWERING THE QUESTION.
2. ONLY ANSWER WHEN YOU HAVE 100% CONFIDENCE IN THE SEARCH RESULTS, ELSE CONTINUE SEARCHING.
3. PLEASE SEARCH MULTIPLE TIMES WITH DIFFERENT QUERIES.
Question: {q}\n"""
return user_prompt
def format_search_results(results: str | list[str]) -> str:
"""
Format search results for display, matching the format from infer.py.
Each result should be in the format: "Doc X(Title: Y) content"
Args:
results: Search results as string or list of strings
Returns:
Formatted search results with document titles
"""
if isinstance(results, list):
# If results are already in the correct format, just join them
if any("Doc" in r and "Title:" in r for r in results):
content = "\n".join(results)
else:
# If results are raw content, format them with default titles
content = "\n".join([f"Doc {i + 1}(Title: Document {i + 1})\n{r}" for i, r in enumerate(results)])
else:
# If single result is already formatted, use it as is
if "Doc" in results and "Title:" in results:
content = results
else:
# If single result is raw content, format it with default title
content = f"Doc 1(Title: Document 1)\n{results}"
return content
def get_initial_chat(question):
"""
Initialize a chat state with the question.
Args:
question (str): The question to ask
Returns:
dict: Initial chat state with system and user messages
"""
return {
"messages": [
{"role": "system", "content": get_system_prompt()},
{"role": "user", "content": build_user_prompt(question)},
]
}
def remove_reasoning(text: str) -> str:
"""
Removes all content between <think> and </think> tags,
including the tags themselves.
Parameters:
text (str): The input text that may contain <think>...</think> tags.
Returns:
str: The text with the tags and their content removed.
"""
# The regex pattern matches from <think> to </think> non-greedily.
pattern = r"<think>.*?</think>"
cleaned_text = re.sub(pattern, "", text, flags=re.DOTALL)
return cleaned_text
def run_agent_generations(generate_fn, tokenizer, chat_states):
"""
Run generation for chat states requiring assistant responses.
"""
logger.debug(f"Starting generation for {len(chat_states)} chat states")
prompts = []
batch_indices = []
# Prepare prompts for chat states needing an assistant response.
for idx, chat_state in enumerate(chat_states):
if chat_state.get("finished"):
logger.debug(f"Chat state {idx} already finished, skipping")
continue
if chat_state["messages"][-1]["role"] in ["ipython", "user"]:
prompt = apply_chat_template(chat_state, tokenizer=tokenizer)["text"]
prompts.append(prompt)
batch_indices.append(idx)
logger.debug(f"Added prompt for chat state {idx}")
if prompts:
logger.info(f"Generating responses for {len(prompts)} prompts")
responses = generate_fn(prompts)
for i, idx in enumerate(batch_indices):
chat_state = chat_states[idx]
response = responses[i]
if hasattr(response, "outputs"):
full_response = response.outputs[0].text
else:
full_response = response
assistant_response = full_response.split("<|start_header_id|>assistant<|end_header_id|>")[-1]
chat_state["messages"].append({"role": "assistant", "content": assistant_response})
logger.debug(f"Added assistant response to chat state {idx}")
else:
logger.debug("No prompts to generate responses for")
return chat_states
def check_finished_chats(chat_states):
"""
Check which chat states are finished (no more search queries).
Args:
chat_states: List of chat states
Returns:
list: Updated chat states with finished flag
"""
for chat_state in chat_states:
if chat_state.get("finished"):
continue
assert chat_state["messages"][-1]["role"] == "assistant", "Expected the last role to be assistant"
assistant_response = chat_state["messages"][-1]["content"]
# Check if there are any search queries in the response
if not re.search(r"<search>.*?</search>", assistant_response, re.DOTALL):
chat_state["finished"] = True
return chat_states
def extract_search_query(text: str) -> str | None:
"""
Extract search query from text between <search> tags.
Args:
text (str): Text containing search query
Returns:
str | None: Search query if found, None otherwise
"""
pattern = re.compile(r"<search>(.*?)</search>", re.DOTALL)
matches = pattern.findall(text)
return matches[-1] if matches else None
def run_tool_calls(chat_states):
"""
Execute tool calls found in chat states.
"""
logger.debug(f"Running tool calls for {len(chat_states)} chat states")
total_retries = 0
for chat_state in chat_states:
if chat_state.get("finished"):
logger.debug("Chat state already finished, skipping tool calls")
continue
assert chat_state["messages"][-1]["role"] == "assistant", (
"Expected the last role to be assistant to run tool calls"
)
try:
assistant_response = chat_state["messages"][-1]["content"]
search_query = extract_search_query(assistant_response)
if search_query:
logger.info(f"🔍 Search Query: {search_query}")
results = search(search_query, return_type=str, results=2)
# Wrap results in <information> tags
formatted_results = f"<information>{results}</information>"
logger.info(f" Information: {formatted_results}")
chat_state["messages"].append({"role": "ipython", "content": formatted_results})
total_retries += 1
logger.debug("Added search results to chat state")
except Exception as e:
logger.error(f"Error during tool call: {str(e)}")
chat_state["messages"].append({"role": "system", "content": f"Error during post-processing: {str(e)}"})
chat_state["finished"] = True
return chat_states
def get_mask(text, tokenizer):
encoding = tokenizer(text, add_special_tokens=False)
start_header_id = tokenizer.convert_tokens_to_ids("<|start_header_id|>")
assistant_token = tokenizer.convert_tokens_to_ids("assistant")
end_header_id = tokenizer.convert_tokens_to_ids("<|end_header_id|>")
eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>")
assistant_ranges = []
i = 0
while i < len(encoding.input_ids) - 1:
if encoding.input_ids[i] == start_header_id and encoding.input_ids[i + 1] == assistant_token:
i += 2
while i < len(encoding.input_ids) and encoding.input_ids[i] != end_header_id:
i += 1
i += 2
start_idx = i
while i < len(encoding.input_ids) and encoding.input_ids[i] != eot_id:
i += 1
end_idx = i
assistant_ranges.append((start_idx, end_idx))
else:
i += 1
mask = [0] * len(encoding.input_ids)
for start_idx, end_idx in assistant_ranges:
for idx in range(start_idx, end_idx):
mask[idx] = 1
return torch.tensor(mask, dtype=torch.int)
def check_exceeded_max_new_tokens(chat_states, max_new_tokens, tokenizer):
for chat_state in chat_states:
if chat_state.get("finished"):
continue
initial_length = chat_state["initial_length"]
new_length = get_chat_num_tokens(chat_state, tokenizer)
if new_length - initial_length > max_new_tokens:
chat_state["finished"] = True
return chat_states
@dataclass
class AgenticOutputs:
prompt_tokens: list[torch.Tensor]
response_tokens: list[torch.Tensor]
response_masks: list[torch.Tensor]
final_response_str: list[str]
full_chat_states: list[dict]
def get_chat_num_tokens(chat_state, tokenizer):
chat_text = apply_chat_template(chat_state, tokenizer=tokenizer)["text"]
return tokenizer(chat_text, add_special_tokens=False, return_tensors="pt")["input_ids"].squeeze().shape[0]
def run_agent(
generate_fn,
tokenizer,
questions,
max_generations=5,
max_new_tokens=4096,
correct_contents=None,
):
"""
Run the agent to completion for a batch of questions.
"""
logger.info(f"Starting agent run with {len(questions)} questions")
logger.debug(f"Max generations: {max_generations}, Max new tokens: {max_new_tokens}")
chat_states = [get_initial_chat(q) for q in questions]
# Add correct content to chat states if provided
if correct_contents:
for chat_state, correct_content in zip(chat_states, correct_contents):
chat_state["correct_content"] = correct_content
# set the initial_prompt length
for i, chat_state in enumerate(chat_states):
chat_state["initial_length"] = get_chat_num_tokens(chat_state, tokenizer)
logger.debug(f"Initial length for question {i}: {chat_state['initial_length']}")
# agent loop
for i in range(max_generations):
logger.info(f"Starting generation step {i + 1}/{max_generations}")
chat_states = run_agent_generations(generate_fn, tokenizer, chat_states)
chat_states = check_finished_chats(chat_states)
chat_states = run_tool_calls(chat_states)
chat_states = check_exceeded_max_new_tokens(chat_states, max_new_tokens, tokenizer)
finished_count = sum(1 for state in chat_states if state.get("finished"))
logger.info(f"Finished {finished_count}/{len(chat_states)} chat states after step {i + 1}")
logger.info("Agent run completed")
# Process final outputs
logger.debug("Processing final outputs")
answers = []
for chat in chat_states:
answers.append(chat["messages"][-1]["content"])
logger.debug(f"Final answer: {chat['messages'][-1]['content'][:100]}...")
def split_prompt_assistant(convo_text):
marker = "<|start_header_id|>assistant<|end_header_id|>"
idx = convo_text.find(marker)
if idx == -1:
logger.error("Could not find assistant marker in conversation text")
raise ValueError("Could not find assistant marker in conversation text.")
return convo_text, ""
prompt = convo_text[: idx + len(marker)]
assistant_response = convo_text[idx + len(marker) :]
return prompt, assistant_response
str_chats = [apply_chat_template(chat, tokenizer=tokenizer)["text"] for chat in chat_states]
prompt_toks, response_toks, response_masks = [], [], []
logger.debug("Processing tokenization")
for i, str_chat in enumerate(str_chats):
prompt, response = split_prompt_assistant(str_chat)
prompt_toks.append(tokenizer(prompt, add_special_tokens=False, return_tensors="pt")["input_ids"].squeeze())
response_toks.append(
tokenizer(response, add_special_tokens=False, return_tensors="pt")["input_ids"].squeeze()[:max_new_tokens]
)
mask = get_mask(str_chat, tokenizer)[len(prompt_toks[-1]) :][:max_new_tokens]
response_masks.append(mask)
logger.debug(f"Processed tokens for chat {i}")
final_response_str = [chat["messages"][-1]["content"] for chat in chat_states]
full_chat_states = chat_states
logger.info("Agent run completed successfully")
return AgenticOutputs(
prompt_tokens=prompt_toks,
response_tokens=response_toks,
response_masks=response_masks,
final_response_str=final_response_str,
full_chat_states=full_chat_states,
)
# Verification
async def verify(student_answer: str, question: str, answer: str) -> bool:
"""
Verify if student's answer matches the correct answer.
Args:
student_answer: The model's answer
question: The original question
answer: The ground truth answer
Returns:
bool: True if answer is correct, False otherwise
"""
logger.debug(f"Verifying answer for question: {question}")
logger.debug(f"Student answer: {student_answer}")
logger.debug(f"Correct answer: {answer}")
# Simple string matching for now
# TODO: Implement more sophisticated matching
return student_answer.strip().lower() == answer.strip().lower()
def check_student_answers(
questions: list[str],
answers: list[str],
student_answers: list, # Can be strings or dicts
vllm_generate_func,
tokenizer,
log_file=None,
) -> list[bool]:
"""
Evaluates a list of student answers against the true answers using a vLLM generate function.
Args:
questions: List of questions
answers: List of correct answers
student_answers: List of student answers to evaluate
vllm_generate_func: Function to generate verification responses
tokenizer: Tokenizer for formatting prompts
log_file: Optional path to write detailed results
Returns:
List of boolean results (True for correct answers)
"""
logger.info(f"Checking {len(questions)} student answers")
if not (len(questions) == len(answers) == len(student_answers)):
logger.error("Mismatched lengths between questions, answers, and student answers")
raise ValueError("The number of questions, answers, and student answers must be equal.")
prompts = []
for question, answer, student_ans in zip(questions, answers, student_answers):
prompt_text = (
"You are grading a student's answer to a question. For the following question, "
"compare the student's answer to the correct answer. Reply with 'Yes' if the student's answer contains the correct information, "
"even if it's not an exact match. If the student's answer doesn't contain the right information or is completely incorrect, reply with 'No'.\n\n"
f"Question: {question}\n"
f"Correct Answer: {answer}\n"
f"Student Answer: {student_ans}\n\n"
"Your response should be just 'Yes' or 'No'."
)
formatted_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt_text}],
tokenize=False,
add_generation_prompt=True,
)
prompts.append(formatted_prompt)
logger.debug(f"Created verification prompt for question: {question[:50]}...")
logger.info("Generating verification responses")
responses = vllm_generate_func(prompts)
responses_text = []
for response in responses:
# Handle different response formats
if hasattr(response, "outputs"):
try:
responses_text.append(response.outputs[0].text)
except (AttributeError, IndexError):
# Fallback for simple string responses
responses_text.append(str(response))
else:
responses_text.append(str(response))
logger.debug(f"Got {len(responses_text)} verification responses")
results = []
for response in responses_text:
results.append("yes" in response.lower())
logger.debug(f"Verification result: {'yes' in response.lower()}")
logger.info(f"Verification complete. {sum(results)}/{len(results)} answers correct")
# Append the QA details and verifier's response to the specified log file
if log_file:
with open(log_file, "a") as file:
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
file.write(f"\n📝 === QA Evaluation at {timestamp} ===\n")
file.write(f"📂 File: {__file__}\n")
# Get current frame info safely
frame = inspect.currentframe()
if frame:
file.write(f"📍 Line: {frame.f_lineno}\n")
# Don't forget to delete the frame to avoid reference cycles
del frame
file.write("=" * 80 + "\n")
for i, (question, answer, student_ans, verifier_response) in enumerate(
zip(questions, answers, student_answers, responses_text)
):
file.write(f"\n❓ Question {i + 1}:\n")
file.write("-" * 40 + "\n")
file.write(f"📋 Question: {question}\n")
file.write(f"✅ Correct Answer: {answer}\n")
file.write(f"👨‍🎓 Student Answer: {student_ans}\n")
file.write(f"🔍 Verifier said: {verifier_response}\n")
# Add search results if available in the chat state
if isinstance(student_ans, dict) and "messages" in student_ans:
# Get messages from dict
messages = student_ans.get("messages", [])
search_results = [msg.get("content", "") for msg in messages if msg.get("role") == "ipython"]
if search_results:
file.write("\n🔎 Search Results:\n")
for j, result in enumerate(search_results, 1):
file.write(f"\nSearch {j}:\n{result}\n")
file.write("-" * 40 + "\n")
file.write(
f"\n📊 Summary: {sum(results)}/{len(results)} answers correct ({sum(results) / len(results) * 100:.2f}%)\n"
)
file.write("=" * 80 + "\n\n")
return results
# Reward Functions
def build_reward_correctness_fn(generate_fn, tokenizer, log_file=None):
def reward_correctness(prompts, completions, **reward_kwargs):
teacher_answers = reward_kwargs["answer"]
student_answers = [completion["messages"][-1]["content"] for completion in completions]
# Log non-exact matches
for i, (student, teacher) in enumerate(zip(student_answers, teacher_answers)):
if student.strip().lower() != teacher.strip().lower():
logger.warning(f"Non-exact match at index {i}:\nStudent: {student}\nTeacher: {teacher}")
correct = check_student_answers(
prompts,
teacher_answers,
student_answers,
vllm_generate_func=generate_fn,
tokenizer=tokenizer,
log_file=log_file,
)
# Log correctness metrics with length info
log_metric("rewards/correctness", np.mean(correct), reward_kwargs.get("step", 0))
log_metric("rewards/correctness_std", np.std(correct), reward_kwargs.get("step", 0))
# Log length metrics
student_lengths = [len(ans.strip()) for ans in student_answers]
teacher_lengths = [len(ans.strip()) for ans in teacher_answers]
log_metric(
"metrics/avg_student_length",
np.mean(student_lengths),
reward_kwargs.get("step", 0),
)
log_metric(
"metrics/avg_teacher_length",
np.mean(teacher_lengths),
reward_kwargs.get("step", 0),
)
log_metric(
"metrics/length_ratio",
np.mean(student_lengths) / np.mean(teacher_lengths),
reward_kwargs.get("step", 0),
)
return correct
return reward_correctness
def reward_formatting(prompts, completions, **reward_kwargs):
# make sure full chats doesn't have any error function calls
has_error = [False] * len(completions)
for i, chat in enumerate(completions):
for message in chat["messages"]:
if "Error during" in message["content"]:
has_error[i] = True
logger.warning(f"Error in chat {i}: {message['content']}")
break
rewards = [0.7 if not e else 0 for e in has_error]
# Log formatting metrics
log_metric("rewards/formatting", np.mean(rewards), reward_kwargs.get("step", 0))
log_metric("rewards/formatting_std", np.std(rewards), reward_kwargs.get("step", 0))
log_metric("metrics/error_rate", np.mean(has_error), reward_kwargs.get("step", 0))
return rewards
def reward_retry_behavior(completions: list[dict], **reward_kwargs) -> list[float]:
"""
Reward function that encourages optimal retry behavior by only rewarding completions
where every assistant message contains at most 1 search query.
"""
rewards: list[float] = []
for completion in completions:
# Get ALL assistant messages
assistant_msgs: list[str] = [
msg["content"]
for msg in completion["messages"]
if msg["role"] == "assistant" and msg["content"] is not None
]
if not assistant_msgs:
rewards.append(0.0)
continue
# Check if every message has at most 1 search query
has_multiple_searches = False
total_searches = 0
for msg in assistant_msgs:
search_count = len(re.findall(r"<search>.*?</search>", msg, re.DOTALL))
total_searches += search_count
if search_count > 1:
has_multiple_searches = True
logger.warning(f"Message contains {search_count} search queries, which exceeds the limit of 1")
break
# Only reward if no message has multiple search queries
if has_multiple_searches:
rewards.append(0.0)
else:
# Base reward is 1.0 if constraint is met
base_reward = 1.0
# Slight penalty for having too many total searches across all messages
if total_searches > 4:
penalty = 0.1 * (total_searches - 4)
base_reward = max(0.2, base_reward - penalty)
logger.debug(f"Applied penalty for {total_searches} total searches: {penalty}")
rewards.append(base_reward)
# Log retry behavior metrics
log_metric("rewards/retry_behavior", np.mean(rewards), reward_kwargs.get("step", 0))
log_metric("rewards/retry_behavior_std", np.std(rewards), reward_kwargs.get("step", 0))
log_metric(
"metrics/avg_searches_per_msg",
np.mean(
[
len(re.findall(r"<search>.*?</search>", msg["content"], re.DOTALL))
for completion in completions
for msg in completion["messages"]
if msg["role"] == "assistant"
]
),
reward_kwargs.get("step", 0),
)
log_metric(
"metrics/multiple_search_violation_rate",
np.mean([0.0 if rewards[i] > 0.0 else 1.0 for i in range(len(rewards))]),
reward_kwargs.get("step", 0),
)
return rewards
def reward_exact_match_chunk_query(prompts, completions, **reward_kwargs):
"""
Reward function that checks if the model's search queries hit the correct chunk content.
"""
logger.debug(f"Calculating rewards for {len(prompts)} prompts")
# Get correct chunk contents from reward kwargs
correct_contents = reward_kwargs.get("chunk_content", [])
if not correct_contents:
logger.error("No chunk_content provided in reward_kwargs")
raise ValueError("chunk_content must be provided in reward_kwargs")
rewards = []
for i, (chat_state, correct_content) in enumerate(zip(completions, correct_contents)):
# Get all ipython messages (search results) from the chat
search_results = [msg["content"] for msg in chat_state["messages"] if msg["role"] == "ipython"]
logger.debug(f"Found {len(search_results)} search results for prompt {i}")
# Log ground truth chunk and searched chunks
logger.info(f"📝 Ground Truth Chunk: {correct_content}")
for j, result in enumerate(search_results):
logger.info(f"🔍 Searched Chunk {j + 1}: {result}")
# Check if any search hit the correct chunk content
found_correct_chunk = False
for result in search_results:
if correct_content in result:
found_correct_chunk = True
logger.debug(f"Found correct chunk content in search results for prompt {i}")
break
if not found_correct_chunk:
logger.warning(
f"Failed to find correct chunk for prompt {i}:\n"
f"Search results: {[r[:100] + '...' for r in search_results]}"
)
reward = 1.0 if found_correct_chunk else 0.0
rewards.append(reward)
logger.debug(f"Reward for prompt {i}: {reward}")
# Log detailed metrics for debugging
log_metric(
f"debug/chunk_match_{i}",
1 if found_correct_chunk else 0,
reward_kwargs.get("step", 0),
)
log_metric(
f"debug/search_results_count_{i}",
len(search_results),
reward_kwargs.get("step", 0),
)
if search_results:
log_metric(
f"debug/result_length_{i}",
np.mean([len(r.split()) for r in search_results]),
reward_kwargs.get("step", 0),
)
# Log chunk query metrics
log_metric("rewards/chunk_query", np.mean(rewards), reward_kwargs.get("step", 0))
log_metric("rewards/chunk_query_std", np.std(rewards), reward_kwargs.get("step", 0))
log_metric(
"metrics/avg_search_results",
np.mean(
[
len([msg["content"] for msg in chat_state["messages"] if msg["role"] == "ipython"])
for chat_state in completions
]
),
reward_kwargs.get("step", 0),
)
log_metric("metrics/chunk_match_rate", np.mean(rewards), reward_kwargs.get("step", 0))
# Log detailed debugging info
logger.info("Chunk Query Rewards Summary:")
logger.info(f"Total prompts: {len(prompts)}")
logger.info(f"Correct matches: {sum(rewards)}")
logger.info(f"Average reward: {np.mean(rewards):.3f}")
logger.info(f"Reward std: {np.std(rewards):.3f}")
return rewards
def run_eval(generate_fn, verify_fn, tokenizer, output_file=None, debug_file=None):
"""
Run evaluation on the test dataset and return results.
Args:
generate_fn: Function to generate completions
verify_fn: Function to verify results
tokenizer: Tokenizer for processing text
output_file: Path to save evaluation results summary
debug_file: Path to save detailed debug information
Returns:
full_chat_states: The chat states from evaluation
"""
train_dataset, test_dataset = get_qa_dataset()
questions = test_dataset["prompt"]
agentic_outputs = run_agent(generate_fn, tokenizer, questions)
full_chat_states = agentic_outputs.full_chat_states
final_responses = agentic_outputs.final_response_str
rewards = verify_fn(questions, full_chat_states, answer=test_dataset["answer"])
# Calculate results
percent_correct = sum(rewards) / len(rewards) * 100
# Log results to console
logger.info("RESULTS:")
logger.info(f"percentage of correct answers: {percent_correct:.2f}%")
logger.info("=" * 30)
# Save results to file if specified
if output_file:
try:
with open(output_file, "w") as f:
f.write("EVALUATION RESULTS\n")
f.write("=================\n\n")
f.write(f"Total questions: {len(questions)}\n")
f.write(f"Correct answers: {sum(rewards)}\n")
f.write(f"Percentage correct: {percent_correct:.2f}%\n\n")
f.write("Individual results:\n")
for i, (q, r, resp) in enumerate(zip(questions, rewards, final_responses)):
f.write(f"\nQ{i + 1}: {q[:100]}...\n")
f.write(f"Correct: {'' if r else ''}\n")
f.write(f"Response: {resp[:150]}...\n")
f.write("-" * 40 + "\n")
logger.info(f"Saved evaluation results to {output_file}")
except Exception as e:
logger.error(f"Error saving results file: {e}")
# Save debug information if specified
if debug_file:
try:
import json
debug_data = []
for i, (q, r, resp, chat) in enumerate(zip(questions, rewards, final_responses, full_chat_states)):
debug_data.append(
{
"question_id": i,
"question": q,
"is_correct": bool(r),
"final_response": resp,
"chat_state": {
k: str(v) if isinstance(v, (list, dict)) else v
for k, v in chat.items()
if k != "tokenizer"
},
}
)
with open(debug_file, "w") as f:
json.dump(debug_data, f, indent=2)
logger.info(f"Saved debug information to {debug_file}")
except Exception as e:
logger.error(f"Error saving debug file: {e}")
return full_chat_states

@ -7,7 +7,7 @@ import json
import random import random
from datasets import Dataset from datasets import Dataset
from langchain.vectorstores import FAISS from langchain_community.vectorstores import FAISS
from src.config import DATA_DIR, logger from src.config import DATA_DIR, logger
from src.embeddings import CustomHuggingFaceEmbeddings from src.embeddings import CustomHuggingFaceEmbeddings

@ -0,0 +1,306 @@
"""
Tokenizer adapter implementations for different models.
This module provides adapter classes for handling different tokenizer formats.
"""
from abc import ABC, abstractmethod
import torch
from src.config import logger
class TokenizerAdapter(ABC):
"""Base class for tokenizer adapters."""
@abstractmethod
def get_assistant_marker(self) -> str:
"""Get the assistant marker for the model."""
pass
@abstractmethod
def get_end_marker(self) -> str:
"""Get the end marker for the model."""
pass
@abstractmethod
def get_mask(self, text: str, tokenizer) -> torch.Tensor:
"""Get the mask for the model's response."""
pass
@abstractmethod
def split_prompt_assistant(self, text: str) -> tuple[str, str]:
"""Split conversation text into prompt and assistant response."""
pass
class LlamaTokenizerAdapter(TokenizerAdapter):
"""Adapter for Llama model tokenizer."""
def get_assistant_marker(self) -> str:
"""Get the assistant marker."""
return "<|start_header_id|>assistant<|end_header_id|>"
def get_end_marker(self) -> str:
"""Get the end marker."""
return "<|eot_id|>"
def split_prompt_assistant(self, convo_text: str) -> tuple[str, str]:
"""Split the text into prompt and assistant parts.
Args:
convo_text: The text to split
Returns:
A tuple of (prompt, assistant)
"""
# EXACT replication from rl_helpers.py but using existing method
marker = self.get_assistant_marker() # Use existing method but same value
idx = convo_text.find(marker)
if idx == -1:
raise ValueError("Could not find assistant marker in conversation text.")
return convo_text, ""
# Include the marker in the prompt by slicing up to the end of the marker.
prompt = convo_text[: idx + len(marker)]
# The assistant response is everything after the marker.
assistant_response = convo_text[idx + len(marker) :]
return prompt, assistant_response
def get_mask(self, text: str, tokenizer) -> torch.Tensor:
"""Get the mask for the text.
Args:
text: The text to get the mask for
tokenizer: The tokenizer to use
Returns:
A tensor of 0s and 1s where 1s indicate assistant tokens
"""
# Log input
logger.debug(f"🔍 Llama: Full text length: {len(text)}")
# EXACT replication from rl_helpers.py but using existing methods
encoding = tokenizer(text, add_special_tokens=False)
# Use existing methods but same values
start_header_id = tokenizer.convert_tokens_to_ids("<|start_header_id|>")
assistant_token = tokenizer.convert_tokens_to_ids("assistant")
end_header_id = tokenizer.convert_tokens_to_ids("<|end_header_id|>")
eot_id = tokenizer.convert_tokens_to_ids(self.get_end_marker()) # Use existing method but same value
# Log token IDs
logger.debug(f"🔍 Llama: Tokenized length: {len(encoding.input_ids)}")
logger.debug(f"🔍 Llama: Input IDs: {encoding.input_ids}")
logger.debug(
f"🔍 Llama: Special token IDs: start={start_header_id}, assistant={assistant_token}, end={end_header_id}, eot={eot_id}"
)
assistant_ranges = []
i = 0
while i < len(encoding.input_ids) - 1:
if encoding.input_ids[i] == start_header_id and encoding.input_ids[i + 1] == assistant_token:
logger.debug(f"🔍 Llama: Found assistant marker at position {i}")
logger.debug(f"🔍 Llama: Assistant marker tokens: {encoding.input_ids[i : i + 2]}")
i += 2
while i < len(encoding.input_ids) and encoding.input_ids[i] != end_header_id:
i += 1
i += 2
start_idx = i
logger.debug(f"🔍 Llama: Found start of response at {start_idx}")
logger.debug(f"🔍 Llama: Start token ID: {encoding.input_ids[start_idx]}")
while i < len(encoding.input_ids) and encoding.input_ids[i] != eot_id:
i += 1
end_idx = i
logger.debug(f"🔍 Llama: Found end of response at {end_idx}")
logger.debug(f"🔍 Llama: End token ID: {encoding.input_ids[end_idx]}")
logger.debug(f"🔍 Llama: Response token IDs: {encoding.input_ids[start_idx:end_idx]}")
assistant_ranges.append((start_idx, end_idx))
else:
i += 1
mask = [0] * len(encoding.input_ids)
for start_idx, end_idx in assistant_ranges:
for idx in range(start_idx, end_idx):
mask[idx] = 1
mask = torch.tensor(mask, dtype=torch.int)
# Log final mask
logger.debug(f"🔍 Llama: Final mask shape: {mask.shape}")
logger.debug(f"🔍 Llama: Mask sum: {mask.sum().item()}")
logger.debug(f"🔍 Llama: Mask: {mask}")
# Additional debug info
try:
prompt, response = self.split_prompt_assistant(text)
prompt_tokens = tokenizer(prompt, add_special_tokens=False).input_ids
response_tokens = tokenizer(response, add_special_tokens=False).input_ids
logger.debug(f"🔍 Llama: Prompt length: {len(prompt)}")
logger.debug(f"🔍 Llama: Response length: {len(response)}")
logger.debug(f"🔍 Llama: Prompt token IDs: {prompt_tokens}")
logger.debug(f"🔍 Llama: Response token IDs: {response_tokens}")
logger.debug(f"🔍 Llama: Prompt: {prompt[:100]}...")
logger.debug(f"🔍 Llama: Response: {response[:100]}...")
logger.debug(f"🔍 Llama: Full input IDs length: {len(encoding.input_ids)}")
logger.debug(f"🔍 Llama: Prompt + Response token IDs length: {len(prompt_tokens) + len(response_tokens)}")
logger.debug(
f"🔍 Llama: Difference in lengths: {len(encoding.input_ids) - (len(prompt_tokens) + len(response_tokens))}"
)
except Exception as e:
logger.error(f"🔍 Llama: Error splitting prompt/response: {e}")
return mask
class R1DistilTokenizerAdapter(TokenizerAdapter):
"""Adapter for R1-Distil model tokenizer."""
def get_assistant_marker(self) -> str:
marker = "<Assistant>"
return marker
def get_end_marker(self) -> str:
marker = "<end▁of▁sentence>"
return marker
def get_begin_marker(self) -> str:
return "<begin▁of▁sentence>"
def get_user_marker(self) -> str:
return "<User>"
def get_mask(self, text: str, tokenizer) -> torch.Tensor:
"""Get the mask for the text.
Args:
text: The text to get the mask for
tokenizer: The tokenizer to use
Returns:
A tensor of 0s and 1s where 1s indicate assistant tokens
"""
logger.debug(f"🔍 R1Distil: Getting mask for text length: {len(text)}")
# Get all markers
assistant_marker = self.get_assistant_marker()
end_marker = self.get_end_marker()
# Get the full tokenization
encoding = tokenizer(text, add_special_tokens=False)
tokens = encoding.input_ids
logger.debug(f"🔍 R1Distil: Full text token IDs: {tokens}")
# Create mask initialized to zeros - ENSURE SAME LENGTH AS INPUT_IDS
mask = torch.zeros(len(tokens), dtype=torch.int)
# Get token IDs for markers
assistant_tokens = tokenizer(assistant_marker, add_special_tokens=False).input_ids
end_tokens = tokenizer(end_marker, add_special_tokens=False).input_ids
logger.debug(f"🔍 R1Distil: Assistant marker token IDs: {assistant_tokens}")
logger.debug(f"🔍 R1Distil: End marker token IDs: {end_tokens}")
# Find all assistant responses
assistant_ranges = []
i = 0
while i < len(tokens):
# Look for assistant marker
if i + len(assistant_tokens) <= len(tokens) and tokens[i : i + len(assistant_tokens)] == assistant_tokens:
logger.debug(f"🔍 R1Distil: Found assistant marker at position {i}")
# Start masking AFTER the assistant marker
start_idx = i + len(assistant_tokens)
# Find end marker
end_idx = None
j = start_idx
while j < len(tokens):
if j + len(end_tokens) <= len(tokens) and tokens[j : j + len(end_tokens)] == end_tokens:
end_idx = j # Don't include the end marker in the mask
break
j += 1
if end_idx is None:
# If no end marker found, mask until the end
end_idx = len(tokens)
logger.debug(f"🔍 R1Distil: Response range: {start_idx} to {end_idx}")
assistant_ranges.append((start_idx, end_idx))
i = end_idx + len(end_tokens) # Start next search after the end marker
else:
i += 1
# Apply mask for all found ranges
for start_idx, end_idx in assistant_ranges:
mask[start_idx:end_idx] = 1
logger.debug(f"🔍 R1Distil: Found {len(assistant_ranges)} assistant responses")
logger.debug(f"🔍 R1Distil: Final mask sum: {mask.sum().item()}")
logger.debug(f"🔍 R1Distil: Final mask length: {len(mask)}")
logger.debug(f"🔍 R1Distil: Mask: {mask}")
return mask
def split_prompt_assistant(self, text: str) -> tuple[str, str]:
"""Split the text into prompt and assistant parts.
Args:
text: The text to split
Returns:
A tuple of (prompt, assistant)
"""
logger.debug(f"🔍 R1Distil: Splitting text of length: {len(text)}")
# Find the assistant marker
marker = self.get_assistant_marker()
end_marker = self.get_end_marker()
# Find ALL assistant markers in the text
assistant_markers = []
pos = 0
while True:
pos = text.find(marker, pos)
if pos == -1:
break
assistant_markers.append(pos)
pos += len(marker)
if not assistant_markers:
raise ValueError("Could not find assistant marker in text")
# Get the positions of all markers for later use
marker_positions = []
for start_pos in assistant_markers:
response_start = start_pos + len(marker)
# Find the end marker after this response
end_pos = text.find(end_marker, response_start)
if end_pos == -1:
end_pos = len(text)
else:
end_pos = end_pos + len(end_marker)
marker_positions.append((start_pos, response_start, end_pos))
# Get the full response (all assistant outputs concatenated)
full_response = ""
for _, resp_start, resp_end in marker_positions:
full_response += text[resp_start:resp_end]
# Include ALL assistant markers and responses in the response
# This matches how the mask is generated in get_mask
first_assistant_pos = marker_positions[0][0]
last_response_end = marker_positions[-1][2]
# Split into prompt and response
prompt = text[:first_assistant_pos] # Everything before the first assistant marker
response = text[first_assistant_pos:last_response_end] # All markers and responses
logger.debug(f"🔍 R1Distil: Prompt length: {len(prompt)}")
logger.debug(f"🔍 R1Distil: Response length: {len(response)}")
logger.debug(f"🔍 R1Distil: Response token count estimate: {len(response) / 4}") # Rough estimate
logger.debug(f"🔍 R1Distil: Final prompt: {prompt[:100]}...")
logger.debug(f"🔍 R1Distil: Final response: {response[:100]}...")
return prompt, response

@ -1,8 +1,16 @@
"""
Train a model using GRPO (Generative Reward-Penalized Optimization).
"""
import os import os
from unsloth import FastLanguageModel, is_bfloat16_supported from unsloth import FastLanguageModel, is_bfloat16_supported
import src.UnslothGRPOTrainerTemp as UnslothGRPOTrainerTemp import src.UnslothGRPOTrainerTemp as UnslothGRPOTrainerTemp
# Import reward functions
from src import build_reward_correctness_fn, get_qa_dataset, reward_em_chunk, reward_format, reward_retry
from src.agent import Agent
from src.config import ( from src.config import (
MODEL_CONFIG, MODEL_CONFIG,
MODEL_NAME, MODEL_NAME,
@ -13,16 +21,9 @@ from src.config import (
logger, logger,
update_log_path, update_log_path,
) )
from src.rewards import build_reward_correctness_fn, reward_em_chunk, reward_retry
# Import reward functions from src.search_module import get_qa_dataset
from src.rl_helpers import ( from src.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter
build_reward_correctness_fn,
get_qa_dataset,
reward_exact_match_chunk_query,
reward_formatting,
reward_retry_behavior,
run_agent,
)
# Initialize training directories # Initialize training directories
paths = init_training_dirs() paths = init_training_dirs()
@ -78,7 +79,17 @@ def agentic_generate(
generate_fn, generate_fn,
max_generations: int = 10, max_generations: int = 10,
): ):
return run_agent(generate_fn, tokenizer, prompts, max_generations) # Create agent with appropriate adapter based on tokenizer
tokenizer_name = tokenizer.name_or_path.lower()
if "deepseek-r1-distill" in tokenizer_name:
adapter = R1DistilTokenizerAdapter()
elif "llama" in tokenizer_name:
adapter = LlamaTokenizerAdapter()
else:
adapter = R1DistilTokenizerAdapter()
agent = Agent(adapter)
return agent.run_agent(generate_fn, tokenizer, prompts, max_generations)
model.agentic_generate = agentic_generate model.agentic_generate = agentic_generate
@ -102,13 +113,12 @@ trainer = UnslothGRPOTrainerTemp.UnslothGRPOTrainer(
processing_class=tokenizer, processing_class=tokenizer,
reward_funcs=[ reward_funcs=[
build_reward_correctness_fn( build_reward_correctness_fn(
verifier_generate_fn, vllm_generate_func=verifier_generate_fn,
tokenizer, tokenizer=tokenizer,
log_file=os.path.join(paths["log_dir"], "qa_log.txt"),
), ),
reward_formatting, reward_format,
reward_retry_behavior, reward_retry,
reward_exact_match_chunk_query, reward_em_chunk,
], ],
args=training_args, args=training_args,
train_dataset=train_dataset, train_dataset=train_dataset,

Loading…
Cancel
Save