Merge pull request #284 from benxu3/async-interpreter

add --debug flag
pull/247/merge
killian 6 months ago committed by GitHub
commit dbb920b27f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -129,7 +129,7 @@ If you want to run local speech-to-text using Whisper, you must install Rust. Fo
To customize the behavior of the system, edit the [system message, model, skills library path,](https://docs.openinterpreter.com/settings/all-settings) etc. in the `profiles` directory under the `server` directory. This file sets up an interpreter, and is powered by Open Interpreter. To customize the behavior of the system, edit the [system message, model, skills library path,](https://docs.openinterpreter.com/settings/all-settings) etc. in the `profiles` directory under the `server` directory. This file sets up an interpreter, and is powered by Open Interpreter.
To specify the text-to-speech service for the 01 `base_device.py`, set `interpreter.tts` to either "openai" for OpenAI, "elevenlabs" for ElevenLabs, or "coqui" for Coqui (local) in a profile. For the 01 Light, set `SPEAKER_SAMPLE_RATE` to 24000 for Coqui (local) or 22050 for OpenAI TTS. We currently don't support ElevenLabs TTS on the 01 Light. To specify the text-to-speech service for the 01 `base_device.py`, set `interpreter.tts` to either "openai" for OpenAI, "elevenlabs" for ElevenLabs, or "coqui" for Coqui (local) in a profile. For the 01 Light, set `SPEAKER_SAMPLE_RATE` in `client.ino` under the `esp32` client directory to 24000 for Coqui (local) or 22050 for OpenAI TTS. We currently don't support ElevenLabs TTS on the 01 Light.
## Ubuntu Dependencies ## Ubuntu Dependencies

@ -91,6 +91,8 @@ class Device:
self.server_url = "" self.server_url = ""
self.ctrl_pressed = False self.ctrl_pressed = False
self.tts_service = "" self.tts_service = ""
self.debug = False
self.playback_latency = None
def fetch_image_from_camera(self, camera_index=CAMERA_DEVICE_INDEX): def fetch_image_from_camera(self, camera_index=CAMERA_DEVICE_INDEX):
"""Captures an image from the specified camera device and saves it to a temporary file. Adds the image to the captured_images list.""" """Captures an image from the specified camera device and saves it to a temporary file. Adds the image to the captured_images list."""
@ -164,6 +166,10 @@ class Device:
while True: while True:
try: try:
audio = await self.audiosegments.get() audio = await self.audiosegments.get()
if self.debug and self.playback_latency and isinstance(audio, bytes):
elapsed_time = time.time() - self.playback_latency
print(f"Time from request to playback: {elapsed_time} seconds")
self.playback_latency = None
if self.tts_service == "elevenlabs": if self.tts_service == "elevenlabs":
mpv_process.stdin.write(audio) # type: ignore mpv_process.stdin.write(audio) # type: ignore
@ -219,6 +225,8 @@ class Device:
stream.stop_stream() stream.stop_stream()
stream.close() stream.close()
print("Recording stopped.") print("Recording stopped.")
if self.debug:
self.playback_latency = time.time()
duration = wav_file.getnframes() / RATE duration = wav_file.getnframes() / RATE
if duration < 0.3: if duration < 0.3:

@ -3,8 +3,9 @@ from ..base_device import Device
device = Device() device = Device()
def main(server_url): def main(server_url, debug):
device.server_url = server_url device.server_url = server_url
device.debug = debug
device.start() device.start()

@ -3,8 +3,9 @@ from ..base_device import Device
device = Device() device = Device()
def main(server_url): def main(server_url, debug):
device.server_url = server_url device.server_url = server_url
device.debug = debug
device.start() device.start()

@ -3,8 +3,9 @@ from ..base_device import Device
device = Device() device = Device()
def main(server_url): def main(server_url, debug):
device.server_url = server_url device.server_url = server_url
device.debug = debug
device.start() device.start()

@ -21,7 +21,14 @@ import os
class AsyncInterpreter: class AsyncInterpreter:
def __init__(self, interpreter): def __init__(self, interpreter, debug):
self.stt_latency = None
self.tts_latency = None
self.interpreter_latency = None
# time from first put to first yield
self.tffytfp = None
self.debug = debug
self.interpreter = interpreter self.interpreter = interpreter
self.audio_chunks = [] self.audio_chunks = []
@ -126,6 +133,8 @@ class AsyncInterpreter:
# Experimental: The AI voice sounds better with replacements like these, but it should happen at the TTS layer # Experimental: The AI voice sounds better with replacements like these, but it should happen at the TTS layer
# content = content.replace(". ", ". ... ").replace(", ", ", ... ").replace("!", "! ... ").replace("?", "? ... ") # content = content.replace(". ", ". ... ").replace(", ", ", ... ").replace("!", "! ... ").replace("?", "? ... ")
# print("yielding ", content) # print("yielding ", content)
if self.tffytfp is None:
self.tffytfp = time.time()
yield content yield content
@ -157,6 +166,10 @@ class AsyncInterpreter:
) )
# Send a completion signal # Send a completion signal
if self.debug:
end_interpreter = time.time()
self.interpreter_latency = end_interpreter - start_interpreter
print("INTERPRETER LATENCY", self.interpreter_latency)
# self.add_to_output_queue_sync({"role": "server","type": "completion", "content": "DONE"}) # self.add_to_output_queue_sync({"role": "server","type": "completion", "content": "DONE"})
async def run(self): async def run(self):
@ -171,13 +184,20 @@ class AsyncInterpreter:
while not self._input_queue.empty(): while not self._input_queue.empty():
input_queue.append(self._input_queue.get()) input_queue.append(self._input_queue.get())
if self.debug:
start_stt = time.time()
message = self.stt.text() message = self.stt.text()
end_stt = time.time()
self.stt_latency = end_stt - start_stt
print("STT LATENCY", self.stt_latency)
if self.audio_chunks: if self.audio_chunks:
audio_bytes = bytearray(b"".join(self.audio_chunks)) audio_bytes = bytearray(b"".join(self.audio_chunks))
wav_file_path = bytes_to_wav(audio_bytes, "audio/raw") wav_file_path = bytes_to_wav(audio_bytes, "audio/raw")
print("wav_file_path ", wav_file_path) print("wav_file_path ", wav_file_path)
self.audio_chunks = [] self.audio_chunks = []
else:
message = self.stt.text()
print(message) print(message)
@ -204,11 +224,22 @@ class AsyncInterpreter:
"end": True, "end": True,
} }
) )
if self.debug:
end_tts = time.time()
self.tts_latency = end_tts - self.tts.stream_start_time
print("TTS LATENCY", self.tts_latency)
self.tts.stop() self.tts.stop()
break break
async def _on_tts_chunk_async(self, chunk): async def _on_tts_chunk_async(self, chunk):
# print("adding chunk to queue") # print("adding chunk to queue")
if self.debug and self.tffytfp is not None and self.tffytfp != 0:
print(
"time from first yield to first put is ",
time.time() - self.tffytfp,
)
self.tffytfp = 0
await self._add_to_queue(self._output_queue, chunk) await self._add_to_queue(self._output_queue, chunk)
def on_tts_chunk(self, chunk): def on_tts_chunk(self, chunk):

