add realtime tts streaming

pull/279/head
Ben Xu 7 months ago
parent 9e04e2c5de
commit 72f7d140d4

743
software/poetry.lock generated

File diff suppressed because it is too large Load Diff

@ -33,19 +33,20 @@ python-crontab = "^3.0.0"
inquirer = "^3.2.4" inquirer = "^3.2.4"
pyqrcode = "^1.2.1" pyqrcode = "^1.2.1"
realtimestt = "^0.1.12" realtimestt = "^0.1.12"
realtimetts = "^0.3.44" realtimetts = "^0.4.1"
keyboard = "^0.13.5" keyboard = "^0.13.5"
pyautogui = "^0.9.54" pyautogui = "^0.9.54"
ctranslate2 = "4.1.0" ctranslate2 = "4.1.0"
py3-tts = "^3.5" py3-tts = "^3.5"
elevenlabs = "0.2.27" elevenlabs = "1.2.2"
groq = "^0.5.0" groq = "^0.5.0"
open-interpreter = "^0.2.5" open-interpreter = "^0.2.6"
litellm = "1.35.35" litellm = "1.35.35"
openai = "1.13.3" openai = "1.30.5"
pywebview = "*" pywebview = "*"
pyobjc = "*" pyobjc = "*"
sentry-sdk = "^2.4.0"
[build-system] [build-system]
requires = ["poetry-core"] requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"

