parent
6ee9bb65b9
commit
e1b0ba42eb
Before Width: | Height: | Size: 170 KiB |
Before Width: | Height: | Size: 878 KiB |
@ -1,42 +0,0 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from swarms.models.fire_function import FireFunctionCaller
|
||||
|
||||
|
||||
def test_fire_function_caller_run(mocker):
|
||||
# Create mock model and tokenizer
|
||||
model = MagicMock()
|
||||
tokenizer = MagicMock()
|
||||
mocker.patch.object(FireFunctionCaller, "model", model)
|
||||
mocker.patch.object(FireFunctionCaller, "tokenizer", tokenizer)
|
||||
|
||||
# Create mock task and arguments
|
||||
task = "Add 2 and 3"
|
||||
args = (2, 3)
|
||||
kwargs = {}
|
||||
|
||||
# Create mock generated_ids and decoded output
|
||||
generated_ids = [1, 2, 3]
|
||||
decoded_output = "5"
|
||||
model.generate.return_value = generated_ids
|
||||
tokenizer.batch_decode.return_value = [decoded_output]
|
||||
|
||||
# Create FireFunctionCaller instance
|
||||
fire_function_caller = FireFunctionCaller()
|
||||
|
||||
# Run the function
|
||||
fire_function_caller.run(task, *args, **kwargs)
|
||||
|
||||
# Assert model.generate was called with the correct inputs
|
||||
model.generate.assert_called_once_with(
|
||||
tokenizer.apply_chat_template.return_value,
|
||||
max_new_tokens=fire_function_caller.max_tokens,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Assert tokenizer.batch_decode was called with the correct inputs
|
||||
tokenizer.batch_decode.assert_called_once_with(generated_ids)
|
||||
|
||||
# Assert the decoded output is printed
|
||||
assert decoded_output in mocker.patch.object(print, "call_args_list")
|
@ -1,162 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from swarms.models.speecht5 import SpeechT5
|
||||
|
||||
|
||||
# Create fixtures if needed
|
||||
@pytest.fixture
|
||||
def speecht5_model():
|
||||
return SpeechT5()
|
||||
|
||||
|
||||
# Test cases for the SpeechT5 class
|
||||
|
||||
|
||||
def test_speecht5_init(speecht5_model):
|
||||
assert isinstance(
|
||||
speecht5_model.processor, SpeechT5.processor.__class__
|
||||
)
|
||||
assert isinstance(speecht5_model.model, SpeechT5.model.__class__)
|
||||
assert isinstance(speecht5_model.vocoder, SpeechT5.vocoder.__class__)
|
||||
assert isinstance(
|
||||
speecht5_model.embeddings_dataset, torch.utils.data.Dataset
|
||||
)
|
||||
|
||||
|
||||
def test_speecht5_call(speecht5_model):
|
||||
text = "Hello, how are you?"
|
||||
speech = speecht5_model(text)
|
||||
assert isinstance(speech, torch.Tensor)
|
||||
|
||||
|
||||
def test_speecht5_save_speech(speecht5_model):
|
||||
text = "Hello, how are you?"
|
||||
speech = speecht5_model(text)
|
||||
filename = "test_speech.wav"
|
||||
speecht5_model.save_speech(speech, filename)
|
||||
assert os.path.isfile(filename)
|
||||
os.remove(filename)
|
||||
|
||||
|
||||
def test_speecht5_set_model(speecht5_model):
|
||||
old_model_name = speecht5_model.model_name
|
||||
new_model_name = "facebook/speecht5-tts"
|
||||
speecht5_model.set_model(new_model_name)
|
||||
assert speecht5_model.model_name == new_model_name
|
||||
assert speecht5_model.processor.model_name == new_model_name
|
||||
assert speecht5_model.model.config.model_name_or_path == new_model_name
|
||||
speecht5_model.set_model(old_model_name) # Restore original model
|
||||
|
||||
|
||||
def test_speecht5_set_vocoder(speecht5_model):
|
||||
old_vocoder_name = speecht5_model.vocoder_name
|
||||
new_vocoder_name = "facebook/speecht5-hifigan"
|
||||
speecht5_model.set_vocoder(new_vocoder_name)
|
||||
assert speecht5_model.vocoder_name == new_vocoder_name
|
||||
assert (
|
||||
speecht5_model.vocoder.config.model_name_or_path
|
||||
== new_vocoder_name
|
||||
)
|
||||
speecht5_model.set_vocoder(
|
||||
old_vocoder_name
|
||||
) # Restore original vocoder
|
||||
|
||||
|
||||
def test_speecht5_set_embeddings_dataset(speecht5_model):
|
||||
old_dataset_name = speecht5_model.dataset_name
|
||||
new_dataset_name = "Matthijs/cmu-arctic-xvectors-test"
|
||||
speecht5_model.set_embeddings_dataset(new_dataset_name)
|
||||
assert speecht5_model.dataset_name == new_dataset_name
|
||||
assert isinstance(
|
||||
speecht5_model.embeddings_dataset, torch.utils.data.Dataset
|
||||
)
|
||||
speecht5_model.set_embeddings_dataset(
|
||||
old_dataset_name
|
||||
) # Restore original dataset
|
||||
|
||||
|
||||
def test_speecht5_get_sampling_rate(speecht5_model):
|
||||
sampling_rate = speecht5_model.get_sampling_rate()
|
||||
assert sampling_rate == 16000
|
||||
|
||||
|
||||
def test_speecht5_print_model_details(speecht5_model, capsys):
|
||||
speecht5_model.print_model_details()
|
||||
captured = capsys.readouterr()
|
||||
assert "Model Name: " in captured.out
|
||||
assert "Vocoder Name: " in captured.out
|
||||
|
||||
|
||||
def test_speecht5_quick_synthesize(speecht5_model):
|
||||
text = "Hello, how are you?"
|
||||
speech = speecht5_model.quick_synthesize(text)
|
||||
assert isinstance(speech, list)
|
||||
assert isinstance(speech[0], dict)
|
||||
assert "audio" in speech[0]
|
||||
|
||||
|
||||
def test_speecht5_change_dataset_split(speecht5_model):
|
||||
split = "test"
|
||||
speecht5_model.change_dataset_split(split)
|
||||
assert speecht5_model.embeddings_dataset.split == split
|
||||
|
||||
|
||||
def test_speecht5_load_custom_embedding(speecht5_model):
|
||||
xvector = [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||
embedding = speecht5_model.load_custom_embedding(xvector)
|
||||
assert torch.all(
|
||||
torch.eq(embedding, torch.tensor(xvector).unsqueeze(0))
|
||||
)
|
||||
|
||||
|
||||
def test_speecht5_with_different_speakers(speecht5_model):
|
||||
text = "Hello, how are you?"
|
||||
speakers = [7306, 5324, 1234]
|
||||
for speaker_id in speakers:
|
||||
speech = speecht5_model(text, speaker_id=speaker_id)
|
||||
assert isinstance(speech, torch.Tensor)
|
||||
|
||||
|
||||
def test_speecht5_save_speech_with_different_extensions(
|
||||
speecht5_model,
|
||||
):
|
||||
text = "Hello, how are you?"
|
||||
speech = speecht5_model(text)
|
||||
extensions = [".wav", ".flac"]
|
||||
for extension in extensions:
|
||||
filename = f"test_speech{extension}"
|
||||
speecht5_model.save_speech(speech, filename)
|
||||
assert os.path.isfile(filename)
|
||||
os.remove(filename)
|
||||
|
||||
|
||||
def test_speecht5_invalid_speaker_id(speecht5_model):
|
||||
text = "Hello, how are you?"
|
||||
invalid_speaker_id = (
|
||||
9999 # Speaker ID that does not exist in the dataset
|
||||
)
|
||||
with pytest.raises(IndexError):
|
||||
speecht5_model(text, speaker_id=invalid_speaker_id)
|
||||
|
||||
|
||||
def test_speecht5_invalid_save_path(speecht5_model):
|
||||
text = "Hello, how are you?"
|
||||
speech = speecht5_model(text)
|
||||
invalid_path = "/invalid_directory/test_speech.wav"
|
||||
with pytest.raises(FileNotFoundError):
|
||||
speecht5_model.save_speech(speech, invalid_path)
|
||||
|
||||
|
||||
def test_speecht5_change_vocoder_model(speecht5_model):
|
||||
text = "Hello, how are you?"
|
||||
old_vocoder_name = speecht5_model.vocoder_name
|
||||
new_vocoder_name = "facebook/speecht5-hifigan-ljspeech"
|
||||
speecht5_model.set_vocoder(new_vocoder_name)
|
||||
speech = speecht5_model(text)
|
||||
assert isinstance(speech, torch.Tensor)
|
||||
speecht5_model.set_vocoder(
|
||||
old_vocoder_name
|
||||
) # Restore original vocoder
|
Loading…
Reference in new issue