- 新增 loop.py Agent 运行循环 - 优化 memory.py 记忆模块 - 扩展 api/routes.py 接口 - 更新 tools 模块:builtin.py, manager.py, __init__.py - 新增 .env.example 配置示例 - 更新 requirements.txt 依赖 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
978 lines
32 KiB
Python
978 lines
32 KiB
Python
"""Memory management for agent sessions."""
|
||
|
||
import json
|
||
import logging
|
||
from collections import defaultdict
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
from typing import Any
|
||
|
||
import aiohttp
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class SessionMemory:
|
||
"""短期会话记忆 - 内存中的会话消息存储,支持 Markdown 持久化"""
|
||
|
||
def __init__(self, max_messages: int = 50, workspace: Path | str | None = None):
|
||
"""初始化会话记忆
|
||
|
||
Args:
|
||
max_messages: 每个会话保留的最大消息数
|
||
workspace: 工作区目录,用于持久化会话文件
|
||
"""
|
||
self.max_messages = max_messages
|
||
self._sessions: dict[str, list[dict[str, Any]]] = defaultdict(list)
|
||
|
||
# 持久化支持
|
||
self.workspace = Path(workspace) if workspace else None
|
||
self.sessions_dir = None
|
||
if self.workspace:
|
||
self.sessions_dir = self.workspace / "sessions"
|
||
self.sessions_dir.mkdir(parents=True, exist_ok=True)
|
||
# 启动时加载所有会话
|
||
self._load_all_sessions()
|
||
|
||
def _get_session_file(self, session_id: str) -> Path | None:
|
||
"""获取会话文件路径"""
|
||
if not self.sessions_dir:
|
||
return None
|
||
# 清理 session_id 中的非法文件名字符
|
||
safe_id = "".join(c if c.isalnum() or c in "-_" else "_" for c in session_id)
|
||
return self.sessions_dir / f"{safe_id}.md"
|
||
|
||
def _load_all_sessions(self) -> None:
|
||
"""启动时加载所有会话文件"""
|
||
if not self.sessions_dir or not self.sessions_dir.exists():
|
||
return
|
||
|
||
for session_file in self.sessions_dir.glob("*.md"):
|
||
session_id = session_file.stem
|
||
self._load_session(session_id)
|
||
logger.info(f"Loaded session from file: {session_id}")
|
||
|
||
def _load_session(self, session_id: str) -> list[dict[str, Any]]:
|
||
"""从文件加载单个会话
|
||
|
||
Args:
|
||
session_id: 会话ID
|
||
|
||
Returns:
|
||
消息列表
|
||
"""
|
||
session_file = self._get_session_file(session_id)
|
||
if not session_file or not session_file.exists():
|
||
return []
|
||
|
||
try:
|
||
content = session_file.read_text(encoding="utf-8")
|
||
messages = []
|
||
lines = content.strip().split("\n")
|
||
|
||
current_message = {}
|
||
for line in lines:
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
|
||
# 解析 "## 消息 N" 格式
|
||
if line.startswith("## 消息"):
|
||
# 保存上一条消息
|
||
if current_message:
|
||
messages.append(current_message)
|
||
|
||
current_message = {
|
||
"role": "",
|
||
"timestamp": "",
|
||
"content": "",
|
||
}
|
||
continue
|
||
|
||
# 解析 "角色: xxx"
|
||
if line.startswith("角色:") and current_message is not None:
|
||
current_message["role"] = line.split(":", 1)[1].strip()
|
||
continue
|
||
|
||
# 解析 "时间: xxx"
|
||
if line.startswith("时间:") and current_message is not None:
|
||
current_message["timestamp"] = line.split(":", 1)[1].strip()
|
||
continue
|
||
|
||
# 解析 "内容: xxx"
|
||
if line.startswith("内容:") and current_message is not None:
|
||
current_message["content"] = line.split(":", 1)[1].strip()
|
||
continue
|
||
|
||
# 保存最后一条消息
|
||
if current_message and current_message.get("role"):
|
||
messages.append(current_message)
|
||
|
||
# 加载到内存
|
||
if messages:
|
||
self._sessions[session_id] = messages[-self.max_messages:]
|
||
|
||
return messages
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error loading session {session_id}: {e}")
|
||
return []
|
||
|
||
def _save_session(self, session_id: str) -> None:
|
||
"""将会话保存到文件
|
||
|
||
Args:
|
||
session_id: 会话ID
|
||
"""
|
||
session_file = self._get_session_file(session_id)
|
||
if not session_file:
|
||
return
|
||
|
||
messages = self._sessions.get(session_id, [])
|
||
if not messages:
|
||
# 如果会话为空,删除文件
|
||
if session_file.exists():
|
||
session_file.unlink()
|
||
return
|
||
|
||
# 构建 Markdown 内容(使用产品经理指定的格式)
|
||
created_time = messages[0].get("timestamp", datetime.now().isoformat()) if messages else datetime.now().isoformat()
|
||
created_time_str = created_time.replace("T", " ") if "T" in created_time else created_time
|
||
|
||
lines = [
|
||
f"# 会话: {session_id}",
|
||
f"创建时间: {created_time_str}",
|
||
"",
|
||
]
|
||
|
||
for i, msg in enumerate(messages, 1):
|
||
role = msg.get("role", "unknown")
|
||
timestamp = msg.get("timestamp", "")
|
||
content = msg.get("content", "")
|
||
|
||
# 格式化时间
|
||
if "T" in timestamp:
|
||
timestamp = timestamp.replace("T", " ")
|
||
|
||
lines.append(f"## 消息 {i}")
|
||
lines.append(f"角色: {role}")
|
||
lines.append(f"时间: {timestamp}")
|
||
lines.append(f"内容: {content}")
|
||
lines.append("")
|
||
|
||
try:
|
||
session_file.write_text("\n".join(lines), encoding="utf-8")
|
||
except Exception as e:
|
||
logger.error(f"Error saving session {session_id}: {e}")
|
||
|
||
def add_message(self, session_id: str, role: str, content: str, metadata: dict | None = None) -> None:
|
||
"""添加消息到会话
|
||
|
||
Args:
|
||
session_id: 会话ID
|
||
role: 消息角色 (user/assistant/system)
|
||
content: 消息内容
|
||
metadata: 附加元数据
|
||
"""
|
||
message = {
|
||
"role": role,
|
||
"content": content,
|
||
"timestamp": datetime.now().isoformat(),
|
||
}
|
||
if metadata:
|
||
message["metadata"] = metadata
|
||
|
||
session_messages = self._sessions[session_id]
|
||
session_messages.append(message)
|
||
|
||
# 超过最大消息数时,移除最旧的消息
|
||
if len(session_messages) > self.max_messages:
|
||
self._sessions[session_id] = session_messages[-self.max_messages:]
|
||
|
||
# 持久化到文件
|
||
self._save_session(session_id)
|
||
|
||
def get_history(self, session_id: str, max_messages: int = 0) -> list[dict[str, Any]]:
|
||
"""获取会话历史
|
||
|
||
Args:
|
||
session_id: 会话ID
|
||
max_messages: 返回的最大消息数,0表示全部
|
||
|
||
Returns:
|
||
消息列表
|
||
"""
|
||
# 如果内存中没有,尝试从文件加载
|
||
if session_id not in self._sessions:
|
||
self._load_session(session_id)
|
||
|
||
messages = self._sessions.get(session_id, [])
|
||
if max_messages > 0 and len(messages) > max_messages:
|
||
return messages[-max_messages:]
|
||
return messages
|
||
|
||
def clear_session(self, session_id: str) -> None:
|
||
"""清除会话记忆
|
||
|
||
Args:
|
||
session_id: 会话ID
|
||
"""
|
||
if session_id in self._sessions:
|
||
del self._sessions[session_id]
|
||
|
||
# 删除会话文件
|
||
session_file = self._get_session_file(session_id)
|
||
if session_file and session_file.exists():
|
||
session_file.unlink()
|
||
|
||
def get_session_count(self) -> int:
|
||
"""获取当前会话数量"""
|
||
return len(self._sessions)
|
||
|
||
def list_sessions(self) -> list[str]:
|
||
"""列出所有会话ID"""
|
||
return list(self._sessions.keys())
|
||
|
||
|
||
class RemoteMemoryClient:
|
||
"""与Go端Memory API对接的客户端"""
|
||
|
||
def __init__(self, base_url: str, agent_id: str, user_id: str = "default"):
|
||
"""初始化远程记忆客户端
|
||
|
||
Args:
|
||
base_url: Go服务端地址
|
||
agent_id: Agent ID
|
||
user_id: 用户ID
|
||
"""
|
||
self.base_url = base_url.rstrip("/")
|
||
self.agent_id = agent_id
|
||
self.user_id = user_id
|
||
self._session = None
|
||
|
||
async def _get_session(self) -> aiohttp.ClientSession:
|
||
"""获取或创建aiohttp session"""
|
||
if self._session is None or self._session.closed:
|
||
self._session = aiohttp.ClientSession()
|
||
return self._session
|
||
|
||
async def close(self) -> None:
|
||
"""关闭session"""
|
||
if self._session and not self._session.closed:
|
||
await self._session.close()
|
||
|
||
async def create_memory(
|
||
self,
|
||
content: str,
|
||
memory_type: str = "conversation",
|
||
importance: int = 5,
|
||
) -> dict[str, Any] | None:
|
||
"""创建记忆
|
||
|
||
Args:
|
||
content: 记忆内容
|
||
memory_type: 记忆类型 (conversation/experience/lessons)
|
||
importance: 重要性评分 1-10
|
||
|
||
Returns:
|
||
创建的记忆对象
|
||
"""
|
||
url = f"{self.base_url}/api/agent/{self.agent_id}/memories"
|
||
payload = {
|
||
"agent_id": self.agent_id,
|
||
"user_id": self.user_id,
|
||
"content": content,
|
||
"memory_type": memory_type,
|
||
"importance": importance,
|
||
}
|
||
|
||
try:
|
||
session = await self._get_session()
|
||
async with session.post(url, json=payload) as response:
|
||
if response.status == 200:
|
||
return await response.json()
|
||
logger.warning(f"Failed to create memory: {response.status}")
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"Error creating memory: {e}")
|
||
return None
|
||
|
||
async def get_memories(
|
||
self,
|
||
limit: int = 10,
|
||
offset: int = 0,
|
||
memory_type: str | None = None,
|
||
category: str | None = None,
|
||
) -> list[dict[str, Any]]:
|
||
"""获取记忆列表
|
||
|
||
Args:
|
||
limit: 返回数量限制
|
||
offset: 偏移量
|
||
memory_type: 记忆类型筛选
|
||
category: 分类筛选
|
||
|
||
Returns:
|
||
记忆列表
|
||
"""
|
||
url = f"{self.base_url}/api/agent/{self.agent_id}/memories"
|
||
params = {
|
||
"user_id": self.user_id,
|
||
"limit": limit,
|
||
"offset": offset,
|
||
}
|
||
if memory_type:
|
||
params["memory_type"] = memory_type
|
||
if category:
|
||
params["category"] = category
|
||
|
||
try:
|
||
session = await self._get_session()
|
||
async with session.get(url, params=params) as response:
|
||
if response.status == 200:
|
||
result = await response.json()
|
||
return result if isinstance(result, list) else result.get("list", [])
|
||
return []
|
||
except Exception as e:
|
||
logger.error(f"Error getting memories: {e}")
|
||
return []
|
||
|
||
async def search_memories(
|
||
self,
|
||
keyword: str,
|
||
tags: str | None = None,
|
||
category: str | None = None,
|
||
memory_type: str | None = None,
|
||
min_score: int = 0,
|
||
limit: int = 10,
|
||
offset: int = 0,
|
||
) -> list[dict[str, Any]]:
|
||
"""搜索记忆(关键词搜索)
|
||
|
||
Args:
|
||
keyword: 搜索关键词
|
||
tags: 标签筛选
|
||
category: 分类筛选
|
||
memory_type: 记忆类型筛选
|
||
min_score: 最低重要性分数
|
||
limit: 返回数量限制
|
||
offset: 偏移量
|
||
|
||
Returns:
|
||
记忆列表
|
||
"""
|
||
url = f"{self.base_url}/api/agent/{self.agent_id}/memories/search"
|
||
payload = {
|
||
"agent_id": self.agent_id,
|
||
"user_id": self.user_id,
|
||
"keyword": keyword,
|
||
"limit": limit,
|
||
"offset": offset,
|
||
}
|
||
if tags:
|
||
payload["tags"] = tags
|
||
if category:
|
||
payload["category"] = category
|
||
if memory_type:
|
||
payload["memory_type"] = memory_type
|
||
if min_score > 0:
|
||
payload["min_score"] = min_score
|
||
|
||
try:
|
||
session = await self._get_session()
|
||
async with session.post(url, json=payload) as response:
|
||
if response.status == 200:
|
||
result = await response.json()
|
||
return result.get("list", [])
|
||
return []
|
||
except Exception as e:
|
||
logger.error(f"Error searching memories: {e}")
|
||
return []
|
||
|
||
async def get_categories(self) -> list[str]:
|
||
"""获取记忆分类列表
|
||
|
||
Returns:
|
||
分类列表
|
||
"""
|
||
url = f"{self.base_url}/api/agent/{self.agent_id}/memories/categories"
|
||
params = {"user_id": self.user_id}
|
||
|
||
try:
|
||
session = await self._get_session()
|
||
async with session.get(url, params=params) as response:
|
||
if response.status == 200:
|
||
result = await response.json()
|
||
return result.get("categories", [])
|
||
return []
|
||
except Exception as e:
|
||
logger.error(f"Error getting categories: {e}")
|
||
return []
|
||
|
||
async def get_tags(self) -> list[str]:
|
||
"""获取记忆标签列表
|
||
|
||
Returns:
|
||
标签列表
|
||
"""
|
||
url = f"{self.base_url}/api/agent/{self.agent_id}/memories/tags"
|
||
params = {"user_id": self.user_id}
|
||
|
||
try:
|
||
session = await self._get_session()
|
||
async with session.get(url, params=params) as response:
|
||
if response.status == 200:
|
||
result = await response.json()
|
||
return result.get("tags", [])
|
||
return []
|
||
except Exception as e:
|
||
logger.error(f"Error getting tags: {e}")
|
||
return []
|
||
|
||
async def delete_memory(self, memory_id: str) -> bool:
|
||
"""删除记忆
|
||
|
||
Args:
|
||
memory_id: 记忆ID
|
||
|
||
Returns:
|
||
是否删除成功
|
||
"""
|
||
url = f"{self.base_url}/api/agent/{self.agent_id}/memories/{memory_id}"
|
||
|
||
try:
|
||
session = await self._get_session()
|
||
async with session.delete(url) as response:
|
||
return response.status == 200
|
||
except Exception as e:
|
||
logger.error(f"Error deleting memory: {e}")
|
||
return False
|
||
|
||
|
||
class AgentMemory:
|
||
"""Manages agent memory and session history."""
|
||
|
||
def __init__(self, workspace: Path):
|
||
"""Initialize the memory manager.
|
||
|
||
Args:
|
||
workspace: Workspace directory for storing memory
|
||
"""
|
||
self.workspace = workspace
|
||
self.memory_dir = workspace / "memory"
|
||
self.memory_dir.mkdir(exist_ok=True)
|
||
|
||
self.long_term_file = self.memory_dir / "MEMORY.md"
|
||
|
||
# Session-specific history
|
||
self.sessions_dir = self.memory_dir / "sessions"
|
||
self.sessions_dir.mkdir(exist_ok=True)
|
||
|
||
# Legacy history file (for backward compatibility)
|
||
self.history_file = self.memory_dir / "HISTORY.md"
|
||
|
||
def _get_session_file(self, session_key: str) -> Path:
|
||
"""Get session file path."""
|
||
# Sanitize session_key for filename
|
||
safe_key = "".join(c if c.isalnum() or c in "-_" else "_" for c in session_key)
|
||
return self.sessions_dir / f"{safe_key}.md"
|
||
|
||
def get_memory_context(self) -> str:
|
||
"""Get long-term memory content.
|
||
|
||
Returns:
|
||
Memory context string
|
||
"""
|
||
if self.long_term_file.exists():
|
||
return self.long_term_file.read_text(encoding="utf-8")
|
||
return ""
|
||
|
||
def add_to_memory(self, content: str) -> None:
|
||
"""Add content to long-term memory.
|
||
|
||
Args:
|
||
content: Content to add to memory
|
||
"""
|
||
with open(self.long_term_file, "a", encoding="utf-8") as f:
|
||
f.write(f"\n{content}")
|
||
|
||
def add_to_history(self, role: str, content: str, session_key: str | None = None) -> None:
|
||
"""Add an entry to conversation history.
|
||
|
||
Args:
|
||
role: Message role (user/assistant)
|
||
content: Message content
|
||
session_key: Session identifier for session-specific history
|
||
"""
|
||
timestamp = datetime.now().isoformat()
|
||
|
||
# If session_key provided, save to session file
|
||
if session_key:
|
||
self._add_to_session_history(session_key, role, content, timestamp)
|
||
else:
|
||
# Legacy: save to global history file
|
||
legacy_timestamp = datetime.now().strftime("%Y-%m-%d %H:%M")
|
||
entry = f"[{legacy_timestamp}] {role}: {content}\n"
|
||
with open(self.history_file, "a", encoding="utf-8") as f:
|
||
f.write(entry)
|
||
|
||
def _add_to_session_history(self, session_key: str, role: str, content: str, timestamp: str) -> None:
|
||
"""Add message to session-specific history file."""
|
||
session_file = self._get_session_file(session_key)
|
||
|
||
# Format timestamp for display
|
||
display_timestamp = timestamp.replace("T", " ") if "T" in timestamp else timestamp
|
||
|
||
# Determine header format based on whether file exists
|
||
header = ""
|
||
if not session_file.exists():
|
||
header = f"# 会话: {session_key}\n创建时间: {display_timestamp}\n\n"
|
||
|
||
# Count existing messages to determine message number
|
||
msg_count = 1
|
||
if session_file.exists():
|
||
try:
|
||
existing = session_file.read_text(encoding="utf-8")
|
||
msg_count = existing.count("## 消息") + 1
|
||
except:
|
||
pass
|
||
|
||
# Check if content contains tool_calls or tool_result markers
|
||
# Format as Markdown (产品经理指定格式)
|
||
entry_lines = [
|
||
f"## 消息 {msg_count}",
|
||
f"角色: {role}",
|
||
f"时间: {display_timestamp}",
|
||
]
|
||
|
||
# Handle tool_calls and tool_result content
|
||
if content.startswith("[tool_calls]"):
|
||
entry_lines.append(f"工具调用: {content[len('[tool_calls]'):]}")
|
||
entry_lines.append(f"内容: ")
|
||
elif content.startswith("[tool_result]"):
|
||
entry_lines.append(f"工具结果: {content[len('[tool_result]'):]}")
|
||
entry_lines.append(f"内容: ")
|
||
else:
|
||
entry_lines.append(f"内容: {content}")
|
||
|
||
entry = "\n".join(entry_lines) + "\n\n"
|
||
|
||
with open(session_file, "a", encoding="utf-8") as f:
|
||
if header:
|
||
f.write(header)
|
||
f.write(entry)
|
||
|
||
def get_history(
|
||
self,
|
||
session_key: str | None = None,
|
||
max_messages: int = 10,
|
||
) -> list[dict[str, Any]]:
|
||
"""Get conversation history.
|
||
|
||
Args:
|
||
session_key: Optional session key for session-specific history
|
||
max_messages: Maximum number of messages to return
|
||
|
||
Returns:
|
||
List of history messages
|
||
"""
|
||
# If session_key provided, load from session file
|
||
if session_key:
|
||
return self._get_session_history(session_key, max_messages)
|
||
|
||
# Legacy: load from global history file
|
||
return self._get_legacy_history(max_messages)
|
||
|
||
def _get_session_history(self, session_key: str, max_messages: int) -> list[dict[str, Any]]:
|
||
"""Get history from session file."""
|
||
session_file = self._get_session_file(session_key)
|
||
if not session_file.exists():
|
||
return []
|
||
|
||
try:
|
||
content = session_file.read_text(encoding="utf-8")
|
||
lines = content.strip().split("\n")
|
||
messages = []
|
||
|
||
current_message = {}
|
||
for line in lines:
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
|
||
# Skip headers
|
||
if line.startswith("#"):
|
||
continue
|
||
|
||
# Parse "## 消息 N"
|
||
if line.startswith("## 消息"):
|
||
# Save previous message
|
||
if current_message and current_message.get("role"):
|
||
messages.append(current_message)
|
||
|
||
current_message = {
|
||
"role": "",
|
||
"timestamp": "",
|
||
"content": "",
|
||
}
|
||
continue
|
||
|
||
# Parse "角色: xxx"
|
||
if line.startswith("角色:") and current_message is not None:
|
||
current_message["role"] = line.split(":", 1)[1].strip()
|
||
continue
|
||
|
||
# Parse "时间: xxx"
|
||
if line.startswith("时间:") and current_message is not None:
|
||
current_message["timestamp"] = line.split(":", 1)[1].strip()
|
||
continue
|
||
|
||
# Parse "工具调用: xxx" - for tool_calls
|
||
if line.startswith("工具调用:") and current_message is not None:
|
||
tool_calls_json = line.split(":", 1)[1].strip()
|
||
try:
|
||
current_message["tool_calls"] = json.loads(tool_calls_json)
|
||
except json.JSONDecodeError:
|
||
pass
|
||
continue
|
||
|
||
# Parse "工具结果: xxx" - for tool_result
|
||
if line.startswith("工具结果:") and current_message is not None:
|
||
tool_result_json = line.split(":", 1)[1].strip()
|
||
try:
|
||
tool_result = json.loads(tool_result_json)
|
||
current_message["tool_call_id"] = tool_result.get("tool_call_id", "")
|
||
current_message["name"] = tool_result.get("name", "")
|
||
current_message["content"] = tool_result.get("content", "")
|
||
except json.JSONDecodeError:
|
||
pass
|
||
continue
|
||
|
||
# Parse "内容: xxx"
|
||
if line.startswith("内容:") and current_message is not None:
|
||
current_message["content"] = line.split(":", 1)[1].strip()
|
||
continue
|
||
|
||
# Content line
|
||
if current_message:
|
||
if current_message.get("content"):
|
||
current_message["content"] += "\n" + line
|
||
else:
|
||
current_message["content"] = line
|
||
|
||
# Save last message
|
||
if current_message:
|
||
messages.append(current_message)
|
||
|
||
# Return most recent messages
|
||
if max_messages > 0 and len(messages) > max_messages:
|
||
return messages[-max_messages:]
|
||
return messages
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error reading session history: {e}")
|
||
return []
|
||
|
||
def _get_legacy_history(self, max_messages: int) -> list[dict[str, Any]]:
|
||
"""Get history from legacy history file."""
|
||
if not self.history_file.exists():
|
||
return []
|
||
|
||
try:
|
||
content = self.history_file.read_text(encoding="utf-8")
|
||
lines = content.strip().split("\n")
|
||
messages = []
|
||
|
||
for line in lines[-max_messages * 2:]:
|
||
if ": " in line:
|
||
try:
|
||
_, rest = line.split("] ", 1)
|
||
role, content = rest.split(": ", 1)
|
||
messages.append({"role": role, "content": content})
|
||
except ValueError:
|
||
continue
|
||
|
||
return messages[-max_messages:] if max_messages > 0 else messages
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error reading legacy history: {e}")
|
||
return []
|
||
|
||
def clear_session(self, session_key: str) -> None:
|
||
"""Clear a specific session's history.
|
||
|
||
Args:
|
||
session_key: Session key to clear
|
||
"""
|
||
session_file = self._get_session_file(session_key)
|
||
if session_file.exists():
|
||
session_file.unlink()
|
||
|
||
for line in lines[-max_messages * 2:]:
|
||
if ": " in line:
|
||
# Skip timestamp prefix
|
||
try:
|
||
_, rest = line.split("] ", 1)
|
||
role, content = rest.split(": ", 1)
|
||
messages.append({"role": role, "content": content})
|
||
except ValueError:
|
||
continue
|
||
|
||
return messages[-max_messages:]
|
||
|
||
return []
|
||
|
||
def clear_session(self, session_key: str) -> None:
|
||
"""Clear a specific session's history.
|
||
|
||
Args:
|
||
session_key: Session key to clear
|
||
"""
|
||
# In a full implementation, you'd handle session-specific storage
|
||
pass
|
||
|
||
|
||
# Vector memory integration
|
||
try:
|
||
from .vector_memory import (
|
||
VectorMemoryStore,
|
||
HybridMemorySearch,
|
||
EmbeddingProvider,
|
||
create_vector_memory_store,
|
||
)
|
||
VECTOR_MEMORY_AVAILABLE = True
|
||
except ImportError:
|
||
VectorMemoryStore = None
|
||
HybridMemorySearch = None
|
||
EmbeddingProvider = None
|
||
create_vector_memory_store = None
|
||
VECTOR_MEMORY_AVAILABLE = False
|
||
|
||
|
||
class EnhancedAgentMemory(AgentMemory):
|
||
"""Enhanced agent memory with vector search capabilities."""
|
||
|
||
def __init__(
|
||
self,
|
||
workspace: Path,
|
||
enable_vector_search: bool = False,
|
||
vector_persist_dir: str | None = None,
|
||
embedding_provider: str = "openai",
|
||
embedding_model: str = "text-embedding-3-small",
|
||
):
|
||
"""Initialize enhanced memory manager.
|
||
|
||
Args:
|
||
workspace: Workspace directory for storing memory
|
||
enable_vector_search: Enable vector search (requires dependencies)
|
||
vector_persist_dir: Directory for vector store persistence
|
||
embedding_provider: Provider type (openai, anthropic, local)
|
||
embedding_model: Model name for embeddings
|
||
"""
|
||
super().__init__(workspace)
|
||
|
||
self.enable_vector_search = enable_vector_search and VECTOR_MEMORY_AVAILABLE
|
||
self.vector_store = None
|
||
self.hybrid_search = None
|
||
self._embedding_provider_type = embedding_provider
|
||
self._embedding_model = embedding_model
|
||
|
||
if self.enable_vector_search:
|
||
try:
|
||
self.vector_store = create_vector_memory_store(
|
||
persist_dir=vector_persist_dir,
|
||
provider_type=embedding_provider,
|
||
model=embedding_model,
|
||
)
|
||
if self.vector_store:
|
||
self.hybrid_search = HybridMemorySearch(self.vector_store)
|
||
logger.info(f"Vector search enabled for agent memory (provider: {embedding_provider})")
|
||
except Exception as e:
|
||
logger.warning(f"Failed to initialize vector store: {e}")
|
||
self.enable_vector_search = False
|
||
|
||
async def add_memory_with_embedding(
|
||
self,
|
||
content: str,
|
||
agent_id: str,
|
||
user_id: str = "default",
|
||
memory_type: str = "conversation",
|
||
importance: int = 5,
|
||
) -> str | None:
|
||
"""Add memory with automatic embedding.
|
||
|
||
Args:
|
||
content: Memory content
|
||
agent_id: Agent ID
|
||
user_id: User ID
|
||
memory_type: Type of memory
|
||
importance: Importance score (1-10)
|
||
|
||
Returns:
|
||
Memory ID if vector search enabled
|
||
"""
|
||
# Also save to markdown file (base class behavior)
|
||
self.add_to_memory(content)
|
||
|
||
# Add to vector store if enabled
|
||
if self.vector_store:
|
||
return await self.vector_store.add_memory(
|
||
content=content,
|
||
agent_id=agent_id,
|
||
user_id=user_id,
|
||
memory_type=memory_type,
|
||
importance=importance,
|
||
)
|
||
return None
|
||
|
||
async def search_memories(
|
||
self,
|
||
query: str,
|
||
agent_id: str | None = None,
|
||
user_id: str | None = None,
|
||
n_results: int = 5,
|
||
) -> list[dict[str, Any]]:
|
||
"""Search memories by semantic similarity.
|
||
|
||
Args:
|
||
query: Search query
|
||
agent_id: Filter by agent ID
|
||
user_id: Filter by user ID
|
||
n_results: Number of results
|
||
|
||
Returns:
|
||
List of matching memories
|
||
"""
|
||
if not self.hybrid_search:
|
||
logger.warning("Vector search not enabled")
|
||
return []
|
||
|
||
return await self.hybrid_search.search(
|
||
query=query,
|
||
agent_id=agent_id,
|
||
user_id=user_id,
|
||
n_results=n_results,
|
||
)
|
||
|
||
|
||
# Intelligent memory system integration
|
||
try:
|
||
from .intelligent_memory import (
|
||
IntelligentMemorySystem,
|
||
MemorySummarizer,
|
||
ContextCompressor,
|
||
MemoryDecayManager,
|
||
EvergreenManager,
|
||
SummarizationConfig,
|
||
create_intelligent_memory_system,
|
||
)
|
||
INTELLIGENT_MEMORY_AVAILABLE = True
|
||
except ImportError:
|
||
IntelligentMemorySystem = None
|
||
MemorySummarizer = None
|
||
ContextCompressor = None
|
||
MemoryDecayManager = None
|
||
EvergreenManager = None
|
||
SummarizationConfig = None
|
||
create_intelligent_memory_system = None
|
||
INTELLIGENT_MEMORY_AVAILABLE = False
|
||
|
||
|
||
class CompleteAgentMemory:
|
||
"""Complete agent memory with all features."""
|
||
|
||
def __init__(
|
||
self,
|
||
workspace: Path,
|
||
llm_provider=None,
|
||
enable_vector_search: bool = False,
|
||
vector_persist_dir: str | None = None,
|
||
embedding_provider: str = "openai",
|
||
embedding_model: str = "text-embedding-3-small",
|
||
context_window: int = 200000,
|
||
):
|
||
"""Initialize complete memory manager.
|
||
|
||
Args:
|
||
workspace: Workspace directory
|
||
llm_provider: LLM provider for summarization
|
||
enable_vector_search: Enable vector search
|
||
vector_persist_dir: Vector store persistence directory
|
||
embedding_provider: Embedding provider type
|
||
embedding_model: Embedding model name
|
||
context_window: Model context window size
|
||
"""
|
||
# Base memory
|
||
self.base = AgentMemory(workspace)
|
||
|
||
# Enhanced memory with vector search
|
||
self.enhanced = None
|
||
if enable_vector_search and VECTOR_MEMORY_AVAILABLE:
|
||
self.enhanced = EnhancedAgentMemory(
|
||
workspace=workspace,
|
||
enable_vector_search=True,
|
||
vector_persist_dir=vector_persist_dir,
|
||
embedding_provider=embedding_provider,
|
||
embedding_model=embedding_model,
|
||
)
|
||
|
||
# Intelligent memory system
|
||
self.intelligent = None
|
||
if INTELLIGENT_MEMORY_AVAILABLE:
|
||
self.intelligent = create_intelligent_memory_system(
|
||
llm_provider=llm_provider,
|
||
context_window=context_window,
|
||
)
|
||
|
||
# Delegate base methods
|
||
def get_memory_context(self) -> str:
|
||
return self.base.get_memory_context()
|
||
|
||
def add_to_memory(self, content: str) -> None:
|
||
self.base.add_to_memory(content)
|
||
|
||
def add_to_history(self, role: str, content: str) -> None:
|
||
self.base.add_to_history(role, content)
|
||
|
||
def get_history(self, session_key: str | None = None, max_messages: int = 10):
|
||
return self.base.get_history(session_key, max_messages)
|
||
|
||
# Delegate enhanced methods
|
||
async def add_memory_with_embedding(self, *args, **kwargs):
|
||
if self.enhanced:
|
||
return await self.enhanced.add_memory_with_embedding(*args, **kwargs)
|
||
return None
|
||
|
||
async def search_memories(self, *args, **kwargs):
|
||
if self.enhanced:
|
||
return await self.enhanced.search_memories(*args, **kwargs)
|
||
return []
|
||
|
||
# Intelligent methods
|
||
async def process_message(
|
||
self,
|
||
messages: list[dict],
|
||
current_tokens: int,
|
||
agent_id: str,
|
||
user_id: str = "default",
|
||
):
|
||
"""Process message with intelligent memory management."""
|
||
if not self.intelligent:
|
||
return messages, None
|
||
|
||
return await self.intelligent.process_message(
|
||
messages, current_tokens, agent_id, user_id
|
||
)
|
||
|
||
def get_evergreen_context(self, memories: list[dict]) -> str:
|
||
"""Get evergreen memories for context."""
|
||
if not self.intelligent:
|
||
return ""
|
||
return self.intelligent.get_evergreen_context(memories)
|
||
|
||
def apply_decay(self, memories: list[dict]) -> list[dict]:
|
||
"""Apply decay to memories."""
|
||
if not self.intelligent:
|
||
return memories
|
||
return self.intelligent.apply_decay(memories)
|