import unittest from unittest.mock import Mock import pytest import torch from swarms.utils import prep_torch_inference def test_prep_torch_inference(): model_path = "model_path" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_mock = Mock() model_mock.eval = Mock() # Mocking the load_model_torch function to return our mock model. with unittest.mock.patch( "swarms.utils.load_model_torch", return_value=model_mock ) as _: model = prep_torch_inference(model_path, device) # Check if model was properly loaded and eval function was called assert model == model_mock model_mock.eval.assert_called_once() @pytest.mark.parametrize( "model_path, device", [ ( "invalid_path", torch.device("cuda"), ), # Invalid file path, valid device (None, torch.device("cuda")), # None file path, valid device ("model_path", None), # Valid file path, None device (None, None), # None file path, None device ], ) def test_prep_torch_inference_exceptions(model_path, device): with pytest.raises(Exception): prep_torch_inference(model_path, device) def test_prep_torch_inference_return_none(): model_path = "invalid_path" # Invalid file path device = torch.device("cuda") # Valid device # Since load_model_torch function will raise an exception, prep_torch_inference should return None assert prep_torch_inference(model_path, device) is None