from unittest.mock import MagicMock

import pytest
import torch

from swarms.utils.load_model_torch import load_model_torch


def test_load_model_torch_no_model_path():
    with pytest.raises(FileNotFoundError):
        load_model_torch()


def test_load_model_torch_model_not_found(mocker):
    mocker.patch("torch.load", side_effect=FileNotFoundError)
    with pytest.raises(FileNotFoundError):
        load_model_torch("non_existent_model_path")


def test_load_model_torch_runtime_error(mocker):
    mocker.patch("torch.load", side_effect=RuntimeError)
    with pytest.raises(RuntimeError):
        load_model_torch("model_path")


def test_load_model_torch_no_device_specified(mocker):
    mock_model = MagicMock(spec=torch.nn.Module)
    mocker.patch("torch.load", return_value=mock_model)
    mocker.patch("torch.cuda.is_available", return_value=False)
    load_model_torch("model_path")
    mock_model.to.assert_called_once_with(torch.device("cpu"))


def test_load_model_torch_device_specified(mocker):
    mock_model = MagicMock(spec=torch.nn.Module)
    mocker.patch("torch.load", return_value=mock_model)
    load_model_torch("model_path", device=torch.device("cuda"))
    mock_model.to.assert_called_once_with(torch.device("cuda"))


def test_load_model_torch_model_specified(mocker):
    mock_model = MagicMock(spec=torch.nn.Module)
    mocker.patch("torch.load", return_value={"key": "value"})
    load_model_torch("model_path", model=mock_model)
    mock_model.load_state_dict.assert_called_once_with(
        {"key": "value"}, strict=True
    )


def test_load_model_torch_model_specified_strict_false(mocker):
    mock_model = MagicMock(spec=torch.nn.Module)
    mocker.patch("torch.load", return_value={"key": "value"})
    load_model_torch("model_path", model=mock_model, strict=False)
    mock_model.load_state_dict.assert_called_once_with(
        {"key": "value"}, strict=False
    )