|
|
|
@ -3,10 +3,10 @@ Test module for tokenizer adapters.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from transformers import LlamaTokenizerFast
|
|
|
|
|
from transformers import AutoTokenizer, LlamaTokenizerFast
|
|
|
|
|
|
|
|
|
|
from src.config import logger
|
|
|
|
|
from src.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter
|
|
|
|
|
from src.tokenizer_adapter import LlamaTokenizerAdapter, QwenTokenizerAdapter, R1DistilTokenizerAdapter
|
|
|
|
|
|
|
|
|
|
# Test conversation used across all tests
|
|
|
|
|
TEST_CHAT = [
|
|
|
|
@ -426,3 +426,66 @@ def test_r1_distil_multi_turn():
|
|
|
|
|
for i in range(len(response_tokens) - len(end_marker_tokens) + 1):
|
|
|
|
|
if response_tokens[i : i + len(end_marker_tokens)].tolist() == end_marker_tokens:
|
|
|
|
|
assert not response_mask[i : i + len(end_marker_tokens)].any(), "End markers should not be masked"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_qwen_format():
|
|
|
|
|
"""Test Qwen tokenizer adapter format handling."""
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct")
|
|
|
|
|
adapter = QwenTokenizerAdapter()
|
|
|
|
|
|
|
|
|
|
convo = tokenizer.apply_chat_template(TEST_CHAT, tokenize=False)
|
|
|
|
|
if not isinstance(convo, str):
|
|
|
|
|
convo = tokenizer.decode(convo)
|
|
|
|
|
|
|
|
|
|
prompt, response = adapter.split_prompt_assistant(convo)
|
|
|
|
|
|
|
|
|
|
# Basic format checks
|
|
|
|
|
assert "<|im_start|>assistant" in prompt
|
|
|
|
|
assert "I'm doing great" in response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_qwen_mask():
|
|
|
|
|
"""Test Qwen tokenizer adapter mask generation."""
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct")
|
|
|
|
|
adapter = QwenTokenizerAdapter()
|
|
|
|
|
|
|
|
|
|
convo = tokenizer.apply_chat_template(TEST_CHAT, tokenize=False)
|
|
|
|
|
if not isinstance(convo, str):
|
|
|
|
|
convo = tokenizer.decode(convo)
|
|
|
|
|
|
|
|
|
|
# Get mask and verify basic properties
|
|
|
|
|
mask = adapter.get_mask(convo, tokenizer)
|
|
|
|
|
assert isinstance(mask, torch.Tensor)
|
|
|
|
|
assert mask.dtype == torch.int
|
|
|
|
|
assert mask.sum().item() > 0 # Has some masked tokens
|
|
|
|
|
assert all(x in [0, 1] for x in mask.tolist()) # Only 0s and 1s
|
|
|
|
|
|
|
|
|
|
# Verify mask length matches input length
|
|
|
|
|
encoding = tokenizer(convo, add_special_tokens=False)
|
|
|
|
|
assert len(mask) == len(encoding.input_ids)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_qwen_multi_turn():
|
|
|
|
|
"""Test Qwen adapter with multi-turn conversations."""
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct")
|
|
|
|
|
adapter = QwenTokenizerAdapter()
|
|
|
|
|
|
|
|
|
|
# Simple multi-turn chat
|
|
|
|
|
chat = [
|
|
|
|
|
{"role": "user", "content": "Hi"},
|
|
|
|
|
{"role": "assistant", "content": "Hello!"},
|
|
|
|
|
{"role": "user", "content": "How are you?"},
|
|
|
|
|
{"role": "assistant", "content": "I'm good!"},
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
convo = tokenizer.apply_chat_template(chat, tokenize=False)
|
|
|
|
|
if not isinstance(convo, str):
|
|
|
|
|
convo = tokenizer.decode(convo)
|
|
|
|
|
|
|
|
|
|
# Test basic multi-turn functionality
|
|
|
|
|
mask = adapter.get_mask(convo, tokenizer)
|
|
|
|
|
prompt, response = adapter.split_prompt_assistant(convo)
|
|
|
|
|
|
|
|
|
|
assert len(mask) > 0
|
|
|
|
|
assert "Hello!" in response
|
|
|
|
|
assert "I'm good!" in response
|
|
|
|
|