test: add Qwen tokenizer adapter tests

Implemented unit tests for the Qwen tokenizer adapter, including format handling, mask generation, and multi-turn conversation support
main
thinhlpg 1 month ago
parent 6efe01d5ff
commit 133cb1ab90

@ -3,10 +3,10 @@ Test module for tokenizer adapters.
""" """
import torch import torch
from transformers import LlamaTokenizerFast from transformers import AutoTokenizer, LlamaTokenizerFast
from src.config import logger from src.config import logger
from src.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter from src.tokenizer_adapter import LlamaTokenizerAdapter, QwenTokenizerAdapter, R1DistilTokenizerAdapter
# Test conversation used across all tests # Test conversation used across all tests
TEST_CHAT = [ TEST_CHAT = [
@ -426,3 +426,66 @@ def test_r1_distil_multi_turn():
for i in range(len(response_tokens) - len(end_marker_tokens) + 1): for i in range(len(response_tokens) - len(end_marker_tokens) + 1):
if response_tokens[i : i + len(end_marker_tokens)].tolist() == end_marker_tokens: if response_tokens[i : i + len(end_marker_tokens)].tolist() == end_marker_tokens:
assert not response_mask[i : i + len(end_marker_tokens)].any(), "End markers should not be masked" assert not response_mask[i : i + len(end_marker_tokens)].any(), "End markers should not be masked"
def test_qwen_format():
"""Test Qwen tokenizer adapter format handling."""
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct")
adapter = QwenTokenizerAdapter()
convo = tokenizer.apply_chat_template(TEST_CHAT, tokenize=False)
if not isinstance(convo, str):
convo = tokenizer.decode(convo)
prompt, response = adapter.split_prompt_assistant(convo)
# Basic format checks
assert "<|im_start|>assistant" in prompt
assert "I'm doing great" in response
def test_qwen_mask():
"""Test Qwen tokenizer adapter mask generation."""
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct")
adapter = QwenTokenizerAdapter()
convo = tokenizer.apply_chat_template(TEST_CHAT, tokenize=False)
if not isinstance(convo, str):
convo = tokenizer.decode(convo)
# Get mask and verify basic properties
mask = adapter.get_mask(convo, tokenizer)
assert isinstance(mask, torch.Tensor)
assert mask.dtype == torch.int
assert mask.sum().item() > 0 # Has some masked tokens
assert all(x in [0, 1] for x in mask.tolist()) # Only 0s and 1s
# Verify mask length matches input length
encoding = tokenizer(convo, add_special_tokens=False)
assert len(mask) == len(encoding.input_ids)
def test_qwen_multi_turn():
"""Test Qwen adapter with multi-turn conversations."""
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct")
adapter = QwenTokenizerAdapter()
# Simple multi-turn chat
chat = [
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Hello!"},
{"role": "user", "content": "How are you?"},
{"role": "assistant", "content": "I'm good!"},
]
convo = tokenizer.apply_chat_template(chat, tokenize=False)
if not isinstance(convo, str):
convo = tokenizer.decode(convo)
# Test basic multi-turn functionality
mask = adapter.get_mask(convo, tokenizer)
prompt, response = adapter.split_prompt_assistant(convo)
assert len(mask) > 0
assert "Hello!" in response
assert "I'm good!" in response

Loading…
Cancel
Save