import pytest
import torch
from torch import nn

from swarms.utils import load_model_torch


class DummyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(10, 2)

    def forward(self, x):
        return self.fc(x)


# Test case 1: Test if model can be loaded successfully
def test_load_model_torch_success(tmp_path):
    model = DummyModel()
    # Save the model to a temporary directory
    model_path = tmp_path / "model.pt"
    torch.save(model.state_dict(), model_path)

    # Load the model
    model_loaded = load_model_torch(model_path, model=DummyModel())

    # Check if loaded model has the same architecture
    assert isinstance(
        model_loaded, DummyModel
    ), "Loaded model type mismatch."


# Test case 2: Test if function raises FileNotFoundError for non-existent file
def test_load_model_torch_file_not_found():
    with pytest.raises(FileNotFoundError):
        load_model_torch("non_existent_model.pt")


# Test case 3: Test if function catches and raises RuntimeError for invalid model file
def test_load_model_torch_invalid_file(tmp_path):
    file = tmp_path / "invalid_model.pt"
    file.write_text("Invalid model file.")

    with pytest.raises(RuntimeError):
        load_model_torch(file)


# Test case 4: Test for handling of 'strict' parameter
def test_load_model_torch_strict_handling(tmp_path):
    # Create a model and modify it to cause a mismatch
    model = DummyModel()
    model.fc = nn.Linear(10, 3)
    model_path = tmp_path / "model.pt"
    torch.save(model.state_dict(), model_path)

    # Try to load the modified model with 'strict' parameter set to True
    with pytest.raises(RuntimeError):
        load_model_torch(model_path, model=DummyModel(), strict=True)


# Test case 5: Test for 'device' parameter handling
def test_load_model_torch_device_handling(tmp_path):
    model = DummyModel()
    model_path = tmp_path / "model.pt"
    torch.save(model.state_dict(), model_path)

    # Define a device other than default and load the model to the specified device
    device = torch.device("cpu")
    model_loaded = load_model_torch(
        model_path, model=DummyModel(), device=device
    )

    assert (
        model_loaded.fc.weight.device == device
    ), "Model not loaded to specified device."


# Test case 6: Testing for correct handling of '*args' and '**kwargs'
def test_load_model_torch_args_kwargs_handling(monkeypatch, tmp_path):
    model = DummyModel()
    model_path = tmp_path / "model.pt"
    torch.save(model.state_dict(), model_path)

    def mock_torch_load(*args, **kwargs):
        assert (
            "pickle_module" in kwargs
        ), "Keyword arguments not passed to 'torch.load'."

    # Monkeypatch 'torch.load' to check if '*args' and '**kwargs' are passed correctly
    monkeypatch.setattr(torch, "load", mock_torch_load)
    load_model_torch(
        model_path, model=DummyModel(), pickle_module="dummy_module"
    )


# Test case 7: Test for model loading on CPU if no GPU is available
def test_load_model_torch_cpu(tmp_path):
    model = DummyModel()
    model_path = tmp_path / "model.pt"
    torch.save(model.state_dict(), model_path)

    def mock_torch_cuda_is_available():
        return False

    # Monkeypatch to simulate no GPU available
    pytest.MonkeyPatch.setattr(
        torch.cuda, "is_available", mock_torch_cuda_is_available
    )
    model_loaded = load_model_torch(model_path, model=DummyModel())

    # Ensure model is loaded on CPU
    assert next(model_loaded.parameters()).device.type == "cpu"