add async interpreter with coqui, openai, elevenlabs tts

pull/279/head
Ben Xu 7 months ago
parent 2627fba481
commit eee00ac026

@ -90,6 +90,7 @@ class Device:
self.audiosegments = asyncio.Queue()
self.server_url = ""
self.ctrl_pressed = False
self.tts_service = ""
self.playback_latency = None
def fetch_image_from_camera(self, camera_index=CAMERA_DEVICE_INDEX):
@ -164,30 +165,18 @@ class Device:
while True:
try:
audio = await self.audiosegments.get()
# print("got audio segment!!!!")
if self.playback_latency:
if self.playback_latency and isinstance(audio, bytes):
elapsed_time = time.time() - self.playback_latency
print(f"Time from request to playback: {elapsed_time} seconds")
# print(f"Time from request to playback: {elapsed_time} seconds")
self.playback_latency = None
if audio is not None:
if self.tts_service == "elevenlabs":
mpv_process.stdin.write(audio) # type: ignore
mpv_process.stdin.flush() # type: ignore
"""
args = ["ffplay", "-autoexit", "-", "-nodisp"]
proc = subprocess.Popen(
args=args,
stdout=subprocess.PIPE,
stdin=subprocess.PIPE,
stderr=subprocess.PIPE,
)
out, err = proc.communicate(input=audio)
proc.poll()
else:
play(audio)
"""
# self.audiosegments.remove(audio)
# await asyncio.sleep(0.1)
await asyncio.sleep(0.1)
except asyncio.exceptions.CancelledError:
# This happens once at the start?
pass
@ -342,24 +331,17 @@ class Device:
async def message_sender(self, websocket):
while True:
try:
message = await asyncio.get_event_loop().run_in_executor(
None, send_queue.get
)
if isinstance(message, bytes):
await websocket.send(message)
else:
await websocket.send(json.dumps(message))
send_queue.task_done()
await asyncio.sleep(0.01)
except:
traceback.print_exc()
async def websocket_communication(self, WS_URL):
print("websocket communication was called!!!!")
show_connection_log = True
async def exec_ws_communication(websocket):
@ -373,48 +355,48 @@ class Device:
asyncio.create_task(self.message_sender(websocket))
while True:
await asyncio.sleep(0.0001)
await asyncio.sleep(0.01)
chunk = await websocket.recv()
logger.debug(f"Got this message from the server: {type(chunk)} {chunk}")
# print((f"Got this message from the server: {type(chunk)} {chunk}"))
# print("received chunk from server")
if type(chunk) == str:
chunk = json.loads(chunk)
# message = accumulator.accumulate(chunk)
if self.tts_service == "elevenlabs":
message = chunk
else:
message = accumulator.accumulate(chunk)
if message == None:
# Will be None until we have a full message ready
continue
# At this point, we have our message
# print("checkpoint reached!", message)
if isinstance(message, bytes):
# if message["type"] == "audio" and message["format"].startswith("bytes"):
if isinstance(message, bytes) or (
message["type"] == "audio" and message["format"].startswith("bytes")
):
# Convert bytes to audio file
# audio_bytes = message["content"]
if self.tts_service == "elevenlabs":
audio_bytes = message
audio = audio_bytes
else:
audio_bytes = message["content"]
# Create an AudioSegment instance with the raw data
"""
audio = AudioSegment(
# raw audio data (bytes)
data=audio_bytes,
# signed 16-bit little-endian format
sample_width=2,
# 24,000 Hz frame rate
frame_rate=16000,
# 16,000 Hz frame rate
frame_rate=22050,
# mono sound
channels=1,
)
"""
# print("audio segment was created")
await self.audiosegments.put(audio_bytes)
# await self.audiosegments.put(audio)
await self.audiosegments.put(audio)
# Run the code if that's the client's job
if os.getenv("CODE_RUNNER") == "client":
@ -434,29 +416,26 @@ class Device:
except Exception as e:
logger.error(f"Error while attempting to connect: {e}")
else:
print("websocket url is", WS_URL)
while True:
try:
async with websockets.connect(WS_URL) as websocket:
print("awaiting exec_ws_communication")
await exec_ws_communication(websocket)
except:
logger.info(traceback.format_exc())
logger.debug(traceback.format_exc())
if show_connection_log:
logger.info(f"Connecting to `{WS_URL}`...")
show_connection_log = False
await asyncio.sleep(2)
async def start_async(self):
print("start async was called!!!!!")
# Configuration for WebSocket
WS_URL = f"ws://{self.server_url}"
# Start the WebSocket communication
asyncio.create_task(self.websocket_communication(WS_URL))
# Start watching the kernel if it's your job to do that
if os.getenv("CODE_RUNNER") == "client":
# client is not running code!
asyncio.create_task(put_kernel_messages_into_queue(send_queue))
asyncio.create_task(self.play_audiosegments())
@ -488,10 +467,8 @@ class Device:
on_press=self.on_press, on_release=self.on_release
)
listener.start()
print("listener for keyboard started!!!!!")
def start(self):
print("device was started!!!!!!")
if os.getenv("TEACH_MODE") != "True":
asyncio.run(self.start_async())
p.terminate()

