You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/tests/models/test_mixtral.py

54 lines
2.0 KiB

import pytest
from unittest.mock import patch, MagicMock
from swarms.models.mixtral import Mixtral
@patch("swarms.models.mixtral.AutoTokenizer")
@patch("swarms.models.mixtral.AutoModelForCausalLM")
def test_mixtral_init(mock_model, mock_tokenizer):
mixtral = Mixtral()
mock_tokenizer.from_pretrained.assert_called_once()
mock_model.from_pretrained.assert_called_once()
assert mixtral.model_name == "mistralai/Mixtral-8x7B-v0.1"
assert mixtral.max_new_tokens == 20
@patch("swarms.models.mixtral.AutoTokenizer")
@patch("swarms.models.mixtral.AutoModelForCausalLM")
def test_mixtral_run(mock_model, mock_tokenizer):
mixtral = Mixtral()
mock_tokenizer_instance = MagicMock()
mock_model_instance = MagicMock()
mock_tokenizer.from_pretrained.return_value = (
mock_tokenizer_instance
)
mock_model.from_pretrained.return_value = mock_model_instance
mock_tokenizer_instance.return_tensors = "pt"
mock_model_instance.generate.return_value = [101, 102, 103]
mock_tokenizer_instance.decode.return_value = "Generated text"
result = mixtral.run("Test task")
assert result == "Generated text"
mock_tokenizer_instance.assert_called_once_with(
"Test task", return_tensors="pt"
)
mock_model_instance.generate.assert_called_once()
mock_tokenizer_instance.decode.assert_called_once_with(
[101, 102, 103], skip_special_tokens=True
)
@patch("swarms.models.mixtral.AutoTokenizer")
@patch("swarms.models.mixtral.AutoModelForCausalLM")
def test_mixtral_run_error(mock_model, mock_tokenizer):
mixtral = Mixtral()
mock_tokenizer_instance = MagicMock()
mock_model_instance = MagicMock()
mock_tokenizer.from_pretrained.return_value = (
mock_tokenizer_instance
)
mock_model.from_pretrained.return_value = mock_model_instance
mock_tokenizer_instance.return_tensors = "pt"
mock_model_instance.generate.side_effect = Exception("Test error")
with pytest.raises(Exception, match="Test error"):
mixtral.run("Test task")