Files
X-Agents/core/agents/agent/memory.py
DESKTOP-72TV0V4\caoxiaozhu 1afa88e812 feat: 增强 core/agents 工具和 API
- 新增 loop.py Agent 运行循环
- 优化 memory.py 记忆模块
- 扩展 api/routes.py 接口
- 更新 tools 模块:builtin.py, manager.py, __init__.py
- 新增 .env.example 配置示例
- 更新 requirements.txt 依赖

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-15 19:49:40 +08:00

978 lines
32 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Memory management for agent sessions."""
import json
import logging
from collections import defaultdict
from datetime import datetime
from pathlib import Path
from typing import Any
import aiohttp
logger = logging.getLogger(__name__)
class SessionMemory:
"""短期会话记忆 - 内存中的会话消息存储,支持 Markdown 持久化"""
def __init__(self, max_messages: int = 50, workspace: Path | str | None = None):
"""初始化会话记忆
Args:
max_messages: 每个会话保留的最大消息数
workspace: 工作区目录,用于持久化会话文件
"""
self.max_messages = max_messages
self._sessions: dict[str, list[dict[str, Any]]] = defaultdict(list)
# 持久化支持
self.workspace = Path(workspace) if workspace else None
self.sessions_dir = None
if self.workspace:
self.sessions_dir = self.workspace / "sessions"
self.sessions_dir.mkdir(parents=True, exist_ok=True)
# 启动时加载所有会话
self._load_all_sessions()
def _get_session_file(self, session_id: str) -> Path | None:
"""获取会话文件路径"""
if not self.sessions_dir:
return None
# 清理 session_id 中的非法文件名字符
safe_id = "".join(c if c.isalnum() or c in "-_" else "_" for c in session_id)
return self.sessions_dir / f"{safe_id}.md"
def _load_all_sessions(self) -> None:
"""启动时加载所有会话文件"""
if not self.sessions_dir or not self.sessions_dir.exists():
return
for session_file in self.sessions_dir.glob("*.md"):
session_id = session_file.stem
self._load_session(session_id)
logger.info(f"Loaded session from file: {session_id}")
def _load_session(self, session_id: str) -> list[dict[str, Any]]:
"""从文件加载单个会话
Args:
session_id: 会话ID
Returns:
消息列表
"""
session_file = self._get_session_file(session_id)
if not session_file or not session_file.exists():
return []
try:
content = session_file.read_text(encoding="utf-8")
messages = []
lines = content.strip().split("\n")
current_message = {}
for line in lines:
line = line.strip()
if not line:
continue
# 解析 "## 消息 N" 格式
if line.startswith("## 消息"):
# 保存上一条消息
if current_message:
messages.append(current_message)
current_message = {
"role": "",
"timestamp": "",
"content": "",
}
continue
# 解析 "角色: xxx"
if line.startswith("角色:") and current_message is not None:
current_message["role"] = line.split(":", 1)[1].strip()
continue
# 解析 "时间: xxx"
if line.startswith("时间:") and current_message is not None:
current_message["timestamp"] = line.split(":", 1)[1].strip()
continue
# 解析 "内容: xxx"
if line.startswith("内容:") and current_message is not None:
current_message["content"] = line.split(":", 1)[1].strip()
continue
# 保存最后一条消息
if current_message and current_message.get("role"):
messages.append(current_message)
# 加载到内存
if messages:
self._sessions[session_id] = messages[-self.max_messages:]
return messages
except Exception as e:
logger.error(f"Error loading session {session_id}: {e}")
return []
def _save_session(self, session_id: str) -> None:
"""将会话保存到文件
Args:
session_id: 会话ID
"""
session_file = self._get_session_file(session_id)
if not session_file:
return
messages = self._sessions.get(session_id, [])
if not messages:
# 如果会话为空,删除文件
if session_file.exists():
session_file.unlink()
return
# 构建 Markdown 内容(使用产品经理指定的格式)
created_time = messages[0].get("timestamp", datetime.now().isoformat()) if messages else datetime.now().isoformat()
created_time_str = created_time.replace("T", " ") if "T" in created_time else created_time
lines = [
f"# 会话: {session_id}",
f"创建时间: {created_time_str}",
"",
]
for i, msg in enumerate(messages, 1):
role = msg.get("role", "unknown")
timestamp = msg.get("timestamp", "")
content = msg.get("content", "")
# 格式化时间
if "T" in timestamp:
timestamp = timestamp.replace("T", " ")
lines.append(f"## 消息 {i}")
lines.append(f"角色: {role}")
lines.append(f"时间: {timestamp}")
lines.append(f"内容: {content}")
lines.append("")
try:
session_file.write_text("\n".join(lines), encoding="utf-8")
except Exception as e:
logger.error(f"Error saving session {session_id}: {e}")
def add_message(self, session_id: str, role: str, content: str, metadata: dict | None = None) -> None:
"""添加消息到会话
Args:
session_id: 会话ID
role: 消息角色 (user/assistant/system)
content: 消息内容
metadata: 附加元数据
"""
message = {
"role": role,
"content": content,
"timestamp": datetime.now().isoformat(),
}
if metadata:
message["metadata"] = metadata
session_messages = self._sessions[session_id]
session_messages.append(message)
# 超过最大消息数时,移除最旧的消息
if len(session_messages) > self.max_messages:
self._sessions[session_id] = session_messages[-self.max_messages:]
# 持久化到文件
self._save_session(session_id)
def get_history(self, session_id: str, max_messages: int = 0) -> list[dict[str, Any]]:
"""获取会话历史
Args:
session_id: 会话ID
max_messages: 返回的最大消息数0表示全部
Returns:
消息列表
"""
# 如果内存中没有,尝试从文件加载
if session_id not in self._sessions:
self._load_session(session_id)
messages = self._sessions.get(session_id, [])
if max_messages > 0 and len(messages) > max_messages:
return messages[-max_messages:]
return messages
def clear_session(self, session_id: str) -> None:
"""清除会话记忆
Args:
session_id: 会话ID
"""
if session_id in self._sessions:
del self._sessions[session_id]
# 删除会话文件
session_file = self._get_session_file(session_id)
if session_file and session_file.exists():
session_file.unlink()
def get_session_count(self) -> int:
"""获取当前会话数量"""
return len(self._sessions)
def list_sessions(self) -> list[str]:
"""列出所有会话ID"""
return list(self._sessions.keys())
class RemoteMemoryClient:
"""与Go端Memory API对接的客户端"""
def __init__(self, base_url: str, agent_id: str, user_id: str = "default"):
"""初始化远程记忆客户端
Args:
base_url: Go服务端地址
agent_id: Agent ID
user_id: 用户ID
"""
self.base_url = base_url.rstrip("/")
self.agent_id = agent_id
self.user_id = user_id
self._session = None
async def _get_session(self) -> aiohttp.ClientSession:
"""获取或创建aiohttp session"""
if self._session is None or self._session.closed:
self._session = aiohttp.ClientSession()
return self._session
async def close(self) -> None:
"""关闭session"""
if self._session and not self._session.closed:
await self._session.close()
async def create_memory(
self,
content: str,
memory_type: str = "conversation",
importance: int = 5,
) -> dict[str, Any] | None:
"""创建记忆
Args:
content: 记忆内容
memory_type: 记忆类型 (conversation/experience/lessons)
importance: 重要性评分 1-10
Returns:
创建的记忆对象
"""
url = f"{self.base_url}/api/agent/{self.agent_id}/memories"
payload = {
"agent_id": self.agent_id,
"user_id": self.user_id,
"content": content,
"memory_type": memory_type,
"importance": importance,
}
try:
session = await self._get_session()
async with session.post(url, json=payload) as response:
if response.status == 200:
return await response.json()
logger.warning(f"Failed to create memory: {response.status}")
return None
except Exception as e:
logger.error(f"Error creating memory: {e}")
return None
async def get_memories(
self,
limit: int = 10,
offset: int = 0,
memory_type: str | None = None,
category: str | None = None,
) -> list[dict[str, Any]]:
"""获取记忆列表
Args:
limit: 返回数量限制
offset: 偏移量
memory_type: 记忆类型筛选
category: 分类筛选
Returns:
记忆列表
"""
url = f"{self.base_url}/api/agent/{self.agent_id}/memories"
params = {
"user_id": self.user_id,
"limit": limit,
"offset": offset,
}
if memory_type:
params["memory_type"] = memory_type
if category:
params["category"] = category
try:
session = await self._get_session()
async with session.get(url, params=params) as response:
if response.status == 200:
result = await response.json()
return result if isinstance(result, list) else result.get("list", [])
return []
except Exception as e:
logger.error(f"Error getting memories: {e}")
return []
async def search_memories(
self,
keyword: str,
tags: str | None = None,
category: str | None = None,
memory_type: str | None = None,
min_score: int = 0,
limit: int = 10,
offset: int = 0,
) -> list[dict[str, Any]]:
"""搜索记忆(关键词搜索)
Args:
keyword: 搜索关键词
tags: 标签筛选
category: 分类筛选
memory_type: 记忆类型筛选
min_score: 最低重要性分数
limit: 返回数量限制
offset: 偏移量
Returns:
记忆列表
"""
url = f"{self.base_url}/api/agent/{self.agent_id}/memories/search"
payload = {
"agent_id": self.agent_id,
"user_id": self.user_id,
"keyword": keyword,
"limit": limit,
"offset": offset,
}
if tags:
payload["tags"] = tags
if category:
payload["category"] = category
if memory_type:
payload["memory_type"] = memory_type
if min_score > 0:
payload["min_score"] = min_score
try:
session = await self._get_session()
async with session.post(url, json=payload) as response:
if response.status == 200:
result = await response.json()
return result.get("list", [])
return []
except Exception as e:
logger.error(f"Error searching memories: {e}")
return []
async def get_categories(self) -> list[str]:
"""获取记忆分类列表
Returns:
分类列表
"""
url = f"{self.base_url}/api/agent/{self.agent_id}/memories/categories"
params = {"user_id": self.user_id}
try:
session = await self._get_session()
async with session.get(url, params=params) as response:
if response.status == 200:
result = await response.json()
return result.get("categories", [])
return []
except Exception as e:
logger.error(f"Error getting categories: {e}")
return []
async def get_tags(self) -> list[str]:
"""获取记忆标签列表
Returns:
标签列表
"""
url = f"{self.base_url}/api/agent/{self.agent_id}/memories/tags"
params = {"user_id": self.user_id}
try:
session = await self._get_session()
async with session.get(url, params=params) as response:
if response.status == 200:
result = await response.json()
return result.get("tags", [])
return []
except Exception as e:
logger.error(f"Error getting tags: {e}")
return []
async def delete_memory(self, memory_id: str) -> bool:
"""删除记忆
Args:
memory_id: 记忆ID
Returns:
是否删除成功
"""
url = f"{self.base_url}/api/agent/{self.agent_id}/memories/{memory_id}"
try:
session = await self._get_session()
async with session.delete(url) as response:
return response.status == 200
except Exception as e:
logger.error(f"Error deleting memory: {e}")
return False
class AgentMemory:
"""Manages agent memory and session history."""
def __init__(self, workspace: Path):
"""Initialize the memory manager.
Args:
workspace: Workspace directory for storing memory
"""
self.workspace = workspace
self.memory_dir = workspace / "memory"
self.memory_dir.mkdir(exist_ok=True)
self.long_term_file = self.memory_dir / "MEMORY.md"
# Session-specific history
self.sessions_dir = self.memory_dir / "sessions"
self.sessions_dir.mkdir(exist_ok=True)
# Legacy history file (for backward compatibility)
self.history_file = self.memory_dir / "HISTORY.md"
def _get_session_file(self, session_key: str) -> Path:
"""Get session file path."""
# Sanitize session_key for filename
safe_key = "".join(c if c.isalnum() or c in "-_" else "_" for c in session_key)
return self.sessions_dir / f"{safe_key}.md"
def get_memory_context(self) -> str:
"""Get long-term memory content.
Returns:
Memory context string
"""
if self.long_term_file.exists():
return self.long_term_file.read_text(encoding="utf-8")
return ""
def add_to_memory(self, content: str) -> None:
"""Add content to long-term memory.
Args:
content: Content to add to memory
"""
with open(self.long_term_file, "a", encoding="utf-8") as f:
f.write(f"\n{content}")
def add_to_history(self, role: str, content: str, session_key: str | None = None) -> None:
"""Add an entry to conversation history.
Args:
role: Message role (user/assistant)
content: Message content
session_key: Session identifier for session-specific history
"""
timestamp = datetime.now().isoformat()
# If session_key provided, save to session file
if session_key:
self._add_to_session_history(session_key, role, content, timestamp)
else:
# Legacy: save to global history file
legacy_timestamp = datetime.now().strftime("%Y-%m-%d %H:%M")
entry = f"[{legacy_timestamp}] {role}: {content}\n"
with open(self.history_file, "a", encoding="utf-8") as f:
f.write(entry)
def _add_to_session_history(self, session_key: str, role: str, content: str, timestamp: str) -> None:
"""Add message to session-specific history file."""
session_file = self._get_session_file(session_key)
# Format timestamp for display
display_timestamp = timestamp.replace("T", " ") if "T" in timestamp else timestamp
# Determine header format based on whether file exists
header = ""
if not session_file.exists():
header = f"# 会话: {session_key}\n创建时间: {display_timestamp}\n\n"
# Count existing messages to determine message number
msg_count = 1
if session_file.exists():
try:
existing = session_file.read_text(encoding="utf-8")
msg_count = existing.count("## 消息") + 1
except:
pass
# Check if content contains tool_calls or tool_result markers
# Format as Markdown (产品经理指定格式)
entry_lines = [
f"## 消息 {msg_count}",
f"角色: {role}",
f"时间: {display_timestamp}",
]
# Handle tool_calls and tool_result content
if content.startswith("[tool_calls]"):
entry_lines.append(f"工具调用: {content[len('[tool_calls]'):]}")
entry_lines.append(f"内容: ")
elif content.startswith("[tool_result]"):
entry_lines.append(f"工具结果: {content[len('[tool_result]'):]}")
entry_lines.append(f"内容: ")
else:
entry_lines.append(f"内容: {content}")
entry = "\n".join(entry_lines) + "\n\n"
with open(session_file, "a", encoding="utf-8") as f:
if header:
f.write(header)
f.write(entry)
def get_history(
self,
session_key: str | None = None,
max_messages: int = 10,
) -> list[dict[str, Any]]:
"""Get conversation history.
Args:
session_key: Optional session key for session-specific history
max_messages: Maximum number of messages to return
Returns:
List of history messages
"""
# If session_key provided, load from session file
if session_key:
return self._get_session_history(session_key, max_messages)
# Legacy: load from global history file
return self._get_legacy_history(max_messages)
def _get_session_history(self, session_key: str, max_messages: int) -> list[dict[str, Any]]:
"""Get history from session file."""
session_file = self._get_session_file(session_key)
if not session_file.exists():
return []
try:
content = session_file.read_text(encoding="utf-8")
lines = content.strip().split("\n")
messages = []
current_message = {}
for line in lines:
line = line.strip()
if not line:
continue
# Skip headers
if line.startswith("#"):
continue
# Parse "## 消息 N"
if line.startswith("## 消息"):
# Save previous message
if current_message and current_message.get("role"):
messages.append(current_message)
current_message = {
"role": "",
"timestamp": "",
"content": "",
}
continue
# Parse "角色: xxx"
if line.startswith("角色:") and current_message is not None:
current_message["role"] = line.split(":", 1)[1].strip()
continue
# Parse "时间: xxx"
if line.startswith("时间:") and current_message is not None:
current_message["timestamp"] = line.split(":", 1)[1].strip()
continue
# Parse "工具调用: xxx" - for tool_calls
if line.startswith("工具调用:") and current_message is not None:
tool_calls_json = line.split(":", 1)[1].strip()
try:
current_message["tool_calls"] = json.loads(tool_calls_json)
except json.JSONDecodeError:
pass
continue
# Parse "工具结果: xxx" - for tool_result
if line.startswith("工具结果:") and current_message is not None:
tool_result_json = line.split(":", 1)[1].strip()
try:
tool_result = json.loads(tool_result_json)
current_message["tool_call_id"] = tool_result.get("tool_call_id", "")
current_message["name"] = tool_result.get("name", "")
current_message["content"] = tool_result.get("content", "")
except json.JSONDecodeError:
pass
continue
# Parse "内容: xxx"
if line.startswith("内容:") and current_message is not None:
current_message["content"] = line.split(":", 1)[1].strip()
continue
# Content line
if current_message:
if current_message.get("content"):
current_message["content"] += "\n" + line
else:
current_message["content"] = line
# Save last message
if current_message:
messages.append(current_message)
# Return most recent messages
if max_messages > 0 and len(messages) > max_messages:
return messages[-max_messages:]
return messages
except Exception as e:
logger.error(f"Error reading session history: {e}")
return []
def _get_legacy_history(self, max_messages: int) -> list[dict[str, Any]]:
"""Get history from legacy history file."""
if not self.history_file.exists():
return []
try:
content = self.history_file.read_text(encoding="utf-8")
lines = content.strip().split("\n")
messages = []
for line in lines[-max_messages * 2:]:
if ": " in line:
try:
_, rest = line.split("] ", 1)
role, content = rest.split(": ", 1)
messages.append({"role": role, "content": content})
except ValueError:
continue
return messages[-max_messages:] if max_messages > 0 else messages
except Exception as e:
logger.error(f"Error reading legacy history: {e}")
return []
def clear_session(self, session_key: str) -> None:
"""Clear a specific session's history.
Args:
session_key: Session key to clear
"""
session_file = self._get_session_file(session_key)
if session_file.exists():
session_file.unlink()
for line in lines[-max_messages * 2:]:
if ": " in line:
# Skip timestamp prefix
try:
_, rest = line.split("] ", 1)
role, content = rest.split(": ", 1)
messages.append({"role": role, "content": content})
except ValueError:
continue
return messages[-max_messages:]
return []
def clear_session(self, session_key: str) -> None:
"""Clear a specific session's history.
Args:
session_key: Session key to clear
"""
# In a full implementation, you'd handle session-specific storage
pass
# Vector memory integration
try:
from .vector_memory import (
VectorMemoryStore,
HybridMemorySearch,
EmbeddingProvider,
create_vector_memory_store,
)
VECTOR_MEMORY_AVAILABLE = True
except ImportError:
VectorMemoryStore = None
HybridMemorySearch = None
EmbeddingProvider = None
create_vector_memory_store = None
VECTOR_MEMORY_AVAILABLE = False
class EnhancedAgentMemory(AgentMemory):
"""Enhanced agent memory with vector search capabilities."""
def __init__(
self,
workspace: Path,
enable_vector_search: bool = False,
vector_persist_dir: str | None = None,
embedding_provider: str = "openai",
embedding_model: str = "text-embedding-3-small",
):
"""Initialize enhanced memory manager.
Args:
workspace: Workspace directory for storing memory
enable_vector_search: Enable vector search (requires dependencies)
vector_persist_dir: Directory for vector store persistence
embedding_provider: Provider type (openai, anthropic, local)
embedding_model: Model name for embeddings
"""
super().__init__(workspace)
self.enable_vector_search = enable_vector_search and VECTOR_MEMORY_AVAILABLE
self.vector_store = None
self.hybrid_search = None
self._embedding_provider_type = embedding_provider
self._embedding_model = embedding_model
if self.enable_vector_search:
try:
self.vector_store = create_vector_memory_store(
persist_dir=vector_persist_dir,
provider_type=embedding_provider,
model=embedding_model,
)
if self.vector_store:
self.hybrid_search = HybridMemorySearch(self.vector_store)
logger.info(f"Vector search enabled for agent memory (provider: {embedding_provider})")
except Exception as e:
logger.warning(f"Failed to initialize vector store: {e}")
self.enable_vector_search = False
async def add_memory_with_embedding(
self,
content: str,
agent_id: str,
user_id: str = "default",
memory_type: str = "conversation",
importance: int = 5,
) -> str | None:
"""Add memory with automatic embedding.
Args:
content: Memory content
agent_id: Agent ID
user_id: User ID
memory_type: Type of memory
importance: Importance score (1-10)
Returns:
Memory ID if vector search enabled
"""
# Also save to markdown file (base class behavior)
self.add_to_memory(content)
# Add to vector store if enabled
if self.vector_store:
return await self.vector_store.add_memory(
content=content,
agent_id=agent_id,
user_id=user_id,
memory_type=memory_type,
importance=importance,
)
return None
async def search_memories(
self,
query: str,
agent_id: str | None = None,
user_id: str | None = None,
n_results: int = 5,
) -> list[dict[str, Any]]:
"""Search memories by semantic similarity.
Args:
query: Search query
agent_id: Filter by agent ID
user_id: Filter by user ID
n_results: Number of results
Returns:
List of matching memories
"""
if not self.hybrid_search:
logger.warning("Vector search not enabled")
return []
return await self.hybrid_search.search(
query=query,
agent_id=agent_id,
user_id=user_id,
n_results=n_results,
)
# Intelligent memory system integration
try:
from .intelligent_memory import (
IntelligentMemorySystem,
MemorySummarizer,
ContextCompressor,
MemoryDecayManager,
EvergreenManager,
SummarizationConfig,
create_intelligent_memory_system,
)
INTELLIGENT_MEMORY_AVAILABLE = True
except ImportError:
IntelligentMemorySystem = None
MemorySummarizer = None
ContextCompressor = None
MemoryDecayManager = None
EvergreenManager = None
SummarizationConfig = None
create_intelligent_memory_system = None
INTELLIGENT_MEMORY_AVAILABLE = False
class CompleteAgentMemory:
"""Complete agent memory with all features."""
def __init__(
self,
workspace: Path,
llm_provider=None,
enable_vector_search: bool = False,
vector_persist_dir: str | None = None,
embedding_provider: str = "openai",
embedding_model: str = "text-embedding-3-small",
context_window: int = 200000,
):
"""Initialize complete memory manager.
Args:
workspace: Workspace directory
llm_provider: LLM provider for summarization
enable_vector_search: Enable vector search
vector_persist_dir: Vector store persistence directory
embedding_provider: Embedding provider type
embedding_model: Embedding model name
context_window: Model context window size
"""
# Base memory
self.base = AgentMemory(workspace)
# Enhanced memory with vector search
self.enhanced = None
if enable_vector_search and VECTOR_MEMORY_AVAILABLE:
self.enhanced = EnhancedAgentMemory(
workspace=workspace,
enable_vector_search=True,
vector_persist_dir=vector_persist_dir,
embedding_provider=embedding_provider,
embedding_model=embedding_model,
)
# Intelligent memory system
self.intelligent = None
if INTELLIGENT_MEMORY_AVAILABLE:
self.intelligent = create_intelligent_memory_system(
llm_provider=llm_provider,
context_window=context_window,
)
# Delegate base methods
def get_memory_context(self) -> str:
return self.base.get_memory_context()
def add_to_memory(self, content: str) -> None:
self.base.add_to_memory(content)
def add_to_history(self, role: str, content: str) -> None:
self.base.add_to_history(role, content)
def get_history(self, session_key: str | None = None, max_messages: int = 10):
return self.base.get_history(session_key, max_messages)
# Delegate enhanced methods
async def add_memory_with_embedding(self, *args, **kwargs):
if self.enhanced:
return await self.enhanced.add_memory_with_embedding(*args, **kwargs)
return None
async def search_memories(self, *args, **kwargs):
if self.enhanced:
return await self.enhanced.search_memories(*args, **kwargs)
return []
# Intelligent methods
async def process_message(
self,
messages: list[dict],
current_tokens: int,
agent_id: str,
user_id: str = "default",
):
"""Process message with intelligent memory management."""
if not self.intelligent:
return messages, None
return await self.intelligent.process_message(
messages, current_tokens, agent_id, user_id
)
def get_evergreen_context(self, memories: list[dict]) -> str:
"""Get evergreen memories for context."""
if not self.intelligent:
return ""
return self.intelligent.get_evergreen_context(memories)
def apply_decay(self, memories: list[dict]) -> list[dict]:
"""Apply decay to memories."""
if not self.intelligent:
return memories
return self.intelligent.apply_decay(memories)