You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
50 lines
1.5 KiB
50 lines
1.5 KiB
6 months ago
|
import unittest
|
||
|
from unittest.mock import Mock
|
||
|
|
||
|
import pytest
|
||
|
import torch
|
||
|
|
||
|
from swarms.utils import prep_torch_inference
|
||
|
|
||
|
|
||
|
def test_prep_torch_inference():
|
||
|
model_path = "model_path"
|
||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
|
model_mock = Mock()
|
||
|
model_mock.eval = Mock()
|
||
|
|
||
|
# Mocking the load_model_torch function to return our mock model.
|
||
|
with unittest.mock.patch(
|
||
|
"swarms.utils.load_model_torch", return_value=model_mock
|
||
|
) as _:
|
||
|
model = prep_torch_inference(model_path, device)
|
||
|
|
||
|
# Check if model was properly loaded and eval function was called
|
||
|
assert model == model_mock
|
||
|
model_mock.eval.assert_called_once()
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize(
|
||
|
"model_path, device",
|
||
|
[
|
||
|
(
|
||
|
"invalid_path",
|
||
|
torch.device("cuda"),
|
||
|
), # Invalid file path, valid device
|
||
|
(None, torch.device("cuda")), # None file path, valid device
|
||
|
("model_path", None), # Valid file path, None device
|
||
|
(None, None), # None file path, None device
|
||
|
],
|
||
|
)
|
||
|
def test_prep_torch_inference_exceptions(model_path, device):
|
||
|
with pytest.raises(Exception):
|
||
|
prep_torch_inference(model_path, device)
|
||
|
|
||
|
|
||
|
def test_prep_torch_inference_return_none():
|
||
|
model_path = "invalid_path" # Invalid file path
|
||
|
device = torch.device("cuda") # Valid device
|
||
|
|
||
|
# Since load_model_torch function will raise an exception, prep_torch_inference should return None
|
||
|
assert prep_torch_inference(model_path, device) is None
|