diff --git a/tests/models/nougat.py b/tests/models/nougat.py new file mode 100644 index 00000000..e61a45af --- /dev/null +++ b/tests/models/nougat.py @@ -0,0 +1,206 @@ +import os +from unittest.mock import MagicMock, Mock, patch + +import pytest +import torch +from PIL import Image +from transformers import NougatProcessor, VisionEncoderDecoderModel + +from swarms.models.nougat import Nougat + + +@pytest.fixture +def setup_nougat(): + return Nougat() + + +def test_nougat_default_initialization(setup_nougat): + assert setup_nougat.model_name_or_path == "facebook/nougat-base" + assert setup_nougat.min_length == 1 + assert setup_nougat.max_new_tokens == 30 + + +def test_nougat_custom_initialization(): + nougat = Nougat(model_name_or_path="custom_path", min_length=10, max_new_tokens=50) + assert nougat.model_name_or_path == "custom_path" + assert nougat.min_length == 10 + assert nougat.max_new_tokens == 50 + + +def test_processor_initialization(setup_nougat): + assert isinstance(setup_nougat.processor, NougatProcessor) + + +def test_model_initialization(setup_nougat): + assert isinstance(setup_nougat.model, VisionEncoderDecoderModel) + + +@pytest.mark.parametrize( + "cuda_available, expected_device", [(True, "cuda"), (False, "cpu")] +) +def test_device_initialization(cuda_available, expected_device, monkeypatch): + monkeypatch.setattr( + torch, "cuda", Mock(is_available=Mock(return_value=cuda_available)) + ) + nougat = Nougat() + assert nougat.device == expected_device + + +def test_get_image_valid_path(setup_nougat): + with patch("PIL.Image.open") as mock_open: + mock_open.return_value = Mock(spec=Image.Image) + assert setup_nougat.get_image("valid_path") is not None + + +def test_get_image_invalid_path(setup_nougat): + with pytest.raises(FileNotFoundError): + setup_nougat.get_image("invalid_path") + + +@pytest.mark.parametrize( + "min_len, max_tokens", + [ + (1, 30), + (5, 40), + (10, 50), + ], +) +def test_model_call_with_diff_params(setup_nougat, min_len, max_tokens): + setup_nougat.min_length = min_len + setup_nougat.max_new_tokens = max_tokens + + with patch("PIL.Image.open") as mock_open: + mock_open.return_value = Mock(spec=Image.Image) + # Here, mocking other required methods or adding more complex logic would be necessary. + result = setup_nougat("valid_path") + assert isinstance(result, str) + + +def test_model_call_invalid_image_path(setup_nougat): + with pytest.raises(FileNotFoundError): + setup_nougat("invalid_path") + + +def test_model_call_mocked_output(setup_nougat): + with patch("PIL.Image.open") as mock_open: + mock_open.return_value = Mock(spec=Image.Image) + mock_model = MagicMock() + mock_model.generate.return_value = "mocked_output" + setup_nougat.model = mock_model + + result = setup_nougat("valid_path") + assert result == "mocked_output" + + +@pytest.fixture +def mock_processor_and_model(): + """Mock the NougatProcessor and VisionEncoderDecoderModel to simulate their behavior.""" + with patch( + "transformers.NougatProcessor.from_pretrained", return_value=Mock() + ), patch( + "transformers.VisionEncoderDecoderModel.from_pretrained", return_value=Mock() + ): + yield + + +@pytest.mark.usefixtures("mock_processor_and_model") +def test_nougat_with_sample_image_1(setup_nougat): + result = setup_nougat( + os.path.join( + "sample_images", + "https://plus.unsplash.com/premium_photo-1687149699194-0207c04bc6e8?auto=format&fit=crop&q=80&w=1378&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D", + ) + ) + assert isinstance(result, str) + + +@pytest.mark.usefixtures("mock_processor_and_model") +def test_nougat_with_sample_image_2(setup_nougat): + result = setup_nougat(os.path.join("sample_images", "test2.png")) + assert isinstance(result, str) + + +@pytest.mark.usefixtures("mock_processor_and_model") +def test_nougat_min_length_param(setup_nougat): + setup_nougat.min_length = 10 + result = setup_nougat( + os.path.join( + "sample_images", + "https://plus.unsplash.com/premium_photo-1687149699194-0207c04bc6e8?auto=format&fit=crop&q=80&w=1378&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D", + ) + ) + assert isinstance(result, str) + + +@pytest.mark.usefixtures("mock_processor_and_model") +def test_nougat_max_new_tokens_param(setup_nougat): + setup_nougat.max_new_tokens = 50 + result = setup_nougat( + os.path.join( + "sample_images", + "https://plus.unsplash.com/premium_photo-1687149699194-0207c04bc6e8?auto=format&fit=crop&q=80&w=1378&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D", + ) + ) + assert isinstance(result, str) + + +@pytest.mark.usefixtures("mock_processor_and_model") +def test_nougat_different_model_path(setup_nougat): + setup_nougat.model_name_or_path = "different/path" + result = setup_nougat( + os.path.join( + "sample_images", + "https://plus.unsplash.com/premium_photo-1687149699194-0207c04bc6e8?auto=format&fit=crop&q=80&w=1378&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D", + ) + ) + assert isinstance(result, str) + + +@pytest.mark.usefixtures("mock_processor_and_model") +def test_nougat_bad_image_path(setup_nougat): + with pytest.raises(Exception): # Adjust the exception type accordingly. + setup_nougat("bad_image_path.png") + + +@pytest.mark.usefixtures("mock_processor_and_model") +def test_nougat_image_large_size(setup_nougat): + result = setup_nougat( + os.path.join( + "sample_images", + "https://images.unsplash.com/photo-1697641039266-bfa00367f7cb?auto=format&fit=crop&q=60&w=400&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHx0b3BpYy1mZWVkfDJ8SnBnNktpZGwtSGt8fGVufDB8fHx8fA%3D%3D", + ) + ) + assert isinstance(result, str) + + +@pytest.mark.usefixtures("mock_processor_and_model") +def test_nougat_image_small_size(setup_nougat): + result = setup_nougat( + os.path.join( + "sample_images", + "https://images.unsplash.com/photo-1697638626987-aa865b769276?auto=format&fit=crop&q=60&w=400&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHx0b3BpYy1mZWVkfDd8SnBnNktpZGwtSGt8fGVufDB8fHx8fA%3D%3D", + ) + ) + assert isinstance(result, str) + + +@pytest.mark.usefixtures("mock_processor_and_model") +def test_nougat_image_varied_content(setup_nougat): + result = setup_nougat( + os.path.join( + "sample_images", + "https://images.unsplash.com/photo-1697469994783-b12bbd9c4cff?auto=format&fit=crop&q=60&w=400&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHx0b3BpYy1mZWVkfDE0fEpwZzZLaWRsLUhrfHxlbnwwfHx8fHw%3D", + ) + ) + assert isinstance(result, str) + + +@pytest.mark.usefixtures("mock_processor_and_model") +def test_nougat_image_with_metadata(setup_nougat): + result = setup_nougat( + os.path.join( + "sample_images", + "https://images.unsplash.com/photo-1697273300766-5bbaa53ec2f0?auto=format&fit=crop&q=60&w=400&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHx0b3BpYy1mZWVkfDE5fEpwZzZLaWRsLUhrfHxlbnwwfHx8fHw%3D", + ) + ) + assert isinstance(result, str)