add async interpreter with coqui, openai, elevenlabs tts

pull/279/head
Ben Xu 7 months ago
parent 2627fba481
commit eee00ac026

@ -90,6 +90,7 @@ class Device:
self.audiosegments = asyncio.Queue() self.audiosegments = asyncio.Queue()
self.server_url = "" self.server_url = ""
self.ctrl_pressed = False self.ctrl_pressed = False
self.tts_service = ""
self.playback_latency = None self.playback_latency = None
def fetch_image_from_camera(self, camera_index=CAMERA_DEVICE_INDEX): def fetch_image_from_camera(self, camera_index=CAMERA_DEVICE_INDEX):
@ -164,30 +165,18 @@ class Device:
while True: while True:
try: try:
audio = await self.audiosegments.get() audio = await self.audiosegments.get()
# print("got audio segment!!!!") if self.playback_latency and isinstance(audio, bytes):
if self.playback_latency:
elapsed_time = time.time() - self.playback_latency elapsed_time = time.time() - self.playback_latency
print(f"Time from request to playback: {elapsed_time} seconds") # print(f"Time from request to playback: {elapsed_time} seconds")
self.playback_latency = None self.playback_latency = None
if audio is not None: if self.tts_service == "elevenlabs":
mpv_process.stdin.write(audio) # type: ignore mpv_process.stdin.write(audio) # type: ignore
mpv_process.stdin.flush() # type: ignore mpv_process.stdin.flush() # type: ignore
""" else:
args = ["ffplay", "-autoexit", "-", "-nodisp"]
proc = subprocess.Popen(
args=args,
stdout=subprocess.PIPE,
stdin=subprocess.PIPE,
stderr=subprocess.PIPE,
)
out, err = proc.communicate(input=audio)
proc.poll()
play(audio) play(audio)
"""
# self.audiosegments.remove(audio) await asyncio.sleep(0.1)
# await asyncio.sleep(0.1)
except asyncio.exceptions.CancelledError: except asyncio.exceptions.CancelledError:
# This happens once at the start? # This happens once at the start?
pass pass
@ -342,24 +331,17 @@ class Device:
async def message_sender(self, websocket): async def message_sender(self, websocket):
while True: while True:
try:
message = await asyncio.get_event_loop().run_in_executor( message = await asyncio.get_event_loop().run_in_executor(
None, send_queue.get None, send_queue.get
) )
if isinstance(message, bytes): if isinstance(message, bytes):
await websocket.send(message) await websocket.send(message)
else: else:
await websocket.send(json.dumps(message)) await websocket.send(json.dumps(message))
send_queue.task_done() send_queue.task_done()
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
except:
traceback.print_exc()
async def websocket_communication(self, WS_URL): async def websocket_communication(self, WS_URL):
print("websocket communication was called!!!!")
show_connection_log = True show_connection_log = True
async def exec_ws_communication(websocket): async def exec_ws_communication(websocket):
@ -373,48 +355,48 @@ class Device:
asyncio.create_task(self.message_sender(websocket)) asyncio.create_task(self.message_sender(websocket))
while True: while True:
await asyncio.sleep(0.0001) await asyncio.sleep(0.01)
chunk = await websocket.recv() chunk = await websocket.recv()
logger.debug(f"Got this message from the server: {type(chunk)} {chunk}") logger.debug(f"Got this message from the server: {type(chunk)} {chunk}")
# print((f"Got this message from the server: {type(chunk)} {chunk}")) # print("received chunk from server")
if type(chunk) == str: if type(chunk) == str:
chunk = json.loads(chunk) chunk = json.loads(chunk)
# message = accumulator.accumulate(chunk) if self.tts_service == "elevenlabs":
message = chunk message = chunk
else:
message = accumulator.accumulate(chunk)
if message == None: if message == None:
# Will be None until we have a full message ready # Will be None until we have a full message ready
continue continue
# At this point, we have our message # At this point, we have our message
# print("checkpoint reached!", message) if isinstance(message, bytes) or (
if isinstance(message, bytes): message["type"] == "audio" and message["format"].startswith("bytes")
):
# if message["type"] == "audio" and message["format"].startswith("bytes"):
# Convert bytes to audio file # Convert bytes to audio file
if self.tts_service == "elevenlabs":
# audio_bytes = message["content"]
audio_bytes = message audio_bytes = message
audio = audio_bytes
else:
audio_bytes = message["content"]
# Create an AudioSegment instance with the raw data # Create an AudioSegment instance with the raw data
"""
audio = AudioSegment( audio = AudioSegment(
# raw audio data (bytes) # raw audio data (bytes)
data=audio_bytes, data=audio_bytes,
# signed 16-bit little-endian format # signed 16-bit little-endian format
sample_width=2, sample_width=2,
# 24,000 Hz frame rate # 16,000 Hz frame rate
frame_rate=16000, frame_rate=22050,
# mono sound # mono sound
channels=1, channels=1,
) )
"""
# print("audio segment was created")
await self.audiosegments.put(audio_bytes)
# await self.audiosegments.put(audio) await self.audiosegments.put(audio)
# Run the code if that's the client's job # Run the code if that's the client's job
if os.getenv("CODE_RUNNER") == "client": if os.getenv("CODE_RUNNER") == "client":
@ -434,29 +416,26 @@ class Device:
except Exception as e: except Exception as e:
logger.error(f"Error while attempting to connect: {e}") logger.error(f"Error while attempting to connect: {e}")
else: else:
print("websocket url is", WS_URL)
while True: while True:
try: try:
async with websockets.connect(WS_URL) as websocket: async with websockets.connect(WS_URL) as websocket:
print("awaiting exec_ws_communication")
await exec_ws_communication(websocket) await exec_ws_communication(websocket)
except: except:
logger.info(traceback.format_exc()) logger.debug(traceback.format_exc())
if show_connection_log: if show_connection_log:
logger.info(f"Connecting to `{WS_URL}`...") logger.info(f"Connecting to `{WS_URL}`...")
show_connection_log = False show_connection_log = False
await asyncio.sleep(2) await asyncio.sleep(2)
async def start_async(self): async def start_async(self):
print("start async was called!!!!!")
# Configuration for WebSocket # Configuration for WebSocket
WS_URL = f"ws://{self.server_url}" WS_URL = f"ws://{self.server_url}"
# Start the WebSocket communication # Start the WebSocket communication
asyncio.create_task(self.websocket_communication(WS_URL)) asyncio.create_task(self.websocket_communication(WS_URL))
# Start watching the kernel if it's your job to do that # Start watching the kernel if it's your job to do that
if os.getenv("CODE_RUNNER") == "client": if os.getenv("CODE_RUNNER") == "client":
# client is not running code!
asyncio.create_task(put_kernel_messages_into_queue(send_queue)) asyncio.create_task(put_kernel_messages_into_queue(send_queue))
asyncio.create_task(self.play_audiosegments()) asyncio.create_task(self.play_audiosegments())
@ -488,10 +467,8 @@ class Device:
on_press=self.on_press, on_release=self.on_release on_press=self.on_press, on_release=self.on_release
) )
listener.start() listener.start()
print("listener for keyboard started!!!!!")
def start(self): def start(self):
print("device was started!!!!!!")
if os.getenv("TEACH_MODE") != "True": if os.getenv("TEACH_MODE") != "True":
asyncio.run(self.start_async()) asyncio.run(self.start_async())
p.terminate() p.terminate()

