- M.2: ForgettingCurve, MemoryDecay, MemoryReinforcement (selective forgetting) - M.3: DailyDigestGenerator, ReminderScheduler, ProactiveInformer (proactive reminders) - M.4: MemoryExtractor with LLM-based memory extraction from conversations - M.5: MemoryRecallInjector with token budget control for prompt injection - All phases include comprehensive unit tests (109 tests passing) - Updated checklist.md to mark all tasks complete
240 lines
7.5 KiB
Python
240 lines
7.5 KiB
Python
"""
|
||
MemoryExtractor
|
||
|
||
Automatically extracts memories from conversations using LLM.
|
||
Extracts 5 types: fact, preference, goal, pain_point, event.
|
||
Deduplicates against existing memories (similarity > 0.85 → reinforce instead of create).
|
||
"""
|
||
|
||
import json
|
||
import logging
|
||
from dataclasses import dataclass, field
|
||
from typing import Any
|
||
|
||
from sqlalchemy import select
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.models.conversation import Message
|
||
from app.models.memory import UserMemory
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
MEMORY_TYPES = ("fact", "preference", "goal", "pain_point", "event")
|
||
|
||
EXTRACT_PROMPT = """从以下对话中提取用户的记忆信息,以 JSON 格式返回。
|
||
|
||
对话内容:
|
||
{conversation_text}
|
||
|
||
提取以下类型(只提取明确信息,不要猜测):
|
||
- fact: 关于用户的客观事实(职业、 location、技能、健康状况等)
|
||
- preference: 用户的偏好和习惯(回答风格偏好、沟通偏好、生活习惯等)
|
||
- goal: 用户提到的目标或计划(想做什么、计划做什么、目标是什么)
|
||
- pain_point: 反复出现或明显困扰用户的问题
|
||
- event: 今天发生的重要事件
|
||
|
||
输出格式(只输出 JSON,不要其他内容):
|
||
[
|
||
{{"type": "fact", "content": "...", "confidence": 0.9}},
|
||
{{"type": "goal", "content": "...", "confidence": 0.7}}
|
||
]"""
|
||
|
||
|
||
@dataclass
|
||
class ExtractedMemory:
|
||
"""A memory extracted from conversation."""
|
||
|
||
memory_type: str # "fact" | "preference" | "goal" | "pain_point" | "event"
|
||
content: str
|
||
confidence: float # 0.0-1.0
|
||
source_conversation_id: str | None = None
|
||
|
||
|
||
class MemoryExtractor:
|
||
"""Extract memories from conversations using LLM."""
|
||
|
||
SIMILARITY_THRESHOLD = 0.85
|
||
|
||
async def extract_from_conversation(
|
||
self,
|
||
db: AsyncSession,
|
||
user_id: str,
|
||
conversation_id: str,
|
||
messages: list[Message],
|
||
) -> list[ExtractedMemory]:
|
||
"""Extract memories from conversation messages.
|
||
|
||
1. Build conversation text
|
||
2. Call LLM to extract memories
|
||
3. Parse JSON response
|
||
4. Deduplicate against existing memories
|
||
5. Return new memories
|
||
"""
|
||
if len(messages) < 2:
|
||
return []
|
||
|
||
# 1. Build conversation text
|
||
conversation_text = "\n".join(f"[{m.role}] {m.content}" for m in messages[-10:])
|
||
|
||
# 2. Call LLM
|
||
extracted = await self._call_llm_extract(conversation_text)
|
||
if not extracted:
|
||
return []
|
||
|
||
# 3. Build ExtractedMemory objects
|
||
new_memories = [
|
||
ExtractedMemory(
|
||
memory_type=m["type"],
|
||
content=m["content"],
|
||
confidence=m.get("confidence", 0.5),
|
||
source_conversation_id=conversation_id,
|
||
)
|
||
for m in extracted
|
||
if m.get("type") in MEMORY_TYPES and m.get("content")
|
||
]
|
||
|
||
# 4. Deduplicate
|
||
new_memories = await self._deduplicate(db, user_id, new_memories)
|
||
|
||
return new_memories
|
||
|
||
async def _call_llm_extract(self, conversation_text: str) -> list[dict]:
|
||
"""Call LLM to extract memories from conversation text."""
|
||
from app.services.llm_service import get_llm
|
||
from langchain_core.messages import HumanMessage, SystemMessage
|
||
|
||
prompt = EXTRACT_PROMPT.format(conversation_text=conversation_text)
|
||
|
||
try:
|
||
llm = get_llm()
|
||
response = await llm.invoke(
|
||
[
|
||
SystemMessage(
|
||
content="你是一个记忆提取助手。从对话中提取用户的记忆信息,只返回JSON数组,不要其他内容。"
|
||
),
|
||
HumanMessage(content=prompt),
|
||
]
|
||
)
|
||
content = response.content.strip()
|
||
|
||
# Try to extract JSON from response
|
||
if content.startswith("["):
|
||
return json.loads(content)
|
||
# Try to find JSON in response
|
||
start = content.find("[")
|
||
end = content.rfind("]") + 1
|
||
if start != -1 and end != 0:
|
||
return json.loads(content[start:end])
|
||
return []
|
||
except (json.JSONDecodeError, Exception) as e:
|
||
logger.warning(f"Memory extraction LLM call failed: {e}")
|
||
return []
|
||
|
||
async def _deduplicate(
|
||
self,
|
||
db: AsyncSession,
|
||
user_id: str,
|
||
new_memories: list[ExtractedMemory],
|
||
) -> list[ExtractedMemory]:
|
||
"""Filter duplicates against existing UserMemory.
|
||
|
||
Similarity > 0.85 → reinforce existing instead of creating new.
|
||
Returns only truly new memories.
|
||
"""
|
||
if not new_memories:
|
||
return []
|
||
|
||
result = await db.execute(
|
||
select(UserMemory)
|
||
.where(
|
||
UserMemory.user_id == user_id,
|
||
UserMemory.is_archived == False,
|
||
)
|
||
.limit(20)
|
||
)
|
||
existing = list(result.scalars().all())
|
||
|
||
deduplicated = []
|
||
for new_mem in new_memories:
|
||
is_duplicate = False
|
||
for existing_mem in existing:
|
||
if self._is_similar(new_mem.content, existing_mem.content):
|
||
# Reinforce existing memory instead of creating new
|
||
await self._reinforce_existing(db, existing_mem)
|
||
is_duplicate = True
|
||
break
|
||
if not is_duplicate:
|
||
deduplicated.append(new_mem)
|
||
|
||
return deduplicated
|
||
|
||
def _is_similar(self, text1: str, text2: str) -> bool:
|
||
"""Simple similarity check using keyword overlap.
|
||
|
||
In production would use embedding similarity.
|
||
"""
|
||
# Simple word overlap ratio
|
||
words1 = set(text1.lower().split())
|
||
words2 = set(text2.lower().split())
|
||
if not words1 or not words2:
|
||
return False
|
||
|
||
overlap = len(words1 & words2)
|
||
union = len(words1 | words2)
|
||
jaccard = overlap / union if union > 0 else 0
|
||
|
||
# Also check substring
|
||
if jaccard > 0.5:
|
||
return True
|
||
if len(text1) > 5 and len(text2) > 5:
|
||
if text1[:20].lower() == text2[:20].lower():
|
||
return True
|
||
return False
|
||
|
||
async def _reinforce_existing(
|
||
self,
|
||
db: AsyncSession,
|
||
memory: UserMemory,
|
||
) -> None:
|
||
"""Reinforce an existing memory instead of creating a duplicate."""
|
||
from app.services.memory.reinforcement import MemoryReinforcement
|
||
|
||
reinforcement = MemoryReinforcement()
|
||
reinforcement.trigger(memory)
|
||
await db.commit()
|
||
|
||
async def save_memories(
|
||
self,
|
||
db: AsyncSession,
|
||
user_id: str,
|
||
conversation_id: str,
|
||
memories: list[ExtractedMemory],
|
||
) -> list[UserMemory]:
|
||
"""Save extracted memories as UserMemory records."""
|
||
from app.services.memory.importance_scorer import ImportanceScorer
|
||
|
||
saved = []
|
||
scorer = ImportanceScorer()
|
||
|
||
for mem in memories:
|
||
user_mem = UserMemory(
|
||
user_id=user_id,
|
||
memory_type=mem.memory_type,
|
||
content=mem.content,
|
||
source_conversation_id=mem.source_conversation_id,
|
||
importance_score=0.5, # Will be updated by scorer
|
||
importance_level="medium",
|
||
)
|
||
# Update importance based on content
|
||
scorer.update_memory_importance(user_mem)
|
||
|
||
db.add(user_mem)
|
||
saved.append(user_mem)
|
||
|
||
if saved:
|
||
await db.commit()
|
||
for mem in saved:
|
||
await db.refresh(mem)
|
||
|
||
return saved
|