208 lines
5.1 KiB
Python
208 lines
5.1 KiB
Python
"""WebSocket 连接管理 - Phase 10.2
|
|
|
|
管理 WebSocket 连接的生命周期。
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
from typing import Any, Callable
|
|
from dataclasses import dataclass
|
|
|
|
|
|
@dataclass
|
|
class WSConnection:
|
|
"""WebSocket 连接"""
|
|
|
|
session_id: str
|
|
websocket: Any # WebSocket 连接
|
|
user_id: str | None = None
|
|
created_at: float | None = None
|
|
last_ping: float | None = None
|
|
|
|
|
|
class WebSocketManager:
|
|
"""WebSocket 连接管理器
|
|
|
|
管理所有 WebSocket 连接的生命周期。
|
|
"""
|
|
|
|
def __init__(self, ping_interval: float = 30.0):
|
|
"""
|
|
Args:
|
|
ping_interval: 心跳间隔(秒)
|
|
"""
|
|
self._connections: dict[str, WSConnection] = {}
|
|
self._handlers: dict[str, Callable] = {}
|
|
self._ping_interval = ping_interval
|
|
self._ping_tasks: dict[str, asyncio.Task] = {}
|
|
|
|
async def connect(self, session_id: str, websocket: Any, user_id: str | None = None) -> bool:
|
|
"""建立连接
|
|
|
|
Args:
|
|
session_id: 会话 ID
|
|
websocket: WebSocket 连接
|
|
user_id: 用户 ID
|
|
|
|
Returns:
|
|
是否连接成功
|
|
"""
|
|
import time
|
|
|
|
if session_id in self._connections:
|
|
return False
|
|
|
|
conn = WSConnection(
|
|
session_id=session_id,
|
|
websocket=websocket,
|
|
user_id=user_id,
|
|
created_at=time.time(),
|
|
last_ping=time.time(),
|
|
)
|
|
self._connections[session_id] = conn
|
|
|
|
# 启动心跳
|
|
self._ping_tasks[session_id] = asyncio.create_task(self._ping_loop(session_id))
|
|
|
|
return True
|
|
|
|
async def disconnect(self, session_id: str) -> bool:
|
|
"""断开连接
|
|
|
|
Args:
|
|
session_id: 会话 ID
|
|
|
|
Returns:
|
|
是否断开成功
|
|
"""
|
|
if session_id not in self._connections:
|
|
return False
|
|
|
|
# 停止心跳
|
|
if session_id in self._ping_tasks:
|
|
self._ping_tasks[session_id].cancel()
|
|
del self._ping_tasks[session_id]
|
|
|
|
del self._connections[session_id]
|
|
return True
|
|
|
|
async def send(self, session_id: str, message: dict[str, Any]) -> bool:
|
|
"""发送消息
|
|
|
|
Args:
|
|
session_id: 会话 ID
|
|
message: 消息内容
|
|
|
|
Returns:
|
|
是否发送成功
|
|
"""
|
|
if session_id not in self._connections:
|
|
return False
|
|
|
|
try:
|
|
conn = self._connections[session_id]
|
|
await conn.websocket.send_json(message)
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
async def broadcast(self, message: dict[str, Any]) -> int:
|
|
"""广播消息
|
|
|
|
Args:
|
|
message: 消息内容
|
|
|
|
Returns:
|
|
发送成功的数量
|
|
"""
|
|
count = 0
|
|
for session_id in list(self._connections.keys()):
|
|
if await self.send(session_id, message):
|
|
count += 1
|
|
return count
|
|
|
|
async def _ping_loop(self, session_id: str) -> None:
|
|
"""心跳循环
|
|
|
|
Args:
|
|
session_id: 会话 ID
|
|
"""
|
|
import time
|
|
|
|
while session_id in self._connections:
|
|
await asyncio.sleep(self._ping_interval)
|
|
|
|
if session_id not in self._connections:
|
|
break
|
|
|
|
try:
|
|
conn = self._connections[session_id]
|
|
await conn.websocket.send_json({"type": "ping", "timestamp": time.time()})
|
|
conn.last_ping = time.time()
|
|
except Exception:
|
|
await self.disconnect(session_id)
|
|
break
|
|
|
|
def register_handler(self, event_type: str, handler: Callable) -> None:
|
|
"""注册消息处理器
|
|
|
|
Args:
|
|
event_type: 事件类型
|
|
handler: 处理函数
|
|
"""
|
|
self._handlers[event_type] = handler
|
|
|
|
async def handle_message(self, session_id: str, message: dict[str, Any]) -> None:
|
|
"""处理消息
|
|
|
|
Args:
|
|
session_id: 会话 ID
|
|
message: 消息内容
|
|
"""
|
|
msg_type = message.get("type")
|
|
handler = self._handlers.get(msg_type)
|
|
if handler:
|
|
await handler(session_id, message.get("data"))
|
|
|
|
def get_connection(self, session_id: str) -> WSConnection | None:
|
|
"""获取连接
|
|
|
|
Args:
|
|
session_id: 会话 ID
|
|
|
|
Returns:
|
|
连接信息或 None
|
|
"""
|
|
return self._connections.get(session_id)
|
|
|
|
def list_connections(self) -> list[WSConnection]:
|
|
"""列出所有连接
|
|
|
|
Returns:
|
|
连接列表
|
|
"""
|
|
return list(self._connections.values())
|
|
|
|
def is_connected(self, session_id: str) -> bool:
|
|
"""检查是否连接
|
|
|
|
Args:
|
|
session_id: 会话 ID
|
|
|
|
Returns:
|
|
是否已连接
|
|
"""
|
|
return session_id in self._connections
|
|
|
|
|
|
# 全局单例
|
|
_ws_manager: WebSocketManager | None = None
|
|
|
|
|
|
def get_websocket_manager() -> WebSocketManager:
|
|
"""获取全局 WebSocket 管理器"""
|
|
global _ws_manager
|
|
if _ws_manager is None:
|
|
_ws_manager = WebSocketManager()
|
|
return _ws_manager
|