@ -3,8 +3,9 @@ from ..base_device import Device
device = Device() device = Device()
def main(server_url): def main(server_url, tts_service):
device.server_url = server_url device.server_url = server_url
device.tts_service = tts_service
device.start() device.start()

@ -3,8 +3,9 @@ from ..base_device import Device
device = Device() device = Device()
def main(server_url): def main(server_url, tts_service):
device.server_url = server_url device.server_url = server_url
device.tts_service = tts_service
device.start() device.start()

@ -3,8 +3,9 @@ from ..base_device import Device
device = Device() device = Device()
def main(server_url): def main(server_url, tts_service):
device.server_url = server_url device.server_url = server_url
device.tts_service = tts_service
device.start() device.start()

@ -10,16 +10,9 @@
""" """
### ###
from pynput import keyboard from pynput import keyboard
from RealtimeTTS import (
TextToAudioStream, from RealtimeTTS import TextToAudioStream, CoquiEngine, OpenAIEngine, ElevenlabsEngine
OpenAIEngine,
CoquiEngine,
ElevenlabsEngine,
SystemEngine,
GTTSEngine,
)
from RealtimeSTT import AudioToTextRecorder from RealtimeSTT import AudioToTextRecorder
import time import time
import asyncio import asyncio
@ -29,9 +22,9 @@ import os
class AsyncInterpreter: class AsyncInterpreter:
def __init__(self, interpreter): def __init__(self, interpreter):
self.stt_latency = None # self.stt_latency = None
self.tts_latency = None # self.tts_latency = None
self.interpreter_latency = None # self.interpreter_latency = None
self.interpreter = interpreter self.interpreter = interpreter
# STT # STT
@ -45,12 +38,9 @@ class AsyncInterpreter:
engine = CoquiEngine() engine = CoquiEngine()
elif self.interpreter.tts == "openai": elif self.interpreter.tts == "openai":
engine = OpenAIEngine() engine = OpenAIEngine()
elif self.interpreter.tts == "gtts":
engine = GTTSEngine()
elif self.interpreter.tts == "elevenlabs": elif self.interpreter.tts == "elevenlabs":
engine = ElevenlabsEngine(api_key=os.environ["ELEVEN_LABS_API_KEY"]) engine = ElevenlabsEngine(api_key=os.environ["ELEVEN_LABS_API_KEY"])
elif self.interpreter.tts == "system": engine.set_voice("Michael")
engine = SystemEngine()
else: else:
raise ValueError(f"Unsupported TTS engine: {self.interpreter.tts}") raise ValueError(f"Unsupported TTS engine: {self.interpreter.tts}")
self.tts = TextToAudioStream(engine) self.tts = TextToAudioStream(engine)
@ -112,41 +102,11 @@ class AsyncInterpreter:
# print("ADDING TO QUEUE:", chunk) # print("ADDING TO QUEUE:", chunk)
asyncio.create_task(self._add_to_queue(self._output_queue, chunk)) asyncio.create_task(self._add_to_queue(self._output_queue, chunk))
async def run(self): def generate(self, message, start_interpreter):
"""
Runs OI on the audio bytes submitted to the input. Will add streaming LMC chunks to the _output_queue.
"""
self.interpreter.messages = self.active_chat_messages
# self.beeper.start()
self.stt.stop()
# message = self.stt.text()
# print("THE MESSAGE:", message)
# accumulates the input queue message
input_queue = []
while not self._input_queue.empty():
input_queue.append(self._input_queue.get())
# print("INPUT QUEUE:", input_queue)
# message = [i for i in input_queue if i["type"] == "message"][0]["content"]
start_stt = time.time()
message = self.stt.text()
end_stt = time.time()
self.stt_latency = end_stt - start_stt
print("STT LATENCY", self.stt_latency)
# print(message)
end_interpreter = 0
# print(message)
def generate(message):
last_lmc_start_flag = self._last_lmc_start_flag last_lmc_start_flag = self._last_lmc_start_flag
self.interpreter.messages = self.active_chat_messages self.interpreter.messages = self.active_chat_messages
# print("🍀🍀🍀🍀GENERATING, using these messages: ", self.interpreter.messages)
# print("🍀 🍀 🍀 🍀 active_chat_messages: ", self.active_chat_messages) # print("message is", message)
print("message is", message)
for chunk in self.interpreter.chat(message, display=True, stream=True): for chunk in self.interpreter.chat(message, display=True, stream=True):
@ -165,7 +125,7 @@ class AsyncInterpreter:
# Experimental: The AI voice sounds better with replacements like these, but it should happen at the TTS layer # Experimental: The AI voice sounds better with replacements like these, but it should happen at the TTS layer
# content = content.replace(". ", ". ... ").replace(", ", ", ... ").replace("!", "! ... ").replace("?", "? ... ") # content = content.replace(". ", ". ... ").replace(", ", ", ... ").replace("!", "! ... ").replace("?", "? ... ")
# print("yielding this", content) # print("yielding ", content)
yield content yield content
# Handle code blocks # Handle code blocks
@ -196,27 +156,42 @@ class AsyncInterpreter:
) )
# Send a completion signal # Send a completion signal
end_interpreter = time.time() # end_interpreter = time.time()
self.interpreter_latency = end_interpreter - start_interpreter # self.interpreter_latency = end_interpreter - start_interpreter
print("INTERPRETER LATENCY", self.interpreter_latency) # print("INTERPRETER LATENCY", self.interpreter_latency)
# self.add_to_output_queue_sync({"role": "server","type": "completion", "content": "DONE"}) # self.add_to_output_queue_sync({"role": "server","type": "completion", "content": "DONE"})
async def run(self):
"""
Runs OI on the audio bytes submitted to the input. Will add streaming LMC chunks to the _output_queue.
"""
self.interpreter.messages = self.active_chat_messages
self.stt.stop()
input_queue = []
while not self._input_queue.empty():
input_queue.append(self._input_queue.get())
# start_stt = time.time()
message = self.stt.text()
# end_stt = time.time()
# self.stt_latency = end_stt - start_stt
# print("STT LATENCY", self.stt_latency)
# print(message)
# Feed generate to RealtimeTTS # Feed generate to RealtimeTTS
self.add_to_output_queue_sync( self.add_to_output_queue_sync(
{"role": "assistant", "type": "audio", "format": "bytes.wav", "start": True} {"role": "assistant", "type": "audio", "format": "bytes.wav", "start": True}
) )
start_interpreter = time.time() start_interpreter = time.time()
text_iterator = generate(message) text_iterator = self.generate(message, start_interpreter)
self.tts.feed(text_iterator) self.tts.feed(text_iterator)
self.tts.play_async(on_audio_chunk=self.on_tts_chunk, muted=True)
while True: self.tts.play_async(on_audio_chunk=self.on_tts_chunk, muted=True)
if self.tts.is_playing():
start_tts = time.time()
break
await asyncio.sleep(0.1)
while True: while True:
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
# print("is_playing", self.tts.is_playing()) # print("is_playing", self.tts.is_playing())
@ -229,14 +204,14 @@ class AsyncInterpreter:
"end": True, "end": True,
} }
) )
end_tts = time.time() # end_tts = time.time()
self.tts_latency = end_tts - start_tts # self.tts_latency = end_tts - self.tts.stream_start_time
print("TTS LATENCY", self.tts_latency) # print("TTS LATENCY", self.tts_latency)
self.tts.stop() self.tts.stop()
break break
async def _on_tts_chunk_async(self, chunk): async def _on_tts_chunk_async(self, chunk):
# print("SENDING TTS CHUNK") # print("adding chunk to queue")
await self._add_to_queue(self._output_queue, chunk) await self._add_to_queue(self._output_queue, chunk)
def on_tts_chunk(self, chunk): def on_tts_chunk(self, chunk):
@ -244,4 +219,5 @@ class AsyncInterpreter:
asyncio.run(self._on_tts_chunk_async(chunk)) asyncio.run(self._on_tts_chunk_async(chunk))
async def output(self): async def output(self):
# print("outputting chunks")
return await self._output_queue.get() return await self._output_queue.get()

