From dfea671d5ee988b7fc61784ee14a27a6e47ebac6 Mon Sep 17 00:00:00 2001 From: Kye Date: Tue, 14 Nov 2023 16:09:05 -0500 Subject: [PATCH] tests for yi, stable diffusion, timm models, etc --- swarms/models/autotemp.py | 2 +- swarms/models/simple_ada.py | 1 + tests/models/bioclip.py | 161 +++++++++++++++++++ tests/models/distill_whisper.py | 118 +++++++++++++- tests/models/distilled_whisperx.py | 119 -------------- tests/models/llama_function_caller.py | 115 +++++++++++++ tests/models/speech_t5.py | 139 ++++++++++++++++ tests/models/ssd_1b.py | 223 ++++++++++++++++++++++++++ tests/models/timm_model.py | 164 +++++++++++++++++++ tests/models/yi_200k.py | 106 ++++++++++++ 10 files changed, 1024 insertions(+), 124 deletions(-) create mode 100644 tests/models/bioclip.py delete mode 100644 tests/models/distilled_whisperx.py create mode 100644 tests/models/llama_function_caller.py create mode 100644 tests/models/speech_t5.py create mode 100644 tests/models/ssd_1b.py create mode 100644 tests/models/timm_model.py create mode 100644 tests/models/yi_200k.py diff --git a/swarms/models/autotemp.py b/swarms/models/autotemp.py index 3c89ad73..c3abb894 100644 --- a/swarms/models/autotemp.py +++ b/swarms/models/autotemp.py @@ -1,6 +1,6 @@ import re from concurrent.futures import ThreadPoolExecutor, as_completed -from swarms.models.auto_temp import OpenAIChat +from swarms.models.openai_models import OpenAIChat class AutoTempAgent: diff --git a/swarms/models/simple_ada.py b/swarms/models/simple_ada.py index 973adaea..6a0dbcc9 100644 --- a/swarms/models/simple_ada.py +++ b/swarms/models/simple_ada.py @@ -2,6 +2,7 @@ from openai import OpenAI client = OpenAI() + def get_ada_embeddings(text: str, model: str = "text-embedding-ada-002"): """ Simple function to get embeddings from ada diff --git a/tests/models/bioclip.py b/tests/models/bioclip.py new file mode 100644 index 00000000..50a65570 --- /dev/null +++ b/tests/models/bioclip.py @@ -0,0 +1,161 @@ +# Import necessary modules and define fixtures if needed +import os +import pytest +import torch +from PIL import Image +from swarms.models.bioclip import BioClip + + +# Define fixtures if needed +@pytest.fixture +def sample_image_path(): + return "path_to_sample_image.jpg" + + +@pytest.fixture +def clip_instance(): + return BioClip("microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224") + + +# Basic tests for the BioClip class +def test_clip_initialization(clip_instance): + assert isinstance(clip_instance.model, torch.nn.Module) + assert hasattr(clip_instance, "model_path") + assert hasattr(clip_instance, "preprocess_train") + assert hasattr(clip_instance, "preprocess_val") + assert hasattr(clip_instance, "tokenizer") + assert hasattr(clip_instance, "device") + + +def test_clip_call_method(clip_instance, sample_image_path): + labels = [ + "adenocarcinoma histopathology", + "brain MRI", + "covid line chart", + "squamous cell carcinoma histopathology", + "immunohistochemistry histopathology", + "bone X-ray", + "chest X-ray", + "pie chart", + "hematoxylin and eosin histopathology", + ] + result = clip_instance(sample_image_path, labels) + assert isinstance(result, dict) + assert len(result) == len(labels) + + +def test_clip_plot_image_with_metadata(clip_instance, sample_image_path): + metadata = { + "filename": "sample_image.jpg", + "top_probs": {"label1": 0.75, "label2": 0.65}, + } + clip_instance.plot_image_with_metadata(sample_image_path, metadata) + + +# More test cases can be added to cover additional functionality and edge cases + + +# Parameterized tests for different image and label combinations +@pytest.mark.parametrize( + "image_path, labels", + [ + ("image1.jpg", ["label1", "label2"]), + ("image2.jpg", ["label3", "label4"]), + # Add more image and label combinations + ], +) +def test_clip_parameterized_calls(clip_instance, image_path, labels): + result = clip_instance(image_path, labels) + assert isinstance(result, dict) + assert len(result) == len(labels) + + +# Test image preprocessing +def test_clip_image_preprocessing(clip_instance, sample_image_path): + image = Image.open(sample_image_path) + processed_image = clip_instance.preprocess_val(image) + assert isinstance(processed_image, torch.Tensor) + + +# Test label tokenization +def test_clip_label_tokenization(clip_instance): + labels = ["label1", "label2"] + tokenized_labels = clip_instance.tokenizer(labels) + assert isinstance(tokenized_labels, torch.Tensor) + assert tokenized_labels.shape[0] == len(labels) + + +# More tests can be added to cover other methods and edge cases + + +# End-to-end tests with actual images and labels +def test_clip_end_to_end(clip_instance, sample_image_path): + labels = [ + "adenocarcinoma histopathology", + "brain MRI", + "covid line chart", + "squamous cell carcinoma histopathology", + "immunohistochemistry histopathology", + "bone X-ray", + "chest X-ray", + "pie chart", + "hematoxylin and eosin histopathology", + ] + result = clip_instance(sample_image_path, labels) + assert isinstance(result, dict) + assert len(result) == len(labels) + + +# Test label tokenization with long labels +def test_clip_long_labels(clip_instance): + labels = ["label" + str(i) for i in range(100)] + tokenized_labels = clip_instance.tokenizer(labels) + assert isinstance(tokenized_labels, torch.Tensor) + assert tokenized_labels.shape[0] == len(labels) + + +# Test handling of multiple image files +def test_clip_multiple_images(clip_instance, sample_image_path): + labels = ["label1", "label2"] + image_paths = [sample_image_path, "image2.jpg"] + results = clip_instance(image_paths, labels) + assert isinstance(results, list) + assert len(results) == len(image_paths) + for result in results: + assert isinstance(result, dict) + assert len(result) == len(labels) + + +# Test model inference performance +def test_clip_inference_performance(clip_instance, sample_image_path, benchmark): + labels = [ + "adenocarcinoma histopathology", + "brain MRI", + "covid line chart", + "squamous cell carcinoma histopathology", + "immunohistochemistry histopathology", + "bone X-ray", + "chest X-ray", + "pie chart", + "hematoxylin and eosin histopathology", + ] + result = benchmark(clip_instance, sample_image_path, labels) + assert isinstance(result, dict) + assert len(result) == len(labels) + + +# Test different preprocessing pipelines +def test_clip_preprocessing_pipelines(clip_instance, sample_image_path): + labels = ["label1", "label2"] + image = Image.open(sample_image_path) + + # Test preprocessing for training + processed_image_train = clip_instance.preprocess_train(image) + assert isinstance(processed_image_train, torch.Tensor) + + # Test preprocessing for validation + processed_image_val = clip_instance.preprocess_val(image) + assert isinstance(processed_image_val, torch.Tensor) + + +# ... diff --git a/tests/models/distill_whisper.py b/tests/models/distill_whisper.py index 6fbfccd1..d83caf62 100644 --- a/tests/models/distill_whisper.py +++ b/tests/models/distill_whisper.py @@ -1,13 +1,14 @@ import os import tempfile from functools import wraps -from unittest.mock import patch +from unittest.mock import AsyncMock, MagicMock, patch import numpy as np import pytest import torch +from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor -from swarms.models.distill_whisperx import DistilWhisperModel, async_retry +from swarms.models.distilled_whisperx import DistilWhisperModel, async_retry @pytest.fixture @@ -150,5 +151,114 @@ def test_create_audio_file(): os.remove(audio_file_path) -if __name__ == "__main__": - pytest.main() +# test_distilled_whisperx.py + + +# Fixtures for setting up model, processor, and audio files +@pytest.fixture(scope="module") +def model_id(): + return "distil-whisper/distil-large-v2" + + +@pytest.fixture(scope="module") +def whisper_model(model_id): + return DistilWhisperModel(model_id) + + +@pytest.fixture(scope="session") +def audio_file_path(tmp_path_factory): + # You would create a small temporary MP3 file here for testing + # or use a public domain MP3 file's path + return "path/to/valid_audio.mp3" + + +@pytest.fixture(scope="session") +def invalid_audio_file_path(): + return "path/to/invalid_audio.mp3" + + +@pytest.fixture(scope="session") +def audio_dict(): + # This should represent a valid audio dictionary as expected by the model + return {"array": torch.randn(1, 16000), "sampling_rate": 16000} + + +# Test initialization +def test_initialization(whisper_model): + assert whisper_model.model is not None + assert whisper_model.processor is not None + + +# Test successful transcription with file path +def test_transcribe_with_file_path(whisper_model, audio_file_path): + transcription = whisper_model.transcribe(audio_file_path) + assert isinstance(transcription, str) + + +# Test successful transcription with audio dict +def test_transcribe_with_audio_dict(whisper_model, audio_dict): + transcription = whisper_model.transcribe(audio_dict) + assert isinstance(transcription, str) + + +# Test for file not found error +def test_file_not_found(whisper_model, invalid_audio_file_path): + with pytest.raises(Exception): + whisper_model.transcribe(invalid_audio_file_path) + + +# Asynchronous tests +@pytest.mark.asyncio +async def test_async_transcription_success(whisper_model, audio_file_path): + transcription = await whisper_model.async_transcribe(audio_file_path) + assert isinstance(transcription, str) + + +@pytest.mark.asyncio +async def test_async_transcription_failure(whisper_model, invalid_audio_file_path): + with pytest.raises(Exception): + await whisper_model.async_transcribe(invalid_audio_file_path) + + +# Testing real-time transcription simulation +def test_real_time_transcription(whisper_model, audio_file_path, capsys): + whisper_model.real_time_transcribe(audio_file_path, chunk_duration=1) + captured = capsys.readouterr() + assert "Starting real-time transcription..." in captured.out + + +# Testing retry decorator for asynchronous function +@pytest.mark.asyncio +async def test_async_retry(): + @async_retry(max_retries=2, exceptions=(ValueError,), delay=0) + async def failing_func(): + raise ValueError("Test") + + with pytest.raises(ValueError): + await failing_func() + + +# Mocking the actual model to avoid GPU/CPU intensive operations during test +@pytest.fixture +def mocked_model(monkeypatch): + model_mock = AsyncMock(AutoModelForSpeechSeq2Seq) + processor_mock = MagicMock(AutoProcessor) + monkeypatch.setattr( + "swarms.models.distilled_whisperx.AutoModelForSpeechSeq2Seq.from_pretrained", + model_mock, + ) + monkeypatch.setattr( + "swarms.models.distilled_whisperx.AutoProcessor.from_pretrained", processor_mock + ) + return model_mock, processor_mock + + +@pytest.mark.asyncio +async def test_async_transcribe_with_mocked_model(mocked_model, audio_file_path): + model_mock, processor_mock = mocked_model + # Set up what the mock should return when it's called + model_mock.return_value.generate.return_value = torch.tensor([[0]]) + processor_mock.return_value.batch_decode.return_value = ["mocked transcription"] + model_wrapper = DistilWhisperModel() + transcription = await model_wrapper.async_transcribe(audio_file_path) + assert transcription == "mocked transcription" diff --git a/tests/models/distilled_whisperx.py b/tests/models/distilled_whisperx.py deleted file mode 100644 index 4bdd10f3..00000000 --- a/tests/models/distilled_whisperx.py +++ /dev/null @@ -1,119 +0,0 @@ -# test_distilled_whisperx.py - -from unittest.mock import AsyncMock, MagicMock - -import pytest -import torch -from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor - -from swarms.models.distilled_whisperx import DistilWhisperModel, async_retry - - -# Fixtures for setting up model, processor, and audio files -@pytest.fixture(scope="module") -def model_id(): - return "distil-whisper/distil-large-v2" - - -@pytest.fixture(scope="module") -def whisper_model(model_id): - return DistilWhisperModel(model_id) - - -@pytest.fixture(scope="session") -def audio_file_path(tmp_path_factory): - # You would create a small temporary MP3 file here for testing - # or use a public domain MP3 file's path - return "path/to/valid_audio.mp3" - - -@pytest.fixture(scope="session") -def invalid_audio_file_path(): - return "path/to/invalid_audio.mp3" - - -@pytest.fixture(scope="session") -def audio_dict(): - # This should represent a valid audio dictionary as expected by the model - return {"array": torch.randn(1, 16000), "sampling_rate": 16000} - - -# Test initialization -def test_initialization(whisper_model): - assert whisper_model.model is not None - assert whisper_model.processor is not None - - -# Test successful transcription with file path -def test_transcribe_with_file_path(whisper_model, audio_file_path): - transcription = whisper_model.transcribe(audio_file_path) - assert isinstance(transcription, str) - - -# Test successful transcription with audio dict -def test_transcribe_with_audio_dict(whisper_model, audio_dict): - transcription = whisper_model.transcribe(audio_dict) - assert isinstance(transcription, str) - - -# Test for file not found error -def test_file_not_found(whisper_model, invalid_audio_file_path): - with pytest.raises(Exception): - whisper_model.transcribe(invalid_audio_file_path) - - -# Asynchronous tests -@pytest.mark.asyncio -async def test_async_transcription_success(whisper_model, audio_file_path): - transcription = await whisper_model.async_transcribe(audio_file_path) - assert isinstance(transcription, str) - - -@pytest.mark.asyncio -async def test_async_transcription_failure(whisper_model, invalid_audio_file_path): - with pytest.raises(Exception): - await whisper_model.async_transcribe(invalid_audio_file_path) - - -# Testing real-time transcription simulation -def test_real_time_transcription(whisper_model, audio_file_path, capsys): - whisper_model.real_time_transcribe(audio_file_path, chunk_duration=1) - captured = capsys.readouterr() - assert "Starting real-time transcription..." in captured.out - - -# Testing retry decorator for asynchronous function -@pytest.mark.asyncio -async def test_async_retry(): - @async_retry(max_retries=2, exceptions=(ValueError,), delay=0) - async def failing_func(): - raise ValueError("Test") - - with pytest.raises(ValueError): - await failing_func() - - -# Mocking the actual model to avoid GPU/CPU intensive operations during test -@pytest.fixture -def mocked_model(monkeypatch): - model_mock = AsyncMock(AutoModelForSpeechSeq2Seq) - processor_mock = MagicMock(AutoProcessor) - monkeypatch.setattr( - "swarms.models.distilled_whisperx.AutoModelForSpeechSeq2Seq.from_pretrained", - model_mock, - ) - monkeypatch.setattr( - "swarms.models.distilled_whisperx.AutoProcessor.from_pretrained", processor_mock - ) - return model_mock, processor_mock - - -@pytest.mark.asyncio -async def test_async_transcribe_with_mocked_model(mocked_model, audio_file_path): - model_mock, processor_mock = mocked_model - # Set up what the mock should return when it's called - model_mock.return_value.generate.return_value = torch.tensor([[0]]) - processor_mock.return_value.batch_decode.return_value = ["mocked transcription"] - model_wrapper = DistilWhisperModel() - transcription = await model_wrapper.async_transcribe(audio_file_path) - assert transcription == "mocked transcription" diff --git a/tests/models/llama_function_caller.py b/tests/models/llama_function_caller.py new file mode 100644 index 00000000..c54c264b --- /dev/null +++ b/tests/models/llama_function_caller.py @@ -0,0 +1,115 @@ +import pytest +from swarms.models.llama_function_caller import LlamaFunctionCaller + + +# Define fixtures if needed +@pytest.fixture +def llama_caller(): + # Initialize the LlamaFunctionCaller with a sample model + return LlamaFunctionCaller() + + +# Basic test for model loading +def test_llama_model_loading(llama_caller): + assert llama_caller.model is not None + assert llama_caller.tokenizer is not None + + +# Test adding and calling custom functions +def test_llama_custom_function(llama_caller): + def sample_function(arg1, arg2): + return f"Sample function called with args: {arg1}, {arg2}" + + llama_caller.add_func( + name="sample_function", + function=sample_function, + description="Sample custom function", + arguments=[ + {"name": "arg1", "type": "string", "description": "Argument 1"}, + {"name": "arg2", "type": "string", "description": "Argument 2"}, + ], + ) + + result = llama_caller.call_function( + "sample_function", arg1="arg1_value", arg2="arg2_value" + ) + assert result == "Sample function called with args: arg1_value, arg2_value" + + +# Test streaming user prompts +def test_llama_streaming(llama_caller): + user_prompt = "Tell me about the tallest mountain in the world." + response = llama_caller(user_prompt) + assert isinstance(response, str) + assert len(response) > 0 + + +# Test custom function not found +def test_llama_custom_function_not_found(llama_caller): + with pytest.raises(ValueError): + llama_caller.call_function("non_existent_function") + + +# Test invalid arguments for custom function +def test_llama_custom_function_invalid_arguments(llama_caller): + def sample_function(arg1, arg2): + return f"Sample function called with args: {arg1}, {arg2}" + + llama_caller.add_func( + name="sample_function", + function=sample_function, + description="Sample custom function", + arguments=[ + {"name": "arg1", "type": "string", "description": "Argument 1"}, + {"name": "arg2", "type": "string", "description": "Argument 2"}, + ], + ) + + with pytest.raises(TypeError): + llama_caller.call_function("sample_function", arg1="arg1_value") + + +# Test streaming with custom runtime +def test_llama_custom_runtime(): + llama_caller = LlamaFunctionCaller( + model_id="Your-Model-ID", cache_dir="Your-Cache-Directory", runtime="cuda" + ) + user_prompt = "Tell me about the tallest mountain in the world." + response = llama_caller(user_prompt) + assert isinstance(response, str) + assert len(response) > 0 + + +# Test caching functionality +def test_llama_cache(): + llama_caller = LlamaFunctionCaller( + model_id="Your-Model-ID", cache_dir="Your-Cache-Directory", runtime="cuda" + ) + + # Perform a request to populate the cache + user_prompt = "Tell me about the tallest mountain in the world." + response = llama_caller(user_prompt) + + # Check if the response is retrieved from the cache + llama_caller.model.from_cache = True + response_from_cache = llama_caller(user_prompt) + assert response == response_from_cache + + +# Test response length within max_tokens limit +def test_llama_response_length(): + llama_caller = LlamaFunctionCaller( + model_id="Your-Model-ID", cache_dir="Your-Cache-Directory", runtime="cuda" + ) + + # Generate a long prompt + long_prompt = "A " + "test " * 100 # Approximately 500 tokens + + # Ensure the response does not exceed max_tokens + response = llama_caller(long_prompt) + assert len(response.split()) <= 500 + + +# Add more test cases as needed to cover different aspects of your code + +# ... diff --git a/tests/models/speech_t5.py b/tests/models/speech_t5.py new file mode 100644 index 00000000..4e5f4cb1 --- /dev/null +++ b/tests/models/speech_t5.py @@ -0,0 +1,139 @@ +import pytest +import os +import torch +from swarms.models.speecht5 import SpeechT5 + + +# Create fixtures if needed +@pytest.fixture +def speecht5_model(): + return SpeechT5() + + +# Test cases for the SpeechT5 class + + +def test_speecht5_init(speecht5_model): + assert isinstance(speecht5_model.processor, SpeechT5.processor.__class__) + assert isinstance(speecht5_model.model, SpeechT5.model.__class__) + assert isinstance(speecht5_model.vocoder, SpeechT5.vocoder.__class__) + assert isinstance(speecht5_model.embeddings_dataset, torch.utils.data.Dataset) + + +def test_speecht5_call(speecht5_model): + text = "Hello, how are you?" + speech = speecht5_model(text) + assert isinstance(speech, torch.Tensor) + + +def test_speecht5_save_speech(speecht5_model): + text = "Hello, how are you?" + speech = speecht5_model(text) + filename = "test_speech.wav" + speecht5_model.save_speech(speech, filename) + assert os.path.isfile(filename) + os.remove(filename) + + +def test_speecht5_set_model(speecht5_model): + old_model_name = speecht5_model.model_name + new_model_name = "facebook/speecht5-tts" + speecht5_model.set_model(new_model_name) + assert speecht5_model.model_name == new_model_name + assert speecht5_model.processor.model_name == new_model_name + assert speecht5_model.model.config.model_name_or_path == new_model_name + speecht5_model.set_model(old_model_name) # Restore original model + + +def test_speecht5_set_vocoder(speecht5_model): + old_vocoder_name = speecht5_model.vocoder_name + new_vocoder_name = "facebook/speecht5-hifigan" + speecht5_model.set_vocoder(new_vocoder_name) + assert speecht5_model.vocoder_name == new_vocoder_name + assert speecht5_model.vocoder.config.model_name_or_path == new_vocoder_name + speecht5_model.set_vocoder(old_vocoder_name) # Restore original vocoder + + +def test_speecht5_set_embeddings_dataset(speecht5_model): + old_dataset_name = speecht5_model.dataset_name + new_dataset_name = "Matthijs/cmu-arctic-xvectors-test" + speecht5_model.set_embeddings_dataset(new_dataset_name) + assert speecht5_model.dataset_name == new_dataset_name + assert isinstance(speecht5_model.embeddings_dataset, torch.utils.data.Dataset) + speecht5_model.set_embeddings_dataset(old_dataset_name) # Restore original dataset + + +def test_speecht5_get_sampling_rate(speecht5_model): + sampling_rate = speecht5_model.get_sampling_rate() + assert sampling_rate == 16000 + + +def test_speecht5_print_model_details(speecht5_model, capsys): + speecht5_model.print_model_details() + captured = capsys.readouterr() + assert "Model Name: " in captured.out + assert "Vocoder Name: " in captured.out + + +def test_speecht5_quick_synthesize(speecht5_model): + text = "Hello, how are you?" + speech = speecht5_model.quick_synthesize(text) + assert isinstance(speech, list) + assert isinstance(speech[0], dict) + assert "audio" in speech[0] + + +def test_speecht5_change_dataset_split(speecht5_model): + split = "test" + speecht5_model.change_dataset_split(split) + assert speecht5_model.embeddings_dataset.split == split + + +def test_speecht5_load_custom_embedding(speecht5_model): + xvector = [0.1, 0.2, 0.3, 0.4, 0.5] + embedding = speecht5_model.load_custom_embedding(xvector) + assert torch.all(torch.eq(embedding, torch.tensor(xvector).unsqueeze(0))) + + +def test_speecht5_with_different_speakers(speecht5_model): + text = "Hello, how are you?" + speakers = [7306, 5324, 1234] + for speaker_id in speakers: + speech = speecht5_model(text, speaker_id=speaker_id) + assert isinstance(speech, torch.Tensor) + + +def test_speecht5_save_speech_with_different_extensions(speecht5_model): + text = "Hello, how are you?" + speech = speecht5_model(text) + extensions = [".wav", ".flac"] + for extension in extensions: + filename = f"test_speech{extension}" + speecht5_model.save_speech(speech, filename) + assert os.path.isfile(filename) + os.remove(filename) + + +def test_speecht5_invalid_speaker_id(speecht5_model): + text = "Hello, how are you?" + invalid_speaker_id = 9999 # Speaker ID that does not exist in the dataset + with pytest.raises(IndexError): + speecht5_model(text, speaker_id=invalid_speaker_id) + + +def test_speecht5_invalid_save_path(speecht5_model): + text = "Hello, how are you?" + speech = speecht5_model(text) + invalid_path = "/invalid_directory/test_speech.wav" + with pytest.raises(FileNotFoundError): + speecht5_model.save_speech(speech, invalid_path) + + +def test_speecht5_change_vocoder_model(speecht5_model): + text = "Hello, how are you?" + old_vocoder_name = speecht5_model.vocoder_name + new_vocoder_name = "facebook/speecht5-hifigan-ljspeech" + speecht5_model.set_vocoder(new_vocoder_name) + speech = speecht5_model(text) + assert isinstance(speech, torch.Tensor) + speecht5_model.set_vocoder(old_vocoder_name) # Restore original vocoder diff --git a/tests/models/ssd_1b.py b/tests/models/ssd_1b.py new file mode 100644 index 00000000..7bd3154c --- /dev/null +++ b/tests/models/ssd_1b.py @@ -0,0 +1,223 @@ +import pytest +from swarms.models.ssd_1b import SSD1B +from PIL import Image + + +# Create fixtures if needed +@pytest.fixture +def ssd1b_model(): + return SSD1B() + + +# Basic tests for model initialization and method call +def test_ssd1b_model_initialization(ssd1b_model): + assert ssd1b_model is not None + + +def test_ssd1b_call(ssd1b_model): + task = "A painting of a dog" + neg_prompt = "ugly, blurry, poor quality" + image_url = ssd1b_model(task, neg_prompt) + assert isinstance(image_url, str) + assert image_url.startswith("https://") # Assuming it starts with "https://" + + +# Add more tests for various aspects of the class and methods + + +# Example of a parameterized test for different tasks +@pytest.mark.parametrize("task", ["A painting of a cat", "A painting of a tree"]) +def test_ssd1b_parameterized_task(ssd1b_model, task): + image_url = ssd1b_model(task) + assert isinstance(image_url, str) + assert image_url.startswith("https://") # Assuming it starts with "https://" + + +# Example of a test using mocks to isolate units of code +def test_ssd1b_with_mock(ssd1b_model, mocker): + mocker.patch("your_module.StableDiffusionXLPipeline") # Mock the pipeline + task = "A painting of a cat" + image_url = ssd1b_model(task) + assert isinstance(image_url, str) + assert image_url.startswith("https://") # Assuming it starts with "https://" + + +def test_ssd1b_call_with_cache(ssd1b_model): + task = "A painting of a dog" + neg_prompt = "ugly, blurry, poor quality" + image_url1 = ssd1b_model(task, neg_prompt) + image_url2 = ssd1b_model(task, neg_prompt) # Should use cache + assert image_url1 == image_url2 + + +def test_ssd1b_invalid_task(ssd1b_model): + invalid_task = "" + with pytest.raises(ValueError): + ssd1b_model(invalid_task) + + +def test_ssd1b_failed_api_call(ssd1b_model, mocker): + mocker.patch( + "your_module.StableDiffusionXLPipeline" + ) # Mock the pipeline to raise an exception + task = "A painting of a cat" + with pytest.raises(Exception): + ssd1b_model(task) + + +def test_ssd1b_process_batch_concurrently(ssd1b_model): + tasks = [ + "A painting of a dog", + "A beautiful sunset", + "A portrait of a person", + ] + results = ssd1b_model.process_batch_concurrently(tasks) + assert isinstance(results, list) + assert len(results) == len(tasks) + + +def test_ssd1b_process_empty_batch_concurrently(ssd1b_model): + tasks = [] + results = ssd1b_model.process_batch_concurrently(tasks) + assert isinstance(results, list) + assert len(results) == 0 + + +def test_ssd1b_download_image(ssd1b_model): + task = "A painting of a dog" + neg_prompt = "ugly, blurry, poor quality" + image_url = ssd1b_model(task, neg_prompt) + img = ssd1b_model._download_image(image_url) + assert isinstance(img, Image.Image) + + +def test_ssd1b_generate_uuid(ssd1b_model): + uuid_str = ssd1b_model._generate_uuid() + assert isinstance(uuid_str, str) + assert len(uuid_str) == 36 # UUID format + + +def test_ssd1b_rate_limited_call(ssd1b_model): + task = "A painting of a dog" + image_url = ssd1b_model.rate_limited_call(task) + assert isinstance(image_url, str) + assert image_url.startswith("https://") + + +# Test cases for additional scenarios and behaviors +def test_ssd1b_dashboard_printing(ssd1b_model, capsys): + ssd1b_model.dashboard = True + ssd1b_model.print_dashboard() + captured = capsys.readouterr() + assert "SSD1B Dashboard:" in captured.out + + +def test_ssd1b_generate_image_name(ssd1b_model): + task = "A painting of a dog" + img_name = ssd1b_model._generate_image_name(task) + assert isinstance(img_name, str) + assert len(img_name) > 0 + + +def test_ssd1b_set_width_height(ssd1b_model, mocker): + img = mocker.MagicMock() + width, height = 800, 600 + result = ssd1b_model.set_width_height(img, width, height) + assert result == img.resize.return_value + + +def test_ssd1b_read_img(ssd1b_model, mocker): + img = mocker.MagicMock() + result = ssd1b_model.read_img(img) + assert result == img.open.return_value + + +def test_ssd1b_convert_to_bytesio(ssd1b_model, mocker): + img = mocker.MagicMock() + img_format = "PNG" + result = ssd1b_model.convert_to_bytesio(img, img_format) + assert isinstance(result, bytes) + + +def test_ssd1b_save_image(ssd1b_model, mocker, tmp_path): + img = mocker.MagicMock() + img_name = "test.png" + save_path = tmp_path / img_name + ssd1b_model._download_image(img, img_name, save_path) + assert save_path.exists() + + +def test_ssd1b_repr_str(ssd1b_model): + task = "A painting of a dog" + image_url = ssd1b_model(task) + assert repr(ssd1b_model) == f"SSD1B(image_url={image_url})" + assert str(ssd1b_model) == f"SSD1B(image_url={image_url})" + + +import pytest +from your_module import SSD1B + + +# Create fixtures if needed +@pytest.fixture +def ssd1b_model(): + return SSD1B() + + +# Test cases for additional scenarios and behaviors +def test_ssd1b_dashboard_printing(ssd1b_model, capsys): + ssd1b_model.dashboard = True + ssd1b_model.print_dashboard() + captured = capsys.readouterr() + assert "SSD1B Dashboard:" in captured.out + + +def test_ssd1b_generate_image_name(ssd1b_model): + task = "A painting of a dog" + img_name = ssd1b_model._generate_image_name(task) + assert isinstance(img_name, str) + assert len(img_name) > 0 + + +def test_ssd1b_set_width_height(ssd1b_model, mocker): + img = mocker.MagicMock() + width, height = 800, 600 + result = ssd1b_model.set_width_height(img, width, height) + assert result == img.resize.return_value + + +def test_ssd1b_read_img(ssd1b_model, mocker): + img = mocker.MagicMock() + result = ssd1b_model.read_img(img) + assert result == img.open.return_value + + +def test_ssd1b_convert_to_bytesio(ssd1b_model, mocker): + img = mocker.MagicMock() + img_format = "PNG" + result = ssd1b_model.convert_to_bytesio(img, img_format) + assert isinstance(result, bytes) + + +def test_ssd1b_save_image(ssd1b_model, mocker, tmp_path): + img = mocker.MagicMock() + img_name = "test.png" + save_path = tmp_path / img_name + ssd1b_model._download_image(img, img_name, save_path) + assert save_path.exists() + + +def test_ssd1b_repr_str(ssd1b_model): + task = "A painting of a dog" + image_url = ssd1b_model(task) + assert repr(ssd1b_model) == f"SSD1B(image_url={image_url})" + assert str(ssd1b_model) == f"SSD1B(image_url={image_url})" + + +def test_ssd1b_rate_limited_call(ssd1b_model, mocker): + task = "A painting of a dog" + mocker.patch.object( + ssd1b_model, "__call__", side_effect=Exception("Rate limit exceeded") + ) + with pytest.raises(Exception, match="Rate limit exceeded"): + ssd1b_model.rate_limited_call(task) diff --git a/tests/models/timm_model.py b/tests/models/timm_model.py new file mode 100644 index 00000000..a3e62605 --- /dev/null +++ b/tests/models/timm_model.py @@ -0,0 +1,164 @@ +from unittest.mock import Mock +import torch +import pytest +from swarms.models.timm import TimmModel, TimmModelInfo + + +@pytest.fixture +def sample_model_info(): + return TimmModelInfo(model_name="resnet18", pretrained=True, in_chans=3) + + +def test_get_supported_models(): + model_handler = TimmModel() + supported_models = model_handler._get_supported_models() + assert isinstance(supported_models, list) + assert len(supported_models) > 0 + + +def test_create_model(sample_model_info): + model_handler = TimmModel() + model = model_handler._create_model(sample_model_info) + assert isinstance(model, torch.nn.Module) + + +def test_call(sample_model_info): + model_handler = TimmModel() + input_tensor = torch.randn(1, 3, 224, 224) + output_shape = model_handler.__call__(sample_model_info, input_tensor) + assert isinstance(output_shape, torch.Size) + + +@pytest.mark.parametrize( + "model_name, pretrained, in_chans", + [ + ("resnet18", True, 3), + ("resnet50", False, 1), + ("efficientnet_b0", True, 3), + ], +) +def test_create_model_parameterized(model_name, pretrained, in_chans): + model_info = TimmModelInfo( + model_name=model_name, pretrained=pretrained, in_chans=in_chans + ) + model_handler = TimmModel() + model = model_handler._create_model(model_info) + assert isinstance(model, torch.nn.Module) + + +@pytest.mark.parametrize( + "model_name, pretrained, in_chans", + [ + ("resnet18", True, 3), + ("resnet50", False, 1), + ("efficientnet_b0", True, 3), + ], +) +def test_call_parameterized(model_name, pretrained, in_chans): + model_info = TimmModelInfo( + model_name=model_name, pretrained=pretrained, in_chans=in_chans + ) + model_handler = TimmModel() + input_tensor = torch.randn(1, in_chans, 224, 224) + output_shape = model_handler.__call__(model_info, input_tensor) + assert isinstance(output_shape, torch.Size) + + +def test_get_supported_models_mock(): + model_handler = TimmModel() + model_handler._get_supported_models = Mock(return_value=["resnet18", "resnet50"]) + supported_models = model_handler._get_supported_models() + assert supported_models == ["resnet18", "resnet50"] + + +def test_create_model_mock(sample_model_info): + model_handler = TimmModel() + model_handler._create_model = Mock(return_value=torch.nn.Module()) + model = model_handler._create_model(sample_model_info) + assert isinstance(model, torch.nn.Module) + + +def test_call_exception(): + model_handler = TimmModel() + model_info = TimmModelInfo(model_name="invalid_model", pretrained=True, in_chans=3) + input_tensor = torch.randn(1, 3, 224, 224) + with pytest.raises(Exception): + model_handler.__call__(model_info, input_tensor) + + +def test_coverage(): + pytest.main(["--cov=my_module", "--cov-report=html"]) + + +def test_environment_variable(): + import os + + os.environ["MODEL_NAME"] = "resnet18" + os.environ["PRETRAINED"] = "True" + os.environ["IN_CHANS"] = "3" + + model_handler = TimmModel() + model_info = TimmModelInfo( + model_name=os.environ["MODEL_NAME"], + pretrained=bool(os.environ["PRETRAINED"]), + in_chans=int(os.environ["IN_CHANS"]), + ) + input_tensor = torch.randn(1, model_info.in_chans, 224, 224) + output_shape = model_handler(model_info, input_tensor) + assert isinstance(output_shape, torch.Size) + + +@pytest.mark.slow +def test_marked_slow(): + model_handler = TimmModel() + model_info = TimmModelInfo(model_name="resnet18", pretrained=True, in_chans=3) + input_tensor = torch.randn(1, 3, 224, 224) + output_shape = model_handler(model_info, input_tensor) + assert isinstance(output_shape, torch.Size) + + +@pytest.mark.parametrize( + "model_name, pretrained, in_chans", + [ + ("resnet18", True, 3), + ("resnet50", False, 1), + ("efficientnet_b0", True, 3), + ], +) +def test_marked_parameterized(model_name, pretrained, in_chans): + model_info = TimmModelInfo( + model_name=model_name, pretrained=pretrained, in_chans=in_chans + ) + model_handler = TimmModel() + model = model_handler._create_model(model_info) + assert isinstance(model, torch.nn.Module) + + +def test_exception_testing(): + model_handler = TimmModel() + model_info = TimmModelInfo(model_name="invalid_model", pretrained=True, in_chans=3) + input_tensor = torch.randn(1, 3, 224, 224) + with pytest.raises(Exception): + model_handler.__call__(model_info, input_tensor) + + +def test_parameterized_testing(): + model_handler = TimmModel() + model_info = TimmModelInfo(model_name="resnet18", pretrained=True, in_chans=3) + input_tensor = torch.randn(1, 3, 224, 224) + output_shape = model_handler.__call__(model_info, input_tensor) + assert isinstance(output_shape, torch.Size) + + +def test_use_mocks_and_monkeypatching(): + model_handler = TimmModel() + model_handler._create_model = Mock(return_value=torch.nn.Module()) + model_info = TimmModelInfo(model_name="resnet18", pretrained=True, in_chans=3) + model = model_handler._create_model(model_info) + assert isinstance(model, torch.nn.Module) + + +def test_coverage_report(): + # Install pytest-cov + # Run tests with coverage report + pytest.main(["--cov=my_module", "--cov-report=html"]) diff --git a/tests/models/yi_200k.py b/tests/models/yi_200k.py new file mode 100644 index 00000000..72a6d1b2 --- /dev/null +++ b/tests/models/yi_200k.py @@ -0,0 +1,106 @@ +import pytest +import torch +from transformers import AutoTokenizer +from swarms.models.yi_200k import Yi34B200k + + +# Create fixtures if needed +@pytest.fixture +def yi34b_model(): + return Yi34B200k() + + +# Test cases for the Yi34B200k class +def test_yi34b_init(yi34b_model): + assert isinstance(yi34b_model.model, torch.nn.Module) + assert isinstance(yi34b_model.tokenizer, AutoTokenizer) + + +def test_yi34b_generate_text(yi34b_model): + prompt = "There's a place where time stands still." + generated_text = yi34b_model(prompt) + assert isinstance(generated_text, str) + assert len(generated_text) > 0 + + +@pytest.mark.parametrize("max_length", [64, 128, 256, 512]) +def test_yi34b_generate_text_with_length(yi34b_model, max_length): + prompt = "There's a place where time stands still." + generated_text = yi34b_model(prompt, max_length=max_length) + assert len(generated_text) <= max_length + + +@pytest.mark.parametrize("temperature", [0.5, 1.0, 1.5]) +def test_yi34b_generate_text_with_temperature(yi34b_model, temperature): + prompt = "There's a place where time stands still." + generated_text = yi34b_model(prompt, temperature=temperature) + assert isinstance(generated_text, str) + + +def test_yi34b_generate_text_with_invalid_prompt(yi34b_model): + prompt = None # Invalid prompt + with pytest.raises(ValueError, match="Input prompt must be a non-empty string"): + yi34b_model(prompt) + + +def test_yi34b_generate_text_with_invalid_max_length(yi34b_model): + prompt = "There's a place where time stands still." + max_length = -1 # Invalid max_length + with pytest.raises(ValueError, match="max_length must be a positive integer"): + yi34b_model(prompt, max_length=max_length) + + +def test_yi34b_generate_text_with_invalid_temperature(yi34b_model): + prompt = "There's a place where time stands still." + temperature = 2.0 # Invalid temperature + with pytest.raises(ValueError, match="temperature must be between 0.01 and 1.0"): + yi34b_model(prompt, temperature=temperature) + + +@pytest.mark.parametrize("top_k", [20, 30, 50]) +def test_yi34b_generate_text_with_top_k(yi34b_model, top_k): + prompt = "There's a place where time stands still." + generated_text = yi34b_model(prompt, top_k=top_k) + assert isinstance(generated_text, str) + + +@pytest.mark.parametrize("top_p", [0.5, 0.7, 0.9]) +def test_yi34b_generate_text_with_top_p(yi34b_model, top_p): + prompt = "There's a place where time stands still." + generated_text = yi34b_model(prompt, top_p=top_p) + assert isinstance(generated_text, str) + + +def test_yi34b_generate_text_with_invalid_top_k(yi34b_model): + prompt = "There's a place where time stands still." + top_k = -1 # Invalid top_k + with pytest.raises(ValueError, match="top_k must be a non-negative integer"): + yi34b_model(prompt, top_k=top_k) + + +def test_yi34b_generate_text_with_invalid_top_p(yi34b_model): + prompt = "There's a place where time stands still." + top_p = 1.5 # Invalid top_p + with pytest.raises(ValueError, match="top_p must be between 0.0 and 1.0"): + yi34b_model(prompt, top_p=top_p) + + +@pytest.mark.parametrize("repitition_penalty", [1.0, 1.2, 1.5]) +def test_yi34b_generate_text_with_repitition_penalty(yi34b_model, repitition_penalty): + prompt = "There's a place where time stands still." + generated_text = yi34b_model(prompt, repitition_penalty=repitition_penalty) + assert isinstance(generated_text, str) + + +def test_yi34b_generate_text_with_invalid_repitition_penalty(yi34b_model): + prompt = "There's a place where time stands still." + repitition_penalty = 0.0 # Invalid repitition_penalty + with pytest.raises(ValueError, match="repitition_penalty must be a positive float"): + yi34b_model(prompt, repitition_penalty=repitition_penalty) + + +def test_yi34b_generate_text_with_invalid_device(yi34b_model): + prompt = "There's a place where time stands still." + device_map = "invalid_device" # Invalid device_map + with pytest.raises(ValueError, match="Invalid device_map"): + yi34b_model(prompt, device_map=device_map)