Fixing asyncio queue creation and usage by decoupling app, queues, and uvicorn config.

pull/246/head
Robert Brisita 9 months ago
parent 1324789123
commit 008763aab0

@ -0,0 +1,110 @@
from fastapi import FastAPI, Request
from fastapi.responses import PlainTextResponse
from starlette.websockets import WebSocket, WebSocketDisconnect
import asyncio
from .utils.logs import setup_logging
from .utils.logs import logger
import traceback
import json
from ..utils.print_markdown import print_markdown
from .queues import Queues
setup_logging()
app = FastAPI()
from_computer, from_user, to_device = Queues.get()
@app.get("/ping")
async def ping():
return PlainTextResponse("pong")
@app.websocket("/")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
receive_task = asyncio.create_task(receive_messages(websocket))
send_task = asyncio.create_task(send_messages(websocket))
try:
await asyncio.gather(receive_task, send_task)
except Exception as e:
logger.debug(traceback.format_exc())
logger.info(f"Connection lost. Error: {e}")
@app.post("/")
async def add_computer_message(request: Request):
body = await request.json()
text = body.get("text")
if not text:
return {"error": "Missing 'text' in request body"}, 422
message = {"role": "user", "type": "message", "content": text}
await from_user.put({"role": "user", "type": "message", "start": True})
await from_user.put(message)
await from_user.put({"role": "user", "type": "message", "end": True})
async def receive_messages(websocket: WebSocket):
while True:
try:
try:
data = await websocket.receive()
except Exception as e:
print(str(e))
return
if "text" in data:
try:
data = json.loads(data["text"])
if data["role"] == "computer":
from_computer.put(
data
) # To be handled by interpreter.computer.run
elif data["role"] == "user":
await from_user.put(data)
else:
raise ("Unknown role:", data)
except json.JSONDecodeError:
pass # data is not JSON, leave it as is
elif "bytes" in data:
data = data["bytes"] # binary data
await from_user.put(data)
except WebSocketDisconnect as e:
if e.code == 1000:
logger.info("Websocket connection closed normally.")
return
else:
raise
async def send_messages(websocket: WebSocket):
while True:
try:
message = await to_device.get()
# print(f"Sending to the device: {type(message)} {str(message)[:100]}")
if isinstance(message, dict):
await websocket.send_json(message)
elif isinstance(message, bytes):
await websocket.send_bytes(message)
else:
raise TypeError("Message must be a dict or bytes")
except Exception as e:
if message:
# Make sure to put the message back in the queue if you failed to send it
await to_device.put(message)
raise
# TODO: These two methods should change to lifespan
@app.on_event("startup")
async def startup_event():
print("")
print_markdown("\n*Ready.*\n")
print("")
@app.on_event("shutdown")
async def shutdown_event():
print_markdown("*Server is shutting down*")

@ -0,0 +1,43 @@
import asyncio
import queue
'''
Queues are created on demand and should
be accessed inside the currect event loop
from a asyncio.run(co()) call.
'''
class _ReadOnly(type):
@property
def from_computer(cls):
if not cls._from_computer:
# Sync queue because interpreter.run is synchronous.
cls._from_computer = queue.Queue()
return cls._from_computer
@property
def from_user(cls):
if not cls._from_user:
cls._from_user = asyncio.Queue()
return cls._from_user
@property
def to_device(cls):
if not cls._to_device:
cls._to_device = asyncio.Queue()
return cls._to_device
class Queues(metaclass=_ReadOnly):
# Queues used in server and app
# Just for computer messages from the device.
_from_computer = None
# Just for user messages from the device.
_from_user = None
# For messages we send.
_to_device = None
def get():
return Queues.from_computer, Queues.from_user, Queues.to_device

