add debug flag

pull/284/head
Ben Xu 7 months ago
parent ef48e9c8fb
commit 72f41ad760

@ -91,6 +91,8 @@ class Device:
self.server_url = "" self.server_url = ""
self.ctrl_pressed = False self.ctrl_pressed = False
self.tts_service = "" self.tts_service = ""
self.debug = False
self.playback_latency = None
def fetch_image_from_camera(self, camera_index=CAMERA_DEVICE_INDEX): def fetch_image_from_camera(self, camera_index=CAMERA_DEVICE_INDEX):
"""Captures an image from the specified camera device and saves it to a temporary file. Adds the image to the captured_images list.""" """Captures an image from the specified camera device and saves it to a temporary file. Adds the image to the captured_images list."""
@ -164,6 +166,10 @@ class Device:
while True: while True:
try: try:
audio = await self.audiosegments.get() audio = await self.audiosegments.get()
if self.debug and self.playback_latency and isinstance(audio, bytes):
elapsed_time = time.time() - self.playback_latency
print(f"Time from request to playback: {elapsed_time} seconds")
self.playback_latency = None
if self.tts_service == "elevenlabs": if self.tts_service == "elevenlabs":
mpv_process.stdin.write(audio) # type: ignore mpv_process.stdin.write(audio) # type: ignore
@ -219,6 +225,8 @@ class Device:
stream.stop_stream() stream.stop_stream()
stream.close() stream.close()
print("Recording stopped.") print("Recording stopped.")
if self.debug:
self.playback_latency = time.time()
duration = wav_file.getnframes() / RATE duration = wav_file.getnframes() / RATE
if duration < 0.3: if duration < 0.3:

@ -3,8 +3,9 @@ from ..base_device import Device
device = Device() device = Device()
def main(server_url): def main(server_url, debug):
device.server_url = server_url device.server_url = server_url
device.debug = debug
device.start() device.start()

@ -3,8 +3,9 @@ from ..base_device import Device
device = Device() device = Device()
def main(server_url): def main(server_url, debug):
device.server_url = server_url device.server_url = server_url
device.debug = debug
device.start() device.start()

@ -3,8 +3,9 @@ from ..base_device import Device
device = Device() device = Device()
def main(server_url): def main(server_url, debug):
device.server_url = server_url device.server_url = server_url
device.debug = debug
device.start() device.start()

@ -21,7 +21,13 @@ import os
class AsyncInterpreter: class AsyncInterpreter:
def __init__(self, interpreter): def __init__(self, interpreter, debug):
self.stt_latency = None
self.tts_latency = None
self.interpreter_latency = None
self.tffytfp = None
self.debug = debug
self.interpreter = interpreter self.interpreter = interpreter
self.audio_chunks = [] self.audio_chunks = []
@ -126,6 +132,8 @@ class AsyncInterpreter:
# Experimental: The AI voice sounds better with replacements like these, but it should happen at the TTS layer # Experimental: The AI voice sounds better with replacements like these, but it should happen at the TTS layer
# content = content.replace(". ", ". ... ").replace(", ", ", ... ").replace("!", "! ... ").replace("?", "? ... ") # content = content.replace(". ", ". ... ").replace(", ", ", ... ").replace("!", "! ... ").replace("?", "? ... ")
# print("yielding ", content) # print("yielding ", content)
if self.time_from_first_yield_to_first_put is None:
self.time_from_first_yield_to_first_put = time.time()
yield content yield content
@ -157,6 +165,10 @@ class AsyncInterpreter:
) )
# Send a completion signal # Send a completion signal
if self.debug:
end_interpreter = time.time()
self.interpreter_latency = end_interpreter - start_interpreter
print("INTERPRETER LATENCY", self.interpreter_latency)
# self.add_to_output_queue_sync({"role": "server","type": "completion", "content": "DONE"}) # self.add_to_output_queue_sync({"role": "server","type": "completion", "content": "DONE"})
async def run(self): async def run(self):
@ -171,13 +183,20 @@ class AsyncInterpreter:
while not self._input_queue.empty(): while not self._input_queue.empty():
input_queue.append(self._input_queue.get()) input_queue.append(self._input_queue.get())
if self.debug:
start_stt = time.time()
message = self.stt.text() message = self.stt.text()
end_stt = time.time()
self.stt_latency = end_stt - start_stt
print("STT LATENCY", self.stt_latency)
if self.audio_chunks: if self.audio_chunks:
audio_bytes = bytearray(b"".join(self.audio_chunks)) audio_bytes = bytearray(b"".join(self.audio_chunks))
wav_file_path = bytes_to_wav(audio_bytes, "audio/raw") wav_file_path = bytes_to_wav(audio_bytes, "audio/raw")
print("wav_file_path ", wav_file_path) print("wav_file_path ", wav_file_path)
self.audio_chunks = [] self.audio_chunks = []
else:
message = self.stt.text()
print(message) print(message)
@ -204,11 +223,22 @@ class AsyncInterpreter:
"end": True, "end": True,
} }
) )
if self.debug:
end_tts = time.time()
self.tts_latency = end_tts - self.tts.stream_start_time
print("TTS LATENCY", self.tts_latency)
self.tts.stop() self.tts.stop()
break break
async def _on_tts_chunk_async(self, chunk): async def _on_tts_chunk_async(self, chunk):
# print("adding chunk to queue") # print("adding chunk to queue")
if self.debug and self.tffytfp is not None and self.tffytfp != 0:
print(
"time from first yield to first put is ",
time.time() - self.tffytfp,
)
self.tffytfp = 0
await self._add_to_queue(self._output_queue, chunk) await self._add_to_queue(self._output_queue, chunk)
def on_tts_chunk(self, chunk): def on_tts_chunk(self, chunk):

