parent
a9b91d4653
commit
f39d722f2a
@ -0,0 +1,30 @@
|
||||
import torch
|
||||
from swarms.utils.load_model_torch import load_model_torch
|
||||
|
||||
|
||||
def prep_torch_inference(
|
||||
model_path: str = None,
|
||||
device: torch.device = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Prepare a Torch model for inference.
|
||||
|
||||
Args:
|
||||
model_path (str): Path to the model file.
|
||||
device (torch.device): Device to run the model on.
|
||||
*args: Additional positional arguments.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
torch.nn.Module: The prepared model.
|
||||
"""
|
||||
try:
|
||||
model = load_model_torch(model_path, device)
|
||||
model.eval()
|
||||
return model
|
||||
except Exception as e:
|
||||
# Add error handling code here
|
||||
print(f"Error occurred while preparing Torch model: {e}")
|
||||
return None
|
@ -0,0 +1,50 @@
|
||||
import torch
|
||||
from unittest.mock import MagicMock
|
||||
from swarms.utils.prep_torch_model_inference import (
|
||||
prep_torch_inference,
|
||||
)
|
||||
|
||||
|
||||
def test_prep_torch_inference_no_model_path():
|
||||
result = prep_torch_inference()
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_prep_torch_inference_model_not_found(mocker):
|
||||
mocker.patch(
|
||||
"swarms.utils.prep_torch_model_inference.load_model_torch",
|
||||
side_effect=FileNotFoundError,
|
||||
)
|
||||
result = prep_torch_inference("non_existent_model_path")
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_prep_torch_inference_runtime_error(mocker):
|
||||
mocker.patch(
|
||||
"swarms.utils.prep_torch_model_inference.load_model_torch",
|
||||
side_effect=RuntimeError,
|
||||
)
|
||||
result = prep_torch_inference("model_path")
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_prep_torch_inference_no_device_specified(mocker):
|
||||
mock_model = MagicMock(spec=torch.nn.Module)
|
||||
mocker.patch(
|
||||
"swarms.utils.prep_torch_model_inference.load_model_torch",
|
||||
return_value=mock_model,
|
||||
)
|
||||
prep_torch_inference("model_path")
|
||||
mock_model.eval.assert_called_once()
|
||||
|
||||
|
||||
def test_prep_torch_inference_device_specified(mocker):
|
||||
mock_model = MagicMock(spec=torch.nn.Module)
|
||||
mocker.patch(
|
||||
"swarms.utils.prep_torch_model_inference.load_model_torch",
|
||||
return_value=mock_model,
|
||||
)
|
||||
prep_torch_inference(
|
||||
"model_path", device=torch.device("cuda")
|
||||
)
|
||||
mock_model.eval.assert_called_once()
|
Loading…
Reference in new issue