You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
55 lines
1.9 KiB
55 lines
1.9 KiB
import pytest
|
|
import torch
|
|
from unittest.mock import MagicMock
|
|
from swarms.utils.load_model_torch import load_model_torch
|
|
|
|
|
|
def test_load_model_torch_no_model_path():
|
|
with pytest.raises(FileNotFoundError):
|
|
load_model_torch()
|
|
|
|
|
|
def test_load_model_torch_model_not_found(mocker):
|
|
mocker.patch("torch.load", side_effect=FileNotFoundError)
|
|
with pytest.raises(FileNotFoundError):
|
|
load_model_torch("non_existent_model_path")
|
|
|
|
|
|
def test_load_model_torch_runtime_error(mocker):
|
|
mocker.patch("torch.load", side_effect=RuntimeError)
|
|
with pytest.raises(RuntimeError):
|
|
load_model_torch("model_path")
|
|
|
|
|
|
def test_load_model_torch_no_device_specified(mocker):
|
|
mock_model = MagicMock(spec=torch.nn.Module)
|
|
mocker.patch("torch.load", return_value=mock_model)
|
|
mocker.patch("torch.cuda.is_available", return_value=False)
|
|
load_model_torch("model_path")
|
|
mock_model.to.assert_called_once_with(torch.device("cpu"))
|
|
|
|
|
|
def test_load_model_torch_device_specified(mocker):
|
|
mock_model = MagicMock(spec=torch.nn.Module)
|
|
mocker.patch("torch.load", return_value=mock_model)
|
|
load_model_torch("model_path", device=torch.device("cuda"))
|
|
mock_model.to.assert_called_once_with(torch.device("cuda"))
|
|
|
|
|
|
def test_load_model_torch_model_specified(mocker):
|
|
mock_model = MagicMock(spec=torch.nn.Module)
|
|
mocker.patch("torch.load", return_value={"key": "value"})
|
|
load_model_torch("model_path", model=mock_model)
|
|
mock_model.load_state_dict.assert_called_once_with(
|
|
{"key": "value"}, strict=True
|
|
)
|
|
|
|
|
|
def test_load_model_torch_model_specified_strict_false(mocker):
|
|
mock_model = MagicMock(spec=torch.nn.Module)
|
|
mocker.patch("torch.load", return_value={"key": "value"})
|
|
load_model_torch("model_path", model=mock_model, strict=False)
|
|
mock_model.load_state_dict.assert_called_once_with(
|
|
{"key": "value"}, strict=False
|
|
)
|