You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/tests/utils/test_prep_torch_model_infer...

51 lines
1.4 KiB

6 months ago
from unittest.mock import MagicMock
import torch
from swarms.utils.prep_torch_model_inference import (
prep_torch_inference,
)
def test_prep_torch_inference_no_model_path():
result = prep_torch_inference()
assert result is None
def test_prep_torch_inference_model_not_found(mocker):
mocker.patch(
"swarms.utils.prep_torch_model_inference.load_model_torch",
side_effect=FileNotFoundError,
)
result = prep_torch_inference("non_existent_model_path")
assert result is None
def test_prep_torch_inference_runtime_error(mocker):
mocker.patch(
"swarms.utils.prep_torch_model_inference.load_model_torch",
side_effect=RuntimeError,
)
result = prep_torch_inference("model_path")
assert result is None
def test_prep_torch_inference_no_device_specified(mocker):
mock_model = MagicMock(spec=torch.nn.Module)
mocker.patch(
"swarms.utils.prep_torch_model_inference.load_model_torch",
return_value=mock_model,
)
prep_torch_inference("model_path")
mock_model.eval.assert_called_once()
def test_prep_torch_inference_device_specified(mocker):
mock_model = MagicMock(spec=torch.nn.Module)
mocker.patch(
"swarms.utils.prep_torch_model_inference.load_model_torch",
return_value=mock_model,
)
prep_torch_inference("model_path", device=torch.device("cuda"))
mock_model.eval.assert_called_once()