parent
385d2df93a
commit
a9b91d4653
@ -0,0 +1,57 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_torch(
|
||||||
|
model_path: str = None,
|
||||||
|
device: torch.device = None,
|
||||||
|
model: nn.Module = None,
|
||||||
|
strict: bool = True,
|
||||||
|
map_location=None,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
) -> nn.Module:
|
||||||
|
"""
|
||||||
|
Load a PyTorch model from a given path and move it to the specified device.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path (str): Path to the saved model file.
|
||||||
|
device (torch.device): Device to move the model to.
|
||||||
|
model (nn.Module): The model architecture, if the model file only contains the state dictionary.
|
||||||
|
strict (bool): Whether to strictly enforce that the keys in the state dictionary match the keys returned by the model's `state_dict()` function.
|
||||||
|
map_location (callable): A function to remap the storage locations of the loaded model.
|
||||||
|
*args: Additional arguments to pass to `torch.load`.
|
||||||
|
**kwargs: Additional keyword arguments to pass to `torch.load`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
nn.Module: The loaded model.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If the model file is not found.
|
||||||
|
RuntimeError: If there is an error while loading the model.
|
||||||
|
"""
|
||||||
|
if device is None:
|
||||||
|
device = torch.device(
|
||||||
|
"cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if model is None:
|
||||||
|
model = torch.load(
|
||||||
|
model_path, map_location=map_location, *args, **kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model.load_state_dict(
|
||||||
|
torch.load(
|
||||||
|
model_path,
|
||||||
|
map_location=map_location,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
),
|
||||||
|
strict=strict,
|
||||||
|
)
|
||||||
|
return model.to(device)
|
||||||
|
except FileNotFoundError:
|
||||||
|
raise FileNotFoundError(f"Model file not found: {model_path}")
|
||||||
|
except RuntimeError as e:
|
||||||
|
raise RuntimeError(f"Error loading model: {str(e)}")
|
@ -0,0 +1,56 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
from swarms.utils.load_model_torch import load_model_torch
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_model_torch_no_model_path():
|
||||||
|
with pytest.raises(FileNotFoundError):
|
||||||
|
load_model_torch()
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_model_torch_model_not_found(mocker):
|
||||||
|
mocker.patch("torch.load", side_effect=FileNotFoundError)
|
||||||
|
with pytest.raises(FileNotFoundError):
|
||||||
|
load_model_torch("non_existent_model_path")
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_model_torch_runtime_error(mocker):
|
||||||
|
mocker.patch("torch.load", side_effect=RuntimeError)
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
load_model_torch("model_path")
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_model_torch_no_device_specified(mocker):
|
||||||
|
mock_model = MagicMock(spec=torch.nn.Module)
|
||||||
|
mocker.patch("torch.load", return_value=mock_model)
|
||||||
|
mocker.patch("torch.cuda.is_available", return_value=False)
|
||||||
|
model = load_model_torch("model_path")
|
||||||
|
mock_model.to.assert_called_once_with(torch.device("cpu"))
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_model_torch_device_specified(mocker):
|
||||||
|
mock_model = MagicMock(spec=torch.nn.Module)
|
||||||
|
mocker.patch("torch.load", return_value=mock_model)
|
||||||
|
model = load_model_torch(
|
||||||
|
"model_path", device=torch.device("cuda")
|
||||||
|
)
|
||||||
|
mock_model.to.assert_called_once_with(torch.device("cuda"))
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_model_torch_model_specified(mocker):
|
||||||
|
mock_model = MagicMock(spec=torch.nn.Module)
|
||||||
|
mocker.patch("torch.load", return_value={"key": "value"})
|
||||||
|
load_model_torch("model_path", model=mock_model)
|
||||||
|
mock_model.load_state_dict.assert_called_once_with(
|
||||||
|
{"key": "value"}, strict=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_model_torch_model_specified_strict_false(mocker):
|
||||||
|
mock_model = MagicMock(spec=torch.nn.Module)
|
||||||
|
mocker.patch("torch.load", return_value={"key": "value"})
|
||||||
|
load_model_torch("model_path", model=mock_model, strict=False)
|
||||||
|
mock_model.load_state_dict.assert_called_once_with(
|
||||||
|
{"key": "value"}, strict=False
|
||||||
|
)
|
Loading…
Reference in new issue