# make this obvious from .profiles.default import interpreter as base_interpreter # from .profiles.fast import interpreter as base_interpreter # from .profiles.local import interpreter as base_interpreter # TODO: remove files i.py, llm.py, conftest?, services import asyncio import traceback import json from fastapi import FastAPI, WebSocket from fastapi.responses import PlainTextResponse from uvicorn import Config, Server # from interpreter import interpreter as base_interpreter from .async_interpreter import AsyncInterpreter from fastapi.middleware.cors import CORSMiddleware from typing import List, Dict, Any import os os.environ["STT_RUNNER"] = "server" os.environ["TTS_RUNNER"] = "server" async def main(server_host, server_port, tts_service): base_interpreter.tts = tts_service interpreter = AsyncInterpreter(base_interpreter) app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], # Allow all methods (GET, POST, etc.) allow_headers=["*"], # Allow all headers ) @app.get("/ping") async def ping(): return PlainTextResponse("pong") @app.post("/load_chat") async def load_chat(messages: List[Dict[str, Any]]): interpreter.interpreter.messages = messages interpreter.active_chat_messages = messages print("🪼🪼🪼🪼🪼🪼 Messages loaded: ", interpreter.active_chat_messages) return {"status": "success"} @app.websocket("/") async def websocket_endpoint(websocket: WebSocket): await websocket.accept() try: async def receive_input(): while True: if websocket.client_state == "DISCONNECTED": break data = await websocket.receive() if isinstance(data, bytes): await interpreter.input(data) elif "bytes" in data: await interpreter.input(data["bytes"]) # print("RECEIVED INPUT", data) elif "text" in data: # print("RECEIVED INPUT", data) await interpreter.input(data["text"]) async def send_output(): while True: output = await interpreter.output() if isinstance(output, bytes): # print(f"Sending {len(output)} bytes of audio data.") await websocket.send_bytes(output) # we dont send out bytes rn, no TTS elif isinstance(output, dict): # print("sending text") await websocket.send_text(json.dumps(output)) await asyncio.gather(send_output(), receive_input()) except Exception as e: print(f"WebSocket connection closed with exception: {e}") traceback.print_exc() finally: if not websocket.client_state == "DISCONNECTED": await websocket.close() print(f"Starting server on {server_host}:{server_port}") config = Config(app, host=server_host, port=server_port, lifespan="on") server = Server(config) await server.serve() if __name__ == "__main__": asyncio.run(main())