Better LMC accumulator logic, queued audio messages, proper audio streaming via bytes

pull/34/head^2
killian 11 months ago
parent cb628fa314
commit d4629c017c

@ -22,6 +22,7 @@ import wave
import tempfile import tempfile
from datetime import datetime from datetime import datetime
from interpreter import interpreter # Just for code execution. Maybe we should let people do from interpreter.computer import run? from interpreter import interpreter # Just for code execution. Maybe we should let people do from interpreter.computer import run?
# In the future, I guess kernel watching code should be elsewhere? Somewhere server / client agnostic?
from ..server.utils.kernel import put_kernel_messages_into_queue from ..server.utils.kernel import put_kernel_messages_into_queue
from ..server.utils.get_system_info import get_system_info from ..server.utils.get_system_info import get_system_info
from ..server.stt.stt import stt_wav from ..server.stt.stt import stt_wav
@ -30,6 +31,11 @@ from ..server.utils.logs import setup_logging
from ..server.utils.logs import logger from ..server.utils.logs import logger
setup_logging() setup_logging()
from ..utils.accumulator import Accumulator
accumulator = Accumulator()
# Configuration for Audio Recording # Configuration for Audio Recording
CHUNK = 1024 # Record in chunks of 1024 samples CHUNK = 1024 # Record in chunks of 1024 samples
FORMAT = pyaudio.paInt16 # 16 bits per sample FORMAT = pyaudio.paInt16 # 16 bits per sample
@ -44,19 +50,30 @@ current_platform = get_system_info()
# Initialize PyAudio # Initialize PyAudio
p = pyaudio.PyAudio() p = pyaudio.PyAudio()
import asyncio
send_queue = queue.Queue() send_queue = queue.Queue()
class Device: class Device:
def __init__(self): def __init__(self):
self.audiosegments = []
pass pass
async def play_audiosegments(self):
"""Plays them sequentially."""
while True:
try:
for audio in self.audiosegments:
play(audio)
self.audiosegments.remove(audio)
await asyncio.sleep(0.1)
except:
traceback.print_exc()
def record_audio(self): def record_audio(self):
if os.getenv('STT_RUNNER') == "server": if os.getenv('STT_RUNNER') == "server":
# STT will happen on the server. we're sending audio. # STT will happen on the server. we're sending audio.
send_queue.put({"role": "user", "type": "audio", "format": "audio/wav", "start": True}) send_queue.put({"role": "user", "type": "audio", "format": "bytes.wav", "start": True})
elif os.getenv('STT_RUNNER') == "client": elif os.getenv('STT_RUNNER') == "client":
# STT will happen here, on the client. we're sending text. # STT will happen here, on the client. we're sending text.
send_queue.put({"role": "user", "type": "message", "start": True}) send_queue.put({"role": "user", "type": "message", "start": True})
@ -92,8 +109,8 @@ class Device:
send_queue.put({"role": "user", "type": "message", "content": "stop"}) send_queue.put({"role": "user", "type": "message", "content": "stop"})
send_queue.put({"role": "user", "type": "message", "end": True}) send_queue.put({"role": "user", "type": "message", "end": True})
else: else:
send_queue.put({"role": "user", "type": "audio", "format": "audio/wav", "content": ""}) send_queue.put({"role": "user", "type": "audio", "format": "bytes.wav", "content": ""})
send_queue.put({"role": "user", "type": "audio", "format": "audio/wav", "end": True}) send_queue.put({"role": "user", "type": "audio", "format": "bytes.wav", "end": True})
else: else:
if os.getenv('STT_RUNNER') == "client": if os.getenv('STT_RUNNER') == "client":
# Run stt then send text # Run stt then send text
@ -105,9 +122,9 @@ class Device:
with open(wav_path, 'rb') as audio_file: with open(wav_path, 'rb') as audio_file:
byte_data = audio_file.read(CHUNK) byte_data = audio_file.read(CHUNK)
while byte_data: while byte_data:
send_queue.put({"role": "user", "type": "audio", "format": "audio/wav", "content": str(byte_data)}) send_queue.put(byte_data)
byte_data = audio_file.read(CHUNK) byte_data = audio_file.read(CHUNK)
send_queue.put({"role": "user", "type": "audio", "format": "audio/wav", "end": True}) send_queue.put({"role": "user", "type": "audio", "format": "bytes.wav", "end": True})
if os.path.exists(wav_path): if os.path.exists(wav_path):
os.remove(wav_path) os.remove(wav_path)
@ -140,8 +157,12 @@ class Device:
async def message_sender(self, websocket): async def message_sender(self, websocket):
while True: while True:
message = await asyncio.get_event_loop().run_in_executor(None, send_queue.get) message = await asyncio.get_event_loop().run_in_executor(None, send_queue.get)
await websocket.send(json.dumps(message)) if isinstance(message, bytes):
await websocket.send(message)
else:
await websocket.send(json.dumps(message))
send_queue.task_done() send_queue.task_done()
await asyncio.sleep(0.01)
async def websocket_communication(self, WS_URL): async def websocket_communication(self, WS_URL):
while True: while True:
@ -150,52 +171,42 @@ class Device:
logger.info("Press the spacebar to start/stop recording. Press ESC to exit.") logger.info("Press the spacebar to start/stop recording. Press ESC to exit.")
asyncio.create_task(self.message_sender(websocket)) asyncio.create_task(self.message_sender(websocket))
initial_message = {"role": None, "type": None, "format": None, "content": None}
message_so_far = initial_message
while True: while True:
message = await websocket.recv() await asyncio.sleep(0.01)
chunk = await websocket.recv()
logger.debug(f"Got this message from the server: {type(message)} {message}") logger.debug(f"Got this message from the server: {type(chunk)} {chunk}")
if type(message) == str: if type(chunk) == str:
message = json.loads(message) chunk = json.loads(chunk)
if message.get("end"): message = accumulator.accumulate(chunk)
logger.debug(f"Complete message from the server: {message_so_far}") if message == None:
logger.info("\n") # Will be None until we have a full message ready
message_so_far = initial_message continue
if "content" in message: # At this point, we have our message
print(message['content'], end="", flush=True)
if any(message_so_far[key] != message[key] for key in message_so_far if key != "content"):
message_so_far = message
else:
message_so_far["content"] += message["content"]
if message["type"] == "audio" and "content" in message: if message["type"] == "audio" and message["format"].startswith("bytes"):
audio_bytes = bytes(ast.literal_eval(message["content"]))
# Convert bytes to audio file # Convert bytes to audio file
audio_file = io.BytesIO(audio_bytes) # Format will be bytes.wav or bytes.opus
audio = AudioSegment.from_mp3(audio_file) audio_bytes = io.BytesIO(message["content"])
audio = AudioSegment.from_file(audio_bytes, codec=message["format"].split(".")[1])
# Play the audio
play(audio)
await asyncio.sleep(1) self.audiosegments.append(audio)
# Run the code if that's the client's job # Run the code if that's the client's job
if os.getenv('CODE_RUNNER') == "client": if os.getenv('CODE_RUNNER') == "client":
if message["type"] == "code" and "end" in message: if message["type"] == "code" and "end" in message:
language = message_so_far["format"] language = message["format"]
code = message_so_far["content"] code = message["content"]
result = interpreter.computer.run(language, code) result = interpreter.computer.run(language, code)
send_queue.put(result) send_queue.put(result)
except: except:
# traceback.print_exc() traceback.print_exc()
logger.info(f"Connecting to `{WS_URL}`...") logger.info(f"Connecting to `{WS_URL}`...")
await asyncio.sleep(2) await asyncio.sleep(2)
@ -212,6 +223,7 @@ class Device:
if os.getenv('CODE_RUNNER') == "client": if os.getenv('CODE_RUNNER') == "client":
asyncio.create_task(put_kernel_messages_into_queue(send_queue)) asyncio.create_task(put_kernel_messages_into_queue(send_queue))
asyncio.create_task(self.play_audiosegments())
# If Raspberry Pi, add the button listener, otherwise use the spacebar # If Raspberry Pi, add the button listener, otherwise use the spacebar
if current_platform.startswith("raspberry-pi"): if current_platform.startswith("raspberry-pi"):