@ -12,7 +12,7 @@ from .profiles.default import interpreter as base_interpreter
import asyncio import asyncio
import traceback import traceback
import json import json
from fastapi import FastAPI, WebSocket from fastapi import FastAPI, WebSocket, Depends
from fastapi.responses import PlainTextResponse from fastapi.responses import PlainTextResponse
from uvicorn import Config, Server from uvicorn import Config, Server
from .async_interpreter import AsyncInterpreter from .async_interpreter import AsyncInterpreter
@ -23,8 +23,6 @@ import os
os.environ["STT_RUNNER"] = "server" os.environ["STT_RUNNER"] = "server"
os.environ["TTS_RUNNER"] = "server" os.environ["TTS_RUNNER"] = "server"
# interpreter.tts set in the profiles directory!!!!
interpreter = AsyncInterpreter(base_interpreter)
app = FastAPI() app = FastAPI()
@ -37,15 +35,24 @@ app.add_middleware(
) )
async def get_debug_flag():
return app.state.debug
@app.get("/ping") @app.get("/ping")
async def ping(): async def ping():
return PlainTextResponse("pong") return PlainTextResponse("pong")
@app.websocket("/") @app.websocket("/")
async def websocket_endpoint(websocket: WebSocket): async def websocket_endpoint(
websocket: WebSocket, debug: bool = Depends(get_debug_flag)
):
await websocket.accept() await websocket.accept()
# interpreter.tts set in the profiles directory!!!!
interpreter = AsyncInterpreter(base_interpreter, debug)
# Send the tts_service value to the client # Send the tts_service value to the client
await websocket.send_text( await websocket.send_text(
json.dumps({"type": "config", "tts_service": interpreter.interpreter.tts}) json.dumps({"type": "config", "tts_service": interpreter.interpreter.tts})
@ -91,7 +98,9 @@ async def websocket_endpoint(websocket: WebSocket):
await websocket.close() await websocket.close()
async def main(server_host, server_port): async def main(server_host, server_port, debug):
app.state.debug = debug
print(f"Starting server on {server_host}:{server_port}") print(f"Starting server on {server_host}:{server_port}")
config = Config(app, host=server_host, port=server_port, lifespan="on") config = Config(app, host=server_host, port=server_port, lifespan="on")
server = Server(config) server = Server(config)

@ -41,6 +41,11 @@ def run(
qr: bool = typer.Option( qr: bool = typer.Option(
False, "--qr", help="Display QR code to scan to connect to the server" False, "--qr", help="Display QR code to scan to connect to the server"
), ),
debug: bool = typer.Option(
False,
"--debug",
help="Print latency measurements and save microphone recordings locally for manual playback.",
),
): ):
_run( _run(
server=server, server=server,
@ -52,6 +57,7 @@ def run(
server_url=server_url, server_url=server_url,
client_type=client_type, client_type=client_type,
qr=qr, qr=qr,
debug=debug,
) )
@ -65,6 +71,7 @@ def _run(
server_url: str = None, server_url: str = None,
client_type: str = "auto", client_type: str = "auto",
qr: bool = False, qr: bool = False,
debug: bool = False,
): ):
system_type = platform.system() system_type = platform.system()
@ -93,6 +100,7 @@ def _run(
main( main(
server_host, server_host,
server_port, server_port,
debug,
), ),
), ),
) )
@ -125,7 +133,7 @@ def _run(
f".clients.{client_type}.device", package="source" f".clients.{client_type}.device", package="source"
) )
client_thread = threading.Thread(target=module.main, args=[server_url]) client_thread = threading.Thread(target=module.main, args=[server_url, debug])
client_thread.start() client_thread.start()
try: try:

Loading…
Cancel
Save