""" Jarvis 记忆系统 (基于 Mem0) 三层记忆: 短期(对话历史) → 中期(摘要) → 长期(用户画像) 底层使用 Mem0 实现事实提取、时间线、矛盾解决和遗忘机制 """ import logging import os import re from datetime import UTC, datetime from typing import Optional, Any from sqlalchemy import select, desc, func from sqlalchemy.ext.asyncio import AsyncSession from app.models.conversation import Conversation, Message from app.models.memory import UserMemory from app.models.user import User from app.services.brain_service import BrainService from app.config import settings as _settings try: from mem0 import Memory MEM0_AVAILABLE = True except ImportError: MEM0_AVAILABLE = False Memory = None logger = logging.getLogger(__name__) async def _get_user_embedding_config(db: AsyncSession, user_id: str) -> dict | None: """从用户配置中获取 embedding 模型配置""" result = await db.execute(select(User).where(User.id == user_id)) user = result.scalar_one_or_none() if not user or not user.llm_config: return None embedding_models = user.llm_config.get("embedding", []) for model in embedding_models: if model.get("enabled") and model.get("model"): return { "model": model.get("model"), "base_url": model.get("base_url") or _settings.EMBEDDING_BASE_URL, "api_key": model.get("api_key") or _settings.EMBEDDING_API_KEY or _settings.OPENAI_API_KEY, } return None async def _get_user_chat_config(db: AsyncSession, user_id: str) -> dict | None: """从用户配置中获取 chat 模型配置""" result = await db.execute(select(User).where(User.id == user_id)) user = result.scalar_one_or_none() if not user or not user.llm_config: return None chat_models = user.llm_config.get("chat", []) for model in chat_models: if model.get("enabled") and model.get("model"): return { "model": model.get("model"), "base_url": model.get("base_url") or _settings.OPENAI_BASE_URL, "api_key": model.get("api_key") or _settings.OPENAI_API_KEY, } return None class Mem0Client: """Mem0 客户端 - 按用户隔离""" _instances: dict[str, Memory] = {} _persist_dir: str = "./data/mem0" async def get_memory(self, db: AsyncSession, user_id: str) -> Memory: """获取指定用户的 Mem0 实例""" cache_key = user_id if cache_key not in self._instances: self._instances[cache_key] = await self._init_memory(db, user_id) return self._instances[cache_key] async def _init_memory(self, db: AsyncSession, user_id: str) -> Memory: if not MEM0_AVAILABLE: raise RuntimeError("mem0ai 未安装,请运行: pip install mem0ai") os.makedirs(self._persist_dir, exist_ok=True) llm_config = { "model": _settings.OPENAI_MODEL, "base_url": _settings.OPENAI_BASE_URL, "api_key": _settings.OPENAI_API_KEY, } embed_config = _settings.EMBEDDING_MODEL embed_base_url = _settings.EMBEDDING_BASE_URL embed_api_key = _settings.EMBEDDING_API_KEY or _settings.OPENAI_API_KEY if db and user_id: try: user_chat = await _get_user_chat_config(db, user_id) if user_chat: llm_config = user_chat except Exception: pass try: user_embed = await _get_user_embedding_config(db, user_id) if user_embed: embed_config = user_embed["model"] embed_base_url = user_embed["base_url"] embed_api_key = user_embed["api_key"] except Exception: pass config = { "vector_store": { "provider": "chroma", "config": { "collection_name": f"jarvis_memory_{user_id}", "path": self._persist_dir, }, }, "llm": { "provider": "openai", "config": { "model": llm_config["model"], "api_key": llm_config["api_key"], "base_url": llm_config["base_url"], }, }, "embedder": { "provider": "openai", "config": { "model": embed_config, "api_key": embed_api_key, "base_url": embed_base_url, }, }, } return Memory.from_config(config) _mem0_client = Mem0Client() async def get_mem0(db: AsyncSession, user_id: str) -> Memory: """获取指定用户的 Mem0 实例""" return await _mem0_client.get_memory(db, user_id) # ———— 短期记忆: 对话历史 ———— async def load_conversation_history( db: AsyncSession, conversation_id: str, limit: int = 20, ) -> list[Message]: """加载指定对话的历史消息""" result = await db.execute( select(Message) .where(Message.conversation_id == conversation_id) .order_by(Message.created_at) .limit(limit) ) return list(result.scalars().all()) async def get_conversation_turn_count(db: AsyncSession, conversation_id: str) -> int: """获取对话轮数(用户消息数)""" result = await db.execute( select(func.count(Message.id)).where( Message.conversation_id == conversation_id, Message.role == "user", ) ) return result.scalar() or 0 # ———— 中期记忆: 对话摘要 ———— SUMMARIZE_THRESHOLD = 8 MAX_HISTORY_TURNS = 10 async def should_summarize(db: AsyncSession, conversation_id: str) -> bool: """判断当前对话是否需要摘要""" from app.models.memory import MemorySummary turn_count = await get_conversation_turn_count(db, conversation_id) result = await db.execute( select(MemorySummary) .where(MemorySummary.conversation_id == conversation_id) .order_by(desc(MemorySummary.turn_count)) .limit(1) ) latest_summary = result.scalar_one_or_none() if latest_summary: return turn_count - latest_summary.turn_count >= SUMMARIZE_THRESHOLD return turn_count >= SUMMARIZE_THRESHOLD async def generate_summary( db: AsyncSession, conversation_id: str, messages: list[Message], ) -> str: """生成对话摘要""" from app.services.llm_service import get_llm from langchain_core.messages import HumanMessage, SystemMessage history_text = "\n".join(f"[{m.role}] {m.content}" for m in messages) llm = get_llm() response = await llm.invoke( [ SystemMessage( content="你是一个记忆助手。请用简洁的中文总结以下对话的核心内容," "提取关键信息、用户偏好、待办事项等。不超过150字。" ), HumanMessage(content=history_text), ] ) return response.content.strip() async def save_summary( db: AsyncSession, user_id: str, conversation_id: str, summary_text: str, turn_count: int, ) -> Any: """保存对话摘要到数据库""" from app.models.memory import MemorySummary summary = MemorySummary( user_id=user_id, conversation_id=conversation_id, summary_text=summary_text, turn_count=turn_count, ) db.add(summary) await db.commit() await db.refresh(summary) return summary async def get_summaries( db: AsyncSession, conversation_id: str, ) -> list[Any]: """获取某对话的所有历史摘要""" from app.models.memory import MemorySummary result = await db.execute( select(MemorySummary) .where(MemorySummary.conversation_id == conversation_id) .order_by(MemorySummary.summary_at) ) return list(result.scalars().all()) # ———— 长期记忆: 基于 Mem0 ———— async def extract_user_memories( db: AsyncSession, user_id: str, conversation_id: str, messages: list[Message], ) -> list[dict]: """ 从对话中提取用户记忆并存储到 Mem0。 Mem0 会自动处理: - 事实提取 - 时间线追踪 - 矛盾解决 - 遗忘机制 """ if len(messages) < 2: return [] history_text = "\n".join(f"[{m.role}] {m.content}" for m in messages[-10:]) try: mem0 = await get_mem0(db, user_id) result = mem0.add( messages=[{"role": m.role, "content": m.content} for m in messages[-10:]], user_id=user_id, metadata={ "conversation_id": conversation_id, "source": "jarvis_memory", }, ) return result.get("results", []) except Exception as e: print(f"Mem0 extract error: {e}") return [] def _extract_memory_query_tokens(query: str) -> list[str]: normalized_query = (query or "").lower() tokens = [token for token in re.findall(r"[a-z0-9]+", normalized_query) if len(token) >= 3] for chunk in re.findall(r"[\u4e00-\u9fff]+", query or ""): stripped_chunk = chunk.strip() if len(stripped_chunk) >= 4: tokens.append(stripped_chunk) if len(stripped_chunk) > 6: tokens.extend( stripped_chunk[index:index + 4] for index in range(len(stripped_chunk) - 3) ) return list(dict.fromkeys(tokens)) async def recall_user_memories( db: AsyncSession, user_id: str, query: str, top_k: int = 5, ) -> list[dict]: """ 根据当前输入召回相关的用户记忆。 使用 Mem0 的语义搜索;如果 Mem0 不可用或失败,则回退到本地 UserMemory。 """ try: mem0 = await get_mem0(db, user_id) results = mem0.search( query=query, filters={"user_id": user_id}, limit=top_k, ) mem0_results = results.get("results", []) if mem0_results: return mem0_results except Exception as e: print(f"Mem0 search error: {e}") query_tokens = _extract_memory_query_tokens(query) statement = select(UserMemory).where(UserMemory.user_id == user_id) result = await db.execute(statement.order_by(UserMemory.importance.desc(), UserMemory.created_at.desc())) fallback_memories = list(result.scalars().all()) if _contains_hint(_normalize_query(query), MEMORY_QUERY_HINTS) or _matches_memory_query_pattern(_normalize_query(query)): return fallback_memories[:top_k] if query_tokens: matched_memories = [ memory for memory in fallback_memories if any(token in (memory.content or '').lower() for token in query_tokens) ] return matched_memories[:top_k] return [] async def _mark_memories_recalled(db: AsyncSession, memories: list[UserMemory]) -> None: recalled_at = datetime.now(UTC) updated = False for memory in memories: memory.is_recalled = True memory.recall_count = (memory.recall_count or 0) + 1 memory.last_recalled_at = recalled_at updated = True if updated: await db.commit() async def _run_tolerated_section( db: AsyncSession, section_name: str, builder, ) -> str: try: return await builder() except Exception: logger.warning( "[MemoryService] %s失败,继续构建剩余上下文", section_name, exc_info=True, ) return "" async def get_user_profile(db: AsyncSession, user_id: str) -> dict: """ 获取用户画像。 Mem0 的 profile API 会返回 static 和 dynamic facts。 """ try: mem0 = await get_mem0(db, user_id) result = mem0.history(user_id=user_id) return { "memories": result.get("results", []), "static": [], "dynamic": [], } except Exception as e: print(f"Mem0 profile error: {e}") return {"memories": [], "static": [], "dynamic": []} # ———— 记忆组装: 供 Agent 使用的上下文 ———— MEMORY_QUERY_HINTS = ( "记住", "记下", "记一下", "记着", "提醒", "偏好", "习惯", ) MEMORY_QUERY_PATTERNS = ( re.compile(r"\bremember\s+(?:that\s+)?i\b"), ) GROUNDING_QUERY_HINTS = ( "根据文档", "严格根据", "只根据", "文档内容", "grounded", "strictly based on", "based on the document", "based on the docs", "document only", "docs only", "only use the document", "only use the docs", ) AVOID_USER_MEMORY_HINTS = ( "不要结合我的个人偏好", "不要结合个人偏好", "不要结合偏好", "不要结合我的记忆", "不要结合记忆", ) def _normalize_query(text: str) -> str: return text.strip().lower() def _contains_hint(text: str, hints: tuple[str, ...]) -> bool: return any(hint in text for hint in hints) def _matches_memory_query_pattern(text: str) -> bool: return any(pattern.search(text) for pattern in MEMORY_QUERY_PATTERNS) def _should_include_user_memories(query: str) -> bool: normalized_query = _normalize_query(query) if _contains_hint(normalized_query, GROUNDING_QUERY_HINTS): return False if _contains_hint(normalized_query, AVOID_USER_MEMORY_HINTS): return False return True def _should_include_summaries(query: str) -> bool: normalized_query = _normalize_query(query) if _contains_hint(normalized_query, GROUNDING_QUERY_HINTS): return False if _contains_hint(normalized_query, MEMORY_QUERY_HINTS): return False if _matches_memory_query_pattern(normalized_query): return False return True async def _build_user_memory_section( db: AsyncSession, user_id: str, current_query: str, ) -> str: memories = await recall_user_memories(db, user_id, current_query, top_k=5) if not memories: return "" lines = [] recalled_user_memories: list[UserMemory] = [] for memory in memories: if isinstance(memory, UserMemory): memory_text = memory.content memory_type = memory.memory_type recalled_user_memories.append(memory) else: memory_text = memory.get("memory", memory.get("text", "")) memory_type = memory.get("memory_type") if not memory_text: continue if memory_type: lines.append(f" [{memory_type}] {memory_text}") else: lines.append(f" - {memory_text}") if not lines: return "" if recalled_user_memories: await _mark_memories_recalled(db, recalled_user_memories) return "【用户记忆】\n" + "\n".join(lines) async def _build_summary_section(db: AsyncSession, conversation_id: str) -> str: summaries = await get_summaries(db, conversation_id) if not summaries: return "" recent = summaries[-2:] lines = [f"[对话摘要{i + 1}] {summary.summary_text}" for i, summary in enumerate(recent)] return "【之前对话摘要】\n" + "\n".join(lines) async def _build_brain_section( db: AsyncSession, user_id: str, current_query: str, ) -> str: brain_memories = await BrainService(db).recall_memories(user_id, current_query, top_k=3) if not brain_memories: return "" lines = [f"- {memory.title}: {memory.content}" for memory in brain_memories] return "【知识大脑】\n" + "\n".join(lines) async def build_memory_context( db: AsyncSession, user_id: str, conversation_id: str, current_query: str, ) -> str: """ 构建完整的记忆上下文字符串, 供注入到 Agent system prompt 中使用。 """ parts: list[str] = [] if _should_include_user_memories(current_query): user_memory_section = await _run_tolerated_section( db, "用户记忆召回", lambda: _build_user_memory_section(db, user_id, current_query), ) if user_memory_section: parts.append(user_memory_section) if _should_include_summaries(current_query): summary_section = await _run_tolerated_section( db, "对话摘要加载", lambda: _build_summary_section(db, conversation_id), ) if summary_section: parts.append(summary_section) brain_section = await _run_tolerated_section( db, "知识大脑召回", lambda: _build_brain_section(db, user_id, current_query), ) if brain_section: parts.append(brain_section) if not parts: return "" return "\n\n".join(parts) async def try_auto_summarize( db: AsyncSession, user_id: str, conversation_id: str, ) -> bool: """ 检查是否需要摘要,如果需要则生成并保存。 同时将对话内容存入 Mem0 进行记忆提取。 """ if not await should_summarize(db, conversation_id): return False messages = await load_conversation_history(db, conversation_id, limit=30) if len(messages) < 3: return False try: summary_text = await generate_summary(db, conversation_id, messages) turn_count = await get_conversation_turn_count(db, conversation_id) await save_summary(db, user_id, conversation_id, summary_text, turn_count) await extract_user_memories(db, user_id, conversation_id, messages) return True except Exception as e: print(f"Auto summarize error: {e}") return False async def forget_memory(db: AsyncSession, user_id: str, memory_id: str) -> bool: """ 主动遗忘某条记忆。 """ try: mem0 = await get_mem0(db, user_id) mem0.delete(memory_id, user_id=user_id) return True except Exception as e: print(f"Mem0 delete error: {e}") return False async def update_memory( db: AsyncSession, user_id: str, memory_id: str, content: str, ) -> bool: """ 更新某条记忆。Mem0 会自动处理矛盾检测。 """ try: mem0 = await get_mem0(db, user_id) mem0.update(memory_id, content, user_id=user_id) return True except Exception as e: print(f"Mem0 update error: {e}") return False