240 lines
7.5 KiB
Python
240 lines
7.5 KiB
Python
|
|
"""
|
|||
|
|
MemoryExtractor
|
|||
|
|
|
|||
|
|
Automatically extracts memories from conversations using LLM.
|
|||
|
|
Extracts 5 types: fact, preference, goal, pain_point, event.
|
|||
|
|
Deduplicates against existing memories (similarity > 0.85 → reinforce instead of create).
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import json
|
|||
|
|
import logging
|
|||
|
|
from dataclasses import dataclass, field
|
|||
|
|
from typing import Any
|
|||
|
|
|
|||
|
|
from sqlalchemy import select
|
|||
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
|
|
|||
|
|
from app.models.conversation import Message
|
|||
|
|
from app.models.memory import UserMemory
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
MEMORY_TYPES = ("fact", "preference", "goal", "pain_point", "event")
|
|||
|
|
|
|||
|
|
EXTRACT_PROMPT = """从以下对话中提取用户的记忆信息,以 JSON 格式返回。
|
|||
|
|
|
|||
|
|
对话内容:
|
|||
|
|
{conversation_text}
|
|||
|
|
|
|||
|
|
提取以下类型(只提取明确信息,不要猜测):
|
|||
|
|
- fact: 关于用户的客观事实(职业、 location、技能、健康状况等)
|
|||
|
|
- preference: 用户的偏好和习惯(回答风格偏好、沟通偏好、生活习惯等)
|
|||
|
|
- goal: 用户提到的目标或计划(想做什么、计划做什么、目标是什么)
|
|||
|
|
- pain_point: 反复出现或明显困扰用户的问题
|
|||
|
|
- event: 今天发生的重要事件
|
|||
|
|
|
|||
|
|
输出格式(只输出 JSON,不要其他内容):
|
|||
|
|
[
|
|||
|
|
{{"type": "fact", "content": "...", "confidence": 0.9}},
|
|||
|
|
{{"type": "goal", "content": "...", "confidence": 0.7}}
|
|||
|
|
]"""
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class ExtractedMemory:
|
|||
|
|
"""A memory extracted from conversation."""
|
|||
|
|
|
|||
|
|
memory_type: str # "fact" | "preference" | "goal" | "pain_point" | "event"
|
|||
|
|
content: str
|
|||
|
|
confidence: float # 0.0-1.0
|
|||
|
|
source_conversation_id: str | None = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
class MemoryExtractor:
|
|||
|
|
"""Extract memories from conversations using LLM."""
|
|||
|
|
|
|||
|
|
SIMILARITY_THRESHOLD = 0.85
|
|||
|
|
|
|||
|
|
async def extract_from_conversation(
|
|||
|
|
self,
|
|||
|
|
db: AsyncSession,
|
|||
|
|
user_id: str,
|
|||
|
|
conversation_id: str,
|
|||
|
|
messages: list[Message],
|
|||
|
|
) -> list[ExtractedMemory]:
|
|||
|
|
"""Extract memories from conversation messages.
|
|||
|
|
|
|||
|
|
1. Build conversation text
|
|||
|
|
2. Call LLM to extract memories
|
|||
|
|
3. Parse JSON response
|
|||
|
|
4. Deduplicate against existing memories
|
|||
|
|
5. Return new memories
|
|||
|
|
"""
|
|||
|
|
if len(messages) < 2:
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
# 1. Build conversation text
|
|||
|
|
conversation_text = "\n".join(f"[{m.role}] {m.content}" for m in messages[-10:])
|
|||
|
|
|
|||
|
|
# 2. Call LLM
|
|||
|
|
extracted = await self._call_llm_extract(conversation_text)
|
|||
|
|
if not extracted:
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
# 3. Build ExtractedMemory objects
|
|||
|
|
new_memories = [
|
|||
|
|
ExtractedMemory(
|
|||
|
|
memory_type=m["type"],
|
|||
|
|
content=m["content"],
|
|||
|
|
confidence=m.get("confidence", 0.5),
|
|||
|
|
source_conversation_id=conversation_id,
|
|||
|
|
)
|
|||
|
|
for m in extracted
|
|||
|
|
if m.get("type") in MEMORY_TYPES and m.get("content")
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
# 4. Deduplicate
|
|||
|
|
new_memories = await self._deduplicate(db, user_id, new_memories)
|
|||
|
|
|
|||
|
|
return new_memories
|
|||
|
|
|
|||
|
|
async def _call_llm_extract(self, conversation_text: str) -> list[dict]:
|
|||
|
|
"""Call LLM to extract memories from conversation text."""
|
|||
|
|
from app.services.llm_service import get_llm
|
|||
|
|
from langchain_core.messages import HumanMessage, SystemMessage
|
|||
|
|
|
|||
|
|
prompt = EXTRACT_PROMPT.format(conversation_text=conversation_text)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
llm = get_llm()
|
|||
|
|
response = await llm.invoke(
|
|||
|
|
[
|
|||
|
|
SystemMessage(
|
|||
|
|
content="你是一个记忆提取助手。从对话中提取用户的记忆信息,只返回JSON数组,不要其他内容。"
|
|||
|
|
),
|
|||
|
|
HumanMessage(content=prompt),
|
|||
|
|
]
|
|||
|
|
)
|
|||
|
|
content = response.content.strip()
|
|||
|
|
|
|||
|
|
# Try to extract JSON from response
|
|||
|
|
if content.startswith("["):
|
|||
|
|
return json.loads(content)
|
|||
|
|
# Try to find JSON in response
|
|||
|
|
start = content.find("[")
|
|||
|
|
end = content.rfind("]") + 1
|
|||
|
|
if start != -1 and end != 0:
|
|||
|
|
return json.loads(content[start:end])
|
|||
|
|
return []
|
|||
|
|
except (json.JSONDecodeError, Exception) as e:
|
|||
|
|
logger.warning(f"Memory extraction LLM call failed: {e}")
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
async def _deduplicate(
|
|||
|
|
self,
|
|||
|
|
db: AsyncSession,
|
|||
|
|
user_id: str,
|
|||
|
|
new_memories: list[ExtractedMemory],
|
|||
|
|
) -> list[ExtractedMemory]:
|
|||
|
|
"""Filter duplicates against existing UserMemory.
|
|||
|
|
|
|||
|
|
Similarity > 0.85 → reinforce existing instead of creating new.
|
|||
|
|
Returns only truly new memories.
|
|||
|
|
"""
|
|||
|
|
if not new_memories:
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
result = await db.execute(
|
|||
|
|
select(UserMemory)
|
|||
|
|
.where(
|
|||
|
|
UserMemory.user_id == user_id,
|
|||
|
|
UserMemory.is_archived == False,
|
|||
|
|
)
|
|||
|
|
.limit(20)
|
|||
|
|
)
|
|||
|
|
existing = list(result.scalars().all())
|
|||
|
|
|
|||
|
|
deduplicated = []
|
|||
|
|
for new_mem in new_memories:
|
|||
|
|
is_duplicate = False
|
|||
|
|
for existing_mem in existing:
|
|||
|
|
if self._is_similar(new_mem.content, existing_mem.content):
|
|||
|
|
# Reinforce existing memory instead of creating new
|
|||
|
|
await self._reinforce_existing(db, existing_mem)
|
|||
|
|
is_duplicate = True
|
|||
|
|
break
|
|||
|
|
if not is_duplicate:
|
|||
|
|
deduplicated.append(new_mem)
|
|||
|
|
|
|||
|
|
return deduplicated
|
|||
|
|
|
|||
|
|
def _is_similar(self, text1: str, text2: str) -> bool:
|
|||
|
|
"""Simple similarity check using keyword overlap.
|
|||
|
|
|
|||
|
|
In production would use embedding similarity.
|
|||
|
|
"""
|
|||
|
|
# Simple word overlap ratio
|
|||
|
|
words1 = set(text1.lower().split())
|
|||
|
|
words2 = set(text2.lower().split())
|
|||
|
|
if not words1 or not words2:
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
overlap = len(words1 & words2)
|
|||
|
|
union = len(words1 | words2)
|
|||
|
|
jaccard = overlap / union if union > 0 else 0
|
|||
|
|
|
|||
|
|
# Also check substring
|
|||
|
|
if jaccard > 0.5:
|
|||
|
|
return True
|
|||
|
|
if len(text1) > 5 and len(text2) > 5:
|
|||
|
|
if text1[:20].lower() == text2[:20].lower():
|
|||
|
|
return True
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
async def _reinforce_existing(
|
|||
|
|
self,
|
|||
|
|
db: AsyncSession,
|
|||
|
|
memory: UserMemory,
|
|||
|
|
) -> None:
|
|||
|
|
"""Reinforce an existing memory instead of creating a duplicate."""
|
|||
|
|
from app.services.memory.reinforcement import MemoryReinforcement
|
|||
|
|
|
|||
|
|
reinforcement = MemoryReinforcement()
|
|||
|
|
reinforcement.trigger(memory)
|
|||
|
|
await db.commit()
|
|||
|
|
|
|||
|
|
async def save_memories(
|
|||
|
|
self,
|
|||
|
|
db: AsyncSession,
|
|||
|
|
user_id: str,
|
|||
|
|
conversation_id: str,
|
|||
|
|
memories: list[ExtractedMemory],
|
|||
|
|
) -> list[UserMemory]:
|
|||
|
|
"""Save extracted memories as UserMemory records."""
|
|||
|
|
from app.services.memory.importance_scorer import ImportanceScorer
|
|||
|
|
|
|||
|
|
saved = []
|
|||
|
|
scorer = ImportanceScorer()
|
|||
|
|
|
|||
|
|
for mem in memories:
|
|||
|
|
user_mem = UserMemory(
|
|||
|
|
user_id=user_id,
|
|||
|
|
memory_type=mem.memory_type,
|
|||
|
|
content=mem.content,
|
|||
|
|
source_conversation_id=mem.source_conversation_id,
|
|||
|
|
importance_score=0.5, # Will be updated by scorer
|
|||
|
|
importance_level="medium",
|
|||
|
|
)
|
|||
|
|
# Update importance based on content
|
|||
|
|
scorer.update_memory_importance(user_mem)
|
|||
|
|
|
|||
|
|
db.add(user_mem)
|
|||
|
|
saved.append(user_mem)
|
|||
|
|
|
|||
|
|
if saved:
|
|||
|
|
await db.commit()
|
|||
|
|
for mem in saved:
|
|||
|
|
await db.refresh(mem)
|
|||
|
|
|
|||
|
|
return saved
|