@ -2,6 +2,7 @@ from dotenv import load_dotenv
load_dotenv() # take environment variables from .env. load_dotenv() # take environment variables from .env.
import subprocess
import os import os
import sys import sys
import asyncio import asyncio
@ -46,7 +47,7 @@ accumulator = Accumulator()
CHUNK = 1024 # Record in chunks of 1024 samples CHUNK = 1024 # Record in chunks of 1024 samples
FORMAT = pyaudio.paInt16 # 16 bits per sample FORMAT = pyaudio.paInt16 # 16 bits per sample
CHANNELS = 1 # Mono CHANNELS = 1 # Mono
RATE = 44100 # Sample rate RATE = 16000 # Sample rate
RECORDING = False # Flag to control recording state RECORDING = False # Flag to control recording state
SPACEBAR_PRESSED = False # Flag to track spacebar press state SPACEBAR_PRESSED = False # Flag to track spacebar press state
@ -86,10 +87,10 @@ class Device:
def __init__(self): def __init__(self):
self.pressed_keys = set() self.pressed_keys = set()
self.captured_images = [] self.captured_images = []
self.audiosegments = [] self.audiosegments = asyncio.Queue()
self.server_url = "" self.server_url = ""
self.ctrl_pressed = False self.ctrl_pressed = False
# self.latency = None self.playback_latency = None
def fetch_image_from_camera(self, camera_index=CAMERA_DEVICE_INDEX): def fetch_image_from_camera(self, camera_index=CAMERA_DEVICE_INDEX):
"""Captures an image from the specified camera device and saves it to a temporary file. Adds the image to the captured_images list.""" """Captures an image from the specified camera device and saves it to a temporary file. Adds the image to the captured_images list."""
@ -153,14 +154,26 @@ class Device:
"""Plays them sequentially.""" """Plays them sequentially."""
while True: while True:
try: try:
for audio in self.audiosegments: audio = await self.audiosegments.get()
# if self.latency: # print("got audio segment!!!!")
# elapsed_time = time.time() - self.latency if self.playback_latency:
# print(f"Time from request to playback: {elapsed_time} seconds") elapsed_time = time.time() - self.playback_latency
# self.latency = None print(f"Time from request to playback: {elapsed_time} seconds")
play(audio) self.playback_latency = None
self.audiosegments.remove(audio)
await asyncio.sleep(0.1) args = ["ffplay", "-autoexit", "-", "-nodisp"]
proc = subprocess.Popen(
args=args,
stdout=subprocess.PIPE,
stdin=subprocess.PIPE,
stderr=subprocess.PIPE,
)
out, err = proc.communicate(input=audio)
proc.poll()
# play(audio)
# self.audiosegments.remove(audio)
# await asyncio.sleep(0.1)
except asyncio.exceptions.CancelledError: except asyncio.exceptions.CancelledError:
# This happens once at the start? # This happens once at the start?
pass pass
@ -208,7 +221,7 @@ class Device:
stream.stop_stream() stream.stop_stream()
stream.close() stream.close()
print("Recording stopped.") print("Recording stopped.")
# self.latency = time.time() self.playback_latency = time.time()
duration = wav_file.getnframes() / RATE duration = wav_file.getnframes() / RATE
if duration < 0.3: if duration < 0.3:
@ -315,18 +328,21 @@ class Device:
async def message_sender(self, websocket): async def message_sender(self, websocket):
while True: while True:
message = await asyncio.get_event_loop().run_in_executor( try:
None, send_queue.get message = await asyncio.get_event_loop().run_in_executor(
) None, send_queue.get
)
if isinstance(message, bytes): if isinstance(message, bytes):
await websocket.send(message) await websocket.send(message)
else: else:
await websocket.send(json.dumps(message)) await websocket.send(json.dumps(message))
send_queue.task_done() send_queue.task_done()
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
except:
traceback.print_exc()
async def websocket_communication(self, WS_URL): async def websocket_communication(self, WS_URL):
print("websocket communication was called!!!!") print("websocket communication was called!!!!")
@ -343,7 +359,7 @@ class Device:
asyncio.create_task(self.message_sender(websocket)) asyncio.create_task(self.message_sender(websocket))
while True: while True:
await asyncio.sleep(0.01) await asyncio.sleep(0.0001)
chunk = await websocket.recv() chunk = await websocket.recv()
logger.debug(f"Got this message from the server: {type(chunk)} {chunk}") logger.debug(f"Got this message from the server: {type(chunk)} {chunk}")
@ -351,31 +367,38 @@ class Device:
if type(chunk) == str: if type(chunk) == str:
chunk = json.loads(chunk) chunk = json.loads(chunk)
message = accumulator.accumulate(chunk) # message = accumulator.accumulate(chunk)
message = chunk
if message == None: if message == None:
# Will be None until we have a full message ready # Will be None until we have a full message ready
continue continue
# At this point, we have our message # At this point, we have our message
# print("checkpoint reached!", message)
if isinstance(message, bytes):
if message["type"] == "audio" and message["format"].startswith("bytes"): # if message["type"] == "audio" and message["format"].startswith("bytes"):
# Convert bytes to audio file # Convert bytes to audio file
audio_bytes = message["content"] # audio_bytes = message["content"]
audio_bytes = message
# Create an AudioSegment instance with the raw data # Create an AudioSegment instance with the raw data
"""
audio = AudioSegment( audio = AudioSegment(
# raw audio data (bytes) # raw audio data (bytes)
data=audio_bytes, data=audio_bytes,
# signed 16-bit little-endian format # signed 16-bit little-endian format
sample_width=2, sample_width=2,
# 16,000 Hz frame rate # 24,000 Hz frame rate
frame_rate=16000, frame_rate=24000,
# mono sound # mono sound
channels=1, channels=1,
) )
"""
self.audiosegments.append(audio) # print("audio segment was created")
await self.audiosegments.put(audio_bytes)
# Run the code if that's the client's job # Run the code if that's the client's job
if os.getenv("CODE_RUNNER") == "client": if os.getenv("CODE_RUNNER") == "client":
@ -399,6 +422,7 @@ class Device:
while True: while True:
try: try:
async with websockets.connect(WS_URL) as websocket: async with websockets.connect(WS_URL) as websocket:
print("awaiting exec_ws_communication")
await exec_ws_communication(websocket) await exec_ws_communication(websocket)
except: except:
logger.info(traceback.format_exc()) logger.info(traceback.format_exc())
@ -410,7 +434,7 @@ class Device:
async def start_async(self): async def start_async(self):
print("start async was called!!!!!") print("start async was called!!!!!")
# Configuration for WebSocket # Configuration for WebSocket
WS_URL = f"ws://{self.server_url}" WS_URL = f"ws://{self.server_url}/ws"
# Start the WebSocket communication # Start the WebSocket communication
asyncio.create_task(self.websocket_communication(WS_URL)) asyncio.create_task(self.websocket_communication(WS_URL))

