model tests for biogpt, etc

Former-commit-id: aed02ffaac
discord-bot-framework
Kye 1 year ago
parent 85ec9bac7d
commit b0c1dd9e50

@ -1,5 +1,11 @@
import pytest import pytest
from swarms.chunkers.base import BaseChunker, TextArtifact, ChunkSeparator, OpenAiTokenizer # adjust the import paths accordingly from swarms.chunkers.base import (
BaseChunker,
TextArtifact,
ChunkSeparator,
OpenAiTokenizer,
) # adjust the import paths accordingly
# 1. Test Initialization # 1. Test Initialization
def test_chunker_initialization(): def test_chunker_initialization():
@ -7,14 +13,17 @@ def test_chunker_initialization():
assert isinstance(chunker, BaseChunker) assert isinstance(chunker, BaseChunker)
assert chunker.max_tokens == chunker.tokenizer.max_tokens assert chunker.max_tokens == chunker.tokenizer.max_tokens
def test_default_separators(): def test_default_separators():
chunker = BaseChunker() chunker = BaseChunker()
assert chunker.separators == BaseChunker.DEFAULT_SEPARATORS assert chunker.separators == BaseChunker.DEFAULT_SEPARATORS
def test_default_tokenizer(): def test_default_tokenizer():
chunker = BaseChunker() chunker = BaseChunker()
assert isinstance(chunker.tokenizer, OpenAiTokenizer) assert isinstance(chunker.tokenizer, OpenAiTokenizer)
# 2. Test Basic Chunking # 2. Test Basic Chunking
@pytest.mark.parametrize( @pytest.mark.parametrize(
"input_text, expected_output", "input_text, expected_output",
@ -29,6 +38,7 @@ def test_basic_chunk(input_text, expected_output):
result = chunker.chunk(input_text) result = chunker.chunk(input_text)
assert result == expected_output assert result == expected_output
# 3. Test Chunking with Different Separators # 3. Test Chunking with Different Separators
def test_custom_separators(): def test_custom_separators():
custom_separator = ChunkSeparator(";") custom_separator = ChunkSeparator(";")
@ -38,6 +48,7 @@ def test_custom_separators():
result = chunker.chunk(input_text) result = chunker.chunk(input_text)
assert result == expected_output assert result == expected_output
# 4. Test Recursive Chunking # 4. Test Recursive Chunking
def test_recursive_chunking(): def test_recursive_chunking():
chunker = BaseChunker(max_tokens=5) chunker = BaseChunker(max_tokens=5)
@ -47,22 +58,25 @@ def test_recursive_chunking():
TextArtifact("is a"), TextArtifact("is a"),
TextArtifact("more"), TextArtifact("more"),
TextArtifact("complex"), TextArtifact("complex"),
TextArtifact("text.") TextArtifact("text."),
] ]
result = chunker.chunk(input_text) result = chunker.chunk(input_text)
assert result == expected_output assert result == expected_output
# 5. Test Edge Cases and Special Scenarios # 5. Test Edge Cases and Special Scenarios
def test_empty_text(): def test_empty_text():
chunker = BaseChunker() chunker = BaseChunker()
result = chunker.chunk("") result = chunker.chunk("")
assert result == [] assert result == []
def test_whitespace_text(): def test_whitespace_text():
chunker = BaseChunker() chunker = BaseChunker()
result = chunker.chunk(" ") result = chunker.chunk(" ")
assert result == [TextArtifact(" ")] assert result == [TextArtifact(" ")]
def test_single_word(): def test_single_word():
chunker = BaseChunker() chunker = BaseChunker()
result = chunker.chunk("Hello") result = chunker.chunk("Hello")

