diff --git a/software/source/server/async_server.py b/software/source/server/async_server.py index 7c815ab..13fcecd 100644 --- a/software/source/server/async_server.py +++ b/software/source/server/async_server.py @@ -1,12 +1,19 @@ +# make this obvious +from .profiles.default import interpreter as base_interpreter + +# from .profiles.fast import interpreter as base_interpreter +# from .profiles.local import interpreter as base_interpreter + +# TODO: remove files i.py, llm.py, conftest?, services + import asyncio import traceback import json from fastapi import FastAPI, WebSocket from fastapi.responses import PlainTextResponse from uvicorn import Config, Server -from .i import configure_interpreter -from interpreter import interpreter as base_interpreter -from starlette.websockets import WebSocketDisconnect + +# from interpreter import interpreter as base_interpreter from .async_interpreter import AsyncInterpreter from fastapi.middleware.cors import CORSMiddleware from typing import List, Dict, Any @@ -17,23 +24,9 @@ os.environ["STT_RUNNER"] = "server" os.environ["TTS_RUNNER"] = "server" -async def main(server_host, server_port, tts_service, asynchronous): - if asynchronous: - base_interpreter.system_message = ( - "You are a helpful assistant that can answer questions and help with tasks." - ) - base_interpreter.computer.import_computer_api = False - base_interpreter.llm.model = "groq/llama3-8b-8192" - base_interpreter.llm.api_key = os.environ["GROQ_API_KEY"] - base_interpreter.llm.supports_functions = False - base_interpreter.auto_run = True - base_interpreter.tts = tts_service - interpreter = AsyncInterpreter(base_interpreter) - else: - configured_interpreter = configure_interpreter(base_interpreter) - configured_interpreter.llm.supports_functions = True - configured_interpreter.tts = tts_service - interpreter = AsyncInterpreter(configured_interpreter) +async def main(server_host, server_port, tts_service): + base_interpreter.tts = tts_service + interpreter = AsyncInterpreter(base_interpreter) app = FastAPI() @@ -59,79 +52,44 @@ async def main(server_host, server_port, tts_service, asynchronous): @app.websocket("/") async def websocket_endpoint(websocket: WebSocket): await websocket.accept() + try: - async def send_output(): - try: + async def receive_input(): while True: - output = await interpreter.output() - - if isinstance(output, bytes): - try: - await websocket.send_bytes(output) - except Exception as e: - print(f"Error: {e}") - traceback.print_exc() - return {"error": str(e)} + if websocket.client_state == "DISCONNECTED": + break - elif isinstance(output, dict): - try: - await websocket.send_text(json.dumps(output)) - - except Exception as e: - print(f"Error: {e}") - traceback.print_exc() - return {"error": str(e)} - except asyncio.CancelledError: - print("WebSocket connection closed") - traceback.print_exc() - - async def receive_input(): - try: - while True: - # print("server awaiting input") data = await websocket.receive() if isinstance(data, bytes): - try: - await interpreter.input(data) - except Exception as e: - print(f"Error: {e}") - traceback.print_exc() - return {"error": str(e)} - + await interpreter.input(data) elif "bytes" in data: - try: - await interpreter.input(data["bytes"]) - except Exception as e: - print(f"Error: {e}") - traceback.print_exc() - return {"error": str(e)} - + await interpreter.input(data["bytes"]) + # print("RECEIVED INPUT", data) elif "text" in data: - try: - await interpreter.input(data["text"]) - except Exception as e: - print(f"Error: {e}") - traceback.print_exc() - return {"error": str(e)} - except asyncio.CancelledError: - print("WebSocket connection closed") - traceback.print_exc() + # print("RECEIVED INPUT", data) + await interpreter.input(data["text"]) - try: - send_task = asyncio.create_task(send_output()) - receive_task = asyncio.create_task(receive_input()) + async def send_output(): + while True: + output = await interpreter.output() + + if isinstance(output, bytes): + # print(f"Sending {len(output)} bytes of audio data.") + await websocket.send_bytes(output) + # we dont send out bytes rn, no TTS - await asyncio.gather(send_task, receive_task) + elif isinstance(output, dict): + # print("sending text") + await websocket.send_text(json.dumps(output)) - except WebSocketDisconnect: - print("WebSocket disconnected") + await asyncio.gather(send_output(), receive_input()) except Exception as e: print(f"WebSocket connection closed with exception: {e}") traceback.print_exc() finally: - print("server closing ws connection") - await websocket.close() + if not websocket.client_state == "DISCONNECTED": + await websocket.close() print(f"Starting server on {server_host}:{server_port}") config = Config(app, host=server_host, port=server_port, lifespan="on") @@ -140,4 +98,4 @@ async def main(server_host, server_port, tts_service, asynchronous): if __name__ == "__main__": - asyncio.run(main("localhost", 8000)) + asyncio.run(main()) diff --git a/software/start.py b/software/start.py index de206cf..5b711da 100644 --- a/software/start.py +++ b/software/start.py @@ -77,9 +77,6 @@ def run( mobile: bool = typer.Option( False, "--mobile", help="Toggle server to support mobile app" ), - asynchronous: bool = typer.Option( - False, "--async", help="use interpreter optimized for latency" - ), ): _run( server=server or mobile, @@ -102,7 +99,6 @@ def run( local=local, qr=qr or mobile, mobile=mobile, - asynchronous=asynchronous, ) @@ -127,7 +123,6 @@ def _run( local: bool = False, qr: bool = False, mobile: bool = False, - asynchronous: bool = False, ): if local: tts_service = "coqui" @@ -162,7 +157,6 @@ def _run( server_host, server_port, tts_service, - asynchronous, # llm_service, # model, # llm_supports_vision,