241 lines
8.1 KiB
Python
241 lines
8.1 KiB
Python
|
|
"""Memory system for persistent agent memory."""
|
||
|
|
|
||
|
|
import json
|
||
|
|
import asyncio
|
||
|
|
import weakref
|
||
|
|
from pathlib import Path
|
||
|
|
from typing import Any, Callable, Optional
|
||
|
|
|
||
|
|
try:
|
||
|
|
import tiktoken
|
||
|
|
HAS_TIKTOKEN = True
|
||
|
|
except ImportError:
|
||
|
|
HAS_TIKTOKEN = False
|
||
|
|
|
||
|
|
|
||
|
|
_SAVE_MEMORY_TOOL = [
|
||
|
|
{
|
||
|
|
"type": "function",
|
||
|
|
"function": {
|
||
|
|
"name": "save_memory",
|
||
|
|
"description": "Save the memory consolidation result to persistent storage.",
|
||
|
|
"parameters": {
|
||
|
|
"type": "object",
|
||
|
|
"properties": {
|
||
|
|
"history_entry": {
|
||
|
|
"type": "string",
|
||
|
|
"description": "A paragraph summarizing key events/decisions/topics.",
|
||
|
|
},
|
||
|
|
"memory_update": {
|
||
|
|
"type": "string",
|
||
|
|
"description": "Full updated long-term memory as markdown. Include all existing facts plus new ones.",
|
||
|
|
},
|
||
|
|
},
|
||
|
|
"required": ["history_entry", "memory_update"],
|
||
|
|
},
|
||
|
|
},
|
||
|
|
}
|
||
|
|
]
|
||
|
|
|
||
|
|
|
||
|
|
class MemoryStore:
|
||
|
|
"""Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log)."""
|
||
|
|
|
||
|
|
def __init__(self, workspace: Path):
|
||
|
|
self.memory_dir = workspace / "memory"
|
||
|
|
self.memory_dir.mkdir(parents=True, exist_ok=True)
|
||
|
|
self.memory_file = self.memory_dir / "MEMORY.md"
|
||
|
|
self.history_file = self.memory_dir / "HISTORY.md"
|
||
|
|
|
||
|
|
def read_long_term(self) -> str:
|
||
|
|
if self.memory_file.exists():
|
||
|
|
return self.memory_file.read_text(encoding="utf-8")
|
||
|
|
return ""
|
||
|
|
|
||
|
|
def write_long_term(self, content: str) -> None:
|
||
|
|
self.memory_file.write_text(content, encoding="utf-8")
|
||
|
|
|
||
|
|
def append_history(self, entry: str) -> None:
|
||
|
|
with open(self.history_file, "a", encoding="utf-8") as f:
|
||
|
|
f.write(entry.rstrip() + "\n\n")
|
||
|
|
|
||
|
|
def get_memory_context(self) -> str:
|
||
|
|
long_term = self.read_long_term()
|
||
|
|
return f"## Long-term Memory\n{long_term}" if long_term else ""
|
||
|
|
|
||
|
|
|
||
|
|
def _estimate_tokens(text: str) -> int:
|
||
|
|
"""Estimate token count."""
|
||
|
|
if HAS_TIKTOKEN:
|
||
|
|
try:
|
||
|
|
enc = tiktoken.get_encoding("cl100k_base")
|
||
|
|
return len(enc.encode(text))
|
||
|
|
except Exception:
|
||
|
|
pass
|
||
|
|
return max(1, len(text) // 4)
|
||
|
|
|
||
|
|
|
||
|
|
def _estimate_message_tokens(message: dict[str, Any]) -> int:
|
||
|
|
"""Estimate prompt tokens for a message."""
|
||
|
|
content = message.get("content")
|
||
|
|
parts = []
|
||
|
|
|
||
|
|
if isinstance(content, str):
|
||
|
|
parts.append(content)
|
||
|
|
elif isinstance(content, list):
|
||
|
|
for part in content:
|
||
|
|
if isinstance(part, dict) and part.get("type") == "text":
|
||
|
|
text = part.get("text", "")
|
||
|
|
if text:
|
||
|
|
parts.append(text)
|
||
|
|
else:
|
||
|
|
parts.append(json.dumps(part, ensure_ascii=False))
|
||
|
|
elif content is not None:
|
||
|
|
parts.append(json.dumps(content, ensure_ascii=False))
|
||
|
|
|
||
|
|
for key in ("name", "tool_call_id"):
|
||
|
|
value = message.get(key)
|
||
|
|
if isinstance(value, str) and value:
|
||
|
|
parts.append(value)
|
||
|
|
if message.get("tool_calls"):
|
||
|
|
parts.append(json.dumps(message["tool_calls"], ensure_ascii=False))
|
||
|
|
|
||
|
|
payload = "\n".join(parts)
|
||
|
|
return max(1, _estimate_tokens(payload)) if payload else 1
|
||
|
|
|
||
|
|
|
||
|
|
class MemoryConsolidator:
|
||
|
|
"""Owns consolidation policy, locking, and session offset updates."""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
workspace: Path,
|
||
|
|
provider: Any,
|
||
|
|
model: str,
|
||
|
|
sessions: Any,
|
||
|
|
context_window_tokens: int = 200000,
|
||
|
|
):
|
||
|
|
self.store = MemoryStore(workspace)
|
||
|
|
self.provider = provider
|
||
|
|
self.model = model
|
||
|
|
self.sessions = sessions
|
||
|
|
self.context_window_tokens = context_window_tokens
|
||
|
|
self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary()
|
||
|
|
|
||
|
|
def get_lock(self, session_key: str) -> asyncio.Lock:
|
||
|
|
"""Return the shared consolidation lock for one session."""
|
||
|
|
return self._locks.setdefault(session_key, asyncio.Lock())
|
||
|
|
|
||
|
|
async def consolidate_messages(self, messages: list[dict[str, object]]) -> bool:
|
||
|
|
"""Archive a selected message chunk into persistent memory."""
|
||
|
|
if not messages:
|
||
|
|
return True
|
||
|
|
|
||
|
|
current_memory = self.store.read_long_term()
|
||
|
|
prompt = f"""Process this conversation and call the save_memory tool.
|
||
|
|
|
||
|
|
## Current Long-term Memory
|
||
|
|
{current_memory or "(empty)"}
|
||
|
|
|
||
|
|
## Conversation to Process
|
||
|
|
{self._format_messages(messages)}"""
|
||
|
|
|
||
|
|
try:
|
||
|
|
response = await self.provider.chat_with_retry(
|
||
|
|
messages=[
|
||
|
|
{"role": "system", "content": "You are a memory consolidation agent."},
|
||
|
|
{"role": "user", "content": prompt},
|
||
|
|
],
|
||
|
|
tools=_SAVE_MEMORY_TOOL,
|
||
|
|
model=self.model,
|
||
|
|
)
|
||
|
|
|
||
|
|
if not response.has_tool_calls:
|
||
|
|
return False
|
||
|
|
|
||
|
|
args = response.tool_calls[0].arguments
|
||
|
|
if isinstance(args, str):
|
||
|
|
args = json.loads(args)
|
||
|
|
if isinstance(args, list):
|
||
|
|
args = args[0] if args else {}
|
||
|
|
|
||
|
|
if entry := args.get("history_entry"):
|
||
|
|
self.store.append_history(str(entry))
|
||
|
|
if update := args.get("memory_update"):
|
||
|
|
update = str(update)
|
||
|
|
if update != current_memory:
|
||
|
|
self.store.write_long_term(update)
|
||
|
|
|
||
|
|
return True
|
||
|
|
except Exception:
|
||
|
|
return False
|
||
|
|
|
||
|
|
def _format_messages(self, messages: list[dict]) -> str:
|
||
|
|
lines = []
|
||
|
|
for message in messages:
|
||
|
|
if not message.get("content"):
|
||
|
|
continue
|
||
|
|
lines.append(
|
||
|
|
f"[{message.get('timestamp', '?')[:16]}] {message['role'].upper()}: {message['content']}"
|
||
|
|
)
|
||
|
|
return "\n".join(lines)
|
||
|
|
|
||
|
|
def pick_consolidation_boundary(
|
||
|
|
self,
|
||
|
|
session: Any,
|
||
|
|
tokens_to_remove: int,
|
||
|
|
) -> Optional[tuple[int, int]]:
|
||
|
|
"""Pick a user-turn boundary that removes enough old prompt tokens."""
|
||
|
|
start = session.last_consolidated
|
||
|
|
if start >= len(session.messages) or tokens_to_remove <= 0:
|
||
|
|
return None
|
||
|
|
|
||
|
|
removed_tokens = 0
|
||
|
|
last_boundary: Optional[tuple[int, int]] = None
|
||
|
|
for idx in range(start, len(session.messages)):
|
||
|
|
message = session.messages[idx]
|
||
|
|
if idx > start and message.get("role") == "user":
|
||
|
|
last_boundary = (idx, removed_tokens)
|
||
|
|
if removed_tokens >= tokens_to_remove:
|
||
|
|
return last_boundary
|
||
|
|
removed_tokens += _estimate_message_tokens(message)
|
||
|
|
|
||
|
|
return last_boundary
|
||
|
|
|
||
|
|
async def archive_unconsolidated(self, session: Any) -> bool:
|
||
|
|
"""Archive the full unconsolidated tail for /new-style session rollover."""
|
||
|
|
lock = self.get_lock(session.key)
|
||
|
|
async with lock:
|
||
|
|
snapshot = session.messages[session.last_consolidated:]
|
||
|
|
if not snapshot:
|
||
|
|
return True
|
||
|
|
return await self.consolidate_messages(snapshot)
|
||
|
|
|
||
|
|
async def maybe_consolidate_by_tokens(self, session: Any) -> None:
|
||
|
|
"""Loop: archive old messages until prompt fits within half the context window."""
|
||
|
|
if not session.messages or self.context_window_tokens <= 0:
|
||
|
|
return
|
||
|
|
|
||
|
|
lock = self.get_lock(session.key)
|
||
|
|
async with lock:
|
||
|
|
target = self.context_window_tokens // 2
|
||
|
|
# Simple estimation without full prompt build
|
||
|
|
estimated = sum(_estimate_message_tokens(m) for m in session.messages[session.last_consolidated:])
|
||
|
|
|
||
|
|
if estimated < self.context_window_tokens:
|
||
|
|
return
|
||
|
|
|
||
|
|
# Find boundary and consolidate
|
||
|
|
boundary = self.pick_consolidation_boundary(session, max(1, estimated - target))
|
||
|
|
if boundary is None:
|
||
|
|
return
|
||
|
|
|
||
|
|
end_idx = boundary[0]
|
||
|
|
chunk = session.messages[session.last_consolidated:end_idx]
|
||
|
|
if not chunk:
|
||
|
|
return
|
||
|
|
|
||
|
|
if await self.consolidate_messages(chunk):
|
||
|
|
session.last_consolidated = end_idx
|
||
|
|
self.sessions.save(session)
|