diff --git a/tests/test_tokenizer_adapters.py b/tests/test_tokenizer_adapters.py index ed6570e..1d1db98 100644 --- a/tests/test_tokenizer_adapters.py +++ b/tests/test_tokenizer_adapters.py @@ -3,10 +3,10 @@ Test module for tokenizer adapters. """ import torch -from transformers import LlamaTokenizerFast +from transformers import AutoTokenizer, LlamaTokenizerFast from src.config import logger -from src.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter +from src.tokenizer_adapter import LlamaTokenizerAdapter, QwenTokenizerAdapter, R1DistilTokenizerAdapter # Test conversation used across all tests TEST_CHAT = [ @@ -426,3 +426,66 @@ def test_r1_distil_multi_turn(): for i in range(len(response_tokens) - len(end_marker_tokens) + 1): if response_tokens[i : i + len(end_marker_tokens)].tolist() == end_marker_tokens: assert not response_mask[i : i + len(end_marker_tokens)].any(), "End markers should not be masked" + + +def test_qwen_format(): + """Test Qwen tokenizer adapter format handling.""" + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct") + adapter = QwenTokenizerAdapter() + + convo = tokenizer.apply_chat_template(TEST_CHAT, tokenize=False) + if not isinstance(convo, str): + convo = tokenizer.decode(convo) + + prompt, response = adapter.split_prompt_assistant(convo) + + # Basic format checks + assert "<|im_start|>assistant" in prompt + assert "I'm doing great" in response + + +def test_qwen_mask(): + """Test Qwen tokenizer adapter mask generation.""" + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct") + adapter = QwenTokenizerAdapter() + + convo = tokenizer.apply_chat_template(TEST_CHAT, tokenize=False) + if not isinstance(convo, str): + convo = tokenizer.decode(convo) + + # Get mask and verify basic properties + mask = adapter.get_mask(convo, tokenizer) + assert isinstance(mask, torch.Tensor) + assert mask.dtype == torch.int + assert mask.sum().item() > 0 # Has some masked tokens + assert all(x in [0, 1] for x in mask.tolist()) # Only 0s and 1s + + # Verify mask length matches input length + encoding = tokenizer(convo, add_special_tokens=False) + assert len(mask) == len(encoding.input_ids) + + +def test_qwen_multi_turn(): + """Test Qwen adapter with multi-turn conversations.""" + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct") + adapter = QwenTokenizerAdapter() + + # Simple multi-turn chat + chat = [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello!"}, + {"role": "user", "content": "How are you?"}, + {"role": "assistant", "content": "I'm good!"}, + ] + + convo = tokenizer.apply_chat_template(chat, tokenize=False) + if not isinstance(convo, str): + convo = tokenizer.decode(convo) + + # Test basic multi-turn functionality + mask = adapter.get_mask(convo, tokenizer) + prompt, response = adapter.split_prompt_assistant(convo) + + assert len(mask) > 0 + assert "Hello!" in response + assert "I'm good!" in response