Spaces:
Runtime error
Runtime error
| from uuid import UUID | |
| import asyncio | |
| from fastapi import WebSocket | |
| from fastapi.websockets import WebSocketDisconnect | |
| from starlette.websockets import WebSocketState | |
| import logging | |
| from typing import Any | |
| from util import ParamsModel | |
| Connections = dict[UUID, dict[str, WebSocket | asyncio.Queue]] | |
| class ServerFullException(Exception): | |
| """Exception raised when the server is full.""" | |
| pass | |
| class ConnectionManager: | |
| def __init__(self): | |
| self.active_connections: Connections = {} | |
| async def connect( | |
| self, user_id: UUID, websocket: WebSocket, max_queue_size: int = 0 | |
| ): | |
| await websocket.accept() | |
| user_count = self.get_user_count() | |
| print(f"User count: {user_count}") | |
| if max_queue_size > 0 and user_count >= max_queue_size: | |
| print("Server is full") | |
| await websocket.send_json({"status": "error", "message": "Server is full"}) | |
| await websocket.close() | |
| raise ServerFullException("Server is full") | |
| print(f"New user connected: {user_id}") | |
| self.active_connections[user_id] = { | |
| "websocket": websocket, | |
| "queue": asyncio.Queue(), | |
| } | |
| await websocket.send_json( | |
| {"status": "connected", "message": "Connected"}, | |
| ) | |
| await websocket.send_json({"status": "wait"}) | |
| await websocket.send_json({"status": "send_frame"}) | |
| def check_user(self, user_id: UUID) -> bool: | |
| return user_id in self.active_connections | |
| async def update_data(self, user_id: UUID, new_data: ParamsModel): | |
| user_session = self.active_connections.get(user_id) | |
| if user_session: | |
| queue = user_session["queue"] | |
| await queue.put(new_data) | |
| async def get_latest_data(self, user_id: UUID) -> ParamsModel | None: | |
| user_session = self.active_connections.get(user_id) | |
| if user_session: | |
| queue = user_session["queue"] | |
| try: | |
| return await queue.get() | |
| except asyncio.QueueEmpty: | |
| return None | |
| return None | |
| def delete_user(self, user_id: UUID): | |
| user_session = self.active_connections.pop(user_id, None) | |
| if user_session: | |
| queue = user_session["queue"] | |
| while not queue.empty(): | |
| try: | |
| queue.get_nowait() | |
| except asyncio.QueueEmpty: | |
| continue | |
| def get_user_count(self) -> int: | |
| return len(self.active_connections) | |
| def get_websocket(self, user_id: UUID) -> WebSocket | None: | |
| user_session = self.active_connections.get(user_id) | |
| if user_session: | |
| websocket = user_session["websocket"] | |
| # Both client_state and application_state should be checked | |
| # to ensure the websocket is fully connected and not closing | |
| if (websocket.client_state == WebSocketState.CONNECTED and | |
| websocket.application_state == WebSocketState.CONNECTED): | |
| return user_session["websocket"] | |
| return None | |
| async def disconnect(self, user_id: UUID): | |
| # First check if user is in active connections | |
| if user_id not in self.active_connections: | |
| return | |
| # Get the websocket directly from active_connections to avoid get_websocket validation | |
| user_session = self.active_connections.get(user_id) | |
| if user_session and "websocket" in user_session: | |
| websocket = user_session["websocket"] | |
| try: | |
| # Only attempt close if not already closed | |
| if (websocket.client_state != WebSocketState.DISCONNECTED and | |
| websocket.application_state != WebSocketState.DISCONNECTED): | |
| await websocket.close() | |
| except Exception as e: | |
| logging.error(f"Error closing websocket for {user_id}: {e}") | |
| # Always delete the user to ensure cleanup | |
| self.delete_user(user_id) | |
| async def send_json(self, user_id: UUID, data: dict): | |
| try: | |
| websocket = self.get_websocket(user_id) | |
| if websocket: | |
| try: | |
| await websocket.send_json(data) | |
| except RuntimeError as e: | |
| error_msg = str(e) | |
| if any(err in error_msg for err in [ | |
| "WebSocket is not connected", | |
| "Cannot call \"send\" once a close message has been sent", | |
| "Cannot call \"receive\" once a close message has been sent", | |
| "WebSocket is disconnected"]): | |
| # The websocket was disconnected or is closing | |
| logging.info(f"WebSocket disconnected/closing for user {user_id}: {error_msg}") | |
| await self.disconnect(user_id) | |
| else: | |
| logging.error(f"Runtime error in send_json: {e}") | |
| except WebSocketDisconnect as disconnect_error: | |
| # Handle websocket disconnection event | |
| code = disconnect_error.code | |
| if code == 1006: # ABNORMAL_CLOSURE | |
| logging.info(f"WebSocket abnormally closed for user {user_id} during send: Connection was closed without a proper close handshake") | |
| else: | |
| logging.info(f"WebSocket disconnected for user {user_id} with code {code} during send: {disconnect_error.reason}") | |
| # Always disconnect the user | |
| if user_id in self.active_connections: | |
| await self.disconnect(user_id) | |
| except Exception as e: | |
| logging.error(f"Error: Send json: {e}") | |
| # If any send fails, ensure the user gets removed to prevent further errors | |
| if user_id in self.active_connections: | |
| await self.disconnect(user_id) | |
| async def receive_json(self, user_id: UUID) -> dict | None: | |
| try: | |
| websocket = self.get_websocket(user_id) | |
| if websocket: | |
| try: | |
| # Receive the raw message and handle JSON parsing manually for better error handling | |
| try: | |
| data = await websocket.receive_json() | |
| # Verify it's a dictionary | |
| if not isinstance(data, dict): | |
| logging.error(f"Expected dict but received {type(data)} from user {user_id}: {data}") | |
| return None | |
| return data | |
| except ValueError as json_err: | |
| # Specific handling for JSON parsing errors | |
| logging.error(f"JSON parsing error for user {user_id}: {json_err}") | |
| return None | |
| except RuntimeError as e: | |
| error_msg = str(e) | |
| if any(err in error_msg for err in [ | |
| "WebSocket is not connected", | |
| "Cannot call \"send\" once a close message has been sent", | |
| "Cannot call \"receive\" once a close message has been sent", | |
| "WebSocket is disconnected"]): | |
| # The websocket was disconnected or closing | |
| logging.info(f"WebSocket disconnected/closing for user {user_id}: {error_msg}") | |
| await self.disconnect(user_id) | |
| else: | |
| logging.error(f"Runtime error in receive_json: {e}") | |
| return None | |
| return None | |
| except WebSocketDisconnect as disconnect_error: | |
| # Handle websocket disconnection event (this is a clean, expected path) | |
| code = disconnect_error.code | |
| if code == 1006: # ABNORMAL_CLOSURE | |
| logging.info(f"WebSocket abnormally closed for user {user_id}: Connection was closed without a proper close handshake") | |
| else: | |
| logging.info(f"WebSocket disconnected for user {user_id} with code {code}: {disconnect_error.reason}") | |
| # Always disconnect the user | |
| if user_id in self.active_connections: | |
| await self.disconnect(user_id) | |
| return None | |
| except Exception as e: | |
| logging.error(f"Error: Receive json: {e}") | |
| # Ensure disconnection on any exception | |
| if user_id in self.active_connections: | |
| await self.disconnect(user_id) | |
| return None | |
| async def receive_bytes(self, user_id: UUID) -> bytes | None: | |
| try: | |
| websocket = self.get_websocket(user_id) | |
| if websocket: | |
| try: | |
| return await websocket.receive_bytes() | |
| except RuntimeError as e: | |
| error_msg = str(e) | |
| if any(err in error_msg for err in [ | |
| "WebSocket is not connected", | |
| "Cannot call \"send\" once a close message has been sent", | |
| "Cannot call \"receive\" once a close message has been sent", | |
| "WebSocket is disconnected"]): | |
| # The websocket was disconnected or closing | |
| logging.info(f"WebSocket disconnected/closing for user {user_id}: {error_msg}") | |
| await self.disconnect(user_id) | |
| else: | |
| logging.error(f"Runtime error in receive_bytes: {e}") | |
| return None | |
| return None | |
| except WebSocketDisconnect as disconnect_error: | |
| # Handle websocket disconnection event (this is a clean, expected path) | |
| code = disconnect_error.code | |
| if code == 1006: # ABNORMAL_CLOSURE | |
| logging.info(f"WebSocket abnormally closed for user {user_id}: Connection was closed without a proper close handshake") | |
| else: | |
| logging.info(f"WebSocket disconnected for user {user_id} with code {code}: {disconnect_error.reason}") | |
| # Always disconnect the user | |
| if user_id in self.active_connections: | |
| await self.disconnect(user_id) | |
| return None | |
| except Exception as e: | |
| logging.error(f"Error: Receive bytes: {e}") | |
| # Ensure disconnection on any exception | |
| if user_id in self.active_connections: | |
| await self.disconnect(user_id) | |
| return None | |