feat(memory): complete M.2-M.5 memory upgrade phases with tests
- M.2: ForgettingCurve, MemoryDecay, MemoryReinforcement (selective forgetting) - M.3: DailyDigestGenerator, ReminderScheduler, ProactiveInformer (proactive reminders) - M.4: MemoryExtractor with LLM-based memory extraction from conversations - M.5: MemoryRecallInjector with token budget control for prompt injection - All phases include comprehensive unit tests (109 tests passing) - Updated checklist.md to mark all tasks complete
This commit is contained in:
239
backend/app/services/memory/memory_extractor.py
Normal file
239
backend/app/services/memory/memory_extractor.py
Normal file
@@ -0,0 +1,239 @@
|
||||
"""
|
||||
MemoryExtractor
|
||||
|
||||
Automatically extracts memories from conversations using LLM.
|
||||
Extracts 5 types: fact, preference, goal, pain_point, event.
|
||||
Deduplicates against existing memories (similarity > 0.85 → reinforce instead of create).
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.conversation import Message
|
||||
from app.models.memory import UserMemory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MEMORY_TYPES = ("fact", "preference", "goal", "pain_point", "event")
|
||||
|
||||
EXTRACT_PROMPT = """从以下对话中提取用户的记忆信息,以 JSON 格式返回。
|
||||
|
||||
对话内容:
|
||||
{conversation_text}
|
||||
|
||||
提取以下类型(只提取明确信息,不要猜测):
|
||||
- fact: 关于用户的客观事实(职业、 location、技能、健康状况等)
|
||||
- preference: 用户的偏好和习惯(回答风格偏好、沟通偏好、生活习惯等)
|
||||
- goal: 用户提到的目标或计划(想做什么、计划做什么、目标是什么)
|
||||
- pain_point: 反复出现或明显困扰用户的问题
|
||||
- event: 今天发生的重要事件
|
||||
|
||||
输出格式(只输出 JSON,不要其他内容):
|
||||
[
|
||||
{{"type": "fact", "content": "...", "confidence": 0.9}},
|
||||
{{"type": "goal", "content": "...", "confidence": 0.7}}
|
||||
]"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractedMemory:
|
||||
"""A memory extracted from conversation."""
|
||||
|
||||
memory_type: str # "fact" | "preference" | "goal" | "pain_point" | "event"
|
||||
content: str
|
||||
confidence: float # 0.0-1.0
|
||||
source_conversation_id: str | None = None
|
||||
|
||||
|
||||
class MemoryExtractor:
|
||||
"""Extract memories from conversations using LLM."""
|
||||
|
||||
SIMILARITY_THRESHOLD = 0.85
|
||||
|
||||
async def extract_from_conversation(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
conversation_id: str,
|
||||
messages: list[Message],
|
||||
) -> list[ExtractedMemory]:
|
||||
"""Extract memories from conversation messages.
|
||||
|
||||
1. Build conversation text
|
||||
2. Call LLM to extract memories
|
||||
3. Parse JSON response
|
||||
4. Deduplicate against existing memories
|
||||
5. Return new memories
|
||||
"""
|
||||
if len(messages) < 2:
|
||||
return []
|
||||
|
||||
# 1. Build conversation text
|
||||
conversation_text = "\n".join(f"[{m.role}] {m.content}" for m in messages[-10:])
|
||||
|
||||
# 2. Call LLM
|
||||
extracted = await self._call_llm_extract(conversation_text)
|
||||
if not extracted:
|
||||
return []
|
||||
|
||||
# 3. Build ExtractedMemory objects
|
||||
new_memories = [
|
||||
ExtractedMemory(
|
||||
memory_type=m["type"],
|
||||
content=m["content"],
|
||||
confidence=m.get("confidence", 0.5),
|
||||
source_conversation_id=conversation_id,
|
||||
)
|
||||
for m in extracted
|
||||
if m.get("type") in MEMORY_TYPES and m.get("content")
|
||||
]
|
||||
|
||||
# 4. Deduplicate
|
||||
new_memories = await self._deduplicate(db, user_id, new_memories)
|
||||
|
||||
return new_memories
|
||||
|
||||
async def _call_llm_extract(self, conversation_text: str) -> list[dict]:
|
||||
"""Call LLM to extract memories from conversation text."""
|
||||
from app.services.llm_service import get_llm
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
|
||||
prompt = EXTRACT_PROMPT.format(conversation_text=conversation_text)
|
||||
|
||||
try:
|
||||
llm = get_llm()
|
||||
response = await llm.invoke(
|
||||
[
|
||||
SystemMessage(
|
||||
content="你是一个记忆提取助手。从对话中提取用户的记忆信息,只返回JSON数组,不要其他内容。"
|
||||
),
|
||||
HumanMessage(content=prompt),
|
||||
]
|
||||
)
|
||||
content = response.content.strip()
|
||||
|
||||
# Try to extract JSON from response
|
||||
if content.startswith("["):
|
||||
return json.loads(content)
|
||||
# Try to find JSON in response
|
||||
start = content.find("[")
|
||||
end = content.rfind("]") + 1
|
||||
if start != -1 and end != 0:
|
||||
return json.loads(content[start:end])
|
||||
return []
|
||||
except (json.JSONDecodeError, Exception) as e:
|
||||
logger.warning(f"Memory extraction LLM call failed: {e}")
|
||||
return []
|
||||
|
||||
async def _deduplicate(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
new_memories: list[ExtractedMemory],
|
||||
) -> list[ExtractedMemory]:
|
||||
"""Filter duplicates against existing UserMemory.
|
||||
|
||||
Similarity > 0.85 → reinforce existing instead of creating new.
|
||||
Returns only truly new memories.
|
||||
"""
|
||||
if not new_memories:
|
||||
return []
|
||||
|
||||
result = await db.execute(
|
||||
select(UserMemory)
|
||||
.where(
|
||||
UserMemory.user_id == user_id,
|
||||
UserMemory.is_archived == False,
|
||||
)
|
||||
.limit(20)
|
||||
)
|
||||
existing = list(result.scalars().all())
|
||||
|
||||
deduplicated = []
|
||||
for new_mem in new_memories:
|
||||
is_duplicate = False
|
||||
for existing_mem in existing:
|
||||
if self._is_similar(new_mem.content, existing_mem.content):
|
||||
# Reinforce existing memory instead of creating new
|
||||
await self._reinforce_existing(db, existing_mem)
|
||||
is_duplicate = True
|
||||
break
|
||||
if not is_duplicate:
|
||||
deduplicated.append(new_mem)
|
||||
|
||||
return deduplicated
|
||||
|
||||
def _is_similar(self, text1: str, text2: str) -> bool:
|
||||
"""Simple similarity check using keyword overlap.
|
||||
|
||||
In production would use embedding similarity.
|
||||
"""
|
||||
# Simple word overlap ratio
|
||||
words1 = set(text1.lower().split())
|
||||
words2 = set(text2.lower().split())
|
||||
if not words1 or not words2:
|
||||
return False
|
||||
|
||||
overlap = len(words1 & words2)
|
||||
union = len(words1 | words2)
|
||||
jaccard = overlap / union if union > 0 else 0
|
||||
|
||||
# Also check substring
|
||||
if jaccard > 0.5:
|
||||
return True
|
||||
if len(text1) > 5 and len(text2) > 5:
|
||||
if text1[:20].lower() == text2[:20].lower():
|
||||
return True
|
||||
return False
|
||||
|
||||
async def _reinforce_existing(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
memory: UserMemory,
|
||||
) -> None:
|
||||
"""Reinforce an existing memory instead of creating a duplicate."""
|
||||
from app.services.memory.reinforcement import MemoryReinforcement
|
||||
|
||||
reinforcement = MemoryReinforcement()
|
||||
reinforcement.trigger(memory)
|
||||
await db.commit()
|
||||
|
||||
async def save_memories(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
conversation_id: str,
|
||||
memories: list[ExtractedMemory],
|
||||
) -> list[UserMemory]:
|
||||
"""Save extracted memories as UserMemory records."""
|
||||
from app.services.memory.importance_scorer import ImportanceScorer
|
||||
|
||||
saved = []
|
||||
scorer = ImportanceScorer()
|
||||
|
||||
for mem in memories:
|
||||
user_mem = UserMemory(
|
||||
user_id=user_id,
|
||||
memory_type=mem.memory_type,
|
||||
content=mem.content,
|
||||
source_conversation_id=mem.source_conversation_id,
|
||||
importance_score=0.5, # Will be updated by scorer
|
||||
importance_level="medium",
|
||||
)
|
||||
# Update importance based on content
|
||||
scorer.update_memory_importance(user_mem)
|
||||
|
||||
db.add(user_mem)
|
||||
saved.append(user_mem)
|
||||
|
||||
if saved:
|
||||
await db.commit()
|
||||
for mem in saved:
|
||||
await db.refresh(mem)
|
||||
|
||||
return saved
|
||||
Reference in New Issue
Block a user