diff --git a/src/UnslothGRPOTrainerTemp.py b/src/UnslothGRPOTrainerTemp.py index 67f375f..780ddd0 100644 --- a/src/UnslothGRPOTrainerTemp.py +++ b/src/UnslothGRPOTrainerTemp.py @@ -54,6 +54,8 @@ from trl.trainer.grpo_trainer import ( wandb, ) +from src.config import logger + torch_compile_options = { "epilogue_fusion": True, "max_autotune": False, @@ -63,11 +65,11 @@ torch_compile_options = { } -@torch.compile( - dynamic=True, - fullgraph=True, - options=torch_compile_options, -) +# @torch.compile( +# dynamic=True, +# fullgraph=True, +# options=torch_compile_options, +# ) def selective_log_softmax(logits, index): logits = logits.to(torch.float32) selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1) @@ -82,6 +84,29 @@ def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages) # All Unsloth Zoo code licensed under LGPLv3 old_logits = old_logits.to(torch.float32) new_logits = new_logits.to(torch.float32) + + # Print FULL tensor contents + logger.debug("\nπŸ” DETAILED TENSOR ANALYSIS:") + logger.debug("\n1️⃣ Input IDs:") + logger.debug(f"Shape: {input_ids.shape}") + # Use tensor.cpu().numpy() to safely print content + + logger.debug(f"Type: {mask.dtype}") + logger.debug(f"Sum: {mask.sum().item()}") # Use .item() to get Python scalar + + logger.debug("\n3️⃣ Old Logits:") + logger.debug(f"Shape: {old_logits.shape}") + logger.debug(f"Type: {old_logits.dtype}") + logger.debug(f"Mean: {old_logits.mean().item():.4f}") + + logger.debug("\n4️⃣ New Logits:") + logger.debug(f"Shape: {new_logits.shape}") + logger.debug(f"Type: {new_logits.dtype}") + logger.debug(f"Mean: {new_logits.mean().item():.4f}") + + logger.debug("\n5️⃣ Advantages:") + logger.debug(f"Shape: {advantages.shape}") + input_ids = input_ids.unsqueeze(-1) # x_i - logsumexp(x_i) @@ -90,6 +115,10 @@ def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages) old = old_x - torch.logsumexp(old_logits, dim=-1) new = new_x - torch.logsumexp(new_logits, dim=-1) + logger.debug("\n6️⃣ After Gather & LogSumExp:") + logger.debug(f"old_x shape: {old_x.shape}, new_x shape: {new_x.shape}") + logger.debug(f"old shape: {old.shape}, new shape: {new.shape}") + # Reverse KL kl_i = torch.exp(old - new) - (old - new) - 1.0 # Full correct reverse KL divergence?? Missing term maybe? @@ -109,6 +138,13 @@ def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages) # See https://github.com/huggingface/trl/pull/2881 # loss_per_reward = (loss_i * mask).sum(1) / n_mask_per_reward # loss = loss_per_reward.mean() + + # Add print statements here for debugging + logger.debug(f"🚨 Debug: loss_i shape: {loss_i.shape}") + logger.debug( + f"🚨 Debug: mask shape: {mask.shape}" + ) # Note: Mask shape might change slightly due to float conversion + loss = (loss_i * mask).sum() / mask.sum() # Get metrics as well which are folded @@ -165,14 +201,7 @@ class UnslothEfficientGRPO(torch.autograd.Function): accumulated_completion_length = torch.zeros(1, device=device) accumulated_mean_kl = torch.zeros(1, device=device) - def accumulate_chunk( - new_hidden_states_j, - old_hidden_states_j, - input_ids_j, - mask_j, - advantages_j, - scaling, - ): + def accumulate_chunk(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j, scaling): ( (chunk_grad_input,), ( @@ -187,14 +216,7 @@ class UnslothEfficientGRPO(torch.autograd.Function): compute_loss, argnums=(0,), has_aux=True, - )( - new_hidden_states_j, - old_hidden_states_j, - input_ids_j, - mask_j, - advantages_j, - scaling, - ) + )(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j, scaling) accumulated_loss.add_(unscaled_loss) accumulated_completion_length.add_(chunk_completion_length) accumulated_mean_kl.add_(chunk_mean_kl) @@ -202,11 +224,11 @@ class UnslothEfficientGRPO(torch.autograd.Function): pass - accumulate_chunk = torch.compile( - accumulate_chunk, - fullgraph=True, - options=torch_compile_options, - ) + # accumulate_chunk = torch.compile( + # accumulate_chunk, + # fullgraph=True, + # options=torch_compile_options, + # ) grad_inputs_chunks = torch.chunk(grad_inputs, chunks=n_chunks, dim=0) new_hidden_states = torch.chunk(_new_hidden_states, chunks=n_chunks, dim=0) @@ -228,28 +250,14 @@ class UnslothEfficientGRPO(torch.autograd.Function): input_ids_j, mask_j, advantages_j, - ) in zip( - grad_inputs_chunks, - new_hidden_states, - old_hidden_states, - input_ids, - mask, - advantages, - ): + ) in zip(grad_inputs_chunks, new_hidden_states, old_hidden_states, input_ids, mask, advantages): mark_dynamic(new_hidden_states_j) mark_dynamic(old_hidden_states_j) mark_dynamic(input_ids_j) mark_dynamic(mask_j) grad_inputs_j.copy_( - accumulate_chunk( - new_hidden_states_j, - old_hidden_states_j, - input_ids_j, - mask_j, - advantages_j, - scaling, - ) + accumulate_chunk(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j, scaling) ) pass @@ -624,7 +632,7 @@ class UnslothGRPOConfig(GRPOConfig): save_strategy = "no" div = per_device_train_batch_size // num_generations if div * num_generations != per_device_train_batch_size: - print( + logger.debug( "Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.\nWe will change the batch size of " + str(per_device_train_batch_size) + " to the `num_generations` of " @@ -971,11 +979,7 @@ class _UnslothGRPOTrainer(Trainer): self.sampling_params = SamplingParams( temperature=args.temperature, max_tokens=self.max_completion_length, - **getattr( - getattr(args, "vllm_sampling_params", vLLMSamplingParams()), - "_set_kwargs", - {}, - ), + **getattr(getattr(args, "vllm_sampling_params", vLLMSamplingParams()), "_set_kwargs", {}), ) else: self.generation_config = GenerationConfig( @@ -1038,9 +1042,7 @@ class _UnslothGRPOTrainer(Trainer): with torch.amp.autocast(device_type="cuda", dtype=self._autocast_dtype): # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded logits = model( - input_ids=input_ids, - attention_mask=attention_mask, - logits_to_keep=logits_to_keep + 1, + input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1 ).logits logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred @@ -1056,25 +1058,31 @@ class _UnslothGRPOTrainer(Trainer): return None def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]: + logger.debug("\nπŸ” DEBUG: Starting _prepare_inputs") device = self.accelerator.device prompts = [x["prompt"] for x in inputs] prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs] + logger.debug("\n1️⃣ Before tokenization:") + logger.debug(f"Number of prompts: {len(prompts)}") + logger.debug(f"Sample prompt text length: {len(prompts_text[0]) if prompts_text else 0}") + prompt_inputs = self.processing_class( - prompts_text, - return_tensors="pt", - padding=True, - padding_side="left", - add_special_tokens=False, + prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False ) prompt_inputs = super()._prepare_inputs(prompt_inputs) - prompt_ids, prompt_mask = ( - prompt_inputs["input_ids"], - prompt_inputs["attention_mask"], - ) + prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] + + logger.debug("\n2️⃣ After initial tokenization:") + logger.debug(f"prompt_ids shape: {prompt_ids.shape}") + logger.debug(f"prompt_mask shape: {prompt_mask.shape}") + logger.debug(f"prompt_mask sum: {prompt_mask.sum().item()}") if self.max_prompt_length is not None: prompt_ids = prompt_ids[:, -self.max_prompt_length :] prompt_mask = prompt_mask[:, -self.max_prompt_length :] + logger.debug("\n3️⃣ After prompt length truncation:") + logger.debug(f"prompt_ids shape: {prompt_ids.shape}") + logger.debug(f"prompt_mask shape: {prompt_mask.shape}") # Generate completions using either vLLM or regular generation if self.args.use_vllm: @@ -1086,7 +1094,7 @@ class _UnslothGRPOTrainer(Trainer): # Generate completions using vLLM: gather all prompts and use them in a single call in the main process all_prompts_text = gather_object(prompts_text) if self.accelerator.is_main_process: - print(all_prompts_text) + logger.debug(all_prompts_text) generate_fn = lambda prompts_text: self.llm.generate( prompts_text, sampling_params=self.sampling_params, @@ -1104,6 +1112,11 @@ class _UnslothGRPOTrainer(Trainer): prompt_inputs = agentic_outputs.prompt_tokens completion_ids = agentic_outputs.response_tokens completion_mask = agentic_outputs.response_masks + for i in range(len(completion_ids)): + logger.debug(f"prompt_inputs {i} len before padding: {len(prompt_inputs[i])}") + logger.debug(f"completion_ids {i} len before padding: {len(completion_ids[i])}") + logger.debug(f"completion_mask {i} len before padding: {len(completion_mask[i])}") + prompt_ids = pad( prompt_inputs, padding_value=self.processing_class.pad_token_id, @@ -1114,6 +1127,12 @@ class _UnslothGRPOTrainer(Trainer): padding_value=0, padding_side="right", ).to(device) + + for i in range(len(completion_ids)): + logger.debug(f"prompt_inputs {i} len after padding: {len(prompt_inputs[i])}") + logger.debug(f"prompt_ids {i} len after padding: {len(prompt_ids[i])}") + logger.debug(f"completion_mask {i} len after padding: {len(completion_mask[i])}") + else: outputs = generate_fn(all_prompts_text) completion_ids = [out.token_ids for completions in outputs for out in completions.outputs] @@ -1130,15 +1149,21 @@ class _UnslothGRPOTrainer(Trainer): # Pad the completions, and concatenate them with the prompts completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] + logger.debug("\n4️⃣ Before completion padding:") + logger.debug(f"completion_ids shapes: {[ids.shape for ids in completion_ids]}") + completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id) + logger.debug("\n5️⃣ After completion padding:") + logger.debug(f"completion_ids shape: {completion_ids.shape}") + prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) + logger.debug("\n6️⃣ After concatenation:") + logger.debug(f"prompt_completion_ids shape: {prompt_completion_ids.shape}") else: # Regular generation path with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model: prompt_completion_ids = unwrapped_model.generate( - prompt_ids, - attention_mask=prompt_mask, - generation_config=self.generation_config, + prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config ) # Compute prompt length and extract completion ids @@ -1148,11 +1173,52 @@ class _UnslothGRPOTrainer(Trainer): if not self.use_agentic_generate: # Mask everything after the first EOS token + logger.debug("\nπŸ” Starting EOS token detection and masking:") + logger.debug(f"completion_ids shape: {completion_ids.shape}") + logger.debug(f"eos_token_id: {self.processing_class.eos_token_id}") + + # Debug EOS detection is_eos = completion_ids == self.processing_class.eos_token_id + logger.debug("\n7️⃣ EOS Detection Details:") + logger.debug(f"is_eos shape: {is_eos.shape}") + logger.debug(f"Sample is_eos values (first sequence):\n{is_eos[0]}") + logger.debug(f"Any EOS tokens found: {is_eos.any().item()}") + logger.debug(f"EOS positions: {is_eos.nonzero()}") + + # Debug EOS index tensor creation eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) + logger.debug("\n8️⃣ EOS Index Creation:") + logger.debug(f"eos_idx initial shape: {eos_idx.shape}") + logger.debug(f"eos_idx initial values: {eos_idx}") + + # Debug the complex indexing operation + logger.debug("\n9️⃣ EOS Position Analysis:") + logger.debug(f"Sequences with EOS: {is_eos.any(dim=1).sum().item()}") + logger.debug(f"First EOS positions: {is_eos.int().argmax(dim=1)}") + eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] + logger.debug("\nπŸ”Ÿ After EOS Index Update:") + logger.debug(f"Updated eos_idx values: {eos_idx}") + + # Debug sequence indices creation sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) + logger.debug("\n1️⃣1️⃣ Sequence Indices:") + logger.debug(f"sequence_indices shape: {sequence_indices.shape}") + logger.debug(f"Sample sequence_indices (first row):\n{sequence_indices[0]}") + + # Debug final mask creation completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() + logger.debug("\n1️⃣2️⃣ Final Completion Mask:") + logger.debug(f"completion_mask shape: {completion_mask.shape}") + logger.debug(f"Sample mask (first sequence):\n{completion_mask[0]}") + logger.debug("Mask statistics:") + logger.debug(f"- Total 1s: {completion_mask.sum().item()}") + logger.debug(f"- Average sequence length: {completion_mask.sum(dim=1).float().mean().item():.2f}") + + # Add a final validation check + logger.debug("\n7️⃣ Final Validation:") + logger.debug(f"Input shape: {completion_ids.shape}") + logger.debug(f"Mask shape: {completion_mask.shape}") # Concatenate prompt_mask with completion_mask for logit computation attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C) @@ -1173,18 +1239,12 @@ class _UnslothGRPOTrainer(Trainer): ): if self.ref_model is not None: ref_per_token_logps = self._get_per_token_logps( - self.ref_model, - prompt_completion_ids, - attention_mask, - logits_to_keep, + self.ref_model, prompt_completion_ids, attention_mask, logits_to_keep ) else: with self.accelerator.unwrap_model(self.model, keep_fp32_wrapper=False).disable_adapter(): ref_per_token_logps = self._get_per_token_logps( - self.model, - prompt_completion_ids, - attention_mask, - logits_to_keep, + self.model, prompt_completion_ids, attention_mask, logits_to_keep ) # Decode the generated completions @@ -1211,11 +1271,7 @@ class _UnslothGRPOTrainer(Trainer): else: texts = [p + c for p, c in zip(prompts, completions)] reward_inputs = reward_processing_class( - texts, - return_tensors="pt", - padding=True, - padding_side="right", - add_special_tokens=False, + texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False ) reward_inputs = super()._prepare_inputs(reward_inputs) with ( @@ -1273,7 +1329,7 @@ class _UnslothGRPOTrainer(Trainer): # Log the metrics reward_per_func = rewards_per_func.mean(0) - print("rewards_per_func:", reward_per_func) + logger.debug("rewards_per_func:", reward_per_func) for i, reward_func in enumerate(self.reward_funcs): if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models reward_func_name = reward_func.config._name_or_path.split("/")[-1] @@ -1318,10 +1374,7 @@ class _UnslothGRPOTrainer(Trainer): # Compute the per-token log probabilities for the model prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] - completion_ids, completion_mask = ( - inputs["completion_ids"], - inputs["completion_mask"], - ) + completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] input_ids = torch.cat([prompt_ids, completion_ids], dim=1) bsz, qlen = input_ids.shape # attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) @@ -1369,13 +1422,7 @@ class _UnslothGRPOTrainer(Trainer): self._metrics["kl"].append(mean_kl.item()) return loss - def prediction_step( - self, - model, - inputs, - prediction_loss_only, - ignore_keys: Optional[list[str]] = None, - ): + def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: Optional[list[str]] = None): inputs = self._prepare_inputs(inputs) with torch.no_grad(): with self.compute_loss_context_manager(): @@ -1593,7 +1640,7 @@ class UnslothGRPOTrainer(_UnslothGRPOTrainer): from transformers import __version__ as transformers_version if Version(transformers_version) <= Version("4.45.2"): - print( + logger.debug( "**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n" "`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`" ) diff --git a/src/__init__.py b/src/__init__.py index e69de29..169eccd 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -0,0 +1,43 @@ +""" +Main package exports for RL helpers. +""" + +from trl.trainer.grpo_trainer import apply_chat_template + +from src.agent import Agent +from src.config import logger +from src.evaluation import check_student_answers, run_eval, verify +from src.prompts import build_user_prompt, format_search_results, get_system_prompt +from src.rewards import ( + build_reward_correctness_fn, + reward_em_chunk, + reward_format, + reward_retry, +) +from src.search_module import get_qa_dataset, search +from src.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter + +__all__ = [ + # Prompts + "get_system_prompt", + "build_user_prompt", + "format_search_results", + "apply_chat_template", + # Agent + "Agent", + "LlamaTokenizerAdapter", + "R1DistilTokenizerAdapter", + # Rewards + "build_reward_correctness_fn", + "reward_format", + "reward_retry", + "reward_em_chunk", + # Evaluation + "run_eval", + "check_student_answers", + "verify", + # Search + "get_qa_dataset", + "search", + "logger", +] diff --git a/src/agent.py b/src/agent.py new file mode 100644 index 0000000..11b4748 --- /dev/null +++ b/src/agent.py @@ -0,0 +1,237 @@ +""" +Core agent functionality for handling tool-based conversations. +This module provides a base agent class for handling tool-based conversations. +""" + +import re +from dataclasses import dataclass + +import torch +from trl.trainer.grpo_trainer import apply_chat_template + +from src.config import logger +from src.prompts import build_user_prompt, get_system_prompt +from src.search_module import search +from src.tokenizer_adapter import TokenizerAdapter + + +@dataclass +class AgenticOutputs: + """Outputs from running the agent on a batch of questions.""" + + prompt_tokens: list[torch.Tensor] + response_tokens: list[torch.Tensor] + response_masks: list[torch.Tensor] + final_response_str: list[str] + full_chat_states: list[dict] + + +class Agent: + """Base agent class for handling tool-based conversations.""" + + def __init__(self, tokenizer_adapter: TokenizerAdapter): + """Initialize the agent with a tokenizer adapter.""" + self.tokenizer_adapter = tokenizer_adapter + + def get_initial_chat(self, question: str) -> dict: + """Initialize a chat state with the question.""" + return { + "messages": [ + {"role": "system", "content": get_system_prompt()}, + {"role": "user", "content": build_user_prompt(question)}, + ] + } + + def extract_search_query(self, text: str) -> str | None: + """Extract search query from text between tags.""" + pattern = re.compile(r"(.*?)", re.DOTALL) + matches = pattern.findall(text) + return matches[-1] if matches else None + + def run_agent_generations(self, generate_fn, tokenizer, chat_states: list[dict]) -> list[dict]: + """Run generation for chat states requiring assistant responses.""" + logger.debug(f"Starting generation for {len(chat_states)} chat states") + prompts = [] + batch_indices = [] + + for idx, chat_state in enumerate(chat_states): + if chat_state.get("finished"): + logger.debug(f"Chat state {idx} already finished, skipping") + continue + + if chat_state["messages"][-1]["role"] in ["ipython", "user"]: + prompt = apply_chat_template(chat_state, tokenizer=tokenizer)["text"] + prompts.append(prompt) + batch_indices.append(idx) + logger.debug(f"Added prompt for chat state {idx}") + + if prompts: + logger.info(f"Generating responses for {len(prompts)} prompts") + responses = generate_fn(prompts) + for i, idx in enumerate(batch_indices): + chat_state = chat_states[idx] + response = responses[i] + if hasattr(response, "outputs"): + full_response = response.outputs[0].text + else: + full_response = response + + assistant_response = full_response.split(self.tokenizer_adapter.get_assistant_marker())[-1] + chat_state["messages"].append({"role": "assistant", "content": assistant_response}) + logger.debug(f"Added assistant response to chat state {idx}") + else: + logger.debug("No prompts to generate responses for") + + return chat_states + + def check_finished_chats(self, chat_states: list[dict]) -> list[dict]: + """Check which chat states are finished (no more search queries).""" + for chat_state in chat_states: + if chat_state.get("finished"): + continue + assert chat_state["messages"][-1]["role"] == "assistant", "Expected the last role to be assistant" + assistant_response = chat_state["messages"][-1]["content"] + if not re.search(r".*?", assistant_response, re.DOTALL): + chat_state["finished"] = True + return chat_states + + def run_tool_calls(self, chat_states: list[dict]) -> list[dict]: + """Execute tool calls found in chat states.""" + logger.debug(f"Running tool calls for {len(chat_states)} chat states") + + for chat_state in chat_states: + if chat_state.get("finished"): + logger.debug("Chat state already finished, skipping tool calls") + continue + + assert chat_state["messages"][-1]["role"] == "assistant", ( + "Expected the last role to be assistant to run tool calls" + ) + try: + assistant_response = chat_state["messages"][-1]["content"] + search_query = self.extract_search_query(assistant_response) + if search_query: + logger.info(f"πŸ” Search Query: {search_query}") + results = search(search_query, return_type=str, results=2) + formatted_results = f"{results}" + logger.info(f"ℹ️ Information: {formatted_results}") + + # chat_state["messages"].append({"role": "ipython", "content": formatted_results}) + chat_state["messages"].append({"role": "user", "content": formatted_results}) + logger.debug("Added search results to chat state") + except Exception as e: + logger.error(f"Error during tool call: {str(e)}") + chat_state["messages"].append({"role": "system", "content": f"Error during post-processing: {str(e)}"}) + chat_state["finished"] = True + + return chat_states + + def get_chat_num_tokens(self, chat_state: dict, tokenizer) -> int: + """Get number of tokens in chat state.""" + chat_text = apply_chat_template(chat_state, tokenizer=tokenizer)["text"] + return tokenizer(chat_text, add_special_tokens=False, return_tensors="pt")["input_ids"].squeeze().shape[0] + + def check_exceeded_max_new_tokens(self, chat_states: list[dict], max_new_tokens: int, tokenizer) -> list[dict]: + """Check if any chat state has exceeded max new tokens.""" + for chat_state in chat_states: + if chat_state.get("finished"): + continue + initial_length = chat_state["initial_length"] + new_length = self.get_chat_num_tokens(chat_state, tokenizer) + if new_length - initial_length > max_new_tokens: + chat_state["finished"] = True + return chat_states + + def run_agent( + self, + generate_fn, + tokenizer, + questions: list[str], + max_generations: int = 5, + max_new_tokens: int = 4096, + correct_contents: list[str] | None = None, + ) -> AgenticOutputs: + """Run the agent to completion for a batch of questions. + + This method follows the same flow as rl_helpers.py: + 1. Initialize chat states with questions + 2. Run agent loop (generations, check finished, tool calls, check tokens) + 3. Process final outputs (split prompt/response, get masks) + + The key difference from our previous implementation is in how we handle + the final tokenization and masking, which now matches rl_helpers.py exactly. + """ + # Step 1: Initialize chat states with questions + chat_states = [self.get_initial_chat(q) for q in questions] + if correct_contents: + for chat_state, correct_content in zip(chat_states, correct_contents): + chat_state["correct_content"] = correct_content + + # Set initial token lengths for each chat state + for chat_state in chat_states: + chat_state["initial_length"] = self.get_chat_num_tokens(chat_state, tokenizer) + + # Step 2: Run agent loop + for i in range(max_generations): + chat_states = self.run_agent_generations(generate_fn, tokenizer, chat_states) + chat_states = self.check_finished_chats(chat_states) + chat_states = self.run_tool_calls(chat_states) + chat_states = self.check_exceeded_max_new_tokens(chat_states, max_new_tokens, tokenizer) + + # Step 3: Process final outputs + # Get the final answers from each chat state + answers = [chat["messages"][-1]["content"] for chat in chat_states] + + # Convert chat states to text format for tokenization + str_chats = [apply_chat_template(chat, tokenizer=tokenizer)["text"] for chat in chat_states] + prompt_toks, response_toks, response_masks = [], [], [] + + # Process each chat state to get tokens and masks + for str_chat in str_chats: + try: + # Split into prompt and response parts + # Note: If assistant marker is missing, split_prompt_assistant will return (full_text, "") + prompt_text, response_text = self.tokenizer_adapter.split_prompt_assistant(str_chat) + + # Get prompt tokens + prompt_toks.append( + tokenizer(prompt_text, add_special_tokens=False, return_tensors="pt")["input_ids"].squeeze() + ) + + # Get response tokens (truncated to max_new_tokens) + response_toks.append( + tokenizer(response_text, add_special_tokens=False, return_tensors="pt")["input_ids"].squeeze()[ + :max_new_tokens + ] + ) + + # Get full mask and slice it properly + # This matches rl_helpers.py exactly: + # 1. Get full mask for entire text + # 2. Slice from prompt length to end + # 3. Truncate to max_new_tokens + full_mask = self.tokenizer_adapter.get_mask(str_chat, tokenizer) + prompt_len = prompt_toks[-1].shape[0] + mask = full_mask[prompt_len:][:max_new_tokens] + response_masks.append(mask) + + # debug if the tokens and masks are of same length by logger info + logger.debug(f"Prompt tokens length: {len(prompt_toks[-1])}") + logger.debug(f"Mask length: {len(mask)}") + logger.debug(f"Response tokens length: {len(response_toks[-1])}") + + except Exception: + # If anything fails, add empty tensors + # This matches rl_helpers.py's behavior of not handling errors explicitly + prompt_toks.append(torch.tensor([], dtype=torch.long)) + response_toks.append(torch.tensor([], dtype=torch.long)) + response_masks.append(torch.tensor([], dtype=torch.long)) + + # Return final outputs + return AgenticOutputs( + prompt_tokens=prompt_toks, + response_tokens=response_toks, + response_masks=response_masks, + final_response_str=answers, + full_chat_states=chat_states, + ) diff --git a/src/config.py b/src/config.py index 8eb4fab..cba9d89 100644 --- a/src/config.py +++ b/src/config.py @@ -19,6 +19,8 @@ LOG_FOLDER = PROJ_ROOT / "logs" # Model configuration # MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" +# MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" +# MODEL_NAME = "unsloth/Qwen2-1.5B" # Smoke test first device_id = 1 if os.environ.get("CUDA_VISIBLE_DEVICES") == "1" else torch.cuda.current_device() timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") @@ -100,21 +102,21 @@ def _init_logging(env: str = "development") -> None: file_format = "{time:YYYY-MM-DD at HH:mm:ss} | {level} | {name}:{function}:{line} - {message}" - # Add console logging with DEBUG level + # Add console logging with INFO level (minimal terminal output) logger.add( sys.stderr, format=console_format, - level="DEBUG", # Always use DEBUG level + level="INFO", # "INFO", # Changed from DEBUG to INFO for minimal terminal output colorize=True, backtrace=True, - diagnose=True, # Always enable diagnostics + diagnose=True, ) - # Add default file logging to ./logs directory with DEBUG level + # Add default file logging to ./logs directory with DEBUG level (full details) logger.add( LOG_FOLDER / "app.log", format=file_format, - level="DEBUG", # Always use DEBUG level + level="DEBUG", # Keep DEBUG level for full file logging rotation="500 MB", retention="7 days", compression="zip", @@ -235,44 +237,6 @@ def setup_logger(module_name=None, create_dirs: bool = False): return logger -# Tensorboard writer singleton -_tensorboard_writer = None - - -# Safe tensorboard logging function -def log_metric(key, value, step=0): - """ - Log a metric safely to tensorboard if writer is available. - - Args: - key: Metric name - value: Metric value - step: Training step - """ - global _tensorboard_writer - - # Skip tensorboard logging if disabled in config - if TRAINING_CONFIG.get("report_to") != "tensorboard": - logger.debug(f"Tensorboard disabled. Metric: {key}={value} (step {step})") - return - - # Get paths and initialize writer if needed - paths = get_paths(create_dirs=False) - if paths["tensorboard_dir"].exists(): - # Only create writer once - if _tensorboard_writer is None: - from torch.utils.tensorboard.writer import SummaryWriter - - _tensorboard_writer = SummaryWriter(paths["tensorboard_dir"]) - logger.debug(f"Created tensorboard writer at {paths['tensorboard_dir']}") - - # Add scalar using existing writer - _tensorboard_writer.add_scalar(key, value, step) - # No need to close the writer - it will be closed at process exit - else: - logger.debug(f"Tensorboard metric: {key}={value} (step {step})") - - # Initialize logging on module import env = os.getenv("APP_ENV", "development") _init_logging(env=env) diff --git a/src/evaluation.py b/src/evaluation.py new file mode 100644 index 0000000..286df9d --- /dev/null +++ b/src/evaluation.py @@ -0,0 +1,238 @@ +""" +Evaluation utilities for RL training. +""" + +import inspect +from datetime import datetime + +from src.agent import Agent +from src.config import logger +from src.search_module import get_qa_dataset +from src.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter + + +async def verify(student_answer: str, question: str, answer: str) -> bool: + """ + Verify if student's answer matches the correct answer. + + Args: + student_answer: The model's answer + question: The original question + answer: The ground truth answer + + Returns: + bool: True if answer is correct, False otherwise + """ + logger.debug(f"Verifying answer for question: {question}") + logger.debug(f"Student answer: {student_answer}") + logger.debug(f"Correct answer: {answer}") + + # Simple string matching for now + # TODO: Implement more sophisticated matching + return student_answer.strip().lower() == answer.strip().lower() + + +def check_student_answers( + questions: list[str], + answers: list[str], + student_answers: list, # Can be strings or dicts + vllm_generate_func, + tokenizer, + log_file=None, +) -> list[bool]: + """ + Evaluates a list of student answers against the true answers using a vLLM generate function. + + Args: + questions: List of questions + answers: List of correct answers + student_answers: List of student answers to evaluate + vllm_generate_func: Function to generate verification responses + tokenizer: Tokenizer for formatting prompts + log_file: Optional path to write detailed results + + Returns: + List of boolean results (True for correct answers) + """ + logger.info(f"Checking {len(questions)} student answers") + + if not (len(questions) == len(answers) == len(student_answers)): + logger.error("Mismatched lengths between questions, answers, and student answers") + raise ValueError("The number of questions, answers, and student answers must be equal.") + + prompts = [] + for question, answer, student_ans in zip(questions, answers, student_answers): + prompt_text = ( + "You are grading a student's answer to a question. For the following question, " + "compare the student's answer to the correct answer. Reply with 'Yes' if the student's answer contains the correct information, " + "even if it's not an exact match. If the student's answer doesn't contain the right information or is completely incorrect, reply with 'No'.\n\n" + f"Question: {question}\n" + f"Correct Answer: {answer}\n" + f"Student Answer: {student_ans}\n\n" + "Your response should be just 'Yes' or 'No'." + ) + + formatted_prompt = tokenizer.apply_chat_template( + [{"role": "user", "content": prompt_text}], + tokenize=False, + add_generation_prompt=True, + ) + prompts.append(formatted_prompt) + logger.debug(f"Created verification prompt for question: {question[:50]}...") + + logger.info("Generating verification responses") + responses = vllm_generate_func(prompts) + responses_text = [] + for response in responses: + # Handle different response formats + if hasattr(response, "outputs"): + try: + responses_text.append(response.outputs[0].text) + except (AttributeError, IndexError): + # Fallback for simple string responses + responses_text.append(str(response)) + else: + responses_text.append(str(response)) + logger.debug(f"Got {len(responses_text)} verification responses") + + results = [] + for response in responses_text: + results.append("yes" in response.lower()) + logger.debug(f"Verification result: {'yes' in response.lower()}") + + logger.info(f"Verification complete. {sum(results)}/{len(results)} answers correct") + + # Append the QA details and verifier's response to the specified log file + if log_file: + with open(log_file, "a") as file: + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + file.write(f"\nπŸ“ === QA Evaluation at {timestamp} ===\n") + file.write(f"πŸ“‚ File: {__file__}\n") + + # Get current frame info safely + frame = inspect.currentframe() + if frame: + file.write(f"πŸ“ Line: {frame.f_lineno}\n") + # Don't forget to delete the frame to avoid reference cycles + del frame + + file.write("=" * 80 + "\n") + + for i, (question, answer, student_ans, verifier_response) in enumerate( + zip(questions, answers, student_answers, responses_text) + ): + file.write(f"\n❓ Question {i + 1}:\n") + file.write("-" * 40 + "\n") + file.write(f"πŸ“‹ Question: {question}\n") + file.write(f"βœ… Correct Answer: {answer}\n") + file.write(f"πŸ‘¨β€πŸŽ“ Student Answer: {student_ans}\n") + file.write(f"πŸ” Verifier said: {verifier_response}\n") + + # Add search results if available in the chat state + if isinstance(student_ans, dict) and "messages" in student_ans: + # Get messages from dict + messages = student_ans.get("messages", []) + search_results = [msg.get("content", "") for msg in messages if msg.get("role") == "ipython"] + if search_results: + file.write("\nπŸ”Ž Search Results:\n") + for j, result in enumerate(search_results, 1): + file.write(f"\nSearch {j}:\n{result}\n") + + file.write("-" * 40 + "\n") + + file.write( + f"\nπŸ“Š Summary: {sum(results)}/{len(results)} answers correct ({sum(results) / len(results) * 100:.2f}%)\n" + ) + file.write("=" * 80 + "\n\n") + + return results + + +def run_eval(generate_fn, verify_fn, tokenizer, output_file=None, debug_file=None): + """ + Run evaluation on the test dataset and return results. + + Args: + generate_fn: Function to generate completions + verify_fn: Function to verify results + tokenizer: Tokenizer for processing text + output_file: Path to save evaluation results summary + debug_file: Path to save detailed debug information + + Returns: + full_chat_states: The chat states from evaluation + """ + train_dataset, test_dataset = get_qa_dataset() + questions = test_dataset["prompt"] + + # Create agent with appropriate adapter based on tokenizer + tokenizer_name = tokenizer.name_or_path.lower() + if "deepseek-r1-distill" in tokenizer_name: + adapter = R1DistilTokenizerAdapter() + elif "llama" in tokenizer_name: + adapter = LlamaTokenizerAdapter() + else: + adapter = R1DistilTokenizerAdapter() + + agent = Agent(adapter) + agentic_outputs = agent.run_agent(generate_fn, tokenizer, questions) + full_chat_states = agentic_outputs.full_chat_states + final_responses = agentic_outputs.final_response_str + rewards = verify_fn(questions, full_chat_states, answer=test_dataset["answer"]) + + # Calculate results + percent_correct = sum(rewards) / len(rewards) * 100 + + # Log results to console + logger.info("RESULTS:") + logger.info(f"percentage of correct answers: {percent_correct:.2f}%") + logger.info("=" * 30) + + # Save results to file if specified + if output_file: + try: + with open(output_file, "w") as f: + f.write("EVALUATION RESULTS\n") + f.write("=================\n\n") + f.write(f"Total questions: {len(questions)}\n") + f.write(f"Correct answers: {sum(rewards)}\n") + f.write(f"Percentage correct: {percent_correct:.2f}%\n\n") + + f.write("Individual results:\n") + for i, (q, r, resp) in enumerate(zip(questions, rewards, final_responses)): + f.write(f"\nQ{i + 1}: {q[:100]}...\n") + f.write(f"Correct: {'βœ“' if r else 'βœ—'}\n") + f.write(f"Response: {resp[:150]}...\n") + f.write("-" * 40 + "\n") + logger.info(f"Saved evaluation results to {output_file}") + except Exception as e: + logger.error(f"Error saving results file: {e}") + + # Save debug information if specified + if debug_file: + try: + import json + + debug_data = [] + for i, (q, r, resp, chat) in enumerate(zip(questions, rewards, final_responses, full_chat_states)): + debug_data.append( + { + "question_id": i, + "question": q, + "is_correct": bool(r), + "final_response": resp, + "chat_state": { + k: str(v) if isinstance(v, (list, dict)) else v + for k, v in chat.items() + if k != "tokenizer" + }, + } + ) + + with open(debug_file, "w") as f: + json.dump(debug_data, f, indent=2) + logger.info(f"Saved debug information to {debug_file}") + except Exception as e: + logger.error(f"Error saving debug file: {e}") + + return full_chat_states diff --git a/src/prompts.py b/src/prompts.py new file mode 100644 index 0000000..e4e6005 --- /dev/null +++ b/src/prompts.py @@ -0,0 +1,67 @@ +""" +Prompt-related functions for handling system and user prompts. +""" + +from datetime import datetime + + +def get_system_prompt(): + """Get the system prompt with current date.""" + current_date = datetime.now().strftime("%d %b %Y") + return f"""Cutting Knowledge Date: December 2023 +Today Date: {current_date} + +You are a helpful assistant with search capabilities. +""" + + +def build_user_prompt(q): + """ + Build a user prompt with the question using the new template format. + + Args: + q (str): The question to ask + + Returns: + str: Formatted user prompt + """ + user_prompt = f"""Answer the given question. \ +You must conduct reasoning inside and first every time you get new information. \ +After reasoning, if you find you lack some knowledge, you can call a search engine by query . \ +Based on the user's core intent, formulate the most effective search query using specific, descriptive keywords that differentiate the topic clearly. \ +Aim for queries that resemble how an expert searcher might phrase it, like using "compare lithium-ion vs solid-state battery efficiency" rather than just "batteries". \ +The document will be provided inside and tags to you later. \ +You can search as many turns as you want, but only one search query per turn. \ +If you find no further external knowledge needed, you can directly provide the answer inside and , without detailed illustrations. \ +Only answer when you have 100% confidence in the search results, else continue searching. \ +Question: {q}\n""" + return user_prompt + + +def format_search_results(results: str | list[str]) -> str: + """ + Format search results for display, matching the format from infer.py. + Each result should be in the format: "Doc X(Title: Y) content" + + Args: + results: Search results as string or list of strings + + Returns: + Formatted search results with document titles + """ + if isinstance(results, list): + # If results are already in the correct format, just join them + if any("Doc" in r and "Title:" in r for r in results): + content = "\n".join(results) + else: + # If results are raw content, format them with default titles + content = "\n".join([f"Doc {i + 1}(Title: Document {i + 1})\n{r}" for i, r in enumerate(results)]) + else: + # If single result is already formatted, use it as is + if "Doc" in results and "Title:" in results: + content = results + else: + # If single result is raw content, format it with default title + content = f"Doc 1(Title: Document 1)\n{results}" + + return content diff --git a/src/rewards.py b/src/rewards.py new file mode 100644 index 0000000..ea7b206 --- /dev/null +++ b/src/rewards.py @@ -0,0 +1,312 @@ +""" +Reward functions for RL training. +""" + +import re + +import numpy as np + +from src.config import logger +from src.evaluation import check_student_answers + + +def build_reward_correctness_fn( + vllm_generate_func, + tokenizer, +): + """Build a reward function that checks correctness of student answers. + + Args: + vllm_generate_func: Function to generate answers using vLLM + tokenizer: Tokenizer for the model + + Returns: + A reward function that takes prompts and completions and returns correctness scores + """ + + def reward_correctness(prompts: list, completions: list, **reward_kwargs) -> list: + """Calculate reward based on correctness of student answers. + + Args: + prompts: List of input prompts + completions: List of model completions + **reward_kwargs: Additional arguments for reward calculation + + Returns: + List of correctness scores between 0 and 1 + """ + teacher_answers = reward_kwargs["answer"] + student_answers = [completion["messages"][-1]["content"] for completion in completions] + + # Log non-exact matches + for i, (student, teacher) in enumerate(zip(student_answers, teacher_answers)): + if student.strip().lower() != teacher.strip().lower(): + logger.debug(f"Non-exact match at index {i}:\nStudent: {student}\nTeacher: {teacher}") + + correct = check_student_answers( + prompts, + teacher_answers, + student_answers, + vllm_generate_func=vllm_generate_func, + tokenizer=tokenizer, + ) + + # Log correctness metrics with length info + logger.info(f"Correctness metrics: {correct}") + logger.info(f"Average correctness: {np.mean(correct):.2f}") + logger.info(f"Standard deviation: {np.std(correct):.2f}") + + # Log length metrics + student_lengths = [len(ans.strip()) for ans in student_answers] + teacher_lengths = [len(ans.strip()) for ans in teacher_answers] + logger.info(f"Student lengths: {student_lengths}") + logger.info(f"Teacher lengths: {teacher_lengths}") + logger.info(f"Average student length: {np.mean(student_lengths):.2f}") + logger.info(f"Average teacher length: {np.mean(teacher_lengths):.2f}") + logger.info(f"Length ratio: {np.mean(student_lengths) / np.mean(teacher_lengths):.2f}") + + return correct + + return reward_correctness + + +def reward_format(prompts: list, completions: list, **reward_kwargs) -> list: + """Reward function that checks if the completion follows the required format with proper tags. + + Args: + prompts: List of input prompts + completions: List of completion dictionaries containing messages + **reward_kwargs: Additional reward parameters + + Returns: + list: List of rewards (1.0 for valid format, 0.0 for invalid) + """ + # Regex patterns for each tag type - no markdown allowed + think_pattern = r"[\s\S]*?" + search_pattern = r"[\s\S]*?" + answer_pattern = r"[\s\S]*?" + + # Information tag patterns - handle multiple variants + info_patterns = [ + r"[\s\S]*?", # Standard + r"[\s\S]*?", # Shortened + r"[\s\S]*?", # Capitalized variants + r"[\s\S]*?", # Uppercase + r"[\s\S]*?", # Uppercase shortened + ] + + # Invalid patterns (bold/italic tags) + invalid_patterns = [ + r"\*\*<\/?(?:think|search|answer|information|info)>\*\*", # Bold tags + r"\*<\/?(?:think|search|answer|information|info)>\*", # Italic tags + r"_<\/?(?:think|search|answer|information|info)>_", # Underscore italic + ] + + rewards = [] + + for completion in completions: + messages = completion.get("messages", []) + assistant_msgs = [msg["content"] for msg in messages if msg["role"] == "assistant"] + + if not assistant_msgs: + rewards.append(0.0) + continue + + content = assistant_msgs[-1] # Get the last assistant message + + # Check for invalid markdown formatting + has_invalid_tags = any(re.search(pattern, content) for pattern in invalid_patterns) + if has_invalid_tags: + logger.debug("Found markdown-formatted tags in response") + rewards.append(0.0) + continue + + # Check for any information tag variants (should not exist in assistant messages) + has_info_tags = False + for pattern in info_patterns: + info_matches = re.findall(pattern, content, re.IGNORECASE) + if info_matches: + logger.debug(f"Found {len(info_matches)} information tag(s) of type '{pattern}' in assistant message") + has_info_tags = True + break + + if has_info_tags: + rewards.append(0.0) + continue + + # Find all tag matches + think_matches = re.findall(think_pattern, content) + search_matches = re.findall(search_pattern, content) + answer_matches = re.findall(answer_pattern, content) + + # Verify tag presence and count + has_think = len(think_matches) >= 1 + has_answer = len(answer_matches) == 1 # Must have exactly one answer + has_search = len(search_matches) >= 1 # One or more search tags + + # Check for search and answer in the same message (not allowed) + if has_search and has_answer: + logger.debug("Found both search and answer tags in the same message") + rewards.append(0.0) + continue + + # Award reward - must have think tag and either answer or search (but not both) + reward = 1.0 if has_think and (has_answer or has_search) else 0.0 + rewards.append(reward) + + # Log issues for debugging + if not reward: + logger.debug(f"Format issues - think: {has_think}, answer: {has_answer}, search: {has_search}") + if search_matches: + logger.debug(f"Number of search tags: {len(search_matches)}") + + # Log overall metrics + logger.info(f"Format reward metrics - Mean: {np.mean(rewards):.3f}, Valid formats: {sum(rewards)}/{len(rewards)}") + + return rewards + + +# TODO: Implement this reward function if the project survives +def reward_long_query(completions, **kwargs): + """Reward function that checks if the query is long.""" + pass + + +def reward_retry(prompts: list, completions: list, **reward_kwargs) -> list: + """ + Reward function that encourages optimal retry behavior. + Rewards increase with more search attempts but caps at optimal_search_count. + Penalizes having multiple searches in a single message. + + Args: + prompts: List of input prompts + completions: List of completion dictionaries with messages + **reward_kwargs: Additional reward parameters (chunk_id, answer, etc.) + + Returns: + List of rewards for each completion, rounded to 3 decimal places + """ + rewards = [] + search_queries = [] + violations = [] + + # Config for retry rewards + optimal_search_count = 5 # Cap rewards at this many searches + base_reward = 0.2 # Base reward for having at least one search + increment = 0.15 # Reward increment per search attempt (0.2 + 5*0.15 = 0.95 max) + violation_penalty = 0.5 # Penalty for having multiple searches in one message + + # Regex pattern for search tags + search_pattern = r"[\s\S]*?" + + for completion in completions: + # Get assistant messages + assistant_messages = [msg["content"] for msg in completion["messages"] if msg["role"] == "assistant"] + + # Count search tags in assistant messages + message_searches = [] + for msg in assistant_messages: + # Find all search tags in each message + search_matches = re.findall(search_pattern, msg) + message_searches.append(len(search_matches)) + + # Record total search queries + total_searches = sum(message_searches) + search_queries.append(total_searches) + + # Check for violations (more than one search query per message) + violation = any(count > 1 for count in message_searches) + violations.append(violation) + + # Calculate reward + if total_searches == 0: + reward = 0.0 # No searches = no reward + else: + # Base reward for having at least one search + reward = base_reward + + # Add incremental reward for each search up to optimal_search_count + search_bonus = min(total_searches, optimal_search_count) * increment + reward += search_bonus + + # Cap reward at 1.0 + reward = min(1.0, reward) + + # Apply penalty if there's a violation + if violation: + reward *= 1 - violation_penalty + + # Round to 3 decimal places to avoid floating point precision issues + reward = round(reward, 3) + + rewards.append(reward) + + # Log metrics with search distribution info + logger.info(f"Retry behavior rewards: {np.mean(rewards):.3f} Β± {np.std(rewards):.3f}") + logger.info(f"Search tags per completion: {np.mean(search_queries):.2f} Β± {np.std(search_queries):.2f}") + logger.info(f"Violations (>1 search per message): {sum(violations)}/{len(violations)}") + logger.info(f"Search counts distribution: {search_queries}") + + return rewards + + +def reward_em_chunk(prompts: list, completions: list, **reward_kwargs) -> list: + """Reward function that checks if model's search queries hit the correct chunk content. + + Args: + prompts: List of input prompts + completions: List of completion dictionaries with messages + **reward_kwargs: Additional reward parameters including: + - chunk_content: List of correct chunk contents to match against + - step: Optional step number for logging metrics + + Returns: + list: List of rewards (1.0 for exact match, 0.0 otherwise) + + Raises: + ValueError: If chunk_content is not provided in reward_kwargs + """ + logger.debug(f"Calculating rewards for {len(prompts)} prompts") + + # Get correct chunk contents from reward kwargs + correct_contents = reward_kwargs.get("chunk_content", []) + if not correct_contents: + logger.error("No chunk_content provided in reward_kwargs") + raise ValueError("chunk_content must be provided in reward_kwargs") + + rewards = [] + for i, (completion, correct_content) in enumerate(zip(completions, correct_contents)): + # Get all messages from ipython or user roles that start with + search_results = [ + msg["content"] + for msg in completion["messages"] + if msg["role"] in ("ipython", "user") and msg["content"].strip().startswith("") + ] + logger.debug(f"Found {len(search_results)} search results for prompt {i}") + + # Log ground truth and searched chunks for debugging + logger.info(f"πŸ“ Ground Truth Chunk: {correct_content}") + for j, result in enumerate(search_results): + logger.info(f"πŸ” Searched Chunk {j + 1}: {result}") + + # Check if any search hit the correct chunk content + found_correct_chunk = any(correct_content in result for result in search_results) + + if not found_correct_chunk: + logger.warning( + f"Failed to find correct chunk for prompt {i}:\n" + f"Search results: {[r[:100] + '...' for r in search_results]}" + ) + + reward = 1.0 if found_correct_chunk else 0.0 + rewards.append(reward) + logger.debug(f"Reward for prompt {i}: {reward}") + + # Log summary metrics + logger.info("Chunk Query Rewards Summary:") + logger.info(f"Total prompts: {len(prompts)}") + logger.info(f"Correct matches: {sum(rewards)}") + logger.info(f"Average reward: {np.mean(rewards):.3f}") + logger.info(f"Reward std: {np.std(rewards):.3f}") + + return rewards diff --git a/src/rl_helpers.py b/src/rl_helpers.py deleted file mode 100644 index 7e4a8a7..0000000 --- a/src/rl_helpers.py +++ /dev/null @@ -1,806 +0,0 @@ -""" -RL helpers module for handling tool-based conversations. -This module provides utility functions for handling chat-based tool interactions -and calculating rewards based on the quality of responses. -""" - -import inspect -import re -from dataclasses import dataclass -from datetime import datetime - -import nest_asyncio -import numpy as np -import torch - -from src.config import log_metric, logger -from src.search_module import get_qa_dataset, search - -# Apply nest_asyncio for supporting async operations in notebooks -nest_asyncio.apply() - -from trl.trainer.grpo_trainer import apply_chat_template - - -# Constants for prompts and tool definitions -def get_system_prompt(): - """Get the system prompt with current date.""" - current_date = datetime.now().strftime("%d %b %Y") - return f"""Cutting Knowledge Date: December 2023 -Today Date: {current_date} - -When you receive a tool call response, use the output to format an answer to the original user question. - -You are a helpful assistant with tool calling capabilities. -""" - - -def build_user_prompt(q): - """ - Build a user prompt with the question using the new template format. - - Args: - q (str): The question to ask - - Returns: - str: Formatted user prompt - """ - user_prompt = f"""Answer the given question. \ -You must conduct reasoning inside and first every time you get new information. \ -After reasoning, if you find you lack some knowledge, you can call a search engine by query . \ -You can search as many times as your want. \ -If you find no further external knowledge needed, you can directly provide the answer inside and , without detailed illustrations. For example, Beijing . - -IMPORTANT INSTRUCTIONS: -1. PLEASE CONSIDER CHAT HISTORY WHEN ANSWERING THE QUESTION. -2. ONLY ANSWER WHEN YOU HAVE 100% CONFIDENCE IN THE SEARCH RESULTS, ELSE CONTINUE SEARCHING. -3. PLEASE SEARCH MULTIPLE TIMES WITH DIFFERENT QUERIES. - -Question: {q}\n""" - return user_prompt - - -def format_search_results(results: str | list[str]) -> str: - """ - Format search results for display, matching the format from infer.py. - Each result should be in the format: "Doc X(Title: Y) content" - - Args: - results: Search results as string or list of strings - - Returns: - Formatted search results with document titles - """ - if isinstance(results, list): - # If results are already in the correct format, just join them - if any("Doc" in r and "Title:" in r for r in results): - content = "\n".join(results) - else: - # If results are raw content, format them with default titles - content = "\n".join([f"Doc {i + 1}(Title: Document {i + 1})\n{r}" for i, r in enumerate(results)]) - else: - # If single result is already formatted, use it as is - if "Doc" in results and "Title:" in results: - content = results - else: - # If single result is raw content, format it with default title - content = f"Doc 1(Title: Document 1)\n{results}" - - return content - - -def get_initial_chat(question): - """ - Initialize a chat state with the question. - - Args: - question (str): The question to ask - - Returns: - dict: Initial chat state with system and user messages - """ - return { - "messages": [ - {"role": "system", "content": get_system_prompt()}, - {"role": "user", "content": build_user_prompt(question)}, - ] - } - - -def remove_reasoning(text: str) -> str: - """ - Removes all content between and tags, - including the tags themselves. - - Parameters: - text (str): The input text that may contain ... tags. - - Returns: - str: The text with the tags and their content removed. - """ - # The regex pattern matches from to non-greedily. - pattern = r".*?" - cleaned_text = re.sub(pattern, "", text, flags=re.DOTALL) - return cleaned_text - - -def run_agent_generations(generate_fn, tokenizer, chat_states): - """ - Run generation for chat states requiring assistant responses. - """ - logger.debug(f"Starting generation for {len(chat_states)} chat states") - prompts = [] - batch_indices = [] - # Prepare prompts for chat states needing an assistant response. - for idx, chat_state in enumerate(chat_states): - if chat_state.get("finished"): - logger.debug(f"Chat state {idx} already finished, skipping") - continue - - if chat_state["messages"][-1]["role"] in ["ipython", "user"]: - prompt = apply_chat_template(chat_state, tokenizer=tokenizer)["text"] - prompts.append(prompt) - batch_indices.append(idx) - logger.debug(f"Added prompt for chat state {idx}") - - if prompts: - logger.info(f"Generating responses for {len(prompts)} prompts") - responses = generate_fn(prompts) - for i, idx in enumerate(batch_indices): - chat_state = chat_states[idx] - response = responses[i] - if hasattr(response, "outputs"): - full_response = response.outputs[0].text - else: - full_response = response - assistant_response = full_response.split("<|start_header_id|>assistant<|end_header_id|>")[-1] - chat_state["messages"].append({"role": "assistant", "content": assistant_response}) - logger.debug(f"Added assistant response to chat state {idx}") - else: - logger.debug("No prompts to generate responses for") - return chat_states - - -def check_finished_chats(chat_states): - """ - Check which chat states are finished (no more search queries). - - Args: - chat_states: List of chat states - - Returns: - list: Updated chat states with finished flag - """ - for chat_state in chat_states: - if chat_state.get("finished"): - continue - assert chat_state["messages"][-1]["role"] == "assistant", "Expected the last role to be assistant" - assistant_response = chat_state["messages"][-1]["content"] - # Check if there are any search queries in the response - if not re.search(r".*?", assistant_response, re.DOTALL): - chat_state["finished"] = True - return chat_states - - -def extract_search_query(text: str) -> str | None: - """ - Extract search query from text between tags. - - Args: - text (str): Text containing search query - - Returns: - str | None: Search query if found, None otherwise - """ - pattern = re.compile(r"(.*?)", re.DOTALL) - matches = pattern.findall(text) - return matches[-1] if matches else None - - -def run_tool_calls(chat_states): - """ - Execute tool calls found in chat states. - """ - logger.debug(f"Running tool calls for {len(chat_states)} chat states") - total_retries = 0 - - for chat_state in chat_states: - if chat_state.get("finished"): - logger.debug("Chat state already finished, skipping tool calls") - continue - assert chat_state["messages"][-1]["role"] == "assistant", ( - "Expected the last role to be assistant to run tool calls" - ) - try: - assistant_response = chat_state["messages"][-1]["content"] - search_query = extract_search_query(assistant_response) - if search_query: - logger.info(f"πŸ” Search Query: {search_query}") - results = search(search_query, return_type=str, results=2) - # Wrap results in tags - formatted_results = f"{results}" - logger.info(f"ℹ️ Information: {formatted_results}") - - chat_state["messages"].append({"role": "ipython", "content": formatted_results}) - total_retries += 1 - logger.debug("Added search results to chat state") - except Exception as e: - logger.error(f"Error during tool call: {str(e)}") - chat_state["messages"].append({"role": "system", "content": f"Error during post-processing: {str(e)}"}) - chat_state["finished"] = True - return chat_states - - -def get_mask(text, tokenizer): - encoding = tokenizer(text, add_special_tokens=False) - start_header_id = tokenizer.convert_tokens_to_ids("<|start_header_id|>") - assistant_token = tokenizer.convert_tokens_to_ids("assistant") - end_header_id = tokenizer.convert_tokens_to_ids("<|end_header_id|>") - eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>") - assistant_ranges = [] - i = 0 - while i < len(encoding.input_ids) - 1: - if encoding.input_ids[i] == start_header_id and encoding.input_ids[i + 1] == assistant_token: - i += 2 - while i < len(encoding.input_ids) and encoding.input_ids[i] != end_header_id: - i += 1 - i += 2 - start_idx = i - while i < len(encoding.input_ids) and encoding.input_ids[i] != eot_id: - i += 1 - end_idx = i - assistant_ranges.append((start_idx, end_idx)) - else: - i += 1 - mask = [0] * len(encoding.input_ids) - for start_idx, end_idx in assistant_ranges: - for idx in range(start_idx, end_idx): - mask[idx] = 1 - return torch.tensor(mask, dtype=torch.int) - - -def check_exceeded_max_new_tokens(chat_states, max_new_tokens, tokenizer): - for chat_state in chat_states: - if chat_state.get("finished"): - continue - initial_length = chat_state["initial_length"] - new_length = get_chat_num_tokens(chat_state, tokenizer) - if new_length - initial_length > max_new_tokens: - chat_state["finished"] = True - return chat_states - - -@dataclass -class AgenticOutputs: - prompt_tokens: list[torch.Tensor] - response_tokens: list[torch.Tensor] - response_masks: list[torch.Tensor] - final_response_str: list[str] - full_chat_states: list[dict] - - -def get_chat_num_tokens(chat_state, tokenizer): - chat_text = apply_chat_template(chat_state, tokenizer=tokenizer)["text"] - return tokenizer(chat_text, add_special_tokens=False, return_tensors="pt")["input_ids"].squeeze().shape[0] - - -def run_agent( - generate_fn, - tokenizer, - questions, - max_generations=5, - max_new_tokens=4096, - correct_contents=None, -): - """ - Run the agent to completion for a batch of questions. - """ - logger.info(f"Starting agent run with {len(questions)} questions") - logger.debug(f"Max generations: {max_generations}, Max new tokens: {max_new_tokens}") - - chat_states = [get_initial_chat(q) for q in questions] - # Add correct content to chat states if provided - if correct_contents: - for chat_state, correct_content in zip(chat_states, correct_contents): - chat_state["correct_content"] = correct_content - - # set the initial_prompt length - for i, chat_state in enumerate(chat_states): - chat_state["initial_length"] = get_chat_num_tokens(chat_state, tokenizer) - logger.debug(f"Initial length for question {i}: {chat_state['initial_length']}") - - # agent loop - for i in range(max_generations): - logger.info(f"Starting generation step {i + 1}/{max_generations}") - chat_states = run_agent_generations(generate_fn, tokenizer, chat_states) - chat_states = check_finished_chats(chat_states) - chat_states = run_tool_calls(chat_states) - chat_states = check_exceeded_max_new_tokens(chat_states, max_new_tokens, tokenizer) - finished_count = sum(1 for state in chat_states if state.get("finished")) - logger.info(f"Finished {finished_count}/{len(chat_states)} chat states after step {i + 1}") - - logger.info("Agent run completed") - - # Process final outputs - logger.debug("Processing final outputs") - answers = [] - for chat in chat_states: - answers.append(chat["messages"][-1]["content"]) - logger.debug(f"Final answer: {chat['messages'][-1]['content'][:100]}...") - - def split_prompt_assistant(convo_text): - marker = "<|start_header_id|>assistant<|end_header_id|>" - idx = convo_text.find(marker) - if idx == -1: - logger.error("Could not find assistant marker in conversation text") - raise ValueError("Could not find assistant marker in conversation text.") - return convo_text, "" - prompt = convo_text[: idx + len(marker)] - assistant_response = convo_text[idx + len(marker) :] - return prompt, assistant_response - - str_chats = [apply_chat_template(chat, tokenizer=tokenizer)["text"] for chat in chat_states] - prompt_toks, response_toks, response_masks = [], [], [] - - logger.debug("Processing tokenization") - for i, str_chat in enumerate(str_chats): - prompt, response = split_prompt_assistant(str_chat) - prompt_toks.append(tokenizer(prompt, add_special_tokens=False, return_tensors="pt")["input_ids"].squeeze()) - response_toks.append( - tokenizer(response, add_special_tokens=False, return_tensors="pt")["input_ids"].squeeze()[:max_new_tokens] - ) - mask = get_mask(str_chat, tokenizer)[len(prompt_toks[-1]) :][:max_new_tokens] - response_masks.append(mask) - logger.debug(f"Processed tokens for chat {i}") - - final_response_str = [chat["messages"][-1]["content"] for chat in chat_states] - full_chat_states = chat_states - - logger.info("Agent run completed successfully") - return AgenticOutputs( - prompt_tokens=prompt_toks, - response_tokens=response_toks, - response_masks=response_masks, - final_response_str=final_response_str, - full_chat_states=full_chat_states, - ) - - -# Verification -async def verify(student_answer: str, question: str, answer: str) -> bool: - """ - Verify if student's answer matches the correct answer. - - Args: - student_answer: The model's answer - question: The original question - answer: The ground truth answer - - Returns: - bool: True if answer is correct, False otherwise - """ - logger.debug(f"Verifying answer for question: {question}") - logger.debug(f"Student answer: {student_answer}") - logger.debug(f"Correct answer: {answer}") - - # Simple string matching for now - # TODO: Implement more sophisticated matching - return student_answer.strip().lower() == answer.strip().lower() - - -def check_student_answers( - questions: list[str], - answers: list[str], - student_answers: list, # Can be strings or dicts - vllm_generate_func, - tokenizer, - log_file=None, -) -> list[bool]: - """ - Evaluates a list of student answers against the true answers using a vLLM generate function. - - Args: - questions: List of questions - answers: List of correct answers - student_answers: List of student answers to evaluate - vllm_generate_func: Function to generate verification responses - tokenizer: Tokenizer for formatting prompts - log_file: Optional path to write detailed results - - Returns: - List of boolean results (True for correct answers) - """ - logger.info(f"Checking {len(questions)} student answers") - - if not (len(questions) == len(answers) == len(student_answers)): - logger.error("Mismatched lengths between questions, answers, and student answers") - raise ValueError("The number of questions, answers, and student answers must be equal.") - - prompts = [] - for question, answer, student_ans in zip(questions, answers, student_answers): - prompt_text = ( - "You are grading a student's answer to a question. For the following question, " - "compare the student's answer to the correct answer. Reply with 'Yes' if the student's answer contains the correct information, " - "even if it's not an exact match. If the student's answer doesn't contain the right information or is completely incorrect, reply with 'No'.\n\n" - f"Question: {question}\n" - f"Correct Answer: {answer}\n" - f"Student Answer: {student_ans}\n\n" - "Your response should be just 'Yes' or 'No'." - ) - - formatted_prompt = tokenizer.apply_chat_template( - [{"role": "user", "content": prompt_text}], - tokenize=False, - add_generation_prompt=True, - ) - prompts.append(formatted_prompt) - logger.debug(f"Created verification prompt for question: {question[:50]}...") - - logger.info("Generating verification responses") - responses = vllm_generate_func(prompts) - responses_text = [] - for response in responses: - # Handle different response formats - if hasattr(response, "outputs"): - try: - responses_text.append(response.outputs[0].text) - except (AttributeError, IndexError): - # Fallback for simple string responses - responses_text.append(str(response)) - else: - responses_text.append(str(response)) - logger.debug(f"Got {len(responses_text)} verification responses") - - results = [] - for response in responses_text: - results.append("yes" in response.lower()) - logger.debug(f"Verification result: {'yes' in response.lower()}") - - logger.info(f"Verification complete. {sum(results)}/{len(results)} answers correct") - - # Append the QA details and verifier's response to the specified log file - if log_file: - with open(log_file, "a") as file: - timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - file.write(f"\nπŸ“ === QA Evaluation at {timestamp} ===\n") - file.write(f"πŸ“‚ File: {__file__}\n") - - # Get current frame info safely - frame = inspect.currentframe() - if frame: - file.write(f"πŸ“ Line: {frame.f_lineno}\n") - # Don't forget to delete the frame to avoid reference cycles - del frame - - file.write("=" * 80 + "\n") - - for i, (question, answer, student_ans, verifier_response) in enumerate( - zip(questions, answers, student_answers, responses_text) - ): - file.write(f"\n❓ Question {i + 1}:\n") - file.write("-" * 40 + "\n") - file.write(f"πŸ“‹ Question: {question}\n") - file.write(f"βœ… Correct Answer: {answer}\n") - file.write(f"πŸ‘¨β€πŸŽ“ Student Answer: {student_ans}\n") - file.write(f"πŸ” Verifier said: {verifier_response}\n") - - # Add search results if available in the chat state - if isinstance(student_ans, dict) and "messages" in student_ans: - # Get messages from dict - messages = student_ans.get("messages", []) - search_results = [msg.get("content", "") for msg in messages if msg.get("role") == "ipython"] - if search_results: - file.write("\nπŸ”Ž Search Results:\n") - for j, result in enumerate(search_results, 1): - file.write(f"\nSearch {j}:\n{result}\n") - - file.write("-" * 40 + "\n") - - file.write( - f"\nπŸ“Š Summary: {sum(results)}/{len(results)} answers correct ({sum(results) / len(results) * 100:.2f}%)\n" - ) - file.write("=" * 80 + "\n\n") - - return results - - -# Reward Functions -def build_reward_correctness_fn(generate_fn, tokenizer, log_file=None): - def reward_correctness(prompts, completions, **reward_kwargs): - teacher_answers = reward_kwargs["answer"] - student_answers = [completion["messages"][-1]["content"] for completion in completions] - - # Log non-exact matches - for i, (student, teacher) in enumerate(zip(student_answers, teacher_answers)): - if student.strip().lower() != teacher.strip().lower(): - logger.warning(f"Non-exact match at index {i}:\nStudent: {student}\nTeacher: {teacher}") - - correct = check_student_answers( - prompts, - teacher_answers, - student_answers, - vllm_generate_func=generate_fn, - tokenizer=tokenizer, - log_file=log_file, - ) - - # Log correctness metrics with length info - log_metric("rewards/correctness", np.mean(correct), reward_kwargs.get("step", 0)) - log_metric("rewards/correctness_std", np.std(correct), reward_kwargs.get("step", 0)) - - # Log length metrics - student_lengths = [len(ans.strip()) for ans in student_answers] - teacher_lengths = [len(ans.strip()) for ans in teacher_answers] - log_metric( - "metrics/avg_student_length", - np.mean(student_lengths), - reward_kwargs.get("step", 0), - ) - log_metric( - "metrics/avg_teacher_length", - np.mean(teacher_lengths), - reward_kwargs.get("step", 0), - ) - log_metric( - "metrics/length_ratio", - np.mean(student_lengths) / np.mean(teacher_lengths), - reward_kwargs.get("step", 0), - ) - - return correct - - return reward_correctness - - -def reward_formatting(prompts, completions, **reward_kwargs): - # make sure full chats doesn't have any error function calls - has_error = [False] * len(completions) - for i, chat in enumerate(completions): - for message in chat["messages"]: - if "Error during" in message["content"]: - has_error[i] = True - logger.warning(f"Error in chat {i}: {message['content']}") - break - - rewards = [0.7 if not e else 0 for e in has_error] - - # Log formatting metrics - log_metric("rewards/formatting", np.mean(rewards), reward_kwargs.get("step", 0)) - log_metric("rewards/formatting_std", np.std(rewards), reward_kwargs.get("step", 0)) - log_metric("metrics/error_rate", np.mean(has_error), reward_kwargs.get("step", 0)) - - return rewards - - -def reward_retry_behavior(completions: list[dict], **reward_kwargs) -> list[float]: - """ - Reward function that encourages optimal retry behavior by only rewarding completions - where every assistant message contains at most 1 search query. - """ - rewards: list[float] = [] - - for completion in completions: - # Get ALL assistant messages - assistant_msgs: list[str] = [ - msg["content"] - for msg in completion["messages"] - if msg["role"] == "assistant" and msg["content"] is not None - ] - - if not assistant_msgs: - rewards.append(0.0) - continue - - # Check if every message has at most 1 search query - has_multiple_searches = False - total_searches = 0 - - for msg in assistant_msgs: - search_count = len(re.findall(r".*?", msg, re.DOTALL)) - total_searches += search_count - - if search_count > 1: - has_multiple_searches = True - logger.warning(f"Message contains {search_count} search queries, which exceeds the limit of 1") - break - - # Only reward if no message has multiple search queries - if has_multiple_searches: - rewards.append(0.0) - else: - # Base reward is 1.0 if constraint is met - base_reward = 1.0 - - # Slight penalty for having too many total searches across all messages - if total_searches > 4: - penalty = 0.1 * (total_searches - 4) - base_reward = max(0.2, base_reward - penalty) - logger.debug(f"Applied penalty for {total_searches} total searches: {penalty}") - - rewards.append(base_reward) - - # Log retry behavior metrics - log_metric("rewards/retry_behavior", np.mean(rewards), reward_kwargs.get("step", 0)) - log_metric("rewards/retry_behavior_std", np.std(rewards), reward_kwargs.get("step", 0)) - log_metric( - "metrics/avg_searches_per_msg", - np.mean( - [ - len(re.findall(r".*?", msg["content"], re.DOTALL)) - for completion in completions - for msg in completion["messages"] - if msg["role"] == "assistant" - ] - ), - reward_kwargs.get("step", 0), - ) - log_metric( - "metrics/multiple_search_violation_rate", - np.mean([0.0 if rewards[i] > 0.0 else 1.0 for i in range(len(rewards))]), - reward_kwargs.get("step", 0), - ) - - return rewards - - -def reward_exact_match_chunk_query(prompts, completions, **reward_kwargs): - """ - Reward function that checks if the model's search queries hit the correct chunk content. - """ - logger.debug(f"Calculating rewards for {len(prompts)} prompts") - - # Get correct chunk contents from reward kwargs - correct_contents = reward_kwargs.get("chunk_content", []) - if not correct_contents: - logger.error("No chunk_content provided in reward_kwargs") - raise ValueError("chunk_content must be provided in reward_kwargs") - - rewards = [] - for i, (chat_state, correct_content) in enumerate(zip(completions, correct_contents)): - # Get all ipython messages (search results) from the chat - search_results = [msg["content"] for msg in chat_state["messages"] if msg["role"] == "ipython"] - logger.debug(f"Found {len(search_results)} search results for prompt {i}") - - # Log ground truth chunk and searched chunks - logger.info(f"πŸ“ Ground Truth Chunk: {correct_content}") - for j, result in enumerate(search_results): - logger.info(f"πŸ” Searched Chunk {j + 1}: {result}") - - # Check if any search hit the correct chunk content - found_correct_chunk = False - for result in search_results: - if correct_content in result: - found_correct_chunk = True - logger.debug(f"Found correct chunk content in search results for prompt {i}") - break - - if not found_correct_chunk: - logger.warning( - f"Failed to find correct chunk for prompt {i}:\n" - f"Search results: {[r[:100] + '...' for r in search_results]}" - ) - - reward = 1.0 if found_correct_chunk else 0.0 - rewards.append(reward) - logger.debug(f"Reward for prompt {i}: {reward}") - - # Log detailed metrics for debugging - log_metric( - f"debug/chunk_match_{i}", - 1 if found_correct_chunk else 0, - reward_kwargs.get("step", 0), - ) - log_metric( - f"debug/search_results_count_{i}", - len(search_results), - reward_kwargs.get("step", 0), - ) - if search_results: - log_metric( - f"debug/result_length_{i}", - np.mean([len(r.split()) for r in search_results]), - reward_kwargs.get("step", 0), - ) - - # Log chunk query metrics - log_metric("rewards/chunk_query", np.mean(rewards), reward_kwargs.get("step", 0)) - log_metric("rewards/chunk_query_std", np.std(rewards), reward_kwargs.get("step", 0)) - log_metric( - "metrics/avg_search_results", - np.mean( - [ - len([msg["content"] for msg in chat_state["messages"] if msg["role"] == "ipython"]) - for chat_state in completions - ] - ), - reward_kwargs.get("step", 0), - ) - log_metric("metrics/chunk_match_rate", np.mean(rewards), reward_kwargs.get("step", 0)) - - # Log detailed debugging info - logger.info("Chunk Query Rewards Summary:") - logger.info(f"Total prompts: {len(prompts)}") - logger.info(f"Correct matches: {sum(rewards)}") - logger.info(f"Average reward: {np.mean(rewards):.3f}") - logger.info(f"Reward std: {np.std(rewards):.3f}") - - return rewards - - -def run_eval(generate_fn, verify_fn, tokenizer, output_file=None, debug_file=None): - """ - Run evaluation on the test dataset and return results. - - Args: - generate_fn: Function to generate completions - verify_fn: Function to verify results - tokenizer: Tokenizer for processing text - output_file: Path to save evaluation results summary - debug_file: Path to save detailed debug information - - Returns: - full_chat_states: The chat states from evaluation - """ - train_dataset, test_dataset = get_qa_dataset() - questions = test_dataset["prompt"] - agentic_outputs = run_agent(generate_fn, tokenizer, questions) - full_chat_states = agentic_outputs.full_chat_states - final_responses = agentic_outputs.final_response_str - rewards = verify_fn(questions, full_chat_states, answer=test_dataset["answer"]) - - # Calculate results - percent_correct = sum(rewards) / len(rewards) * 100 - - # Log results to console - logger.info("RESULTS:") - logger.info(f"percentage of correct answers: {percent_correct:.2f}%") - logger.info("=" * 30) - - # Save results to file if specified - if output_file: - try: - with open(output_file, "w") as f: - f.write("EVALUATION RESULTS\n") - f.write("=================\n\n") - f.write(f"Total questions: {len(questions)}\n") - f.write(f"Correct answers: {sum(rewards)}\n") - f.write(f"Percentage correct: {percent_correct:.2f}%\n\n") - - f.write("Individual results:\n") - for i, (q, r, resp) in enumerate(zip(questions, rewards, final_responses)): - f.write(f"\nQ{i + 1}: {q[:100]}...\n") - f.write(f"Correct: {'βœ“' if r else 'βœ—'}\n") - f.write(f"Response: {resp[:150]}...\n") - f.write("-" * 40 + "\n") - logger.info(f"Saved evaluation results to {output_file}") - except Exception as e: - logger.error(f"Error saving results file: {e}") - - # Save debug information if specified - if debug_file: - try: - import json - - debug_data = [] - for i, (q, r, resp, chat) in enumerate(zip(questions, rewards, final_responses, full_chat_states)): - debug_data.append( - { - "question_id": i, - "question": q, - "is_correct": bool(r), - "final_response": resp, - "chat_state": { - k: str(v) if isinstance(v, (list, dict)) else v - for k, v in chat.items() - if k != "tokenizer" - }, - } - ) - - with open(debug_file, "w") as f: - json.dump(debug_data, f, indent=2) - logger.info(f"Saved debug information to {debug_file}") - except Exception as e: - logger.error(f"Error saving debug file: {e}") - - return full_chat_states diff --git a/src/search_module.py b/src/search_module.py index fe4991d..48aff90 100644 --- a/src/search_module.py +++ b/src/search_module.py @@ -7,7 +7,7 @@ import json import random from datasets import Dataset -from langchain.vectorstores import FAISS +from langchain_community.vectorstores import FAISS from src.config import DATA_DIR, logger from src.embeddings import CustomHuggingFaceEmbeddings diff --git a/src/tokenizer_adapter.py b/src/tokenizer_adapter.py new file mode 100644 index 0000000..d5b4bc8 --- /dev/null +++ b/src/tokenizer_adapter.py @@ -0,0 +1,306 @@ +""" +Tokenizer adapter implementations for different models. +This module provides adapter classes for handling different tokenizer formats. +""" + +from abc import ABC, abstractmethod + +import torch + +from src.config import logger + + +class TokenizerAdapter(ABC): + """Base class for tokenizer adapters.""" + + @abstractmethod + def get_assistant_marker(self) -> str: + """Get the assistant marker for the model.""" + pass + + @abstractmethod + def get_end_marker(self) -> str: + """Get the end marker for the model.""" + pass + + @abstractmethod + def get_mask(self, text: str, tokenizer) -> torch.Tensor: + """Get the mask for the model's response.""" + pass + + @abstractmethod + def split_prompt_assistant(self, text: str) -> tuple[str, str]: + """Split conversation text into prompt and assistant response.""" + pass + + +class LlamaTokenizerAdapter(TokenizerAdapter): + """Adapter for Llama model tokenizer.""" + + def get_assistant_marker(self) -> str: + """Get the assistant marker.""" + return "<|start_header_id|>assistant<|end_header_id|>" + + def get_end_marker(self) -> str: + """Get the end marker.""" + return "<|eot_id|>" + + def split_prompt_assistant(self, convo_text: str) -> tuple[str, str]: + """Split the text into prompt and assistant parts. + + Args: + convo_text: The text to split + + Returns: + A tuple of (prompt, assistant) + """ + # EXACT replication from rl_helpers.py but using existing method + marker = self.get_assistant_marker() # Use existing method but same value + idx = convo_text.find(marker) + if idx == -1: + raise ValueError("Could not find assistant marker in conversation text.") + return convo_text, "" + + # Include the marker in the prompt by slicing up to the end of the marker. + prompt = convo_text[: idx + len(marker)] + # The assistant response is everything after the marker. + assistant_response = convo_text[idx + len(marker) :] + return prompt, assistant_response + + def get_mask(self, text: str, tokenizer) -> torch.Tensor: + """Get the mask for the text. + + Args: + text: The text to get the mask for + tokenizer: The tokenizer to use + + Returns: + A tensor of 0s and 1s where 1s indicate assistant tokens + """ + # Log input + logger.debug(f"πŸ” Llama: Full text length: {len(text)}") + + # EXACT replication from rl_helpers.py but using existing methods + encoding = tokenizer(text, add_special_tokens=False) + # Use existing methods but same values + start_header_id = tokenizer.convert_tokens_to_ids("<|start_header_id|>") + assistant_token = tokenizer.convert_tokens_to_ids("assistant") + end_header_id = tokenizer.convert_tokens_to_ids("<|end_header_id|>") + eot_id = tokenizer.convert_tokens_to_ids(self.get_end_marker()) # Use existing method but same value + + # Log token IDs + logger.debug(f"πŸ” Llama: Tokenized length: {len(encoding.input_ids)}") + logger.debug(f"πŸ” Llama: Input IDs: {encoding.input_ids}") + logger.debug( + f"πŸ” Llama: Special token IDs: start={start_header_id}, assistant={assistant_token}, end={end_header_id}, eot={eot_id}" + ) + + assistant_ranges = [] + i = 0 + while i < len(encoding.input_ids) - 1: + if encoding.input_ids[i] == start_header_id and encoding.input_ids[i + 1] == assistant_token: + logger.debug(f"πŸ” Llama: Found assistant marker at position {i}") + logger.debug(f"πŸ” Llama: Assistant marker tokens: {encoding.input_ids[i : i + 2]}") + i += 2 + while i < len(encoding.input_ids) and encoding.input_ids[i] != end_header_id: + i += 1 + i += 2 + start_idx = i + logger.debug(f"πŸ” Llama: Found start of response at {start_idx}") + logger.debug(f"πŸ” Llama: Start token ID: {encoding.input_ids[start_idx]}") + while i < len(encoding.input_ids) and encoding.input_ids[i] != eot_id: + i += 1 + end_idx = i + logger.debug(f"πŸ” Llama: Found end of response at {end_idx}") + logger.debug(f"πŸ” Llama: End token ID: {encoding.input_ids[end_idx]}") + logger.debug(f"πŸ” Llama: Response token IDs: {encoding.input_ids[start_idx:end_idx]}") + assistant_ranges.append((start_idx, end_idx)) + else: + i += 1 + + mask = [0] * len(encoding.input_ids) + for start_idx, end_idx in assistant_ranges: + for idx in range(start_idx, end_idx): + mask[idx] = 1 + + mask = torch.tensor(mask, dtype=torch.int) + + # Log final mask + logger.debug(f"πŸ” Llama: Final mask shape: {mask.shape}") + logger.debug(f"πŸ” Llama: Mask sum: {mask.sum().item()}") + logger.debug(f"πŸ” Llama: Mask: {mask}") + + # Additional debug info + try: + prompt, response = self.split_prompt_assistant(text) + prompt_tokens = tokenizer(prompt, add_special_tokens=False).input_ids + response_tokens = tokenizer(response, add_special_tokens=False).input_ids + + logger.debug(f"πŸ” Llama: Prompt length: {len(prompt)}") + logger.debug(f"πŸ” Llama: Response length: {len(response)}") + logger.debug(f"πŸ” Llama: Prompt token IDs: {prompt_tokens}") + logger.debug(f"πŸ” Llama: Response token IDs: {response_tokens}") + logger.debug(f"πŸ” Llama: Prompt: {prompt[:100]}...") + logger.debug(f"πŸ” Llama: Response: {response[:100]}...") + logger.debug(f"πŸ” Llama: Full input IDs length: {len(encoding.input_ids)}") + logger.debug(f"πŸ” Llama: Prompt + Response token IDs length: {len(prompt_tokens) + len(response_tokens)}") + logger.debug( + f"πŸ” Llama: Difference in lengths: {len(encoding.input_ids) - (len(prompt_tokens) + len(response_tokens))}" + ) + except Exception as e: + logger.error(f"πŸ” Llama: Error splitting prompt/response: {e}") + + return mask + + +class R1DistilTokenizerAdapter(TokenizerAdapter): + """Adapter for R1-Distil model tokenizer.""" + + def get_assistant_marker(self) -> str: + marker = "<|Assistant|>" + return marker + + def get_end_marker(self) -> str: + marker = "<|end▁of▁sentence|>" + return marker + + def get_begin_marker(self) -> str: + return "<|begin▁of▁sentence|>" + + def get_user_marker(self) -> str: + return "<|User|>" + + def get_mask(self, text: str, tokenizer) -> torch.Tensor: + """Get the mask for the text. + + Args: + text: The text to get the mask for + tokenizer: The tokenizer to use + + Returns: + A tensor of 0s and 1s where 1s indicate assistant tokens + """ + logger.debug(f"πŸ” R1Distil: Getting mask for text length: {len(text)}") + + # Get all markers + assistant_marker = self.get_assistant_marker() + end_marker = self.get_end_marker() + + # Get the full tokenization + encoding = tokenizer(text, add_special_tokens=False) + tokens = encoding.input_ids + logger.debug(f"πŸ” R1Distil: Full text token IDs: {tokens}") + + # Create mask initialized to zeros - ENSURE SAME LENGTH AS INPUT_IDS + mask = torch.zeros(len(tokens), dtype=torch.int) + + # Get token IDs for markers + assistant_tokens = tokenizer(assistant_marker, add_special_tokens=False).input_ids + end_tokens = tokenizer(end_marker, add_special_tokens=False).input_ids + logger.debug(f"πŸ” R1Distil: Assistant marker token IDs: {assistant_tokens}") + logger.debug(f"πŸ” R1Distil: End marker token IDs: {end_tokens}") + + # Find all assistant responses + assistant_ranges = [] + i = 0 + while i < len(tokens): + # Look for assistant marker + if i + len(assistant_tokens) <= len(tokens) and tokens[i : i + len(assistant_tokens)] == assistant_tokens: + logger.debug(f"πŸ” R1Distil: Found assistant marker at position {i}") + + # Start masking AFTER the assistant marker + start_idx = i + len(assistant_tokens) + + # Find end marker + end_idx = None + j = start_idx + while j < len(tokens): + if j + len(end_tokens) <= len(tokens) and tokens[j : j + len(end_tokens)] == end_tokens: + end_idx = j # Don't include the end marker in the mask + break + j += 1 + + if end_idx is None: + # If no end marker found, mask until the end + end_idx = len(tokens) + + logger.debug(f"πŸ” R1Distil: Response range: {start_idx} to {end_idx}") + assistant_ranges.append((start_idx, end_idx)) + i = end_idx + len(end_tokens) # Start next search after the end marker + else: + i += 1 + + # Apply mask for all found ranges + for start_idx, end_idx in assistant_ranges: + mask[start_idx:end_idx] = 1 + + logger.debug(f"πŸ” R1Distil: Found {len(assistant_ranges)} assistant responses") + logger.debug(f"πŸ” R1Distil: Final mask sum: {mask.sum().item()}") + logger.debug(f"πŸ” R1Distil: Final mask length: {len(mask)}") + logger.debug(f"πŸ” R1Distil: Mask: {mask}") + + return mask + + def split_prompt_assistant(self, text: str) -> tuple[str, str]: + """Split the text into prompt and assistant parts. + + Args: + text: The text to split + + Returns: + A tuple of (prompt, assistant) + """ + logger.debug(f"πŸ” R1Distil: Splitting text of length: {len(text)}") + + # Find the assistant marker + marker = self.get_assistant_marker() + end_marker = self.get_end_marker() + + # Find ALL assistant markers in the text + assistant_markers = [] + pos = 0 + while True: + pos = text.find(marker, pos) + if pos == -1: + break + assistant_markers.append(pos) + pos += len(marker) + + if not assistant_markers: + raise ValueError("Could not find assistant marker in text") + + # Get the positions of all markers for later use + marker_positions = [] + for start_pos in assistant_markers: + response_start = start_pos + len(marker) + + # Find the end marker after this response + end_pos = text.find(end_marker, response_start) + if end_pos == -1: + end_pos = len(text) + else: + end_pos = end_pos + len(end_marker) + + marker_positions.append((start_pos, response_start, end_pos)) + + # Get the full response (all assistant outputs concatenated) + full_response = "" + for _, resp_start, resp_end in marker_positions: + full_response += text[resp_start:resp_end] + + # Include ALL assistant markers and responses in the response + # This matches how the mask is generated in get_mask + first_assistant_pos = marker_positions[0][0] + last_response_end = marker_positions[-1][2] + + # Split into prompt and response + prompt = text[:first_assistant_pos] # Everything before the first assistant marker + response = text[first_assistant_pos:last_response_end] # All markers and responses + + logger.debug(f"πŸ” R1Distil: Prompt length: {len(prompt)}") + logger.debug(f"πŸ” R1Distil: Response length: {len(response)}") + logger.debug(f"πŸ” R1Distil: Response token count estimate: {len(response) / 4}") # Rough estimate + logger.debug(f"πŸ” R1Distil: Final prompt: {prompt[:100]}...") + logger.debug(f"πŸ” R1Distil: Final response: {response[:100]}...") + + return prompt, response diff --git a/train_grpo.py b/train_grpo.py index e29fdae..37d68c4 100644 --- a/train_grpo.py +++ b/train_grpo.py @@ -1,8 +1,16 @@ +""" +Train a model using GRPO (Generative Reward-Penalized Optimization). +""" + import os from unsloth import FastLanguageModel, is_bfloat16_supported import src.UnslothGRPOTrainerTemp as UnslothGRPOTrainerTemp + +# Import reward functions +from src import build_reward_correctness_fn, get_qa_dataset, reward_em_chunk, reward_format, reward_retry +from src.agent import Agent from src.config import ( MODEL_CONFIG, MODEL_NAME, @@ -13,16 +21,9 @@ from src.config import ( logger, update_log_path, ) - -# Import reward functions -from src.rl_helpers import ( - build_reward_correctness_fn, - get_qa_dataset, - reward_exact_match_chunk_query, - reward_formatting, - reward_retry_behavior, - run_agent, -) +from src.rewards import build_reward_correctness_fn, reward_em_chunk, reward_retry +from src.search_module import get_qa_dataset +from src.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter # Initialize training directories paths = init_training_dirs() @@ -78,7 +79,17 @@ def agentic_generate( generate_fn, max_generations: int = 10, ): - return run_agent(generate_fn, tokenizer, prompts, max_generations) + # Create agent with appropriate adapter based on tokenizer + tokenizer_name = tokenizer.name_or_path.lower() + if "deepseek-r1-distill" in tokenizer_name: + adapter = R1DistilTokenizerAdapter() + elif "llama" in tokenizer_name: + adapter = LlamaTokenizerAdapter() + else: + adapter = R1DistilTokenizerAdapter() + + agent = Agent(adapter) + return agent.run_agent(generate_fn, tokenizer, prompts, max_generations) model.agentic_generate = agentic_generate @@ -102,13 +113,12 @@ trainer = UnslothGRPOTrainerTemp.UnslothGRPOTrainer( processing_class=tokenizer, reward_funcs=[ build_reward_correctness_fn( - verifier_generate_fn, - tokenizer, - log_file=os.path.join(paths["log_dir"], "qa_log.txt"), + vllm_generate_func=verifier_generate_fn, + tokenizer=tokenizer, ), - reward_formatting, - reward_retry_behavior, - reward_exact_match_chunk_query, + reward_format, + reward_retry, + reward_em_chunk, ], args=training_args, train_dataset=train_dataset,