parent
58c0ee1986
commit
e719e83912
@ -0,0 +1,87 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import openai
|
||||||
|
import requests
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from swarms.models.base_llm import AbstractLLM
|
||||||
|
|
||||||
|
# Load .env file
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# OpenAI API Key env
|
||||||
|
def openai_api_key_env():
|
||||||
|
openai_api_key = os.getenv("OPENAI_API_KEY")
|
||||||
|
return openai_api_key
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAITTS(AbstractLLM):
|
||||||
|
"""OpenAI TTS model
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
model_name (str): _description_
|
||||||
|
proxy_url (str): _description_
|
||||||
|
openai_api_key (str): _description_
|
||||||
|
voice (str): _description_
|
||||||
|
chunk_size (_type_): _description_
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
run: _description_
|
||||||
|
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> from swarms.models.openai_tts import OpenAITTS
|
||||||
|
>>> tts = OpenAITTS(
|
||||||
|
... model_name = "tts-1-1106",
|
||||||
|
... proxy_url = "https://api.openai.com/v1/audio/speech",
|
||||||
|
... openai_api_key = openai_api_key_env,
|
||||||
|
... voice = "onyx",
|
||||||
|
... )
|
||||||
|
>>> tts.run("Hello world")
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str = "tts-1-1106",
|
||||||
|
proxy_url: str = "https://api.openai.com/v1/audio/speech",
|
||||||
|
openai_api_key: str = openai_api_key_env,
|
||||||
|
voice: str = "onyx",
|
||||||
|
chunk_size = 1024 * 1024,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.model_name = model_name
|
||||||
|
self.proxy_url = proxy_url
|
||||||
|
self.openai_api_key = openai_api_key
|
||||||
|
self.voice = voice
|
||||||
|
self.chunk_size = chunk_size
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
task: str,
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
"""Run the tts model
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (str): _description_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
_type_: _description_
|
||||||
|
"""
|
||||||
|
response = requests.post(
|
||||||
|
self.proxy_url,
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.openai_api_key}",
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"model": self.model_name,
|
||||||
|
"input": task,
|
||||||
|
"voice": self.voice,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
audio = b""
|
||||||
|
for chunk in response.iter_content(chunk_size = 1024 * 1024):
|
||||||
|
audio += chunk
|
||||||
|
return audio
|
@ -0,0 +1,79 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
from swarms.models.openai_tts import OpenAITTS
|
||||||
|
|
||||||
|
def test_openaitts_initialization():
|
||||||
|
tts = OpenAITTS()
|
||||||
|
assert isinstance(tts, OpenAITTS)
|
||||||
|
assert tts.model_name == "tts-1-1106"
|
||||||
|
assert tts.proxy_url == "https://api.openai.com/v1/audio/speech"
|
||||||
|
assert tts.voice == "onyx"
|
||||||
|
assert tts.chunk_size == 1024 * 1024
|
||||||
|
|
||||||
|
def test_openaitts_initialization_custom_parameters():
|
||||||
|
tts = OpenAITTS("custom_model", "custom_url", "custom_key", "custom_voice", 2048)
|
||||||
|
assert tts.model_name == "custom_model"
|
||||||
|
assert tts.proxy_url == "custom_url"
|
||||||
|
assert tts.openai_api_key == "custom_key"
|
||||||
|
assert tts.voice == "custom_voice"
|
||||||
|
assert tts.chunk_size == 2048
|
||||||
|
|
||||||
|
@patch("requests.post")
|
||||||
|
def test_run(mock_post):
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.iter_content.return_value = [b"chunk1", b"chunk2"]
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
tts = OpenAITTS()
|
||||||
|
audio = tts.run("Hello world")
|
||||||
|
assert audio == b"chunk1chunk2"
|
||||||
|
mock_post.assert_called_once_with(
|
||||||
|
"https://api.openai.com/v1/audio/speech",
|
||||||
|
headers={"Authorization": f"Bearer {tts.openai_api_key}"},
|
||||||
|
json={"model": "tts-1-1106", "input": "Hello world", "voice": "onyx"},
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch("requests.post")
|
||||||
|
def test_run_empty_task(mock_post):
|
||||||
|
tts = OpenAITTS()
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
tts.run("")
|
||||||
|
|
||||||
|
@patch("requests.post")
|
||||||
|
def test_run_very_long_task(mock_post):
|
||||||
|
tts = OpenAITTS()
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
tts.run("A" * 10000)
|
||||||
|
|
||||||
|
@patch("requests.post")
|
||||||
|
def test_run_invalid_task(mock_post):
|
||||||
|
tts = OpenAITTS()
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
tts.run(None)
|
||||||
|
|
||||||
|
@patch("requests.post")
|
||||||
|
def test_run_custom_model(mock_post):
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.iter_content.return_value = [b"chunk1", b"chunk2"]
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
tts = OpenAITTS("custom_model")
|
||||||
|
audio = tts.run("Hello world")
|
||||||
|
assert audio == b"chunk1chunk2"
|
||||||
|
mock_post.assert_called_once_with(
|
||||||
|
"https://api.openai.com/v1/audio/speech",
|
||||||
|
headers={"Authorization": f"Bearer {tts.openai_api_key}"},
|
||||||
|
json={"model": "custom_model", "input": "Hello world", "voice": "onyx"},
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch("requests.post")
|
||||||
|
def test_run_custom_voice(mock_post):
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.iter_content.return_value = [b"chunk1", b"chunk2"]
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
tts = OpenAITTS(voice="custom_voice")
|
||||||
|
audio = tts.run("Hello world")
|
||||||
|
assert audio == b"chunk1chunk2"
|
||||||
|
mock_post.assert_called_once_with(
|
||||||
|
"https://api.openai.com/v1/audio/speech",
|
||||||
|
headers={"Authorization": f"Bearer {tts.openai_api_key}"},
|
||||||
|
json={"model": "tts-1-1106", "input": "Hello world", "voice": "custom_voice"},
|
||||||
|
)
|
@ -0,0 +1,10 @@
|
|||||||
|
from swarms import OpenAITTS
|
||||||
|
|
||||||
|
tts = OpenAITTS(
|
||||||
|
model_name = "tts-1-1106",
|
||||||
|
voice = "onyx",
|
||||||
|
openai_api_key="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"
|
||||||
|
)
|
||||||
|
|
||||||
|
out = tts("Hello world")
|
||||||
|
print(out)
|
Loading…
Reference in new issue