Files
JARVIS/backend/app/services/memory/recall_injector.py
WIN-JHFT4D3SIVT\caoxiaozhu 11160ec4d2 feat(memory): complete M.2-M.5 memory upgrade phases with tests
- M.2: ForgettingCurve, MemoryDecay, MemoryReinforcement (selective forgetting)
- M.3: DailyDigestGenerator, ReminderScheduler, ProactiveInformer (proactive reminders)
- M.4: MemoryExtractor with LLM-based memory extraction from conversations
- M.5: MemoryRecallInjector with token budget control for prompt injection
- All phases include comprehensive unit tests (109 tests passing)
- Updated checklist.md to mark all tasks complete
2026-04-05 14:09:51 +08:00

169 lines
5.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
MemoryRecallInjector
Injects relevant memories into LLM system prompt before response generation.
Token budget: 800 by default (configurable).
Priority: pain_point > goal > preference > fact > event
"""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.memory import UserMemory
MEMORY_TYPE_PRIORITY = {
"pain_point": 1,
"goal": 2,
"preference": 3,
"fact": 4,
"event": 5,
}
DEFAULT_TOKEN_BUDGET = 800
class MemoryRecallInjector:
"""Inject relevant memories into system prompt with token budget control."""
def __init__(self, token_budget: int = DEFAULT_TOKEN_BUDGET):
self.token_budget = token_budget
async def build_context(
self,
db: "AsyncSession",
user_id: str,
current_message: str,
token_budget: int | None = None,
) -> str:
"""Build memory context string for injection into system prompt.
1. Recall relevant memories (top_k=20)
2. Filter out archived memories
3. Rank by importance × relevance × type priority
4. Select within token budget
5. Format as system prompt fragment
"""
budget = token_budget or self.token_budget
# 1. Recall candidates using existing memory service
candidates = await recall_user_memories_for_injection(
db, user_id, current_message, top_k=20
)
# 2. Filter archived and non-UserMemory
active = [
m
for m in candidates
if isinstance(m, UserMemory) and not getattr(m, "is_archived", False)
]
# 3. Rank
ranked = self._rank(active, current_message)
# 4. Budget select
selected = self._budget_select(ranked, budget)
# 5. Format
return self._format(selected)
def _rank(
self,
memories: list["UserMemory"],
query: str,
) -> list["UserMemory"]:
"""Rank memories by: relevance * 0.6 + importance * 0.4 * type_boost.
pain_point/goal get 1.0 type boost, others get 0.8.
"""
def score(m: "UserMemory") -> float:
relevance = (
getattr(m, "similarity_score", 0.5) if hasattr(m, "similarity_score") else 0.5
)
importance = getattr(m, "importance_score", 0.5) or 0.5
mem_type = getattr(m, "memory_type", None) or "fact"
type_boost = 1.0 if mem_type in ("goal", "pain_point") else 0.8
return relevance * 0.6 + importance * 0.4 * type_boost
return sorted(memories, key=score, reverse=True)
def _budget_select(
self,
memories: list["UserMemory"],
token_budget: int,
) -> list["UserMemory"]:
"""Greedy selection until token budget runs out.
Rough estimate: 1 token ≈ 2 characters, fixed overhead = 20 tokens.
"""
selected = []
used = 20 # "[关于你的记忆]\n"
for m in memories:
content = getattr(m, "content", "") or ""
cost = len(content) // 2 + 10
if used + cost > token_budget:
break
selected.append(m)
used += cost
return selected
def _format(self, memories: list["UserMemory"]) -> str:
"""Format memories as system prompt fragment."""
if not memories:
return ""
lines = ["[关于你的记忆]"]
for m in memories:
mem_type = getattr(m, "memory_type", None) or ""
content = getattr(m, "content", "") or ""
type_label = f"[{mem_type}]" if mem_type else ""
lines.append(f"- {type_label} {content}".strip())
return "\n".join(lines)
async def recall_user_memories_for_injection(
db: "AsyncSession",
user_id: str,
query: str,
top_k: int = 5,
) -> list:
"""Recall user memories for injection (used by MemoryRecallInjector).
This is a simplified version of recall_user_memories that returns
UserMemory objects directly instead of dicts.
"""
import re
from sqlalchemy import select
from app.models.memory import UserMemory
def _extract_query_tokens(q: str) -> list[str]:
normalized = (q or "").lower()
tokens = [token for token in re.findall(r"[a-z0-9]+", normalized) if len(token) >= 3]
for chunk in re.findall(r"[\u4e00-\u9fff]+", q or ""):
stripped = chunk.strip()
if len(stripped) >= 4:
tokens.append(stripped)
return list(dict.fromkeys(tokens))
query_tokens = _extract_query_tokens(query)
result = await db.execute(
select(UserMemory)
.where(UserMemory.user_id == user_id)
.order_by(UserMemory.importance_score.desc(), UserMemory.created_at.desc())
)
memories = list(result.scalars().all())
if query_tokens:
matched = [
m
for m in memories
if any(token in ((getattr(m, "content", "") or "").lower()) for token in query_tokens)
]
if matched:
return matched[:top_k]
return memories[:top_k]
return memories[:top_k]