You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/tests/utils/test_prep_torch_inference.py

50 lines
1.5 KiB

6 months ago
import unittest
from unittest.mock import Mock
import pytest
import torch
from swarms.utils import prep_torch_inference
def test_prep_torch_inference():
model_path = "model_path"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_mock = Mock()
model_mock.eval = Mock()
# Mocking the load_model_torch function to return our mock model.
with unittest.mock.patch(
"swarms.utils.load_model_torch", return_value=model_mock
) as _:
model = prep_torch_inference(model_path, device)
# Check if model was properly loaded and eval function was called
assert model == model_mock
model_mock.eval.assert_called_once()
@pytest.mark.parametrize(
"model_path, device",
[
(
"invalid_path",
torch.device("cuda"),
), # Invalid file path, valid device
(None, torch.device("cuda")), # None file path, valid device
("model_path", None), # Valid file path, None device
(None, None), # None file path, None device
],
)
def test_prep_torch_inference_exceptions(model_path, device):
with pytest.raises(Exception):
prep_torch_inference(model_path, device)
def test_prep_torch_inference_return_none():
model_path = "invalid_path" # Invalid file path
device = torch.device("cuda") # Valid device
# Since load_model_torch function will raise an exception, prep_torch_inference should return None
assert prep_torch_inference(model_path, device) is None