merge server from temp-branch

pull/279/head
Ben Xu 7 months ago
parent 375ed1f575
commit 456ac51634

@ -1,12 +1,19 @@
# make this obvious
from .profiles.default import interpreter as base_interpreter
# from .profiles.fast import interpreter as base_interpreter
# from .profiles.local import interpreter as base_interpreter
# TODO: remove files i.py, llm.py, conftest?, services
import asyncio import asyncio
import traceback import traceback
import json import json
from fastapi import FastAPI, WebSocket from fastapi import FastAPI, WebSocket
from fastapi.responses import PlainTextResponse from fastapi.responses import PlainTextResponse
from uvicorn import Config, Server from uvicorn import Config, Server
from .i import configure_interpreter
from interpreter import interpreter as base_interpreter # from interpreter import interpreter as base_interpreter
from starlette.websockets import WebSocketDisconnect
from .async_interpreter import AsyncInterpreter from .async_interpreter import AsyncInterpreter
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from typing import List, Dict, Any from typing import List, Dict, Any
@ -17,23 +24,9 @@ os.environ["STT_RUNNER"] = "server"
os.environ["TTS_RUNNER"] = "server" os.environ["TTS_RUNNER"] = "server"
async def main(server_host, server_port, tts_service, asynchronous): async def main(server_host, server_port, tts_service):
if asynchronous:
base_interpreter.system_message = (
"You are a helpful assistant that can answer questions and help with tasks."
)
base_interpreter.computer.import_computer_api = False
base_interpreter.llm.model = "groq/llama3-8b-8192"
base_interpreter.llm.api_key = os.environ["GROQ_API_KEY"]
base_interpreter.llm.supports_functions = False
base_interpreter.auto_run = True
base_interpreter.tts = tts_service base_interpreter.tts = tts_service
interpreter = AsyncInterpreter(base_interpreter) interpreter = AsyncInterpreter(base_interpreter)
else:
configured_interpreter = configure_interpreter(base_interpreter)
configured_interpreter.llm.supports_functions = True
configured_interpreter.tts = tts_service
interpreter = AsyncInterpreter(configured_interpreter)
app = FastAPI() app = FastAPI()
@ -59,78 +52,43 @@ async def main(server_host, server_port, tts_service, asynchronous):
@app.websocket("/") @app.websocket("/")
async def websocket_endpoint(websocket: WebSocket): async def websocket_endpoint(websocket: WebSocket):
await websocket.accept() await websocket.accept()
async def send_output():
try:
while True:
output = await interpreter.output()
if isinstance(output, bytes):
try: try:
await websocket.send_bytes(output)
except Exception as e:
print(f"Error: {e}")
traceback.print_exc()
return {"error": str(e)}
elif isinstance(output, dict):
try:
await websocket.send_text(json.dumps(output))
except Exception as e:
print(f"Error: {e}")
traceback.print_exc()
return {"error": str(e)}
except asyncio.CancelledError:
print("WebSocket connection closed")
traceback.print_exc()
async def receive_input(): async def receive_input():
try:
while True: while True:
# print("server awaiting input") if websocket.client_state == "DISCONNECTED":
break
data = await websocket.receive() data = await websocket.receive()
if isinstance(data, bytes): if isinstance(data, bytes):
try:
await interpreter.input(data) await interpreter.input(data)
except Exception as e:
print(f"Error: {e}")
traceback.print_exc()
return {"error": str(e)}
elif "bytes" in data: elif "bytes" in data:
try:
await interpreter.input(data["bytes"]) await interpreter.input(data["bytes"])
except Exception as e: # print("RECEIVED INPUT", data)
print(f"Error: {e}")
traceback.print_exc()
return {"error": str(e)}
elif "text" in data: elif "text" in data:
try: # print("RECEIVED INPUT", data)
await interpreter.input(data["text"]) await interpreter.input(data["text"])
except Exception as e:
print(f"Error: {e}")
traceback.print_exc()
return {"error": str(e)}
except asyncio.CancelledError:
print("WebSocket connection closed")
traceback.print_exc()
try: async def send_output():
send_task = asyncio.create_task(send_output()) while True:
receive_task = asyncio.create_task(receive_input()) output = await interpreter.output()
await asyncio.gather(send_task, receive_task) if isinstance(output, bytes):
# print(f"Sending {len(output)} bytes of audio data.")
await websocket.send_bytes(output)
# we dont send out bytes rn, no TTS
elif isinstance(output, dict):
# print("sending text")
await websocket.send_text(json.dumps(output))
except WebSocketDisconnect: await asyncio.gather(send_output(), receive_input())
print("WebSocket disconnected")
except Exception as e: except Exception as e:
print(f"WebSocket connection closed with exception: {e}") print(f"WebSocket connection closed with exception: {e}")
traceback.print_exc() traceback.print_exc()
finally: finally:
print("server closing ws connection") if not websocket.client_state == "DISCONNECTED":
await websocket.close() await websocket.close()
print(f"Starting server on {server_host}:{server_port}") print(f"Starting server on {server_host}:{server_port}")
@ -140,4 +98,4 @@ async def main(server_host, server_port, tts_service, asynchronous):
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main("localhost", 8000)) asyncio.run(main())

@ -77,9 +77,6 @@ def run(
mobile: bool = typer.Option( mobile: bool = typer.Option(
False, "--mobile", help="Toggle server to support mobile app" False, "--mobile", help="Toggle server to support mobile app"
), ),
asynchronous: bool = typer.Option(
False, "--async", help="use interpreter optimized for latency"
),
): ):
_run( _run(
server=server or mobile, server=server or mobile,
@ -102,7 +99,6 @@ def run(
local=local, local=local,
qr=qr or mobile, qr=qr or mobile,
mobile=mobile, mobile=mobile,
asynchronous=asynchronous,
) )
@ -127,7 +123,6 @@ def _run(
local: bool = False, local: bool = False,
qr: bool = False, qr: bool = False,
mobile: bool = False, mobile: bool = False,
asynchronous: bool = False,
): ):
if local: if local:
tts_service = "coqui" tts_service = "coqui"
@ -162,7 +157,6 @@ def _run(
server_host, server_host,
server_port, server_port,
tts_service, tts_service,
asynchronous,
# llm_service, # llm_service,
# model, # model,
# llm_supports_vision, # llm_supports_vision,

Loading…
Cancel
Save