@ -1,30 +1,18 @@
# import from the profiles directory the interpreter to be served
# add other profiles to the directory to define other interpreter instances and import them here
# {.profiles.fast: optimizes for STT/TTS latency with the fastest models }
# {.profiles.local: uses local models and local STT/TTS }
# {.profiles.default: uses default interpreter settings with optimized TTS latency }
# from .profiles.fast import interpreter as base_interpreter
# from .profiles.local import interpreter as base_interpreter
from .profiles.default import interpreter as base_interpreter
import asyncio import asyncio
import traceback import traceback
import json import json
from fastapi import FastAPI, WebSocket from fastapi import FastAPI, WebSocket, Depends
from fastapi.responses import PlainTextResponse from fastapi.responses import PlainTextResponse
from uvicorn import Config, Server from uvicorn import Config, Server
from .async_interpreter import AsyncInterpreter from .async_interpreter import AsyncInterpreter
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from typing import List, Dict, Any from typing import List, Dict, Any
import os import os
import importlib.util
os.environ["STT_RUNNER"] = "server" os.environ["STT_RUNNER"] = "server"
os.environ["TTS_RUNNER"] = "server" os.environ["TTS_RUNNER"] = "server"
# interpreter.tts set in the profiles directory!!!!
interpreter = AsyncInterpreter(base_interpreter)
app = FastAPI() app = FastAPI()
@ -37,13 +25,19 @@ app.add_middleware(
) )
async def get_debug_flag():
return app.state.debug
@app.get("/ping") @app.get("/ping")
async def ping(): async def ping():
return PlainTextResponse("pong") return PlainTextResponse("pong")
@app.websocket("/") @app.websocket("/")
async def websocket_endpoint(websocket: WebSocket): async def websocket_endpoint(
websocket: WebSocket, debug: bool = Depends(get_debug_flag)
):
await websocket.accept() await websocket.accept()
# Send the tts_service value to the client # Send the tts_service value to the client
@ -91,7 +85,25 @@ async def websocket_endpoint(websocket: WebSocket):
await websocket.close() await websocket.close()
async def main(server_host, server_port): async def main(server_host, server_port, profile, debug):
app.state.debug = debug
# Load the profile module from the provided path
spec = importlib.util.spec_from_file_location("profile", profile)
profile_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(profile_module)
# Get the interpreter from the profile
interpreter = profile_module.interpreter
if not hasattr(interpreter, 'tts'):
print("Setting TTS provider to default: openai")
interpreter.tts = "openai"
# Make it async
interpreter = AsyncInterpreter(interpreter, debug)
print(f"Starting server on {server_host}:{server_port}") print(f"Starting server on {server_host}:{server_port}")
config = Config(app, host=server_host, port=server_port, lifespan="on") config = Config(app, host=server_host, port=server_port, lifespan="on")
server = Server(config) server = Server(config)

