"""WebSocket 连接管理 - Phase 10.2 管理 WebSocket 连接的生命周期。 """ import asyncio import json from typing import Any, Callable from dataclasses import dataclass @dataclass class WSConnection: """WebSocket 连接""" session_id: str websocket: Any # WebSocket 连接 user_id: str | None = None created_at: float | None = None last_ping: float | None = None class WebSocketManager: """WebSocket 连接管理器 管理所有 WebSocket 连接的生命周期。 """ def __init__(self, ping_interval: float = 30.0): """ Args: ping_interval: 心跳间隔(秒) """ self._connections: dict[str, WSConnection] = {} self._handlers: dict[str, Callable] = {} self._ping_interval = ping_interval self._ping_tasks: dict[str, asyncio.Task] = {} async def connect(self, session_id: str, websocket: Any, user_id: str | None = None) -> bool: """建立连接 Args: session_id: 会话 ID websocket: WebSocket 连接 user_id: 用户 ID Returns: 是否连接成功 """ import time if session_id in self._connections: return False conn = WSConnection( session_id=session_id, websocket=websocket, user_id=user_id, created_at=time.time(), last_ping=time.time(), ) self._connections[session_id] = conn # 启动心跳 self._ping_tasks[session_id] = asyncio.create_task(self._ping_loop(session_id)) return True async def disconnect(self, session_id: str) -> bool: """断开连接 Args: session_id: 会话 ID Returns: 是否断开成功 """ if session_id not in self._connections: return False # 停止心跳 if session_id in self._ping_tasks: self._ping_tasks[session_id].cancel() del self._ping_tasks[session_id] del self._connections[session_id] return True async def send(self, session_id: str, message: dict[str, Any]) -> bool: """发送消息 Args: session_id: 会话 ID message: 消息内容 Returns: 是否发送成功 """ if session_id not in self._connections: return False try: conn = self._connections[session_id] await conn.websocket.send_json(message) return True except Exception: return False async def broadcast(self, message: dict[str, Any]) -> int: """广播消息 Args: message: 消息内容 Returns: 发送成功的数量 """ count = 0 for session_id in list(self._connections.keys()): if await self.send(session_id, message): count += 1 return count async def _ping_loop(self, session_id: str) -> None: """心跳循环 Args: session_id: 会话 ID """ import time while session_id in self._connections: await asyncio.sleep(self._ping_interval) if session_id not in self._connections: break try: conn = self._connections[session_id] await conn.websocket.send_json({"type": "ping", "timestamp": time.time()}) conn.last_ping = time.time() except Exception: await self.disconnect(session_id) break def register_handler(self, event_type: str, handler: Callable) -> None: """注册消息处理器 Args: event_type: 事件类型 handler: 处理函数 """ self._handlers[event_type] = handler async def handle_message(self, session_id: str, message: dict[str, Any]) -> None: """处理消息 Args: session_id: 会话 ID message: 消息内容 """ msg_type = message.get("type") handler = self._handlers.get(msg_type) if handler: await handler(session_id, message.get("data")) def get_connection(self, session_id: str) -> WSConnection | None: """获取连接 Args: session_id: 会话 ID Returns: 连接信息或 None """ return self._connections.get(session_id) def list_connections(self) -> list[WSConnection]: """列出所有连接 Returns: 连接列表 """ return list(self._connections.values()) def is_connected(self, session_id: str) -> bool: """检查是否连接 Args: session_id: 会话 ID Returns: 是否已连接 """ return session_id in self._connections # 全局单例 _ws_manager: WebSocketManager | None = None def get_websocket_manager() -> WebSocketManager: """获取全局 WebSocket 管理器""" global _ws_manager if _ws_manager is None: _ws_manager = WebSocketManager() return _ws_manager