merge server from temp-branch

pull/279/head
Ben Xu 7 months ago
parent 375ed1f575
commit 456ac51634

@ -1,12 +1,19 @@
# make this obvious
from .profiles.default import interpreter as base_interpreter
# from .profiles.fast import interpreter as base_interpreter
# from .profiles.local import interpreter as base_interpreter
# TODO: remove files i.py, llm.py, conftest?, services
import asyncio
import traceback
import json
from fastapi import FastAPI, WebSocket
from fastapi.responses import PlainTextResponse
from uvicorn import Config, Server
from .i import configure_interpreter
from interpreter import interpreter as base_interpreter
from starlette.websockets import WebSocketDisconnect
# from interpreter import interpreter as base_interpreter
from .async_interpreter import AsyncInterpreter
from fastapi.middleware.cors import CORSMiddleware
from typing import List, Dict, Any
@ -17,23 +24,9 @@ os.environ["STT_RUNNER"] = "server"
os.environ["TTS_RUNNER"] = "server"
async def main(server_host, server_port, tts_service, asynchronous):
if asynchronous:
base_interpreter.system_message = (
"You are a helpful assistant that can answer questions and help with tasks."
)
base_interpreter.computer.import_computer_api = False
base_interpreter.llm.model = "groq/llama3-8b-8192"
base_interpreter.llm.api_key = os.environ["GROQ_API_KEY"]
base_interpreter.llm.supports_functions = False
base_interpreter.auto_run = True
async def main(server_host, server_port, tts_service):
base_interpreter.tts = tts_service
interpreter = AsyncInterpreter(base_interpreter)
else:
configured_interpreter = configure_interpreter(base_interpreter)
configured_interpreter.llm.supports_functions = True
configured_interpreter.tts = tts_service
interpreter = AsyncInterpreter(configured_interpreter)
app = FastAPI()
@ -59,78 +52,43 @@ async def main(server_host, server_port, tts_service, asynchronous):
@app.websocket("/")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
async def send_output():
try:
while True:
output = await interpreter.output()
if isinstance(output, bytes):
try:
await websocket.send_bytes(output)
except Exception as e:
print(f"Error: {e}")
traceback.print_exc()
return {"error": str(e)}
elif isinstance(output, dict):
try:
await websocket.send_text(json.dumps(output))
except Exception as e:
print(f"Error: {e}")
traceback.print_exc()
return {"error": str(e)}
except asyncio.CancelledError:
print("WebSocket connection closed")
traceback.print_exc()
async def receive_input():
try:
while True:
# print("server awaiting input")
if websocket.client_state == "DISCONNECTED":
break
data = await websocket.receive()
if isinstance(data, bytes):
try:
await interpreter.input(data)
except Exception as e:
print(f"Error: {e}")
traceback.print_exc()
return {"error": str(e)}
elif "bytes" in data:
try:
await interpreter.input(data["bytes"])
except Exception as e:
print(f"Error: {e}")
traceback.print_exc()
return {"error": str(e)}
# print("RECEIVED INPUT", data)
elif "text" in data:
try:
# print("RECEIVED INPUT", data)
await interpreter.input(data["text"])
except Exception as e:
print(f"Error: {e}")
traceback.print_exc()
return {"error": str(e)}
except asyncio.CancelledError:
print("WebSocket connection closed")
traceback.print_exc()
try:
send_task = asyncio.create_task(send_output())
receive_task = asyncio.create_task(receive_input())
async def send_output():
while True:
output = await interpreter.output()
await asyncio.gather(send_task, receive_task)
if isinstance(output, bytes):
# print(f"Sending {len(output)} bytes of audio data.")
await websocket.send_bytes(output)
# we dont send out bytes rn, no TTS
elif isinstance(output, dict):
# print("sending text")
await websocket.send_text(json.dumps(output))
except WebSocketDisconnect:
print("WebSocket disconnected")
await asyncio.gather(send_output(), receive_input())
except Exception as e:
print(f"WebSocket connection closed with exception: {e}")
traceback.print_exc()
finally:
print("server closing ws connection")
if not websocket.client_state == "DISCONNECTED":
await websocket.close()
print(f"Starting server on {server_host}:{server_port}")
@ -140,4 +98,4 @@ async def main(server_host, server_port, tts_service, asynchronous):
if __name__ == "__main__":
asyncio.run(main("localhost", 8000))
asyncio.run(main())

@ -77,9 +77,6 @@ def run(
mobile: bool = typer.Option(
False, "--mobile", help="Toggle server to support mobile app"
),
asynchronous: bool = typer.Option(
False, "--async", help="use interpreter optimized for latency"
),
):
_run(
server=server or mobile,
@ -102,7 +99,6 @@ def run(
local=local,
qr=qr or mobile,
mobile=mobile,
asynchronous=asynchronous,
)
@ -127,7 +123,6 @@ def _run(
local: bool = False,
qr: bool = False,
mobile: bool = False,
asynchronous: bool = False,
):
if local:
tts_service = "coqui"
@ -162,7 +157,6 @@ def _run(
server_host,
server_port,
tts_service,
asynchronous,
# llm_service,
# model,
# llm_supports_vision,

Loading…
Cancel
Save