Align the L3 graph, agent service, and sync tool shims on one canonical continuity contract so clarification resumes and persisted snapshots behave consistently. Add targeted regressions and hardening notes covering system-message coalescing, async bridge usage, and continuity rehydration.
637 lines
18 KiB
Python
637 lines
18 KiB
Python
"""
|
|
Jarvis 记忆系统 (基于 Mem0)
|
|
三层记忆: 短期(对话历史) → 中期(摘要) → 长期(用户画像)
|
|
底层使用 Mem0 实现事实提取、时间线、矛盾解决和遗忘机制
|
|
"""
|
|
|
|
import logging
|
|
import os
|
|
import re
|
|
from datetime import UTC, datetime
|
|
from typing import Optional, Any
|
|
from sqlalchemy import select, desc, func
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from app.models.conversation import Conversation, Message
|
|
from app.models.memory import UserMemory
|
|
from app.models.user import User
|
|
from app.services.brain_service import BrainService
|
|
from app.config import settings as _settings
|
|
|
|
try:
|
|
from mem0 import Memory
|
|
|
|
MEM0_AVAILABLE = True
|
|
except ImportError:
|
|
MEM0_AVAILABLE = False
|
|
Memory = None
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
async def _get_user_embedding_config(db: AsyncSession, user_id: str) -> dict | None:
|
|
"""从用户配置中获取 embedding 模型配置"""
|
|
result = await db.execute(select(User).where(User.id == user_id))
|
|
user = result.scalar_one_or_none()
|
|
if not user or not user.llm_config:
|
|
return None
|
|
|
|
embedding_models = user.llm_config.get("embedding", [])
|
|
for model in embedding_models:
|
|
if model.get("enabled") and model.get("model"):
|
|
return {
|
|
"model": model.get("model"),
|
|
"base_url": model.get("base_url") or _settings.EMBEDDING_BASE_URL,
|
|
"api_key": model.get("api_key")
|
|
or _settings.EMBEDDING_API_KEY
|
|
or _settings.OPENAI_API_KEY,
|
|
}
|
|
return None
|
|
|
|
|
|
async def _get_user_chat_config(db: AsyncSession, user_id: str) -> dict | None:
|
|
"""从用户配置中获取 chat 模型配置"""
|
|
result = await db.execute(select(User).where(User.id == user_id))
|
|
user = result.scalar_one_or_none()
|
|
if not user or not user.llm_config:
|
|
return None
|
|
|
|
chat_models = user.llm_config.get("chat", [])
|
|
for model in chat_models:
|
|
if model.get("enabled") and model.get("model"):
|
|
return {
|
|
"model": model.get("model"),
|
|
"base_url": model.get("base_url") or _settings.OPENAI_BASE_URL,
|
|
"api_key": model.get("api_key") or _settings.OPENAI_API_KEY,
|
|
}
|
|
return None
|
|
|
|
|
|
class Mem0Client:
|
|
"""Mem0 客户端 - 按用户隔离"""
|
|
|
|
_instances: dict[str, Memory] = {}
|
|
_persist_dir: str = "./data/mem0"
|
|
|
|
async def get_memory(self, db: AsyncSession, user_id: str) -> Memory:
|
|
"""获取指定用户的 Mem0 实例"""
|
|
cache_key = user_id
|
|
|
|
if cache_key not in self._instances:
|
|
self._instances[cache_key] = await self._init_memory(db, user_id)
|
|
|
|
return self._instances[cache_key]
|
|
|
|
async def _init_memory(self, db: AsyncSession, user_id: str) -> Memory:
|
|
if not MEM0_AVAILABLE:
|
|
raise RuntimeError("mem0ai 未安装,请运行: pip install mem0ai")
|
|
|
|
os.makedirs(self._persist_dir, exist_ok=True)
|
|
|
|
llm_config = {
|
|
"model": _settings.OPENAI_MODEL,
|
|
"base_url": _settings.OPENAI_BASE_URL,
|
|
"api_key": _settings.OPENAI_API_KEY,
|
|
}
|
|
|
|
embed_config = _settings.EMBEDDING_MODEL
|
|
embed_base_url = _settings.EMBEDDING_BASE_URL
|
|
embed_api_key = _settings.EMBEDDING_API_KEY or _settings.OPENAI_API_KEY
|
|
|
|
if db and user_id:
|
|
try:
|
|
user_chat = await _get_user_chat_config(db, user_id)
|
|
if user_chat:
|
|
llm_config = user_chat
|
|
except Exception:
|
|
pass
|
|
|
|
try:
|
|
user_embed = await _get_user_embedding_config(db, user_id)
|
|
if user_embed:
|
|
embed_config = user_embed["model"]
|
|
embed_base_url = user_embed["base_url"]
|
|
embed_api_key = user_embed["api_key"]
|
|
except Exception:
|
|
pass
|
|
|
|
config = {
|
|
"vector_store": {
|
|
"provider": "chroma",
|
|
"config": {
|
|
"collection_name": f"jarvis_memory_{user_id}",
|
|
"path": self._persist_dir,
|
|
},
|
|
},
|
|
"llm": {
|
|
"provider": "openai",
|
|
"config": {
|
|
"model": llm_config["model"],
|
|
"api_key": llm_config["api_key"],
|
|
"base_url": llm_config["base_url"],
|
|
},
|
|
},
|
|
"embedder": {
|
|
"provider": "openai",
|
|
"config": {
|
|
"model": embed_config,
|
|
"api_key": embed_api_key,
|
|
"base_url": embed_base_url,
|
|
},
|
|
},
|
|
}
|
|
|
|
return Memory.from_config(config)
|
|
|
|
|
|
_mem0_client = Mem0Client()
|
|
|
|
|
|
async def get_mem0(db: AsyncSession, user_id: str) -> Memory:
|
|
"""获取指定用户的 Mem0 实例"""
|
|
return await _mem0_client.get_memory(db, user_id)
|
|
|
|
|
|
# ———— 短期记忆: 对话历史 ————
|
|
|
|
|
|
async def load_conversation_history(
|
|
db: AsyncSession,
|
|
conversation_id: str,
|
|
limit: int = 20,
|
|
) -> list[Message]:
|
|
"""加载指定对话的历史消息"""
|
|
result = await db.execute(
|
|
select(Message)
|
|
.where(Message.conversation_id == conversation_id)
|
|
.order_by(Message.created_at)
|
|
.limit(limit)
|
|
)
|
|
return list(result.scalars().all())
|
|
|
|
|
|
async def get_conversation_turn_count(db: AsyncSession, conversation_id: str) -> int:
|
|
"""获取对话轮数(用户消息数)"""
|
|
result = await db.execute(
|
|
select(func.count(Message.id)).where(
|
|
Message.conversation_id == conversation_id,
|
|
Message.role == "user",
|
|
)
|
|
)
|
|
return result.scalar() or 0
|
|
|
|
|
|
# ———— 中期记忆: 对话摘要 ————
|
|
|
|
SUMMARIZE_THRESHOLD = 8
|
|
MAX_HISTORY_TURNS = 10
|
|
|
|
|
|
async def should_summarize(db: AsyncSession, conversation_id: str) -> bool:
|
|
"""判断当前对话是否需要摘要"""
|
|
from app.models.memory import MemorySummary
|
|
|
|
turn_count = await get_conversation_turn_count(db, conversation_id)
|
|
result = await db.execute(
|
|
select(MemorySummary)
|
|
.where(MemorySummary.conversation_id == conversation_id)
|
|
.order_by(desc(MemorySummary.turn_count))
|
|
.limit(1)
|
|
)
|
|
latest_summary = result.scalar_one_or_none()
|
|
if latest_summary:
|
|
return turn_count - latest_summary.turn_count >= SUMMARIZE_THRESHOLD
|
|
return turn_count >= SUMMARIZE_THRESHOLD
|
|
|
|
|
|
async def generate_summary(
|
|
db: AsyncSession,
|
|
conversation_id: str,
|
|
messages: list[Message],
|
|
) -> str:
|
|
"""生成对话摘要"""
|
|
from app.services.llm_service import get_llm
|
|
from langchain_core.messages import HumanMessage, SystemMessage
|
|
|
|
history_text = "\n".join(f"[{m.role}] {m.content}" for m in messages)
|
|
llm = get_llm()
|
|
response = await llm.invoke(
|
|
[
|
|
SystemMessage(
|
|
content="你是一个记忆助手。请用简洁的中文总结以下对话的核心内容,"
|
|
"提取关键信息、用户偏好、待办事项等。不超过150字。"
|
|
),
|
|
HumanMessage(content=history_text),
|
|
]
|
|
)
|
|
return response.content.strip()
|
|
|
|
|
|
async def save_summary(
|
|
db: AsyncSession,
|
|
user_id: str,
|
|
conversation_id: str,
|
|
summary_text: str,
|
|
turn_count: int,
|
|
) -> Any:
|
|
"""保存对话摘要到数据库"""
|
|
from app.models.memory import MemorySummary
|
|
|
|
summary = MemorySummary(
|
|
user_id=user_id,
|
|
conversation_id=conversation_id,
|
|
summary_text=summary_text,
|
|
turn_count=turn_count,
|
|
)
|
|
db.add(summary)
|
|
await db.commit()
|
|
await db.refresh(summary)
|
|
return summary
|
|
|
|
|
|
async def get_summaries(
|
|
db: AsyncSession,
|
|
conversation_id: str,
|
|
) -> list[Any]:
|
|
"""获取某对话的所有历史摘要"""
|
|
from app.models.memory import MemorySummary
|
|
|
|
result = await db.execute(
|
|
select(MemorySummary)
|
|
.where(MemorySummary.conversation_id == conversation_id)
|
|
.order_by(MemorySummary.summary_at)
|
|
)
|
|
return list(result.scalars().all())
|
|
|
|
|
|
# ———— 长期记忆: 基于 Mem0 ————
|
|
|
|
|
|
async def extract_user_memories(
|
|
db: AsyncSession,
|
|
user_id: str,
|
|
conversation_id: str,
|
|
messages: list[Message],
|
|
) -> list[dict]:
|
|
"""
|
|
从对话中提取用户记忆并存储到 Mem0。
|
|
Mem0 会自动处理:
|
|
- 事实提取
|
|
- 时间线追踪
|
|
- 矛盾解决
|
|
- 遗忘机制
|
|
"""
|
|
if len(messages) < 2:
|
|
return []
|
|
|
|
history_text = "\n".join(f"[{m.role}] {m.content}" for m in messages[-10:])
|
|
|
|
try:
|
|
mem0 = await get_mem0(db, user_id)
|
|
result = mem0.add(
|
|
messages=[{"role": m.role, "content": m.content} for m in messages[-10:]],
|
|
user_id=user_id,
|
|
metadata={
|
|
"conversation_id": conversation_id,
|
|
"source": "jarvis_memory",
|
|
},
|
|
)
|
|
return result.get("results", [])
|
|
except Exception as e:
|
|
print(f"Mem0 extract error: {e}")
|
|
return []
|
|
|
|
|
|
def _extract_memory_query_tokens(query: str) -> list[str]:
|
|
normalized_query = (query or "").lower()
|
|
tokens = [token for token in re.findall(r"[a-z0-9]+", normalized_query) if len(token) >= 3]
|
|
|
|
for chunk in re.findall(r"[\u4e00-\u9fff]+", query or ""):
|
|
stripped_chunk = chunk.strip()
|
|
if len(stripped_chunk) >= 4:
|
|
tokens.append(stripped_chunk)
|
|
if len(stripped_chunk) > 6:
|
|
tokens.extend(
|
|
stripped_chunk[index:index + 4]
|
|
for index in range(len(stripped_chunk) - 3)
|
|
)
|
|
|
|
return list(dict.fromkeys(tokens))
|
|
|
|
|
|
async def recall_user_memories(
|
|
db: AsyncSession,
|
|
user_id: str,
|
|
query: str,
|
|
top_k: int = 5,
|
|
) -> list[dict]:
|
|
"""
|
|
根据当前输入召回相关的用户记忆。
|
|
使用 Mem0 的语义搜索;如果 Mem0 不可用或失败,则回退到本地 UserMemory。
|
|
"""
|
|
try:
|
|
mem0 = await get_mem0(db, user_id)
|
|
results = mem0.search(
|
|
query=query,
|
|
filters={"user_id": user_id},
|
|
limit=top_k,
|
|
)
|
|
mem0_results = results.get("results", [])
|
|
if mem0_results:
|
|
return mem0_results
|
|
except Exception as e:
|
|
print(f"Mem0 search error: {e}")
|
|
|
|
query_tokens = _extract_memory_query_tokens(query)
|
|
statement = select(UserMemory).where(UserMemory.user_id == user_id)
|
|
result = await db.execute(statement.order_by(UserMemory.importance.desc(), UserMemory.created_at.desc()))
|
|
fallback_memories = list(result.scalars().all())
|
|
|
|
if _contains_hint(_normalize_query(query), MEMORY_QUERY_HINTS) or _matches_memory_query_pattern(_normalize_query(query)):
|
|
return fallback_memories[:top_k]
|
|
|
|
if query_tokens:
|
|
matched_memories = [
|
|
memory for memory in fallback_memories
|
|
if any(token in (memory.content or '').lower() for token in query_tokens)
|
|
]
|
|
return matched_memories[:top_k]
|
|
|
|
return []
|
|
|
|
|
|
async def _mark_memories_recalled(db: AsyncSession, memories: list[UserMemory]) -> None:
|
|
recalled_at = datetime.now(UTC)
|
|
updated = False
|
|
for memory in memories:
|
|
memory.is_recalled = True
|
|
memory.recall_count = (memory.recall_count or 0) + 1
|
|
memory.last_recalled_at = recalled_at
|
|
updated = True
|
|
if updated:
|
|
await db.commit()
|
|
|
|
|
|
async def _run_tolerated_section(
|
|
db: AsyncSession,
|
|
section_name: str,
|
|
builder,
|
|
) -> str:
|
|
try:
|
|
return await builder()
|
|
except Exception:
|
|
logger.warning(
|
|
"[MemoryService] %s失败,继续构建剩余上下文",
|
|
section_name,
|
|
exc_info=True,
|
|
)
|
|
return ""
|
|
|
|
|
|
async def get_user_profile(db: AsyncSession, user_id: str) -> dict:
|
|
"""
|
|
获取用户画像。
|
|
Mem0 的 profile API 会返回 static 和 dynamic facts。
|
|
"""
|
|
try:
|
|
mem0 = await get_mem0(db, user_id)
|
|
result = mem0.history(user_id=user_id)
|
|
return {
|
|
"memories": result.get("results", []),
|
|
"static": [],
|
|
"dynamic": [],
|
|
}
|
|
except Exception as e:
|
|
print(f"Mem0 profile error: {e}")
|
|
return {"memories": [], "static": [], "dynamic": []}
|
|
|
|
|
|
# ———— 记忆组装: 供 Agent 使用的上下文 ————
|
|
|
|
MEMORY_QUERY_HINTS = (
|
|
"记住",
|
|
"记下",
|
|
"记一下",
|
|
"记着",
|
|
"提醒",
|
|
"偏好",
|
|
"习惯",
|
|
)
|
|
MEMORY_QUERY_PATTERNS = (
|
|
re.compile(r"\bremember\s+(?:that\s+)?i\b"),
|
|
)
|
|
GROUNDING_QUERY_HINTS = (
|
|
"根据文档",
|
|
"严格根据",
|
|
"只根据",
|
|
"文档内容",
|
|
"grounded",
|
|
"strictly based on",
|
|
"based on the document",
|
|
"based on the docs",
|
|
"document only",
|
|
"docs only",
|
|
"only use the document",
|
|
"only use the docs",
|
|
)
|
|
AVOID_USER_MEMORY_HINTS = (
|
|
"不要结合我的个人偏好",
|
|
"不要结合个人偏好",
|
|
"不要结合偏好",
|
|
"不要结合我的记忆",
|
|
"不要结合记忆",
|
|
)
|
|
|
|
|
|
def _normalize_query(text: str) -> str:
|
|
return text.strip().lower()
|
|
|
|
|
|
def _contains_hint(text: str, hints: tuple[str, ...]) -> bool:
|
|
return any(hint in text for hint in hints)
|
|
|
|
|
|
def _matches_memory_query_pattern(text: str) -> bool:
|
|
return any(pattern.search(text) for pattern in MEMORY_QUERY_PATTERNS)
|
|
|
|
|
|
def _should_include_user_memories(query: str) -> bool:
|
|
normalized_query = _normalize_query(query)
|
|
if _contains_hint(normalized_query, GROUNDING_QUERY_HINTS):
|
|
return False
|
|
if _contains_hint(normalized_query, AVOID_USER_MEMORY_HINTS):
|
|
return False
|
|
return True
|
|
|
|
|
|
def _should_include_summaries(query: str) -> bool:
|
|
normalized_query = _normalize_query(query)
|
|
if _contains_hint(normalized_query, GROUNDING_QUERY_HINTS):
|
|
return False
|
|
if _contains_hint(normalized_query, MEMORY_QUERY_HINTS):
|
|
return False
|
|
if _matches_memory_query_pattern(normalized_query):
|
|
return False
|
|
return True
|
|
|
|
|
|
async def _build_user_memory_section(
|
|
db: AsyncSession,
|
|
user_id: str,
|
|
current_query: str,
|
|
) -> str:
|
|
memories = await recall_user_memories(db, user_id, current_query, top_k=5)
|
|
if not memories:
|
|
return ""
|
|
|
|
lines = []
|
|
recalled_user_memories: list[UserMemory] = []
|
|
for memory in memories:
|
|
if isinstance(memory, UserMemory):
|
|
memory_text = memory.content
|
|
memory_type = memory.memory_type
|
|
recalled_user_memories.append(memory)
|
|
else:
|
|
memory_text = memory.get("memory", memory.get("text", ""))
|
|
memory_type = memory.get("memory_type")
|
|
|
|
if not memory_text:
|
|
continue
|
|
|
|
if memory_type:
|
|
lines.append(f" [{memory_type}] {memory_text}")
|
|
else:
|
|
lines.append(f" - {memory_text}")
|
|
|
|
if not lines:
|
|
return ""
|
|
|
|
if recalled_user_memories:
|
|
await _mark_memories_recalled(db, recalled_user_memories)
|
|
return "【用户记忆】\n" + "\n".join(lines)
|
|
|
|
|
|
async def _build_summary_section(db: AsyncSession, conversation_id: str) -> str:
|
|
summaries = await get_summaries(db, conversation_id)
|
|
if not summaries:
|
|
return ""
|
|
|
|
recent = summaries[-2:]
|
|
lines = [f"[对话摘要{i + 1}] {summary.summary_text}" for i, summary in enumerate(recent)]
|
|
return "【之前对话摘要】\n" + "\n".join(lines)
|
|
|
|
|
|
async def _build_brain_section(
|
|
db: AsyncSession,
|
|
user_id: str,
|
|
current_query: str,
|
|
) -> str:
|
|
brain_memories = await BrainService(db).recall_memories(user_id, current_query, top_k=3)
|
|
if not brain_memories:
|
|
return ""
|
|
|
|
lines = [f"- {memory.title}: {memory.content}" for memory in brain_memories]
|
|
return "【知识大脑】\n" + "\n".join(lines)
|
|
|
|
|
|
async def build_memory_context(
|
|
db: AsyncSession,
|
|
user_id: str,
|
|
conversation_id: str,
|
|
current_query: str,
|
|
) -> str:
|
|
"""
|
|
构建完整的记忆上下文字符串,
|
|
供注入到 Agent system prompt 中使用。
|
|
"""
|
|
parts: list[str] = []
|
|
|
|
if _should_include_user_memories(current_query):
|
|
user_memory_section = await _run_tolerated_section(
|
|
db,
|
|
"用户记忆召回",
|
|
lambda: _build_user_memory_section(db, user_id, current_query),
|
|
)
|
|
if user_memory_section:
|
|
parts.append(user_memory_section)
|
|
|
|
if _should_include_summaries(current_query):
|
|
summary_section = await _run_tolerated_section(
|
|
db,
|
|
"对话摘要加载",
|
|
lambda: _build_summary_section(db, conversation_id),
|
|
)
|
|
if summary_section:
|
|
parts.append(summary_section)
|
|
|
|
brain_section = await _run_tolerated_section(
|
|
db,
|
|
"知识大脑召回",
|
|
lambda: _build_brain_section(db, user_id, current_query),
|
|
)
|
|
if brain_section:
|
|
parts.append(brain_section)
|
|
|
|
if not parts:
|
|
return ""
|
|
return "\n\n".join(parts)
|
|
|
|
|
|
async def try_auto_summarize(
|
|
db: AsyncSession,
|
|
user_id: str,
|
|
conversation_id: str,
|
|
) -> bool:
|
|
"""
|
|
检查是否需要摘要,如果需要则生成并保存。
|
|
同时将对话内容存入 Mem0 进行记忆提取。
|
|
"""
|
|
if not await should_summarize(db, conversation_id):
|
|
return False
|
|
|
|
messages = await load_conversation_history(db, conversation_id, limit=30)
|
|
if len(messages) < 3:
|
|
return False
|
|
|
|
try:
|
|
summary_text = await generate_summary(db, conversation_id, messages)
|
|
turn_count = await get_conversation_turn_count(db, conversation_id)
|
|
await save_summary(db, user_id, conversation_id, summary_text, turn_count)
|
|
|
|
await extract_user_memories(db, user_id, conversation_id, messages)
|
|
return True
|
|
except Exception as e:
|
|
print(f"Auto summarize error: {e}")
|
|
return False
|
|
|
|
|
|
async def forget_memory(db: AsyncSession, user_id: str, memory_id: str) -> bool:
|
|
"""
|
|
主动遗忘某条记忆。
|
|
"""
|
|
try:
|
|
mem0 = await get_mem0(db, user_id)
|
|
mem0.delete(memory_id, user_id=user_id)
|
|
return True
|
|
except Exception as e:
|
|
print(f"Mem0 delete error: {e}")
|
|
return False
|
|
|
|
|
|
async def update_memory(
|
|
db: AsyncSession,
|
|
user_id: str,
|
|
memory_id: str,
|
|
content: str,
|
|
) -> bool:
|
|
"""
|
|
更新某条记忆。Mem0 会自动处理矛盾检测。
|
|
"""
|
|
try:
|
|
mem0 = await get_mem0(db, user_id)
|
|
mem0.update(memory_id, content, user_id=user_id)
|
|
return True
|
|
except Exception as e:
|
|
print(f"Mem0 update error: {e}")
|
|
return False
|