You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
79 lines
2.7 KiB
79 lines
2.7 KiB
1 year ago
|
import pytest
|
||
|
from unittest.mock import patch, MagicMock
|
||
|
from swarms.models.openai_tts import OpenAITTS
|
||
|
|
||
|
def test_openaitts_initialization():
|
||
|
tts = OpenAITTS()
|
||
|
assert isinstance(tts, OpenAITTS)
|
||
|
assert tts.model_name == "tts-1-1106"
|
||
|
assert tts.proxy_url == "https://api.openai.com/v1/audio/speech"
|
||
|
assert tts.voice == "onyx"
|
||
|
assert tts.chunk_size == 1024 * 1024
|
||
|
|
||
|
def test_openaitts_initialization_custom_parameters():
|
||
|
tts = OpenAITTS("custom_model", "custom_url", "custom_key", "custom_voice", 2048)
|
||
|
assert tts.model_name == "custom_model"
|
||
|
assert tts.proxy_url == "custom_url"
|
||
|
assert tts.openai_api_key == "custom_key"
|
||
|
assert tts.voice == "custom_voice"
|
||
|
assert tts.chunk_size == 2048
|
||
|
|
||
|
@patch("requests.post")
|
||
|
def test_run(mock_post):
|
||
|
mock_response = MagicMock()
|
||
|
mock_response.iter_content.return_value = [b"chunk1", b"chunk2"]
|
||
|
mock_post.return_value = mock_response
|
||
|
tts = OpenAITTS()
|
||
|
audio = tts.run("Hello world")
|
||
|
assert audio == b"chunk1chunk2"
|
||
|
mock_post.assert_called_once_with(
|
||
|
"https://api.openai.com/v1/audio/speech",
|
||
|
headers={"Authorization": f"Bearer {tts.openai_api_key}"},
|
||
|
json={"model": "tts-1-1106", "input": "Hello world", "voice": "onyx"},
|
||
|
)
|
||
|
|
||
|
@patch("requests.post")
|
||
|
def test_run_empty_task(mock_post):
|
||
|
tts = OpenAITTS()
|
||
|
with pytest.raises(Exception):
|
||
|
tts.run("")
|
||
|
|
||
|
@patch("requests.post")
|
||
|
def test_run_very_long_task(mock_post):
|
||
|
tts = OpenAITTS()
|
||
|
with pytest.raises(Exception):
|
||
|
tts.run("A" * 10000)
|
||
|
|
||
|
@patch("requests.post")
|
||
|
def test_run_invalid_task(mock_post):
|
||
|
tts = OpenAITTS()
|
||
|
with pytest.raises(Exception):
|
||
|
tts.run(None)
|
||
|
|
||
|
@patch("requests.post")
|
||
|
def test_run_custom_model(mock_post):
|
||
|
mock_response = MagicMock()
|
||
|
mock_response.iter_content.return_value = [b"chunk1", b"chunk2"]
|
||
|
mock_post.return_value = mock_response
|
||
|
tts = OpenAITTS("custom_model")
|
||
|
audio = tts.run("Hello world")
|
||
|
assert audio == b"chunk1chunk2"
|
||
|
mock_post.assert_called_once_with(
|
||
|
"https://api.openai.com/v1/audio/speech",
|
||
|
headers={"Authorization": f"Bearer {tts.openai_api_key}"},
|
||
|
json={"model": "custom_model", "input": "Hello world", "voice": "onyx"},
|
||
|
)
|
||
|
|
||
|
@patch("requests.post")
|
||
|
def test_run_custom_voice(mock_post):
|
||
|
mock_response = MagicMock()
|
||
|
mock_response.iter_content.return_value = [b"chunk1", b"chunk2"]
|
||
|
mock_post.return_value = mock_response
|
||
|
tts = OpenAITTS(voice="custom_voice")
|
||
|
audio = tts.run("Hello world")
|
||
|
assert audio == b"chunk1chunk2"
|
||
|
mock_post.assert_called_once_with(
|
||
|
"https://api.openai.com/v1/audio/speech",
|
||
|
headers={"Authorization": f"Bearer {tts.openai_api_key}"},
|
||
|
json={"model": "tts-1-1106", "input": "Hello world", "voice": "custom_voice"},
|
||
|
)
|