parent
58c0ee1986
commit
e719e83912
@ -0,0 +1,87 @@
|
||||
import os
|
||||
|
||||
import openai
|
||||
import requests
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from swarms.models.base_llm import AbstractLLM
|
||||
|
||||
# Load .env file
|
||||
load_dotenv()
|
||||
|
||||
# OpenAI API Key env
|
||||
def openai_api_key_env():
|
||||
openai_api_key = os.getenv("OPENAI_API_KEY")
|
||||
return openai_api_key
|
||||
|
||||
|
||||
class OpenAITTS(AbstractLLM):
|
||||
"""OpenAI TTS model
|
||||
|
||||
Attributes:
|
||||
model_name (str): _description_
|
||||
proxy_url (str): _description_
|
||||
openai_api_key (str): _description_
|
||||
voice (str): _description_
|
||||
chunk_size (_type_): _description_
|
||||
|
||||
Methods:
|
||||
run: _description_
|
||||
|
||||
|
||||
Examples:
|
||||
>>> from swarms.models.openai_tts import OpenAITTS
|
||||
>>> tts = OpenAITTS(
|
||||
... model_name = "tts-1-1106",
|
||||
... proxy_url = "https://api.openai.com/v1/audio/speech",
|
||||
... openai_api_key = openai_api_key_env,
|
||||
... voice = "onyx",
|
||||
... )
|
||||
>>> tts.run("Hello world")
|
||||
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "tts-1-1106",
|
||||
proxy_url: str = "https://api.openai.com/v1/audio/speech",
|
||||
openai_api_key: str = openai_api_key_env,
|
||||
voice: str = "onyx",
|
||||
chunk_size = 1024 * 1024,
|
||||
):
|
||||
super().__init__()
|
||||
self.model_name = model_name
|
||||
self.proxy_url = proxy_url
|
||||
self.openai_api_key = openai_api_key
|
||||
self.voice = voice
|
||||
self.chunk_size = chunk_size
|
||||
|
||||
def run(
|
||||
self,
|
||||
task: str,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
"""Run the tts model
|
||||
|
||||
Args:
|
||||
task (str): _description_
|
||||
|
||||
Returns:
|
||||
_type_: _description_
|
||||
"""
|
||||
response = requests.post(
|
||||
self.proxy_url,
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.openai_api_key}",
|
||||
},
|
||||
json={
|
||||
"model": self.model_name,
|
||||
"input": task,
|
||||
"voice": self.voice,
|
||||
},
|
||||
)
|
||||
|
||||
audio = b""
|
||||
for chunk in response.iter_content(chunk_size = 1024 * 1024):
|
||||
audio += chunk
|
||||
return audio
|
@ -0,0 +1,79 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from swarms.models.openai_tts import OpenAITTS
|
||||
|
||||
def test_openaitts_initialization():
|
||||
tts = OpenAITTS()
|
||||
assert isinstance(tts, OpenAITTS)
|
||||
assert tts.model_name == "tts-1-1106"
|
||||
assert tts.proxy_url == "https://api.openai.com/v1/audio/speech"
|
||||
assert tts.voice == "onyx"
|
||||
assert tts.chunk_size == 1024 * 1024
|
||||
|
||||
def test_openaitts_initialization_custom_parameters():
|
||||
tts = OpenAITTS("custom_model", "custom_url", "custom_key", "custom_voice", 2048)
|
||||
assert tts.model_name == "custom_model"
|
||||
assert tts.proxy_url == "custom_url"
|
||||
assert tts.openai_api_key == "custom_key"
|
||||
assert tts.voice == "custom_voice"
|
||||
assert tts.chunk_size == 2048
|
||||
|
||||
@patch("requests.post")
|
||||
def test_run(mock_post):
|
||||
mock_response = MagicMock()
|
||||
mock_response.iter_content.return_value = [b"chunk1", b"chunk2"]
|
||||
mock_post.return_value = mock_response
|
||||
tts = OpenAITTS()
|
||||
audio = tts.run("Hello world")
|
||||
assert audio == b"chunk1chunk2"
|
||||
mock_post.assert_called_once_with(
|
||||
"https://api.openai.com/v1/audio/speech",
|
||||
headers={"Authorization": f"Bearer {tts.openai_api_key}"},
|
||||
json={"model": "tts-1-1106", "input": "Hello world", "voice": "onyx"},
|
||||
)
|
||||
|
||||
@patch("requests.post")
|
||||
def test_run_empty_task(mock_post):
|
||||
tts = OpenAITTS()
|
||||
with pytest.raises(Exception):
|
||||
tts.run("")
|
||||
|
||||
@patch("requests.post")
|
||||
def test_run_very_long_task(mock_post):
|
||||
tts = OpenAITTS()
|
||||
with pytest.raises(Exception):
|
||||
tts.run("A" * 10000)
|
||||
|
||||
@patch("requests.post")
|
||||
def test_run_invalid_task(mock_post):
|
||||
tts = OpenAITTS()
|
||||
with pytest.raises(Exception):
|
||||
tts.run(None)
|
||||
|
||||
@patch("requests.post")
|
||||
def test_run_custom_model(mock_post):
|
||||
mock_response = MagicMock()
|
||||
mock_response.iter_content.return_value = [b"chunk1", b"chunk2"]
|
||||
mock_post.return_value = mock_response
|
||||
tts = OpenAITTS("custom_model")
|
||||
audio = tts.run("Hello world")
|
||||
assert audio == b"chunk1chunk2"
|
||||
mock_post.assert_called_once_with(
|
||||
"https://api.openai.com/v1/audio/speech",
|
||||
headers={"Authorization": f"Bearer {tts.openai_api_key}"},
|
||||
json={"model": "custom_model", "input": "Hello world", "voice": "onyx"},
|
||||
)
|
||||
|
||||
@patch("requests.post")
|
||||
def test_run_custom_voice(mock_post):
|
||||
mock_response = MagicMock()
|
||||
mock_response.iter_content.return_value = [b"chunk1", b"chunk2"]
|
||||
mock_post.return_value = mock_response
|
||||
tts = OpenAITTS(voice="custom_voice")
|
||||
audio = tts.run("Hello world")
|
||||
assert audio == b"chunk1chunk2"
|
||||
mock_post.assert_called_once_with(
|
||||
"https://api.openai.com/v1/audio/speech",
|
||||
headers={"Authorization": f"Bearer {tts.openai_api_key}"},
|
||||
json={"model": "tts-1-1106", "input": "Hello world", "voice": "custom_voice"},
|
||||
)
|
@ -0,0 +1,10 @@
|
||||
from swarms import OpenAITTS
|
||||
|
||||
tts = OpenAITTS(
|
||||
model_name = "tts-1-1106",
|
||||
voice = "onyx",
|
||||
openai_api_key="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"
|
||||
)
|
||||
|
||||
out = tts("Hello world")
|
||||
print(out)
|
Loading…
Reference in new issue