You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/tests/models/mistral.py

41 lines
1.4 KiB

import pytest
from unittest.mock import patch, MagicMock
from swarms.models.mistral import Mistral
def test_mistral_initialization():
mistral = Mistral(device="cpu")
assert isinstance(mistral, Mistral)
assert mistral.ai_name == "Node Model Agent"
assert mistral.system_prompt == None
assert mistral.model_name == "mistralai/Mistral-7B-v0.1"
assert mistral.device == "cpu"
assert mistral.use_flash_attention == False
assert mistral.temperature == 1.0
assert mistral.max_length == 100
assert mistral.history == []
@patch('your_module.AutoModelForCausalLM.from_pretrained')
@patch('your_module.AutoTokenizer.from_pretrained')
def test_mistral_load_model(mock_tokenizer, mock_model):
mistral = Mistral(device="cpu")
mistral.load_model()
mock_model.assert_called_once()
mock_tokenizer.assert_called_once()
@patch('your_module.Mistral.load_model')
def test_mistral_run(mock_load_model):
mistral = Mistral(device="cpu")
mistral.run("What's the weather in miami")
mock_load_model.assert_called_once()
@patch('your_module.Mistral.run')
def test_mistral_chat(mock_run):
mistral = Mistral(device="cpu")
mistral.chat("What's the weather in miami")
mock_run.assert_called_once()
def test_mistral__stream_response():
mistral = Mistral(device="cpu")
response = "It's sunny in Miami."
tokens = list(mistral._stream_response(response))
assert tokens == ["It's", "sunny", "in", "Miami."]