parent
385d2df93a
commit
a9b91d4653
@ -0,0 +1,57 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
def load_model_torch(
|
||||
model_path: str = None,
|
||||
device: torch.device = None,
|
||||
model: nn.Module = None,
|
||||
strict: bool = True,
|
||||
map_location=None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Load a PyTorch model from a given path and move it to the specified device.
|
||||
|
||||
Args:
|
||||
model_path (str): Path to the saved model file.
|
||||
device (torch.device): Device to move the model to.
|
||||
model (nn.Module): The model architecture, if the model file only contains the state dictionary.
|
||||
strict (bool): Whether to strictly enforce that the keys in the state dictionary match the keys returned by the model's `state_dict()` function.
|
||||
map_location (callable): A function to remap the storage locations of the loaded model.
|
||||
*args: Additional arguments to pass to `torch.load`.
|
||||
**kwargs: Additional keyword arguments to pass to `torch.load`.
|
||||
|
||||
Returns:
|
||||
nn.Module: The loaded model.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the model file is not found.
|
||||
RuntimeError: If there is an error while loading the model.
|
||||
"""
|
||||
if device is None:
|
||||
device = torch.device(
|
||||
"cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
|
||||
try:
|
||||
if model is None:
|
||||
model = torch.load(
|
||||
model_path, map_location=map_location, *args, **kwargs
|
||||
)
|
||||
else:
|
||||
model.load_state_dict(
|
||||
torch.load(
|
||||
model_path,
|
||||
map_location=map_location,
|
||||
*args,
|
||||
**kwargs,
|
||||
),
|
||||
strict=strict,
|
||||
)
|
||||
return model.to(device)
|
||||
except FileNotFoundError:
|
||||
raise FileNotFoundError(f"Model file not found: {model_path}")
|
||||
except RuntimeError as e:
|
||||
raise RuntimeError(f"Error loading model: {str(e)}")
|
@ -0,0 +1,56 @@
|
||||
import pytest
|
||||
import torch
|
||||
from unittest.mock import MagicMock
|
||||
from swarms.utils.load_model_torch import load_model_torch
|
||||
|
||||
|
||||
def test_load_model_torch_no_model_path():
|
||||
with pytest.raises(FileNotFoundError):
|
||||
load_model_torch()
|
||||
|
||||
|
||||
def test_load_model_torch_model_not_found(mocker):
|
||||
mocker.patch("torch.load", side_effect=FileNotFoundError)
|
||||
with pytest.raises(FileNotFoundError):
|
||||
load_model_torch("non_existent_model_path")
|
||||
|
||||
|
||||
def test_load_model_torch_runtime_error(mocker):
|
||||
mocker.patch("torch.load", side_effect=RuntimeError)
|
||||
with pytest.raises(RuntimeError):
|
||||
load_model_torch("model_path")
|
||||
|
||||
|
||||
def test_load_model_torch_no_device_specified(mocker):
|
||||
mock_model = MagicMock(spec=torch.nn.Module)
|
||||
mocker.patch("torch.load", return_value=mock_model)
|
||||
mocker.patch("torch.cuda.is_available", return_value=False)
|
||||
model = load_model_torch("model_path")
|
||||
mock_model.to.assert_called_once_with(torch.device("cpu"))
|
||||
|
||||
|
||||
def test_load_model_torch_device_specified(mocker):
|
||||
mock_model = MagicMock(spec=torch.nn.Module)
|
||||
mocker.patch("torch.load", return_value=mock_model)
|
||||
model = load_model_torch(
|
||||
"model_path", device=torch.device("cuda")
|
||||
)
|
||||
mock_model.to.assert_called_once_with(torch.device("cuda"))
|
||||
|
||||
|
||||
def test_load_model_torch_model_specified(mocker):
|
||||
mock_model = MagicMock(spec=torch.nn.Module)
|
||||
mocker.patch("torch.load", return_value={"key": "value"})
|
||||
load_model_torch("model_path", model=mock_model)
|
||||
mock_model.load_state_dict.assert_called_once_with(
|
||||
{"key": "value"}, strict=True
|
||||
)
|
||||
|
||||
|
||||
def test_load_model_torch_model_specified_strict_false(mocker):
|
||||
mock_model = MagicMock(spec=torch.nn.Module)
|
||||
mocker.patch("torch.load", return_value={"key": "value"})
|
||||
load_model_torch("model_path", model=mock_model, strict=False)
|
||||
mock_model.load_state_dict.assert_called_once_with(
|
||||
{"key": "value"}, strict=False
|
||||
)
|
Loading…
Reference in new issue