You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/tests/models/test_openaitts.py

79 lines
2.7 KiB

import pytest
from unittest.mock import patch, MagicMock
from swarms.models.openai_tts import OpenAITTS
def test_openaitts_initialization():
tts = OpenAITTS()
assert isinstance(tts, OpenAITTS)
assert tts.model_name == "tts-1-1106"
assert tts.proxy_url == "https://api.openai.com/v1/audio/speech"
assert tts.voice == "onyx"
assert tts.chunk_size == 1024 * 1024
def test_openaitts_initialization_custom_parameters():
tts = OpenAITTS("custom_model", "custom_url", "custom_key", "custom_voice", 2048)
assert tts.model_name == "custom_model"
assert tts.proxy_url == "custom_url"
assert tts.openai_api_key == "custom_key"
assert tts.voice == "custom_voice"
assert tts.chunk_size == 2048
@patch("requests.post")
def test_run(mock_post):
mock_response = MagicMock()
mock_response.iter_content.return_value = [b"chunk1", b"chunk2"]
mock_post.return_value = mock_response
tts = OpenAITTS()
audio = tts.run("Hello world")
assert audio == b"chunk1chunk2"
mock_post.assert_called_once_with(
"https://api.openai.com/v1/audio/speech",
headers={"Authorization": f"Bearer {tts.openai_api_key}"},
json={"model": "tts-1-1106", "input": "Hello world", "voice": "onyx"},
)
@patch("requests.post")
def test_run_empty_task(mock_post):
tts = OpenAITTS()
with pytest.raises(Exception):
tts.run("")
@patch("requests.post")
def test_run_very_long_task(mock_post):
tts = OpenAITTS()
with pytest.raises(Exception):
tts.run("A" * 10000)
@patch("requests.post")
def test_run_invalid_task(mock_post):
tts = OpenAITTS()
with pytest.raises(Exception):
tts.run(None)
@patch("requests.post")
def test_run_custom_model(mock_post):
mock_response = MagicMock()
mock_response.iter_content.return_value = [b"chunk1", b"chunk2"]
mock_post.return_value = mock_response
tts = OpenAITTS("custom_model")
audio = tts.run("Hello world")
assert audio == b"chunk1chunk2"
mock_post.assert_called_once_with(
"https://api.openai.com/v1/audio/speech",
headers={"Authorization": f"Bearer {tts.openai_api_key}"},
json={"model": "custom_model", "input": "Hello world", "voice": "onyx"},
)
@patch("requests.post")
def test_run_custom_voice(mock_post):
mock_response = MagicMock()
mock_response.iter_content.return_value = [b"chunk1", b"chunk2"]
mock_post.return_value = mock_response
tts = OpenAITTS(voice="custom_voice")
audio = tts.run("Hello world")
assert audio == b"chunk1chunk2"
mock_post.assert_called_once_with(
"https://api.openai.com/v1/audio/speech",
headers={"Authorization": f"Bearer {tts.openai_api_key}"},
json={"model": "tts-1-1106", "input": "Hello world", "voice": "custom_voice"},
)