You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/tests/models/test_openaitts.py

108 lines
2.9 KiB

5 months ago
from unittest.mock import MagicMock, patch
import pytest
4 months ago
from swarm_models.openai_tts import OpenAITTS
5 months ago
def test_openaitts_initialization():
tts = OpenAITTS()
assert isinstance(tts, OpenAITTS)
assert tts.model_name == "tts-1-1106"
assert tts.proxy_url == "https://api.openai.com/v1/audio/speech"
assert tts.voice == "onyx"
assert tts.chunk_size == 1024 * 1024
def test_openaitts_initialization_custom_parameters():
tts = OpenAITTS(
"custom_model",
"custom_url",
"custom_key",
"custom_voice",
2048,
)
assert tts.model_name == "custom_model"
assert tts.proxy_url == "custom_url"
assert tts.openai_api_key == "custom_key"
assert tts.voice == "custom_voice"
assert tts.chunk_size == 2048
@patch("requests.post")
def test_run(mock_post):
mock_response = MagicMock()
mock_response.iter_content.return_value = [b"chunk1", b"chunk2"]
mock_post.return_value = mock_response
tts = OpenAITTS()
audio = tts.run("Hello world")
assert audio == b"chunk1chunk2"
mock_post.assert_called_once_with(
"https://api.openai.com/v1/audio/speech",
headers={"Authorization": f"Bearer {tts.openai_api_key}"},
json={
"model": "tts-1-1106",
"input": "Hello world",
"voice": "onyx",
},
)
@patch("requests.post")
def test_run_empty_task(mock_post):
tts = OpenAITTS()
with pytest.raises(Exception):
tts.run("")
@patch("requests.post")
def test_run_very_long_task(mock_post):
tts = OpenAITTS()
with pytest.raises(Exception):
tts.run("A" * 10000)
@patch("requests.post")
def test_run_invalid_task(mock_post):
tts = OpenAITTS()
with pytest.raises(Exception):
tts.run(None)
@patch("requests.post")
def test_run_custom_model(mock_post):
mock_response = MagicMock()
mock_response.iter_content.return_value = [b"chunk1", b"chunk2"]
mock_post.return_value = mock_response
tts = OpenAITTS("custom_model")
audio = tts.run("Hello world")
assert audio == b"chunk1chunk2"
mock_post.assert_called_once_with(
"https://api.openai.com/v1/audio/speech",
headers={"Authorization": f"Bearer {tts.openai_api_key}"},
json={
"model": "custom_model",
"input": "Hello world",
"voice": "onyx",
},
)
@patch("requests.post")
def test_run_custom_voice(mock_post):
mock_response = MagicMock()
mock_response.iter_content.return_value = [b"chunk1", b"chunk2"]
mock_post.return_value = mock_response
tts = OpenAITTS(voice="custom_voice")
audio = tts.run("Hello world")
assert audio == b"chunk1chunk2"
mock_post.assert_called_once_with(
"https://api.openai.com/v1/audio/speech",
headers={"Authorization": f"Bearer {tts.openai_api_key}"},
json={
"model": "tts-1-1106",
"input": "Hello world",
"voice": "custom_voice",
},
)