@ -9,9 +9,9 @@ import traceback
import re
from fastapi import FastAPI
from fastapi . responses import PlainTextResponse
from starlette . websockets import WebSocket
from starlette . websockets import WebSocket , WebSocketDisconnect
from . stt . stt import stt_bytes
from . tts . tts import tts
from . tts . tts import stream_ tts
from pathlib import Path
import asyncio
import urllib . parse
@ -19,11 +19,13 @@ from .utils.kernel import put_kernel_messages_into_queue
from . i import configure_interpreter
from interpreter import interpreter
import ngrok
from . . utils . accumulator import Accumulator
from . utils . logs import setup_logging
from . utils . logs import logger
setup_logging ( )
accumulator = Accumulator ( )
app = FastAPI ( )
@ -105,54 +107,89 @@ async def websocket_endpoint(websocket: WebSocket):
async def receive_messages ( websocket : WebSocket ) :
while True :
data = await websocket . receive_json ( )
if data [ " role " ] == " computer " :
from_computer . put ( data ) # To be handled by interpreter.computer.run
elif data [ " role " ] == " user " :
await from_user . put ( data )
else :
raise ( " Unknown role: " , data )
try :
try :
data = await websocket . receive ( )
except Exception as e :
print ( str ( e ) )
return
if ' text ' in data :
try :
data = json . loads ( data [ ' text ' ] )
if data [ " role " ] == " computer " :
from_computer . put ( data ) # To be handled by interpreter.computer.run
elif data [ " role " ] == " user " :
await from_user . put ( data )
else :
raise ( " Unknown role: " , data )
except json . JSONDecodeError :
pass # data is not JSON, leave it as is
elif ' bytes ' in data :
data = data [ ' bytes ' ] # binary data
await from_user . put ( data )
except WebSocketDisconnect as e :
if e . code == 1000 :
logger . info ( " Websocket connection closed normally. " )
return
else :
raise
async def send_messages ( websocket : WebSocket ) :
while True :
message = await to_device . get ( )
logger . debug ( f " Sending to the device: { type ( message ) } { message } " )
await websocket . send_json ( message )
try :
if isinstance ( message , dict ) :
await websocket . send_json ( message )
elif isinstance ( message , bytes ) :
await websocket . send_bytes ( message )
else :
raise TypeError ( " Message must be a dict or bytes " )
except :
# Make sure to put the message back in the queue if you failed to send it
await to_device . put ( message )
raise
async def listener ( ) :
audio_bytes = bytearray ( )
while True :
while True :
if not from_user . empty ( ) :
message = await from_user . get ( )
chunk = await from_user . get ( )
break
elif not from_computer . empty ( ) :
message = from_computer . get ( )
chunk = from_computer . get ( )
break
await asyncio . sleep ( 1 )
if type ( message ) == str :
message = json . loads ( message )
# Hold the audio in a buffer. If it's ready (we got end flag, stt it)
if message [ " type " ] == " audio " :
if " content " in message :
audio_bytes . extend ( bytes ( ast . literal_eval ( message [ " content " ] ) ) )
if " end " in message :
content = stt_bytes ( audio_bytes , message [ " format " ] )
if content == None : # If it was nothing / silence
continue
audio_bytes = bytearray ( )
message = { " role " : " user " , " type " : " message " , " content " : content }
else :
message = accumulator . accumulate ( chunk )
if message == None :
# Will be None until we have a full message ready
continue
# print(str(message)[:1000])
# At this point, we have our message
if message [ " type " ] == " audio " and message [ " format " ] . startswith ( " bytes " ) :
if not message [ " content " ] : # If it was nothing / silence
continue
# Ignore flags, we only needed them for audio ^
if " content " not in message or message [ " content " ] == None :
continue
# Convert bytes to audio file
# Format will be bytes.wav or bytes.opus
mime_type = " audio/ " + message [ " format " ] . split ( " . " ) [ 1 ]
text = stt_bytes ( message [ " content " ] , mime_type )
message = { " role " : " user " , " type " : " message " , " content " : text }
# At this point, we have only text messages
# Custom stop message will halt us
if message [ " content " ] . lower ( ) . strip ( " .,! " ) == " stop " :
if message [ " content " ] . lower ( ) . strip ( " .,! " ) == " stop " :
continue
# Load, append, and save conversation history
@ -173,19 +210,31 @@ async def listener():
# Yield to the event loop, so you actually send it out
await asyncio . sleep ( 0.01 )
# Speak full sentences out loud
if chunk [ " role " ] == " assistant " and " content " in chunk :
accumulated_text + = chunk [ " content " ]
sentences = split_into_sentences ( accumulated_text )
if is_full_sentence ( sentences [ - 1 ] ) :
for sentence in sentences :
await stream_or_play_tts ( sentence )
accumulated_text = " "
else :
for sentence in sentences [ : - 1 ] :
await stream_or_play_tts ( sentence )
accumulated_text = sentences [ - 1 ]
if os . getenv ( ' TTS_RUNNER ' ) == " server " :
# Speak full sentences out loud
if chunk [ " role " ] == " assistant " and " content " in chunk :
accumulated_text + = chunk [ " content " ]
sentences = split_into_sentences ( accumulated_text )
# If we're going to speak, say we're going to stop sending text.
# This should be fixed probably, we should be able to do both in parallel, or only one.
if any ( is_full_sentence ( sentence ) for sentence in sentences ) :
await to_device . put ( { " role " : " assistant " , " type " : " message " , " end " : True } )
if is_full_sentence ( sentences [ - 1 ] ) :
for sentence in sentences :
await stream_tts_to_device ( sentence )
accumulated_text = " "
else :
for sentence in sentences [ : - 1 ] :
await stream_tts_to_device ( sentence )
accumulated_text = sentences [ - 1 ]
# If we're going to speak, say we're going to stop sending text.
# This should be fixed probably, we should be able to do both in parallel, or only one.
if any ( is_full_sentence ( sentence ) for sentence in sentences ) :
await to_device . put ( { " role " : " assistant " , " type " : " message " , " start " : True } )
# If we have a new message, save our progress and go back to the top
if not from_user . empty ( ) :
@ -215,19 +264,12 @@ async def listener():
break
else :
with open ( conversation_history_path , ' w ' ) as file :
json . dump ( interpreter . messages , file , indent = 4 )
async def stream_or_play_tts ( sentence ) :
if os . getenv ( ' TTS_RUNNER ' ) == " server " :
tts ( sentence , play_audio = True )
else :
await to_device . put ( { " role " : " assistant " , " type " : " audio " , " format " : " audio/mp3 " , " start " : True } )
audio_bytes = tts ( sentence , play_audio = False )
await to_device . put ( { " role " : " assistant " , " type " : " audio " , " format " : " audio/mp3 " , " content " : str ( audio_bytes ) } )
await to_device . put ( { " role " : " assistant " , " type " : " audio " , " format " : " audio/mp3 " , " end " : True } )
json . dump ( interpreter . messages , file , indent = 4 )
async def stream_tts_to_device ( sentence ) :
for chunk in stream_tts ( sentence ) :
await to_device . put ( chunk )
async def setup_ngrok ( ngrok_auth_token , parsed_url ) :
# Set up Ngrok
logger . info ( " Setting up Ngrok " )