Files
JARVIS/backend/app/agents/transport/websocket.py

208 lines
5.1 KiB
Python

"""WebSocket 连接管理 - Phase 10.2
管理 WebSocket 连接的生命周期。
"""
import asyncio
import json
from typing import Any, Callable
from dataclasses import dataclass
@dataclass
class WSConnection:
"""WebSocket 连接"""
session_id: str
websocket: Any # WebSocket 连接
user_id: str | None = None
created_at: float | None = None
last_ping: float | None = None
class WebSocketManager:
"""WebSocket 连接管理器
管理所有 WebSocket 连接的生命周期。
"""
def __init__(self, ping_interval: float = 30.0):
"""
Args:
ping_interval: 心跳间隔(秒)
"""
self._connections: dict[str, WSConnection] = {}
self._handlers: dict[str, Callable] = {}
self._ping_interval = ping_interval
self._ping_tasks: dict[str, asyncio.Task] = {}
async def connect(self, session_id: str, websocket: Any, user_id: str | None = None) -> bool:
"""建立连接
Args:
session_id: 会话 ID
websocket: WebSocket 连接
user_id: 用户 ID
Returns:
是否连接成功
"""
import time
if session_id in self._connections:
return False
conn = WSConnection(
session_id=session_id,
websocket=websocket,
user_id=user_id,
created_at=time.time(),
last_ping=time.time(),
)
self._connections[session_id] = conn
# 启动心跳
self._ping_tasks[session_id] = asyncio.create_task(self._ping_loop(session_id))
return True
async def disconnect(self, session_id: str) -> bool:
"""断开连接
Args:
session_id: 会话 ID
Returns:
是否断开成功
"""
if session_id not in self._connections:
return False
# 停止心跳
if session_id in self._ping_tasks:
self._ping_tasks[session_id].cancel()
del self._ping_tasks[session_id]
del self._connections[session_id]
return True
async def send(self, session_id: str, message: dict[str, Any]) -> bool:
"""发送消息
Args:
session_id: 会话 ID
message: 消息内容
Returns:
是否发送成功
"""
if session_id not in self._connections:
return False
try:
conn = self._connections[session_id]
await conn.websocket.send_json(message)
return True
except Exception:
return False
async def broadcast(self, message: dict[str, Any]) -> int:
"""广播消息
Args:
message: 消息内容
Returns:
发送成功的数量
"""
count = 0
for session_id in list(self._connections.keys()):
if await self.send(session_id, message):
count += 1
return count
async def _ping_loop(self, session_id: str) -> None:
"""心跳循环
Args:
session_id: 会话 ID
"""
import time
while session_id in self._connections:
await asyncio.sleep(self._ping_interval)
if session_id not in self._connections:
break
try:
conn = self._connections[session_id]
await conn.websocket.send_json({"type": "ping", "timestamp": time.time()})
conn.last_ping = time.time()
except Exception:
await self.disconnect(session_id)
break
def register_handler(self, event_type: str, handler: Callable) -> None:
"""注册消息处理器
Args:
event_type: 事件类型
handler: 处理函数
"""
self._handlers[event_type] = handler
async def handle_message(self, session_id: str, message: dict[str, Any]) -> None:
"""处理消息
Args:
session_id: 会话 ID
message: 消息内容
"""
msg_type = message.get("type")
handler = self._handlers.get(msg_type)
if handler:
await handler(session_id, message.get("data"))
def get_connection(self, session_id: str) -> WSConnection | None:
"""获取连接
Args:
session_id: 会话 ID
Returns:
连接信息或 None
"""
return self._connections.get(session_id)
def list_connections(self) -> list[WSConnection]:
"""列出所有连接
Returns:
连接列表
"""
return list(self._connections.values())
def is_connected(self, session_id: str) -> bool:
"""检查是否连接
Args:
session_id: 会话 ID
Returns:
是否已连接
"""
return session_id in self._connections
# 全局单例
_ws_manager: WebSocketManager | None = None
def get_websocket_manager() -> WebSocketManager:
"""获取全局 WebSocket 管理器"""
global _ws_manager
if _ws_manager is None:
_ws_manager = WebSocketManager()
return _ws_manager