@ -0,0 +1,207 @@
from unittest.mock import patch
# Import necessary modules
import pytest
import torch
from transformers import BioGptForCausalLM, BioGptTokenizer
# Fixture for BioGPT instance
@pytest.fixture
def biogpt_instance():
from swarms.models import (
BioGPT,
)
return BioGPT()
# 36. Test if BioGPT provides a response for a simple biomedical question
def test_biomedical_response_1(biogpt_instance):
question = "What are the functions of the mitochondria?"
response = biogpt_instance(question)
assert response and isinstance(response, str)
# 37. Test for a genetics-based question
def test_genetics_response(biogpt_instance):
question = "Can you explain the Mendelian inheritance?"
response = biogpt_instance(question)
assert response and isinstance(response, str)
# 38. Test for a question about viruses
def test_virus_response(biogpt_instance):
question = "How do RNA viruses replicate?"
response = biogpt_instance(question)
assert response and isinstance(response, str)
# 39. Test for a cell biology related question
def test_cell_biology_response(biogpt_instance):
question = "Describe the cell cycle and its phases."
response = biogpt_instance(question)
assert response and isinstance(response, str)
# 40. Test for a question about protein structure
def test_protein_structure_response(biogpt_instance):
question = "What's the difference between alpha helix and beta sheet structures in proteins?"
response = biogpt_instance(question)
assert response and isinstance(response, str)
# 41. Test for a pharmacology question
def test_pharmacology_response(biogpt_instance):
question = "How do beta blockers work?"
response = biogpt_instance(question)
assert response and isinstance(response, str)
# 42. Test for an anatomy-based question
def test_anatomy_response(biogpt_instance):
question = "Describe the structure of the human heart."
response = biogpt_instance(question)
assert response and isinstance(response, str)
# 43. Test for a question about bioinformatics
def test_bioinformatics_response(biogpt_instance):
question = "What is a BLAST search?"
response = biogpt_instance(question)
assert response and isinstance(response, str)
# 44. Test for a neuroscience question
def test_neuroscience_response(biogpt_instance):
question = "Explain the function of synapses in the nervous system."
response = biogpt_instance(question)
assert response and isinstance(response, str)
# 45. Test for an immunology question
def test_immunology_response(biogpt_instance):
question = "What is the role of T cells in the immune response?"
response = biogpt_instance(question)
assert response and isinstance(response, str)
def test_init(bio_gpt):
assert bio_gpt.model_name == "microsoft/biogpt"
assert bio_gpt.max_length == 500
assert bio_gpt.num_return_sequences == 5
assert bio_gpt.do_sample is True
assert bio_gpt.min_length == 100
def test_call(bio_gpt, monkeypatch):
def mock_pipeline(*args, **kwargs):
class MockGenerator:
def __call__(self, text, **kwargs):
return ["Generated text"]
return MockGenerator()
monkeypatch.setattr("transformers.pipeline", mock_pipeline)
result = bio_gpt("Input text")
assert result == ["Generated text"]
def test_get_features(bio_gpt):
features = bio_gpt.get_features("Input text")
assert "last_hidden_state" in features
def test_beam_search_decoding(bio_gpt):
generated_text = bio_gpt.beam_search_decoding("Input text")
assert isinstance(generated_text, str)
def test_set_pretrained_model(bio_gpt):
bio_gpt.set_pretrained_model("new_model")
assert bio_gpt.model_name == "new_model"
def test_get_config(bio_gpt):
config = bio_gpt.get_config()
assert "vocab_size" in config
def test_save_load_model(tmp_path, bio_gpt):
bio_gpt.save_model(tmp_path)
bio_gpt.load_from_path(tmp_path)
assert bio_gpt.model_name == "microsoft/biogpt"
def test_print_model(capsys, bio_gpt):
bio_gpt.print_model()
captured = capsys.readouterr()
assert "BioGptForCausalLM" in captured.out
# 26. Test if set_pretrained_model changes the model_name
def test_set_pretrained_model_name_change(biogpt_instance):
biogpt_instance.set_pretrained_model("new_model_name")
assert biogpt_instance.model_name == "new_model_name"
# 27. Test get_config return type
def test_get_config_return_type(biogpt_instance):
config = biogpt_instance.get_config()
assert isinstance(config, type(biogpt_instance.model.config))
# 28. Test saving model functionality by checking if files are created
@patch.object(BioGptForCausalLM, "save_pretrained")
@patch.object(BioGptTokenizer, "save_pretrained")
def test_save_model(mock_save_model, mock_save_tokenizer, biogpt_instance):
path = "test_path"
biogpt_instance.save_model(path)
mock_save_model.assert_called_once_with(path)
mock_save_tokenizer.assert_called_once_with(path)
# 29. Test loading model from path
@patch.object(BioGptForCausalLM, "from_pretrained")
@patch.object(BioGptTokenizer, "from_pretrained")
def test_load_from_path(mock_load_model, mock_load_tokenizer, biogpt_instance):
path = "test_path"
biogpt_instance.load_from_path(path)
mock_load_model.assert_called_once_with(path)
mock_load_tokenizer.assert_called_once_with(path)
# 30. Test print_model doesn't raise any error
def test_print_model_metadata(biogpt_instance):
try:
biogpt_instance.print_model()
except Exception as e:
pytest.fail(f"print_model() raised an exception: {e}")
# 31. Test that beam_search_decoding uses the correct number of beams
@patch.object(BioGptForCausalLM, "generate")
def test_beam_search_decoding_num_beams(mock_generate, biogpt_instance):
biogpt_instance.beam_search_decoding("test_sentence", num_beams=7)
_, kwargs = mock_generate.call_args
assert kwargs["num_beams"] == 7
# 32. Test if beam_search_decoding handles early_stopping
@patch.object(BioGptForCausalLM, "generate")
def test_beam_search_decoding_early_stopping(mock_generate, biogpt_instance):
biogpt_instance.beam_search_decoding("test_sentence", early_stopping=False)
_, kwargs = mock_generate.call_args
assert kwargs["early_stopping"] is False
# 33. Test get_features return type
def test_get_features_return_type(biogpt_instance):
result = biogpt_instance.get_features("This is a sample text.")
assert isinstance(result, torch.nn.modules.module.Module)
# 34. Test if default model is set correctly during initialization
def test_default_model_name(biogpt_instance):
assert biogpt_instance.model_name == "microsoft/biogpt"