@ -1,20 +1,23 @@
import asyncio import asyncio
import traceback import traceback
import json import json
from fastapi import FastAPI, WebSocket, Header from fastapi import FastAPI, WebSocket
from fastapi.responses import PlainTextResponse from fastapi.responses import PlainTextResponse
from uvicorn import Config, Server from uvicorn import Config, Server
from .i import configure_interpreter
from interpreter import interpreter as base_interpreter from interpreter import interpreter as base_interpreter
from .async_interpreter import AsyncInterpreter from .async_interpreter import AsyncInterpreter
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from typing import List, Dict, Any from typing import List, Dict, Any
from openai import OpenAI
from pydantic import BaseModel
import argparse
import os import os
# import sentry_sdk
os.environ["STT_RUNNER"] = "server"
os.environ["TTS_RUNNER"] = "server"
async def main(server_host, server_port, tts_service, asynchronous):
if asynchronous:
base_interpreter.system_message = ( base_interpreter.system_message = (
"You are a helpful assistant that can answer questions and help with tasks." "You are a helpful assistant that can answer questions and help with tasks."
) )
@ -23,21 +26,13 @@ base_interpreter.llm.model = "groq/llama3-8b-8192"
base_interpreter.llm.api_key = os.environ["GROQ_API_KEY"] base_interpreter.llm.api_key = os.environ["GROQ_API_KEY"]
base_interpreter.llm.supports_functions = False base_interpreter.llm.supports_functions = False
base_interpreter.auto_run = True base_interpreter.auto_run = True
base_interpreter.tts = tts_service
os.environ["STT_RUNNER"] = "server"
os.environ["TTS_RUNNER"] = "server"
# Parse command line arguments for port number
"""
parser = argparse.ArgumentParser(description="FastAPI server.")
parser.add_argument("--port", type=int, default=8000, help="Port to run on.")
args = parser.parse_args()
"""
base_interpreter.tts = "coqui"
async def main(server_host, server_port):
interpreter = AsyncInterpreter(base_interpreter) interpreter = AsyncInterpreter(base_interpreter)
else:
configured_interpreter = configure_interpreter(base_interpreter)
configured_interpreter.llm.supports_functions = True
configured_interpreter.tts = tts_service
interpreter = AsyncInterpreter(configured_interpreter)
app = FastAPI() app = FastAPI()
@ -107,37 +102,6 @@ async def main(server_host, server_port):
server = Server(config) server = Server(config)
await server.serve() await server.serve()
class Rename(BaseModel):
input: str
@app.post("/rename-chat")
async def rename_chat(body_content: Rename, x_api_key: str = Header(None)):
print("RENAME CHAT REQUEST in PY 🌙🌙🌙🌙")
input_value = body_content.input
client = OpenAI(
# defaults to os.environ.get("OPENAI_API_KEY")
api_key=x_api_key,
)
try:
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{
"role": "user",
"content": f"Given the following chat snippet, create a unique and descriptive title in less than 8 words. Your answer must not be related to customer service.\n\n{input_value}",
}
],
temperature=0.3,
stream=False,
)
print(response)
completion = response["choices"][0]["message"]["content"]
return {"data": {"content": completion}}
except Exception as e:
print(f"Error: {e}")
traceback.print_exc()
return {"error": str(e)}
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