@ -12,7 +12,14 @@
### ###
from pynput import keyboard from pynput import keyboard
from RealtimeTTS import TextToAudioStream, OpenAIEngine, CoquiEngine from RealtimeTTS import (
TextToAudioStream,
OpenAIEngine,
CoquiEngine,
ElevenlabsEngine,
SystemEngine,
GTTSEngine,
)
from RealtimeSTT import AudioToTextRecorder from RealtimeSTT import AudioToTextRecorder
import time import time
import asyncio import asyncio
@ -21,11 +28,14 @@ import json
class AsyncInterpreter: class AsyncInterpreter:
def __init__(self, interpreter): def __init__(self, interpreter):
self.stt_latency = None
self.tts_latency = None
self.interpreter_latency = None
self.interpreter = interpreter self.interpreter = interpreter
# STT # STT
self.stt = AudioToTextRecorder( self.stt = AudioToTextRecorder(
model="tiny", spinner=False, use_microphone=False model="tiny.en", spinner=False, use_microphone=False
) )
self.stt.stop() # It needs this for some reason self.stt.stop() # It needs this for some reason
@ -34,6 +44,16 @@ class AsyncInterpreter:
engine = CoquiEngine() engine = CoquiEngine()
elif self.interpreter.tts == "openai": elif self.interpreter.tts == "openai":
engine = OpenAIEngine() engine = OpenAIEngine()
elif self.interpreter.tts == "gtts":
engine = GTTSEngine()
elif self.interpreter.tts == "elevenlabs":
engine = ElevenlabsEngine(
api_key="sk_077cb1cabdf67e62b85f8782e66e5d8e11f78b450c7ce171"
)
elif self.interpreter.tts == "system":
engine = SystemEngine()
else:
raise ValueError(f"Unsupported TTS engine: {self.interpreter.tts}")
self.tts = TextToAudioStream(engine) self.tts = TextToAudioStream(engine)
self.active_chat_messages = [] self.active_chat_messages = []
@ -112,7 +132,11 @@ class AsyncInterpreter:
# print("INPUT QUEUE:", input_queue) # print("INPUT QUEUE:", input_queue)
# message = [i for i in input_queue if i["type"] == "message"][0]["content"] # message = [i for i in input_queue if i["type"] == "message"][0]["content"]
start_stt = time.time()
message = self.stt.text() message = self.stt.text()
end_stt = time.time()
self.stt_latency = end_stt - start_stt
print("STT LATENCY", self.stt_latency)
# print(message) # print(message)
@ -141,7 +165,7 @@ class AsyncInterpreter:
# Experimental: The AI voice sounds better with replacements like these, but it should happen at the TTS layer # Experimental: The AI voice sounds better with replacements like these, but it should happen at the TTS layer
# content = content.replace(". ", ". ... ").replace(", ", ", ... ").replace("!", "! ... ").replace("?", "? ... ") # content = content.replace(". ", ". ... ").replace(", ", ", ... ").replace("!", "! ... ").replace("?", "? ... ")
print("yielding this", content)
yield content yield content
# Handle code blocks # Handle code blocks
@ -172,17 +196,24 @@ class AsyncInterpreter:
) )
# Send a completion signal # Send a completion signal
end_interpreter = time.time()
self.interpreter_latency = end_interpreter - start_interpreter
print("INTERPRETER LATENCY", self.interpreter_latency)
# self.add_to_output_queue_sync({"role": "server","type": "completion", "content": "DONE"}) # self.add_to_output_queue_sync({"role": "server","type": "completion", "content": "DONE"})
# Feed generate to RealtimeTTS # Feed generate to RealtimeTTS
self.add_to_output_queue_sync( self.add_to_output_queue_sync(
{"role": "assistant", "type": "audio", "format": "bytes.wav", "start": True} {"role": "assistant", "type": "audio", "format": "bytes.wav", "start": True}
) )
self.tts.feed(generate(message)) start_interpreter = time.time()
text_iterator = generate(message)
self.tts.feed(text_iterator)
self.tts.play_async(on_audio_chunk=self.on_tts_chunk, muted=True) self.tts.play_async(on_audio_chunk=self.on_tts_chunk, muted=True)
while True: while True:
if self.tts.is_playing(): if self.tts.is_playing():
start_tts = time.time()
break break
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
while True: while True:
@ -197,6 +228,9 @@ class AsyncInterpreter:
"end": True, "end": True,
} }
) )
end_tts = time.time()
self.tts_latency = end_tts - start_tts
print("TTS LATENCY", self.tts_latency)
break break
async def _on_tts_chunk_async(self, chunk): async def _on_tts_chunk_async(self, chunk):
@ -204,6 +238,7 @@ class AsyncInterpreter:
await self._add_to_queue(self._output_queue, chunk) await self._add_to_queue(self._output_queue, chunk)
def on_tts_chunk(self, chunk): def on_tts_chunk(self, chunk):
# print("ye")
asyncio.run(self._on_tts_chunk_async(chunk)) asyncio.run(self._on_tts_chunk_async(chunk))
async def output(self): async def output(self):