@ -6,7 +6,7 @@ from ..utils.print_markdown import print_markdown
def create_tunnel( def create_tunnel(
tunnel_method="ngrok", server_host="localhost", server_port=10001, qr=False tunnel_method="ngrok", server_host="localhost", server_port=10001, qr=False, domain=None
): ):
print_markdown("Exposing server to the internet...") print_markdown("Exposing server to the internet...")
@ -99,8 +99,13 @@ def create_tunnel(
# If ngrok is installed, start it on the specified port # If ngrok is installed, start it on the specified port
# process = subprocess.Popen(f'ngrok http {server_port} --log=stdout', shell=True, stdout=subprocess.PIPE) # process = subprocess.Popen(f'ngrok http {server_port} --log=stdout', shell=True, stdout=subprocess.PIPE)
if domain:
domain = f"--domain={domain}"
else:
domain = ""
process = subprocess.Popen( process = subprocess.Popen(
f"ngrok http {server_port} --scheme http,https --log=stdout", f"ngrok http {server_port} --scheme http,https {domain} --log=stdout",
shell=True, shell=True,
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
) )

@ -6,6 +6,7 @@ import os
import importlib import importlib
from source.server.tunnel import create_tunnel from source.server.tunnel import create_tunnel
from source.server.async_server import main from source.server.async_server import main
import subprocess
import signal import signal
@ -41,6 +42,25 @@ def run(
qr: bool = typer.Option( qr: bool = typer.Option(
False, "--qr", help="Display QR code to scan to connect to the server" False, "--qr", help="Display QR code to scan to connect to the server"
), ),
domain: str = typer.Option(
None, "--domain", help="Connect ngrok to a custom domain"
),
profiles: bool = typer.Option(
False,
"--profiles",
help="Opens the folder where this script is contained",
),
profile: str = typer.Option(
"default.py", # default
"--profile",
help="Specify the path to the profile, or the name of the file if it's in the `profiles` directory (run `--profiles` to open the profiles directory)",
),
debug: bool = typer.Option(
False,
"--debug",
help="Print latency measurements and save microphone recordings locally for manual playback.",
),
): ):
_run( _run(
server=server, server=server,
@ -52,6 +72,10 @@ def run(
server_url=server_url, server_url=server_url,
client_type=client_type, client_type=client_type,
qr=qr, qr=qr,
debug=debug,
domain=domain,
profiles=profiles,
profile=profile,
) )
@ -65,8 +89,34 @@ def _run(
server_url: str = None, server_url: str = None,
client_type: str = "auto", client_type: str = "auto",
qr: bool = False, qr: bool = False,
debug: bool = False,
domain = None,
profiles = None,
profile = None,
): ):
profiles_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "source", "server", "profiles")
if profiles:
if platform.system() == "Windows":
subprocess.Popen(['explorer', profiles_dir])
elif platform.system() == "Darwin":
subprocess.Popen(['open', profiles_dir])
elif platform.system() == "Linux":
subprocess.Popen(['xdg-open', profiles_dir])
else:
subprocess.Popen(['open', profiles_dir])
exit(0)
if profile:
if not os.path.isfile(profile):
profile = os.path.join(profiles_dir, profile)
if not os.path.isfile(profile):
profile += ".py"
if not os.path.isfile(profile):
print(f"Invalid profile path: {profile}")
exit(1)
system_type = platform.system() system_type = platform.system()
if system_type == "Windows": if system_type == "Windows":
server_host = "localhost" server_host = "localhost"
@ -84,7 +134,6 @@ def _run(
signal.signal(signal.SIGINT, handle_exit) signal.signal(signal.SIGINT, handle_exit)
if server: if server:
# print(f"Starting server with mobile = {mobile}")
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
server_thread = threading.Thread( server_thread = threading.Thread(
@ -93,6 +142,8 @@ def _run(
main( main(
server_host, server_host,
server_port, server_port,
profile,
debug,
), ),
), ),
) )
@ -100,7 +151,7 @@ def _run(
if expose: if expose:
tunnel_thread = threading.Thread( tunnel_thread = threading.Thread(
target=create_tunnel, args=[tunnel_service, server_host, server_port, qr] target=create_tunnel, args=[tunnel_service, server_host, server_port, qr, domain]
) )
tunnel_thread.start() tunnel_thread.start()
@ -125,7 +176,7 @@ def _run(
f".clients.{client_type}.device", package="source" f".clients.{client_type}.device", package="source"
) )
client_thread = threading.Thread(target=module.main, args=[server_url]) client_thread = threading.Thread(target=module.main, args=[server_url, debug])
client_thread.start() client_thread.start()
try: try:

Loading…
Cancel
Save