63 lines
2.0 KiB
Python
63 lines
2.0 KiB
Python
|
|
"""
|
||
|
|
会话管理器 - 管理 Agent 的会话历史
|
||
|
|
"""
|
||
|
|
from typing import Any, Optional
|
||
|
|
from collections import defaultdict
|
||
|
|
from datetime import datetime
|
||
|
|
|
||
|
|
|
||
|
|
class SessionManager:
|
||
|
|
"""会话管理器"""
|
||
|
|
|
||
|
|
def __init__(self, max_history: int = 10):
|
||
|
|
"""
|
||
|
|
初始化会话管理器
|
||
|
|
|
||
|
|
Args:
|
||
|
|
max_history: 每个会话保留的最大历史消息数
|
||
|
|
"""
|
||
|
|
self.max_history = max_history
|
||
|
|
self.sessions: dict[str, list[dict]] = defaultdict(list)
|
||
|
|
self.metadata: dict[str, dict] = {}
|
||
|
|
|
||
|
|
def add_message(self, session_id: str, role: str, content: str):
|
||
|
|
"""添加消息到会话"""
|
||
|
|
self.sessions[session_id].append({
|
||
|
|
"role": role,
|
||
|
|
"content": content,
|
||
|
|
"timestamp": datetime.now().isoformat()
|
||
|
|
})
|
||
|
|
|
||
|
|
# 限制历史长度
|
||
|
|
if len(self.sessions[session_id]) > self.max_history:
|
||
|
|
self.sessions[session_id] = self.sessions[session_id][-self.max_history:]
|
||
|
|
|
||
|
|
def get_history(self, session_id: str) -> list[dict]:
|
||
|
|
"""获取会话历史"""
|
||
|
|
return self.sessions.get(session_id, [])
|
||
|
|
|
||
|
|
def clear_session(self, session_id: str):
|
||
|
|
"""清除会话"""
|
||
|
|
if session_id in self.sessions:
|
||
|
|
del self.sessions[session_id]
|
||
|
|
if session_id in self.metadata:
|
||
|
|
del self.metadata[session_id]
|
||
|
|
|
||
|
|
def set_metadata(self, session_id: str, key: str, value: Any):
|
||
|
|
"""设置会话元数据"""
|
||
|
|
if session_id not in self.metadata:
|
||
|
|
self.metadata[session_id] = {}
|
||
|
|
self.metadata[session_id][key] = value
|
||
|
|
|
||
|
|
def get_metadata(self, session_id: str, key: str, default: Any = None) -> Any:
|
||
|
|
"""获取会话元数据"""
|
||
|
|
return self.metadata.get(session_id, {}).get(key, default)
|
||
|
|
|
||
|
|
def list_sessions(self) -> list[str]:
|
||
|
|
"""列出所有会话ID"""
|
||
|
|
return list(self.sessions.keys())
|
||
|
|
|
||
|
|
def get_session_count(self) -> int:
|
||
|
|
"""获取会话数量"""
|
||
|
|
return len(self.sessions)
|