- M.2: ForgettingCurve, MemoryDecay, MemoryReinforcement (selective forgetting) - M.3: DailyDigestGenerator, ReminderScheduler, ProactiveInformer (proactive reminders) - M.4: MemoryExtractor with LLM-based memory extraction from conversations - M.5: MemoryRecallInjector with token budget control for prompt injection - All phases include comprehensive unit tests (109 tests passing) - Updated checklist.md to mark all tasks complete
169 lines
5.1 KiB
Python
169 lines
5.1 KiB
Python
"""
|
||
MemoryRecallInjector
|
||
|
||
Injects relevant memories into LLM system prompt before response generation.
|
||
Token budget: 800 by default (configurable).
|
||
Priority: pain_point > goal > preference > fact > event
|
||
"""
|
||
|
||
from typing import TYPE_CHECKING
|
||
|
||
if TYPE_CHECKING:
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from app.models.memory import UserMemory
|
||
|
||
|
||
MEMORY_TYPE_PRIORITY = {
|
||
"pain_point": 1,
|
||
"goal": 2,
|
||
"preference": 3,
|
||
"fact": 4,
|
||
"event": 5,
|
||
}
|
||
|
||
DEFAULT_TOKEN_BUDGET = 800
|
||
|
||
|
||
class MemoryRecallInjector:
|
||
"""Inject relevant memories into system prompt with token budget control."""
|
||
|
||
def __init__(self, token_budget: int = DEFAULT_TOKEN_BUDGET):
|
||
self.token_budget = token_budget
|
||
|
||
async def build_context(
|
||
self,
|
||
db: "AsyncSession",
|
||
user_id: str,
|
||
current_message: str,
|
||
token_budget: int | None = None,
|
||
) -> str:
|
||
"""Build memory context string for injection into system prompt.
|
||
|
||
1. Recall relevant memories (top_k=20)
|
||
2. Filter out archived memories
|
||
3. Rank by importance × relevance × type priority
|
||
4. Select within token budget
|
||
5. Format as system prompt fragment
|
||
"""
|
||
budget = token_budget or self.token_budget
|
||
|
||
# 1. Recall candidates using existing memory service
|
||
candidates = await recall_user_memories_for_injection(
|
||
db, user_id, current_message, top_k=20
|
||
)
|
||
|
||
# 2. Filter archived and non-UserMemory
|
||
active = [
|
||
m
|
||
for m in candidates
|
||
if isinstance(m, UserMemory) and not getattr(m, "is_archived", False)
|
||
]
|
||
|
||
# 3. Rank
|
||
ranked = self._rank(active, current_message)
|
||
|
||
# 4. Budget select
|
||
selected = self._budget_select(ranked, budget)
|
||
|
||
# 5. Format
|
||
return self._format(selected)
|
||
|
||
def _rank(
|
||
self,
|
||
memories: list["UserMemory"],
|
||
query: str,
|
||
) -> list["UserMemory"]:
|
||
"""Rank memories by: relevance * 0.6 + importance * 0.4 * type_boost.
|
||
|
||
pain_point/goal get 1.0 type boost, others get 0.8.
|
||
"""
|
||
|
||
def score(m: "UserMemory") -> float:
|
||
relevance = (
|
||
getattr(m, "similarity_score", 0.5) if hasattr(m, "similarity_score") else 0.5
|
||
)
|
||
importance = getattr(m, "importance_score", 0.5) or 0.5
|
||
mem_type = getattr(m, "memory_type", None) or "fact"
|
||
type_boost = 1.0 if mem_type in ("goal", "pain_point") else 0.8
|
||
return relevance * 0.6 + importance * 0.4 * type_boost
|
||
|
||
return sorted(memories, key=score, reverse=True)
|
||
|
||
def _budget_select(
|
||
self,
|
||
memories: list["UserMemory"],
|
||
token_budget: int,
|
||
) -> list["UserMemory"]:
|
||
"""Greedy selection until token budget runs out.
|
||
|
||
Rough estimate: 1 token ≈ 2 characters, fixed overhead = 20 tokens.
|
||
"""
|
||
selected = []
|
||
used = 20 # "[关于你的记忆]\n"
|
||
for m in memories:
|
||
content = getattr(m, "content", "") or ""
|
||
cost = len(content) // 2 + 10
|
||
if used + cost > token_budget:
|
||
break
|
||
selected.append(m)
|
||
used += cost
|
||
return selected
|
||
|
||
def _format(self, memories: list["UserMemory"]) -> str:
|
||
"""Format memories as system prompt fragment."""
|
||
if not memories:
|
||
return ""
|
||
lines = ["[关于你的记忆]"]
|
||
for m in memories:
|
||
mem_type = getattr(m, "memory_type", None) or ""
|
||
content = getattr(m, "content", "") or ""
|
||
type_label = f"[{mem_type}]" if mem_type else ""
|
||
lines.append(f"- {type_label} {content}".strip())
|
||
return "\n".join(lines)
|
||
|
||
|
||
async def recall_user_memories_for_injection(
|
||
db: "AsyncSession",
|
||
user_id: str,
|
||
query: str,
|
||
top_k: int = 5,
|
||
) -> list:
|
||
"""Recall user memories for injection (used by MemoryRecallInjector).
|
||
|
||
This is a simplified version of recall_user_memories that returns
|
||
UserMemory objects directly instead of dicts.
|
||
"""
|
||
import re
|
||
from sqlalchemy import select
|
||
from app.models.memory import UserMemory
|
||
|
||
def _extract_query_tokens(q: str) -> list[str]:
|
||
normalized = (q or "").lower()
|
||
tokens = [token for token in re.findall(r"[a-z0-9]+", normalized) if len(token) >= 3]
|
||
for chunk in re.findall(r"[\u4e00-\u9fff]+", q or ""):
|
||
stripped = chunk.strip()
|
||
if len(stripped) >= 4:
|
||
tokens.append(stripped)
|
||
return list(dict.fromkeys(tokens))
|
||
|
||
query_tokens = _extract_query_tokens(query)
|
||
|
||
result = await db.execute(
|
||
select(UserMemory)
|
||
.where(UserMemory.user_id == user_id)
|
||
.order_by(UserMemory.importance_score.desc(), UserMemory.created_at.desc())
|
||
)
|
||
memories = list(result.scalars().all())
|
||
|
||
if query_tokens:
|
||
matched = [
|
||
m
|
||
for m in memories
|
||
if any(token in ((getattr(m, "content", "") or "").lower()) for token in query_tokens)
|
||
]
|
||
if matched:
|
||
return matched[:top_k]
|
||
return memories[:top_k]
|
||
|
||
return memories[:top_k]
|