From 3e7d6eadd727ca08f7a0fb58c7cae622b932078f Mon Sep 17 00:00:00 2001 From: Keyvan Hardani Date: Wed, 24 Apr 2024 16:29:36 +0200 Subject: [PATCH] Authentication for WebSocket on server.py Add authentication for WebSocket connections. OAuth2PasswordBearer scheme from FastAPI security --- software/source/server/server.py | 34 ++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/software/source/server/server.py b/software/source/server/server.py index c4dd036..e92b50d 100644 --- a/software/source/server/server.py +++ b/software/source/server/server.py @@ -41,6 +41,7 @@ setup_logging() accumulator = Accumulator() app = FastAPI() +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") app_dir = user_data_dir("01") conversation_history_path = os.path.join(app_dir, "conversations", "user.json") @@ -134,10 +135,43 @@ interpreter = configure_interpreter(interpreter) async def ping(): return PlainTextResponse("pong") +async def authenticate(websocket: WebSocket): + # Send authentication request to the client + await websocket.send_json({"type": "auth_request"}) + # Receive authentication response from the client + try: + auth_response = await websocket.receive_json() + except WebSocketDisconnect: + return False + + # Verify the provided token + token = auth_response.get("token") + if not token: + await websocket.send_json({"type": "auth_failure"}) + await websocket.close() + return False + + try: + # Use the OAuth2PasswordBearer scheme to validate the token + token = await oauth2_scheme(token) + except Exception: + await websocket.send_json({"type": "auth_failure"}) + await websocket.close() + return False + + # Authentication successful + await websocket.send_json({"type": "auth_success"}) + return True + @app.websocket("/") async def websocket_endpoint(websocket: WebSocket): await websocket.accept() + + # Perform authentication + if not await authenticate(websocket): + return + receive_task = asyncio.create_task(receive_messages(websocket)) send_task = asyncio.create_task(send_messages(websocket)) try: