[OpenAITTS][Save to filepath] [FEAT][BaseTTSModel]

pull/289/head
Kye 1 year ago
parent 01501ac2ad
commit ec580fa60e

@ -3,8 +3,8 @@ from swarms import OpenAITTS
tts = OpenAITTS( tts = OpenAITTS(
model_name="tts-1-1106", model_name="tts-1-1106",
voice="onyx", voice="onyx",
openai_api_key="sk" openai_api_key="YOUR_API_KEY",
) )
out = tts("Hello world") out = tts.run_and_save("pliny is a girl and a chicken")
print(out) print(out)

@ -8,6 +8,7 @@ from swarms.models.openai_models import (
AzureOpenAI, AzureOpenAI,
OpenAIChat, OpenAIChat,
) # noqa: E402 ) # noqa: E402
# from swarms.models.vllm import vLLM # noqa: E402 # from swarms.models.vllm import vLLM # noqa: E402
# from swarms.models.zephyr import Zephyr # noqa: E402 # from swarms.models.zephyr import Zephyr # noqa: E402

@ -0,0 +1,41 @@
import wave
from typing import Optional
from swarms.models.base_llm import AbstractLLM
from abc import ABC, abstractmethod
class BaseTTSModel(AbstractLLM):
def __init__(
self,
model_name,
voice,
chunk_size,
save_to_file: bool = False,
saved_filepath: Optional[str] = None,
):
self.model_name = model_name
self.voice = voice
self.chunk_size = chunk_size
def save(self, filepath: Optional[str] = None):
pass
def load(self, filepath: Optional[str] = None):
pass
@abstractmethod
def run(self, task: str, *args, **kwargs):
pass
def save_to_file(self, speech_data, filename):
"""Save the speech data to a file.
Args:
speech_data (bytes): The speech data.
filename (str): The path to the file where the speech will be saved.
"""
with wave.open(filename, "wb") as file:
file.setnchannels(1)
file.setsampwidth(2)
file.setframerate(22050)
file.writeframes(speech_data)

@ -1,14 +1,24 @@
import os import os
import sys
import openai import openai
import requests import requests
from dotenv import load_dotenv from dotenv import load_dotenv
import subprocess
from swarms.models.base_llm import AbstractLLM from swarms.models.base_llm import AbstractLLM
try:
import wave
except ImportError as error:
print(f"Import Error: {error} - Please install pyaudio")
subprocess.check_call(
[sys.executable, "-m", "pip", "install", "pyaudio"]
)
# Load .env file # Load .env file
load_dotenv() load_dotenv()
# OpenAI API Key env # OpenAI API Key env
def openai_api_key_env(): def openai_api_key_env():
openai_api_key = os.getenv("OPENAI_API_KEY") openai_api_key = os.getenv("OPENAI_API_KEY")
@ -40,6 +50,7 @@ class OpenAITTS(AbstractLLM):
>>> tts.run("Hello world") >>> tts.run("Hello world")
""" """
def __init__( def __init__(
self, self,
model_name: str = "tts-1-1106", model_name: str = "tts-1-1106",
@ -47,6 +58,8 @@ class OpenAITTS(AbstractLLM):
openai_api_key: str = openai_api_key_env, openai_api_key: str = openai_api_key_env,
voice: str = "onyx", voice: str = "onyx",
chunk_size=1024 * 1024, chunk_size=1024 * 1024,
autosave: bool = False,
saved_filepath: str = None,
): ):
super().__init__() super().__init__()
self.model_name = model_name self.model_name = model_name
@ -54,13 +67,12 @@ class OpenAITTS(AbstractLLM):
self.openai_api_key = openai_api_key self.openai_api_key = openai_api_key
self.voice = voice self.voice = voice
self.chunk_size = chunk_size self.chunk_size = chunk_size
self.autosave = autosave
self.saved_filepath = saved_filepath
def run( self.saved_filepath = "runs/tts_speech.wav"
self,
task: str, def run(self, task: str, *args, **kwargs):
*args,
**kwargs
):
"""Run the tts model """Run the tts model
Args: Args:
@ -85,3 +97,25 @@ class OpenAITTS(AbstractLLM):
for chunk in response.iter_content(chunk_size=1024 * 1024): for chunk in response.iter_content(chunk_size=1024 * 1024):
audio += chunk audio += chunk
return audio return audio
def run_and_save(self, task: str = None, *args, **kwargs):
"""Run the TTS model and save the output to a file.
Args:
task (str): The text to be converted to speech.
filename (str): The path to the file where the speech will be saved.
Returns:
bytes: The speech data.
"""
# Run the TTS model.
speech_data = self.run(task)
# Save the speech data to a file.
with wave.open(self.saved_filepath, "wb") as file:
file.setnchannels(1)
file.setsampwidth(2)
file.setframerate(22050)
file.writeframes(speech_data)
return speech_data

