From eee00ac0268e66fd497a37887a457c1f109a6cf1 Mon Sep 17 00:00:00 2001 From: Ben Xu Date: Tue, 18 Jun 2024 05:47:12 -0700 Subject: [PATCH] add async interpreter with coqui, openai, elevenlabs tts --- software/source/clients/base_device.py | 119 ++++++------- software/source/clients/linux/device.py | 3 +- software/source/clients/mac/device.py | 3 +- software/source/clients/windows/device.py | 3 +- software/source/server/async_interpreter.py | 176 +++++++++----------- software/source/server/async_server.py | 74 +++----- software/start.py | 24 ++- 7 files changed, 165 insertions(+), 237 deletions(-) diff --git a/software/source/clients/base_device.py b/software/source/clients/base_device.py index b713601..c5ac73f 100644 --- a/software/source/clients/base_device.py +++ b/software/source/clients/base_device.py @@ -90,6 +90,7 @@ class Device: self.audiosegments = asyncio.Queue() self.server_url = "" self.ctrl_pressed = False + self.tts_service = "" self.playback_latency = None def fetch_image_from_camera(self, camera_index=CAMERA_DEVICE_INDEX): @@ -164,30 +165,18 @@ class Device: while True: try: audio = await self.audiosegments.get() - # print("got audio segment!!!!") - if self.playback_latency: + if self.playback_latency and isinstance(audio, bytes): elapsed_time = time.time() - self.playback_latency - print(f"Time from request to playback: {elapsed_time} seconds") + # print(f"Time from request to playback: {elapsed_time} seconds") self.playback_latency = None - if audio is not None: + if self.tts_service == "elevenlabs": mpv_process.stdin.write(audio) # type: ignore mpv_process.stdin.flush() # type: ignore - """ - args = ["ffplay", "-autoexit", "-", "-nodisp"] - proc = subprocess.Popen( - args=args, - stdout=subprocess.PIPE, - stdin=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - out, err = proc.communicate(input=audio) - proc.poll() + else: + play(audio) - play(audio) - """ - # self.audiosegments.remove(audio) - # await asyncio.sleep(0.1) + await asyncio.sleep(0.1) except asyncio.exceptions.CancelledError: # This happens once at the start? pass @@ -342,24 +331,17 @@ class Device: async def message_sender(self, websocket): while True: - try: - message = await asyncio.get_event_loop().run_in_executor( - None, send_queue.get - ) - - if isinstance(message, bytes): - await websocket.send(message) - - else: - await websocket.send(json.dumps(message)) - - send_queue.task_done() - await asyncio.sleep(0.01) - except: - traceback.print_exc() + message = await asyncio.get_event_loop().run_in_executor( + None, send_queue.get + ) + if isinstance(message, bytes): + await websocket.send(message) + else: + await websocket.send(json.dumps(message)) + send_queue.task_done() + await asyncio.sleep(0.01) async def websocket_communication(self, WS_URL): - print("websocket communication was called!!!!") show_connection_log = True async def exec_ws_communication(websocket): @@ -373,48 +355,48 @@ class Device: asyncio.create_task(self.message_sender(websocket)) while True: - await asyncio.sleep(0.0001) + await asyncio.sleep(0.01) chunk = await websocket.recv() logger.debug(f"Got this message from the server: {type(chunk)} {chunk}") - # print((f"Got this message from the server: {type(chunk)} {chunk}")) + # print("received chunk from server") + if type(chunk) == str: chunk = json.loads(chunk) - # message = accumulator.accumulate(chunk) - message = chunk + if self.tts_service == "elevenlabs": + message = chunk + else: + message = accumulator.accumulate(chunk) + if message == None: # Will be None until we have a full message ready continue # At this point, we have our message - # print("checkpoint reached!", message) - if isinstance(message, bytes): - - # if message["type"] == "audio" and message["format"].startswith("bytes"): + if isinstance(message, bytes) or ( + message["type"] == "audio" and message["format"].startswith("bytes") + ): # Convert bytes to audio file - - # audio_bytes = message["content"] - audio_bytes = message - - # Create an AudioSegment instance with the raw data - """ - audio = AudioSegment( - # raw audio data (bytes) - data=audio_bytes, - # signed 16-bit little-endian format - sample_width=2, - # 24,000 Hz frame rate - frame_rate=16000, - # mono sound - channels=1, - ) - """ - - # print("audio segment was created") - await self.audiosegments.put(audio_bytes) - - # await self.audiosegments.put(audio) + if self.tts_service == "elevenlabs": + audio_bytes = message + audio = audio_bytes + else: + audio_bytes = message["content"] + + # Create an AudioSegment instance with the raw data + audio = AudioSegment( + # raw audio data (bytes) + data=audio_bytes, + # signed 16-bit little-endian format + sample_width=2, + # 16,000 Hz frame rate + frame_rate=22050, + # mono sound + channels=1, + ) + + await self.audiosegments.put(audio) # Run the code if that's the client's job if os.getenv("CODE_RUNNER") == "client": @@ -434,29 +416,26 @@ class Device: except Exception as e: logger.error(f"Error while attempting to connect: {e}") else: - print("websocket url is", WS_URL) while True: try: async with websockets.connect(WS_URL) as websocket: - print("awaiting exec_ws_communication") await exec_ws_communication(websocket) except: - logger.info(traceback.format_exc()) + logger.debug(traceback.format_exc()) if show_connection_log: logger.info(f"Connecting to `{WS_URL}`...") show_connection_log = False await asyncio.sleep(2) async def start_async(self): - print("start async was called!!!!!") # Configuration for WebSocket WS_URL = f"ws://{self.server_url}" - # Start the WebSocket communication asyncio.create_task(self.websocket_communication(WS_URL)) # Start watching the kernel if it's your job to do that if os.getenv("CODE_RUNNER") == "client": + # client is not running code! asyncio.create_task(put_kernel_messages_into_queue(send_queue)) asyncio.create_task(self.play_audiosegments()) @@ -488,10 +467,8 @@ class Device: on_press=self.on_press, on_release=self.on_release ) listener.start() - print("listener for keyboard started!!!!!") def start(self): - print("device was started!!!!!!") if os.getenv("TEACH_MODE") != "True": asyncio.run(self.start_async()) p.terminate() diff --git a/software/source/clients/linux/device.py b/software/source/clients/linux/device.py index 0fa0fed..cf549d5 100644 --- a/software/source/clients/linux/device.py +++ b/software/source/clients/linux/device.py @@ -3,8 +3,9 @@ from ..base_device import Device device = Device() -def main(server_url): +def main(server_url, tts_service): device.server_url = server_url + device.tts_service = tts_service device.start() diff --git a/software/source/clients/mac/device.py b/software/source/clients/mac/device.py index 0fa0fed..cf549d5 100644 --- a/software/source/clients/mac/device.py +++ b/software/source/clients/mac/device.py @@ -3,8 +3,9 @@ from ..base_device import Device device = Device() -def main(server_url): +def main(server_url, tts_service): device.server_url = server_url + device.tts_service = tts_service device.start() diff --git a/software/source/clients/windows/device.py b/software/source/clients/windows/device.py index 0fa0fed..cf549d5 100644 --- a/software/source/clients/windows/device.py +++ b/software/source/clients/windows/device.py @@ -3,8 +3,9 @@ from ..base_device import Device device = Device() -def main(server_url): +def main(server_url, tts_service): device.server_url = server_url + device.tts_service = tts_service device.start() diff --git a/software/source/server/async_interpreter.py b/software/source/server/async_interpreter.py index 209ff73..4fee556 100644 --- a/software/source/server/async_interpreter.py +++ b/software/source/server/async_interpreter.py @@ -10,16 +10,9 @@ """ ### - from pynput import keyboard -from RealtimeTTS import ( - TextToAudioStream, - OpenAIEngine, - CoquiEngine, - ElevenlabsEngine, - SystemEngine, - GTTSEngine, -) + +from RealtimeTTS import TextToAudioStream, CoquiEngine, OpenAIEngine, ElevenlabsEngine from RealtimeSTT import AudioToTextRecorder import time import asyncio @@ -29,9 +22,9 @@ import os class AsyncInterpreter: def __init__(self, interpreter): - self.stt_latency = None - self.tts_latency = None - self.interpreter_latency = None + # self.stt_latency = None + # self.tts_latency = None + # self.interpreter_latency = None self.interpreter = interpreter # STT @@ -45,12 +38,9 @@ class AsyncInterpreter: engine = CoquiEngine() elif self.interpreter.tts == "openai": engine = OpenAIEngine() - elif self.interpreter.tts == "gtts": - engine = GTTSEngine() elif self.interpreter.tts == "elevenlabs": engine = ElevenlabsEngine(api_key=os.environ["ELEVEN_LABS_API_KEY"]) - elif self.interpreter.tts == "system": - engine = SystemEngine() + engine.set_voice("Michael") else: raise ValueError(f"Unsupported TTS engine: {self.interpreter.tts}") self.tts = TextToAudioStream(engine) @@ -112,111 +102,96 @@ class AsyncInterpreter: # print("ADDING TO QUEUE:", chunk) asyncio.create_task(self._add_to_queue(self._output_queue, chunk)) + def generate(self, message, start_interpreter): + last_lmc_start_flag = self._last_lmc_start_flag + self.interpreter.messages = self.active_chat_messages + + # print("message is", message) + + for chunk in self.interpreter.chat(message, display=True, stream=True): + + if self._last_lmc_start_flag != last_lmc_start_flag: + # self.beeper.stop() + break + + # self.add_to_output_queue_sync(chunk) # To send text, not just audio + + content = chunk.get("content") + + # Handle message blocks + if chunk.get("type") == "message": + if content: + # self.beeper.stop() + + # Experimental: The AI voice sounds better with replacements like these, but it should happen at the TTS layer + # content = content.replace(". ", ". ... ").replace(", ", ", ... ").replace("!", "! ... ").replace("?", "? ... ") + # print("yielding ", content) + yield content + + # Handle code blocks + elif chunk.get("type") == "code": + if "start" in chunk: + # self.beeper.start() + pass + + # Experimental: If the AI wants to type, we should type immediatly + if ( + self.interpreter.messages[-1] + .get("content", "") + .startswith("computer.keyboard.write(") + ): + keyboard.controller.type(content) + self._in_keyboard_write_block = True + if "end" in chunk and self._in_keyboard_write_block: + self._in_keyboard_write_block = False + # (This will make it so it doesn't type twice when the block executes) + if self.interpreter.messages[-1]["content"].startswith( + "computer.keyboard.write(" + ): + self.interpreter.messages[-1]["content"] = ( + "dummy_variable = (" + + self.interpreter.messages[-1]["content"][ + len("computer.keyboard.write(") : + ] + ) + + # Send a completion signal + # end_interpreter = time.time() + # self.interpreter_latency = end_interpreter - start_interpreter + # print("INTERPRETER LATENCY", self.interpreter_latency) + # self.add_to_output_queue_sync({"role": "server","type": "completion", "content": "DONE"}) + async def run(self): """ Runs OI on the audio bytes submitted to the input. Will add streaming LMC chunks to the _output_queue. """ self.interpreter.messages = self.active_chat_messages - # self.beeper.start() - self.stt.stop() - # message = self.stt.text() - # print("THE MESSAGE:", message) - # accumulates the input queue message input_queue = [] while not self._input_queue.empty(): input_queue.append(self._input_queue.get()) - # print("INPUT QUEUE:", input_queue) - # message = [i for i in input_queue if i["type"] == "message"][0]["content"] - start_stt = time.time() + # start_stt = time.time() message = self.stt.text() - end_stt = time.time() - self.stt_latency = end_stt - start_stt - print("STT LATENCY", self.stt_latency) - - # print(message) - end_interpreter = 0 + # end_stt = time.time() + # self.stt_latency = end_stt - start_stt + # print("STT LATENCY", self.stt_latency) # print(message) - def generate(message): - last_lmc_start_flag = self._last_lmc_start_flag - self.interpreter.messages = self.active_chat_messages - # print("πŸ€πŸ€πŸ€πŸ€GENERATING, using these messages: ", self.interpreter.messages) - # print("πŸ€ πŸ€ πŸ€ πŸ€ active_chat_messages: ", self.active_chat_messages) - print("message is", message) - - for chunk in self.interpreter.chat(message, display=True, stream=True): - - if self._last_lmc_start_flag != last_lmc_start_flag: - # self.beeper.stop() - break - - # self.add_to_output_queue_sync(chunk) # To send text, not just audio - - content = chunk.get("content") - - # Handle message blocks - if chunk.get("type") == "message": - if content: - # self.beeper.stop() - - # Experimental: The AI voice sounds better with replacements like these, but it should happen at the TTS layer - # content = content.replace(". ", ". ... ").replace(", ", ", ... ").replace("!", "! ... ").replace("?", "? ... ") - # print("yielding this", content) - yield content - - # Handle code blocks - elif chunk.get("type") == "code": - if "start" in chunk: - # self.beeper.start() - pass - - # Experimental: If the AI wants to type, we should type immediatly - if ( - self.interpreter.messages[-1] - .get("content", "") - .startswith("computer.keyboard.write(") - ): - keyboard.controller.type(content) - self._in_keyboard_write_block = True - if "end" in chunk and self._in_keyboard_write_block: - self._in_keyboard_write_block = False - # (This will make it so it doesn't type twice when the block executes) - if self.interpreter.messages[-1]["content"].startswith( - "computer.keyboard.write(" - ): - self.interpreter.messages[-1]["content"] = ( - "dummy_variable = (" - + self.interpreter.messages[-1]["content"][ - len("computer.keyboard.write(") : - ] - ) - - # Send a completion signal - end_interpreter = time.time() - self.interpreter_latency = end_interpreter - start_interpreter - print("INTERPRETER LATENCY", self.interpreter_latency) - # self.add_to_output_queue_sync({"role": "server","type": "completion", "content": "DONE"}) # Feed generate to RealtimeTTS self.add_to_output_queue_sync( {"role": "assistant", "type": "audio", "format": "bytes.wav", "start": True} ) start_interpreter = time.time() - text_iterator = generate(message) + text_iterator = self.generate(message, start_interpreter) self.tts.feed(text_iterator) - self.tts.play_async(on_audio_chunk=self.on_tts_chunk, muted=True) - while True: - if self.tts.is_playing(): - start_tts = time.time() + self.tts.play_async(on_audio_chunk=self.on_tts_chunk, muted=True) - break - await asyncio.sleep(0.1) while True: await asyncio.sleep(0.1) # print("is_playing", self.tts.is_playing()) @@ -229,14 +204,14 @@ class AsyncInterpreter: "end": True, } ) - end_tts = time.time() - self.tts_latency = end_tts - start_tts - print("TTS LATENCY", self.tts_latency) + # end_tts = time.time() + # self.tts_latency = end_tts - self.tts.stream_start_time + # print("TTS LATENCY", self.tts_latency) self.tts.stop() break async def _on_tts_chunk_async(self, chunk): - # print("SENDING TTS CHUNK") + # print("adding chunk to queue") await self._add_to_queue(self._output_queue, chunk) def on_tts_chunk(self, chunk): @@ -244,4 +219,5 @@ class AsyncInterpreter: asyncio.run(self._on_tts_chunk_async(chunk)) async def output(self): + # print("outputting chunks") return await self._output_queue.get() diff --git a/software/source/server/async_server.py b/software/source/server/async_server.py index ace4b4a..d52db12 100644 --- a/software/source/server/async_server.py +++ b/software/source/server/async_server.py @@ -1,43 +1,38 @@ import asyncio import traceback import json -from fastapi import FastAPI, WebSocket, Header +from fastapi import FastAPI, WebSocket from fastapi.responses import PlainTextResponse from uvicorn import Config, Server +from .i import configure_interpreter from interpreter import interpreter as base_interpreter from .async_interpreter import AsyncInterpreter from fastapi.middleware.cors import CORSMiddleware from typing import List, Dict, Any -from openai import OpenAI -from pydantic import BaseModel -import argparse import os -# import sentry_sdk - -base_interpreter.system_message = ( - "You are a helpful assistant that can answer questions and help with tasks." -) -base_interpreter.computer.import_computer_api = False -base_interpreter.llm.model = "groq/llama3-8b-8192" -base_interpreter.llm.api_key = os.environ["GROQ_API_KEY"] -base_interpreter.llm.supports_functions = False -base_interpreter.auto_run = True os.environ["STT_RUNNER"] = "server" os.environ["TTS_RUNNER"] = "server" -# Parse command line arguments for port number -""" -parser = argparse.ArgumentParser(description="FastAPI server.") -parser.add_argument("--port", type=int, default=8000, help="Port to run on.") -args = parser.parse_args() -""" -base_interpreter.tts = "coqui" - -async def main(server_host, server_port): - interpreter = AsyncInterpreter(base_interpreter) +async def main(server_host, server_port, tts_service, asynchronous): + if asynchronous: + base_interpreter.system_message = ( + "You are a helpful assistant that can answer questions and help with tasks." + ) + base_interpreter.computer.import_computer_api = False + base_interpreter.llm.model = "groq/llama3-8b-8192" + base_interpreter.llm.api_key = os.environ["GROQ_API_KEY"] + base_interpreter.llm.supports_functions = False + base_interpreter.auto_run = True + base_interpreter.tts = tts_service + interpreter = AsyncInterpreter(base_interpreter) + else: + configured_interpreter = configure_interpreter(base_interpreter) + configured_interpreter.llm.supports_functions = True + configured_interpreter.tts = tts_service + interpreter = AsyncInterpreter(configured_interpreter) app = FastAPI() @@ -107,37 +102,6 @@ async def main(server_host, server_port): server = Server(config) await server.serve() - class Rename(BaseModel): - input: str - - @app.post("/rename-chat") - async def rename_chat(body_content: Rename, x_api_key: str = Header(None)): - print("RENAME CHAT REQUEST in PY πŸŒ™πŸŒ™πŸŒ™πŸŒ™") - input_value = body_content.input - client = OpenAI( - # defaults to os.environ.get("OPENAI_API_KEY") - api_key=x_api_key, - ) - try: - response = client.chat.completions.create( - model="gpt-3.5-turbo", - messages=[ - { - "role": "user", - "content": f"Given the following chat snippet, create a unique and descriptive title in less than 8 words. Your answer must not be related to customer service.\n\n{input_value}", - } - ], - temperature=0.3, - stream=False, - ) - print(response) - completion = response["choices"][0]["message"]["content"] - return {"data": {"content": completion}} - except Exception as e: - print(f"Error: {e}") - traceback.print_exc() - return {"error": str(e)} - if __name__ == "__main__": asyncio.run(main()) diff --git a/software/start.py b/software/start.py index 7c5186e..6db7ebb 100644 --- a/software/start.py +++ b/software/start.py @@ -6,6 +6,8 @@ import os import importlib from source.server.tunnel import create_tunnel from source.server.async_server import main + +# from source.server.server import main from source.server.utils.local_mode import select_local_model import signal @@ -63,7 +65,7 @@ def run( 0.8, "--temperature", help="Specify the temperature for generation" ), tts_service: str = typer.Option( - "openai", "--tts-service", help="Specify the TTS service" + "elevenlabs", "--tts-service", help="Specify the TTS service" ), stt_service: str = typer.Option( "openai", "--stt-service", help="Specify the STT service" @@ -75,6 +77,9 @@ def run( mobile: bool = typer.Option( False, "--mobile", help="Toggle server to support mobile app" ), + asynchronous: bool = typer.Option( + False, "--async", help="use interpreter optimized for latency" + ), ): _run( server=server or mobile, @@ -97,6 +102,7 @@ def run( local=local, qr=qr or mobile, mobile=mobile, + asynchronous=asynchronous, ) @@ -116,14 +122,15 @@ def _run( context_window: int = 2048, max_tokens: int = 4096, temperature: float = 0.8, - tts_service: str = "openai", + tts_service: str = "elevenlabs", stt_service: str = "openai", local: bool = False, qr: bool = False, mobile: bool = False, + asynchronous: bool = False, ): if local: - tts_service = "piper" + tts_service = "coqui" # llm_service = "llamafile" stt_service = "local-whisper" select_local_model() @@ -154,6 +161,8 @@ def _run( main( server_host, server_port, + tts_service, + asynchronous, # llm_service, # model, # llm_supports_vision, @@ -161,7 +170,6 @@ def _run( # context_window, # max_tokens, # temperature, - # tts_service, # stt_service, # mobile, ), @@ -180,7 +188,6 @@ def _run( system_type = platform.system() if system_type == "Darwin": # Mac OS client_type = "mac" - print("initiating mac device with base device!!!") elif system_type == "Windows": # Windows System client_type = "windows" elif system_type == "Linux": # Linux System @@ -196,9 +203,10 @@ def _run( module = importlib.import_module( f".clients.{client_type}.device", package="source" ) - # server_url = "0.0.0.0:8000" - client_thread = threading.Thread(target=module.main, args=[server_url]) - print("client thread started") + + client_thread = threading.Thread( + target=module.main, args=[server_url, tts_service] + ) client_thread.start() try: