[OpenAITTS][Save to filepath] [FEAT][BaseTTSModel]

pull/289/head
Kye 1 year ago
parent 01501ac2ad
commit ec580fa60e

@ -1,10 +1,10 @@
from swarms import OpenAITTS
tts = OpenAITTS(
model_name = "tts-1-1106",
voice = "onyx",
openai_api_key="sk"
model_name="tts-1-1106",
voice="onyx",
openai_api_key="YOUR_API_KEY",
)
out = tts("Hello world")
out = tts.run_and_save("pliny is a girl and a chicken")
print(out)

@ -8,6 +8,7 @@ from swarms.models.openai_models import (
AzureOpenAI,
OpenAIChat,
) # noqa: E402
# from swarms.models.vllm import vLLM # noqa: E402
# from swarms.models.zephyr import Zephyr # noqa: E402

@ -0,0 +1,41 @@
import wave
from typing import Optional
from swarms.models.base_llm import AbstractLLM
from abc import ABC, abstractmethod
class BaseTTSModel(AbstractLLM):
def __init__(
self,
model_name,
voice,
chunk_size,
save_to_file: bool = False,
saved_filepath: Optional[str] = None,
):
self.model_name = model_name
self.voice = voice
self.chunk_size = chunk_size
def save(self, filepath: Optional[str] = None):
pass
def load(self, filepath: Optional[str] = None):
pass
@abstractmethod
def run(self, task: str, *args, **kwargs):
pass
def save_to_file(self, speech_data, filename):
"""Save the speech data to a file.
Args:
speech_data (bytes): The speech data.
filename (str): The path to the file where the speech will be saved.
"""
with wave.open(filename, "wb") as file:
file.setnchannels(1)
file.setsampwidth(2)
file.setframerate(22050)
file.writeframes(speech_data)

@ -1,14 +1,24 @@
import os
import sys
import openai
import requests
from dotenv import load_dotenv
import subprocess
from swarms.models.base_llm import AbstractLLM
try:
import wave
except ImportError as error:
print(f"Import Error: {error} - Please install pyaudio")
subprocess.check_call(
[sys.executable, "-m", "pip", "install", "pyaudio"]
)
# Load .env file
load_dotenv()
# OpenAI API Key env
def openai_api_key_env():
openai_api_key = os.getenv("OPENAI_API_KEY")
@ -40,13 +50,16 @@ class OpenAITTS(AbstractLLM):
>>> tts.run("Hello world")
"""
def __init__(
self,
model_name: str = "tts-1-1106",
proxy_url: str = "https://api.openai.com/v1/audio/speech",
openai_api_key: str = openai_api_key_env,
voice: str = "onyx",
chunk_size = 1024 * 1024,
chunk_size=1024 * 1024,
autosave: bool = False,
saved_filepath: str = None,
):
super().__init__()
self.model_name = model_name
@ -54,13 +67,12 @@ class OpenAITTS(AbstractLLM):
self.openai_api_key = openai_api_key
self.voice = voice
self.chunk_size = chunk_size
self.autosave = autosave
self.saved_filepath = saved_filepath
def run(
self,
task: str,
*args,
**kwargs
):
self.saved_filepath = "runs/tts_speech.wav"
def run(self, task: str, *args, **kwargs):
"""Run the tts model
Args:
@ -71,7 +83,7 @@ class OpenAITTS(AbstractLLM):
"""
response = requests.post(
self.proxy_url,
headers = {
headers={
"Authorization": f"Bearer {self.openai_api_key}",
},
json={
@ -82,6 +94,28 @@ class OpenAITTS(AbstractLLM):
)
audio = b""
for chunk in response.iter_content(chunk_size = 1024 * 1024):
for chunk in response.iter_content(chunk_size=1024 * 1024):
audio += chunk
return audio
def run_and_save(self, task: str = None, *args, **kwargs):
"""Run the TTS model and save the output to a file.
Args:
task (str): The text to be converted to speech.
filename (str): The path to the file where the speech will be saved.
Returns:
bytes: The speech data.
"""
# Run the TTS model.
speech_data = self.run(task)
# Save the speech data to a file.
with wave.open(self.saved_filepath, "wb") as file:
file.setnchannels(1)
file.setsampwidth(2)
file.setframerate(22050)
file.writeframes(speech_data)
return speech_data

@ -23,7 +23,6 @@ from termcolor import colored
from swarms.structs.agent import Agent
class AbstractSwarm(ABC):
"""
Abstract Swarm Class for multi-agent systems
@ -161,9 +160,7 @@ class AbstractSwarm(ABC):
pass
# @abstractmethod
def assign_task(
self, agent: "Agent", task: Any
) -> Dict:
def assign_task(self, agent: "Agent", task: Any) -> Dict:
"""Assign a task to a agent"""
pass

@ -2,6 +2,7 @@ import pytest
from unittest.mock import patch, MagicMock
from swarms.models.openai_tts import OpenAITTS
def test_openaitts_initialization():
tts = OpenAITTS()
assert isinstance(tts, OpenAITTS)
@ -10,14 +11,22 @@ def test_openaitts_initialization():
assert tts.voice == "onyx"
assert tts.chunk_size == 1024 * 1024
def test_openaitts_initialization_custom_parameters():
tts = OpenAITTS("custom_model", "custom_url", "custom_key", "custom_voice", 2048)
tts = OpenAITTS(
"custom_model",
"custom_url",
"custom_key",
"custom_voice",
2048,
)
assert tts.model_name == "custom_model"
assert tts.proxy_url == "custom_url"
assert tts.openai_api_key == "custom_key"
assert tts.voice == "custom_voice"
assert tts.chunk_size == 2048
@patch("requests.post")
def test_run(mock_post):
mock_response = MagicMock()
@ -29,27 +38,35 @@ def test_run(mock_post):
mock_post.assert_called_once_with(
"https://api.openai.com/v1/audio/speech",
headers={"Authorization": f"Bearer {tts.openai_api_key}"},
json={"model": "tts-1-1106", "input": "Hello world", "voice": "onyx"},
json={
"model": "tts-1-1106",
"input": "Hello world",
"voice": "onyx",
},
)
@patch("requests.post")
def test_run_empty_task(mock_post):
tts = OpenAITTS()
with pytest.raises(Exception):
tts.run("")
@patch("requests.post")
def test_run_very_long_task(mock_post):
tts = OpenAITTS()
with pytest.raises(Exception):
tts.run("A" * 10000)
@patch("requests.post")
def test_run_invalid_task(mock_post):
tts = OpenAITTS()
with pytest.raises(Exception):
tts.run(None)
@patch("requests.post")
def test_run_custom_model(mock_post):
mock_response = MagicMock()
@ -61,9 +78,14 @@ def test_run_custom_model(mock_post):
mock_post.assert_called_once_with(
"https://api.openai.com/v1/audio/speech",
headers={"Authorization": f"Bearer {tts.openai_api_key}"},
json={"model": "custom_model", "input": "Hello world", "voice": "onyx"},
json={
"model": "custom_model",
"input": "Hello world",
"voice": "onyx",
},
)
@patch("requests.post")
def test_run_custom_voice(mock_post):
mock_response = MagicMock()
@ -75,5 +97,9 @@ def test_run_custom_voice(mock_post):
mock_post.assert_called_once_with(
"https://api.openai.com/v1/audio/speech",
headers={"Authorization": f"Bearer {tts.openai_api_key}"},
json={"model": "tts-1-1106", "input": "Hello world", "voice": "custom_voice"},
json={
"model": "tts-1-1106",
"input": "Hello world",
"voice": "custom_voice",
},
)
Loading…
Cancel
Save