feat(memory): complete M.2-M.5 memory upgrade phases with tests
- M.2: ForgettingCurve, MemoryDecay, MemoryReinforcement (selective forgetting) - M.3: DailyDigestGenerator, ReminderScheduler, ProactiveInformer (proactive reminders) - M.4: MemoryExtractor with LLM-based memory extraction from conversations - M.5: MemoryRecallInjector with token budget control for prompt injection - All phases include comprehensive unit tests (109 tests passing) - Updated checklist.md to mark all tasks complete
This commit is contained in:
168
backend/app/services/memory/recall_injector.py
Normal file
168
backend/app/services/memory/recall_injector.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""
|
||||
MemoryRecallInjector
|
||||
|
||||
Injects relevant memories into LLM system prompt before response generation.
|
||||
Token budget: 800 by default (configurable).
|
||||
Priority: pain_point > goal > preference > fact > event
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.models.memory import UserMemory
|
||||
|
||||
|
||||
MEMORY_TYPE_PRIORITY = {
|
||||
"pain_point": 1,
|
||||
"goal": 2,
|
||||
"preference": 3,
|
||||
"fact": 4,
|
||||
"event": 5,
|
||||
}
|
||||
|
||||
DEFAULT_TOKEN_BUDGET = 800
|
||||
|
||||
|
||||
class MemoryRecallInjector:
|
||||
"""Inject relevant memories into system prompt with token budget control."""
|
||||
|
||||
def __init__(self, token_budget: int = DEFAULT_TOKEN_BUDGET):
|
||||
self.token_budget = token_budget
|
||||
|
||||
async def build_context(
|
||||
self,
|
||||
db: "AsyncSession",
|
||||
user_id: str,
|
||||
current_message: str,
|
||||
token_budget: int | None = None,
|
||||
) -> str:
|
||||
"""Build memory context string for injection into system prompt.
|
||||
|
||||
1. Recall relevant memories (top_k=20)
|
||||
2. Filter out archived memories
|
||||
3. Rank by importance × relevance × type priority
|
||||
4. Select within token budget
|
||||
5. Format as system prompt fragment
|
||||
"""
|
||||
budget = token_budget or self.token_budget
|
||||
|
||||
# 1. Recall candidates using existing memory service
|
||||
candidates = await recall_user_memories_for_injection(
|
||||
db, user_id, current_message, top_k=20
|
||||
)
|
||||
|
||||
# 2. Filter archived and non-UserMemory
|
||||
active = [
|
||||
m
|
||||
for m in candidates
|
||||
if isinstance(m, UserMemory) and not getattr(m, "is_archived", False)
|
||||
]
|
||||
|
||||
# 3. Rank
|
||||
ranked = self._rank(active, current_message)
|
||||
|
||||
# 4. Budget select
|
||||
selected = self._budget_select(ranked, budget)
|
||||
|
||||
# 5. Format
|
||||
return self._format(selected)
|
||||
|
||||
def _rank(
|
||||
self,
|
||||
memories: list["UserMemory"],
|
||||
query: str,
|
||||
) -> list["UserMemory"]:
|
||||
"""Rank memories by: relevance * 0.6 + importance * 0.4 * type_boost.
|
||||
|
||||
pain_point/goal get 1.0 type boost, others get 0.8.
|
||||
"""
|
||||
|
||||
def score(m: "UserMemory") -> float:
|
||||
relevance = (
|
||||
getattr(m, "similarity_score", 0.5) if hasattr(m, "similarity_score") else 0.5
|
||||
)
|
||||
importance = getattr(m, "importance_score", 0.5) or 0.5
|
||||
mem_type = getattr(m, "memory_type", None) or "fact"
|
||||
type_boost = 1.0 if mem_type in ("goal", "pain_point") else 0.8
|
||||
return relevance * 0.6 + importance * 0.4 * type_boost
|
||||
|
||||
return sorted(memories, key=score, reverse=True)
|
||||
|
||||
def _budget_select(
|
||||
self,
|
||||
memories: list["UserMemory"],
|
||||
token_budget: int,
|
||||
) -> list["UserMemory"]:
|
||||
"""Greedy selection until token budget runs out.
|
||||
|
||||
Rough estimate: 1 token ≈ 2 characters, fixed overhead = 20 tokens.
|
||||
"""
|
||||
selected = []
|
||||
used = 20 # "[关于你的记忆]\n"
|
||||
for m in memories:
|
||||
content = getattr(m, "content", "") or ""
|
||||
cost = len(content) // 2 + 10
|
||||
if used + cost > token_budget:
|
||||
break
|
||||
selected.append(m)
|
||||
used += cost
|
||||
return selected
|
||||
|
||||
def _format(self, memories: list["UserMemory"]) -> str:
|
||||
"""Format memories as system prompt fragment."""
|
||||
if not memories:
|
||||
return ""
|
||||
lines = ["[关于你的记忆]"]
|
||||
for m in memories:
|
||||
mem_type = getattr(m, "memory_type", None) or ""
|
||||
content = getattr(m, "content", "") or ""
|
||||
type_label = f"[{mem_type}]" if mem_type else ""
|
||||
lines.append(f"- {type_label} {content}".strip())
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
async def recall_user_memories_for_injection(
|
||||
db: "AsyncSession",
|
||||
user_id: str,
|
||||
query: str,
|
||||
top_k: int = 5,
|
||||
) -> list:
|
||||
"""Recall user memories for injection (used by MemoryRecallInjector).
|
||||
|
||||
This is a simplified version of recall_user_memories that returns
|
||||
UserMemory objects directly instead of dicts.
|
||||
"""
|
||||
import re
|
||||
from sqlalchemy import select
|
||||
from app.models.memory import UserMemory
|
||||
|
||||
def _extract_query_tokens(q: str) -> list[str]:
|
||||
normalized = (q or "").lower()
|
||||
tokens = [token for token in re.findall(r"[a-z0-9]+", normalized) if len(token) >= 3]
|
||||
for chunk in re.findall(r"[\u4e00-\u9fff]+", q or ""):
|
||||
stripped = chunk.strip()
|
||||
if len(stripped) >= 4:
|
||||
tokens.append(stripped)
|
||||
return list(dict.fromkeys(tokens))
|
||||
|
||||
query_tokens = _extract_query_tokens(query)
|
||||
|
||||
result = await db.execute(
|
||||
select(UserMemory)
|
||||
.where(UserMemory.user_id == user_id)
|
||||
.order_by(UserMemory.importance_score.desc(), UserMemory.created_at.desc())
|
||||
)
|
||||
memories = list(result.scalars().all())
|
||||
|
||||
if query_tokens:
|
||||
matched = [
|
||||
m
|
||||
for m in memories
|
||||
if any(token in ((getattr(m, "content", "") or "").lower()) for token in query_tokens)
|
||||
]
|
||||
if matched:
|
||||
return matched[:top_k]
|
||||
return memories[:top_k]
|
||||
|
||||
return memories[:top_k]
|
||||
Reference in New Issue
Block a user