"""Memory system for persistent agent memory.""" import json import asyncio import weakref from pathlib import Path from typing import Any, Callable, Optional try: import tiktoken HAS_TIKTOKEN = True except ImportError: HAS_TIKTOKEN = False _SAVE_MEMORY_TOOL = [ { "type": "function", "function": { "name": "save_memory", "description": "Save the memory consolidation result to persistent storage.", "parameters": { "type": "object", "properties": { "history_entry": { "type": "string", "description": "A paragraph summarizing key events/decisions/topics.", }, "memory_update": { "type": "string", "description": "Full updated long-term memory as markdown. Include all existing facts plus new ones.", }, }, "required": ["history_entry", "memory_update"], }, }, } ] class MemoryStore: """Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log).""" def __init__(self, workspace: Path): self.memory_dir = workspace / "memory" self.memory_dir.mkdir(parents=True, exist_ok=True) self.memory_file = self.memory_dir / "MEMORY.md" self.history_file = self.memory_dir / "HISTORY.md" def read_long_term(self) -> str: if self.memory_file.exists(): return self.memory_file.read_text(encoding="utf-8") return "" def write_long_term(self, content: str) -> None: self.memory_file.write_text(content, encoding="utf-8") def append_history(self, entry: str) -> None: with open(self.history_file, "a", encoding="utf-8") as f: f.write(entry.rstrip() + "\n\n") def get_memory_context(self) -> str: long_term = self.read_long_term() return f"## Long-term Memory\n{long_term}" if long_term else "" def _estimate_tokens(text: str) -> int: """Estimate token count.""" if HAS_TIKTOKEN: try: enc = tiktoken.get_encoding("cl100k_base") return len(enc.encode(text)) except Exception: pass return max(1, len(text) // 4) def _estimate_message_tokens(message: dict[str, Any]) -> int: """Estimate prompt tokens for a message.""" content = message.get("content") parts = [] if isinstance(content, str): parts.append(content) elif isinstance(content, list): for part in content: if isinstance(part, dict) and part.get("type") == "text": text = part.get("text", "") if text: parts.append(text) else: parts.append(json.dumps(part, ensure_ascii=False)) elif content is not None: parts.append(json.dumps(content, ensure_ascii=False)) for key in ("name", "tool_call_id"): value = message.get(key) if isinstance(value, str) and value: parts.append(value) if message.get("tool_calls"): parts.append(json.dumps(message["tool_calls"], ensure_ascii=False)) payload = "\n".join(parts) return max(1, _estimate_tokens(payload)) if payload else 1 class MemoryConsolidator: """Owns consolidation policy, locking, and session offset updates.""" def __init__( self, workspace: Path, provider: Any, model: str, sessions: Any, context_window_tokens: int = 200000, ): self.store = MemoryStore(workspace) self.provider = provider self.model = model self.sessions = sessions self.context_window_tokens = context_window_tokens self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary() def get_lock(self, session_key: str) -> asyncio.Lock: """Return the shared consolidation lock for one session.""" return self._locks.setdefault(session_key, asyncio.Lock()) async def consolidate_messages(self, messages: list[dict[str, object]]) -> bool: """Archive a selected message chunk into persistent memory.""" if not messages: return True current_memory = self.store.read_long_term() prompt = f"""Process this conversation and call the save_memory tool. ## Current Long-term Memory {current_memory or "(empty)"} ## Conversation to Process {self._format_messages(messages)}""" try: response = await self.provider.chat_with_retry( messages=[ {"role": "system", "content": "You are a memory consolidation agent."}, {"role": "user", "content": prompt}, ], tools=_SAVE_MEMORY_TOOL, model=self.model, ) if not response.has_tool_calls: return False args = response.tool_calls[0].arguments if isinstance(args, str): args = json.loads(args) if isinstance(args, list): args = args[0] if args else {} if entry := args.get("history_entry"): self.store.append_history(str(entry)) if update := args.get("memory_update"): update = str(update) if update != current_memory: self.store.write_long_term(update) return True except Exception: return False def _format_messages(self, messages: list[dict]) -> str: lines = [] for message in messages: if not message.get("content"): continue lines.append( f"[{message.get('timestamp', '?')[:16]}] {message['role'].upper()}: {message['content']}" ) return "\n".join(lines) def pick_consolidation_boundary( self, session: Any, tokens_to_remove: int, ) -> Optional[tuple[int, int]]: """Pick a user-turn boundary that removes enough old prompt tokens.""" start = session.last_consolidated if start >= len(session.messages) or tokens_to_remove <= 0: return None removed_tokens = 0 last_boundary: Optional[tuple[int, int]] = None for idx in range(start, len(session.messages)): message = session.messages[idx] if idx > start and message.get("role") == "user": last_boundary = (idx, removed_tokens) if removed_tokens >= tokens_to_remove: return last_boundary removed_tokens += _estimate_message_tokens(message) return last_boundary async def archive_unconsolidated(self, session: Any) -> bool: """Archive the full unconsolidated tail for /new-style session rollover.""" lock = self.get_lock(session.key) async with lock: snapshot = session.messages[session.last_consolidated:] if not snapshot: return True return await self.consolidate_messages(snapshot) async def maybe_consolidate_by_tokens(self, session: Any) -> None: """Loop: archive old messages until prompt fits within half the context window.""" if not session.messages or self.context_window_tokens <= 0: return lock = self.get_lock(session.key) async with lock: target = self.context_window_tokens // 2 # Simple estimation without full prompt build estimated = sum(_estimate_message_tokens(m) for m in session.messages[session.last_consolidated:]) if estimated < self.context_window_tokens: return # Find boundary and consolidate boundary = self.pick_consolidation_boundary(session, max(1, estimated - target)) if boundary is None: return end_idx = boundary[0] chunk = session.messages[session.last_consolidated:end_idx] if not chunk: return if await self.consolidate_messages(chunk): session.last_consolidated = end_idx self.sessions.save(session)