@ -9,9 +9,9 @@ import traceback
import re import re
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.responses import PlainTextResponse from fastapi.responses import PlainTextResponse
from starlette.websockets import WebSocket from starlette.websockets import WebSocket, WebSocketDisconnect
from .stt.stt import stt_bytes from .stt.stt import stt_bytes
from .tts.tts import tts from .tts.tts import stream_tts
from pathlib import Path from pathlib import Path
import asyncio import asyncio
import urllib.parse import urllib.parse
@ -19,11 +19,13 @@ from .utils.kernel import put_kernel_messages_into_queue
from .i import configure_interpreter from .i import configure_interpreter
from interpreter import interpreter from interpreter import interpreter
import ngrok import ngrok
from ..utils.accumulator import Accumulator
from .utils.logs import setup_logging from .utils.logs import setup_logging
from .utils.logs import logger from .utils.logs import logger
setup_logging() setup_logging()
accumulator = Accumulator()
app = FastAPI() app = FastAPI()
@ -105,54 +107,89 @@ async def websocket_endpoint(websocket: WebSocket):
async def receive_messages(websocket: WebSocket): async def receive_messages(websocket: WebSocket):
while True: while True:
data = await websocket.receive_json() try:
if data["role"] == "computer": try:
from_computer.put(data) # To be handled by interpreter.computer.run data = await websocket.receive()
elif data["role"] == "user": except Exception as e:
await from_user.put(data) print(str(e))
else: return
raise("Unknown role:", data) if 'text' in data:
try:
data = json.loads(data['text'])
if data["role"] == "computer":
from_computer.put(data) # To be handled by interpreter.computer.run
elif data["role"] == "user":
await from_user.put(data)
else:
raise("Unknown role:", data)
except json.JSONDecodeError:
pass # data is not JSON, leave it as is
elif 'bytes' in data:
data = data['bytes'] # binary data
await from_user.put(data)
except WebSocketDisconnect as e:
if e.code == 1000:
logger.info("Websocket connection closed normally.")
return
else:
raise
async def send_messages(websocket: WebSocket): async def send_messages(websocket: WebSocket):
while True: while True:
message = await to_device.get() message = await to_device.get()
logger.debug(f"Sending to the device: {type(message)} {message}") logger.debug(f"Sending to the device: {type(message)} {message}")
await websocket.send_json(message)
try:
if isinstance(message, dict):
await websocket.send_json(message)
elif isinstance(message, bytes):
await websocket.send_bytes(message)
else:
raise TypeError("Message must be a dict or bytes")
except:
# Make sure to put the message back in the queue if you failed to send it
await to_device.put(message)
raise
async def listener(): async def listener():
audio_bytes = bytearray()
while True: while True:
while True: while True:
if not from_user.empty(): if not from_user.empty():
message = await from_user.get() chunk = await from_user.get()
break break
elif not from_computer.empty(): elif not from_computer.empty():
message = from_computer.get() chunk = from_computer.get()
break break
await asyncio.sleep(1) await asyncio.sleep(1)
if type(message) == str:
message = json.loads(message)
# Hold the audio in a buffer. If it's ready (we got end flag, stt it)
if message["type"] == "audio":
if "content" in message:
audio_bytes.extend(bytes(ast.literal_eval(message["content"])))
if "end" in message:
content = stt_bytes(audio_bytes, message["format"])
if content == None: # If it was nothing / silence
continue
audio_bytes = bytearray()
message = {"role": "user", "type": "message", "content": content}
else:
continue
# Ignore flags, we only needed them for audio ^ message = accumulator.accumulate(chunk)
if "content" not in message or message["content"] == None: if message == None:
# Will be None until we have a full message ready
continue continue
# print(str(message)[:1000])
# At this point, we have our message
if message["type"] == "audio" and message["format"].startswith("bytes"):
if not message["content"]: # If it was nothing / silence
continue
# Convert bytes to audio file
# Format will be bytes.wav or bytes.opus
mime_type = "audio/" + message["format"].split(".")[1]
text = stt_bytes(message["content"], mime_type)
message = {"role": "user", "type": "message", "content": text}
# At this point, we have only text messages
# Custom stop message will halt us # Custom stop message will halt us
if message["content"].lower().strip(".,!") == "stop": if message["content"].lower().strip(".,! ") == "stop":
continue continue
# Load, append, and save conversation history # Load, append, and save conversation history
@ -173,18 +210,30 @@ async def listener():
# Yield to the event loop, so you actually send it out # Yield to the event loop, so you actually send it out
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
# Speak full sentences out loud if os.getenv('TTS_RUNNER') == "server":
if chunk["role"] == "assistant" and "content" in chunk: # Speak full sentences out loud
accumulated_text += chunk["content"] if chunk["role"] == "assistant" and "content" in chunk:
sentences = split_into_sentences(accumulated_text) accumulated_text += chunk["content"]
if is_full_sentence(sentences[-1]): sentences = split_into_sentences(accumulated_text)
for sentence in sentences:
await stream_or_play_tts(sentence) # If we're going to speak, say we're going to stop sending text.
accumulated_text = "" # This should be fixed probably, we should be able to do both in parallel, or only one.
else: if any(is_full_sentence(sentence) for sentence in sentences):
for sentence in sentences[:-1]: await to_device.put({"role": "assistant", "type": "message", "end": True})
await stream_or_play_tts(sentence)
accumulated_text = sentences[-1] if is_full_sentence(sentences[-1]):
for sentence in sentences:
await stream_tts_to_device(sentence)
accumulated_text = ""
else:
for sentence in sentences[:-1]:
await stream_tts_to_device(sentence)
accumulated_text = sentences[-1]
# If we're going to speak, say we're going to stop sending text.
# This should be fixed probably, we should be able to do both in parallel, or only one.
if any(is_full_sentence(sentence) for sentence in sentences):
await to_device.put({"role": "assistant", "type": "message", "start": True})
# If we have a new message, save our progress and go back to the top # If we have a new message, save our progress and go back to the top
if not from_user.empty(): if not from_user.empty():
@ -217,16 +266,9 @@ async def listener():
with open(conversation_history_path, 'w') as file: with open(conversation_history_path, 'w') as file:
json.dump(interpreter.messages, file, indent=4) json.dump(interpreter.messages, file, indent=4)
async def stream_tts_to_device(sentence):
async def stream_or_play_tts(sentence): for chunk in stream_tts(sentence):
await to_device.put(chunk)
if os.getenv('TTS_RUNNER') == "server":
tts(sentence, play_audio=True)
else:
await to_device.put({"role": "assistant", "type": "audio", "format": "audio/mp3", "start": True})
audio_bytes = tts(sentence, play_audio=False)
await to_device.put({"role": "assistant", "type": "audio", "format": "audio/mp3", "content": str(audio_bytes)})
await to_device.put({"role": "assistant", "type": "audio", "format": "audio/mp3", "end": True})
async def setup_ngrok(ngrok_auth_token, parsed_url): async def setup_ngrok(ngrok_auth_token, parsed_url):
# Set up Ngrok # Set up Ngrok