@ -3,8 +3,9 @@ from ..base_device import Device
device = Device()
def main(server_url):
def main(server_url, tts_service):
device.server_url = server_url
device.tts_service = tts_service
device.start()

@ -3,8 +3,9 @@ from ..base_device import Device
device = Device()
def main(server_url):
def main(server_url, tts_service):
device.server_url = server_url
device.tts_service = tts_service
device.start()

@ -3,8 +3,9 @@ from ..base_device import Device
device = Device()
def main(server_url):
def main(server_url, tts_service):
device.server_url = server_url
device.tts_service = tts_service
device.start()

@ -10,16 +10,9 @@
"""
###
from pynput import keyboard
from RealtimeTTS import (
TextToAudioStream,
OpenAIEngine,
CoquiEngine,
ElevenlabsEngine,
SystemEngine,
GTTSEngine,
)
from RealtimeTTS import TextToAudioStream, CoquiEngine, OpenAIEngine, ElevenlabsEngine
from RealtimeSTT import AudioToTextRecorder
import time
import asyncio
@ -29,9 +22,9 @@ import os
class AsyncInterpreter:
def __init__(self, interpreter):
self.stt_latency = None
self.tts_latency = None
self.interpreter_latency = None
# self.stt_latency = None
# self.tts_latency = None
# self.interpreter_latency = None
self.interpreter = interpreter
# STT
@ -45,12 +38,9 @@ class AsyncInterpreter:
engine = CoquiEngine()
elif self.interpreter.tts == "openai":
engine = OpenAIEngine()
elif self.interpreter.tts == "gtts":
engine = GTTSEngine()
elif self.interpreter.tts == "elevenlabs":
engine = ElevenlabsEngine(api_key=os.environ["ELEVEN_LABS_API_KEY"])
elif self.interpreter.tts == "system":
engine = SystemEngine()
engine.set_voice("Michael")
else:
raise ValueError(f"Unsupported TTS engine: {self.interpreter.tts}")
self.tts = TextToAudioStream(engine)
@ -112,41 +102,11 @@ class AsyncInterpreter:
# print("ADDING TO QUEUE:", chunk)
asyncio.create_task(self._add_to_queue(self._output_queue, chunk))
async def run(self):
"""
Runs OI on the audio bytes submitted to the input. Will add streaming LMC chunks to the _output_queue.
"""
self.interpreter.messages = self.active_chat_messages
# self.beeper.start()
self.stt.stop()
# message = self.stt.text()
# print("THE MESSAGE:", message)
# accumulates the input queue message
input_queue = []
while not self._input_queue.empty():
input_queue.append(self._input_queue.get())
# print("INPUT QUEUE:", input_queue)
# message = [i for i in input_queue if i["type"] == "message"][0]["content"]
start_stt = time.time()
message = self.stt.text()
end_stt = time.time()
self.stt_latency = end_stt - start_stt
print("STT LATENCY", self.stt_latency)
# print(message)
end_interpreter = 0
# print(message)
def generate(message):
def generate(self, message, start_interpreter):
last_lmc_start_flag = self._last_lmc_start_flag
self.interpreter.messages = self.active_chat_messages
# print("🍀🍀🍀🍀GENERATING, using these messages: ", self.interpreter.messages)
# print("🍀 🍀 🍀 🍀 active_chat_messages: ", self.active_chat_messages)
print("message is", message)
# print("message is", message)
for chunk in self.interpreter.chat(message, display=True, stream=True):
@ -165,7 +125,7 @@ class AsyncInterpreter:
# Experimental: The AI voice sounds better with replacements like these, but it should happen at the TTS layer
# content = content.replace(". ", ". ... ").replace(", ", ", ... ").replace("!", "! ... ").replace("?", "? ... ")
# print("yielding this", content)
# print("yielding ", content)
yield content
# Handle code blocks
@ -196,27 +156,42 @@ class AsyncInterpreter:
)
# Send a completion signal
end_interpreter = time.time()
self.interpreter_latency = end_interpreter - start_interpreter
print("INTERPRETER LATENCY", self.interpreter_latency)
# end_interpreter = time.time()
# self.interpreter_latency = end_interpreter - start_interpreter
# print("INTERPRETER LATENCY", self.interpreter_latency)
# self.add_to_output_queue_sync({"role": "server","type": "completion", "content": "DONE"})
async def run(self):
"""
Runs OI on the audio bytes submitted to the input. Will add streaming LMC chunks to the _output_queue.
"""
self.interpreter.messages = self.active_chat_messages
self.stt.stop()
input_queue = []
while not self._input_queue.empty():
input_queue.append(self._input_queue.get())
# start_stt = time.time()
message = self.stt.text()
# end_stt = time.time()
# self.stt_latency = end_stt - start_stt
# print("STT LATENCY", self.stt_latency)
# print(message)
# Feed generate to RealtimeTTS
self.add_to_output_queue_sync(
{"role": "assistant", "type": "audio", "format": "bytes.wav", "start": True}
)
start_interpreter = time.time()
text_iterator = generate(message)
text_iterator = self.generate(message, start_interpreter)
self.tts.feed(text_iterator)
self.tts.play_async(on_audio_chunk=self.on_tts_chunk, muted=True)
while True:
if self.tts.is_playing():
start_tts = time.time()
self.tts.play_async(on_audio_chunk=self.on_tts_chunk, muted=True)
break
await asyncio.sleep(0.1)
while True:
await asyncio.sleep(0.1)
# print("is_playing", self.tts.is_playing())
@ -229,14 +204,14 @@ class AsyncInterpreter:
"end": True,
}
)
end_tts = time.time()
self.tts_latency = end_tts - start_tts
print("TTS LATENCY", self.tts_latency)
# end_tts = time.time()
# self.tts_latency = end_tts - self.tts.stream_start_time
# print("TTS LATENCY", self.tts_latency)
self.tts.stop()
break
async def _on_tts_chunk_async(self, chunk):
# print("SENDING TTS CHUNK")
# print("adding chunk to queue")
await self._add_to_queue(self._output_queue, chunk)
def on_tts_chunk(self, chunk):
@ -244,4 +219,5 @@ class AsyncInterpreter:
asyncio.run(self._on_tts_chunk_async(chunk))
async def output(self):
# print("outputting chunks")
return await self._output_queue.get()

@ -1,20 +1,23 @@
import asyncio
import traceback
import json
from fastapi import FastAPI, WebSocket, Header
from fastapi import FastAPI, WebSocket
from fastapi.responses import PlainTextResponse
from uvicorn import Config, Server
from .i import configure_interpreter
from interpreter import interpreter as base_interpreter
from .async_interpreter import AsyncInterpreter
from fastapi.middleware.cors import CORSMiddleware
from typing import List, Dict, Any
from openai import OpenAI
from pydantic import BaseModel
import argparse
import os
# import sentry_sdk
os.environ["STT_RUNNER"] = "server"
os.environ["TTS_RUNNER"] = "server"
async def main(server_host, server_port, tts_service, asynchronous):
if asynchronous:
base_interpreter.system_message = (
"You are a helpful assistant that can answer questions and help with tasks."
)
@ -23,21 +26,13 @@ base_interpreter.llm.model = "groq/llama3-8b-8192"
base_interpreter.llm.api_key = os.environ["GROQ_API_KEY"]
base_interpreter.llm.supports_functions = False
base_interpreter.auto_run = True
os.environ["STT_RUNNER"] = "server"
os.environ["TTS_RUNNER"] = "server"
# Parse command line arguments for port number
"""
parser = argparse.ArgumentParser(description="FastAPI server.")
parser.add_argument("--port", type=int, default=8000, help="Port to run on.")
args = parser.parse_args()
"""
base_interpreter.tts = "coqui"
async def main(server_host, server_port):
base_interpreter.tts = tts_service
interpreter = AsyncInterpreter(base_interpreter)
else:
configured_interpreter = configure_interpreter(base_interpreter)
configured_interpreter.llm.supports_functions = True
configured_interpreter.tts = tts_service
interpreter = AsyncInterpreter(configured_interpreter)
app = FastAPI()
@ -107,37 +102,6 @@ async def main(server_host, server_port):
server = Server(config)
await server.serve()
class Rename(BaseModel):
input: str
@app.post("/rename-chat")
async def rename_chat(body_content: Rename, x_api_key: str = Header(None)):
print("RENAME CHAT REQUEST in PY 🌙🌙🌙🌙")
input_value = body_content.input
client = OpenAI(
# defaults to os.environ.get("OPENAI_API_KEY")
api_key=x_api_key,
)
try:
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{
"role": "user",
"content": f"Given the following chat snippet, create a unique and descriptive title in less than 8 words. Your answer must not be related to customer service.\n\n{input_value}",
}
],
temperature=0.3,
stream=False,
)
print(response)
completion = response["choices"][0]["message"]["content"]
return {"data": {"content": completion}}
except Exception as e:
print(f"Error: {e}")
traceback.print_exc()
return {"error": str(e)}
if __name__ == "__main__":
asyncio.run(main())

@ -6,6 +6,8 @@ import os
import importlib
from source.server.tunnel import create_tunnel
from source.server.async_server import main
# from source.server.server import main
from source.server.utils.local_mode import select_local_model
import signal
@ -63,7 +65,7 @@ def run(
0.8, "--temperature", help="Specify the temperature for generation"
),
tts_service: str = typer.Option(
"openai", "--tts-service", help="Specify the TTS service"
"elevenlabs", "--tts-service", help="Specify the TTS service"
),
stt_service: str = typer.Option(
"openai", "--stt-service", help="Specify the STT service"
@ -75,6 +77,9 @@ def run(
mobile: bool = typer.Option(
False, "--mobile", help="Toggle server to support mobile app"
),
asynchronous: bool = typer.Option(
False, "--async", help="use interpreter optimized for latency"
),
):
_run(
server=server or mobile,
@ -97,6 +102,7 @@ def run(
local=local,
qr=qr or mobile,
mobile=mobile,
asynchronous=asynchronous,
)
@ -116,14 +122,15 @@ def _run(
context_window: int = 2048,
max_tokens: int = 4096,
temperature: float = 0.8,
tts_service: str = "openai",
tts_service: str = "elevenlabs",
stt_service: str = "openai",
local: bool = False,
qr: bool = False,
mobile: bool = False,
asynchronous: bool = False,
):
if local:
tts_service = "piper"
tts_service = "coqui"
# llm_service = "llamafile"
stt_service = "local-whisper"
select_local_model()
@ -154,6 +161,8 @@ def _run(
main(
server_host,
server_port,
tts_service,
asynchronous,
# llm_service,
# model,
# llm_supports_vision,
@ -161,7 +170,6 @@ def _run(
# context_window,
# max_tokens,
# temperature,
# tts_service,
# stt_service,
# mobile,
),
@ -180,7 +188,6 @@ def _run(
system_type = platform.system()
if system_type == "Darwin": # Mac OS
client_type = "mac"
print("initiating mac device with base device!!!")
elif system_type == "Windows": # Windows System
client_type = "windows"
elif system_type == "Linux": # Linux System
@ -196,9 +203,10 @@ def _run(
module = importlib.import_module(
f".clients.{client_type}.device", package="source"
)
# server_url = "0.0.0.0:8000"
client_thread = threading.Thread(target=module.main, args=[server_url])
print("client thread started")
client_thread = threading.Thread(
target=module.main, args=[server_url, tts_service]
)
client_thread.start()
try:

Loading…
Cancel
Save