239 lines
7.7 KiB
Python
239 lines
7.7 KiB
Python
|
|
"""Agent Session 管理 - Phase 10.3
|
|||
|
|
|
|||
|
|
支持会话层级管理和子会话创建。
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import json
|
|||
|
|
import os
|
|||
|
|
import uuid
|
|||
|
|
from dataclasses import asdict, dataclass, field
|
|||
|
|
from datetime import datetime
|
|||
|
|
from typing import Any
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class SessionContext:
|
|||
|
|
"""会话上下文"""
|
|||
|
|
|
|||
|
|
session_id: str
|
|||
|
|
parent_session_id: str | None = None
|
|||
|
|
root_session_id: str | None = None
|
|||
|
|
depth: int = 0
|
|||
|
|
user_id: str | None = None
|
|||
|
|
created_at: str | None = None
|
|||
|
|
last_active: str | None = None
|
|||
|
|
message_count: int = 0
|
|||
|
|
metadata: dict[str, Any] = field(default_factory=dict)
|
|||
|
|
|
|||
|
|
def __post_init__(self):
|
|||
|
|
if self.created_at is None:
|
|||
|
|
self.created_at = datetime.now().isoformat()
|
|||
|
|
if self.last_active is None:
|
|||
|
|
self.last_active = self.created_at
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class SessionPersistence:
|
|||
|
|
"""会话持久化"""
|
|||
|
|
|
|||
|
|
def __init__(self, persistence_dir: str | None = None):
|
|||
|
|
if persistence_dir is None:
|
|||
|
|
persistence_dir = os.path.join(
|
|||
|
|
os.path.dirname(__file__), "..", "..", "..", "data", "sessions"
|
|||
|
|
)
|
|||
|
|
self.persistence_dir = persistence_dir
|
|||
|
|
|
|||
|
|
def _get_session_path(self, session_id: str) -> str:
|
|||
|
|
return os.path.join(self.persistence_dir, f"{session_id}.json")
|
|||
|
|
|
|||
|
|
def save(self, session: "AgentSession") -> bool:
|
|||
|
|
"""保存会话"""
|
|||
|
|
try:
|
|||
|
|
os.makedirs(self.persistence_dir, exist_ok=True)
|
|||
|
|
path = self._get_session_path(session.session_id)
|
|||
|
|
data = {
|
|||
|
|
"session_id": session.session_id,
|
|||
|
|
"parent_session_id": session.context.parent_session_id,
|
|||
|
|
"root_session_id": session.context.root_session_id,
|
|||
|
|
"depth": session.context.depth,
|
|||
|
|
"user_id": session.context.user_id,
|
|||
|
|
"created_at": session.context.created_at,
|
|||
|
|
"last_active": session.context.last_active,
|
|||
|
|
"message_count": session.context.message_count,
|
|||
|
|
"metadata": session.context.metadata,
|
|||
|
|
"history": session._history,
|
|||
|
|
}
|
|||
|
|
with open(path, "w", encoding="utf-8") as f:
|
|||
|
|
json.dump(data, f, indent=2, ensure_ascii=False)
|
|||
|
|
return True
|
|||
|
|
except Exception:
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
def load(self, session_id: str) -> dict[str, Any] | None:
|
|||
|
|
"""加载会话"""
|
|||
|
|
try:
|
|||
|
|
path = self._get_session_path(session_id)
|
|||
|
|
if not os.path.exists(path):
|
|||
|
|
return None
|
|||
|
|
with open(path, "r", encoding="utf-8") as f:
|
|||
|
|
return json.load(f)
|
|||
|
|
except Exception:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
def delete(self, session_id: str) -> bool:
|
|||
|
|
"""删除会话"""
|
|||
|
|
try:
|
|||
|
|
path = self._get_session_path(session_id)
|
|||
|
|
if os.path.exists(path):
|
|||
|
|
os.remove(path)
|
|||
|
|
return True
|
|||
|
|
except Exception:
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
def list_sessions(self, user_id: str | None = None) -> list[dict[str, Any]]:
|
|||
|
|
"""列出所有会话"""
|
|||
|
|
sessions = []
|
|||
|
|
try:
|
|||
|
|
os.makedirs(self.persistence_dir, exist_ok=True)
|
|||
|
|
for filename in os.listdir(self.persistence_dir):
|
|||
|
|
if filename.endswith(".json"):
|
|||
|
|
path = os.path.join(self.persistence_dir, filename)
|
|||
|
|
with open(path, "r", encoding="utf-8") as f:
|
|||
|
|
data = json.load(f)
|
|||
|
|
if user_id is None or data.get("user_id") == user_id:
|
|||
|
|
sessions.append(data)
|
|||
|
|
except Exception:
|
|||
|
|
pass
|
|||
|
|
return sessions
|
|||
|
|
|
|||
|
|
|
|||
|
|
class AgentSession:
|
|||
|
|
"""Agent 会话管理器
|
|||
|
|
|
|||
|
|
支持:
|
|||
|
|
- 会话层级(parent/root/depth)
|
|||
|
|
- 子会话创建
|
|||
|
|
- 会话摘要
|
|||
|
|
- 持久化
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
session_id: str | None = None,
|
|||
|
|
user_id: str | None = None,
|
|||
|
|
parent_session_id: str | None = None,
|
|||
|
|
):
|
|||
|
|
self.session_id = session_id or str(uuid.uuid4())[:8]
|
|||
|
|
self.context = SessionContext(
|
|||
|
|
session_id=self.session_id,
|
|||
|
|
user_id=user_id,
|
|||
|
|
parent_session_id=parent_session_id,
|
|||
|
|
depth=0 if parent_session_id is None else 1,
|
|||
|
|
)
|
|||
|
|
self._history: list[dict[str, Any]] = []
|
|||
|
|
self._persistence = SessionPersistence()
|
|||
|
|
|
|||
|
|
# 如果有父会话,设置 root_session_id
|
|||
|
|
if parent_session_id:
|
|||
|
|
parent_data = self._persistence.load(parent_session_id)
|
|||
|
|
if parent_data:
|
|||
|
|
self.context.root_session_id = (
|
|||
|
|
parent_data.get("root_session_id") or parent_session_id
|
|||
|
|
)
|
|||
|
|
self.context.depth = parent_data.get("depth", 0) + 1
|
|||
|
|
|
|||
|
|
async def initialize(self) -> dict[str, Any]:
|
|||
|
|
"""初始化会话"""
|
|||
|
|
self.context.last_active = datetime.now().isoformat()
|
|||
|
|
self._persistence.save(self)
|
|||
|
|
return {
|
|||
|
|
"session_id": self.session_id,
|
|||
|
|
"depth": self.context.depth,
|
|||
|
|
"parent_session_id": self.context.parent_session_id,
|
|||
|
|
"root_session_id": self.context.root_session_id,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
async def process_message(self, message: str, response: str) -> None:
|
|||
|
|
"""处理消息并记录到历史"""
|
|||
|
|
self.context.message_count += 1
|
|||
|
|
self.context.last_active = datetime.now().isoformat()
|
|||
|
|
self._history.append(
|
|||
|
|
{
|
|||
|
|
"role": "user",
|
|||
|
|
"content": message,
|
|||
|
|
"timestamp": datetime.now().isoformat(),
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
self._history.append(
|
|||
|
|
{
|
|||
|
|
"role": "assistant",
|
|||
|
|
"content": response,
|
|||
|
|
"timestamp": datetime.now().isoformat(),
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
self._persistence.save(self)
|
|||
|
|
|
|||
|
|
async def spawn_child_session(self, user_id: str | None = None) -> "AgentSession":
|
|||
|
|
"""创建子会话"""
|
|||
|
|
child = AgentSession(
|
|||
|
|
user_id=user_id or self.context.user_id,
|
|||
|
|
parent_session_id=self.session_id,
|
|||
|
|
)
|
|||
|
|
child.context.root_session_id = self.context.root_session_id or self.session_id
|
|||
|
|
await child.initialize()
|
|||
|
|
return child
|
|||
|
|
|
|||
|
|
async def get_session_summary(self) -> dict[str, Any]:
|
|||
|
|
"""获取会话摘要"""
|
|||
|
|
return {
|
|||
|
|
"session_id": self.session_id,
|
|||
|
|
"parent_session_id": self.context.parent_session_id,
|
|||
|
|
"root_session_id": self.context.root_session_id,
|
|||
|
|
"depth": self.context.depth,
|
|||
|
|
"user_id": self.context.user_id,
|
|||
|
|
"created_at": self.context.created_at,
|
|||
|
|
"last_active": self.context.last_active,
|
|||
|
|
"message_count": self.context.message_count,
|
|||
|
|
"history_length": len(self._history),
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
async def persist(self) -> bool:
|
|||
|
|
"""持久化会话"""
|
|||
|
|
return self._persistence.save(self)
|
|||
|
|
|
|||
|
|
def get_history(self) -> list[dict[str, Any]]:
|
|||
|
|
"""获取会话历史"""
|
|||
|
|
return self._history.copy()
|
|||
|
|
|
|||
|
|
def add_metadata(self, key: str, value: Any) -> None:
|
|||
|
|
"""添加会话元数据"""
|
|||
|
|
self.context.metadata[key] = value
|
|||
|
|
|
|||
|
|
def get_metadata(self, key: str) -> Any:
|
|||
|
|
"""获取会话元数据"""
|
|||
|
|
return self.context.metadata.get(key)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 全局会话存储(内存中)
|
|||
|
|
_sessions: dict[str, AgentSession] = {}
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_agent_session(session_id: str) -> AgentSession | None:
|
|||
|
|
"""获取会话"""
|
|||
|
|
return _sessions.get(session_id)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def create_agent_session(
|
|||
|
|
session_id: str | None = None,
|
|||
|
|
user_id: str | None = None,
|
|||
|
|
parent_session_id: str | None = None,
|
|||
|
|
) -> AgentSession:
|
|||
|
|
"""创建新会话"""
|
|||
|
|
session = AgentSession(
|
|||
|
|
session_id=session_id,
|
|||
|
|
user_id=user_id,
|
|||
|
|
parent_session_id=parent_session_id,
|
|||
|
|
)
|
|||
|
|
_sessions[session.session_id] = session
|
|||
|
|
return session
|