From aed02ffaac9fbba6f1b9e32cab65070816a37937 Mon Sep 17 00:00:00 2001 From: Kye Date: Fri, 20 Oct 2023 14:44:29 -0400 Subject: [PATCH] model tests for biogpt, etc --- swarms/models/idefics.py | 2 +- tests/chunkers/basechunker.py | 18 ++- tests/models/biogpt.py | 207 ++++++++++++++++++++++++++++++++++ tests/models/fuyu.py | 27 ++++- tests/models/idefics.py | 56 ++++++--- tests/models/vilt.py | 45 ++++++-- 6 files changed, 319 insertions(+), 36 deletions(-) create mode 100644 tests/models/biogpt.py diff --git a/swarms/models/idefics.py b/swarms/models/idefics.py index 6bb8082d..73cb4991 100644 --- a/swarms/models/idefics.py +++ b/swarms/models/idefics.py @@ -39,7 +39,7 @@ class Idefics: # Usage ``` from swarms.models import idefics - + model = idefics() user_input = "User: What is in this image? https://upload.wikimedia.org/wikipedia/commons/8/86/Id%C3%A9fix.JPG" diff --git a/tests/chunkers/basechunker.py b/tests/chunkers/basechunker.py index b6ced038..f70705bc 100644 --- a/tests/chunkers/basechunker.py +++ b/tests/chunkers/basechunker.py @@ -1,5 +1,11 @@ import pytest -from swarms.chunkers.base import BaseChunker, TextArtifact, ChunkSeparator, OpenAiTokenizer # adjust the import paths accordingly +from swarms.chunkers.base import ( + BaseChunker, + TextArtifact, + ChunkSeparator, + OpenAiTokenizer, +) # adjust the import paths accordingly + # 1. Test Initialization def test_chunker_initialization(): @@ -7,14 +13,17 @@ def test_chunker_initialization(): assert isinstance(chunker, BaseChunker) assert chunker.max_tokens == chunker.tokenizer.max_tokens + def test_default_separators(): chunker = BaseChunker() assert chunker.separators == BaseChunker.DEFAULT_SEPARATORS + def test_default_tokenizer(): chunker = BaseChunker() assert isinstance(chunker.tokenizer, OpenAiTokenizer) + # 2. Test Basic Chunking @pytest.mark.parametrize( "input_text, expected_output", @@ -29,6 +38,7 @@ def test_basic_chunk(input_text, expected_output): result = chunker.chunk(input_text) assert result == expected_output + # 3. Test Chunking with Different Separators def test_custom_separators(): custom_separator = ChunkSeparator(";") @@ -38,6 +48,7 @@ def test_custom_separators(): result = chunker.chunk(input_text) assert result == expected_output + # 4. Test Recursive Chunking def test_recursive_chunking(): chunker = BaseChunker(max_tokens=5) @@ -47,22 +58,25 @@ def test_recursive_chunking(): TextArtifact("is a"), TextArtifact("more"), TextArtifact("complex"), - TextArtifact("text.") + TextArtifact("text."), ] result = chunker.chunk(input_text) assert result == expected_output + # 5. Test Edge Cases and Special Scenarios def test_empty_text(): chunker = BaseChunker() result = chunker.chunk("") assert result == [] + def test_whitespace_text(): chunker = BaseChunker() result = chunker.chunk(" ") assert result == [TextArtifact(" ")] + def test_single_word(): chunker = BaseChunker() result = chunker.chunk("Hello") diff --git a/tests/models/biogpt.py b/tests/models/biogpt.py new file mode 100644 index 00000000..29cbe86c --- /dev/null +++ b/tests/models/biogpt.py @@ -0,0 +1,207 @@ +from unittest.mock import patch + +# Import necessary modules +import pytest +import torch +from transformers import BioGptForCausalLM, BioGptTokenizer + + + +# Fixture for BioGPT instance +@pytest.fixture +def biogpt_instance(): + from swarms.models import ( + BioGPT, + ) + + return BioGPT() + + +# 36. Test if BioGPT provides a response for a simple biomedical question +def test_biomedical_response_1(biogpt_instance): + question = "What are the functions of the mitochondria?" + response = biogpt_instance(question) + assert response and isinstance(response, str) + + +# 37. Test for a genetics-based question +def test_genetics_response(biogpt_instance): + question = "Can you explain the Mendelian inheritance?" + response = biogpt_instance(question) + assert response and isinstance(response, str) + + +# 38. Test for a question about viruses +def test_virus_response(biogpt_instance): + question = "How do RNA viruses replicate?" + response = biogpt_instance(question) + assert response and isinstance(response, str) + + +# 39. Test for a cell biology related question +def test_cell_biology_response(biogpt_instance): + question = "Describe the cell cycle and its phases." + response = biogpt_instance(question) + assert response and isinstance(response, str) + + +# 40. Test for a question about protein structure +def test_protein_structure_response(biogpt_instance): + question = "What's the difference between alpha helix and beta sheet structures in proteins?" + response = biogpt_instance(question) + assert response and isinstance(response, str) + + +# 41. Test for a pharmacology question +def test_pharmacology_response(biogpt_instance): + question = "How do beta blockers work?" + response = biogpt_instance(question) + assert response and isinstance(response, str) + + +# 42. Test for an anatomy-based question +def test_anatomy_response(biogpt_instance): + question = "Describe the structure of the human heart." + response = biogpt_instance(question) + assert response and isinstance(response, str) + + +# 43. Test for a question about bioinformatics +def test_bioinformatics_response(biogpt_instance): + question = "What is a BLAST search?" + response = biogpt_instance(question) + assert response and isinstance(response, str) + + +# 44. Test for a neuroscience question +def test_neuroscience_response(biogpt_instance): + question = "Explain the function of synapses in the nervous system." + response = biogpt_instance(question) + assert response and isinstance(response, str) + + +# 45. Test for an immunology question +def test_immunology_response(biogpt_instance): + question = "What is the role of T cells in the immune response?" + response = biogpt_instance(question) + assert response and isinstance(response, str) + + +def test_init(bio_gpt): + assert bio_gpt.model_name == "microsoft/biogpt" + assert bio_gpt.max_length == 500 + assert bio_gpt.num_return_sequences == 5 + assert bio_gpt.do_sample is True + assert bio_gpt.min_length == 100 + + +def test_call(bio_gpt, monkeypatch): + def mock_pipeline(*args, **kwargs): + class MockGenerator: + def __call__(self, text, **kwargs): + return ["Generated text"] + + return MockGenerator() + + monkeypatch.setattr("transformers.pipeline", mock_pipeline) + result = bio_gpt("Input text") + assert result == ["Generated text"] + + +def test_get_features(bio_gpt): + features = bio_gpt.get_features("Input text") + assert "last_hidden_state" in features + + +def test_beam_search_decoding(bio_gpt): + generated_text = bio_gpt.beam_search_decoding("Input text") + assert isinstance(generated_text, str) + + +def test_set_pretrained_model(bio_gpt): + bio_gpt.set_pretrained_model("new_model") + assert bio_gpt.model_name == "new_model" + + +def test_get_config(bio_gpt): + config = bio_gpt.get_config() + assert "vocab_size" in config + + +def test_save_load_model(tmp_path, bio_gpt): + bio_gpt.save_model(tmp_path) + bio_gpt.load_from_path(tmp_path) + assert bio_gpt.model_name == "microsoft/biogpt" + + +def test_print_model(capsys, bio_gpt): + bio_gpt.print_model() + captured = capsys.readouterr() + assert "BioGptForCausalLM" in captured.out + + +# 26. Test if set_pretrained_model changes the model_name +def test_set_pretrained_model_name_change(biogpt_instance): + biogpt_instance.set_pretrained_model("new_model_name") + assert biogpt_instance.model_name == "new_model_name" + + +# 27. Test get_config return type +def test_get_config_return_type(biogpt_instance): + config = biogpt_instance.get_config() + assert isinstance(config, type(biogpt_instance.model.config)) + + +# 28. Test saving model functionality by checking if files are created +@patch.object(BioGptForCausalLM, "save_pretrained") +@patch.object(BioGptTokenizer, "save_pretrained") +def test_save_model(mock_save_model, mock_save_tokenizer, biogpt_instance): + path = "test_path" + biogpt_instance.save_model(path) + mock_save_model.assert_called_once_with(path) + mock_save_tokenizer.assert_called_once_with(path) + + +# 29. Test loading model from path +@patch.object(BioGptForCausalLM, "from_pretrained") +@patch.object(BioGptTokenizer, "from_pretrained") +def test_load_from_path(mock_load_model, mock_load_tokenizer, biogpt_instance): + path = "test_path" + biogpt_instance.load_from_path(path) + mock_load_model.assert_called_once_with(path) + mock_load_tokenizer.assert_called_once_with(path) + + +# 30. Test print_model doesn't raise any error +def test_print_model_metadata(biogpt_instance): + try: + biogpt_instance.print_model() + except Exception as e: + pytest.fail(f"print_model() raised an exception: {e}") + + +# 31. Test that beam_search_decoding uses the correct number of beams +@patch.object(BioGptForCausalLM, "generate") +def test_beam_search_decoding_num_beams(mock_generate, biogpt_instance): + biogpt_instance.beam_search_decoding("test_sentence", num_beams=7) + _, kwargs = mock_generate.call_args + assert kwargs["num_beams"] == 7 + + +# 32. Test if beam_search_decoding handles early_stopping +@patch.object(BioGptForCausalLM, "generate") +def test_beam_search_decoding_early_stopping(mock_generate, biogpt_instance): + biogpt_instance.beam_search_decoding("test_sentence", early_stopping=False) + _, kwargs = mock_generate.call_args + assert kwargs["early_stopping"] is False + + +# 33. Test get_features return type +def test_get_features_return_type(biogpt_instance): + result = biogpt_instance.get_features("This is a sample text.") + assert isinstance(result, torch.nn.modules.module.Module) + + +# 34. Test if default model is set correctly during initialization +def test_default_model_name(biogpt_instance): + assert biogpt_instance.model_name == "microsoft/biogpt" diff --git a/tests/models/fuyu.py b/tests/models/fuyu.py index fddb172a..9a26dbfb 100644 --- a/tests/models/fuyu.py +++ b/tests/models/fuyu.py @@ -5,11 +5,13 @@ from swarms.models import Fuyu from transformers import FuyuProcessor, FuyuImageProcessor from PIL import Image + # Basic test to ensure instantiation of class. def test_fuyu_initialization(): fuyu_instance = Fuyu() assert isinstance(fuyu_instance, Fuyu) + # Using parameterized testing for different init parameters. @pytest.mark.parametrize( "pretrained_path, device_map, max_new_tokens", @@ -24,73 +26,92 @@ def test_fuyu_parameters(pretrained_path, device_map, max_new_tokens): assert fuyu_instance.device_map == device_map assert fuyu_instance.max_new_tokens == max_new_tokens + # Fixture for creating a Fuyu instance. @pytest.fixture def fuyu_instance(): return Fuyu() + # Test using the fixture. def test_fuyu_processor_initialization(fuyu_instance): assert isinstance(fuyu_instance.processor, FuyuProcessor) assert isinstance(fuyu_instance.image_processor, FuyuImageProcessor) + # Test exception when providing an invalid image path. def test_invalid_image_path(fuyu_instance): with pytest.raises(FileNotFoundError): fuyu_instance("Hello", "invalid/path/to/image.png") + # Using monkeypatch to replace the Image.open method to simulate a failure. def test_image_open_failure(fuyu_instance, monkeypatch): - def mock_open(*args, **kwargs): raise Exception("Mocked failure") monkeypatch.setattr(Image, "open", mock_open) with pytest.raises(Exception, match="Mocked failure"): - fuyu_instance("Hello", "https://plus.unsplash.com/premium_photo-1687149699194-0207c04bc6e8?auto=format&fit=crop&q=80&w=1378&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D") + fuyu_instance( + "Hello", + "https://plus.unsplash.com/premium_photo-1687149699194-0207c04bc6e8?auto=format&fit=crop&q=80&w=1378&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D", + ) + # Marking a slow test. @pytest.mark.slow def test_fuyu_model_output(fuyu_instance): # This is a dummy test and may not be functional without real data. - output = fuyu_instance("Hello, my name is", "https://plus.unsplash.com/premium_photo-1687149699194-0207c04bc6e8?auto=format&fit=crop&q=80&w=1378&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D") + output = fuyu_instance( + "Hello, my name is", + "https://plus.unsplash.com/premium_photo-1687149699194-0207c04bc6e8?auto=format&fit=crop&q=80&w=1378&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D", + ) assert isinstance(output, str) + def test_tokenizer_type(fuyu_instance): assert "tokenizer" in dir(fuyu_instance) + def test_processor_has_image_processor_and_tokenizer(fuyu_instance): assert fuyu_instance.processor.image_processor == fuyu_instance.image_processor assert fuyu_instance.processor.tokenizer == fuyu_instance.tokenizer + def test_model_device_map(fuyu_instance): assert fuyu_instance.model.device_map == fuyu_instance.device_map + # Testing maximum tokens setting def test_max_new_tokens_setting(fuyu_instance): assert fuyu_instance.max_new_tokens == 7 + # Test if an exception is raised when invalid text is provided. def test_invalid_text_input(fuyu_instance): with pytest.raises(Exception): fuyu_instance(None, "path/to/image.png") + # Test if an exception is raised when empty text is provided. def test_empty_text_input(fuyu_instance): with pytest.raises(Exception): fuyu_instance("", "path/to/image.png") + # Test if an exception is raised when a very long text is provided. def test_very_long_text_input(fuyu_instance): with pytest.raises(Exception): fuyu_instance("A" * 10000, "path/to/image.png") + # Check model's default device map def test_default_device_map(): fuyu_instance = Fuyu() assert fuyu_instance.device_map == "cuda:0" + # Testing if processor is correctly initialized def test_processor_initialization(fuyu_instance): assert isinstance(fuyu_instance.processor, FuyuProcessor) diff --git a/tests/models/idefics.py b/tests/models/idefics.py index 2cfee1a4..610657bd 100644 --- a/tests/models/idefics.py +++ b/tests/models/idefics.py @@ -3,97 +3,117 @@ from unittest.mock import patch import torch from swarms.models.idefics import Idefics, IdeficsForVisionText2Text, AutoProcessor + @pytest.fixture def idefics_instance(): - with patch("torch.cuda.is_available", return_value=False): # Assuming tests are run on CPU for simplicity + with patch( + "torch.cuda.is_available", return_value=False + ): # Assuming tests are run on CPU for simplicity instance = Idefics() return instance + # Basic Tests def test_init_default(idefics_instance): assert idefics_instance.device == "cpu" assert idefics_instance.max_length == 100 assert not idefics_instance.chat_history + @pytest.mark.parametrize( "device,expected", [ - (None, "cpu"), + (None, "cpu"), ("cuda", "cuda"), ("cpu", "cpu"), - ] + ], ) def test_init_device(device, expected): - with patch("torch.cuda.is_available", return_value=True if expected == "cuda" else False): + with patch( + "torch.cuda.is_available", return_value=True if expected == "cuda" else False + ): instance = Idefics(device=device) assert instance.device == expected + # Test `run` method def test_run(idefics_instance): prompts = [["User: Test"]] - with patch.object(idefics_instance, "processor") as mock_processor, \ - patch.object(idefics_instance, "model") as mock_model: + with patch.object(idefics_instance, "processor") as mock_processor, patch.object( + idefics_instance, "model" + ) as mock_model: mock_processor.return_value = {"input_ids": torch.tensor([1, 2, 3])} mock_model.generate.return_value = torch.tensor([1, 2, 3]) mock_processor.batch_decode.return_value = ["Test"] - + result = idefics_instance.run(prompts) - + assert result == ["Test"] + # Test `__call__` method (using the same logic as run for simplicity) def test_call(idefics_instance): prompts = [["User: Test"]] - with patch.object(idefics_instance, "processor") as mock_processor, \ - patch.object(idefics_instance, "model") as mock_model: + with patch.object(idefics_instance, "processor") as mock_processor, patch.object( + idefics_instance, "model" + ) as mock_model: mock_processor.return_value = {"input_ids": torch.tensor([1, 2, 3])} mock_model.generate.return_value = torch.tensor([1, 2, 3]) mock_processor.batch_decode.return_value = ["Test"] - + result = idefics_instance(prompts) - + assert result == ["Test"] + # Test `chat` method def test_chat(idefics_instance): user_input = "User: Hello" response = "Model: Hi there!" with patch.object(idefics_instance, "run", return_value=[response]): result = idefics_instance.chat(user_input) - + assert result == response assert idefics_instance.chat_history == [user_input, response] + # Test `set_checkpoint` method def test_set_checkpoint(idefics_instance): new_checkpoint = "new_checkpoint" - with patch.object(IdeficsForVisionText2Text, "from_pretrained") as mock_from_pretrained, \ - patch.object(AutoProcessor, "from_pretrained"): + with patch.object( + IdeficsForVisionText2Text, "from_pretrained" + ) as mock_from_pretrained, patch.object(AutoProcessor, "from_pretrained"): idefics_instance.set_checkpoint(new_checkpoint) - + mock_from_pretrained.assert_called_with(new_checkpoint, torch_dtype=torch.bfloat16) + # Test `set_device` method def test_set_device(idefics_instance): new_device = "cuda" with patch.object(idefics_instance.model, "to"): idefics_instance.set_device(new_device) - + assert idefics_instance.device == new_device + # Test `set_max_length` method def test_set_max_length(idefics_instance): new_length = 150 idefics_instance.set_max_length(new_length) assert idefics_instance.max_length == new_length + # Test `clear_chat_history` method def test_clear_chat_history(idefics_instance): idefics_instance.chat_history = ["User: Test", "Model: Response"] idefics_instance.clear_chat_history() assert not idefics_instance.chat_history + # Exception Tests def test_run_with_empty_prompts(idefics_instance): - with pytest.raises(Exception): # Replace Exception with the actual exception that may arise for an empty prompt. + with pytest.raises( + Exception + ): # Replace Exception with the actual exception that may arise for an empty prompt. idefics_instance.run([]) diff --git a/tests/models/vilt.py b/tests/models/vilt.py index ba9fafba..b376f41b 100644 --- a/tests/models/vilt.py +++ b/tests/models/vilt.py @@ -2,20 +2,23 @@ import pytest from unittest.mock import patch, Mock from swarms.models.vilt import Vilt, Image, requests + # Fixture for Vilt instance @pytest.fixture def vilt_instance(): return Vilt() + # 1. Test Initialization def test_vilt_initialization(vilt_instance): assert isinstance(vilt_instance, Vilt) assert vilt_instance.processor is not None assert vilt_instance.model is not None + # 2. Test Model Predictions -@patch.object(requests, 'get') -@patch.object(Image, 'open') +@patch.object(requests, "get") +@patch.object(Image, "open") def test_vilt_prediction(mock_image_open, mock_requests_get, vilt_instance): mock_image = Mock() mock_image_open.return_value = mock_image @@ -23,13 +26,21 @@ def test_vilt_prediction(mock_image_open, mock_requests_get, vilt_instance): # It's a mock response, so no real answer expected with pytest.raises(Exception): # Ensure exception is more specific - vilt_instance("What is this image", "https://images.unsplash.com/photo-1582538885592-e70a5d7ab3d3?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1770&q=80") + vilt_instance( + "What is this image", + "https://images.unsplash.com/photo-1582538885592-e70a5d7ab3d3?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1770&q=80", + ) + # 3. Test Exception Handling for network -@patch.object(requests, 'get', side_effect=requests.RequestException("Network error")) +@patch.object(requests, "get", side_effect=requests.RequestException("Network error")) def test_vilt_network_exception(vilt_instance): with pytest.raises(requests.RequestException): - vilt_instance("What is this image", "https://images.unsplash.com/photo-1582538885592-e70a5d7ab3d3?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1770&q=80") + vilt_instance( + "What is this image", + "https://images.unsplash.com/photo-1582538885592-e70a5d7ab3d3?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1770&q=80", + ) + # Parameterized test cases for different inputs @pytest.mark.parametrize( @@ -39,26 +50,37 @@ def test_vilt_network_exception(vilt_instance): ("Who is in the image?", "http://example.com/image2.jpg"), ("Where was this picture taken?", "http://example.com/image3.jpg"), # ... Add more scenarios - ] + ], ) def test_vilt_various_inputs(text, image_url, vilt_instance): with pytest.raises(Exception): # Again, ensure exception is more specific vilt_instance(text, image_url) + # Test with invalid or empty text @pytest.mark.parametrize( "text,image_url", [ - ("", "https://images.unsplash.com/photo-1582538885592-e70a5d7ab3d3?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1770&q=80"), - (None, "https://images.unsplash.com/photo-1582538885592-e70a5d7ab3d3?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1770&q=80"), - (" ", "https://images.unsplash.com/photo-1582538885592-e70a5d7ab3d3?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1770&q=80"), + ( + "", + "https://images.unsplash.com/photo-1582538885592-e70a5d7ab3d3?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1770&q=80", + ), + ( + None, + "https://images.unsplash.com/photo-1582538885592-e70a5d7ab3d3?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1770&q=80", + ), + ( + " ", + "https://images.unsplash.com/photo-1582538885592-e70a5d7ab3d3?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1770&q=80", + ), # ... Add more scenarios - ] + ], ) def test_vilt_invalid_text(text, image_url, vilt_instance): with pytest.raises(ValueError): vilt_instance(text, image_url) + # Test with invalid or empty image_url @pytest.mark.parametrize( "text,image_url", @@ -66,9 +88,8 @@ def test_vilt_invalid_text(text, image_url, vilt_instance): ("What is this?", ""), ("Who is in the image?", None), ("Where was this picture taken?", " "), - ] + ], ) def test_vilt_invalid_image_url(text, image_url, vilt_instance): with pytest.raises(ValueError): vilt_instance(text, image_url) -