You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

419 lines
17 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
Tokenizer adapter implementations for different models.
This module provides adapter classes for handling different tokenizer formats.
"""
from abc import ABC, abstractmethod
import torch
from config import logger
class TokenizerAdapter(ABC):
"""Base class for tokenizer adapters."""
@abstractmethod
def get_assistant_marker(self) -> str:
"""Get the assistant marker for the model."""
pass
@abstractmethod
def get_end_marker(self) -> str:
"""Get the end marker for the model."""
pass
@abstractmethod
def get_mask(self, text: str, tokenizer) -> torch.Tensor:
"""Get the mask for the model's response."""
pass
@abstractmethod
def split_prompt_assistant(self, text: str) -> tuple[str, str]:
"""Split conversation text into prompt and assistant response."""
pass
class LlamaTokenizerAdapter(TokenizerAdapter):
"""Adapter for Llama model tokenizer."""
def get_assistant_marker(self) -> str:
"""Get the assistant marker."""
return "<|start_header_id|>assistant<|end_header_id|>"
def get_end_marker(self) -> str:
"""Get the end marker."""
return "<|eot_id|>"
def split_prompt_assistant(self, convo_text: str) -> tuple[str, str]:
"""Split the text into prompt and assistant parts.
Args:
convo_text: The text to split
Returns:
A tuple of (prompt, assistant)
"""
# EXACT replication from rl_helpers.py but using existing method
marker = self.get_assistant_marker() # Use existing method but same value
idx = convo_text.find(marker)
if idx == -1:
raise ValueError("Could not find assistant marker in conversation text.")
return convo_text, ""
# Include the marker in the prompt by slicing up to the end of the marker.
prompt = convo_text[: idx + len(marker)]
# The assistant response is everything after the marker.
assistant_response = convo_text[idx + len(marker) :]
return prompt, assistant_response
def get_mask(self, text: str, tokenizer) -> torch.Tensor:
"""Get the mask for the text.
Args:
text: The text to get the mask for
tokenizer: The tokenizer to use
Returns:
A tensor of 0s and 1s where 1s indicate assistant tokens
"""
# Log input
logger.debug(f"🔍 Llama: Full text length: {len(text)}")
# EXACT replication from rl_helpers.py but using existing methods
encoding = tokenizer(text, add_special_tokens=False)
# Use existing methods but same values
start_header_id = tokenizer.convert_tokens_to_ids("<|start_header_id|>")
assistant_token = tokenizer.convert_tokens_to_ids("assistant")
end_header_id = tokenizer.convert_tokens_to_ids("<|end_header_id|>")
eot_id = tokenizer.convert_tokens_to_ids(self.get_end_marker()) # Use existing method but same value
# Log token IDs
logger.debug(f"🔍 Llama: Tokenized length: {len(encoding.input_ids)}")
logger.debug(f"🔍 Llama: Input IDs: {encoding.input_ids}")
logger.debug(
f"🔍 Llama: Special token IDs: start={start_header_id}, assistant={assistant_token}, end={end_header_id}, eot={eot_id}"
)
assistant_ranges = []
i = 0
while i < len(encoding.input_ids) - 1:
if encoding.input_ids[i] == start_header_id and encoding.input_ids[i + 1] == assistant_token:
logger.debug(f"🔍 Llama: Found assistant marker at position {i}")
logger.debug(f"🔍 Llama: Assistant marker tokens: {encoding.input_ids[i : i + 2]}")
i += 2
while i < len(encoding.input_ids) and encoding.input_ids[i] != end_header_id:
i += 1
i += 2
start_idx = i
logger.debug(f"🔍 Llama: Found start of response at {start_idx}")
logger.debug(f"🔍 Llama: Start token ID: {encoding.input_ids[start_idx]}")
while i < len(encoding.input_ids) and encoding.input_ids[i] != eot_id:
i += 1
end_idx = i
logger.debug(f"🔍 Llama: Found end of response at {end_idx}")
logger.debug(f"🔍 Llama: End token ID: {encoding.input_ids[end_idx]}")
logger.debug(f"🔍 Llama: Response token IDs: {encoding.input_ids[start_idx:end_idx]}")
assistant_ranges.append((start_idx, end_idx))
else:
i += 1
mask = [0] * len(encoding.input_ids)
for start_idx, end_idx in assistant_ranges:
for idx in range(start_idx, end_idx):
mask[idx] = 1
mask = torch.tensor(mask, dtype=torch.int)
# Log final mask
logger.debug(f"🔍 Llama: Final mask shape: {mask.shape}")
logger.debug(f"🔍 Llama: Mask sum: {mask.sum().item()}")
logger.debug(f"🔍 Llama: Mask: {mask}")
# Additional debug info
try:
prompt, response = self.split_prompt_assistant(text)
prompt_tokens = tokenizer(prompt, add_special_tokens=False).input_ids
response_tokens = tokenizer(response, add_special_tokens=False).input_ids
logger.debug(f"🔍 Llama: Prompt length: {len(prompt)}")
logger.debug(f"🔍 Llama: Response length: {len(response)}")
logger.debug(f"🔍 Llama: Prompt token IDs: {prompt_tokens}")
logger.debug(f"🔍 Llama: Response token IDs: {response_tokens}")
logger.debug(f"🔍 Llama: Prompt: {prompt[:100]}...")
logger.debug(f"🔍 Llama: Response: {response[:100]}...")
logger.debug(f"🔍 Llama: Full input IDs length: {len(encoding.input_ids)}")
logger.debug(f"🔍 Llama: Prompt + Response token IDs length: {len(prompt_tokens) + len(response_tokens)}")
logger.debug(
f"🔍 Llama: Difference in lengths: {len(encoding.input_ids) - (len(prompt_tokens) + len(response_tokens))}"
)
except Exception as e:
logger.error(f"🔍 Llama: Error splitting prompt/response: {e}")
return mask
class R1DistilTokenizerAdapter(TokenizerAdapter):
"""Adapter for R1-Distil model tokenizer."""
def get_assistant_marker(self) -> str:
marker = "<Assistant>"
return marker
def get_end_marker(self) -> str:
marker = "<end▁of▁sentence>"
return marker
def get_begin_marker(self) -> str:
return "<begin▁of▁sentence>"
def get_user_marker(self) -> str:
return "<User>"
def get_mask(self, text: str, tokenizer) -> torch.Tensor:
"""Get the mask for the text.
Args:
text: The text to get the mask for
tokenizer: The tokenizer to use
Returns:
A tensor of 0s and 1s where 1s indicate assistant tokens
"""
logger.debug(f"🔍 R1Distil: Getting mask for text length: {len(text)}")
# Get all markers
assistant_marker = self.get_assistant_marker()
end_marker = self.get_end_marker()
# Get the full tokenization
encoding = tokenizer(text, add_special_tokens=False)
tokens = encoding.input_ids
logger.debug(f"🔍 R1Distil: Full text token IDs: {tokens}")
# Create mask initialized to zeros - ENSURE SAME LENGTH AS INPUT_IDS
mask = torch.zeros(len(tokens), dtype=torch.int)
# Get token IDs for markers
assistant_tokens = tokenizer(assistant_marker, add_special_tokens=False).input_ids
end_tokens = tokenizer(end_marker, add_special_tokens=False).input_ids
logger.debug(f"🔍 R1Distil: Assistant marker token IDs: {assistant_tokens}")
logger.debug(f"🔍 R1Distil: End marker token IDs: {end_tokens}")
# Find all assistant responses
assistant_ranges = []
i = 0
while i < len(tokens):
# Look for assistant marker
if i + len(assistant_tokens) <= len(tokens) and tokens[i : i + len(assistant_tokens)] == assistant_tokens:
logger.debug(f"🔍 R1Distil: Found assistant marker at position {i}")
# Start masking AFTER the assistant marker
start_idx = i + len(assistant_tokens)
# Find end marker
end_idx = None
j = start_idx
while j < len(tokens):
if j + len(end_tokens) <= len(tokens) and tokens[j : j + len(end_tokens)] == end_tokens:
end_idx = j # Don't include the end marker in the mask
break
j += 1
if end_idx is None:
# If no end marker found, mask until the end
end_idx = len(tokens)
logger.debug(f"🔍 R1Distil: Response range: {start_idx} to {end_idx}")
assistant_ranges.append((start_idx, end_idx))
i = end_idx + len(end_tokens) # Start next search after the end marker
else:
i += 1
# Apply mask for all found ranges
for start_idx, end_idx in assistant_ranges:
mask[start_idx:end_idx] = 1
logger.debug(f"🔍 R1Distil: Found {len(assistant_ranges)} assistant responses")
logger.debug(f"🔍 R1Distil: Final mask sum: {mask.sum().item()}")
logger.debug(f"🔍 R1Distil: Final mask length: {len(mask)}")
logger.debug(f"🔍 R1Distil: Mask: {mask}")
return mask
def split_prompt_assistant(self, text: str) -> tuple[str, str]:
"""Split the text into prompt and assistant parts.
Args:
text: The text to split
Returns:
A tuple of (prompt, assistant)
"""
logger.debug(f"🔍 R1Distil: Splitting text of length: {len(text)}")
# Find the assistant marker
marker = self.get_assistant_marker()
end_marker = self.get_end_marker()
# Find ALL assistant markers in the text
assistant_markers = []
pos = 0
while True:
pos = text.find(marker, pos)
if pos == -1:
break
assistant_markers.append(pos)
pos += len(marker)
if not assistant_markers:
raise ValueError("Could not find assistant marker in text")
# Get the positions of all markers for later use
marker_positions = []
for start_pos in assistant_markers:
response_start = start_pos + len(marker)
# Find the end marker after this response
end_pos = text.find(end_marker, response_start)
if end_pos == -1:
end_pos = len(text)
else:
end_pos = end_pos + len(end_marker)
marker_positions.append((start_pos, response_start, end_pos))
# Get the full response (all assistant outputs concatenated)
full_response = ""
for _, resp_start, resp_end in marker_positions:
full_response += text[resp_start:resp_end]
# Include ALL assistant markers and responses in the response
# This matches how the mask is generated in get_mask
first_assistant_pos = marker_positions[0][0]
last_response_end = marker_positions[-1][2]
# Split into prompt and response
prompt = text[:first_assistant_pos] # Everything before the first assistant marker
response = text[first_assistant_pos:last_response_end] # All markers and responses
logger.debug(f"🔍 R1Distil: Prompt length: {len(prompt)}")
logger.debug(f"🔍 R1Distil: Response length: {len(response)}")
logger.debug(f"🔍 R1Distil: Response token count estimate: {len(response) / 4}") # Rough estimate
logger.debug(f"🔍 R1Distil: Final prompt: {prompt[:100]}...")
logger.debug(f"🔍 R1Distil: Final response: {response[:100]}...")
return prompt, response
class QwenTokenizerAdapter(TokenizerAdapter):
"""Adapter for Qwen2.5 model tokenizer."""
def get_assistant_marker(self) -> str:
"""Get the assistant marker."""
return "<|im_start|>assistant"
def get_end_marker(self) -> str:
"""Get the end marker."""
return "<|im_end|>"
def split_prompt_assistant(self, convo_text: str) -> tuple[str, str]:
"""Split the text into prompt and assistant parts.
Args:
convo_text: The text to split
Returns:
A tuple of (prompt, assistant)
"""
marker = self.get_assistant_marker()
idx = convo_text.find(marker)
if idx == -1:
raise ValueError("Could not find assistant marker in conversation text.")
return convo_text, ""
# Include the marker in the prompt by slicing up to the end of the marker
prompt = convo_text[: idx + len(marker)]
# The assistant response is everything after the marker
assistant_response = convo_text[idx + len(marker) :]
return prompt, assistant_response
def get_mask(self, text: str, tokenizer) -> torch.Tensor:
"""Get the mask for the text.
Args:
text: The text to get the mask for
tokenizer: The tokenizer to use
Returns:
A tensor of 0s and 1s where 1s indicate assistant tokens
"""
# Log input
logger.debug(f"🔍 Qwen: Full text length: {len(text)}")
encoding = tokenizer(text, add_special_tokens=False)
# Get token IDs for markers
im_start = tokenizer.convert_tokens_to_ids("<|im_start|>")
assistant_token = tokenizer.convert_tokens_to_ids("assistant")
im_end = tokenizer.convert_tokens_to_ids(self.get_end_marker())
# Log token IDs
logger.debug(f"🔍 Qwen: Tokenized length: {len(encoding.input_ids)}")
logger.debug(f"🔍 Qwen: Input IDs: {encoding.input_ids}")
logger.debug(f"🔍 Qwen: Special token IDs: im_start={im_start}, assistant={assistant_token}, im_end={im_end}")
assistant_ranges = []
i = 0
while i < len(encoding.input_ids) - 1:
if encoding.input_ids[i] == im_start and encoding.input_ids[i + 1] == assistant_token:
logger.debug(f"🔍 Qwen: Found assistant marker at position {i}")
logger.debug(f"🔍 Qwen: Assistant marker tokens: {encoding.input_ids[i : i + 2]}")
i += 2 # Skip past <|im_start|>assistant
start_idx = i
logger.debug(f"🔍 Qwen: Found start of response at {start_idx}")
logger.debug(f"🔍 Qwen: Start token ID: {encoding.input_ids[start_idx]}")
while i < len(encoding.input_ids) and encoding.input_ids[i] != im_end:
i += 1
end_idx = i
logger.debug(f"🔍 Qwen: Found end of response at {end_idx}")
logger.debug(f"🔍 Qwen: End token ID: {encoding.input_ids[end_idx]}")
logger.debug(f"🔍 Qwen: Response token IDs: {encoding.input_ids[start_idx:end_idx]}")
assistant_ranges.append((start_idx, end_idx))
else:
i += 1
mask = [0] * len(encoding.input_ids)
for start_idx, end_idx in assistant_ranges:
for idx in range(start_idx, end_idx):
mask[idx] = 1
mask = torch.tensor(mask, dtype=torch.int)
# Log final mask
logger.debug(f"🔍 Qwen: Final mask shape: {mask.shape}")
logger.debug(f"🔍 Qwen: Mask sum: {mask.sum().item()}")
logger.debug(f"🔍 Qwen: Mask: {mask}")
# Additional debug info
try:
prompt, response = self.split_prompt_assistant(text)
prompt_tokens = tokenizer(prompt, add_special_tokens=False).input_ids
response_tokens = tokenizer(response, add_special_tokens=False).input_ids
logger.debug(f"🔍 Qwen: Prompt length: {len(prompt)}")
logger.debug(f"🔍 Qwen: Response length: {len(response)}")
logger.debug(f"🔍 Qwen: Prompt token IDs: {prompt_tokens}")
logger.debug(f"🔍 Qwen: Response token IDs: {response_tokens}")
logger.debug(f"🔍 Qwen: Prompt: {prompt[:100]}...")
logger.debug(f"🔍 Qwen: Response: {response[:100]}...")
logger.debug(f"🔍 Qwen: Full input IDs length: {len(encoding.input_ids)}")
logger.debug(f"🔍 Qwen: Prompt + Response token IDs length: {len(prompt_tokens) + len(response_tokens)}")
logger.debug(
f"🔍 Qwen: Difference in lengths: {len(encoding.input_ids) - (len(prompt_tokens) + len(response_tokens))}"
)
except Exception as e:
logger.error(f"🔍 Qwen: Error splitting prompt/response: {e}")
return mask