You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
111 lines
3.4 KiB
111 lines
3.4 KiB
from fastapi import FastAPI, Request
|
|
from fastapi.responses import PlainTextResponse
|
|
from starlette.websockets import WebSocket, WebSocketDisconnect
|
|
import asyncio
|
|
from .utils.logs import setup_logging
|
|
from .utils.logs import logger
|
|
import traceback
|
|
import json
|
|
from ..utils.print_markdown import print_markdown
|
|
from .queues import Queues
|
|
|
|
|
|
setup_logging()
|
|
|
|
app = FastAPI()
|
|
|
|
from_computer, from_user, to_device = Queues.get()
|
|
|
|
|
|
@app.get("/ping")
|
|
async def ping():
|
|
return PlainTextResponse("pong")
|
|
|
|
|
|
@app.websocket("/")
|
|
async def websocket_endpoint(websocket: WebSocket):
|
|
await websocket.accept()
|
|
receive_task = asyncio.create_task(receive_messages(websocket))
|
|
send_task = asyncio.create_task(send_messages(websocket))
|
|
try:
|
|
await asyncio.gather(receive_task, send_task)
|
|
except Exception as e:
|
|
logger.debug(traceback.format_exc())
|
|
logger.info(f"Connection lost. Error: {e}")
|
|
|
|
|
|
@app.post("/")
|
|
async def add_computer_message(request: Request):
|
|
body = await request.json()
|
|
text = body.get("text")
|
|
if not text:
|
|
return {"error": "Missing 'text' in request body"}, 422
|
|
message = {"role": "user", "type": "message", "content": text}
|
|
await from_user.put({"role": "user", "type": "message", "start": True})
|
|
await from_user.put(message)
|
|
await from_user.put({"role": "user", "type": "message", "end": True})
|
|
|
|
|
|
async def receive_messages(websocket: WebSocket):
|
|
while True:
|
|
try:
|
|
try:
|
|
data = await websocket.receive()
|
|
except Exception as e:
|
|
print(str(e))
|
|
return
|
|
|
|
if "text" in data:
|
|
try:
|
|
data = json.loads(data["text"])
|
|
if data["role"] == "computer":
|
|
from_computer.put(
|
|
data
|
|
) # To be handled by interpreter.computer.run
|
|
elif data["role"] == "user":
|
|
await from_user.put(data)
|
|
else:
|
|
raise ("Unknown role:", data)
|
|
except json.JSONDecodeError:
|
|
pass # data is not JSON, leave it as is
|
|
elif "bytes" in data:
|
|
data = data["bytes"] # binary data
|
|
await from_user.put(data)
|
|
except WebSocketDisconnect as e:
|
|
if e.code == 1000:
|
|
logger.info("Websocket connection closed normally.")
|
|
return
|
|
else:
|
|
raise
|
|
|
|
|
|
async def send_messages(websocket: WebSocket):
|
|
while True:
|
|
try:
|
|
message = await to_device.get()
|
|
# print(f"Sending to the device: {type(message)} {str(message)[:100]}")
|
|
|
|
if isinstance(message, dict):
|
|
await websocket.send_json(message)
|
|
elif isinstance(message, bytes):
|
|
await websocket.send_bytes(message)
|
|
else:
|
|
raise TypeError("Message must be a dict or bytes")
|
|
except Exception as e:
|
|
if message:
|
|
# Make sure to put the message back in the queue if you failed to send it
|
|
await to_device.put(message)
|
|
raise
|
|
|
|
# TODO: These two methods should change to lifespan
|
|
@app.on_event("startup")
|
|
async def startup_event():
|
|
print("")
|
|
print_markdown("\n*Ready.*\n")
|
|
print("")
|
|
|
|
|
|
@app.on_event("shutdown")
|
|
async def shutdown_event():
|
|
print_markdown("*Server is shutting down*")
|