@ -6,6 +6,8 @@ import os
import importlib import importlib
from source.server.tunnel import create_tunnel from source.server.tunnel import create_tunnel
from source.server.async_server import main from source.server.async_server import main
# from source.server.server import main
from source.server.utils.local_mode import select_local_model from source.server.utils.local_mode import select_local_model
import signal import signal
@ -63,7 +65,7 @@ def run(
0.8, "--temperature", help="Specify the temperature for generation" 0.8, "--temperature", help="Specify the temperature for generation"
), ),
tts_service: str = typer.Option( tts_service: str = typer.Option(
"openai", "--tts-service", help="Specify the TTS service" "elevenlabs", "--tts-service", help="Specify the TTS service"
), ),
stt_service: str = typer.Option( stt_service: str = typer.Option(
"openai", "--stt-service", help="Specify the STT service" "openai", "--stt-service", help="Specify the STT service"
@ -75,6 +77,9 @@ def run(
mobile: bool = typer.Option( mobile: bool = typer.Option(
False, "--mobile", help="Toggle server to support mobile app" False, "--mobile", help="Toggle server to support mobile app"
), ),
asynchronous: bool = typer.Option(
False, "--async", help="use interpreter optimized for latency"
),
): ):
_run( _run(
server=server or mobile, server=server or mobile,
@ -97,6 +102,7 @@ def run(
local=local, local=local,
qr=qr or mobile, qr=qr or mobile,
mobile=mobile, mobile=mobile,
asynchronous=asynchronous,
) )
@ -116,14 +122,15 @@ def _run(
context_window: int = 2048, context_window: int = 2048,
max_tokens: int = 4096, max_tokens: int = 4096,
temperature: float = 0.8, temperature: float = 0.8,
tts_service: str = "openai", tts_service: str = "elevenlabs",
stt_service: str = "openai", stt_service: str = "openai",
local: bool = False, local: bool = False,
qr: bool = False, qr: bool = False,
mobile: bool = False, mobile: bool = False,
asynchronous: bool = False,
): ):
if local: if local:
tts_service = "piper" tts_service = "coqui"
# llm_service = "llamafile" # llm_service = "llamafile"
stt_service = "local-whisper" stt_service = "local-whisper"
select_local_model() select_local_model()
@ -154,6 +161,8 @@ def _run(
main( main(
server_host, server_host,
server_port, server_port,
tts_service,
asynchronous,
# llm_service, # llm_service,
# model, # model,
# llm_supports_vision, # llm_supports_vision,
@ -161,7 +170,6 @@ def _run(
# context_window, # context_window,
# max_tokens, # max_tokens,
# temperature, # temperature,
# tts_service,
# stt_service, # stt_service,
# mobile, # mobile,
), ),
@ -180,7 +188,6 @@ def _run(
system_type = platform.system() system_type = platform.system()
if system_type == "Darwin": # Mac OS if system_type == "Darwin": # Mac OS
client_type = "mac" client_type = "mac"
print("initiating mac device with base device!!!")
elif system_type == "Windows": # Windows System elif system_type == "Windows": # Windows System
client_type = "windows" client_type = "windows"
elif system_type == "Linux": # Linux System elif system_type == "Linux": # Linux System
@ -196,9 +203,10 @@ def _run(
module = importlib.import_module( module = importlib.import_module(
f".clients.{client_type}.device", package="source" f".clients.{client_type}.device", package="source"
) )
# server_url = "0.0.0.0:8000"
client_thread = threading.Thread(target=module.main, args=[server_url]) client_thread = threading.Thread(
print("client thread started") target=module.main, args=[server_url, tts_service]
)
client_thread.start() client_thread.start()
try: try:

Loading…
Cancel
Save