parent
a9b91d4653
commit
f39d722f2a
@ -0,0 +1,30 @@
|
|||||||
|
import torch
|
||||||
|
from swarms.utils.load_model_torch import load_model_torch
|
||||||
|
|
||||||
|
|
||||||
|
def prep_torch_inference(
|
||||||
|
model_path: str = None,
|
||||||
|
device: torch.device = None,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Prepare a Torch model for inference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path (str): Path to the model file.
|
||||||
|
device (torch.device): Device to run the model on.
|
||||||
|
*args: Additional positional arguments.
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.nn.Module: The prepared model.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
model = load_model_torch(model_path, device)
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
except Exception as e:
|
||||||
|
# Add error handling code here
|
||||||
|
print(f"Error occurred while preparing Torch model: {e}")
|
||||||
|
return None
|
@ -0,0 +1,50 @@
|
|||||||
|
import torch
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
from swarms.utils.prep_torch_model_inference import (
|
||||||
|
prep_torch_inference,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_prep_torch_inference_no_model_path():
|
||||||
|
result = prep_torch_inference()
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_prep_torch_inference_model_not_found(mocker):
|
||||||
|
mocker.patch(
|
||||||
|
"swarms.utils.prep_torch_model_inference.load_model_torch",
|
||||||
|
side_effect=FileNotFoundError,
|
||||||
|
)
|
||||||
|
result = prep_torch_inference("non_existent_model_path")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_prep_torch_inference_runtime_error(mocker):
|
||||||
|
mocker.patch(
|
||||||
|
"swarms.utils.prep_torch_model_inference.load_model_torch",
|
||||||
|
side_effect=RuntimeError,
|
||||||
|
)
|
||||||
|
result = prep_torch_inference("model_path")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_prep_torch_inference_no_device_specified(mocker):
|
||||||
|
mock_model = MagicMock(spec=torch.nn.Module)
|
||||||
|
mocker.patch(
|
||||||
|
"swarms.utils.prep_torch_model_inference.load_model_torch",
|
||||||
|
return_value=mock_model,
|
||||||
|
)
|
||||||
|
prep_torch_inference("model_path")
|
||||||
|
mock_model.eval.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_prep_torch_inference_device_specified(mocker):
|
||||||
|
mock_model = MagicMock(spec=torch.nn.Module)
|
||||||
|
mocker.patch(
|
||||||
|
"swarms.utils.prep_torch_model_inference.load_model_torch",
|
||||||
|
return_value=mock_model,
|
||||||
|
)
|
||||||
|
prep_torch_inference(
|
||||||
|
"model_path", device=torch.device("cuda")
|
||||||
|
)
|
||||||
|
mock_model.eval.assert_called_once()
|
Loading…
Reference in new issue