You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							165 lines
						
					
					
						
							5.7 KiB
						
					
					
				
			
		
		
	
	
							165 lines
						
					
					
						
							5.7 KiB
						
					
					
				import asyncio
 | 
						|
import traceback
 | 
						|
import json
 | 
						|
from fastapi import FastAPI, WebSocket, Header
 | 
						|
from fastapi.responses import PlainTextResponse
 | 
						|
from uvicorn import Config, Server
 | 
						|
from interpreter import interpreter as base_interpreter
 | 
						|
from starlette.websockets import WebSocketDisconnect
 | 
						|
from .async_interpreter import AsyncInterpreter
 | 
						|
from fastapi.middleware.cors import CORSMiddleware
 | 
						|
from typing import List, Dict, Any
 | 
						|
from openai import OpenAI
 | 
						|
from pydantic import BaseModel
 | 
						|
import argparse
 | 
						|
import os
 | 
						|
 | 
						|
# import sentry_sdk
 | 
						|
 | 
						|
base_interpreter.system_message = (
 | 
						|
    "You are a helpful assistant that can answer questions and help with tasks."
 | 
						|
)
 | 
						|
base_interpreter.computer.import_computer_api = False
 | 
						|
base_interpreter.llm.model = "groq/llama3-8b-8192"
 | 
						|
base_interpreter.llm.api_key = os.environ["GROQ_API_KEY"]
 | 
						|
base_interpreter.llm.supports_functions = False
 | 
						|
base_interpreter.auto_run = True
 | 
						|
base_interpreter.tts = "elevenlabs"
 | 
						|
 | 
						|
os.environ["STT_RUNNER"] = "server"
 | 
						|
os.environ["TTS_RUNNER"] = "server"
 | 
						|
 | 
						|
 | 
						|
async def main(server_host, server_port):
 | 
						|
    interpreter = AsyncInterpreter(base_interpreter)
 | 
						|
 | 
						|
    app = FastAPI()
 | 
						|
 | 
						|
    app.add_middleware(
 | 
						|
        CORSMiddleware,
 | 
						|
        allow_origins=["*"],
 | 
						|
        allow_credentials=True,
 | 
						|
        allow_methods=["*"],  # Allow all methods (GET, POST, etc.)
 | 
						|
        allow_headers=["*"],  # Allow all headers
 | 
						|
    )
 | 
						|
 | 
						|
    @app.get("/ping")
 | 
						|
    async def ping():
 | 
						|
        return PlainTextResponse("pong")
 | 
						|
 | 
						|
    @app.post("/load_chat")
 | 
						|
    async def load_chat(messages: List[Dict[str, Any]]):
 | 
						|
        interpreter.interpreter.messages = messages
 | 
						|
        interpreter.active_chat_messages = messages
 | 
						|
        print("🪼🪼🪼🪼🪼🪼 Messages loaded: ", interpreter.active_chat_messages)
 | 
						|
        return {"status": "success"}
 | 
						|
 | 
						|
    print("About to set up the websocker endpoint!!!!!!!!!!!!!!!!!!!!!!!!!")
 | 
						|
 | 
						|
    @app.websocket("/")
 | 
						|
    async def websocket_endpoint(websocket: WebSocket):
 | 
						|
        print("websocket hit")
 | 
						|
        await websocket.accept()
 | 
						|
        print("websocket accepted")
 | 
						|
 | 
						|
        async def send_output():
 | 
						|
            try:
 | 
						|
                while True:
 | 
						|
                    output = await interpreter.output()
 | 
						|
 | 
						|
                    if isinstance(output, bytes):
 | 
						|
                        print("server sending bytes output")
 | 
						|
                        try:
 | 
						|
                            await websocket.send_bytes(output)
 | 
						|
                            print("server successfully sent bytes output")
 | 
						|
                        except Exception as e:
 | 
						|
                            print(f"Error: {e}")
 | 
						|
                            traceback.print_exc()
 | 
						|
                            return {"error": str(e)}
 | 
						|
 | 
						|
                    elif isinstance(output, dict):
 | 
						|
                        print("server sending text output")
 | 
						|
                        try:
 | 
						|
                            await websocket.send_text(json.dumps(output))
 | 
						|
                            print("server successfully sent text output")
 | 
						|
                        except Exception as e:
 | 
						|
                            print(f"Error: {e}")
 | 
						|
                            traceback.print_exc()
 | 
						|
                            return {"error": str(e)}
 | 
						|
            except asyncio.CancelledError:
 | 
						|
                print("WebSocket connection closed")
 | 
						|
                traceback.print_exc()
 | 
						|
 | 
						|
        async def receive_input():
 | 
						|
            try:
 | 
						|
                while True:
 | 
						|
                    print("server awaiting input")
 | 
						|
                    data = await websocket.receive()
 | 
						|
 | 
						|
                    if isinstance(data, bytes):
 | 
						|
                        try:
 | 
						|
                            await interpreter.input(data)
 | 
						|
                        except Exception as e:
 | 
						|
                            print(f"Error: {e}")
 | 
						|
                            traceback.print_exc()
 | 
						|
                            return {"error": str(e)}
 | 
						|
 | 
						|
                    elif "bytes" in data:
 | 
						|
                        try:
 | 
						|
                            await interpreter.input(data["bytes"])
 | 
						|
                        except Exception as e:
 | 
						|
                            print(f"Error: {e}")
 | 
						|
                            traceback.print_exc()
 | 
						|
                            return {"error": str(e)}
 | 
						|
 | 
						|
                    elif "text" in data:
 | 
						|
                        try:
 | 
						|
                            await interpreter.input(data["text"])
 | 
						|
                        except Exception as e:
 | 
						|
                            print(f"Error: {e}")
 | 
						|
                            traceback.print_exc()
 | 
						|
                            return {"error": str(e)}
 | 
						|
            except asyncio.CancelledError:
 | 
						|
                print("WebSocket connection closed")
 | 
						|
                traceback.print_exc()
 | 
						|
 | 
						|
        try:
 | 
						|
            send_task = asyncio.create_task(send_output())
 | 
						|
            receive_task = asyncio.create_task(receive_input())
 | 
						|
 | 
						|
            print("server starting to handle ws connection")
 | 
						|
            """
 | 
						|
            done, pending = await asyncio.wait(
 | 
						|
                [send_task, receive_task],
 | 
						|
                return_when=asyncio.FIRST_COMPLETED,
 | 
						|
            )
 | 
						|
 | 
						|
            for task in pending:
 | 
						|
                task.cancel()
 | 
						|
 | 
						|
            for task in done:
 | 
						|
                if task.exception() is not None:
 | 
						|
                    raise
 | 
						|
            """
 | 
						|
            await asyncio.gather(send_task, receive_task)
 | 
						|
 | 
						|
            print("server finished handling ws connection")
 | 
						|
 | 
						|
        except WebSocketDisconnect:
 | 
						|
            print("WebSocket disconnected")
 | 
						|
        except Exception as e:
 | 
						|
            print(f"WebSocket connection closed with exception: {e}")
 | 
						|
            traceback.print_exc()
 | 
						|
        finally:
 | 
						|
            print("server closing ws connection")
 | 
						|
            await websocket.close()
 | 
						|
 | 
						|
    print(f"Starting server on {server_host}:{server_port}")
 | 
						|
    config = Config(app, host=server_host, port=server_port, lifespan="on")
 | 
						|
    server = Server(config)
 | 
						|
    await server.serve()
 | 
						|
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    asyncio.run(main("localhost", 8000))
 |