""" MemoryExtractor Automatically extracts memories from conversations using LLM. Extracts 5 types: fact, preference, goal, pain_point, event. Deduplicates against existing memories (similarity > 0.85 → reinforce instead of create). """ import json import logging from dataclasses import dataclass, field from typing import Any from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.models.conversation import Message from app.models.memory import UserMemory logger = logging.getLogger(__name__) MEMORY_TYPES = ("fact", "preference", "goal", "pain_point", "event") EXTRACT_PROMPT = """从以下对话中提取用户的记忆信息,以 JSON 格式返回。 对话内容: {conversation_text} 提取以下类型(只提取明确信息,不要猜测): - fact: 关于用户的客观事实(职业、 location、技能、健康状况等) - preference: 用户的偏好和习惯(回答风格偏好、沟通偏好、生活习惯等) - goal: 用户提到的目标或计划(想做什么、计划做什么、目标是什么) - pain_point: 反复出现或明显困扰用户的问题 - event: 今天发生的重要事件 输出格式(只输出 JSON,不要其他内容): [ {{"type": "fact", "content": "...", "confidence": 0.9}}, {{"type": "goal", "content": "...", "confidence": 0.7}} ]""" @dataclass class ExtractedMemory: """A memory extracted from conversation.""" memory_type: str # "fact" | "preference" | "goal" | "pain_point" | "event" content: str confidence: float # 0.0-1.0 source_conversation_id: str | None = None class MemoryExtractor: """Extract memories from conversations using LLM.""" SIMILARITY_THRESHOLD = 0.85 async def extract_from_conversation( self, db: AsyncSession, user_id: str, conversation_id: str, messages: list[Message], ) -> list[ExtractedMemory]: """Extract memories from conversation messages. 1. Build conversation text 2. Call LLM to extract memories 3. Parse JSON response 4. Deduplicate against existing memories 5. Return new memories """ if len(messages) < 2: return [] # 1. Build conversation text conversation_text = "\n".join(f"[{m.role}] {m.content}" for m in messages[-10:]) # 2. Call LLM extracted = await self._call_llm_extract(conversation_text) if not extracted: return [] # 3. Build ExtractedMemory objects new_memories = [ ExtractedMemory( memory_type=m["type"], content=m["content"], confidence=m.get("confidence", 0.5), source_conversation_id=conversation_id, ) for m in extracted if m.get("type") in MEMORY_TYPES and m.get("content") ] # 4. Deduplicate new_memories = await self._deduplicate(db, user_id, new_memories) return new_memories async def _call_llm_extract(self, conversation_text: str) -> list[dict]: """Call LLM to extract memories from conversation text.""" from app.services.llm_service import get_llm from langchain_core.messages import HumanMessage, SystemMessage prompt = EXTRACT_PROMPT.format(conversation_text=conversation_text) try: llm = get_llm() response = await llm.invoke( [ SystemMessage( content="你是一个记忆提取助手。从对话中提取用户的记忆信息,只返回JSON数组,不要其他内容。" ), HumanMessage(content=prompt), ] ) content = response.content.strip() # Try to extract JSON from response if content.startswith("["): return json.loads(content) # Try to find JSON in response start = content.find("[") end = content.rfind("]") + 1 if start != -1 and end != 0: return json.loads(content[start:end]) return [] except (json.JSONDecodeError, Exception) as e: logger.warning(f"Memory extraction LLM call failed: {e}") return [] async def _deduplicate( self, db: AsyncSession, user_id: str, new_memories: list[ExtractedMemory], ) -> list[ExtractedMemory]: """Filter duplicates against existing UserMemory. Similarity > 0.85 → reinforce existing instead of creating new. Returns only truly new memories. """ if not new_memories: return [] result = await db.execute( select(UserMemory) .where( UserMemory.user_id == user_id, UserMemory.is_archived == False, ) .limit(20) ) existing = list(result.scalars().all()) deduplicated = [] for new_mem in new_memories: is_duplicate = False for existing_mem in existing: if self._is_similar(new_mem.content, existing_mem.content): # Reinforce existing memory instead of creating new await self._reinforce_existing(db, existing_mem) is_duplicate = True break if not is_duplicate: deduplicated.append(new_mem) return deduplicated def _is_similar(self, text1: str, text2: str) -> bool: """Simple similarity check using keyword overlap. In production would use embedding similarity. """ # Simple word overlap ratio words1 = set(text1.lower().split()) words2 = set(text2.lower().split()) if not words1 or not words2: return False overlap = len(words1 & words2) union = len(words1 | words2) jaccard = overlap / union if union > 0 else 0 # Also check substring if jaccard > 0.5: return True if len(text1) > 5 and len(text2) > 5: if text1[:20].lower() == text2[:20].lower(): return True return False async def _reinforce_existing( self, db: AsyncSession, memory: UserMemory, ) -> None: """Reinforce an existing memory instead of creating a duplicate.""" from app.services.memory.reinforcement import MemoryReinforcement reinforcement = MemoryReinforcement() reinforcement.trigger(memory) await db.commit() async def save_memories( self, db: AsyncSession, user_id: str, conversation_id: str, memories: list[ExtractedMemory], ) -> list[UserMemory]: """Save extracted memories as UserMemory records.""" from app.services.memory.importance_scorer import ImportanceScorer saved = [] scorer = ImportanceScorer() for mem in memories: user_mem = UserMemory( user_id=user_id, memory_type=mem.memory_type, content=mem.content, source_conversation_id=mem.source_conversation_id, importance_score=0.5, # Will be updated by scorer importance_level="medium", ) # Update importance based on content scorer.update_memory_importance(user_mem) db.add(user_mem) saved.append(user_mem) if saved: await db.commit() for mem in saved: await db.refresh(mem) return saved