parent
59f3b4c83f
commit
48643b38c1
@ -0,0 +1,161 @@
|
|||||||
|
# Import necessary modules and define fixtures if needed
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from swarms.models.bioclip import BioClip
|
||||||
|
|
||||||
|
|
||||||
|
# Define fixtures if needed
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_image_path():
|
||||||
|
return "path_to_sample_image.jpg"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def clip_instance():
|
||||||
|
return BioClip("microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224")
|
||||||
|
|
||||||
|
|
||||||
|
# Basic tests for the BioClip class
|
||||||
|
def test_clip_initialization(clip_instance):
|
||||||
|
assert isinstance(clip_instance.model, torch.nn.Module)
|
||||||
|
assert hasattr(clip_instance, "model_path")
|
||||||
|
assert hasattr(clip_instance, "preprocess_train")
|
||||||
|
assert hasattr(clip_instance, "preprocess_val")
|
||||||
|
assert hasattr(clip_instance, "tokenizer")
|
||||||
|
assert hasattr(clip_instance, "device")
|
||||||
|
|
||||||
|
|
||||||
|
def test_clip_call_method(clip_instance, sample_image_path):
|
||||||
|
labels = [
|
||||||
|
"adenocarcinoma histopathology",
|
||||||
|
"brain MRI",
|
||||||
|
"covid line chart",
|
||||||
|
"squamous cell carcinoma histopathology",
|
||||||
|
"immunohistochemistry histopathology",
|
||||||
|
"bone X-ray",
|
||||||
|
"chest X-ray",
|
||||||
|
"pie chart",
|
||||||
|
"hematoxylin and eosin histopathology",
|
||||||
|
]
|
||||||
|
result = clip_instance(sample_image_path, labels)
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert len(result) == len(labels)
|
||||||
|
|
||||||
|
|
||||||
|
def test_clip_plot_image_with_metadata(clip_instance, sample_image_path):
|
||||||
|
metadata = {
|
||||||
|
"filename": "sample_image.jpg",
|
||||||
|
"top_probs": {"label1": 0.75, "label2": 0.65},
|
||||||
|
}
|
||||||
|
clip_instance.plot_image_with_metadata(sample_image_path, metadata)
|
||||||
|
|
||||||
|
|
||||||
|
# More test cases can be added to cover additional functionality and edge cases
|
||||||
|
|
||||||
|
|
||||||
|
# Parameterized tests for different image and label combinations
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"image_path, labels",
|
||||||
|
[
|
||||||
|
("image1.jpg", ["label1", "label2"]),
|
||||||
|
("image2.jpg", ["label3", "label4"]),
|
||||||
|
# Add more image and label combinations
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_clip_parameterized_calls(clip_instance, image_path, labels):
|
||||||
|
result = clip_instance(image_path, labels)
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert len(result) == len(labels)
|
||||||
|
|
||||||
|
|
||||||
|
# Test image preprocessing
|
||||||
|
def test_clip_image_preprocessing(clip_instance, sample_image_path):
|
||||||
|
image = Image.open(sample_image_path)
|
||||||
|
processed_image = clip_instance.preprocess_val(image)
|
||||||
|
assert isinstance(processed_image, torch.Tensor)
|
||||||
|
|
||||||
|
|
||||||
|
# Test label tokenization
|
||||||
|
def test_clip_label_tokenization(clip_instance):
|
||||||
|
labels = ["label1", "label2"]
|
||||||
|
tokenized_labels = clip_instance.tokenizer(labels)
|
||||||
|
assert isinstance(tokenized_labels, torch.Tensor)
|
||||||
|
assert tokenized_labels.shape[0] == len(labels)
|
||||||
|
|
||||||
|
|
||||||
|
# More tests can be added to cover other methods and edge cases
|
||||||
|
|
||||||
|
|
||||||
|
# End-to-end tests with actual images and labels
|
||||||
|
def test_clip_end_to_end(clip_instance, sample_image_path):
|
||||||
|
labels = [
|
||||||
|
"adenocarcinoma histopathology",
|
||||||
|
"brain MRI",
|
||||||
|
"covid line chart",
|
||||||
|
"squamous cell carcinoma histopathology",
|
||||||
|
"immunohistochemistry histopathology",
|
||||||
|
"bone X-ray",
|
||||||
|
"chest X-ray",
|
||||||
|
"pie chart",
|
||||||
|
"hematoxylin and eosin histopathology",
|
||||||
|
]
|
||||||
|
result = clip_instance(sample_image_path, labels)
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert len(result) == len(labels)
|
||||||
|
|
||||||
|
|
||||||
|
# Test label tokenization with long labels
|
||||||
|
def test_clip_long_labels(clip_instance):
|
||||||
|
labels = ["label" + str(i) for i in range(100)]
|
||||||
|
tokenized_labels = clip_instance.tokenizer(labels)
|
||||||
|
assert isinstance(tokenized_labels, torch.Tensor)
|
||||||
|
assert tokenized_labels.shape[0] == len(labels)
|
||||||
|
|
||||||
|
|
||||||
|
# Test handling of multiple image files
|
||||||
|
def test_clip_multiple_images(clip_instance, sample_image_path):
|
||||||
|
labels = ["label1", "label2"]
|
||||||
|
image_paths = [sample_image_path, "image2.jpg"]
|
||||||
|
results = clip_instance(image_paths, labels)
|
||||||
|
assert isinstance(results, list)
|
||||||
|
assert len(results) == len(image_paths)
|
||||||
|
for result in results:
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert len(result) == len(labels)
|
||||||
|
|
||||||
|
|
||||||
|
# Test model inference performance
|
||||||
|
def test_clip_inference_performance(clip_instance, sample_image_path, benchmark):
|
||||||
|
labels = [
|
||||||
|
"adenocarcinoma histopathology",
|
||||||
|
"brain MRI",
|
||||||
|
"covid line chart",
|
||||||
|
"squamous cell carcinoma histopathology",
|
||||||
|
"immunohistochemistry histopathology",
|
||||||
|
"bone X-ray",
|
||||||
|
"chest X-ray",
|
||||||
|
"pie chart",
|
||||||
|
"hematoxylin and eosin histopathology",
|
||||||
|
]
|
||||||
|
result = benchmark(clip_instance, sample_image_path, labels)
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert len(result) == len(labels)
|
||||||
|
|
||||||
|
|
||||||
|
# Test different preprocessing pipelines
|
||||||
|
def test_clip_preprocessing_pipelines(clip_instance, sample_image_path):
|
||||||
|
labels = ["label1", "label2"]
|
||||||
|
image = Image.open(sample_image_path)
|
||||||
|
|
||||||
|
# Test preprocessing for training
|
||||||
|
processed_image_train = clip_instance.preprocess_train(image)
|
||||||
|
assert isinstance(processed_image_train, torch.Tensor)
|
||||||
|
|
||||||
|
# Test preprocessing for validation
|
||||||
|
processed_image_val = clip_instance.preprocess_val(image)
|
||||||
|
assert isinstance(processed_image_val, torch.Tensor)
|
||||||
|
|
||||||
|
|
||||||
|
# ...
|
@ -1,119 +0,0 @@
|
|||||||
# test_distilled_whisperx.py
|
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
|
|
||||||
|
|
||||||
from swarms.models.distilled_whisperx import DistilWhisperModel, async_retry
|
|
||||||
|
|
||||||
|
|
||||||
# Fixtures for setting up model, processor, and audio files
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def model_id():
|
|
||||||
return "distil-whisper/distil-large-v2"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def whisper_model(model_id):
|
|
||||||
return DistilWhisperModel(model_id)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def audio_file_path(tmp_path_factory):
|
|
||||||
# You would create a small temporary MP3 file here for testing
|
|
||||||
# or use a public domain MP3 file's path
|
|
||||||
return "path/to/valid_audio.mp3"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def invalid_audio_file_path():
|
|
||||||
return "path/to/invalid_audio.mp3"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def audio_dict():
|
|
||||||
# This should represent a valid audio dictionary as expected by the model
|
|
||||||
return {"array": torch.randn(1, 16000), "sampling_rate": 16000}
|
|
||||||
|
|
||||||
|
|
||||||
# Test initialization
|
|
||||||
def test_initialization(whisper_model):
|
|
||||||
assert whisper_model.model is not None
|
|
||||||
assert whisper_model.processor is not None
|
|
||||||
|
|
||||||
|
|
||||||
# Test successful transcription with file path
|
|
||||||
def test_transcribe_with_file_path(whisper_model, audio_file_path):
|
|
||||||
transcription = whisper_model.transcribe(audio_file_path)
|
|
||||||
assert isinstance(transcription, str)
|
|
||||||
|
|
||||||
|
|
||||||
# Test successful transcription with audio dict
|
|
||||||
def test_transcribe_with_audio_dict(whisper_model, audio_dict):
|
|
||||||
transcription = whisper_model.transcribe(audio_dict)
|
|
||||||
assert isinstance(transcription, str)
|
|
||||||
|
|
||||||
|
|
||||||
# Test for file not found error
|
|
||||||
def test_file_not_found(whisper_model, invalid_audio_file_path):
|
|
||||||
with pytest.raises(Exception):
|
|
||||||
whisper_model.transcribe(invalid_audio_file_path)
|
|
||||||
|
|
||||||
|
|
||||||
# Asynchronous tests
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_async_transcription_success(whisper_model, audio_file_path):
|
|
||||||
transcription = await whisper_model.async_transcribe(audio_file_path)
|
|
||||||
assert isinstance(transcription, str)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_async_transcription_failure(whisper_model, invalid_audio_file_path):
|
|
||||||
with pytest.raises(Exception):
|
|
||||||
await whisper_model.async_transcribe(invalid_audio_file_path)
|
|
||||||
|
|
||||||
|
|
||||||
# Testing real-time transcription simulation
|
|
||||||
def test_real_time_transcription(whisper_model, audio_file_path, capsys):
|
|
||||||
whisper_model.real_time_transcribe(audio_file_path, chunk_duration=1)
|
|
||||||
captured = capsys.readouterr()
|
|
||||||
assert "Starting real-time transcription..." in captured.out
|
|
||||||
|
|
||||||
|
|
||||||
# Testing retry decorator for asynchronous function
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_async_retry():
|
|
||||||
@async_retry(max_retries=2, exceptions=(ValueError,), delay=0)
|
|
||||||
async def failing_func():
|
|
||||||
raise ValueError("Test")
|
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
await failing_func()
|
|
||||||
|
|
||||||
|
|
||||||
# Mocking the actual model to avoid GPU/CPU intensive operations during test
|
|
||||||
@pytest.fixture
|
|
||||||
def mocked_model(monkeypatch):
|
|
||||||
model_mock = AsyncMock(AutoModelForSpeechSeq2Seq)
|
|
||||||
processor_mock = MagicMock(AutoProcessor)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
"swarms.models.distilled_whisperx.AutoModelForSpeechSeq2Seq.from_pretrained",
|
|
||||||
model_mock,
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
"swarms.models.distilled_whisperx.AutoProcessor.from_pretrained", processor_mock
|
|
||||||
)
|
|
||||||
return model_mock, processor_mock
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_async_transcribe_with_mocked_model(mocked_model, audio_file_path):
|
|
||||||
model_mock, processor_mock = mocked_model
|
|
||||||
# Set up what the mock should return when it's called
|
|
||||||
model_mock.return_value.generate.return_value = torch.tensor([[0]])
|
|
||||||
processor_mock.return_value.batch_decode.return_value = ["mocked transcription"]
|
|
||||||
model_wrapper = DistilWhisperModel()
|
|
||||||
transcription = await model_wrapper.async_transcribe(audio_file_path)
|
|
||||||
assert transcription == "mocked transcription"
|
|
@ -0,0 +1,115 @@
|
|||||||
|
import pytest
|
||||||
|
from swarms.models.llama_function_caller import LlamaFunctionCaller
|
||||||
|
|
||||||
|
|
||||||
|
# Define fixtures if needed
|
||||||
|
@pytest.fixture
|
||||||
|
def llama_caller():
|
||||||
|
# Initialize the LlamaFunctionCaller with a sample model
|
||||||
|
return LlamaFunctionCaller()
|
||||||
|
|
||||||
|
|
||||||
|
# Basic test for model loading
|
||||||
|
def test_llama_model_loading(llama_caller):
|
||||||
|
assert llama_caller.model is not None
|
||||||
|
assert llama_caller.tokenizer is not None
|
||||||
|
|
||||||
|
|
||||||
|
# Test adding and calling custom functions
|
||||||
|
def test_llama_custom_function(llama_caller):
|
||||||
|
def sample_function(arg1, arg2):
|
||||||
|
return f"Sample function called with args: {arg1}, {arg2}"
|
||||||
|
|
||||||
|
llama_caller.add_func(
|
||||||
|
name="sample_function",
|
||||||
|
function=sample_function,
|
||||||
|
description="Sample custom function",
|
||||||
|
arguments=[
|
||||||
|
{"name": "arg1", "type": "string", "description": "Argument 1"},
|
||||||
|
{"name": "arg2", "type": "string", "description": "Argument 2"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
result = llama_caller.call_function(
|
||||||
|
"sample_function", arg1="arg1_value", arg2="arg2_value"
|
||||||
|
)
|
||||||
|
assert result == "Sample function called with args: arg1_value, arg2_value"
|
||||||
|
|
||||||
|
|
||||||
|
# Test streaming user prompts
|
||||||
|
def test_llama_streaming(llama_caller):
|
||||||
|
user_prompt = "Tell me about the tallest mountain in the world."
|
||||||
|
response = llama_caller(user_prompt)
|
||||||
|
assert isinstance(response, str)
|
||||||
|
assert len(response) > 0
|
||||||
|
|
||||||
|
|
||||||
|
# Test custom function not found
|
||||||
|
def test_llama_custom_function_not_found(llama_caller):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
llama_caller.call_function("non_existent_function")
|
||||||
|
|
||||||
|
|
||||||
|
# Test invalid arguments for custom function
|
||||||
|
def test_llama_custom_function_invalid_arguments(llama_caller):
|
||||||
|
def sample_function(arg1, arg2):
|
||||||
|
return f"Sample function called with args: {arg1}, {arg2}"
|
||||||
|
|
||||||
|
llama_caller.add_func(
|
||||||
|
name="sample_function",
|
||||||
|
function=sample_function,
|
||||||
|
description="Sample custom function",
|
||||||
|
arguments=[
|
||||||
|
{"name": "arg1", "type": "string", "description": "Argument 1"},
|
||||||
|
{"name": "arg2", "type": "string", "description": "Argument 2"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
llama_caller.call_function("sample_function", arg1="arg1_value")
|
||||||
|
|
||||||
|
|
||||||
|
# Test streaming with custom runtime
|
||||||
|
def test_llama_custom_runtime():
|
||||||
|
llama_caller = LlamaFunctionCaller(
|
||||||
|
model_id="Your-Model-ID", cache_dir="Your-Cache-Directory", runtime="cuda"
|
||||||
|
)
|
||||||
|
user_prompt = "Tell me about the tallest mountain in the world."
|
||||||
|
response = llama_caller(user_prompt)
|
||||||
|
assert isinstance(response, str)
|
||||||
|
assert len(response) > 0
|
||||||
|
|
||||||
|
|
||||||
|
# Test caching functionality
|
||||||
|
def test_llama_cache():
|
||||||
|
llama_caller = LlamaFunctionCaller(
|
||||||
|
model_id="Your-Model-ID", cache_dir="Your-Cache-Directory", runtime="cuda"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Perform a request to populate the cache
|
||||||
|
user_prompt = "Tell me about the tallest mountain in the world."
|
||||||
|
response = llama_caller(user_prompt)
|
||||||
|
|
||||||
|
# Check if the response is retrieved from the cache
|
||||||
|
llama_caller.model.from_cache = True
|
||||||
|
response_from_cache = llama_caller(user_prompt)
|
||||||
|
assert response == response_from_cache
|
||||||
|
|
||||||
|
|
||||||
|
# Test response length within max_tokens limit
|
||||||
|
def test_llama_response_length():
|
||||||
|
llama_caller = LlamaFunctionCaller(
|
||||||
|
model_id="Your-Model-ID", cache_dir="Your-Cache-Directory", runtime="cuda"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate a long prompt
|
||||||
|
long_prompt = "A " + "test " * 100 # Approximately 500 tokens
|
||||||
|
|
||||||
|
# Ensure the response does not exceed max_tokens
|
||||||
|
response = llama_caller(long_prompt)
|
||||||
|
assert len(response.split()) <= 500
|
||||||
|
|
||||||
|
|
||||||
|
# Add more test cases as needed to cover different aspects of your code
|
||||||
|
|
||||||
|
# ...
|
@ -0,0 +1,139 @@
|
|||||||
|
import pytest
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from swarms.models.speecht5 import SpeechT5
|
||||||
|
|
||||||
|
|
||||||
|
# Create fixtures if needed
|
||||||
|
@pytest.fixture
|
||||||
|
def speecht5_model():
|
||||||
|
return SpeechT5()
|
||||||
|
|
||||||
|
|
||||||
|
# Test cases for the SpeechT5 class
|
||||||
|
|
||||||
|
|
||||||
|
def test_speecht5_init(speecht5_model):
|
||||||
|
assert isinstance(speecht5_model.processor, SpeechT5.processor.__class__)
|
||||||
|
assert isinstance(speecht5_model.model, SpeechT5.model.__class__)
|
||||||
|
assert isinstance(speecht5_model.vocoder, SpeechT5.vocoder.__class__)
|
||||||
|
assert isinstance(speecht5_model.embeddings_dataset, torch.utils.data.Dataset)
|
||||||
|
|
||||||
|
|
||||||
|
def test_speecht5_call(speecht5_model):
|
||||||
|
text = "Hello, how are you?"
|
||||||
|
speech = speecht5_model(text)
|
||||||
|
assert isinstance(speech, torch.Tensor)
|
||||||
|
|
||||||
|
|
||||||
|
def test_speecht5_save_speech(speecht5_model):
|
||||||
|
text = "Hello, how are you?"
|
||||||
|
speech = speecht5_model(text)
|
||||||
|
filename = "test_speech.wav"
|
||||||
|
speecht5_model.save_speech(speech, filename)
|
||||||
|
assert os.path.isfile(filename)
|
||||||
|
os.remove(filename)
|
||||||
|
|
||||||
|
|
||||||
|
def test_speecht5_set_model(speecht5_model):
|
||||||
|
old_model_name = speecht5_model.model_name
|
||||||
|
new_model_name = "facebook/speecht5-tts"
|
||||||
|
speecht5_model.set_model(new_model_name)
|
||||||
|
assert speecht5_model.model_name == new_model_name
|
||||||
|
assert speecht5_model.processor.model_name == new_model_name
|
||||||
|
assert speecht5_model.model.config.model_name_or_path == new_model_name
|
||||||
|
speecht5_model.set_model(old_model_name) # Restore original model
|
||||||
|
|
||||||
|
|
||||||
|
def test_speecht5_set_vocoder(speecht5_model):
|
||||||
|
old_vocoder_name = speecht5_model.vocoder_name
|
||||||
|
new_vocoder_name = "facebook/speecht5-hifigan"
|
||||||
|
speecht5_model.set_vocoder(new_vocoder_name)
|
||||||
|
assert speecht5_model.vocoder_name == new_vocoder_name
|
||||||
|
assert speecht5_model.vocoder.config.model_name_or_path == new_vocoder_name
|
||||||
|
speecht5_model.set_vocoder(old_vocoder_name) # Restore original vocoder
|
||||||
|
|
||||||
|
|
||||||
|
def test_speecht5_set_embeddings_dataset(speecht5_model):
|
||||||
|
old_dataset_name = speecht5_model.dataset_name
|
||||||
|
new_dataset_name = "Matthijs/cmu-arctic-xvectors-test"
|
||||||
|
speecht5_model.set_embeddings_dataset(new_dataset_name)
|
||||||
|
assert speecht5_model.dataset_name == new_dataset_name
|
||||||
|
assert isinstance(speecht5_model.embeddings_dataset, torch.utils.data.Dataset)
|
||||||
|
speecht5_model.set_embeddings_dataset(old_dataset_name) # Restore original dataset
|
||||||
|
|
||||||
|
|
||||||
|
def test_speecht5_get_sampling_rate(speecht5_model):
|
||||||
|
sampling_rate = speecht5_model.get_sampling_rate()
|
||||||
|
assert sampling_rate == 16000
|
||||||
|
|
||||||
|
|
||||||
|
def test_speecht5_print_model_details(speecht5_model, capsys):
|
||||||
|
speecht5_model.print_model_details()
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
assert "Model Name: " in captured.out
|
||||||
|
assert "Vocoder Name: " in captured.out
|
||||||
|
|
||||||
|
|
||||||
|
def test_speecht5_quick_synthesize(speecht5_model):
|
||||||
|
text = "Hello, how are you?"
|
||||||
|
speech = speecht5_model.quick_synthesize(text)
|
||||||
|
assert isinstance(speech, list)
|
||||||
|
assert isinstance(speech[0], dict)
|
||||||
|
assert "audio" in speech[0]
|
||||||
|
|
||||||
|
|
||||||
|
def test_speecht5_change_dataset_split(speecht5_model):
|
||||||
|
split = "test"
|
||||||
|
speecht5_model.change_dataset_split(split)
|
||||||
|
assert speecht5_model.embeddings_dataset.split == split
|
||||||
|
|
||||||
|
|
||||||
|
def test_speecht5_load_custom_embedding(speecht5_model):
|
||||||
|
xvector = [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||||
|
embedding = speecht5_model.load_custom_embedding(xvector)
|
||||||
|
assert torch.all(torch.eq(embedding, torch.tensor(xvector).unsqueeze(0)))
|
||||||
|
|
||||||
|
|
||||||
|
def test_speecht5_with_different_speakers(speecht5_model):
|
||||||
|
text = "Hello, how are you?"
|
||||||
|
speakers = [7306, 5324, 1234]
|
||||||
|
for speaker_id in speakers:
|
||||||
|
speech = speecht5_model(text, speaker_id=speaker_id)
|
||||||
|
assert isinstance(speech, torch.Tensor)
|
||||||
|
|
||||||
|
|
||||||
|
def test_speecht5_save_speech_with_different_extensions(speecht5_model):
|
||||||
|
text = "Hello, how are you?"
|
||||||
|
speech = speecht5_model(text)
|
||||||
|
extensions = [".wav", ".flac"]
|
||||||
|
for extension in extensions:
|
||||||
|
filename = f"test_speech{extension}"
|
||||||
|
speecht5_model.save_speech(speech, filename)
|
||||||
|
assert os.path.isfile(filename)
|
||||||
|
os.remove(filename)
|
||||||
|
|
||||||
|
|
||||||
|
def test_speecht5_invalid_speaker_id(speecht5_model):
|
||||||
|
text = "Hello, how are you?"
|
||||||
|
invalid_speaker_id = 9999 # Speaker ID that does not exist in the dataset
|
||||||
|
with pytest.raises(IndexError):
|
||||||
|
speecht5_model(text, speaker_id=invalid_speaker_id)
|
||||||
|
|
||||||
|
|
||||||
|
def test_speecht5_invalid_save_path(speecht5_model):
|
||||||
|
text = "Hello, how are you?"
|
||||||
|
speech = speecht5_model(text)
|
||||||
|
invalid_path = "/invalid_directory/test_speech.wav"
|
||||||
|
with pytest.raises(FileNotFoundError):
|
||||||
|
speecht5_model.save_speech(speech, invalid_path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_speecht5_change_vocoder_model(speecht5_model):
|
||||||
|
text = "Hello, how are you?"
|
||||||
|
old_vocoder_name = speecht5_model.vocoder_name
|
||||||
|
new_vocoder_name = "facebook/speecht5-hifigan-ljspeech"
|
||||||
|
speecht5_model.set_vocoder(new_vocoder_name)
|
||||||
|
speech = speecht5_model(text)
|
||||||
|
assert isinstance(speech, torch.Tensor)
|
||||||
|
speecht5_model.set_vocoder(old_vocoder_name) # Restore original vocoder
|
@ -0,0 +1,223 @@
|
|||||||
|
import pytest
|
||||||
|
from swarms.models.ssd_1b import SSD1B
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
# Create fixtures if needed
|
||||||
|
@pytest.fixture
|
||||||
|
def ssd1b_model():
|
||||||
|
return SSD1B()
|
||||||
|
|
||||||
|
|
||||||
|
# Basic tests for model initialization and method call
|
||||||
|
def test_ssd1b_model_initialization(ssd1b_model):
|
||||||
|
assert ssd1b_model is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_ssd1b_call(ssd1b_model):
|
||||||
|
task = "A painting of a dog"
|
||||||
|
neg_prompt = "ugly, blurry, poor quality"
|
||||||
|
image_url = ssd1b_model(task, neg_prompt)
|
||||||
|
assert isinstance(image_url, str)
|
||||||
|
assert image_url.startswith("https://") # Assuming it starts with "https://"
|
||||||
|
|
||||||
|
|
||||||
|
# Add more tests for various aspects of the class and methods
|
||||||
|
|
||||||
|
|
||||||
|
# Example of a parameterized test for different tasks
|
||||||
|
@pytest.mark.parametrize("task", ["A painting of a cat", "A painting of a tree"])
|
||||||
|
def test_ssd1b_parameterized_task(ssd1b_model, task):
|
||||||
|
image_url = ssd1b_model(task)
|
||||||
|
assert isinstance(image_url, str)
|
||||||
|
assert image_url.startswith("https://") # Assuming it starts with "https://"
|
||||||
|
|
||||||
|
|
||||||
|
# Example of a test using mocks to isolate units of code
|
||||||
|
def test_ssd1b_with_mock(ssd1b_model, mocker):
|
||||||
|
mocker.patch("your_module.StableDiffusionXLPipeline") # Mock the pipeline
|
||||||
|
task = "A painting of a cat"
|
||||||
|
image_url = ssd1b_model(task)
|
||||||
|
assert isinstance(image_url, str)
|
||||||
|
assert image_url.startswith("https://") # Assuming it starts with "https://"
|
||||||
|
|
||||||
|
|
||||||
|
def test_ssd1b_call_with_cache(ssd1b_model):
|
||||||
|
task = "A painting of a dog"
|
||||||
|
neg_prompt = "ugly, blurry, poor quality"
|
||||||
|
image_url1 = ssd1b_model(task, neg_prompt)
|
||||||
|
image_url2 = ssd1b_model(task, neg_prompt) # Should use cache
|
||||||
|
assert image_url1 == image_url2
|
||||||
|
|
||||||
|
|
||||||
|
def test_ssd1b_invalid_task(ssd1b_model):
|
||||||
|
invalid_task = ""
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
ssd1b_model(invalid_task)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ssd1b_failed_api_call(ssd1b_model, mocker):
|
||||||
|
mocker.patch(
|
||||||
|
"your_module.StableDiffusionXLPipeline"
|
||||||
|
) # Mock the pipeline to raise an exception
|
||||||
|
task = "A painting of a cat"
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
ssd1b_model(task)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ssd1b_process_batch_concurrently(ssd1b_model):
|
||||||
|
tasks = [
|
||||||
|
"A painting of a dog",
|
||||||
|
"A beautiful sunset",
|
||||||
|
"A portrait of a person",
|
||||||
|
]
|
||||||
|
results = ssd1b_model.process_batch_concurrently(tasks)
|
||||||
|
assert isinstance(results, list)
|
||||||
|
assert len(results) == len(tasks)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ssd1b_process_empty_batch_concurrently(ssd1b_model):
|
||||||
|
tasks = []
|
||||||
|
results = ssd1b_model.process_batch_concurrently(tasks)
|
||||||
|
assert isinstance(results, list)
|
||||||
|
assert len(results) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_ssd1b_download_image(ssd1b_model):
|
||||||
|
task = "A painting of a dog"
|
||||||
|
neg_prompt = "ugly, blurry, poor quality"
|
||||||
|
image_url = ssd1b_model(task, neg_prompt)
|
||||||
|
img = ssd1b_model._download_image(image_url)
|
||||||
|
assert isinstance(img, Image.Image)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ssd1b_generate_uuid(ssd1b_model):
|
||||||
|
uuid_str = ssd1b_model._generate_uuid()
|
||||||
|
assert isinstance(uuid_str, str)
|
||||||
|
assert len(uuid_str) == 36 # UUID format
|
||||||
|
|
||||||
|
|
||||||
|
def test_ssd1b_rate_limited_call(ssd1b_model):
|
||||||
|
task = "A painting of a dog"
|
||||||
|
image_url = ssd1b_model.rate_limited_call(task)
|
||||||
|
assert isinstance(image_url, str)
|
||||||
|
assert image_url.startswith("https://")
|
||||||
|
|
||||||
|
|
||||||
|
# Test cases for additional scenarios and behaviors
|
||||||
|
def test_ssd1b_dashboard_printing(ssd1b_model, capsys):
|
||||||
|
ssd1b_model.dashboard = True
|
||||||
|
ssd1b_model.print_dashboard()
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
assert "SSD1B Dashboard:" in captured.out
|
||||||
|
|
||||||
|
|
||||||
|
def test_ssd1b_generate_image_name(ssd1b_model):
|
||||||
|
task = "A painting of a dog"
|
||||||
|
img_name = ssd1b_model._generate_image_name(task)
|
||||||
|
assert isinstance(img_name, str)
|
||||||
|
assert len(img_name) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_ssd1b_set_width_height(ssd1b_model, mocker):
|
||||||
|
img = mocker.MagicMock()
|
||||||
|
width, height = 800, 600
|
||||||
|
result = ssd1b_model.set_width_height(img, width, height)
|
||||||
|
assert result == img.resize.return_value
|
||||||
|
|
||||||
|
|
||||||
|
def test_ssd1b_read_img(ssd1b_model, mocker):
|
||||||
|
img = mocker.MagicMock()
|
||||||
|
result = ssd1b_model.read_img(img)
|
||||||
|
assert result == img.open.return_value
|
||||||
|
|
||||||
|
|
||||||
|
def test_ssd1b_convert_to_bytesio(ssd1b_model, mocker):
|
||||||
|
img = mocker.MagicMock()
|
||||||
|
img_format = "PNG"
|
||||||
|
result = ssd1b_model.convert_to_bytesio(img, img_format)
|
||||||
|
assert isinstance(result, bytes)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ssd1b_save_image(ssd1b_model, mocker, tmp_path):
|
||||||
|
img = mocker.MagicMock()
|
||||||
|
img_name = "test.png"
|
||||||
|
save_path = tmp_path / img_name
|
||||||
|
ssd1b_model._download_image(img, img_name, save_path)
|
||||||
|
assert save_path.exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_ssd1b_repr_str(ssd1b_model):
|
||||||
|
task = "A painting of a dog"
|
||||||
|
image_url = ssd1b_model(task)
|
||||||
|
assert repr(ssd1b_model) == f"SSD1B(image_url={image_url})"
|
||||||
|
assert str(ssd1b_model) == f"SSD1B(image_url={image_url})"
|
||||||
|
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from your_module import SSD1B
|
||||||
|
|
||||||
|
|
||||||
|
# Create fixtures if needed
|
||||||
|
@pytest.fixture
|
||||||
|
def ssd1b_model():
|
||||||
|
return SSD1B()
|
||||||
|
|
||||||
|
|
||||||
|
# Test cases for additional scenarios and behaviors
|
||||||
|
def test_ssd1b_dashboard_printing(ssd1b_model, capsys):
|
||||||
|
ssd1b_model.dashboard = True
|
||||||
|
ssd1b_model.print_dashboard()
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
assert "SSD1B Dashboard:" in captured.out
|
||||||
|
|
||||||
|
|
||||||
|
def test_ssd1b_generate_image_name(ssd1b_model):
|
||||||
|
task = "A painting of a dog"
|
||||||
|
img_name = ssd1b_model._generate_image_name(task)
|
||||||
|
assert isinstance(img_name, str)
|
||||||
|
assert len(img_name) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_ssd1b_set_width_height(ssd1b_model, mocker):
|
||||||
|
img = mocker.MagicMock()
|
||||||
|
width, height = 800, 600
|
||||||
|
result = ssd1b_model.set_width_height(img, width, height)
|
||||||
|
assert result == img.resize.return_value
|
||||||
|
|
||||||
|
|
||||||
|
def test_ssd1b_read_img(ssd1b_model, mocker):
|
||||||
|
img = mocker.MagicMock()
|
||||||
|
result = ssd1b_model.read_img(img)
|
||||||
|
assert result == img.open.return_value
|
||||||
|
|
||||||
|
|
||||||
|
def test_ssd1b_convert_to_bytesio(ssd1b_model, mocker):
|
||||||
|
img = mocker.MagicMock()
|
||||||
|
img_format = "PNG"
|
||||||
|
result = ssd1b_model.convert_to_bytesio(img, img_format)
|
||||||
|
assert isinstance(result, bytes)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ssd1b_save_image(ssd1b_model, mocker, tmp_path):
|
||||||
|
img = mocker.MagicMock()
|
||||||
|
img_name = "test.png"
|
||||||
|
save_path = tmp_path / img_name
|
||||||
|
ssd1b_model._download_image(img, img_name, save_path)
|
||||||
|
assert save_path.exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_ssd1b_repr_str(ssd1b_model):
|
||||||
|
task = "A painting of a dog"
|
||||||
|
image_url = ssd1b_model(task)
|
||||||
|
assert repr(ssd1b_model) == f"SSD1B(image_url={image_url})"
|
||||||
|
assert str(ssd1b_model) == f"SSD1B(image_url={image_url})"
|
||||||
|
|
||||||
|
|
||||||
|
def test_ssd1b_rate_limited_call(ssd1b_model, mocker):
|
||||||
|
task = "A painting of a dog"
|
||||||
|
mocker.patch.object(
|
||||||
|
ssd1b_model, "__call__", side_effect=Exception("Rate limit exceeded")
|
||||||
|
)
|
||||||
|
with pytest.raises(Exception, match="Rate limit exceeded"):
|
||||||
|
ssd1b_model.rate_limited_call(task)
|
@ -0,0 +1,164 @@
|
|||||||
|
from unittest.mock import Mock
|
||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
from swarms.models.timm import TimmModel, TimmModelInfo
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_model_info():
|
||||||
|
return TimmModelInfo(model_name="resnet18", pretrained=True, in_chans=3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_supported_models():
|
||||||
|
model_handler = TimmModel()
|
||||||
|
supported_models = model_handler._get_supported_models()
|
||||||
|
assert isinstance(supported_models, list)
|
||||||
|
assert len(supported_models) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_model(sample_model_info):
|
||||||
|
model_handler = TimmModel()
|
||||||
|
model = model_handler._create_model(sample_model_info)
|
||||||
|
assert isinstance(model, torch.nn.Module)
|
||||||
|
|
||||||
|
|
||||||
|
def test_call(sample_model_info):
|
||||||
|
model_handler = TimmModel()
|
||||||
|
input_tensor = torch.randn(1, 3, 224, 224)
|
||||||
|
output_shape = model_handler.__call__(sample_model_info, input_tensor)
|
||||||
|
assert isinstance(output_shape, torch.Size)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_name, pretrained, in_chans",
|
||||||
|
[
|
||||||
|
("resnet18", True, 3),
|
||||||
|
("resnet50", False, 1),
|
||||||
|
("efficientnet_b0", True, 3),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_create_model_parameterized(model_name, pretrained, in_chans):
|
||||||
|
model_info = TimmModelInfo(
|
||||||
|
model_name=model_name, pretrained=pretrained, in_chans=in_chans
|
||||||
|
)
|
||||||
|
model_handler = TimmModel()
|
||||||
|
model = model_handler._create_model(model_info)
|
||||||
|
assert isinstance(model, torch.nn.Module)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_name, pretrained, in_chans",
|
||||||
|
[
|
||||||
|
("resnet18", True, 3),
|
||||||
|
("resnet50", False, 1),
|
||||||
|
("efficientnet_b0", True, 3),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_call_parameterized(model_name, pretrained, in_chans):
|
||||||
|
model_info = TimmModelInfo(
|
||||||
|
model_name=model_name, pretrained=pretrained, in_chans=in_chans
|
||||||
|
)
|
||||||
|
model_handler = TimmModel()
|
||||||
|
input_tensor = torch.randn(1, in_chans, 224, 224)
|
||||||
|
output_shape = model_handler.__call__(model_info, input_tensor)
|
||||||
|
assert isinstance(output_shape, torch.Size)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_supported_models_mock():
|
||||||
|
model_handler = TimmModel()
|
||||||
|
model_handler._get_supported_models = Mock(return_value=["resnet18", "resnet50"])
|
||||||
|
supported_models = model_handler._get_supported_models()
|
||||||
|
assert supported_models == ["resnet18", "resnet50"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_model_mock(sample_model_info):
|
||||||
|
model_handler = TimmModel()
|
||||||
|
model_handler._create_model = Mock(return_value=torch.nn.Module())
|
||||||
|
model = model_handler._create_model(sample_model_info)
|
||||||
|
assert isinstance(model, torch.nn.Module)
|
||||||
|
|
||||||
|
|
||||||
|
def test_call_exception():
|
||||||
|
model_handler = TimmModel()
|
||||||
|
model_info = TimmModelInfo(model_name="invalid_model", pretrained=True, in_chans=3)
|
||||||
|
input_tensor = torch.randn(1, 3, 224, 224)
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
model_handler.__call__(model_info, input_tensor)
|
||||||
|
|
||||||
|
|
||||||
|
def test_coverage():
|
||||||
|
pytest.main(["--cov=my_module", "--cov-report=html"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_environment_variable():
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ["MODEL_NAME"] = "resnet18"
|
||||||
|
os.environ["PRETRAINED"] = "True"
|
||||||
|
os.environ["IN_CHANS"] = "3"
|
||||||
|
|
||||||
|
model_handler = TimmModel()
|
||||||
|
model_info = TimmModelInfo(
|
||||||
|
model_name=os.environ["MODEL_NAME"],
|
||||||
|
pretrained=bool(os.environ["PRETRAINED"]),
|
||||||
|
in_chans=int(os.environ["IN_CHANS"]),
|
||||||
|
)
|
||||||
|
input_tensor = torch.randn(1, model_info.in_chans, 224, 224)
|
||||||
|
output_shape = model_handler(model_info, input_tensor)
|
||||||
|
assert isinstance(output_shape, torch.Size)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
|
def test_marked_slow():
|
||||||
|
model_handler = TimmModel()
|
||||||
|
model_info = TimmModelInfo(model_name="resnet18", pretrained=True, in_chans=3)
|
||||||
|
input_tensor = torch.randn(1, 3, 224, 224)
|
||||||
|
output_shape = model_handler(model_info, input_tensor)
|
||||||
|
assert isinstance(output_shape, torch.Size)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_name, pretrained, in_chans",
|
||||||
|
[
|
||||||
|
("resnet18", True, 3),
|
||||||
|
("resnet50", False, 1),
|
||||||
|
("efficientnet_b0", True, 3),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_marked_parameterized(model_name, pretrained, in_chans):
|
||||||
|
model_info = TimmModelInfo(
|
||||||
|
model_name=model_name, pretrained=pretrained, in_chans=in_chans
|
||||||
|
)
|
||||||
|
model_handler = TimmModel()
|
||||||
|
model = model_handler._create_model(model_info)
|
||||||
|
assert isinstance(model, torch.nn.Module)
|
||||||
|
|
||||||
|
|
||||||
|
def test_exception_testing():
|
||||||
|
model_handler = TimmModel()
|
||||||
|
model_info = TimmModelInfo(model_name="invalid_model", pretrained=True, in_chans=3)
|
||||||
|
input_tensor = torch.randn(1, 3, 224, 224)
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
model_handler.__call__(model_info, input_tensor)
|
||||||
|
|
||||||
|
|
||||||
|
def test_parameterized_testing():
|
||||||
|
model_handler = TimmModel()
|
||||||
|
model_info = TimmModelInfo(model_name="resnet18", pretrained=True, in_chans=3)
|
||||||
|
input_tensor = torch.randn(1, 3, 224, 224)
|
||||||
|
output_shape = model_handler.__call__(model_info, input_tensor)
|
||||||
|
assert isinstance(output_shape, torch.Size)
|
||||||
|
|
||||||
|
|
||||||
|
def test_use_mocks_and_monkeypatching():
|
||||||
|
model_handler = TimmModel()
|
||||||
|
model_handler._create_model = Mock(return_value=torch.nn.Module())
|
||||||
|
model_info = TimmModelInfo(model_name="resnet18", pretrained=True, in_chans=3)
|
||||||
|
model = model_handler._create_model(model_info)
|
||||||
|
assert isinstance(model, torch.nn.Module)
|
||||||
|
|
||||||
|
|
||||||
|
def test_coverage_report():
|
||||||
|
# Install pytest-cov
|
||||||
|
# Run tests with coverage report
|
||||||
|
pytest.main(["--cov=my_module", "--cov-report=html"])
|
@ -0,0 +1,106 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
from swarms.models.yi_200k import Yi34B200k
|
||||||
|
|
||||||
|
|
||||||
|
# Create fixtures if needed
|
||||||
|
@pytest.fixture
|
||||||
|
def yi34b_model():
|
||||||
|
return Yi34B200k()
|
||||||
|
|
||||||
|
|
||||||
|
# Test cases for the Yi34B200k class
|
||||||
|
def test_yi34b_init(yi34b_model):
|
||||||
|
assert isinstance(yi34b_model.model, torch.nn.Module)
|
||||||
|
assert isinstance(yi34b_model.tokenizer, AutoTokenizer)
|
||||||
|
|
||||||
|
|
||||||
|
def test_yi34b_generate_text(yi34b_model):
|
||||||
|
prompt = "There's a place where time stands still."
|
||||||
|
generated_text = yi34b_model(prompt)
|
||||||
|
assert isinstance(generated_text, str)
|
||||||
|
assert len(generated_text) > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("max_length", [64, 128, 256, 512])
|
||||||
|
def test_yi34b_generate_text_with_length(yi34b_model, max_length):
|
||||||
|
prompt = "There's a place where time stands still."
|
||||||
|
generated_text = yi34b_model(prompt, max_length=max_length)
|
||||||
|
assert len(generated_text) <= max_length
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("temperature", [0.5, 1.0, 1.5])
|
||||||
|
def test_yi34b_generate_text_with_temperature(yi34b_model, temperature):
|
||||||
|
prompt = "There's a place where time stands still."
|
||||||
|
generated_text = yi34b_model(prompt, temperature=temperature)
|
||||||
|
assert isinstance(generated_text, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_yi34b_generate_text_with_invalid_prompt(yi34b_model):
|
||||||
|
prompt = None # Invalid prompt
|
||||||
|
with pytest.raises(ValueError, match="Input prompt must be a non-empty string"):
|
||||||
|
yi34b_model(prompt)
|
||||||
|
|
||||||
|
|
||||||
|
def test_yi34b_generate_text_with_invalid_max_length(yi34b_model):
|
||||||
|
prompt = "There's a place where time stands still."
|
||||||
|
max_length = -1 # Invalid max_length
|
||||||
|
with pytest.raises(ValueError, match="max_length must be a positive integer"):
|
||||||
|
yi34b_model(prompt, max_length=max_length)
|
||||||
|
|
||||||
|
|
||||||
|
def test_yi34b_generate_text_with_invalid_temperature(yi34b_model):
|
||||||
|
prompt = "There's a place where time stands still."
|
||||||
|
temperature = 2.0 # Invalid temperature
|
||||||
|
with pytest.raises(ValueError, match="temperature must be between 0.01 and 1.0"):
|
||||||
|
yi34b_model(prompt, temperature=temperature)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("top_k", [20, 30, 50])
|
||||||
|
def test_yi34b_generate_text_with_top_k(yi34b_model, top_k):
|
||||||
|
prompt = "There's a place where time stands still."
|
||||||
|
generated_text = yi34b_model(prompt, top_k=top_k)
|
||||||
|
assert isinstance(generated_text, str)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("top_p", [0.5, 0.7, 0.9])
|
||||||
|
def test_yi34b_generate_text_with_top_p(yi34b_model, top_p):
|
||||||
|
prompt = "There's a place where time stands still."
|
||||||
|
generated_text = yi34b_model(prompt, top_p=top_p)
|
||||||
|
assert isinstance(generated_text, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_yi34b_generate_text_with_invalid_top_k(yi34b_model):
|
||||||
|
prompt = "There's a place where time stands still."
|
||||||
|
top_k = -1 # Invalid top_k
|
||||||
|
with pytest.raises(ValueError, match="top_k must be a non-negative integer"):
|
||||||
|
yi34b_model(prompt, top_k=top_k)
|
||||||
|
|
||||||
|
|
||||||
|
def test_yi34b_generate_text_with_invalid_top_p(yi34b_model):
|
||||||
|
prompt = "There's a place where time stands still."
|
||||||
|
top_p = 1.5 # Invalid top_p
|
||||||
|
with pytest.raises(ValueError, match="top_p must be between 0.0 and 1.0"):
|
||||||
|
yi34b_model(prompt, top_p=top_p)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("repitition_penalty", [1.0, 1.2, 1.5])
|
||||||
|
def test_yi34b_generate_text_with_repitition_penalty(yi34b_model, repitition_penalty):
|
||||||
|
prompt = "There's a place where time stands still."
|
||||||
|
generated_text = yi34b_model(prompt, repitition_penalty=repitition_penalty)
|
||||||
|
assert isinstance(generated_text, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_yi34b_generate_text_with_invalid_repitition_penalty(yi34b_model):
|
||||||
|
prompt = "There's a place where time stands still."
|
||||||
|
repitition_penalty = 0.0 # Invalid repitition_penalty
|
||||||
|
with pytest.raises(ValueError, match="repitition_penalty must be a positive float"):
|
||||||
|
yi34b_model(prompt, repitition_penalty=repitition_penalty)
|
||||||
|
|
||||||
|
|
||||||
|
def test_yi34b_generate_text_with_invalid_device(yi34b_model):
|
||||||
|
prompt = "There's a place where time stands still."
|
||||||
|
device_map = "invalid_device" # Invalid device_map
|
||||||
|
with pytest.raises(ValueError, match="Invalid device_map"):
|
||||||
|
yi34b_model(prompt, device_map=device_map)
|
Loading…
Reference in new issue