Files
JARVIS/backend/app/agents/session/manager.py

239 lines
7.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Agent Session 管理 - Phase 10.3
支持会话层级管理和子会话创建。
"""
import json
import os
import uuid
from dataclasses import asdict, dataclass, field
from datetime import datetime
from typing import Any
@dataclass
class SessionContext:
"""会话上下文"""
session_id: str
parent_session_id: str | None = None
root_session_id: str | None = None
depth: int = 0
user_id: str | None = None
created_at: str | None = None
last_active: str | None = None
message_count: int = 0
metadata: dict[str, Any] = field(default_factory=dict)
def __post_init__(self):
if self.created_at is None:
self.created_at = datetime.now().isoformat()
if self.last_active is None:
self.last_active = self.created_at
@dataclass
class SessionPersistence:
"""会话持久化"""
def __init__(self, persistence_dir: str | None = None):
if persistence_dir is None:
persistence_dir = os.path.join(
os.path.dirname(__file__), "..", "..", "..", "data", "sessions"
)
self.persistence_dir = persistence_dir
def _get_session_path(self, session_id: str) -> str:
return os.path.join(self.persistence_dir, f"{session_id}.json")
def save(self, session: "AgentSession") -> bool:
"""保存会话"""
try:
os.makedirs(self.persistence_dir, exist_ok=True)
path = self._get_session_path(session.session_id)
data = {
"session_id": session.session_id,
"parent_session_id": session.context.parent_session_id,
"root_session_id": session.context.root_session_id,
"depth": session.context.depth,
"user_id": session.context.user_id,
"created_at": session.context.created_at,
"last_active": session.context.last_active,
"message_count": session.context.message_count,
"metadata": session.context.metadata,
"history": session._history,
}
with open(path, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2, ensure_ascii=False)
return True
except Exception:
return False
def load(self, session_id: str) -> dict[str, Any] | None:
"""加载会话"""
try:
path = self._get_session_path(session_id)
if not os.path.exists(path):
return None
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
except Exception:
return None
def delete(self, session_id: str) -> bool:
"""删除会话"""
try:
path = self._get_session_path(session_id)
if os.path.exists(path):
os.remove(path)
return True
except Exception:
return False
def list_sessions(self, user_id: str | None = None) -> list[dict[str, Any]]:
"""列出所有会话"""
sessions = []
try:
os.makedirs(self.persistence_dir, exist_ok=True)
for filename in os.listdir(self.persistence_dir):
if filename.endswith(".json"):
path = os.path.join(self.persistence_dir, filename)
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
if user_id is None or data.get("user_id") == user_id:
sessions.append(data)
except Exception:
pass
return sessions
class AgentSession:
"""Agent 会话管理器
支持:
- 会话层级parent/root/depth
- 子会话创建
- 会话摘要
- 持久化
"""
def __init__(
self,
session_id: str | None = None,
user_id: str | None = None,
parent_session_id: str | None = None,
):
self.session_id = session_id or str(uuid.uuid4())[:8]
self.context = SessionContext(
session_id=self.session_id,
user_id=user_id,
parent_session_id=parent_session_id,
depth=0 if parent_session_id is None else 1,
)
self._history: list[dict[str, Any]] = []
self._persistence = SessionPersistence()
# 如果有父会话,设置 root_session_id
if parent_session_id:
parent_data = self._persistence.load(parent_session_id)
if parent_data:
self.context.root_session_id = (
parent_data.get("root_session_id") or parent_session_id
)
self.context.depth = parent_data.get("depth", 0) + 1
async def initialize(self) -> dict[str, Any]:
"""初始化会话"""
self.context.last_active = datetime.now().isoformat()
self._persistence.save(self)
return {
"session_id": self.session_id,
"depth": self.context.depth,
"parent_session_id": self.context.parent_session_id,
"root_session_id": self.context.root_session_id,
}
async def process_message(self, message: str, response: str) -> None:
"""处理消息并记录到历史"""
self.context.message_count += 1
self.context.last_active = datetime.now().isoformat()
self._history.append(
{
"role": "user",
"content": message,
"timestamp": datetime.now().isoformat(),
}
)
self._history.append(
{
"role": "assistant",
"content": response,
"timestamp": datetime.now().isoformat(),
}
)
self._persistence.save(self)
async def spawn_child_session(self, user_id: str | None = None) -> "AgentSession":
"""创建子会话"""
child = AgentSession(
user_id=user_id or self.context.user_id,
parent_session_id=self.session_id,
)
child.context.root_session_id = self.context.root_session_id or self.session_id
await child.initialize()
return child
async def get_session_summary(self) -> dict[str, Any]:
"""获取会话摘要"""
return {
"session_id": self.session_id,
"parent_session_id": self.context.parent_session_id,
"root_session_id": self.context.root_session_id,
"depth": self.context.depth,
"user_id": self.context.user_id,
"created_at": self.context.created_at,
"last_active": self.context.last_active,
"message_count": self.context.message_count,
"history_length": len(self._history),
}
async def persist(self) -> bool:
"""持久化会话"""
return self._persistence.save(self)
def get_history(self) -> list[dict[str, Any]]:
"""获取会话历史"""
return self._history.copy()
def add_metadata(self, key: str, value: Any) -> None:
"""添加会话元数据"""
self.context.metadata[key] = value
def get_metadata(self, key: str) -> Any:
"""获取会话元数据"""
return self.context.metadata.get(key)
# 全局会话存储(内存中)
_sessions: dict[str, AgentSession] = {}
def get_agent_session(session_id: str) -> AgentSession | None:
"""获取会话"""
return _sessions.get(session_id)
def create_agent_session(
session_id: str | None = None,
user_id: str | None = None,
parent_session_id: str | None = None,
) -> AgentSession:
"""创建新会话"""
session = AgentSession(
session_id=session_id,
user_id=user_id,
parent_session_id=parent_session_id,
)
_sessions[session.session_id] = session
return session