From af7f38c792b537e97a95fe40bb2efcb3b2942779 Mon Sep 17 00:00:00 2001 From: thinhlpg Date: Thu, 3 Apr 2025 14:48:44 +0700 Subject: [PATCH] feat: add code for qwen architecture --- notebooks/250402_inspect_mask.ipynb | 85 +++++++++++++++++++++ src/tokenizer_adapter.py | 112 ++++++++++++++++++++++++++++ train_grpo.py | 6 +- 3 files changed, 201 insertions(+), 2 deletions(-) diff --git a/notebooks/250402_inspect_mask.ipynb b/notebooks/250402_inspect_mask.ipynb index 29a7788..b5c0994 100644 --- a/notebooks/250402_inspect_mask.ipynb +++ b/notebooks/250402_inspect_mask.ipynb @@ -279,6 +279,91 @@ "masked_df = df[df[\"Mask\"] == 1]\n", "print(masked_df.to_string(index=False))" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Inspect Qwen 2.5 Instruct" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Imports\n", + "import sys\n", + "\n", + "sys.path.append(\"..\")\n", + "\n", + "from transformers import AutoTokenizer\n", + "from src.tokenizer_adapter import QwenTokenizerAdapter\n", + "import pandas as pd\n", + "\n", + "pd.set_option(\"display.max_rows\", None)\n", + "pd.set_option(\"display.max_colwidth\", None)\n", + "\n", + "# Initialize\n", + "tokenizer = AutoTokenizer.from_pretrained(\"Qwen/Qwen2.5-1.5B-Instruct\")\n", + "adapter = QwenTokenizerAdapter()\n", + "\n", + "# Example conversation using chat template\n", + "chat = [\n", + " {\n", + " \"role\": \"system\",\n", + " \"content\": \"You are a friendly chatbot who always responds in the style of a pirate\",\n", + " },\n", + " {\"role\": \"user\", \"content\": \"Hello, how are you?\"},\n", + " {\"role\": \"assistant\", \"content\": \"I'm doing great. How can I help you today?\"},\n", + " {\"role\": \"ipython\", \"content\": \"THIS IS THE DOCUMENT!!!\"}, # this shit doesn't work in chat template\n", + " {\"role\": \"user\", \"content\": \"Hello, have you eanten?\"},\n", + " {\"role\": \"assistant\", \"content\": \"No I'm hungry?\"},\n", + "]\n", + "\n", + "# Get the formatted conversation using chat template\n", + "convo = tokenizer.apply_chat_template(chat, tokenize=False)\n", + "print(\"šŸ’¬ Raw Chat Template Output:\")\n", + "print(f\"{'-' * 50}\\n{convo}\\n{'-' * 50}\\n\")\n", + "\n", + "# 1. Show text splitting\n", + "prompt, response = adapter.split_prompt_assistant(convo)\n", + "print(\"šŸ” Text Split:\")\n", + "print(f\"Prompt:\\n{'-' * 50}\\n{prompt}\\n{'-' * 50}\")\n", + "print(f\"Response:\\n{'-' * 50}\\n{response}\\n{'-' * 50}\\n\")\n", + "\n", + "# 2. Get tokens and mask\n", + "encoding = tokenizer(convo, add_special_tokens=False)\n", + "input_ids = encoding.input_ids\n", + "tokens = tokenizer.convert_ids_to_tokens(input_ids)\n", + "mask = adapter.get_mask(convo, tokenizer)\n", + "\n", + "# 3. Create detailed view\n", + "df = pd.DataFrame(\n", + " {\n", + " \"Position\": range(len(tokens)),\n", + " \"Token ID\": input_ids,\n", + " \"Token\": tokens,\n", + " \"Text\": [tokenizer.decode([id]) for id in input_ids],\n", + " \"Mask\": mask.tolist(),\n", + " }\n", + ")\n", + "\n", + "print(\"šŸ“Š Token Analysis:\")\n", + "print(df.to_string(index=False))\n", + "\n", + "# 4. Quick Stats\n", + "print(\"\\nšŸ“ˆ Quick Stats:\")\n", + "print(f\"Total tokens: {len(tokens)}\")\n", + "print(f\"Masked tokens (1s): {mask.sum().item()}\")\n", + "print(f\"Unmasked tokens (0s): {len(mask) - mask.sum().item()}\")\n", + "\n", + "# 5. Show masked content only\n", + "print(\"\\nšŸŽÆ Masked Content (Response):\")\n", + "masked_df = df[df[\"Mask\"] == 1]\n", + "print(masked_df.to_string(index=False))" + ] } ], "metadata": { diff --git a/src/tokenizer_adapter.py b/src/tokenizer_adapter.py index d5b4bc8..f7aad02 100644 --- a/src/tokenizer_adapter.py +++ b/src/tokenizer_adapter.py @@ -304,3 +304,115 @@ class R1DistilTokenizerAdapter(TokenizerAdapter): logger.debug(f"šŸ” R1Distil: Final response: {response[:100]}...") return prompt, response + + +class QwenTokenizerAdapter(TokenizerAdapter): + """Adapter for Qwen2.5 model tokenizer.""" + + def get_assistant_marker(self) -> str: + """Get the assistant marker.""" + return "<|im_start|>assistant" + + def get_end_marker(self) -> str: + """Get the end marker.""" + return "<|im_end|>" + + def split_prompt_assistant(self, convo_text: str) -> tuple[str, str]: + """Split the text into prompt and assistant parts. + + Args: + convo_text: The text to split + + Returns: + A tuple of (prompt, assistant) + """ + marker = self.get_assistant_marker() + idx = convo_text.find(marker) + if idx == -1: + raise ValueError("Could not find assistant marker in conversation text.") + return convo_text, "" + + # Include the marker in the prompt by slicing up to the end of the marker + prompt = convo_text[: idx + len(marker)] + # The assistant response is everything after the marker + assistant_response = convo_text[idx + len(marker) :] + return prompt, assistant_response + + def get_mask(self, text: str, tokenizer) -> torch.Tensor: + """Get the mask for the text. + + Args: + text: The text to get the mask for + tokenizer: The tokenizer to use + + Returns: + A tensor of 0s and 1s where 1s indicate assistant tokens + """ + # Log input + logger.debug(f"šŸ” Qwen: Full text length: {len(text)}") + + encoding = tokenizer(text, add_special_tokens=False) + # Get token IDs for markers + im_start = tokenizer.convert_tokens_to_ids("<|im_start|>") + assistant_token = tokenizer.convert_tokens_to_ids("assistant") + im_end = tokenizer.convert_tokens_to_ids(self.get_end_marker()) + + # Log token IDs + logger.debug(f"šŸ” Qwen: Tokenized length: {len(encoding.input_ids)}") + logger.debug(f"šŸ” Qwen: Input IDs: {encoding.input_ids}") + logger.debug(f"šŸ” Qwen: Special token IDs: im_start={im_start}, assistant={assistant_token}, im_end={im_end}") + + assistant_ranges = [] + i = 0 + while i < len(encoding.input_ids) - 1: + if encoding.input_ids[i] == im_start and encoding.input_ids[i + 1] == assistant_token: + logger.debug(f"šŸ” Qwen: Found assistant marker at position {i}") + logger.debug(f"šŸ” Qwen: Assistant marker tokens: {encoding.input_ids[i : i + 2]}") + i += 2 # Skip past <|im_start|>assistant + start_idx = i + logger.debug(f"šŸ” Qwen: Found start of response at {start_idx}") + logger.debug(f"šŸ” Qwen: Start token ID: {encoding.input_ids[start_idx]}") + + while i < len(encoding.input_ids) and encoding.input_ids[i] != im_end: + i += 1 + end_idx = i + logger.debug(f"šŸ” Qwen: Found end of response at {end_idx}") + logger.debug(f"šŸ” Qwen: End token ID: {encoding.input_ids[end_idx]}") + logger.debug(f"šŸ” Qwen: Response token IDs: {encoding.input_ids[start_idx:end_idx]}") + assistant_ranges.append((start_idx, end_idx)) + else: + i += 1 + + mask = [0] * len(encoding.input_ids) + for start_idx, end_idx in assistant_ranges: + for idx in range(start_idx, end_idx): + mask[idx] = 1 + + mask = torch.tensor(mask, dtype=torch.int) + + # Log final mask + logger.debug(f"šŸ” Qwen: Final mask shape: {mask.shape}") + logger.debug(f"šŸ” Qwen: Mask sum: {mask.sum().item()}") + logger.debug(f"šŸ” Qwen: Mask: {mask}") + + # Additional debug info + try: + prompt, response = self.split_prompt_assistant(text) + prompt_tokens = tokenizer(prompt, add_special_tokens=False).input_ids + response_tokens = tokenizer(response, add_special_tokens=False).input_ids + + logger.debug(f"šŸ” Qwen: Prompt length: {len(prompt)}") + logger.debug(f"šŸ” Qwen: Response length: {len(response)}") + logger.debug(f"šŸ” Qwen: Prompt token IDs: {prompt_tokens}") + logger.debug(f"šŸ” Qwen: Response token IDs: {response_tokens}") + logger.debug(f"šŸ” Qwen: Prompt: {prompt[:100]}...") + logger.debug(f"šŸ” Qwen: Response: {response[:100]}...") + logger.debug(f"šŸ” Qwen: Full input IDs length: {len(encoding.input_ids)}") + logger.debug(f"šŸ” Qwen: Prompt + Response token IDs length: {len(prompt_tokens) + len(response_tokens)}") + logger.debug( + f"šŸ” Qwen: Difference in lengths: {len(encoding.input_ids) - (len(prompt_tokens) + len(response_tokens))}" + ) + except Exception as e: + logger.error(f"šŸ” Qwen: Error splitting prompt/response: {e}") + + return mask diff --git a/train_grpo.py b/train_grpo.py index 37d68c4..4b79d2e 100644 --- a/train_grpo.py +++ b/train_grpo.py @@ -23,7 +23,7 @@ from src.config import ( ) from src.rewards import build_reward_correctness_fn, reward_em_chunk, reward_retry from src.search_module import get_qa_dataset -from src.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter +from src.tokenizer_adapter import LlamaTokenizerAdapter, QwenTokenizerAdapter, R1DistilTokenizerAdapter # Initialize training directories paths = init_training_dirs() @@ -85,8 +85,10 @@ def agentic_generate( adapter = R1DistilTokenizerAdapter() elif "llama" in tokenizer_name: adapter = LlamaTokenizerAdapter() + elif "qwen" in tokenizer_name: + adapter = QwenTokenizerAdapter() else: - adapter = R1DistilTokenizerAdapter() + raise ValueError(f"Unsupported tokenizer: {tokenizer_name}") agent = Agent(adapter) return agent.run_agent(generate_fn, tokenizer, prompts, max_generations)