@ -12,27 +12,28 @@ import os
import subprocess import subprocess
import tempfile import tempfile
from pydub import AudioSegment from pydub import AudioSegment
from pydub.playback import play
import simpleaudio as sa
client = OpenAI() client = OpenAI()
def tts(text, play_audio): chunk_size = 1024
def stream_tts(text):
"""
A generator that streams tts as LMC messages.
"""
if os.getenv('ALL_LOCAL') == 'False': if os.getenv('ALL_LOCAL') == 'False':
response = client.audio.speech.create( response = client.audio.speech.create(
model="tts-1", model="tts-1",
voice="alloy", voice="alloy",
input=text, input=text,
response_format="mp3" response_format="opus"
) )
with tempfile.NamedTemporaryFile(suffix=".mp3") as temp_file: with tempfile.NamedTemporaryFile(suffix=".opus") as temp_file:
response.stream_to_file(temp_file.name) response.stream_to_file(temp_file.name)
if play_audio: audio_bytes = temp_file.read()
audio = AudioSegment.from_mp3(temp_file.name) file_type = "bytes.opus"
play_audiosegment(audio)
return temp_file.read()
else: else:
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file: with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
output_file = temp_file.name output_file = temp_file.name
@ -43,13 +44,19 @@ def tts(text, play_audio):
'--output_file', output_file '--output_file', output_file
], input=text, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) ], input=text, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
if play_audio: audio_bytes = temp_file.read()
audio = AudioSegment.from_wav(temp_file.name) file_type = "bytes.wav"
play_audiosegment(audio)
return temp_file.read() # Stream the audio
yield {"role": "assistant", "type": "audio", "format": file_type, "start": True}
for i in range(0, len(audio_bytes), chunk_size):
chunk = audio_bytes[i:i+chunk_size]
yield chunk
yield {"role": "assistant", "type": "audio", "format": file_type, "end": True}
def play_audiosegment(audio): def play_audiosegment(audio):
""" """
UNUSED
the default makes some pops. this fixes that the default makes some pops. this fixes that
""" """
@ -73,3 +80,6 @@ def play_audiosegment(audio):
# Wait for the playback to finish # Wait for the playback to finish
play_obj.wait_done() play_obj.wait_done()
# Delete the wav file
os.remove("output_audio.wav")

@ -0,0 +1,40 @@
class Accumulator:
def __init__(self):
self.template = {"role": None, "type": None, "format": None, "content": None}
self.message = self.template
def accumulate(self, chunk):
print(str(chunk)[:100])
if type(chunk) == dict:
if "format" in chunk and chunk["format"] == "active_line":
# We don't do anything with these
return None
if "start" in chunk:
self.message = chunk
self.message.pop("start")
return None
if "content" in chunk:
if any(self.message[key] != chunk[key] for key in self.message if key != "content"):
self.message = chunk
if "content" not in self.message:
self.message["content"] = chunk["content"]
else:
self.message["content"] += chunk["content"]
return None
if "end" in chunk:
# We will proceed
message = self.message
self.message = self.template
return message
if type(chunk) == bytes:
if "content" not in self.message or type(self.message["content"]) != bytes:
self.message["content"] = b""
self.message["content"] += chunk
return None

Binary file not shown.
Loading…
Cancel
Save