|
import json |
|
import logging |
|
import uuid |
|
from datetime import UTC, datetime |
|
|
|
from fastapi import WebSocket, WebSocketDisconnect |
|
from typing_extensions import TypedDict |
|
|
|
from .models import MessageType, ParticipantRole |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class ConnectionMetadata(TypedDict): |
|
"""Connection metadata with proper typing""" |
|
|
|
workspace_id: str |
|
room_id: str |
|
participant_id: str |
|
role: ParticipantRole |
|
connected_at: datetime |
|
last_activity: datetime |
|
message_count: int |
|
|
|
|
|
|
|
|
|
|
|
class RoboticsRoom: |
|
"""Simple robotics room with producer/consumer pattern""" |
|
|
|
def __init__(self, room_id: str, workspace_id: str): |
|
self.id = room_id |
|
self.workspace_id = workspace_id |
|
|
|
|
|
self.producer: str | None = None |
|
self.consumers: list[str] = [] |
|
|
|
|
|
self.joints: dict[str, float] = {} |
|
|
|
|
|
class RoboticsCore: |
|
"""Core robotics system - simplified and merged with workspace support""" |
|
|
|
def __init__(self): |
|
|
|
self.workspaces: dict[str, dict[str, RoboticsRoom]] = {} |
|
self.connections: dict[str, WebSocket] = {} |
|
self.connection_metadata: dict[ |
|
str, ConnectionMetadata |
|
] = {} |
|
|
|
|
|
|
|
def create_room( |
|
self, workspace_id: str | None = None, room_id: str | None = None |
|
) -> tuple[str, str]: |
|
"""Create a new room and return (workspace_id, room_id)""" |
|
workspace_id = workspace_id or str(uuid.uuid4()) |
|
room_id = room_id or str(uuid.uuid4()) |
|
|
|
|
|
if workspace_id not in self.workspaces: |
|
self.workspaces[workspace_id] = {} |
|
|
|
room = RoboticsRoom(room_id, workspace_id) |
|
self.workspaces[workspace_id][room_id] = room |
|
|
|
logger.info(f"Created room {room_id} in workspace {workspace_id}") |
|
return workspace_id, room_id |
|
|
|
def list_rooms(self, workspace_id: str) -> list[dict]: |
|
"""List all rooms in a specific workspace""" |
|
if workspace_id not in self.workspaces: |
|
return [] |
|
|
|
return [ |
|
{ |
|
"id": room.id, |
|
"workspace_id": room.workspace_id, |
|
"participants": { |
|
"producer": room.producer, |
|
"consumers": room.consumers, |
|
"total": len(room.consumers) + (1 if room.producer else 0), |
|
}, |
|
"joints_count": len(room.joints), |
|
} |
|
for room in self.workspaces[workspace_id].values() |
|
] |
|
|
|
def delete_room(self, workspace_id: str, room_id: str) -> bool: |
|
"""Delete a room from a workspace""" |
|
if ( |
|
workspace_id not in self.workspaces |
|
or room_id not in self.workspaces[workspace_id] |
|
): |
|
return False |
|
|
|
room = self.workspaces[workspace_id][room_id] |
|
|
|
|
|
for consumer_id in room.consumers[:]: |
|
self.leave_room(workspace_id, room_id, consumer_id) |
|
if room.producer: |
|
self.leave_room(workspace_id, room_id, room.producer) |
|
|
|
del self.workspaces[workspace_id][room_id] |
|
logger.info(f"Deleted room {room_id} from workspace {workspace_id}") |
|
return True |
|
|
|
def get_room_state(self, workspace_id: str, room_id: str) -> dict: |
|
"""Get detailed room state""" |
|
room = self._get_room(workspace_id, room_id) |
|
if not room: |
|
return {"error": "Room not found"} |
|
|
|
return { |
|
"room_id": room_id, |
|
"workspace_id": workspace_id, |
|
"joints": room.joints, |
|
"participants": { |
|
"producer": room.producer, |
|
"consumers": room.consumers, |
|
"total": len(room.consumers) + (1 if room.producer else 0), |
|
}, |
|
"timestamp": datetime.now(tz=UTC).isoformat(), |
|
} |
|
|
|
def get_room_info(self, workspace_id: str, room_id: str) -> dict: |
|
"""Get basic room info""" |
|
room = self._get_room(workspace_id, room_id) |
|
if not room: |
|
return {"error": "Room not found"} |
|
|
|
return { |
|
"id": room.id, |
|
"workspace_id": room.workspace_id, |
|
"participants": { |
|
"producer": room.producer, |
|
"consumers": room.consumers, |
|
"total": len(room.consumers) + (1 if room.producer else 0), |
|
}, |
|
"joints_count": len(room.joints), |
|
"has_producer": room.producer is not None, |
|
"active_consumers": len(room.consumers), |
|
} |
|
|
|
def _get_room(self, workspace_id: str, room_id: str) -> RoboticsRoom | None: |
|
"""Get room by workspace and room ID""" |
|
if workspace_id not in self.workspaces: |
|
return None |
|
return self.workspaces[workspace_id].get(room_id) |
|
|
|
|
|
|
|
def join_room( |
|
self, |
|
workspace_id: str, |
|
room_id: str, |
|
participant_id: str, |
|
role: ParticipantRole, |
|
) -> bool: |
|
"""Join room as producer or consumer""" |
|
room = self._get_room(workspace_id, room_id) |
|
if not room: |
|
return False |
|
|
|
if role == ParticipantRole.PRODUCER: |
|
if room.producer is None: |
|
room.producer = participant_id |
|
logger.info( |
|
f"Producer {participant_id} joined room {room_id} in workspace {workspace_id}" |
|
) |
|
return True |
|
|
|
logger.warning( |
|
f"Producer {participant_id} failed to join room {room_id} - room already has producer {room.producer}" |
|
) |
|
return False |
|
|
|
if role == ParticipantRole.CONSUMER: |
|
if participant_id not in room.consumers: |
|
room.consumers.append(participant_id) |
|
logger.info( |
|
f"Consumer {participant_id} joined room {room_id} in workspace {workspace_id}" |
|
) |
|
return True |
|
return False |
|
|
|
return False |
|
|
|
def leave_room(self, workspace_id: str, room_id: str, participant_id: str): |
|
"""Leave room""" |
|
room = self._get_room(workspace_id, room_id) |
|
if not room: |
|
return |
|
|
|
if room.producer == participant_id: |
|
room.producer = None |
|
logger.info( |
|
f"Producer {participant_id} left room {room_id} in workspace {workspace_id}" |
|
) |
|
elif participant_id in room.consumers: |
|
room.consumers.remove(participant_id) |
|
logger.info( |
|
f"Consumer {participant_id} left room {room_id} in workspace {workspace_id}" |
|
) |
|
|
|
|
|
|
|
def update_joints( |
|
self, workspace_id: str, room_id: str, joint_updates: list[dict] |
|
) -> list[dict]: |
|
room = self._get_room(workspace_id, room_id) |
|
if not room: |
|
msg = f"Room {room_id} not found in workspace {workspace_id}" |
|
raise ValueError(msg) |
|
|
|
changed_joints = [] |
|
for joint in joint_updates: |
|
name = joint["name"] |
|
value = joint["value"] |
|
|
|
|
|
if room.joints.get(name) != value: |
|
room.joints[name] = value |
|
changed_joints.append(joint) |
|
|
|
return changed_joints |
|
|
|
|
|
|
|
async def handle_websocket( |
|
self, websocket: WebSocket, workspace_id: str, room_id: str |
|
): |
|
"""Handle WebSocket connection""" |
|
await websocket.accept() |
|
|
|
participant_id: str | None = None |
|
role: ParticipantRole | None = None |
|
|
|
try: |
|
|
|
data = await websocket.receive_text() |
|
join_msg = json.loads(data) |
|
|
|
participant_id = join_msg["participant_id"] |
|
role = ParticipantRole(join_msg["role"]) |
|
|
|
|
|
if not self.join_room(workspace_id, room_id, participant_id, role): |
|
await websocket.send_text( |
|
json.dumps({ |
|
"type": MessageType.ERROR.value, |
|
"message": "Cannot join room", |
|
}) |
|
) |
|
await websocket.close() |
|
return |
|
|
|
self.connections[participant_id] = websocket |
|
|
|
|
|
self.connection_metadata[participant_id] = ConnectionMetadata( |
|
workspace_id=workspace_id, |
|
room_id=room_id, |
|
participant_id=participant_id, |
|
role=role, |
|
connected_at=datetime.now(tz=UTC), |
|
last_activity=datetime.now(tz=UTC), |
|
message_count=0, |
|
) |
|
|
|
|
|
if role == ParticipantRole.CONSUMER: |
|
room = self._get_room(workspace_id, room_id) |
|
if room: |
|
await websocket.send_text( |
|
json.dumps({ |
|
"type": MessageType.STATE_SYNC.value, |
|
"data": room.joints, |
|
"timestamp": datetime.now(tz=UTC).isoformat(), |
|
}) |
|
) |
|
|
|
|
|
await websocket.send_text( |
|
json.dumps({ |
|
"type": MessageType.JOINED.value, |
|
"room_id": room_id, |
|
"workspace_id": workspace_id, |
|
"role": role.value, |
|
"timestamp": datetime.now(tz=UTC).isoformat(), |
|
}) |
|
) |
|
|
|
|
|
async for message in websocket.iter_text(): |
|
try: |
|
msg = json.loads(message) |
|
await self._handle_message( |
|
workspace_id, room_id, participant_id, role, msg |
|
) |
|
except json.JSONDecodeError: |
|
logger.exception(f"Invalid JSON from {participant_id}") |
|
except Exception: |
|
logger.exception("Message error") |
|
|
|
except WebSocketDisconnect: |
|
logger.info(f"WebSocket disconnected: {participant_id}") |
|
except Exception: |
|
logger.exception("WebSocket error") |
|
finally: |
|
|
|
if participant_id: |
|
metadata = self.connection_metadata.get(participant_id) |
|
if metadata: |
|
self.leave_room( |
|
metadata["workspace_id"], metadata["room_id"], participant_id |
|
) |
|
if participant_id in self.connections: |
|
del self.connections[participant_id] |
|
if participant_id in self.connection_metadata: |
|
del self.connection_metadata[participant_id] |
|
|
|
async def _handle_message( |
|
self, |
|
workspace_id: str, |
|
room_id: str, |
|
participant_id: str, |
|
role: ParticipantRole, |
|
message: dict, |
|
): |
|
"""Handle incoming WebSocket message with structured handlers""" |
|
|
|
if participant_id in self.connection_metadata: |
|
self.connection_metadata[participant_id]["last_activity"] = datetime.now( |
|
tz=UTC |
|
) |
|
self.connection_metadata[participant_id]["message_count"] += 1 |
|
|
|
try: |
|
msg_type = MessageType(message.get("type")) |
|
except ValueError: |
|
logger.warning( |
|
f"Unknown message type from {participant_id}: {message.get('type')}" |
|
) |
|
await self._handle_error( |
|
participant_id, f"Unknown message type: {message.get('type')}" |
|
) |
|
return |
|
|
|
|
|
if msg_type == MessageType.JOINT_UPDATE: |
|
await self._handle_joint_update( |
|
workspace_id, room_id, participant_id, role, message |
|
) |
|
elif msg_type == MessageType.HEARTBEAT: |
|
await self._handle_heartbeat(participant_id) |
|
elif msg_type == MessageType.EMERGENCY_STOP: |
|
await self._handle_emergency_stop( |
|
workspace_id, room_id, participant_id, message |
|
) |
|
else: |
|
logger.warning(f"Unhandled message type {msg_type} from {participant_id}") |
|
|
|
|
|
|
|
async def _handle_joint_update( |
|
self, |
|
workspace_id: str, |
|
room_id: str, |
|
participant_id: str, |
|
role: ParticipantRole, |
|
message: dict, |
|
): |
|
"""Handle joint update commands from producers""" |
|
if role != ParticipantRole.PRODUCER: |
|
logger.warning( |
|
f"Non-producer {participant_id} attempted to send joint update" |
|
) |
|
return |
|
|
|
joints = message.get("data", []) |
|
if not joints: |
|
logger.warning(f"Empty joint data from producer {participant_id}") |
|
return |
|
|
|
try: |
|
changed_joints = self.update_joints(workspace_id, room_id, joints) |
|
|
|
if changed_joints: |
|
await self._broadcast_to_consumers( |
|
workspace_id, |
|
room_id, |
|
{ |
|
"type": MessageType.JOINT_UPDATE.value, |
|
"data": changed_joints, |
|
"timestamp": datetime.now(tz=UTC).isoformat(), |
|
"source": participant_id, |
|
}, |
|
) |
|
logger.debug( |
|
f"Producer {participant_id} sent {len(changed_joints)} joint updates" |
|
) |
|
except Exception: |
|
logger.exception(f"Error processing joint update from {participant_id}") |
|
await self._handle_error(participant_id, "Failed to process joint update") |
|
|
|
async def _handle_heartbeat(self, participant_id: str): |
|
"""Handle heartbeat messages""" |
|
try: |
|
await self._send_to_participant( |
|
participant_id, |
|
{ |
|
"type": MessageType.HEARTBEAT_ACK.value, |
|
"timestamp": datetime.now(tz=UTC).isoformat(), |
|
}, |
|
) |
|
logger.debug(f"Heartbeat acknowledged for {participant_id}") |
|
except Exception: |
|
logger.exception(f"Error handling heartbeat from {participant_id}") |
|
|
|
async def _handle_emergency_stop( |
|
self, workspace_id: str, room_id: str, participant_id: str, message: dict |
|
): |
|
"""Handle emergency stop messages""" |
|
try: |
|
reason = message.get("reason", f"Emergency stop from {participant_id}") |
|
|
|
emergency_message = { |
|
"type": MessageType.EMERGENCY_STOP.value, |
|
"timestamp": datetime.now(tz=UTC).isoformat(), |
|
"reason": reason, |
|
"source": participant_id, |
|
} |
|
|
|
|
|
await self._broadcast_to_all_participants( |
|
workspace_id, room_id, emergency_message |
|
) |
|
logger.warning( |
|
f"🚨 Emergency stop triggered by {participant_id} in room {room_id} (workspace {workspace_id})" |
|
) |
|
|
|
except Exception: |
|
logger.exception(f"Error handling emergency stop from {participant_id}") |
|
|
|
async def _handle_error(self, participant_id: str, error_message: str): |
|
"""Send error message to participant""" |
|
try: |
|
await self._send_to_participant( |
|
participant_id, |
|
{ |
|
"type": MessageType.ERROR.value, |
|
"message": error_message, |
|
"timestamp": datetime.now(tz=UTC).isoformat(), |
|
}, |
|
) |
|
except Exception: |
|
logger.exception(f"Error sending error message to {participant_id}") |
|
|
|
async def _broadcast_to_consumers( |
|
self, workspace_id: str, room_id: str, message: dict |
|
): |
|
"""Send message to all consumers in room""" |
|
room = self._get_room(workspace_id, room_id) |
|
if not room: |
|
return |
|
|
|
message_text = json.dumps(message) |
|
failed = [] |
|
|
|
for consumer_id in room.consumers: |
|
if consumer_id in self.connections: |
|
try: |
|
await self.connections[consumer_id].send_text(message_text) |
|
except Exception: |
|
logger.exception(f"Error sending message to {consumer_id}") |
|
failed.append(consumer_id) |
|
|
|
|
|
for consumer_id in failed: |
|
room.consumers.remove(consumer_id) |
|
if consumer_id in self.connections: |
|
del self.connections[consumer_id] |
|
if consumer_id in self.connection_metadata: |
|
del self.connection_metadata[consumer_id] |
|
|
|
async def _broadcast_to_all_participants( |
|
self, workspace_id: str, room_id: str, message: dict |
|
): |
|
"""Send message to all participants (producer + consumers) in room""" |
|
room = self._get_room(workspace_id, room_id) |
|
if not room: |
|
return |
|
|
|
message_text = json.dumps(message) |
|
participants = [] |
|
|
|
|
|
if room.producer: |
|
participants.append(room.producer) |
|
|
|
|
|
participants.extend(room.consumers) |
|
|
|
failed = [] |
|
sent_count = 0 |
|
|
|
for participant_id in participants: |
|
if participant_id in self.connections: |
|
try: |
|
await self.connections[participant_id].send_text(message_text) |
|
sent_count += 1 |
|
except Exception: |
|
logger.exception(f"Error sending message to {participant_id}") |
|
failed.append(participant_id) |
|
|
|
|
|
for participant_id in failed: |
|
metadata = self.connection_metadata.get(participant_id) |
|
if metadata: |
|
self.leave_room( |
|
metadata["workspace_id"], metadata["room_id"], participant_id |
|
) |
|
if participant_id in self.connections: |
|
del self.connections[participant_id] |
|
if participant_id in self.connection_metadata: |
|
del self.connection_metadata[participant_id] |
|
|
|
logger.debug( |
|
f"Broadcast message to {sent_count}/{len(participants)} participants in room {room_id}" |
|
) |
|
|
|
async def _send_to_participant(self, participant_id: str, message: dict): |
|
"""Send message to specific participant""" |
|
if participant_id in self.connections: |
|
try: |
|
await self.connections[participant_id].send_text(json.dumps(message)) |
|
except Exception: |
|
logger.exception(f"Error sending message to {participant_id}") |
|
if participant_id in self.connections: |
|
del self.connections[participant_id] |
|
|
|
|
|
|
|
def get_connection_stats(self) -> dict: |
|
"""Get connection statistics and metadata""" |
|
stats = { |
|
"total_connections": len(self.connections), |
|
"total_workspaces": len(self.workspaces), |
|
"total_rooms": sum(len(rooms) for rooms in self.workspaces.values()), |
|
"connections_by_role": {"producer": 0, "consumer": 0}, |
|
"connections_by_workspace": {}, |
|
"active_connections": [], |
|
} |
|
|
|
|
|
for participant_id, metadata in self.connection_metadata.items(): |
|
role = metadata["role"] |
|
workspace_id = metadata["workspace_id"] |
|
room_id = metadata["room_id"] |
|
|
|
stats["connections_by_role"][role.value] += 1 |
|
|
|
if workspace_id not in stats["connections_by_workspace"]: |
|
stats["connections_by_workspace"][workspace_id] = { |
|
"producer": 0, |
|
"consumer": 0, |
|
"rooms": 0, |
|
} |
|
stats["connections_by_workspace"][workspace_id][role.value] += 1 |
|
|
|
|
|
if workspace_id in self.workspaces: |
|
stats["connections_by_workspace"][workspace_id]["rooms"] = len( |
|
self.workspaces[workspace_id] |
|
) |
|
|
|
stats["active_connections"].append({ |
|
"participant_id": participant_id, |
|
"workspace_id": workspace_id, |
|
"room_id": room_id, |
|
"role": role.value, |
|
"connected_at": metadata["connected_at"].isoformat(), |
|
"last_activity": metadata["last_activity"].isoformat(), |
|
"message_count": metadata["message_count"], |
|
}) |
|
|
|
return stats |
|
|
|
|
|
|
|
async def send_command_to_room( |
|
self, workspace_id: str, room_id: str, joints: list[dict] |
|
): |
|
changed_joints = self.update_joints(workspace_id, room_id, joints) |
|
|
|
if changed_joints: |
|
await self._broadcast_to_consumers( |
|
workspace_id, |
|
room_id, |
|
{ |
|
"type": MessageType.JOINT_UPDATE.value, |
|
"data": changed_joints, |
|
"timestamp": datetime.now(tz=UTC).isoformat(), |
|
"source": "api", |
|
}, |
|
) |
|
|
|
return len(changed_joints) |
|
|