239 lines
7.7 KiB
Python
239 lines
7.7 KiB
Python
"""Agent Session 管理 - Phase 10.3
|
||
|
||
支持会话层级管理和子会话创建。
|
||
"""
|
||
|
||
import json
|
||
import os
|
||
import uuid
|
||
from dataclasses import asdict, dataclass, field
|
||
from datetime import datetime
|
||
from typing import Any
|
||
|
||
|
||
@dataclass
|
||
class SessionContext:
|
||
"""会话上下文"""
|
||
|
||
session_id: str
|
||
parent_session_id: str | None = None
|
||
root_session_id: str | None = None
|
||
depth: int = 0
|
||
user_id: str | None = None
|
||
created_at: str | None = None
|
||
last_active: str | None = None
|
||
message_count: int = 0
|
||
metadata: dict[str, Any] = field(default_factory=dict)
|
||
|
||
def __post_init__(self):
|
||
if self.created_at is None:
|
||
self.created_at = datetime.now().isoformat()
|
||
if self.last_active is None:
|
||
self.last_active = self.created_at
|
||
|
||
|
||
@dataclass
|
||
class SessionPersistence:
|
||
"""会话持久化"""
|
||
|
||
def __init__(self, persistence_dir: str | None = None):
|
||
if persistence_dir is None:
|
||
persistence_dir = os.path.join(
|
||
os.path.dirname(__file__), "..", "..", "..", "data", "sessions"
|
||
)
|
||
self.persistence_dir = persistence_dir
|
||
|
||
def _get_session_path(self, session_id: str) -> str:
|
||
return os.path.join(self.persistence_dir, f"{session_id}.json")
|
||
|
||
def save(self, session: "AgentSession") -> bool:
|
||
"""保存会话"""
|
||
try:
|
||
os.makedirs(self.persistence_dir, exist_ok=True)
|
||
path = self._get_session_path(session.session_id)
|
||
data = {
|
||
"session_id": session.session_id,
|
||
"parent_session_id": session.context.parent_session_id,
|
||
"root_session_id": session.context.root_session_id,
|
||
"depth": session.context.depth,
|
||
"user_id": session.context.user_id,
|
||
"created_at": session.context.created_at,
|
||
"last_active": session.context.last_active,
|
||
"message_count": session.context.message_count,
|
||
"metadata": session.context.metadata,
|
||
"history": session._history,
|
||
}
|
||
with open(path, "w", encoding="utf-8") as f:
|
||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||
return True
|
||
except Exception:
|
||
return False
|
||
|
||
def load(self, session_id: str) -> dict[str, Any] | None:
|
||
"""加载会话"""
|
||
try:
|
||
path = self._get_session_path(session_id)
|
||
if not os.path.exists(path):
|
||
return None
|
||
with open(path, "r", encoding="utf-8") as f:
|
||
return json.load(f)
|
||
except Exception:
|
||
return None
|
||
|
||
def delete(self, session_id: str) -> bool:
|
||
"""删除会话"""
|
||
try:
|
||
path = self._get_session_path(session_id)
|
||
if os.path.exists(path):
|
||
os.remove(path)
|
||
return True
|
||
except Exception:
|
||
return False
|
||
|
||
def list_sessions(self, user_id: str | None = None) -> list[dict[str, Any]]:
|
||
"""列出所有会话"""
|
||
sessions = []
|
||
try:
|
||
os.makedirs(self.persistence_dir, exist_ok=True)
|
||
for filename in os.listdir(self.persistence_dir):
|
||
if filename.endswith(".json"):
|
||
path = os.path.join(self.persistence_dir, filename)
|
||
with open(path, "r", encoding="utf-8") as f:
|
||
data = json.load(f)
|
||
if user_id is None or data.get("user_id") == user_id:
|
||
sessions.append(data)
|
||
except Exception:
|
||
pass
|
||
return sessions
|
||
|
||
|
||
class AgentSession:
|
||
"""Agent 会话管理器
|
||
|
||
支持:
|
||
- 会话层级(parent/root/depth)
|
||
- 子会话创建
|
||
- 会话摘要
|
||
- 持久化
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
session_id: str | None = None,
|
||
user_id: str | None = None,
|
||
parent_session_id: str | None = None,
|
||
):
|
||
self.session_id = session_id or str(uuid.uuid4())[:8]
|
||
self.context = SessionContext(
|
||
session_id=self.session_id,
|
||
user_id=user_id,
|
||
parent_session_id=parent_session_id,
|
||
depth=0 if parent_session_id is None else 1,
|
||
)
|
||
self._history: list[dict[str, Any]] = []
|
||
self._persistence = SessionPersistence()
|
||
|
||
# 如果有父会话,设置 root_session_id
|
||
if parent_session_id:
|
||
parent_data = self._persistence.load(parent_session_id)
|
||
if parent_data:
|
||
self.context.root_session_id = (
|
||
parent_data.get("root_session_id") or parent_session_id
|
||
)
|
||
self.context.depth = parent_data.get("depth", 0) + 1
|
||
|
||
async def initialize(self) -> dict[str, Any]:
|
||
"""初始化会话"""
|
||
self.context.last_active = datetime.now().isoformat()
|
||
self._persistence.save(self)
|
||
return {
|
||
"session_id": self.session_id,
|
||
"depth": self.context.depth,
|
||
"parent_session_id": self.context.parent_session_id,
|
||
"root_session_id": self.context.root_session_id,
|
||
}
|
||
|
||
async def process_message(self, message: str, response: str) -> None:
|
||
"""处理消息并记录到历史"""
|
||
self.context.message_count += 1
|
||
self.context.last_active = datetime.now().isoformat()
|
||
self._history.append(
|
||
{
|
||
"role": "user",
|
||
"content": message,
|
||
"timestamp": datetime.now().isoformat(),
|
||
}
|
||
)
|
||
self._history.append(
|
||
{
|
||
"role": "assistant",
|
||
"content": response,
|
||
"timestamp": datetime.now().isoformat(),
|
||
}
|
||
)
|
||
self._persistence.save(self)
|
||
|
||
async def spawn_child_session(self, user_id: str | None = None) -> "AgentSession":
|
||
"""创建子会话"""
|
||
child = AgentSession(
|
||
user_id=user_id or self.context.user_id,
|
||
parent_session_id=self.session_id,
|
||
)
|
||
child.context.root_session_id = self.context.root_session_id or self.session_id
|
||
await child.initialize()
|
||
return child
|
||
|
||
async def get_session_summary(self) -> dict[str, Any]:
|
||
"""获取会话摘要"""
|
||
return {
|
||
"session_id": self.session_id,
|
||
"parent_session_id": self.context.parent_session_id,
|
||
"root_session_id": self.context.root_session_id,
|
||
"depth": self.context.depth,
|
||
"user_id": self.context.user_id,
|
||
"created_at": self.context.created_at,
|
||
"last_active": self.context.last_active,
|
||
"message_count": self.context.message_count,
|
||
"history_length": len(self._history),
|
||
}
|
||
|
||
async def persist(self) -> bool:
|
||
"""持久化会话"""
|
||
return self._persistence.save(self)
|
||
|
||
def get_history(self) -> list[dict[str, Any]]:
|
||
"""获取会话历史"""
|
||
return self._history.copy()
|
||
|
||
def add_metadata(self, key: str, value: Any) -> None:
|
||
"""添加会话元数据"""
|
||
self.context.metadata[key] = value
|
||
|
||
def get_metadata(self, key: str) -> Any:
|
||
"""获取会话元数据"""
|
||
return self.context.metadata.get(key)
|
||
|
||
|
||
# 全局会话存储(内存中)
|
||
_sessions: dict[str, AgentSession] = {}
|
||
|
||
|
||
def get_agent_session(session_id: str) -> AgentSession | None:
|
||
"""获取会话"""
|
||
return _sessions.get(session_id)
|
||
|
||
|
||
def create_agent_session(
|
||
session_id: str | None = None,
|
||
user_id: str | None = None,
|
||
parent_session_id: str | None = None,
|
||
) -> AgentSession:
|
||
"""创建新会话"""
|
||
session = AgentSession(
|
||
session_id=session_id,
|
||
user_id=user_id,
|
||
parent_session_id=parent_session_id,
|
||
)
|
||
_sessions[session.session_id] = session
|
||
return session
|