from unittest.mock import MagicMock

import torch

from swarms.utils.prep_torch_model_inference import (
    prep_torch_inference,
)


def test_prep_torch_inference_no_model_path():
    result = prep_torch_inference()
    assert result is None


def test_prep_torch_inference_model_not_found(mocker):
    mocker.patch(
        "swarms.utils.prep_torch_model_inference.load_model_torch",
        side_effect=FileNotFoundError,
    )
    result = prep_torch_inference("non_existent_model_path")
    assert result is None


def test_prep_torch_inference_runtime_error(mocker):
    mocker.patch(
        "swarms.utils.prep_torch_model_inference.load_model_torch",
        side_effect=RuntimeError,
    )
    result = prep_torch_inference("model_path")
    assert result is None


def test_prep_torch_inference_no_device_specified(mocker):
    mock_model = MagicMock(spec=torch.nn.Module)
    mocker.patch(
        "swarms.utils.prep_torch_model_inference.load_model_torch",
        return_value=mock_model,
    )
    prep_torch_inference("model_path")
    mock_model.eval.assert_called_once()


def test_prep_torch_inference_device_specified(mocker):
    mock_model = MagicMock(spec=torch.nn.Module)
    mocker.patch(
        "swarms.utils.prep_torch_model_inference.load_model_torch",
        return_value=mock_model,
    )
    prep_torch_inference("model_path", device=torch.device("cuda"))
    mock_model.eval.assert_called_once()