You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
122 lines
3.4 KiB
122 lines
3.4 KiB
import asyncio
|
|
import traceback
|
|
import json
|
|
from fastapi import FastAPI, WebSocket, Depends
|
|
from fastapi.responses import PlainTextResponse
|
|
from uvicorn import Config, Server
|
|
from .async_interpreter import AsyncInterpreter
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from typing import List, Dict, Any
|
|
import os
|
|
import importlib.util
|
|
|
|
|
|
|
|
os.environ["STT_RUNNER"] = "server"
|
|
os.environ["TTS_RUNNER"] = "server"
|
|
|
|
app = FastAPI()
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"], # Allow all methods (GET, POST, etc.)
|
|
allow_headers=["*"], # Allow all headers
|
|
)
|
|
|
|
|
|
async def get_debug_flag():
|
|
return app.state.debug
|
|
|
|
|
|
@app.get("/ping")
|
|
async def ping():
|
|
return PlainTextResponse("pong")
|
|
|
|
|
|
@app.websocket("/")
|
|
async def websocket_endpoint(
|
|
websocket: WebSocket, debug: bool = Depends(get_debug_flag)
|
|
):
|
|
await websocket.accept()
|
|
|
|
global global_interpreter
|
|
interpreter = global_interpreter
|
|
|
|
# Send the tts_service value to the client
|
|
await websocket.send_text(
|
|
json.dumps({"type": "config", "tts_service": interpreter.interpreter.tts})
|
|
)
|
|
|
|
try:
|
|
|
|
async def receive_input():
|
|
while True:
|
|
if websocket.client_state == "DISCONNECTED":
|
|
break
|
|
|
|
data = await websocket.receive()
|
|
|
|
if isinstance(data, bytes):
|
|
await interpreter.input(data)
|
|
elif "bytes" in data:
|
|
await interpreter.input(data["bytes"])
|
|
# print("RECEIVED INPUT", data)
|
|
elif "text" in data:
|
|
# print("RECEIVED INPUT", data)
|
|
await interpreter.input(data["text"])
|
|
|
|
async def send_output():
|
|
while True:
|
|
output = await interpreter.output()
|
|
|
|
if isinstance(output, bytes):
|
|
# print(f"Sending {len(output)} bytes of audio data.")
|
|
await websocket.send_bytes(output)
|
|
# we dont send out bytes rn, no TTS
|
|
|
|
elif isinstance(output, dict):
|
|
# print("sending text")
|
|
await websocket.send_text(json.dumps(output))
|
|
|
|
await asyncio.gather(send_output(), receive_input())
|
|
except Exception as e:
|
|
print(f"WebSocket connection closed with exception: {e}")
|
|
traceback.print_exc()
|
|
finally:
|
|
if not websocket.client_state == "DISCONNECTED":
|
|
await websocket.close()
|
|
|
|
|
|
async def main(server_host, server_port, profile, debug):
|
|
|
|
app.state.debug = debug
|
|
|
|
# Load the profile module from the provided path
|
|
spec = importlib.util.spec_from_file_location("profile", profile)
|
|
profile_module = importlib.util.module_from_spec(spec)
|
|
spec.loader.exec_module(profile_module)
|
|
|
|
# Get the interpreter from the profile
|
|
interpreter = profile_module.interpreter
|
|
|
|
if not hasattr(interpreter, 'tts'):
|
|
print("Setting TTS provider to default: openai")
|
|
interpreter.tts = "openai"
|
|
|
|
# Make it async
|
|
interpreter = AsyncInterpreter(interpreter, debug)
|
|
|
|
global global_interpreter
|
|
global_interpreter = interpreter
|
|
|
|
print(f"Starting server on {server_host}:{server_port}")
|
|
config = Config(app, host=server_host, port=server_port, lifespan="on")
|
|
server = Server(config)
|
|
await server.serve()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|