pull/246/head
parent
1324789123
commit
008763aab0
@ -0,0 +1,110 @@
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import PlainTextResponse
|
||||
from starlette.websockets import WebSocket, WebSocketDisconnect
|
||||
import asyncio
|
||||
from .utils.logs import setup_logging
|
||||
from .utils.logs import logger
|
||||
import traceback
|
||||
import json
|
||||
from ..utils.print_markdown import print_markdown
|
||||
from .queues import Queues
|
||||
|
||||
|
||||
setup_logging()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
from_computer, from_user, to_device = Queues.get()
|
||||
|
||||
|
||||
@app.get("/ping")
|
||||
async def ping():
|
||||
return PlainTextResponse("pong")
|
||||
|
||||
|
||||
@app.websocket("/")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
receive_task = asyncio.create_task(receive_messages(websocket))
|
||||
send_task = asyncio.create_task(send_messages(websocket))
|
||||
try:
|
||||
await asyncio.gather(receive_task, send_task)
|
||||
except Exception as e:
|
||||
logger.debug(traceback.format_exc())
|
||||
logger.info(f"Connection lost. Error: {e}")
|
||||
|
||||
|
||||
@app.post("/")
|
||||
async def add_computer_message(request: Request):
|
||||
body = await request.json()
|
||||
text = body.get("text")
|
||||
if not text:
|
||||
return {"error": "Missing 'text' in request body"}, 422
|
||||
message = {"role": "user", "type": "message", "content": text}
|
||||
await from_user.put({"role": "user", "type": "message", "start": True})
|
||||
await from_user.put(message)
|
||||
await from_user.put({"role": "user", "type": "message", "end": True})
|
||||
|
||||
|
||||
async def receive_messages(websocket: WebSocket):
|
||||
while True:
|
||||
try:
|
||||
try:
|
||||
data = await websocket.receive()
|
||||
except Exception as e:
|
||||
print(str(e))
|
||||
return
|
||||
|
||||
if "text" in data:
|
||||
try:
|
||||
data = json.loads(data["text"])
|
||||
if data["role"] == "computer":
|
||||
from_computer.put(
|
||||
data
|
||||
) # To be handled by interpreter.computer.run
|
||||
elif data["role"] == "user":
|
||||
await from_user.put(data)
|
||||
else:
|
||||
raise ("Unknown role:", data)
|
||||
except json.JSONDecodeError:
|
||||
pass # data is not JSON, leave it as is
|
||||
elif "bytes" in data:
|
||||
data = data["bytes"] # binary data
|
||||
await from_user.put(data)
|
||||
except WebSocketDisconnect as e:
|
||||
if e.code == 1000:
|
||||
logger.info("Websocket connection closed normally.")
|
||||
return
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
async def send_messages(websocket: WebSocket):
|
||||
while True:
|
||||
try:
|
||||
message = await to_device.get()
|
||||
# print(f"Sending to the device: {type(message)} {str(message)[:100]}")
|
||||
|
||||
if isinstance(message, dict):
|
||||
await websocket.send_json(message)
|
||||
elif isinstance(message, bytes):
|
||||
await websocket.send_bytes(message)
|
||||
else:
|
||||
raise TypeError("Message must be a dict or bytes")
|
||||
except Exception as e:
|
||||
if message:
|
||||
# Make sure to put the message back in the queue if you failed to send it
|
||||
await to_device.put(message)
|
||||
raise
|
||||
|
||||
# TODO: These two methods should change to lifespan
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
print("")
|
||||
print_markdown("\n*Ready.*\n")
|
||||
print("")
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
print_markdown("*Server is shutting down*")
|
@ -0,0 +1,43 @@
|
||||
import asyncio
|
||||
import queue
|
||||
|
||||
'''
|
||||
Queues are created on demand and should
|
||||
be accessed inside the currect event loop
|
||||
from a asyncio.run(co()) call.
|
||||
'''
|
||||
|
||||
class _ReadOnly(type):
|
||||
@property
|
||||
def from_computer(cls):
|
||||
if not cls._from_computer:
|
||||
# Sync queue because interpreter.run is synchronous.
|
||||
cls._from_computer = queue.Queue()
|
||||
return cls._from_computer
|
||||
|
||||
@property
|
||||
def from_user(cls):
|
||||
if not cls._from_user:
|
||||
cls._from_user = asyncio.Queue()
|
||||
return cls._from_user
|
||||
|
||||
@property
|
||||
def to_device(cls):
|
||||
if not cls._to_device:
|
||||
cls._to_device = asyncio.Queue()
|
||||
return cls._to_device
|
||||
|
||||
|
||||
class Queues(metaclass=_ReadOnly):
|
||||
# Queues used in server and app
|
||||
# Just for computer messages from the device.
|
||||
_from_computer = None
|
||||
|
||||
# Just for user messages from the device.
|
||||
_from_user = None
|
||||
|
||||
# For messages we send.
|
||||
_to_device = None
|
||||
|
||||
def get():
|
||||
return Queues.from_computer, Queues.from_user, Queues.to_device
|
Loading…
Reference in new issue