@ -5,11 +5,13 @@ from swarms.models import Fuyu
from transformers import FuyuProcessor, FuyuImageProcessor from transformers import FuyuProcessor, FuyuImageProcessor
from PIL import Image from PIL import Image
# Basic test to ensure instantiation of class. # Basic test to ensure instantiation of class.
def test_fuyu_initialization(): def test_fuyu_initialization():
fuyu_instance = Fuyu() fuyu_instance = Fuyu()
assert isinstance(fuyu_instance, Fuyu) assert isinstance(fuyu_instance, Fuyu)
# Using parameterized testing for different init parameters. # Using parameterized testing for different init parameters.
@pytest.mark.parametrize( @pytest.mark.parametrize(
"pretrained_path, device_map, max_new_tokens", "pretrained_path, device_map, max_new_tokens",
@ -24,73 +26,92 @@ def test_fuyu_parameters(pretrained_path, device_map, max_new_tokens):
assert fuyu_instance.device_map == device_map assert fuyu_instance.device_map == device_map
assert fuyu_instance.max_new_tokens == max_new_tokens assert fuyu_instance.max_new_tokens == max_new_tokens
# Fixture for creating a Fuyu instance. # Fixture for creating a Fuyu instance.
@pytest.fixture @pytest.fixture
def fuyu_instance(): def fuyu_instance():
return Fuyu() return Fuyu()
# Test using the fixture. # Test using the fixture.
def test_fuyu_processor_initialization(fuyu_instance): def test_fuyu_processor_initialization(fuyu_instance):
assert isinstance(fuyu_instance.processor, FuyuProcessor) assert isinstance(fuyu_instance.processor, FuyuProcessor)
assert isinstance(fuyu_instance.image_processor, FuyuImageProcessor) assert isinstance(fuyu_instance.image_processor, FuyuImageProcessor)
# Test exception when providing an invalid image path. # Test exception when providing an invalid image path.
def test_invalid_image_path(fuyu_instance): def test_invalid_image_path(fuyu_instance):
with pytest.raises(FileNotFoundError): with pytest.raises(FileNotFoundError):
fuyu_instance("Hello", "invalid/path/to/image.png") fuyu_instance("Hello", "invalid/path/to/image.png")
# Using monkeypatch to replace the Image.open method to simulate a failure. # Using monkeypatch to replace the Image.open method to simulate a failure.
def test_image_open_failure(fuyu_instance, monkeypatch): def test_image_open_failure(fuyu_instance, monkeypatch):
def mock_open(*args, **kwargs): def mock_open(*args, **kwargs):
raise Exception("Mocked failure") raise Exception("Mocked failure")
monkeypatch.setattr(Image, "open", mock_open) monkeypatch.setattr(Image, "open", mock_open)
with pytest.raises(Exception, match="Mocked failure"): with pytest.raises(Exception, match="Mocked failure"):
fuyu_instance("Hello", "https://plus.unsplash.com/premium_photo-1687149699194-0207c04bc6e8?auto=format&fit=crop&q=80&w=1378&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D") fuyu_instance(
"Hello",
"https://plus.unsplash.com/premium_photo-1687149699194-0207c04bc6e8?auto=format&fit=crop&q=80&w=1378&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D",
)
# Marking a slow test. # Marking a slow test.
@pytest.mark.slow @pytest.mark.slow
def test_fuyu_model_output(fuyu_instance): def test_fuyu_model_output(fuyu_instance):
# This is a dummy test and may not be functional without real data. # This is a dummy test and may not be functional without real data.
output = fuyu_instance("Hello, my name is", "https://plus.unsplash.com/premium_photo-1687149699194-0207c04bc6e8?auto=format&fit=crop&q=80&w=1378&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D") output = fuyu_instance(
"Hello, my name is",
"https://plus.unsplash.com/premium_photo-1687149699194-0207c04bc6e8?auto=format&fit=crop&q=80&w=1378&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D",
)
assert isinstance(output, str) assert isinstance(output, str)
def test_tokenizer_type(fuyu_instance): def test_tokenizer_type(fuyu_instance):
assert "tokenizer" in dir(fuyu_instance) assert "tokenizer" in dir(fuyu_instance)
def test_processor_has_image_processor_and_tokenizer(fuyu_instance): def test_processor_has_image_processor_and_tokenizer(fuyu_instance):
assert fuyu_instance.processor.image_processor == fuyu_instance.image_processor assert fuyu_instance.processor.image_processor == fuyu_instance.image_processor
assert fuyu_instance.processor.tokenizer == fuyu_instance.tokenizer assert fuyu_instance.processor.tokenizer == fuyu_instance.tokenizer
def test_model_device_map(fuyu_instance): def test_model_device_map(fuyu_instance):
assert fuyu_instance.model.device_map == fuyu_instance.device_map assert fuyu_instance.model.device_map == fuyu_instance.device_map
# Testing maximum tokens setting # Testing maximum tokens setting
def test_max_new_tokens_setting(fuyu_instance): def test_max_new_tokens_setting(fuyu_instance):
assert fuyu_instance.max_new_tokens == 7 assert fuyu_instance.max_new_tokens == 7
# Test if an exception is raised when invalid text is provided. # Test if an exception is raised when invalid text is provided.
def test_invalid_text_input(fuyu_instance): def test_invalid_text_input(fuyu_instance):
with pytest.raises(Exception): with pytest.raises(Exception):
fuyu_instance(None, "path/to/image.png") fuyu_instance(None, "path/to/image.png")
# Test if an exception is raised when empty text is provided. # Test if an exception is raised when empty text is provided.
def test_empty_text_input(fuyu_instance): def test_empty_text_input(fuyu_instance):
with pytest.raises(Exception): with pytest.raises(Exception):
fuyu_instance("", "path/to/image.png") fuyu_instance("", "path/to/image.png")
# Test if an exception is raised when a very long text is provided. # Test if an exception is raised when a very long text is provided.
def test_very_long_text_input(fuyu_instance): def test_very_long_text_input(fuyu_instance):
with pytest.raises(Exception): with pytest.raises(Exception):
fuyu_instance("A" * 10000, "path/to/image.png") fuyu_instance("A" * 10000, "path/to/image.png")
# Check model's default device map # Check model's default device map
def test_default_device_map(): def test_default_device_map():
fuyu_instance = Fuyu() fuyu_instance = Fuyu()
assert fuyu_instance.device_map == "cuda:0" assert fuyu_instance.device_map == "cuda:0"
# Testing if processor is correctly initialized # Testing if processor is correctly initialized
def test_processor_initialization(fuyu_instance): def test_processor_initialization(fuyu_instance):
assert isinstance(fuyu_instance.processor, FuyuProcessor) assert isinstance(fuyu_instance.processor, FuyuProcessor)

@ -3,36 +3,45 @@ from unittest.mock import patch
import torch import torch
from swarms.models.idefics import Idefics, IdeficsForVisionText2Text, AutoProcessor from swarms.models.idefics import Idefics, IdeficsForVisionText2Text, AutoProcessor
@pytest.fixture @pytest.fixture
def idefics_instance(): def idefics_instance():
with patch("torch.cuda.is_available", return_value=False): # Assuming tests are run on CPU for simplicity with patch(
"torch.cuda.is_available", return_value=False
): # Assuming tests are run on CPU for simplicity
instance = Idefics() instance = Idefics()
return instance return instance
# Basic Tests # Basic Tests
def test_init_default(idefics_instance): def test_init_default(idefics_instance):
assert idefics_instance.device == "cpu" assert idefics_instance.device == "cpu"
assert idefics_instance.max_length == 100 assert idefics_instance.max_length == 100
assert not idefics_instance.chat_history assert not idefics_instance.chat_history
@pytest.mark.parametrize( @pytest.mark.parametrize(
"device,expected", "device,expected",
[ [
(None, "cpu"), (None, "cpu"),
("cuda", "cuda"), ("cuda", "cuda"),
("cpu", "cpu"), ("cpu", "cpu"),
] ],
) )
def test_init_device(device, expected): def test_init_device(device, expected):
with patch("torch.cuda.is_available", return_value=True if expected == "cuda" else False): with patch(
"torch.cuda.is_available", return_value=True if expected == "cuda" else False
):
instance = Idefics(device=device) instance = Idefics(device=device)
assert instance.device == expected assert instance.device == expected
# Test `run` method # Test `run` method
def test_run(idefics_instance): def test_run(idefics_instance):
prompts = [["User: Test"]] prompts = [["User: Test"]]
with patch.object(idefics_instance, "processor") as mock_processor, \ with patch.object(idefics_instance, "processor") as mock_processor, patch.object(
patch.object(idefics_instance, "model") as mock_model: idefics_instance, "model"
) as mock_model:
mock_processor.return_value = {"input_ids": torch.tensor([1, 2, 3])} mock_processor.return_value = {"input_ids": torch.tensor([1, 2, 3])}
mock_model.generate.return_value = torch.tensor([1, 2, 3]) mock_model.generate.return_value = torch.tensor([1, 2, 3])
mock_processor.batch_decode.return_value = ["Test"] mock_processor.batch_decode.return_value = ["Test"]
@ -41,11 +50,13 @@ def test_run(idefics_instance):
assert result == ["Test"] assert result == ["Test"]
# Test `__call__` method (using the same logic as run for simplicity) # Test `__call__` method (using the same logic as run for simplicity)
def test_call(idefics_instance): def test_call(idefics_instance):
prompts = [["User: Test"]] prompts = [["User: Test"]]
with patch.object(idefics_instance, "processor") as mock_processor, \ with patch.object(idefics_instance, "processor") as mock_processor, patch.object(
patch.object(idefics_instance, "model") as mock_model: idefics_instance, "model"
) as mock_model:
mock_processor.return_value = {"input_ids": torch.tensor([1, 2, 3])} mock_processor.return_value = {"input_ids": torch.tensor([1, 2, 3])}
mock_model.generate.return_value = torch.tensor([1, 2, 3]) mock_model.generate.return_value = torch.tensor([1, 2, 3])
mock_processor.batch_decode.return_value = ["Test"] mock_processor.batch_decode.return_value = ["Test"]
@ -54,6 +65,7 @@ def test_call(idefics_instance):
assert result == ["Test"] assert result == ["Test"]
# Test `chat` method # Test `chat` method
def test_chat(idefics_instance): def test_chat(idefics_instance):
user_input = "User: Hello" user_input = "User: Hello"
@ -64,15 +76,18 @@ def test_chat(idefics_instance):
assert result == response assert result == response
assert idefics_instance.chat_history == [user_input, response] assert idefics_instance.chat_history == [user_input, response]
# Test `set_checkpoint` method # Test `set_checkpoint` method
def test_set_checkpoint(idefics_instance): def test_set_checkpoint(idefics_instance):
new_checkpoint = "new_checkpoint" new_checkpoint = "new_checkpoint"
with patch.object(IdeficsForVisionText2Text, "from_pretrained") as mock_from_pretrained, \ with patch.object(
patch.object(AutoProcessor, "from_pretrained"): IdeficsForVisionText2Text, "from_pretrained"
) as mock_from_pretrained, patch.object(AutoProcessor, "from_pretrained"):
idefics_instance.set_checkpoint(new_checkpoint) idefics_instance.set_checkpoint(new_checkpoint)
mock_from_pretrained.assert_called_with(new_checkpoint, torch_dtype=torch.bfloat16) mock_from_pretrained.assert_called_with(new_checkpoint, torch_dtype=torch.bfloat16)
# Test `set_device` method # Test `set_device` method
def test_set_device(idefics_instance): def test_set_device(idefics_instance):
new_device = "cuda" new_device = "cuda"
@ -81,19 +96,24 @@ def test_set_device(idefics_instance):
assert idefics_instance.device == new_device assert idefics_instance.device == new_device
# Test `set_max_length` method # Test `set_max_length` method
def test_set_max_length(idefics_instance): def test_set_max_length(idefics_instance):
new_length = 150 new_length = 150
idefics_instance.set_max_length(new_length) idefics_instance.set_max_length(new_length)
assert idefics_instance.max_length == new_length assert idefics_instance.max_length == new_length
# Test `clear_chat_history` method # Test `clear_chat_history` method
def test_clear_chat_history(idefics_instance): def test_clear_chat_history(idefics_instance):
idefics_instance.chat_history = ["User: Test", "Model: Response"] idefics_instance.chat_history = ["User: Test", "Model: Response"]
idefics_instance.clear_chat_history() idefics_instance.clear_chat_history()
assert not idefics_instance.chat_history assert not idefics_instance.chat_history
# Exception Tests # Exception Tests
def test_run_with_empty_prompts(idefics_instance): def test_run_with_empty_prompts(idefics_instance):
with pytest.raises(Exception): # Replace Exception with the actual exception that may arise for an empty prompt. with pytest.raises(
Exception
): # Replace Exception with the actual exception that may arise for an empty prompt.
idefics_instance.run([]) idefics_instance.run([])

