tests for yi, stable diffusion, timm models, etc

Former-commit-id: dfea671d5e
clean-history
Kye 1 year ago
parent 59f3b4c83f
commit 48643b38c1

@ -1,6 +1,6 @@
import re
from concurrent.futures import ThreadPoolExecutor, as_completed
from swarms.models.auto_temp import OpenAIChat
from swarms.models.openai_models import OpenAIChat
class AutoTempAgent:

@ -2,6 +2,7 @@ from openai import OpenAI
client = OpenAI()
def get_ada_embeddings(text: str, model: str = "text-embedding-ada-002"):
"""
Simple function to get embeddings from ada

@ -0,0 +1,161 @@
# Import necessary modules and define fixtures if needed
import os
import pytest
import torch
from PIL import Image
from swarms.models.bioclip import BioClip
# Define fixtures if needed
@pytest.fixture
def sample_image_path():
return "path_to_sample_image.jpg"
@pytest.fixture
def clip_instance():
return BioClip("microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224")
# Basic tests for the BioClip class
def test_clip_initialization(clip_instance):
assert isinstance(clip_instance.model, torch.nn.Module)
assert hasattr(clip_instance, "model_path")
assert hasattr(clip_instance, "preprocess_train")
assert hasattr(clip_instance, "preprocess_val")
assert hasattr(clip_instance, "tokenizer")
assert hasattr(clip_instance, "device")
def test_clip_call_method(clip_instance, sample_image_path):
labels = [
"adenocarcinoma histopathology",
"brain MRI",
"covid line chart",
"squamous cell carcinoma histopathology",
"immunohistochemistry histopathology",
"bone X-ray",
"chest X-ray",
"pie chart",
"hematoxylin and eosin histopathology",
]
result = clip_instance(sample_image_path, labels)
assert isinstance(result, dict)
assert len(result) == len(labels)
def test_clip_plot_image_with_metadata(clip_instance, sample_image_path):
metadata = {
"filename": "sample_image.jpg",
"top_probs": {"label1": 0.75, "label2": 0.65},
}
clip_instance.plot_image_with_metadata(sample_image_path, metadata)
# More test cases can be added to cover additional functionality and edge cases
# Parameterized tests for different image and label combinations
@pytest.mark.parametrize(
"image_path, labels",
[
("image1.jpg", ["label1", "label2"]),
("image2.jpg", ["label3", "label4"]),
# Add more image and label combinations
],
)
def test_clip_parameterized_calls(clip_instance, image_path, labels):
result = clip_instance(image_path, labels)
assert isinstance(result, dict)
assert len(result) == len(labels)
# Test image preprocessing
def test_clip_image_preprocessing(clip_instance, sample_image_path):
image = Image.open(sample_image_path)
processed_image = clip_instance.preprocess_val(image)
assert isinstance(processed_image, torch.Tensor)
# Test label tokenization
def test_clip_label_tokenization(clip_instance):
labels = ["label1", "label2"]
tokenized_labels = clip_instance.tokenizer(labels)
assert isinstance(tokenized_labels, torch.Tensor)
assert tokenized_labels.shape[0] == len(labels)
# More tests can be added to cover other methods and edge cases
# End-to-end tests with actual images and labels
def test_clip_end_to_end(clip_instance, sample_image_path):
labels = [
"adenocarcinoma histopathology",
"brain MRI",
"covid line chart",
"squamous cell carcinoma histopathology",
"immunohistochemistry histopathology",
"bone X-ray",
"chest X-ray",
"pie chart",
"hematoxylin and eosin histopathology",
]
result = clip_instance(sample_image_path, labels)
assert isinstance(result, dict)
assert len(result) == len(labels)
# Test label tokenization with long labels
def test_clip_long_labels(clip_instance):
labels = ["label" + str(i) for i in range(100)]
tokenized_labels = clip_instance.tokenizer(labels)
assert isinstance(tokenized_labels, torch.Tensor)
assert tokenized_labels.shape[0] == len(labels)
# Test handling of multiple image files
def test_clip_multiple_images(clip_instance, sample_image_path):
labels = ["label1", "label2"]
image_paths = [sample_image_path, "image2.jpg"]
results = clip_instance(image_paths, labels)
assert isinstance(results, list)
assert len(results) == len(image_paths)
for result in results:
assert isinstance(result, dict)
assert len(result) == len(labels)
# Test model inference performance
def test_clip_inference_performance(clip_instance, sample_image_path, benchmark):
labels = [
"adenocarcinoma histopathology",
"brain MRI",
"covid line chart",
"squamous cell carcinoma histopathology",
"immunohistochemistry histopathology",
"bone X-ray",
"chest X-ray",
"pie chart",
"hematoxylin and eosin histopathology",
]
result = benchmark(clip_instance, sample_image_path, labels)
assert isinstance(result, dict)
assert len(result) == len(labels)
# Test different preprocessing pipelines
def test_clip_preprocessing_pipelines(clip_instance, sample_image_path):
labels = ["label1", "label2"]
image = Image.open(sample_image_path)
# Test preprocessing for training
processed_image_train = clip_instance.preprocess_train(image)
assert isinstance(processed_image_train, torch.Tensor)
# Test preprocessing for validation
processed_image_val = clip_instance.preprocess_val(image)
assert isinstance(processed_image_val, torch.Tensor)
# ...

@ -1,13 +1,14 @@
import os
import tempfile
from functools import wraps
from unittest.mock import patch
from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np
import pytest
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
from swarms.models.distill_whisperx import DistilWhisperModel, async_retry
from swarms.models.distilled_whisperx import DistilWhisperModel, async_retry
@pytest.fixture
@ -150,5 +151,114 @@ def test_create_audio_file():
os.remove(audio_file_path)
if __name__ == "__main__":
pytest.main()
# test_distilled_whisperx.py
# Fixtures for setting up model, processor, and audio files
@pytest.fixture(scope="module")
def model_id():
return "distil-whisper/distil-large-v2"
@pytest.fixture(scope="module")
def whisper_model(model_id):
return DistilWhisperModel(model_id)
@pytest.fixture(scope="session")
def audio_file_path(tmp_path_factory):
# You would create a small temporary MP3 file here for testing
# or use a public domain MP3 file's path
return "path/to/valid_audio.mp3"
@pytest.fixture(scope="session")
def invalid_audio_file_path():
return "path/to/invalid_audio.mp3"
@pytest.fixture(scope="session")
def audio_dict():
# This should represent a valid audio dictionary as expected by the model
return {"array": torch.randn(1, 16000), "sampling_rate": 16000}
# Test initialization
def test_initialization(whisper_model):
assert whisper_model.model is not None
assert whisper_model.processor is not None
# Test successful transcription with file path
def test_transcribe_with_file_path(whisper_model, audio_file_path):
transcription = whisper_model.transcribe(audio_file_path)
assert isinstance(transcription, str)
# Test successful transcription with audio dict
def test_transcribe_with_audio_dict(whisper_model, audio_dict):
transcription = whisper_model.transcribe(audio_dict)
assert isinstance(transcription, str)
# Test for file not found error
def test_file_not_found(whisper_model, invalid_audio_file_path):
with pytest.raises(Exception):
whisper_model.transcribe(invalid_audio_file_path)
# Asynchronous tests
@pytest.mark.asyncio
async def test_async_transcription_success(whisper_model, audio_file_path):
transcription = await whisper_model.async_transcribe(audio_file_path)
assert isinstance(transcription, str)
@pytest.mark.asyncio
async def test_async_transcription_failure(whisper_model, invalid_audio_file_path):
with pytest.raises(Exception):
await whisper_model.async_transcribe(invalid_audio_file_path)
# Testing real-time transcription simulation
def test_real_time_transcription(whisper_model, audio_file_path, capsys):
whisper_model.real_time_transcribe(audio_file_path, chunk_duration=1)
captured = capsys.readouterr()
assert "Starting real-time transcription..." in captured.out
# Testing retry decorator for asynchronous function
@pytest.mark.asyncio
async def test_async_retry():
@async_retry(max_retries=2, exceptions=(ValueError,), delay=0)
async def failing_func():
raise ValueError("Test")
with pytest.raises(ValueError):
await failing_func()
# Mocking the actual model to avoid GPU/CPU intensive operations during test
@pytest.fixture
def mocked_model(monkeypatch):
model_mock = AsyncMock(AutoModelForSpeechSeq2Seq)
processor_mock = MagicMock(AutoProcessor)
monkeypatch.setattr(
"swarms.models.distilled_whisperx.AutoModelForSpeechSeq2Seq.from_pretrained",
model_mock,
)
monkeypatch.setattr(
"swarms.models.distilled_whisperx.AutoProcessor.from_pretrained", processor_mock
)
return model_mock, processor_mock
@pytest.mark.asyncio
async def test_async_transcribe_with_mocked_model(mocked_model, audio_file_path):
model_mock, processor_mock = mocked_model
# Set up what the mock should return when it's called
model_mock.return_value.generate.return_value = torch.tensor([[0]])
processor_mock.return_value.batch_decode.return_value = ["mocked transcription"]
model_wrapper = DistilWhisperModel()
transcription = await model_wrapper.async_transcribe(audio_file_path)
assert transcription == "mocked transcription"

@ -1,119 +0,0 @@
# test_distilled_whisperx.py
from unittest.mock import AsyncMock, MagicMock
import pytest
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
from swarms.models.distilled_whisperx import DistilWhisperModel, async_retry
# Fixtures for setting up model, processor, and audio files
@pytest.fixture(scope="module")
def model_id():
return "distil-whisper/distil-large-v2"
@pytest.fixture(scope="module")
def whisper_model(model_id):
return DistilWhisperModel(model_id)
@pytest.fixture(scope="session")
def audio_file_path(tmp_path_factory):
# You would create a small temporary MP3 file here for testing
# or use a public domain MP3 file's path
return "path/to/valid_audio.mp3"
@pytest.fixture(scope="session")
def invalid_audio_file_path():
return "path/to/invalid_audio.mp3"
@pytest.fixture(scope="session")
def audio_dict():
# This should represent a valid audio dictionary as expected by the model
return {"array": torch.randn(1, 16000), "sampling_rate": 16000}
# Test initialization
def test_initialization(whisper_model):
assert whisper_model.model is not None
assert whisper_model.processor is not None
# Test successful transcription with file path
def test_transcribe_with_file_path(whisper_model, audio_file_path):
transcription = whisper_model.transcribe(audio_file_path)
assert isinstance(transcription, str)
# Test successful transcription with audio dict
def test_transcribe_with_audio_dict(whisper_model, audio_dict):
transcription = whisper_model.transcribe(audio_dict)
assert isinstance(transcription, str)
# Test for file not found error
def test_file_not_found(whisper_model, invalid_audio_file_path):
with pytest.raises(Exception):
whisper_model.transcribe(invalid_audio_file_path)
# Asynchronous tests
@pytest.mark.asyncio
async def test_async_transcription_success(whisper_model, audio_file_path):
transcription = await whisper_model.async_transcribe(audio_file_path)
assert isinstance(transcription, str)
@pytest.mark.asyncio
async def test_async_transcription_failure(whisper_model, invalid_audio_file_path):
with pytest.raises(Exception):
await whisper_model.async_transcribe(invalid_audio_file_path)
# Testing real-time transcription simulation
def test_real_time_transcription(whisper_model, audio_file_path, capsys):
whisper_model.real_time_transcribe(audio_file_path, chunk_duration=1)
captured = capsys.readouterr()
assert "Starting real-time transcription..." in captured.out
# Testing retry decorator for asynchronous function
@pytest.mark.asyncio
async def test_async_retry():
@async_retry(max_retries=2, exceptions=(ValueError,), delay=0)
async def failing_func():
raise ValueError("Test")
with pytest.raises(ValueError):
await failing_func()
# Mocking the actual model to avoid GPU/CPU intensive operations during test
@pytest.fixture
def mocked_model(monkeypatch):
model_mock = AsyncMock(AutoModelForSpeechSeq2Seq)
processor_mock = MagicMock(AutoProcessor)
monkeypatch.setattr(
"swarms.models.distilled_whisperx.AutoModelForSpeechSeq2Seq.from_pretrained",
model_mock,
)
monkeypatch.setattr(
"swarms.models.distilled_whisperx.AutoProcessor.from_pretrained", processor_mock
)
return model_mock, processor_mock
@pytest.mark.asyncio
async def test_async_transcribe_with_mocked_model(mocked_model, audio_file_path):
model_mock, processor_mock = mocked_model
# Set up what the mock should return when it's called
model_mock.return_value.generate.return_value = torch.tensor([[0]])
processor_mock.return_value.batch_decode.return_value = ["mocked transcription"]
model_wrapper = DistilWhisperModel()
transcription = await model_wrapper.async_transcribe(audio_file_path)
assert transcription == "mocked transcription"

@ -0,0 +1,115 @@
import pytest
from swarms.models.llama_function_caller import LlamaFunctionCaller
# Define fixtures if needed
@pytest.fixture
def llama_caller():
# Initialize the LlamaFunctionCaller with a sample model
return LlamaFunctionCaller()
# Basic test for model loading
def test_llama_model_loading(llama_caller):
assert llama_caller.model is not None
assert llama_caller.tokenizer is not None
# Test adding and calling custom functions
def test_llama_custom_function(llama_caller):
def sample_function(arg1, arg2):
return f"Sample function called with args: {arg1}, {arg2}"
llama_caller.add_func(
name="sample_function",
function=sample_function,
description="Sample custom function",
arguments=[
{"name": "arg1", "type": "string", "description": "Argument 1"},
{"name": "arg2", "type": "string", "description": "Argument 2"},
],
)
result = llama_caller.call_function(
"sample_function", arg1="arg1_value", arg2="arg2_value"
)
assert result == "Sample function called with args: arg1_value, arg2_value"
# Test streaming user prompts
def test_llama_streaming(llama_caller):
user_prompt = "Tell me about the tallest mountain in the world."
response = llama_caller(user_prompt)
assert isinstance(response, str)
assert len(response) > 0
# Test custom function not found
def test_llama_custom_function_not_found(llama_caller):
with pytest.raises(ValueError):
llama_caller.call_function("non_existent_function")
# Test invalid arguments for custom function
def test_llama_custom_function_invalid_arguments(llama_caller):
def sample_function(arg1, arg2):
return f"Sample function called with args: {arg1}, {arg2}"
llama_caller.add_func(
name="sample_function",
function=sample_function,
description="Sample custom function",
arguments=[
{"name": "arg1", "type": "string", "description": "Argument 1"},
{"name": "arg2", "type": "string", "description": "Argument 2"},
],
)
with pytest.raises(TypeError):
llama_caller.call_function("sample_function", arg1="arg1_value")
# Test streaming with custom runtime
def test_llama_custom_runtime():
llama_caller = LlamaFunctionCaller(
model_id="Your-Model-ID", cache_dir="Your-Cache-Directory", runtime="cuda"
)
user_prompt = "Tell me about the tallest mountain in the world."
response = llama_caller(user_prompt)
assert isinstance(response, str)
assert len(response) > 0
# Test caching functionality
def test_llama_cache():
llama_caller = LlamaFunctionCaller(
model_id="Your-Model-ID", cache_dir="Your-Cache-Directory", runtime="cuda"
)
# Perform a request to populate the cache
user_prompt = "Tell me about the tallest mountain in the world."
response = llama_caller(user_prompt)
# Check if the response is retrieved from the cache
llama_caller.model.from_cache = True
response_from_cache = llama_caller(user_prompt)
assert response == response_from_cache
# Test response length within max_tokens limit
def test_llama_response_length():
llama_caller = LlamaFunctionCaller(
model_id="Your-Model-ID", cache_dir="Your-Cache-Directory", runtime="cuda"
)
# Generate a long prompt
long_prompt = "A " + "test " * 100 # Approximately 500 tokens
# Ensure the response does not exceed max_tokens
response = llama_caller(long_prompt)
assert len(response.split()) <= 500
# Add more test cases as needed to cover different aspects of your code
# ...

@ -0,0 +1,139 @@
import pytest
import os
import torch
from swarms.models.speecht5 import SpeechT5
# Create fixtures if needed
@pytest.fixture
def speecht5_model():
return SpeechT5()
# Test cases for the SpeechT5 class
def test_speecht5_init(speecht5_model):
assert isinstance(speecht5_model.processor, SpeechT5.processor.__class__)
assert isinstance(speecht5_model.model, SpeechT5.model.__class__)
assert isinstance(speecht5_model.vocoder, SpeechT5.vocoder.__class__)
assert isinstance(speecht5_model.embeddings_dataset, torch.utils.data.Dataset)
def test_speecht5_call(speecht5_model):
text = "Hello, how are you?"
speech = speecht5_model(text)
assert isinstance(speech, torch.Tensor)
def test_speecht5_save_speech(speecht5_model):
text = "Hello, how are you?"
speech = speecht5_model(text)
filename = "test_speech.wav"
speecht5_model.save_speech(speech, filename)
assert os.path.isfile(filename)
os.remove(filename)
def test_speecht5_set_model(speecht5_model):
old_model_name = speecht5_model.model_name
new_model_name = "facebook/speecht5-tts"
speecht5_model.set_model(new_model_name)
assert speecht5_model.model_name == new_model_name
assert speecht5_model.processor.model_name == new_model_name
assert speecht5_model.model.config.model_name_or_path == new_model_name
speecht5_model.set_model(old_model_name) # Restore original model
def test_speecht5_set_vocoder(speecht5_model):
old_vocoder_name = speecht5_model.vocoder_name
new_vocoder_name = "facebook/speecht5-hifigan"
speecht5_model.set_vocoder(new_vocoder_name)
assert speecht5_model.vocoder_name == new_vocoder_name
assert speecht5_model.vocoder.config.model_name_or_path == new_vocoder_name
speecht5_model.set_vocoder(old_vocoder_name) # Restore original vocoder
def test_speecht5_set_embeddings_dataset(speecht5_model):
old_dataset_name = speecht5_model.dataset_name
new_dataset_name = "Matthijs/cmu-arctic-xvectors-test"
speecht5_model.set_embeddings_dataset(new_dataset_name)
assert speecht5_model.dataset_name == new_dataset_name
assert isinstance(speecht5_model.embeddings_dataset, torch.utils.data.Dataset)
speecht5_model.set_embeddings_dataset(old_dataset_name) # Restore original dataset
def test_speecht5_get_sampling_rate(speecht5_model):
sampling_rate = speecht5_model.get_sampling_rate()
assert sampling_rate == 16000
def test_speecht5_print_model_details(speecht5_model, capsys):
speecht5_model.print_model_details()
captured = capsys.readouterr()
assert "Model Name: " in captured.out
assert "Vocoder Name: " in captured.out
def test_speecht5_quick_synthesize(speecht5_model):
text = "Hello, how are you?"
speech = speecht5_model.quick_synthesize(text)
assert isinstance(speech, list)
assert isinstance(speech[0], dict)
assert "audio" in speech[0]
def test_speecht5_change_dataset_split(speecht5_model):
split = "test"
speecht5_model.change_dataset_split(split)
assert speecht5_model.embeddings_dataset.split == split
def test_speecht5_load_custom_embedding(speecht5_model):
xvector = [0.1, 0.2, 0.3, 0.4, 0.5]
embedding = speecht5_model.load_custom_embedding(xvector)
assert torch.all(torch.eq(embedding, torch.tensor(xvector).unsqueeze(0)))
def test_speecht5_with_different_speakers(speecht5_model):
text = "Hello, how are you?"
speakers = [7306, 5324, 1234]
for speaker_id in speakers:
speech = speecht5_model(text, speaker_id=speaker_id)
assert isinstance(speech, torch.Tensor)
def test_speecht5_save_speech_with_different_extensions(speecht5_model):
text = "Hello, how are you?"
speech = speecht5_model(text)
extensions = [".wav", ".flac"]
for extension in extensions:
filename = f"test_speech{extension}"
speecht5_model.save_speech(speech, filename)
assert os.path.isfile(filename)
os.remove(filename)
def test_speecht5_invalid_speaker_id(speecht5_model):
text = "Hello, how are you?"
invalid_speaker_id = 9999 # Speaker ID that does not exist in the dataset
with pytest.raises(IndexError):
speecht5_model(text, speaker_id=invalid_speaker_id)
def test_speecht5_invalid_save_path(speecht5_model):
text = "Hello, how are you?"
speech = speecht5_model(text)
invalid_path = "/invalid_directory/test_speech.wav"
with pytest.raises(FileNotFoundError):
speecht5_model.save_speech(speech, invalid_path)
def test_speecht5_change_vocoder_model(speecht5_model):
text = "Hello, how are you?"
old_vocoder_name = speecht5_model.vocoder_name
new_vocoder_name = "facebook/speecht5-hifigan-ljspeech"
speecht5_model.set_vocoder(new_vocoder_name)
speech = speecht5_model(text)
assert isinstance(speech, torch.Tensor)
speecht5_model.set_vocoder(old_vocoder_name) # Restore original vocoder

@ -0,0 +1,223 @@
import pytest
from swarms.models.ssd_1b import SSD1B
from PIL import Image
# Create fixtures if needed
@pytest.fixture
def ssd1b_model():
return SSD1B()
# Basic tests for model initialization and method call
def test_ssd1b_model_initialization(ssd1b_model):
assert ssd1b_model is not None
def test_ssd1b_call(ssd1b_model):
task = "A painting of a dog"
neg_prompt = "ugly, blurry, poor quality"
image_url = ssd1b_model(task, neg_prompt)
assert isinstance(image_url, str)
assert image_url.startswith("https://") # Assuming it starts with "https://"
# Add more tests for various aspects of the class and methods
# Example of a parameterized test for different tasks
@pytest.mark.parametrize("task", ["A painting of a cat", "A painting of a tree"])
def test_ssd1b_parameterized_task(ssd1b_model, task):
image_url = ssd1b_model(task)
assert isinstance(image_url, str)
assert image_url.startswith("https://") # Assuming it starts with "https://"
# Example of a test using mocks to isolate units of code
def test_ssd1b_with_mock(ssd1b_model, mocker):
mocker.patch("your_module.StableDiffusionXLPipeline") # Mock the pipeline
task = "A painting of a cat"
image_url = ssd1b_model(task)
assert isinstance(image_url, str)
assert image_url.startswith("https://") # Assuming it starts with "https://"
def test_ssd1b_call_with_cache(ssd1b_model):
task = "A painting of a dog"
neg_prompt = "ugly, blurry, poor quality"
image_url1 = ssd1b_model(task, neg_prompt)
image_url2 = ssd1b_model(task, neg_prompt) # Should use cache
assert image_url1 == image_url2
def test_ssd1b_invalid_task(ssd1b_model):
invalid_task = ""
with pytest.raises(ValueError):
ssd1b_model(invalid_task)
def test_ssd1b_failed_api_call(ssd1b_model, mocker):
mocker.patch(
"your_module.StableDiffusionXLPipeline"
) # Mock the pipeline to raise an exception
task = "A painting of a cat"
with pytest.raises(Exception):
ssd1b_model(task)
def test_ssd1b_process_batch_concurrently(ssd1b_model):
tasks = [
"A painting of a dog",
"A beautiful sunset",
"A portrait of a person",
]
results = ssd1b_model.process_batch_concurrently(tasks)
assert isinstance(results, list)
assert len(results) == len(tasks)
def test_ssd1b_process_empty_batch_concurrently(ssd1b_model):
tasks = []
results = ssd1b_model.process_batch_concurrently(tasks)
assert isinstance(results, list)
assert len(results) == 0
def test_ssd1b_download_image(ssd1b_model):
task = "A painting of a dog"
neg_prompt = "ugly, blurry, poor quality"
image_url = ssd1b_model(task, neg_prompt)
img = ssd1b_model._download_image(image_url)
assert isinstance(img, Image.Image)
def test_ssd1b_generate_uuid(ssd1b_model):
uuid_str = ssd1b_model._generate_uuid()
assert isinstance(uuid_str, str)
assert len(uuid_str) == 36 # UUID format
def test_ssd1b_rate_limited_call(ssd1b_model):
task = "A painting of a dog"
image_url = ssd1b_model.rate_limited_call(task)
assert isinstance(image_url, str)
assert image_url.startswith("https://")
# Test cases for additional scenarios and behaviors
def test_ssd1b_dashboard_printing(ssd1b_model, capsys):
ssd1b_model.dashboard = True
ssd1b_model.print_dashboard()
captured = capsys.readouterr()
assert "SSD1B Dashboard:" in captured.out
def test_ssd1b_generate_image_name(ssd1b_model):
task = "A painting of a dog"
img_name = ssd1b_model._generate_image_name(task)
assert isinstance(img_name, str)
assert len(img_name) > 0
def test_ssd1b_set_width_height(ssd1b_model, mocker):
img = mocker.MagicMock()
width, height = 800, 600
result = ssd1b_model.set_width_height(img, width, height)
assert result == img.resize.return_value
def test_ssd1b_read_img(ssd1b_model, mocker):
img = mocker.MagicMock()
result = ssd1b_model.read_img(img)
assert result == img.open.return_value
def test_ssd1b_convert_to_bytesio(ssd1b_model, mocker):
img = mocker.MagicMock()
img_format = "PNG"
result = ssd1b_model.convert_to_bytesio(img, img_format)
assert isinstance(result, bytes)
def test_ssd1b_save_image(ssd1b_model, mocker, tmp_path):
img = mocker.MagicMock()
img_name = "test.png"
save_path = tmp_path / img_name
ssd1b_model._download_image(img, img_name, save_path)
assert save_path.exists()
def test_ssd1b_repr_str(ssd1b_model):
task = "A painting of a dog"
image_url = ssd1b_model(task)
assert repr(ssd1b_model) == f"SSD1B(image_url={image_url})"
assert str(ssd1b_model) == f"SSD1B(image_url={image_url})"
import pytest
from your_module import SSD1B
# Create fixtures if needed
@pytest.fixture
def ssd1b_model():
return SSD1B()
# Test cases for additional scenarios and behaviors
def test_ssd1b_dashboard_printing(ssd1b_model, capsys):
ssd1b_model.dashboard = True
ssd1b_model.print_dashboard()
captured = capsys.readouterr()
assert "SSD1B Dashboard:" in captured.out
def test_ssd1b_generate_image_name(ssd1b_model):
task = "A painting of a dog"
img_name = ssd1b_model._generate_image_name(task)
assert isinstance(img_name, str)
assert len(img_name) > 0
def test_ssd1b_set_width_height(ssd1b_model, mocker):
img = mocker.MagicMock()
width, height = 800, 600
result = ssd1b_model.set_width_height(img, width, height)
assert result == img.resize.return_value
def test_ssd1b_read_img(ssd1b_model, mocker):
img = mocker.MagicMock()
result = ssd1b_model.read_img(img)
assert result == img.open.return_value
def test_ssd1b_convert_to_bytesio(ssd1b_model, mocker):
img = mocker.MagicMock()
img_format = "PNG"
result = ssd1b_model.convert_to_bytesio(img, img_format)
assert isinstance(result, bytes)
def test_ssd1b_save_image(ssd1b_model, mocker, tmp_path):
img = mocker.MagicMock()
img_name = "test.png"
save_path = tmp_path / img_name
ssd1b_model._download_image(img, img_name, save_path)
assert save_path.exists()
def test_ssd1b_repr_str(ssd1b_model):
task = "A painting of a dog"
image_url = ssd1b_model(task)
assert repr(ssd1b_model) == f"SSD1B(image_url={image_url})"
assert str(ssd1b_model) == f"SSD1B(image_url={image_url})"
def test_ssd1b_rate_limited_call(ssd1b_model, mocker):
task = "A painting of a dog"
mocker.patch.object(
ssd1b_model, "__call__", side_effect=Exception("Rate limit exceeded")
)
with pytest.raises(Exception, match="Rate limit exceeded"):
ssd1b_model.rate_limited_call(task)

@ -0,0 +1,164 @@
from unittest.mock import Mock
import torch
import pytest
from swarms.models.timm import TimmModel, TimmModelInfo
@pytest.fixture
def sample_model_info():
return TimmModelInfo(model_name="resnet18", pretrained=True, in_chans=3)
def test_get_supported_models():
model_handler = TimmModel()
supported_models = model_handler._get_supported_models()
assert isinstance(supported_models, list)
assert len(supported_models) > 0
def test_create_model(sample_model_info):
model_handler = TimmModel()
model = model_handler._create_model(sample_model_info)
assert isinstance(model, torch.nn.Module)
def test_call(sample_model_info):
model_handler = TimmModel()
input_tensor = torch.randn(1, 3, 224, 224)
output_shape = model_handler.__call__(sample_model_info, input_tensor)
assert isinstance(output_shape, torch.Size)
@pytest.mark.parametrize(
"model_name, pretrained, in_chans",
[
("resnet18", True, 3),
("resnet50", False, 1),
("efficientnet_b0", True, 3),
],
)
def test_create_model_parameterized(model_name, pretrained, in_chans):
model_info = TimmModelInfo(
model_name=model_name, pretrained=pretrained, in_chans=in_chans
)
model_handler = TimmModel()
model = model_handler._create_model(model_info)
assert isinstance(model, torch.nn.Module)
@pytest.mark.parametrize(
"model_name, pretrained, in_chans",
[
("resnet18", True, 3),
("resnet50", False, 1),
("efficientnet_b0", True, 3),
],
)
def test_call_parameterized(model_name, pretrained, in_chans):
model_info = TimmModelInfo(
model_name=model_name, pretrained=pretrained, in_chans=in_chans
)
model_handler = TimmModel()
input_tensor = torch.randn(1, in_chans, 224, 224)
output_shape = model_handler.__call__(model_info, input_tensor)
assert isinstance(output_shape, torch.Size)
def test_get_supported_models_mock():
model_handler = TimmModel()
model_handler._get_supported_models = Mock(return_value=["resnet18", "resnet50"])
supported_models = model_handler._get_supported_models()
assert supported_models == ["resnet18", "resnet50"]
def test_create_model_mock(sample_model_info):
model_handler = TimmModel()
model_handler._create_model = Mock(return_value=torch.nn.Module())
model = model_handler._create_model(sample_model_info)
assert isinstance(model, torch.nn.Module)
def test_call_exception():
model_handler = TimmModel()
model_info = TimmModelInfo(model_name="invalid_model", pretrained=True, in_chans=3)
input_tensor = torch.randn(1, 3, 224, 224)
with pytest.raises(Exception):
model_handler.__call__(model_info, input_tensor)
def test_coverage():
pytest.main(["--cov=my_module", "--cov-report=html"])
def test_environment_variable():
import os
os.environ["MODEL_NAME"] = "resnet18"
os.environ["PRETRAINED"] = "True"
os.environ["IN_CHANS"] = "3"
model_handler = TimmModel()
model_info = TimmModelInfo(
model_name=os.environ["MODEL_NAME"],
pretrained=bool(os.environ["PRETRAINED"]),
in_chans=int(os.environ["IN_CHANS"]),
)
input_tensor = torch.randn(1, model_info.in_chans, 224, 224)
output_shape = model_handler(model_info, input_tensor)
assert isinstance(output_shape, torch.Size)
@pytest.mark.slow
def test_marked_slow():
model_handler = TimmModel()
model_info = TimmModelInfo(model_name="resnet18", pretrained=True, in_chans=3)
input_tensor = torch.randn(1, 3, 224, 224)
output_shape = model_handler(model_info, input_tensor)
assert isinstance(output_shape, torch.Size)
@pytest.mark.parametrize(
"model_name, pretrained, in_chans",
[
("resnet18", True, 3),
("resnet50", False, 1),
("efficientnet_b0", True, 3),
],
)
def test_marked_parameterized(model_name, pretrained, in_chans):
model_info = TimmModelInfo(
model_name=model_name, pretrained=pretrained, in_chans=in_chans
)
model_handler = TimmModel()
model = model_handler._create_model(model_info)
assert isinstance(model, torch.nn.Module)
def test_exception_testing():
model_handler = TimmModel()
model_info = TimmModelInfo(model_name="invalid_model", pretrained=True, in_chans=3)
input_tensor = torch.randn(1, 3, 224, 224)
with pytest.raises(Exception):
model_handler.__call__(model_info, input_tensor)
def test_parameterized_testing():
model_handler = TimmModel()
model_info = TimmModelInfo(model_name="resnet18", pretrained=True, in_chans=3)
input_tensor = torch.randn(1, 3, 224, 224)
output_shape = model_handler.__call__(model_info, input_tensor)
assert isinstance(output_shape, torch.Size)
def test_use_mocks_and_monkeypatching():
model_handler = TimmModel()
model_handler._create_model = Mock(return_value=torch.nn.Module())
model_info = TimmModelInfo(model_name="resnet18", pretrained=True, in_chans=3)
model = model_handler._create_model(model_info)
assert isinstance(model, torch.nn.Module)
def test_coverage_report():
# Install pytest-cov
# Run tests with coverage report
pytest.main(["--cov=my_module", "--cov-report=html"])

@ -0,0 +1,106 @@
import pytest
import torch
from transformers import AutoTokenizer
from swarms.models.yi_200k import Yi34B200k
# Create fixtures if needed
@pytest.fixture
def yi34b_model():
return Yi34B200k()
# Test cases for the Yi34B200k class
def test_yi34b_init(yi34b_model):
assert isinstance(yi34b_model.model, torch.nn.Module)
assert isinstance(yi34b_model.tokenizer, AutoTokenizer)
def test_yi34b_generate_text(yi34b_model):
prompt = "There's a place where time stands still."
generated_text = yi34b_model(prompt)
assert isinstance(generated_text, str)
assert len(generated_text) > 0
@pytest.mark.parametrize("max_length", [64, 128, 256, 512])
def test_yi34b_generate_text_with_length(yi34b_model, max_length):
prompt = "There's a place where time stands still."
generated_text = yi34b_model(prompt, max_length=max_length)
assert len(generated_text) <= max_length
@pytest.mark.parametrize("temperature", [0.5, 1.0, 1.5])
def test_yi34b_generate_text_with_temperature(yi34b_model, temperature):
prompt = "There's a place where time stands still."
generated_text = yi34b_model(prompt, temperature=temperature)
assert isinstance(generated_text, str)
def test_yi34b_generate_text_with_invalid_prompt(yi34b_model):
prompt = None # Invalid prompt
with pytest.raises(ValueError, match="Input prompt must be a non-empty string"):
yi34b_model(prompt)
def test_yi34b_generate_text_with_invalid_max_length(yi34b_model):
prompt = "There's a place where time stands still."
max_length = -1 # Invalid max_length
with pytest.raises(ValueError, match="max_length must be a positive integer"):
yi34b_model(prompt, max_length=max_length)
def test_yi34b_generate_text_with_invalid_temperature(yi34b_model):
prompt = "There's a place where time stands still."
temperature = 2.0 # Invalid temperature
with pytest.raises(ValueError, match="temperature must be between 0.01 and 1.0"):
yi34b_model(prompt, temperature=temperature)
@pytest.mark.parametrize("top_k", [20, 30, 50])
def test_yi34b_generate_text_with_top_k(yi34b_model, top_k):
prompt = "There's a place where time stands still."
generated_text = yi34b_model(prompt, top_k=top_k)
assert isinstance(generated_text, str)
@pytest.mark.parametrize("top_p", [0.5, 0.7, 0.9])
def test_yi34b_generate_text_with_top_p(yi34b_model, top_p):
prompt = "There's a place where time stands still."
generated_text = yi34b_model(prompt, top_p=top_p)
assert isinstance(generated_text, str)
def test_yi34b_generate_text_with_invalid_top_k(yi34b_model):
prompt = "There's a place where time stands still."
top_k = -1 # Invalid top_k
with pytest.raises(ValueError, match="top_k must be a non-negative integer"):
yi34b_model(prompt, top_k=top_k)
def test_yi34b_generate_text_with_invalid_top_p(yi34b_model):
prompt = "There's a place where time stands still."
top_p = 1.5 # Invalid top_p
with pytest.raises(ValueError, match="top_p must be between 0.0 and 1.0"):
yi34b_model(prompt, top_p=top_p)
@pytest.mark.parametrize("repitition_penalty", [1.0, 1.2, 1.5])
def test_yi34b_generate_text_with_repitition_penalty(yi34b_model, repitition_penalty):
prompt = "There's a place where time stands still."
generated_text = yi34b_model(prompt, repitition_penalty=repitition_penalty)
assert isinstance(generated_text, str)
def test_yi34b_generate_text_with_invalid_repitition_penalty(yi34b_model):
prompt = "There's a place where time stands still."
repitition_penalty = 0.0 # Invalid repitition_penalty
with pytest.raises(ValueError, match="repitition_penalty must be a positive float"):
yi34b_model(prompt, repitition_penalty=repitition_penalty)
def test_yi34b_generate_text_with_invalid_device(yi34b_model):
prompt = "There's a place where time stands still."
device_map = "invalid_device" # Invalid device_map
with pytest.raises(ValueError, match="Invalid device_map"):
yi34b_model(prompt, device_map=device_map)
Loading…
Cancel
Save