You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/tests/utils/test_load_models_torch.py

57 lines
1.9 KiB

11 months ago
from unittest.mock import MagicMock
import pytest
import torch
11 months ago
from swarms.utils.load_model_torch import load_model_torch
def test_load_model_torch_no_model_path():
with pytest.raises(FileNotFoundError):
load_model_torch()
def test_load_model_torch_model_not_found(mocker):
mocker.patch("torch.load", side_effect=FileNotFoundError)
with pytest.raises(FileNotFoundError):
load_model_torch("non_existent_model_path")
def test_load_model_torch_runtime_error(mocker):
mocker.patch("torch.load", side_effect=RuntimeError)
with pytest.raises(RuntimeError):
load_model_torch("model_path")
def test_load_model_torch_no_device_specified(mocker):
mock_model = MagicMock(spec=torch.nn.Module)
mocker.patch("torch.load", return_value=mock_model)
mocker.patch("torch.cuda.is_available", return_value=False)
load_model_torch("model_path")
mock_model.to.assert_called_once_with(torch.device("cpu"))
def test_load_model_torch_device_specified(mocker):
mock_model = MagicMock(spec=torch.nn.Module)
mocker.patch("torch.load", return_value=mock_model)
1 year ago
load_model_torch("model_path", device=torch.device("cuda"))
mock_model.to.assert_called_once_with(torch.device("cuda"))
def test_load_model_torch_model_specified(mocker):
mock_model = MagicMock(spec=torch.nn.Module)
mocker.patch("torch.load", return_value={"key": "value"})
load_model_torch("model_path", model=mock_model)
mock_model.load_state_dict.assert_called_once_with(
{"key": "value"}, strict=True
)
def test_load_model_torch_model_specified_strict_false(mocker):
mock_model = MagicMock(spec=torch.nn.Module)
mocker.patch("torch.load", return_value={"key": "value"})
load_model_torch("model_path", model=mock_model, strict=False)
mock_model.load_state_dict.assert_called_once_with(
{"key": "value"}, strict=False
)