@ -5,14 +5,10 @@ load_dotenv() # take environment variables from .env.
import traceback import traceback
from platformdirs import user_data_dir from platformdirs import user_data_dir
import json import json
import queue
import os import os
import datetime import datetime
from .utils.bytes_to_wav import bytes_to_wav from .utils.bytes_to_wav import bytes_to_wav
import re import re
from fastapi import FastAPI, Request
from fastapi.responses import PlainTextResponse
from starlette.websockets import WebSocket, WebSocketDisconnect
import asyncio import asyncio
from .utils.kernel import put_kernel_messages_into_queue from .utils.kernel import put_kernel_messages_into_queue
from .i import configure_interpreter from .i import configure_interpreter
@ -20,8 +16,8 @@ from interpreter import interpreter
from ..utils.accumulator import Accumulator from ..utils.accumulator import Accumulator
from .utils.logs import setup_logging from .utils.logs import setup_logging
from .utils.logs import logger from .utils.logs import logger
from ..utils.print_markdown import print_markdown from ..utils.print_markdown import print_markdown
from .queues import Queues
os.environ["STT_RUNNER"] = "server" os.environ["STT_RUNNER"] = "server"
os.environ["TTS_RUNNER"] = "server" os.environ["TTS_RUNNER"] = "server"
@ -40,8 +36,6 @@ setup_logging()
accumulator = Accumulator() accumulator = Accumulator()
app = FastAPI()
app_dir = user_data_dir("01") app_dir = user_data_dir("01")
conversation_history_path = os.path.join(app_dir, "conversations", "user.json") conversation_history_path = os.path.join(app_dir, "conversations", "user.json")
@ -56,14 +50,6 @@ def is_full_sentence(text):
def split_into_sentences(text): def split_into_sentences(text):
return re.split(r"(?<=[.!?])\s+", text) return re.split(r"(?<=[.!?])\s+", text)
# Queues
from_computer = (
queue.Queue()
) # Just for computer messages from the device. Sync queue because interpreter.run is synchronous
from_user = asyncio.Queue() # Just for user messages from the device.
to_device = asyncio.Queue() # For messages we send.
# Switch code executor to device if that's set # Switch code executor to device if that's set
if os.getenv("CODE_RUNNER") == "device": if os.getenv("CODE_RUNNER") == "device":
@ -76,9 +62,11 @@ if os.getenv("CODE_RUNNER") == "device":
def __init__(self): def __init__(self):
self.halt = False self.halt = False
def run(self, code): async def run(self, code):
"""Generator that yields a dictionary in LMC Format.""" """Generator that yields a dictionary in LMC Format."""
from_computer, _, to_device = Queues.get()
# Prepare the data # Prepare the data
message = { message = {
"role": "assistant", "role": "assistant",
@ -89,7 +77,7 @@ if os.getenv("CODE_RUNNER") == "device":
# Unless it was just sent to the device, send it wrapped in flags # Unless it was just sent to the device, send it wrapped in flags
if not (interpreter.messages and interpreter.messages[-1] == message): if not (interpreter.messages and interpreter.messages[-1] == message):
to_device.put( await to_device.put(
{ {
"role": "assistant", "role": "assistant",
"type": "code", "type": "code",
@ -97,8 +85,8 @@ if os.getenv("CODE_RUNNER") == "device":
"start": True, "start": True,
} }
) )
to_device.put(message) await to_device.put(message)
to_device.put( await to_device.put(
{ {
"role": "assistant", "role": "assistant",
"type": "code", "type": "code",
@ -130,86 +118,9 @@ if os.getenv("CODE_RUNNER") == "device":
interpreter = configure_interpreter(interpreter) interpreter = configure_interpreter(interpreter)
@app.get("/ping")
async def ping():
return PlainTextResponse("pong")
@app.websocket("/")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
receive_task = asyncio.create_task(receive_messages(websocket))
send_task = asyncio.create_task(send_messages(websocket))
try:
await asyncio.gather(receive_task, send_task)
except Exception as e:
logger.debug(traceback.format_exc())
logger.info(f"Connection lost. Error: {e}")
@app.post("/")
async def add_computer_message(request: Request):
body = await request.json()
text = body.get("text")
if not text:
return {"error": "Missing 'text' in request body"}, 422
message = {"role": "user", "type": "message", "content": text}
await from_user.put({"role": "user", "type": "message", "start": True})
await from_user.put(message)
await from_user.put({"role": "user", "type": "message", "end": True})
async def receive_messages(websocket: WebSocket):
while True:
try:
try:
data = await websocket.receive()
except Exception as e:
print(str(e))
return
if "text" in data:
try:
data = json.loads(data["text"])
if data["role"] == "computer":
from_computer.put(
data
) # To be handled by interpreter.computer.run
elif data["role"] == "user":
await from_user.put(data)
else:
raise ("Unknown role:", data)
except json.JSONDecodeError:
pass # data is not JSON, leave it as is
elif "bytes" in data:
data = data["bytes"] # binary data
await from_user.put(data)
except WebSocketDisconnect as e:
if e.code == 1000:
logger.info("Websocket connection closed normally.")
return
else:
raise
async def send_messages(websocket: WebSocket):
while True:
message = await to_device.get()
# print(f"Sending to the device: {type(message)} {str(message)[:100]}")
try:
if isinstance(message, dict):
await websocket.send_json(message)
elif isinstance(message, bytes):
await websocket.send_bytes(message)
else:
raise TypeError("Message must be a dict or bytes")
except:
# Make sure to put the message back in the queue if you failed to send it
await to_device.put(message)
raise
async def listener(): async def listener():
from_computer, from_user, to_device = Queues.get()
while True: while True:
try: try:
while True: while True:
@ -250,6 +161,7 @@ async def listener():
time.sleep(15) time.sleep(15)
# stt is a bound method
text = stt(audio_file_path) text = stt(audio_file_path)
print("> ", text) print("> ", text)
message = {"role": "user", "type": "message", "content": text} message = {"role": "user", "type": "message", "content": text}
@ -367,10 +279,11 @@ async def stream_tts_to_device(sentence):
return return
for chunk in stream_tts(sentence): for chunk in stream_tts(sentence):
await to_device.put(chunk) await Queues.to_device.put(chunk)
def stream_tts(sentence): def stream_tts(sentence):
# tts is a bound method
audio_file = tts(sentence) audio_file = tts(sentence)
with open(audio_file, "rb") as f: with open(audio_file, "rb") as f:
@ -392,23 +305,6 @@ from uvicorn import Config, Server
import os import os
from importlib import import_module from importlib import import_module
# these will be overwritten
HOST = ""
PORT = 0
@app.on_event("startup")
async def startup_event():
server_url = f"{HOST}:{PORT}"
print("")
print_markdown("\n*Ready.*\n")
print("")
@app.on_event("shutdown")
async def shutdown_event():
print_markdown("*Server is shutting down*")
async def main( async def main(
server_host, server_host,
@ -423,11 +319,6 @@ async def main(
tts_service, tts_service,
stt_service, stt_service,
): ):
global HOST
global PORT
PORT = server_port
HOST = server_host
# Setup services # Setup services
application_directory = user_data_dir("01") application_directory = user_data_dir("01")
services_directory = os.path.join(application_directory, "services") services_directory = os.path.join(application_directory, "services")
@ -470,6 +361,7 @@ async def main(
service_instance = ServiceClass(config) service_instance = ServiceClass(config)
globals()[service] = getattr(service_instance, service) globals()[service] = getattr(service_instance, service)
# llm is a bound method
interpreter.llm.completions = llm interpreter.llm.completions = llm
# Start listening # Start listening
@ -477,9 +369,9 @@ async def main(
# Start watching the kernel if it's your job to do that # Start watching the kernel if it's your job to do that
if True: # in the future, code can run on device. for now, just server. if True: # in the future, code can run on device. for now, just server.
asyncio.create_task(put_kernel_messages_into_queue(from_computer)) asyncio.create_task(put_kernel_messages_into_queue(Queues.from_computer))
config = Config(app, host=server_host, port=int(server_port), lifespan="on") config = Config("source.server.app:app", host=server_host, port=int(server_port), lifespan="on")
server = Server(config) server = Server(config)
await server.serve() await server.serve()

Loading…
Cancel
Save