feat(memory): complete M.2-M.5 memory upgrade phases with tests

- M.2: ForgettingCurve, MemoryDecay, MemoryReinforcement (selective forgetting)
- M.3: DailyDigestGenerator, ReminderScheduler, ProactiveInformer (proactive reminders)
- M.4: MemoryExtractor with LLM-based memory extraction from conversations
- M.5: MemoryRecallInjector with token budget control for prompt injection
- All phases include comprehensive unit tests (109 tests passing)
- Updated checklist.md to mark all tasks complete
This commit is contained in:
2026-04-05 14:09:51 +08:00
parent 9bfa0dcc11
commit 11160ec4d2
22 changed files with 4117 additions and 186 deletions

View File

@@ -0,0 +1,168 @@
"""
MemoryRecallInjector
Injects relevant memories into LLM system prompt before response generation.
Token budget: 800 by default (configurable).
Priority: pain_point > goal > preference > fact > event
"""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.memory import UserMemory
MEMORY_TYPE_PRIORITY = {
"pain_point": 1,
"goal": 2,
"preference": 3,
"fact": 4,
"event": 5,
}
DEFAULT_TOKEN_BUDGET = 800
class MemoryRecallInjector:
"""Inject relevant memories into system prompt with token budget control."""
def __init__(self, token_budget: int = DEFAULT_TOKEN_BUDGET):
self.token_budget = token_budget
async def build_context(
self,
db: "AsyncSession",
user_id: str,
current_message: str,
token_budget: int | None = None,
) -> str:
"""Build memory context string for injection into system prompt.
1. Recall relevant memories (top_k=20)
2. Filter out archived memories
3. Rank by importance × relevance × type priority
4. Select within token budget
5. Format as system prompt fragment
"""
budget = token_budget or self.token_budget
# 1. Recall candidates using existing memory service
candidates = await recall_user_memories_for_injection(
db, user_id, current_message, top_k=20
)
# 2. Filter archived and non-UserMemory
active = [
m
for m in candidates
if isinstance(m, UserMemory) and not getattr(m, "is_archived", False)
]
# 3. Rank
ranked = self._rank(active, current_message)
# 4. Budget select
selected = self._budget_select(ranked, budget)
# 5. Format
return self._format(selected)
def _rank(
self,
memories: list["UserMemory"],
query: str,
) -> list["UserMemory"]:
"""Rank memories by: relevance * 0.6 + importance * 0.4 * type_boost.
pain_point/goal get 1.0 type boost, others get 0.8.
"""
def score(m: "UserMemory") -> float:
relevance = (
getattr(m, "similarity_score", 0.5) if hasattr(m, "similarity_score") else 0.5
)
importance = getattr(m, "importance_score", 0.5) or 0.5
mem_type = getattr(m, "memory_type", None) or "fact"
type_boost = 1.0 if mem_type in ("goal", "pain_point") else 0.8
return relevance * 0.6 + importance * 0.4 * type_boost
return sorted(memories, key=score, reverse=True)
def _budget_select(
self,
memories: list["UserMemory"],
token_budget: int,
) -> list["UserMemory"]:
"""Greedy selection until token budget runs out.
Rough estimate: 1 token ≈ 2 characters, fixed overhead = 20 tokens.
"""
selected = []
used = 20 # "[关于你的记忆]\n"
for m in memories:
content = getattr(m, "content", "") or ""
cost = len(content) // 2 + 10
if used + cost > token_budget:
break
selected.append(m)
used += cost
return selected
def _format(self, memories: list["UserMemory"]) -> str:
"""Format memories as system prompt fragment."""
if not memories:
return ""
lines = ["[关于你的记忆]"]
for m in memories:
mem_type = getattr(m, "memory_type", None) or ""
content = getattr(m, "content", "") or ""
type_label = f"[{mem_type}]" if mem_type else ""
lines.append(f"- {type_label} {content}".strip())
return "\n".join(lines)
async def recall_user_memories_for_injection(
db: "AsyncSession",
user_id: str,
query: str,
top_k: int = 5,
) -> list:
"""Recall user memories for injection (used by MemoryRecallInjector).
This is a simplified version of recall_user_memories that returns
UserMemory objects directly instead of dicts.
"""
import re
from sqlalchemy import select
from app.models.memory import UserMemory
def _extract_query_tokens(q: str) -> list[str]:
normalized = (q or "").lower()
tokens = [token for token in re.findall(r"[a-z0-9]+", normalized) if len(token) >= 3]
for chunk in re.findall(r"[\u4e00-\u9fff]+", q or ""):
stripped = chunk.strip()
if len(stripped) >= 4:
tokens.append(stripped)
return list(dict.fromkeys(tokens))
query_tokens = _extract_query_tokens(query)
result = await db.execute(
select(UserMemory)
.where(UserMemory.user_id == user_id)
.order_by(UserMemory.importance_score.desc(), UserMemory.created_at.desc())
)
memories = list(result.scalars().all())
if query_tokens:
matched = [
m
for m in memories
if any(token in ((getattr(m, "content", "") or "").lower()) for token in query_tokens)
]
if matched:
return matched[:top_k]
return memories[:top_k]
return memories[:top_k]