""" MemoryRecallInjector Injects relevant memories into LLM system prompt before response generation. Token budget: 800 by default (configurable). Priority: pain_point > goal > preference > fact > event """ from typing import TYPE_CHECKING if TYPE_CHECKING: from sqlalchemy.ext.asyncio import AsyncSession from app.models.memory import UserMemory MEMORY_TYPE_PRIORITY = { "pain_point": 1, "goal": 2, "preference": 3, "fact": 4, "event": 5, } DEFAULT_TOKEN_BUDGET = 800 class MemoryRecallInjector: """Inject relevant memories into system prompt with token budget control.""" def __init__(self, token_budget: int = DEFAULT_TOKEN_BUDGET): self.token_budget = token_budget async def build_context( self, db: "AsyncSession", user_id: str, current_message: str, token_budget: int | None = None, ) -> str: """Build memory context string for injection into system prompt. 1. Recall relevant memories (top_k=20) 2. Filter out archived memories 3. Rank by importance × relevance × type priority 4. Select within token budget 5. Format as system prompt fragment """ budget = token_budget or self.token_budget # 1. Recall candidates using existing memory service candidates = await recall_user_memories_for_injection( db, user_id, current_message, top_k=20 ) # 2. Filter archived and non-UserMemory active = [ m for m in candidates if isinstance(m, UserMemory) and not getattr(m, "is_archived", False) ] # 3. Rank ranked = self._rank(active, current_message) # 4. Budget select selected = self._budget_select(ranked, budget) # 5. Format return self._format(selected) def _rank( self, memories: list["UserMemory"], query: str, ) -> list["UserMemory"]: """Rank memories by: relevance * 0.6 + importance * 0.4 * type_boost. pain_point/goal get 1.0 type boost, others get 0.8. """ def score(m: "UserMemory") -> float: relevance = ( getattr(m, "similarity_score", 0.5) if hasattr(m, "similarity_score") else 0.5 ) importance = getattr(m, "importance_score", 0.5) or 0.5 mem_type = getattr(m, "memory_type", None) or "fact" type_boost = 1.0 if mem_type in ("goal", "pain_point") else 0.8 return relevance * 0.6 + importance * 0.4 * type_boost return sorted(memories, key=score, reverse=True) def _budget_select( self, memories: list["UserMemory"], token_budget: int, ) -> list["UserMemory"]: """Greedy selection until token budget runs out. Rough estimate: 1 token ≈ 2 characters, fixed overhead = 20 tokens. """ selected = [] used = 20 # "[关于你的记忆]\n" for m in memories: content = getattr(m, "content", "") or "" cost = len(content) // 2 + 10 if used + cost > token_budget: break selected.append(m) used += cost return selected def _format(self, memories: list["UserMemory"]) -> str: """Format memories as system prompt fragment.""" if not memories: return "" lines = ["[关于你的记忆]"] for m in memories: mem_type = getattr(m, "memory_type", None) or "" content = getattr(m, "content", "") or "" type_label = f"[{mem_type}]" if mem_type else "" lines.append(f"- {type_label} {content}".strip()) return "\n".join(lines) async def recall_user_memories_for_injection( db: "AsyncSession", user_id: str, query: str, top_k: int = 5, ) -> list: """Recall user memories for injection (used by MemoryRecallInjector). This is a simplified version of recall_user_memories that returns UserMemory objects directly instead of dicts. """ import re from sqlalchemy import select from app.models.memory import UserMemory def _extract_query_tokens(q: str) -> list[str]: normalized = (q or "").lower() tokens = [token for token in re.findall(r"[a-z0-9]+", normalized) if len(token) >= 3] for chunk in re.findall(r"[\u4e00-\u9fff]+", q or ""): stripped = chunk.strip() if len(stripped) >= 4: tokens.append(stripped) return list(dict.fromkeys(tokens)) query_tokens = _extract_query_tokens(query) result = await db.execute( select(UserMemory) .where(UserMemory.user_id == user_id) .order_by(UserMemory.importance_score.desc(), UserMemory.created_at.desc()) ) memories = list(result.scalars().all()) if query_tokens: matched = [ m for m in memories if any(token in ((getattr(m, "content", "") or "").lower()) for token in query_tokens) ] if matched: return matched[:top_k] return memories[:top_k] return memories[:top_k]