from unittest.mock import Mock, patch import pytest import torch from swarms.models import TimmModel def test_timm_model_init(): with patch("swarms.models.timm.list_models") as mock_list_models: model_name = "resnet18" pretrained = True in_chans = 3 timm_model = TimmModel(model_name, pretrained, in_chans) mock_list_models.assert_called_once() assert timm_model.model_name == model_name assert timm_model.pretrained == pretrained assert timm_model.in_chans == in_chans assert timm_model.models == mock_list_models.return_value def test_timm_model_call(): with patch( "swarms.models.timm.create_model" ) as mock_create_model: model_name = "resnet18" pretrained = True in_chans = 3 timm_model = TimmModel(model_name, pretrained, in_chans) task = torch.rand(1, in_chans, 224, 224) result = timm_model(task) mock_create_model.assert_called_once_with( model_name, pretrained=pretrained, in_chans=in_chans ) assert result == mock_create_model.return_value(task) def test_timm_model_list_models(): with patch("swarms.models.timm.list_models") as mock_list_models: model_name = "resnet18" pretrained = True in_chans = 3 timm_model = TimmModel(model_name, pretrained, in_chans) result = timm_model.list_models() mock_list_models.assert_called_once() assert result == mock_list_models.return_value def test_get_supported_models(): model_handler = TimmModel() supported_models = model_handler._get_supported_models() assert isinstance(supported_models, list) assert len(supported_models) > 0 def test_create_model(sample_model_info): model_handler = TimmModel() model = model_handler._create_model(sample_model_info) assert isinstance(model, torch.nn.Module) def test_call(sample_model_info): model_handler = TimmModel() input_tensor = torch.randn(1, 3, 224, 224) output_shape = model_handler.__call__( sample_model_info, input_tensor ) assert isinstance(output_shape, torch.Size) def test_get_supported_models_mock(): model_handler = TimmModel() model_handler._get_supported_models = Mock( return_value=["resnet18", "resnet50"] ) supported_models = model_handler._get_supported_models() assert supported_models == ["resnet18", "resnet50"] def test_create_model_mock(sample_model_info): model_handler = TimmModel() model_handler._create_model = Mock(return_value=torch.nn.Module()) model = model_handler._create_model(sample_model_info) assert isinstance(model, torch.nn.Module) def test_coverage_report(): # Install pytest-cov # Run tests with coverage report pytest.main(["--cov=my_module", "--cov-report=html"])