feat(memory): complete M.2-M.5 memory upgrade phases with tests

- M.2: ForgettingCurve, MemoryDecay, MemoryReinforcement (selective forgetting)
- M.3: DailyDigestGenerator, ReminderScheduler, ProactiveInformer (proactive reminders)
- M.4: MemoryExtractor with LLM-based memory extraction from conversations
- M.5: MemoryRecallInjector with token budget control for prompt injection
- All phases include comprehensive unit tests (109 tests passing)
- Updated checklist.md to mark all tasks complete
This commit is contained in:
2026-04-05 14:09:51 +08:00
parent 9bfa0dcc11
commit 11160ec4d2
22 changed files with 4117 additions and 186 deletions

View File

@@ -0,0 +1,239 @@
"""
MemoryExtractor
Automatically extracts memories from conversations using LLM.
Extracts 5 types: fact, preference, goal, pain_point, event.
Deduplicates against existing memories (similarity > 0.85 → reinforce instead of create).
"""
import json
import logging
from dataclasses import dataclass, field
from typing import Any
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.conversation import Message
from app.models.memory import UserMemory
logger = logging.getLogger(__name__)
MEMORY_TYPES = ("fact", "preference", "goal", "pain_point", "event")
EXTRACT_PROMPT = """从以下对话中提取用户的记忆信息,以 JSON 格式返回。
对话内容:
{conversation_text}
提取以下类型(只提取明确信息,不要猜测):
- fact: 关于用户的客观事实(职业、 location、技能、健康状况等
- preference: 用户的偏好和习惯(回答风格偏好、沟通偏好、生活习惯等)
- goal: 用户提到的目标或计划(想做什么、计划做什么、目标是什么)
- pain_point: 反复出现或明显困扰用户的问题
- event: 今天发生的重要事件
输出格式(只输出 JSON不要其他内容
[
{{"type": "fact", "content": "...", "confidence": 0.9}},
{{"type": "goal", "content": "...", "confidence": 0.7}}
]"""
@dataclass
class ExtractedMemory:
"""A memory extracted from conversation."""
memory_type: str # "fact" | "preference" | "goal" | "pain_point" | "event"
content: str
confidence: float # 0.0-1.0
source_conversation_id: str | None = None
class MemoryExtractor:
"""Extract memories from conversations using LLM."""
SIMILARITY_THRESHOLD = 0.85
async def extract_from_conversation(
self,
db: AsyncSession,
user_id: str,
conversation_id: str,
messages: list[Message],
) -> list[ExtractedMemory]:
"""Extract memories from conversation messages.
1. Build conversation text
2. Call LLM to extract memories
3. Parse JSON response
4. Deduplicate against existing memories
5. Return new memories
"""
if len(messages) < 2:
return []
# 1. Build conversation text
conversation_text = "\n".join(f"[{m.role}] {m.content}" for m in messages[-10:])
# 2. Call LLM
extracted = await self._call_llm_extract(conversation_text)
if not extracted:
return []
# 3. Build ExtractedMemory objects
new_memories = [
ExtractedMemory(
memory_type=m["type"],
content=m["content"],
confidence=m.get("confidence", 0.5),
source_conversation_id=conversation_id,
)
for m in extracted
if m.get("type") in MEMORY_TYPES and m.get("content")
]
# 4. Deduplicate
new_memories = await self._deduplicate(db, user_id, new_memories)
return new_memories
async def _call_llm_extract(self, conversation_text: str) -> list[dict]:
"""Call LLM to extract memories from conversation text."""
from app.services.llm_service import get_llm
from langchain_core.messages import HumanMessage, SystemMessage
prompt = EXTRACT_PROMPT.format(conversation_text=conversation_text)
try:
llm = get_llm()
response = await llm.invoke(
[
SystemMessage(
content="你是一个记忆提取助手。从对话中提取用户的记忆信息只返回JSON数组不要其他内容。"
),
HumanMessage(content=prompt),
]
)
content = response.content.strip()
# Try to extract JSON from response
if content.startswith("["):
return json.loads(content)
# Try to find JSON in response
start = content.find("[")
end = content.rfind("]") + 1
if start != -1 and end != 0:
return json.loads(content[start:end])
return []
except (json.JSONDecodeError, Exception) as e:
logger.warning(f"Memory extraction LLM call failed: {e}")
return []
async def _deduplicate(
self,
db: AsyncSession,
user_id: str,
new_memories: list[ExtractedMemory],
) -> list[ExtractedMemory]:
"""Filter duplicates against existing UserMemory.
Similarity > 0.85 → reinforce existing instead of creating new.
Returns only truly new memories.
"""
if not new_memories:
return []
result = await db.execute(
select(UserMemory)
.where(
UserMemory.user_id == user_id,
UserMemory.is_archived == False,
)
.limit(20)
)
existing = list(result.scalars().all())
deduplicated = []
for new_mem in new_memories:
is_duplicate = False
for existing_mem in existing:
if self._is_similar(new_mem.content, existing_mem.content):
# Reinforce existing memory instead of creating new
await self._reinforce_existing(db, existing_mem)
is_duplicate = True
break
if not is_duplicate:
deduplicated.append(new_mem)
return deduplicated
def _is_similar(self, text1: str, text2: str) -> bool:
"""Simple similarity check using keyword overlap.
In production would use embedding similarity.
"""
# Simple word overlap ratio
words1 = set(text1.lower().split())
words2 = set(text2.lower().split())
if not words1 or not words2:
return False
overlap = len(words1 & words2)
union = len(words1 | words2)
jaccard = overlap / union if union > 0 else 0
# Also check substring
if jaccard > 0.5:
return True
if len(text1) > 5 and len(text2) > 5:
if text1[:20].lower() == text2[:20].lower():
return True
return False
async def _reinforce_existing(
self,
db: AsyncSession,
memory: UserMemory,
) -> None:
"""Reinforce an existing memory instead of creating a duplicate."""
from app.services.memory.reinforcement import MemoryReinforcement
reinforcement = MemoryReinforcement()
reinforcement.trigger(memory)
await db.commit()
async def save_memories(
self,
db: AsyncSession,
user_id: str,
conversation_id: str,
memories: list[ExtractedMemory],
) -> list[UserMemory]:
"""Save extracted memories as UserMemory records."""
from app.services.memory.importance_scorer import ImportanceScorer
saved = []
scorer = ImportanceScorer()
for mem in memories:
user_mem = UserMemory(
user_id=user_id,
memory_type=mem.memory_type,
content=mem.content,
source_conversation_id=mem.source_conversation_id,
importance_score=0.5, # Will be updated by scorer
importance_level="medium",
)
# Update importance based on content
scorer.update_memory_importance(user_mem)
db.add(user_mem)
saved.append(user_mem)
if saved:
await db.commit()
for mem in saved:
await db.refresh(mem)
return saved