[FEAT][load_model_torch]

pull/334/head
Kye 1 year ago
parent 385d2df93a
commit a9b91d4653

@ -6,15 +6,16 @@ from swarms.utils.parse_code import (
from swarms.utils.pdf_to_text import pdf_to_text
from swarms.utils.math_eval import math_eval
from swarms.utils.llm_metrics_decorator import metrics_decorator
# from swarms.utils.phoenix_handler import phoenix_trace_decorator
from swarms.utils.device_checker_cuda import check_device
from swarms.utils.load_model_torch import load_model_torch
__all__ = [
"display_markdown_message",
"SubprocessCodeInterpreter",
"extract_code_in_backticks_in_string",
"pdf_to_text",
# "phoenix_trace_decorator",
"math_eval",
"metrics_decorator",
"check_device",
"load_model_torch",
]

@ -0,0 +1,57 @@
import torch
from torch import nn
def load_model_torch(
model_path: str = None,
device: torch.device = None,
model: nn.Module = None,
strict: bool = True,
map_location=None,
*args,
**kwargs,
) -> nn.Module:
"""
Load a PyTorch model from a given path and move it to the specified device.
Args:
model_path (str): Path to the saved model file.
device (torch.device): Device to move the model to.
model (nn.Module): The model architecture, if the model file only contains the state dictionary.
strict (bool): Whether to strictly enforce that the keys in the state dictionary match the keys returned by the model's `state_dict()` function.
map_location (callable): A function to remap the storage locations of the loaded model.
*args: Additional arguments to pass to `torch.load`.
**kwargs: Additional keyword arguments to pass to `torch.load`.
Returns:
nn.Module: The loaded model.
Raises:
FileNotFoundError: If the model file is not found.
RuntimeError: If there is an error while loading the model.
"""
if device is None:
device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
try:
if model is None:
model = torch.load(
model_path, map_location=map_location, *args, **kwargs
)
else:
model.load_state_dict(
torch.load(
model_path,
map_location=map_location,
*args,
**kwargs,
),
strict=strict,
)
return model.to(device)
except FileNotFoundError:
raise FileNotFoundError(f"Model file not found: {model_path}")
except RuntimeError as e:
raise RuntimeError(f"Error loading model: {str(e)}")

@ -0,0 +1,56 @@
import pytest
import torch
from unittest.mock import MagicMock
from swarms.utils.load_model_torch import load_model_torch
def test_load_model_torch_no_model_path():
with pytest.raises(FileNotFoundError):
load_model_torch()
def test_load_model_torch_model_not_found(mocker):
mocker.patch("torch.load", side_effect=FileNotFoundError)
with pytest.raises(FileNotFoundError):
load_model_torch("non_existent_model_path")
def test_load_model_torch_runtime_error(mocker):
mocker.patch("torch.load", side_effect=RuntimeError)
with pytest.raises(RuntimeError):
load_model_torch("model_path")
def test_load_model_torch_no_device_specified(mocker):
mock_model = MagicMock(spec=torch.nn.Module)
mocker.patch("torch.load", return_value=mock_model)
mocker.patch("torch.cuda.is_available", return_value=False)
model = load_model_torch("model_path")
mock_model.to.assert_called_once_with(torch.device("cpu"))
def test_load_model_torch_device_specified(mocker):
mock_model = MagicMock(spec=torch.nn.Module)
mocker.patch("torch.load", return_value=mock_model)
model = load_model_torch(
"model_path", device=torch.device("cuda")
)
mock_model.to.assert_called_once_with(torch.device("cuda"))
def test_load_model_torch_model_specified(mocker):
mock_model = MagicMock(spec=torch.nn.Module)
mocker.patch("torch.load", return_value={"key": "value"})
load_model_torch("model_path", model=mock_model)
mock_model.load_state_dict.assert_called_once_with(
{"key": "value"}, strict=True
)
def test_load_model_torch_model_specified_strict_false(mocker):
mock_model = MagicMock(spec=torch.nn.Module)
mocker.patch("torch.load", return_value={"key": "value"})
load_model_torch("model_path", model=mock_model, strict=False)
mock_model.load_state_dict.assert_called_once_with(
{"key": "value"}, strict=False
)
Loading…
Cancel
Save