diff --git a/software/source/clients/base_device.py b/software/source/clients/base_device.py index d67bcc2..188e915 100644 --- a/software/source/clients/base_device.py +++ b/software/source/clients/base_device.py @@ -285,11 +285,49 @@ class Device: await websocket.send(json.dumps(message)) send_queue.task_done() await asyncio.sleep(0.01) - + + async def authenticate(self, websocket): + while True: + # Receive authentication request from the server + auth_request = await websocket.recv() + auth_data = json.loads(auth_request) + + if auth_data["type"] == "auth_request": + # Send authentication response with the token + token = os.getenv("WS_TOKEN") + if token: + auth_response = {"token": token} + await websocket.send(json.dumps(auth_response)) + + # Receive authentication result from the server + auth_result = await websocket.recv() + result_data = json.loads(auth_result) + + if result_data["type"] == "auth_success": + # Authentication successful + return True + else: + # Authentication failed + logger.error("Authentication failed. Closing the connection.") + await websocket.close() + return False + else: + logger.error("WS_TOKEN not found in environment variables.") + await websocket.close() + return False + else: + # Unexpected message from the server + logger.warning(f"Unexpected message from the server: {auth_data}") + async def websocket_communication(self, WS_URL): show_connection_log = True async def exec_ws_communication(websocket): + + # Perform authentication + if not await self.authenticate(websocket): + return # Authentication successful, continue with the rest of the communication + if CAMERA_ENABLED: print( "\nHold the spacebar to start recording. Press 'c' to capture an image from the camera. Press CTRL-C to exit." @@ -347,8 +385,8 @@ class Device: # Workaround for Windows 10 not latching to the websocket server. # See https://github.com/OpenInterpreter/01/issues/197 try: - ws = websockets.connect(WS_URL) - await exec_ws_communication(ws) + async with websockets.connect(WS_URL) as websocket: + await exec_ws_communication(websocket) except Exception as e: logger.error(f"Error while attempting to connect: {e}") else: