You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
41 lines
1.4 KiB
41 lines
1.4 KiB
import pytest
|
|
from unittest.mock import patch, MagicMock
|
|
from swarms.models.mistral import Mistral
|
|
|
|
def test_mistral_initialization():
|
|
mistral = Mistral(device="cpu")
|
|
assert isinstance(mistral, Mistral)
|
|
assert mistral.ai_name == "Node Model Agent"
|
|
assert mistral.system_prompt == None
|
|
assert mistral.model_name == "mistralai/Mistral-7B-v0.1"
|
|
assert mistral.device == "cpu"
|
|
assert mistral.use_flash_attention == False
|
|
assert mistral.temperature == 1.0
|
|
assert mistral.max_length == 100
|
|
assert mistral.history == []
|
|
|
|
@patch('your_module.AutoModelForCausalLM.from_pretrained')
|
|
@patch('your_module.AutoTokenizer.from_pretrained')
|
|
def test_mistral_load_model(mock_tokenizer, mock_model):
|
|
mistral = Mistral(device="cpu")
|
|
mistral.load_model()
|
|
mock_model.assert_called_once()
|
|
mock_tokenizer.assert_called_once()
|
|
|
|
@patch('your_module.Mistral.load_model')
|
|
def test_mistral_run(mock_load_model):
|
|
mistral = Mistral(device="cpu")
|
|
mistral.run("What's the weather in miami")
|
|
mock_load_model.assert_called_once()
|
|
|
|
@patch('your_module.Mistral.run')
|
|
def test_mistral_chat(mock_run):
|
|
mistral = Mistral(device="cpu")
|
|
mistral.chat("What's the weather in miami")
|
|
mock_run.assert_called_once()
|
|
|
|
def test_mistral__stream_response():
|
|
mistral = Mistral(device="cpu")
|
|
response = "It's sunny in Miami."
|
|
tokens = list(mistral._stream_response(response))
|
|
assert tokens == ["It's", "sunny", "in", "Miami."] |