from unittest.mock import Mock
import torch
import pytest
from swarms.models.timm import TimmModel, TimmModelInfo


@pytest.fixture
def sample_model_info():
    return TimmModelInfo(model_name="resnet18", pretrained=True, in_chans=3)


def test_get_supported_models():
    model_handler = TimmModel()
    supported_models = model_handler._get_supported_models()
    assert isinstance(supported_models, list)
    assert len(supported_models) > 0


def test_create_model(sample_model_info):
    model_handler = TimmModel()
    model = model_handler._create_model(sample_model_info)
    assert isinstance(model, torch.nn.Module)


def test_call(sample_model_info):
    model_handler = TimmModel()
    input_tensor = torch.randn(1, 3, 224, 224)
    output_shape = model_handler.__call__(sample_model_info, input_tensor)
    assert isinstance(output_shape, torch.Size)


@pytest.mark.parametrize(
    "model_name, pretrained, in_chans",
    [
        ("resnet18", True, 3),
        ("resnet50", False, 1),
        ("efficientnet_b0", True, 3),
    ],
)
def test_create_model_parameterized(model_name, pretrained, in_chans):
    model_info = TimmModelInfo(
        model_name=model_name, pretrained=pretrained, in_chans=in_chans
    )
    model_handler = TimmModel()
    model = model_handler._create_model(model_info)
    assert isinstance(model, torch.nn.Module)


@pytest.mark.parametrize(
    "model_name, pretrained, in_chans",
    [
        ("resnet18", True, 3),
        ("resnet50", False, 1),
        ("efficientnet_b0", True, 3),
    ],
)
def test_call_parameterized(model_name, pretrained, in_chans):
    model_info = TimmModelInfo(
        model_name=model_name, pretrained=pretrained, in_chans=in_chans
    )
    model_handler = TimmModel()
    input_tensor = torch.randn(1, in_chans, 224, 224)
    output_shape = model_handler.__call__(model_info, input_tensor)
    assert isinstance(output_shape, torch.Size)


def test_get_supported_models_mock():
    model_handler = TimmModel()
    model_handler._get_supported_models = Mock(return_value=["resnet18", "resnet50"])
    supported_models = model_handler._get_supported_models()
    assert supported_models == ["resnet18", "resnet50"]


def test_create_model_mock(sample_model_info):
    model_handler = TimmModel()
    model_handler._create_model = Mock(return_value=torch.nn.Module())
    model = model_handler._create_model(sample_model_info)
    assert isinstance(model, torch.nn.Module)


def test_call_exception():
    model_handler = TimmModel()
    model_info = TimmModelInfo(model_name="invalid_model", pretrained=True, in_chans=3)
    input_tensor = torch.randn(1, 3, 224, 224)
    with pytest.raises(Exception):
        model_handler.__call__(model_info, input_tensor)


def test_coverage():
    pytest.main(["--cov=my_module", "--cov-report=html"])


def test_environment_variable():
    import os

    os.environ["MODEL_NAME"] = "resnet18"
    os.environ["PRETRAINED"] = "True"
    os.environ["IN_CHANS"] = "3"

    model_handler = TimmModel()
    model_info = TimmModelInfo(
        model_name=os.environ["MODEL_NAME"],
        pretrained=bool(os.environ["PRETRAINED"]),
        in_chans=int(os.environ["IN_CHANS"]),
    )
    input_tensor = torch.randn(1, model_info.in_chans, 224, 224)
    output_shape = model_handler(model_info, input_tensor)
    assert isinstance(output_shape, torch.Size)


@pytest.mark.slow
def test_marked_slow():
    model_handler = TimmModel()
    model_info = TimmModelInfo(model_name="resnet18", pretrained=True, in_chans=3)
    input_tensor = torch.randn(1, 3, 224, 224)
    output_shape = model_handler(model_info, input_tensor)
    assert isinstance(output_shape, torch.Size)


@pytest.mark.parametrize(
    "model_name, pretrained, in_chans",
    [
        ("resnet18", True, 3),
        ("resnet50", False, 1),
        ("efficientnet_b0", True, 3),
    ],
)
def test_marked_parameterized(model_name, pretrained, in_chans):
    model_info = TimmModelInfo(
        model_name=model_name, pretrained=pretrained, in_chans=in_chans
    )
    model_handler = TimmModel()
    model = model_handler._create_model(model_info)
    assert isinstance(model, torch.nn.Module)


def test_exception_testing():
    model_handler = TimmModel()
    model_info = TimmModelInfo(model_name="invalid_model", pretrained=True, in_chans=3)
    input_tensor = torch.randn(1, 3, 224, 224)
    with pytest.raises(Exception):
        model_handler.__call__(model_info, input_tensor)


def test_parameterized_testing():
    model_handler = TimmModel()
    model_info = TimmModelInfo(model_name="resnet18", pretrained=True, in_chans=3)
    input_tensor = torch.randn(1, 3, 224, 224)
    output_shape = model_handler.__call__(model_info, input_tensor)
    assert isinstance(output_shape, torch.Size)


def test_use_mocks_and_monkeypatching():
    model_handler = TimmModel()
    model_handler._create_model = Mock(return_value=torch.nn.Module())
    model_info = TimmModelInfo(model_name="resnet18", pretrained=True, in_chans=3)
    model = model_handler._create_model(model_info)
    assert isinstance(model, torch.nn.Module)


def test_coverage_report():
    # Install pytest-cov
    # Run tests with coverage report
    pytest.main(["--cov=my_module", "--cov-report=html"])