|
|
|
@ -304,3 +304,115 @@ class R1DistilTokenizerAdapter(TokenizerAdapter):
|
|
|
|
|
logger.debug(f"🔍 R1Distil: Final response: {response[:100]}...")
|
|
|
|
|
|
|
|
|
|
return prompt, response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class QwenTokenizerAdapter(TokenizerAdapter):
|
|
|
|
|
"""Adapter for Qwen2.5 model tokenizer."""
|
|
|
|
|
|
|
|
|
|
def get_assistant_marker(self) -> str:
|
|
|
|
|
"""Get the assistant marker."""
|
|
|
|
|
return "<|im_start|>assistant"
|
|
|
|
|
|
|
|
|
|
def get_end_marker(self) -> str:
|
|
|
|
|
"""Get the end marker."""
|
|
|
|
|
return "<|im_end|>"
|
|
|
|
|
|
|
|
|
|
def split_prompt_assistant(self, convo_text: str) -> tuple[str, str]:
|
|
|
|
|
"""Split the text into prompt and assistant parts.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
convo_text: The text to split
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
A tuple of (prompt, assistant)
|
|
|
|
|
"""
|
|
|
|
|
marker = self.get_assistant_marker()
|
|
|
|
|
idx = convo_text.find(marker)
|
|
|
|
|
if idx == -1:
|
|
|
|
|
raise ValueError("Could not find assistant marker in conversation text.")
|
|
|
|
|
return convo_text, ""
|
|
|
|
|
|
|
|
|
|
# Include the marker in the prompt by slicing up to the end of the marker
|
|
|
|
|
prompt = convo_text[: idx + len(marker)]
|
|
|
|
|
# The assistant response is everything after the marker
|
|
|
|
|
assistant_response = convo_text[idx + len(marker) :]
|
|
|
|
|
return prompt, assistant_response
|
|
|
|
|
|
|
|
|
|
def get_mask(self, text: str, tokenizer) -> torch.Tensor:
|
|
|
|
|
"""Get the mask for the text.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
text: The text to get the mask for
|
|
|
|
|
tokenizer: The tokenizer to use
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
A tensor of 0s and 1s where 1s indicate assistant tokens
|
|
|
|
|
"""
|
|
|
|
|
# Log input
|
|
|
|
|
logger.debug(f"🔍 Qwen: Full text length: {len(text)}")
|
|
|
|
|
|
|
|
|
|
encoding = tokenizer(text, add_special_tokens=False)
|
|
|
|
|
# Get token IDs for markers
|
|
|
|
|
im_start = tokenizer.convert_tokens_to_ids("<|im_start|>")
|
|
|
|
|
assistant_token = tokenizer.convert_tokens_to_ids("assistant")
|
|
|
|
|
im_end = tokenizer.convert_tokens_to_ids(self.get_end_marker())
|
|
|
|
|
|
|
|
|
|
# Log token IDs
|
|
|
|
|
logger.debug(f"🔍 Qwen: Tokenized length: {len(encoding.input_ids)}")
|
|
|
|
|
logger.debug(f"🔍 Qwen: Input IDs: {encoding.input_ids}")
|
|
|
|
|
logger.debug(f"🔍 Qwen: Special token IDs: im_start={im_start}, assistant={assistant_token}, im_end={im_end}")
|
|
|
|
|
|
|
|
|
|
assistant_ranges = []
|
|
|
|
|
i = 0
|
|
|
|
|
while i < len(encoding.input_ids) - 1:
|
|
|
|
|
if encoding.input_ids[i] == im_start and encoding.input_ids[i + 1] == assistant_token:
|
|
|
|
|
logger.debug(f"🔍 Qwen: Found assistant marker at position {i}")
|
|
|
|
|
logger.debug(f"🔍 Qwen: Assistant marker tokens: {encoding.input_ids[i : i + 2]}")
|
|
|
|
|
i += 2 # Skip past <|im_start|>assistant
|
|
|
|
|
start_idx = i
|
|
|
|
|
logger.debug(f"🔍 Qwen: Found start of response at {start_idx}")
|
|
|
|
|
logger.debug(f"🔍 Qwen: Start token ID: {encoding.input_ids[start_idx]}")
|
|
|
|
|
|
|
|
|
|
while i < len(encoding.input_ids) and encoding.input_ids[i] != im_end:
|
|
|
|
|
i += 1
|
|
|
|
|
end_idx = i
|
|
|
|
|
logger.debug(f"🔍 Qwen: Found end of response at {end_idx}")
|
|
|
|
|
logger.debug(f"🔍 Qwen: End token ID: {encoding.input_ids[end_idx]}")
|
|
|
|
|
logger.debug(f"🔍 Qwen: Response token IDs: {encoding.input_ids[start_idx:end_idx]}")
|
|
|
|
|
assistant_ranges.append((start_idx, end_idx))
|
|
|
|
|
else:
|
|
|
|
|
i += 1
|
|
|
|
|
|
|
|
|
|
mask = [0] * len(encoding.input_ids)
|
|
|
|
|
for start_idx, end_idx in assistant_ranges:
|
|
|
|
|
for idx in range(start_idx, end_idx):
|
|
|
|
|
mask[idx] = 1
|
|
|
|
|
|
|
|
|
|
mask = torch.tensor(mask, dtype=torch.int)
|
|
|
|
|
|
|
|
|
|
# Log final mask
|
|
|
|
|
logger.debug(f"🔍 Qwen: Final mask shape: {mask.shape}")
|
|
|
|
|
logger.debug(f"🔍 Qwen: Mask sum: {mask.sum().item()}")
|
|
|
|
|
logger.debug(f"🔍 Qwen: Mask: {mask}")
|
|
|
|
|
|
|
|
|
|
# Additional debug info
|
|
|
|
|
try:
|
|
|
|
|
prompt, response = self.split_prompt_assistant(text)
|
|
|
|
|
prompt_tokens = tokenizer(prompt, add_special_tokens=False).input_ids
|
|
|
|
|
response_tokens = tokenizer(response, add_special_tokens=False).input_ids
|
|
|
|
|
|
|
|
|
|
logger.debug(f"🔍 Qwen: Prompt length: {len(prompt)}")
|
|
|
|
|
logger.debug(f"🔍 Qwen: Response length: {len(response)}")
|
|
|
|
|
logger.debug(f"🔍 Qwen: Prompt token IDs: {prompt_tokens}")
|
|
|
|
|
logger.debug(f"🔍 Qwen: Response token IDs: {response_tokens}")
|
|
|
|
|
logger.debug(f"🔍 Qwen: Prompt: {prompt[:100]}...")
|
|
|
|
|
logger.debug(f"🔍 Qwen: Response: {response[:100]}...")
|
|
|
|
|
logger.debug(f"🔍 Qwen: Full input IDs length: {len(encoding.input_ids)}")
|
|
|
|
|
logger.debug(f"🔍 Qwen: Prompt + Response token IDs length: {len(prompt_tokens) + len(response_tokens)}")
|
|
|
|
|
logger.debug(
|
|
|
|
|
f"🔍 Qwen: Difference in lengths: {len(encoding.input_ids) - (len(prompt_tokens) + len(response_tokens))}"
|
|
|
|
|
)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"🔍 Qwen: Error splitting prompt/response: {e}")
|
|
|
|
|
|
|
|
|
|
return mask
|
|
|
|
|