@ -23,7 +23,6 @@ from termcolor import colored
from swarms.structs.agent import Agent from swarms.structs.agent import Agent
class AbstractSwarm(ABC): class AbstractSwarm(ABC):
""" """
Abstract Swarm Class for multi-agent systems Abstract Swarm Class for multi-agent systems
@ -161,9 +160,7 @@ class AbstractSwarm(ABC):
pass pass
# @abstractmethod # @abstractmethod
def assign_task( def assign_task(self, agent: "Agent", task: Any) -> Dict:
self, agent: "Agent", task: Any
) -> Dict:
"""Assign a task to a agent""" """Assign a task to a agent"""
pass pass

@ -2,6 +2,7 @@ import pytest
from unittest.mock import patch, MagicMock from unittest.mock import patch, MagicMock
from swarms.models.openai_tts import OpenAITTS from swarms.models.openai_tts import OpenAITTS
def test_openaitts_initialization(): def test_openaitts_initialization():
tts = OpenAITTS() tts = OpenAITTS()
assert isinstance(tts, OpenAITTS) assert isinstance(tts, OpenAITTS)
@ -10,14 +11,22 @@ def test_openaitts_initialization():
assert tts.voice == "onyx" assert tts.voice == "onyx"
assert tts.chunk_size == 1024 * 1024 assert tts.chunk_size == 1024 * 1024
def test_openaitts_initialization_custom_parameters(): def test_openaitts_initialization_custom_parameters():
tts = OpenAITTS("custom_model", "custom_url", "custom_key", "custom_voice", 2048) tts = OpenAITTS(
"custom_model",
"custom_url",
"custom_key",
"custom_voice",
2048,
)
assert tts.model_name == "custom_model" assert tts.model_name == "custom_model"
assert tts.proxy_url == "custom_url" assert tts.proxy_url == "custom_url"
assert tts.openai_api_key == "custom_key" assert tts.openai_api_key == "custom_key"
assert tts.voice == "custom_voice" assert tts.voice == "custom_voice"
assert tts.chunk_size == 2048 assert tts.chunk_size == 2048
@patch("requests.post") @patch("requests.post")
def test_run(mock_post): def test_run(mock_post):
mock_response = MagicMock() mock_response = MagicMock()
@ -29,27 +38,35 @@ def test_run(mock_post):
mock_post.assert_called_once_with( mock_post.assert_called_once_with(
"https://api.openai.com/v1/audio/speech", "https://api.openai.com/v1/audio/speech",
headers={"Authorization": f"Bearer {tts.openai_api_key}"}, headers={"Authorization": f"Bearer {tts.openai_api_key}"},
json={"model": "tts-1-1106", "input": "Hello world", "voice": "onyx"}, json={
"model": "tts-1-1106",
"input": "Hello world",
"voice": "onyx",
},
) )
@patch("requests.post") @patch("requests.post")
def test_run_empty_task(mock_post): def test_run_empty_task(mock_post):
tts = OpenAITTS() tts = OpenAITTS()
with pytest.raises(Exception): with pytest.raises(Exception):
tts.run("") tts.run("")
@patch("requests.post") @patch("requests.post")
def test_run_very_long_task(mock_post): def test_run_very_long_task(mock_post):
tts = OpenAITTS() tts = OpenAITTS()
with pytest.raises(Exception): with pytest.raises(Exception):
tts.run("A" * 10000) tts.run("A" * 10000)
@patch("requests.post") @patch("requests.post")
def test_run_invalid_task(mock_post): def test_run_invalid_task(mock_post):
tts = OpenAITTS() tts = OpenAITTS()
with pytest.raises(Exception): with pytest.raises(Exception):
tts.run(None) tts.run(None)
@patch("requests.post") @patch("requests.post")
def test_run_custom_model(mock_post): def test_run_custom_model(mock_post):
mock_response = MagicMock() mock_response = MagicMock()
@ -61,9 +78,14 @@ def test_run_custom_model(mock_post):
mock_post.assert_called_once_with( mock_post.assert_called_once_with(
"https://api.openai.com/v1/audio/speech", "https://api.openai.com/v1/audio/speech",
headers={"Authorization": f"Bearer {tts.openai_api_key}"}, headers={"Authorization": f"Bearer {tts.openai_api_key}"},
json={"model": "custom_model", "input": "Hello world", "voice": "onyx"}, json={
"model": "custom_model",
"input": "Hello world",
"voice": "onyx",
},
) )
@patch("requests.post") @patch("requests.post")
def test_run_custom_voice(mock_post): def test_run_custom_voice(mock_post):
mock_response = MagicMock() mock_response = MagicMock()
@ -75,5 +97,9 @@ def test_run_custom_voice(mock_post):
mock_post.assert_called_once_with( mock_post.assert_called_once_with(
"https://api.openai.com/v1/audio/speech", "https://api.openai.com/v1/audio/speech",
headers={"Authorization": f"Bearer {tts.openai_api_key}"}, headers={"Authorization": f"Bearer {tts.openai_api_key}"},
json={"model": "tts-1-1106", "input": "Hello world", "voice": "custom_voice"}, json={
"model": "tts-1-1106",
"input": "Hello world",
"voice": "custom_voice",
},
) )
Loading…
Cancel
Save