From a9b91d46533616b2e62c36fee1677fc81adeab02 Mon Sep 17 00:00:00 2001 From: Kye Date: Sat, 23 Dec 2023 18:23:02 -0500 Subject: [PATCH] [FEAT][load_model_torch] --- swarms/utils/__init__.py | 7 ++-- swarms/utils/load_model_torch.py | 57 ++++++++++++++++++++++++++++++++ tests/utils/load_models_torch.py | 56 +++++++++++++++++++++++++++++++ 3 files changed, 117 insertions(+), 3 deletions(-) create mode 100644 swarms/utils/load_model_torch.py create mode 100644 tests/utils/load_models_torch.py diff --git a/swarms/utils/__init__.py b/swarms/utils/__init__.py index 7dedefec..46628aae 100644 --- a/swarms/utils/__init__.py +++ b/swarms/utils/__init__.py @@ -6,15 +6,16 @@ from swarms.utils.parse_code import ( from swarms.utils.pdf_to_text import pdf_to_text from swarms.utils.math_eval import math_eval from swarms.utils.llm_metrics_decorator import metrics_decorator - -# from swarms.utils.phoenix_handler import phoenix_trace_decorator +from swarms.utils.device_checker_cuda import check_device +from swarms.utils.load_model_torch import load_model_torch __all__ = [ "display_markdown_message", "SubprocessCodeInterpreter", "extract_code_in_backticks_in_string", "pdf_to_text", - # "phoenix_trace_decorator", "math_eval", "metrics_decorator", + "check_device", + "load_model_torch", ] diff --git a/swarms/utils/load_model_torch.py b/swarms/utils/load_model_torch.py new file mode 100644 index 00000000..53649e93 --- /dev/null +++ b/swarms/utils/load_model_torch.py @@ -0,0 +1,57 @@ +import torch +from torch import nn + + +def load_model_torch( + model_path: str = None, + device: torch.device = None, + model: nn.Module = None, + strict: bool = True, + map_location=None, + *args, + **kwargs, +) -> nn.Module: + """ + Load a PyTorch model from a given path and move it to the specified device. + + Args: + model_path (str): Path to the saved model file. + device (torch.device): Device to move the model to. + model (nn.Module): The model architecture, if the model file only contains the state dictionary. + strict (bool): Whether to strictly enforce that the keys in the state dictionary match the keys returned by the model's `state_dict()` function. + map_location (callable): A function to remap the storage locations of the loaded model. + *args: Additional arguments to pass to `torch.load`. + **kwargs: Additional keyword arguments to pass to `torch.load`. + + Returns: + nn.Module: The loaded model. + + Raises: + FileNotFoundError: If the model file is not found. + RuntimeError: If there is an error while loading the model. + """ + if device is None: + device = torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ) + + try: + if model is None: + model = torch.load( + model_path, map_location=map_location, *args, **kwargs + ) + else: + model.load_state_dict( + torch.load( + model_path, + map_location=map_location, + *args, + **kwargs, + ), + strict=strict, + ) + return model.to(device) + except FileNotFoundError: + raise FileNotFoundError(f"Model file not found: {model_path}") + except RuntimeError as e: + raise RuntimeError(f"Error loading model: {str(e)}") diff --git a/tests/utils/load_models_torch.py b/tests/utils/load_models_torch.py new file mode 100644 index 00000000..15a66537 --- /dev/null +++ b/tests/utils/load_models_torch.py @@ -0,0 +1,56 @@ +import pytest +import torch +from unittest.mock import MagicMock +from swarms.utils.load_model_torch import load_model_torch + + +def test_load_model_torch_no_model_path(): + with pytest.raises(FileNotFoundError): + load_model_torch() + + +def test_load_model_torch_model_not_found(mocker): + mocker.patch("torch.load", side_effect=FileNotFoundError) + with pytest.raises(FileNotFoundError): + load_model_torch("non_existent_model_path") + + +def test_load_model_torch_runtime_error(mocker): + mocker.patch("torch.load", side_effect=RuntimeError) + with pytest.raises(RuntimeError): + load_model_torch("model_path") + + +def test_load_model_torch_no_device_specified(mocker): + mock_model = MagicMock(spec=torch.nn.Module) + mocker.patch("torch.load", return_value=mock_model) + mocker.patch("torch.cuda.is_available", return_value=False) + model = load_model_torch("model_path") + mock_model.to.assert_called_once_with(torch.device("cpu")) + + +def test_load_model_torch_device_specified(mocker): + mock_model = MagicMock(spec=torch.nn.Module) + mocker.patch("torch.load", return_value=mock_model) + model = load_model_torch( + "model_path", device=torch.device("cuda") + ) + mock_model.to.assert_called_once_with(torch.device("cuda")) + + +def test_load_model_torch_model_specified(mocker): + mock_model = MagicMock(spec=torch.nn.Module) + mocker.patch("torch.load", return_value={"key": "value"}) + load_model_torch("model_path", model=mock_model) + mock_model.load_state_dict.assert_called_once_with( + {"key": "value"}, strict=True + ) + + +def test_load_model_torch_model_specified_strict_false(mocker): + mock_model = MagicMock(spec=torch.nn.Module) + mocker.patch("torch.load", return_value={"key": "value"}) + load_model_torch("model_path", model=mock_model, strict=False) + mock_model.load_state_dict.assert_called_once_with( + {"key": "value"}, strict=False + )