Files
JARVIS/backend/app/agents/session/manager.py

239 lines
7.7 KiB
Python
Raw Normal View History

"""Agent Session 管理 - Phase 10.3
支持会话层级管理和子会话创建
"""
import json
import os
import uuid
from dataclasses import asdict, dataclass, field
from datetime import datetime
from typing import Any
@dataclass
class SessionContext:
"""会话上下文"""
session_id: str
parent_session_id: str | None = None
root_session_id: str | None = None
depth: int = 0
user_id: str | None = None
created_at: str | None = None
last_active: str | None = None
message_count: int = 0
metadata: dict[str, Any] = field(default_factory=dict)
def __post_init__(self):
if self.created_at is None:
self.created_at = datetime.now().isoformat()
if self.last_active is None:
self.last_active = self.created_at
@dataclass
class SessionPersistence:
"""会话持久化"""
def __init__(self, persistence_dir: str | None = None):
if persistence_dir is None:
persistence_dir = os.path.join(
os.path.dirname(__file__), "..", "..", "..", "data", "sessions"
)
self.persistence_dir = persistence_dir
def _get_session_path(self, session_id: str) -> str:
return os.path.join(self.persistence_dir, f"{session_id}.json")
def save(self, session: "AgentSession") -> bool:
"""保存会话"""
try:
os.makedirs(self.persistence_dir, exist_ok=True)
path = self._get_session_path(session.session_id)
data = {
"session_id": session.session_id,
"parent_session_id": session.context.parent_session_id,
"root_session_id": session.context.root_session_id,
"depth": session.context.depth,
"user_id": session.context.user_id,
"created_at": session.context.created_at,
"last_active": session.context.last_active,
"message_count": session.context.message_count,
"metadata": session.context.metadata,
"history": session._history,
}
with open(path, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2, ensure_ascii=False)
return True
except Exception:
return False
def load(self, session_id: str) -> dict[str, Any] | None:
"""加载会话"""
try:
path = self._get_session_path(session_id)
if not os.path.exists(path):
return None
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
except Exception:
return None
def delete(self, session_id: str) -> bool:
"""删除会话"""
try:
path = self._get_session_path(session_id)
if os.path.exists(path):
os.remove(path)
return True
except Exception:
return False
def list_sessions(self, user_id: str | None = None) -> list[dict[str, Any]]:
"""列出所有会话"""
sessions = []
try:
os.makedirs(self.persistence_dir, exist_ok=True)
for filename in os.listdir(self.persistence_dir):
if filename.endswith(".json"):
path = os.path.join(self.persistence_dir, filename)
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
if user_id is None or data.get("user_id") == user_id:
sessions.append(data)
except Exception:
pass
return sessions
class AgentSession:
"""Agent 会话管理器
支持
- 会话层级parent/root/depth
- 子会话创建
- 会话摘要
- 持久化
"""
def __init__(
self,
session_id: str | None = None,
user_id: str | None = None,
parent_session_id: str | None = None,
):
self.session_id = session_id or str(uuid.uuid4())[:8]
self.context = SessionContext(
session_id=self.session_id,
user_id=user_id,
parent_session_id=parent_session_id,
depth=0 if parent_session_id is None else 1,
)
self._history: list[dict[str, Any]] = []
self._persistence = SessionPersistence()
# 如果有父会话,设置 root_session_id
if parent_session_id:
parent_data = self._persistence.load(parent_session_id)
if parent_data:
self.context.root_session_id = (
parent_data.get("root_session_id") or parent_session_id
)
self.context.depth = parent_data.get("depth", 0) + 1
async def initialize(self) -> dict[str, Any]:
"""初始化会话"""
self.context.last_active = datetime.now().isoformat()
self._persistence.save(self)
return {
"session_id": self.session_id,
"depth": self.context.depth,
"parent_session_id": self.context.parent_session_id,
"root_session_id": self.context.root_session_id,
}
async def process_message(self, message: str, response: str) -> None:
"""处理消息并记录到历史"""
self.context.message_count += 1
self.context.last_active = datetime.now().isoformat()
self._history.append(
{
"role": "user",
"content": message,
"timestamp": datetime.now().isoformat(),
}
)
self._history.append(
{
"role": "assistant",
"content": response,
"timestamp": datetime.now().isoformat(),
}
)
self._persistence.save(self)
async def spawn_child_session(self, user_id: str | None = None) -> "AgentSession":
"""创建子会话"""
child = AgentSession(
user_id=user_id or self.context.user_id,
parent_session_id=self.session_id,
)
child.context.root_session_id = self.context.root_session_id or self.session_id
await child.initialize()
return child
async def get_session_summary(self) -> dict[str, Any]:
"""获取会话摘要"""
return {
"session_id": self.session_id,
"parent_session_id": self.context.parent_session_id,
"root_session_id": self.context.root_session_id,
"depth": self.context.depth,
"user_id": self.context.user_id,
"created_at": self.context.created_at,
"last_active": self.context.last_active,
"message_count": self.context.message_count,
"history_length": len(self._history),
}
async def persist(self) -> bool:
"""持久化会话"""
return self._persistence.save(self)
def get_history(self) -> list[dict[str, Any]]:
"""获取会话历史"""
return self._history.copy()
def add_metadata(self, key: str, value: Any) -> None:
"""添加会话元数据"""
self.context.metadata[key] = value
def get_metadata(self, key: str) -> Any:
"""获取会话元数据"""
return self.context.metadata.get(key)
# 全局会话存储(内存中)
_sessions: dict[str, AgentSession] = {}
def get_agent_session(session_id: str) -> AgentSession | None:
"""获取会话"""
return _sessions.get(session_id)
def create_agent_session(
session_id: str | None = None,
user_id: str | None = None,
parent_session_id: str | None = None,
) -> AgentSession:
"""创建新会话"""
session = AgentSession(
session_id=session_id,
user_id=user_id,
parent_session_id=parent_session_id,
)
_sessions[session.session_id] = session
return session