@ -2,20 +2,23 @@ import pytest
from unittest.mock import patch, Mock from unittest.mock import patch, Mock
from swarms.models.vilt import Vilt, Image, requests from swarms.models.vilt import Vilt, Image, requests
# Fixture for Vilt instance # Fixture for Vilt instance
@pytest.fixture @pytest.fixture
def vilt_instance(): def vilt_instance():
return Vilt() return Vilt()
# 1. Test Initialization # 1. Test Initialization
def test_vilt_initialization(vilt_instance): def test_vilt_initialization(vilt_instance):
assert isinstance(vilt_instance, Vilt) assert isinstance(vilt_instance, Vilt)
assert vilt_instance.processor is not None assert vilt_instance.processor is not None
assert vilt_instance.model is not None assert vilt_instance.model is not None
# 2. Test Model Predictions # 2. Test Model Predictions
@patch.object(requests, 'get') @patch.object(requests, "get")
@patch.object(Image, 'open') @patch.object(Image, "open")
def test_vilt_prediction(mock_image_open, mock_requests_get, vilt_instance): def test_vilt_prediction(mock_image_open, mock_requests_get, vilt_instance):
mock_image = Mock() mock_image = Mock()
mock_image_open.return_value = mock_image mock_image_open.return_value = mock_image
@ -23,13 +26,21 @@ def test_vilt_prediction(mock_image_open, mock_requests_get, vilt_instance):
# It's a mock response, so no real answer expected # It's a mock response, so no real answer expected
with pytest.raises(Exception): # Ensure exception is more specific with pytest.raises(Exception): # Ensure exception is more specific
vilt_instance("What is this image", "https://images.unsplash.com/photo-1582538885592-e70a5d7ab3d3?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1770&q=80") vilt_instance(
"What is this image",
"https://images.unsplash.com/photo-1582538885592-e70a5d7ab3d3?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1770&q=80",
)
# 3. Test Exception Handling for network # 3. Test Exception Handling for network
@patch.object(requests, 'get', side_effect=requests.RequestException("Network error")) @patch.object(requests, "get", side_effect=requests.RequestException("Network error"))
def test_vilt_network_exception(vilt_instance): def test_vilt_network_exception(vilt_instance):
with pytest.raises(requests.RequestException): with pytest.raises(requests.RequestException):
vilt_instance("What is this image", "https://images.unsplash.com/photo-1582538885592-e70a5d7ab3d3?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1770&q=80") vilt_instance(
"What is this image",
"https://images.unsplash.com/photo-1582538885592-e70a5d7ab3d3?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1770&q=80",
)
# Parameterized test cases for different inputs # Parameterized test cases for different inputs
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -39,26 +50,37 @@ def test_vilt_network_exception(vilt_instance):
("Who is in the image?", "http://example.com/image2.jpg"), ("Who is in the image?", "http://example.com/image2.jpg"),
("Where was this picture taken?", "http://example.com/image3.jpg"), ("Where was this picture taken?", "http://example.com/image3.jpg"),
# ... Add more scenarios # ... Add more scenarios
] ],
) )
def test_vilt_various_inputs(text, image_url, vilt_instance): def test_vilt_various_inputs(text, image_url, vilt_instance):
with pytest.raises(Exception): # Again, ensure exception is more specific with pytest.raises(Exception): # Again, ensure exception is more specific
vilt_instance(text, image_url) vilt_instance(text, image_url)
# Test with invalid or empty text # Test with invalid or empty text
@pytest.mark.parametrize( @pytest.mark.parametrize(
"text,image_url", "text,image_url",
[ [
("", "https://images.unsplash.com/photo-1582538885592-e70a5d7ab3d3?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1770&q=80"), (
(None, "https://images.unsplash.com/photo-1582538885592-e70a5d7ab3d3?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1770&q=80"), "",
(" ", "https://images.unsplash.com/photo-1582538885592-e70a5d7ab3d3?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1770&q=80"), "https://images.unsplash.com/photo-1582538885592-e70a5d7ab3d3?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1770&q=80",
),
(
None,
"https://images.unsplash.com/photo-1582538885592-e70a5d7ab3d3?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1770&q=80",
),
(
" ",
"https://images.unsplash.com/photo-1582538885592-e70a5d7ab3d3?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1770&q=80",
),
# ... Add more scenarios # ... Add more scenarios
] ],
) )
def test_vilt_invalid_text(text, image_url, vilt_instance): def test_vilt_invalid_text(text, image_url, vilt_instance):
with pytest.raises(ValueError): with pytest.raises(ValueError):
vilt_instance(text, image_url) vilt_instance(text, image_url)
# Test with invalid or empty image_url # Test with invalid or empty image_url
@pytest.mark.parametrize( @pytest.mark.parametrize(
"text,image_url", "text,image_url",
@ -66,9 +88,8 @@ def test_vilt_invalid_text(text, image_url, vilt_instance):
("What is this?", ""), ("What is this?", ""),
("Who is in the image?", None), ("Who is in the image?", None),
("Where was this picture taken?", " "), ("Where was this picture taken?", " "),
] ],
) )
def test_vilt_invalid_image_url(text, image_url, vilt_instance): def test_vilt_invalid_image_url(text, image_url, vilt_instance):
with pytest.raises(ValueError): with pytest.raises(ValueError):
vilt_instance(text, image_url) vilt_instance(text, image_url)

Loading…
Cancel
Save