@ -12,6 +12,18 @@ from pydantic import BaseModel
import argparse import argparse
import os import os
# import sentry_sdk
base_interpreter.system_message = (
"You are a helpful assistant that can answer questions and help with tasks."
)
base_interpreter.computer.import_computer_api = False
base_interpreter.llm.model = "groq/mixtral-8x7b-32768"
base_interpreter.llm.api_key = (
"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"
)
base_interpreter.llm.supports_functions = False
os.environ["STT_RUNNER"] = "server" os.environ["STT_RUNNER"] = "server"
os.environ["TTS_RUNNER"] = "server" os.environ["TTS_RUNNER"] = "server"
@ -20,11 +32,24 @@ parser = argparse.ArgumentParser(description="FastAPI server.")
parser.add_argument("--port", type=int, default=8000, help="Port to run on.") parser.add_argument("--port", type=int, default=8000, help="Port to run on.")
args = parser.parse_args() args = parser.parse_args()
base_interpreter.tts = "openai" base_interpreter.tts = "elevenlabs"
base_interpreter.llm.model = "gpt-4-turbo"
async def main(): async def main():
"""
sentry_sdk.init(
dsn="https://a1465f62a31c7dfb23e1616da86341e9@o4506046614667264.ingest.us.sentry.io/4507374662385664",
enable_tracing=True,
# Set traces_sample_rate to 1.0 to capture 100%
# of transactions for performance monitoring.
traces_sample_rate=1.0,
# Set profiles_sample_rate to 1.0 to profile 100%
# of sampled transactions.
# We recommend adjusting this value in production.
profiles_sample_rate=1.0,
)
"""
interpreter = AsyncInterpreter(base_interpreter) interpreter = AsyncInterpreter(base_interpreter)
app = FastAPI() app = FastAPI()
@ -51,6 +76,9 @@ async def main():
async def receive_input(): async def receive_input():
while True: while True:
if websocket.client_state == "DISCONNECTED":
break
data = await websocket.receive() data = await websocket.receive()
if isinstance(data, bytes): if isinstance(data, bytes):
@ -65,19 +93,23 @@ async def main():
async def send_output(): async def send_output():
while True: while True:
output = await interpreter.output() output = await interpreter.output()
if isinstance(output, bytes): if isinstance(output, bytes):
# print(f"Sending {len(output)} bytes of audio data.")
await websocket.send_bytes(output) await websocket.send_bytes(output)
# we dont send out bytes rn, no TTS # we dont send out bytes rn, no TTS
pass
elif isinstance(output, dict): elif isinstance(output, dict):
# print("sending text")
await websocket.send_text(json.dumps(output)) await websocket.send_text(json.dumps(output))
await asyncio.gather(receive_input(), send_output()) await asyncio.gather(send_output(), receive_input())
except Exception as e: except Exception as e:
print(f"WebSocket connection closed with exception: {e}") print(f"WebSocket connection closed with exception: {e}")
traceback.print_exc() traceback.print_exc()
finally: finally:
await websocket.close() if not websocket.client_state == "DISCONNECTED":
await websocket.close()
config = Config(app, host="0.0.0.0", port=8000, lifespan="on") config = Config(app, host="0.0.0.0", port=8000, lifespan="on")
server = Server(config) server = Server(config)

@ -23,6 +23,7 @@ from .utils.logs import logger
import base64 import base64
import shutil import shutil
from ..utils.print_markdown import print_markdown from ..utils.print_markdown import print_markdown
import time
os.environ["STT_RUNNER"] = "server" os.environ["STT_RUNNER"] = "server"
os.environ["TTS_RUNNER"] = "server" os.environ["TTS_RUNNER"] = "server"
@ -383,6 +384,7 @@ async def stream_tts_to_device(sentence, mobile: bool):
def stream_tts(sentence, mobile: bool): def stream_tts(sentence, mobile: bool):
audio_file = tts(sentence, mobile) audio_file = tts(sentence, mobile)
# Read the entire WAV file # Read the entire WAV file

@ -5,7 +5,7 @@ import threading
import os import os
import importlib import importlib
from source.server.tunnel import create_tunnel from source.server.tunnel import create_tunnel
from source.server.server import main from source.server.async_server import main
from source.server.utils.local_mode import select_local_model from source.server.utils.local_mode import select_local_model
import signal import signal
@ -152,18 +152,18 @@ def _run(
target=loop.run_until_complete, target=loop.run_until_complete,
args=( args=(
main( main(
server_host, # server_host,
server_port, # server_port,
llm_service, # llm_service,
model, # model,
llm_supports_vision, # llm_supports_vision,
llm_supports_functions, # llm_supports_functions,
context_window, # context_window,
max_tokens, # max_tokens,
temperature, # temperature,
tts_service, # tts_service,
stt_service, # stt_service,
mobile, # mobile,
), ),
), ),
) )
@ -196,7 +196,7 @@ def _run(
module = importlib.import_module( module = importlib.import_module(
f".clients.{client_type}.device", package="source" f".clients.{client_type}.device", package="source"
) )
server_url = "0.0.0.0:8000"
client_thread = threading.Thread(target=module.main, args=[server_url]) client_thread = threading.Thread(target=module.main, args=[server_url])
print("client thread started") print("client thread started")
client_thread.start() client_thread.start()

Loading…
Cancel
Save