440 lines
13 KiB
Python
440 lines
13 KiB
Python
"""
|
|
Jarvis 记忆系统 (基于 Mem0)
|
|
三层记忆: 短期(对话历史) → 中期(摘要) → 长期(用户画像)
|
|
底层使用 Mem0 实现事实提取、时间线、矛盾解决和遗忘机制
|
|
"""
|
|
|
|
import os
|
|
from datetime import datetime
|
|
from typing import Optional, Any
|
|
from sqlalchemy import select, desc, func
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from app.models.conversation import Conversation, Message
|
|
from app.models.user import User
|
|
from app.services.brain_service import BrainService
|
|
from app.config import settings as _settings
|
|
|
|
try:
|
|
from mem0 import Memory
|
|
|
|
MEM0_AVAILABLE = True
|
|
except ImportError:
|
|
MEM0_AVAILABLE = False
|
|
Memory = None
|
|
|
|
|
|
async def _get_user_embedding_config(db: AsyncSession, user_id: str) -> dict | None:
|
|
"""从用户配置中获取 embedding 模型配置"""
|
|
result = await db.execute(select(User).where(User.id == user_id))
|
|
user = result.scalar_one_or_none()
|
|
if not user or not user.llm_config:
|
|
return None
|
|
|
|
embedding_models = user.llm_config.get("embedding", [])
|
|
for model in embedding_models:
|
|
if model.get("enabled") and model.get("model"):
|
|
return {
|
|
"model": model.get("model"),
|
|
"base_url": model.get("base_url") or _settings.EMBEDDING_BASE_URL,
|
|
"api_key": model.get("api_key")
|
|
or _settings.EMBEDDING_API_KEY
|
|
or _settings.OPENAI_API_KEY,
|
|
}
|
|
return None
|
|
|
|
|
|
async def _get_user_chat_config(db: AsyncSession, user_id: str) -> dict | None:
|
|
"""从用户配置中获取 chat 模型配置"""
|
|
result = await db.execute(select(User).where(User.id == user_id))
|
|
user = result.scalar_one_or_none()
|
|
if not user or not user.llm_config:
|
|
return None
|
|
|
|
chat_models = user.llm_config.get("chat", [])
|
|
for model in chat_models:
|
|
if model.get("enabled") and model.get("model"):
|
|
return {
|
|
"model": model.get("model"),
|
|
"base_url": model.get("base_url") or _settings.OPENAI_BASE_URL,
|
|
"api_key": model.get("api_key") or _settings.OPENAI_API_KEY,
|
|
}
|
|
return None
|
|
|
|
|
|
class Mem0Client:
|
|
"""Mem0 客户端 - 按用户隔离"""
|
|
|
|
_instances: dict[str, Memory] = {}
|
|
_persist_dir: str = "./data/mem0"
|
|
|
|
async def get_memory(self, db: AsyncSession, user_id: str) -> Memory:
|
|
"""获取指定用户的 Mem0 实例"""
|
|
cache_key = user_id
|
|
|
|
if cache_key not in self._instances:
|
|
self._instances[cache_key] = await self._init_memory(db, user_id)
|
|
|
|
return self._instances[cache_key]
|
|
|
|
async def _init_memory(self, db: AsyncSession, user_id: str) -> Memory:
|
|
if not MEM0_AVAILABLE:
|
|
raise RuntimeError("mem0ai 未安装,请运行: pip install mem0ai")
|
|
|
|
os.makedirs(self._persist_dir, exist_ok=True)
|
|
|
|
llm_config = {
|
|
"model": _settings.OPENAI_MODEL,
|
|
"base_url": _settings.OPENAI_BASE_URL,
|
|
"api_key": _settings.OPENAI_API_KEY,
|
|
}
|
|
|
|
embed_config = _settings.EMBEDDING_MODEL
|
|
embed_base_url = _settings.EMBEDDING_BASE_URL
|
|
embed_api_key = _settings.EMBEDDING_API_KEY or _settings.OPENAI_API_KEY
|
|
|
|
if db and user_id:
|
|
try:
|
|
user_chat = await _get_user_chat_config(db, user_id)
|
|
if user_chat:
|
|
llm_config = user_chat
|
|
except Exception:
|
|
pass
|
|
|
|
try:
|
|
user_embed = await _get_user_embedding_config(db, user_id)
|
|
if user_embed:
|
|
embed_config = user_embed["model"]
|
|
embed_base_url = user_embed["base_url"]
|
|
embed_api_key = user_embed["api_key"]
|
|
except Exception:
|
|
pass
|
|
|
|
config = {
|
|
"vector_store": {
|
|
"provider": "chroma",
|
|
"config": {
|
|
"collection_name": f"jarvis_memory_{user_id}",
|
|
"path": self._persist_dir,
|
|
},
|
|
},
|
|
"llm": {
|
|
"provider": "openai",
|
|
"config": {
|
|
"model": llm_config["model"],
|
|
"api_key": llm_config["api_key"],
|
|
"base_url": llm_config["base_url"],
|
|
},
|
|
},
|
|
"embedder": {
|
|
"provider": "openai",
|
|
"config": {
|
|
"model": embed_config,
|
|
"api_key": embed_api_key,
|
|
"base_url": embed_base_url,
|
|
},
|
|
},
|
|
}
|
|
|
|
return Memory.from_config(config)
|
|
|
|
|
|
_mem0_client = Mem0Client()
|
|
|
|
|
|
async def get_mem0(db: AsyncSession, user_id: str) -> Memory:
|
|
"""获取指定用户的 Mem0 实例"""
|
|
return await _mem0_client.get_memory(db, user_id)
|
|
|
|
|
|
# ———— 短期记忆: 对话历史 ————
|
|
|
|
|
|
async def load_conversation_history(
|
|
db: AsyncSession,
|
|
conversation_id: str,
|
|
limit: int = 20,
|
|
) -> list[Message]:
|
|
"""加载指定对话的历史消息"""
|
|
result = await db.execute(
|
|
select(Message)
|
|
.where(Message.conversation_id == conversation_id)
|
|
.order_by(Message.created_at)
|
|
.limit(limit)
|
|
)
|
|
return list(result.scalars().all())
|
|
|
|
|
|
async def get_conversation_turn_count(db: AsyncSession, conversation_id: str) -> int:
|
|
"""获取对话轮数(用户消息数)"""
|
|
result = await db.execute(
|
|
select(func.count(Message.id)).where(
|
|
Message.conversation_id == conversation_id,
|
|
Message.role == "user",
|
|
)
|
|
)
|
|
return result.scalar() or 0
|
|
|
|
|
|
# ———— 中期记忆: 对话摘要 ————
|
|
|
|
SUMMARIZE_THRESHOLD = 8
|
|
MAX_HISTORY_TURNS = 10
|
|
|
|
|
|
async def should_summarize(db: AsyncSession, conversation_id: str) -> bool:
|
|
"""判断当前对话是否需要摘要"""
|
|
from app.models.memory import MemorySummary
|
|
|
|
turn_count = await get_conversation_turn_count(db, conversation_id)
|
|
result = await db.execute(
|
|
select(MemorySummary)
|
|
.where(MemorySummary.conversation_id == conversation_id)
|
|
.order_by(desc(MemorySummary.turn_count))
|
|
.limit(1)
|
|
)
|
|
latest_summary = result.scalar_one_or_none()
|
|
if latest_summary:
|
|
return turn_count - latest_summary.turn_count >= SUMMARIZE_THRESHOLD
|
|
return turn_count >= SUMMARIZE_THRESHOLD
|
|
|
|
|
|
async def generate_summary(
|
|
db: AsyncSession,
|
|
conversation_id: str,
|
|
messages: list[Message],
|
|
) -> str:
|
|
"""生成对话摘要"""
|
|
from app.services.llm_service import get_llm
|
|
from langchain_core.messages import HumanMessage, SystemMessage
|
|
|
|
history_text = "\n".join(f"[{m.role}] {m.content}" for m in messages)
|
|
llm = get_llm()
|
|
response = await llm.invoke(
|
|
[
|
|
SystemMessage(
|
|
content="你是一个记忆助手。请用简洁的中文总结以下对话的核心内容,"
|
|
"提取关键信息、用户偏好、待办事项等。不超过150字。"
|
|
),
|
|
HumanMessage(content=history_text),
|
|
]
|
|
)
|
|
return response.content.strip()
|
|
|
|
|
|
async def save_summary(
|
|
db: AsyncSession,
|
|
user_id: str,
|
|
conversation_id: str,
|
|
summary_text: str,
|
|
turn_count: int,
|
|
) -> Any:
|
|
"""保存对话摘要到数据库"""
|
|
from app.models.memory import MemorySummary
|
|
|
|
summary = MemorySummary(
|
|
user_id=user_id,
|
|
conversation_id=conversation_id,
|
|
summary_text=summary_text,
|
|
turn_count=turn_count,
|
|
)
|
|
db.add(summary)
|
|
await db.commit()
|
|
await db.refresh(summary)
|
|
return summary
|
|
|
|
|
|
async def get_summaries(
|
|
db: AsyncSession,
|
|
conversation_id: str,
|
|
) -> list[Any]:
|
|
"""获取某对话的所有历史摘要"""
|
|
from app.models.memory import MemorySummary
|
|
|
|
result = await db.execute(
|
|
select(MemorySummary)
|
|
.where(MemorySummary.conversation_id == conversation_id)
|
|
.order_by(MemorySummary.summary_at)
|
|
)
|
|
return list(result.scalars().all())
|
|
|
|
|
|
# ———— 长期记忆: 基于 Mem0 ————
|
|
|
|
|
|
async def extract_user_memories(
|
|
db: AsyncSession,
|
|
user_id: str,
|
|
conversation_id: str,
|
|
messages: list[Message],
|
|
) -> list[dict]:
|
|
"""
|
|
从对话中提取用户记忆并存储到 Mem0。
|
|
Mem0 会自动处理:
|
|
- 事实提取
|
|
- 时间线追踪
|
|
- 矛盾解决
|
|
- 遗忘机制
|
|
"""
|
|
if len(messages) < 2:
|
|
return []
|
|
|
|
history_text = "\n".join(f"[{m.role}] {m.content}" for m in messages[-10:])
|
|
|
|
try:
|
|
mem0 = await get_mem0(db, user_id)
|
|
result = mem0.add(
|
|
messages=[{"role": m.role, "content": m.content} for m in messages[-10:]],
|
|
user_id=user_id,
|
|
metadata={
|
|
"conversation_id": conversation_id,
|
|
"source": "jarvis_memory",
|
|
},
|
|
)
|
|
return result.get("results", [])
|
|
except Exception as e:
|
|
print(f"Mem0 extract error: {e}")
|
|
return []
|
|
|
|
|
|
async def recall_user_memories(
|
|
db: AsyncSession,
|
|
user_id: str,
|
|
query: str,
|
|
top_k: int = 5,
|
|
) -> list[dict]:
|
|
"""
|
|
根据当前输入召回相关的用户记忆。
|
|
使用 Mem0 的语义搜索。
|
|
"""
|
|
try:
|
|
mem0 = await get_mem0(db, user_id)
|
|
results = mem0.search(
|
|
query=query,
|
|
filters={"user_id": user_id},
|
|
limit=top_k,
|
|
)
|
|
return results.get("results", [])
|
|
except Exception as e:
|
|
print(f"Mem0 search error: {e}")
|
|
return []
|
|
|
|
|
|
async def get_user_profile(db: AsyncSession, user_id: str) -> dict:
|
|
"""
|
|
获取用户画像。
|
|
Mem0 的 profile API 会返回 static 和 dynamic facts。
|
|
"""
|
|
try:
|
|
mem0 = await get_mem0(db, user_id)
|
|
result = mem0.history(user_id=user_id)
|
|
return {
|
|
"memories": result.get("results", []),
|
|
"static": [],
|
|
"dynamic": [],
|
|
}
|
|
except Exception as e:
|
|
print(f"Mem0 profile error: {e}")
|
|
return {"memories": [], "static": [], "dynamic": []}
|
|
|
|
|
|
# ———— 记忆组装: 供 Agent 使用的上下文 ————
|
|
|
|
|
|
async def build_memory_context(
|
|
db: AsyncSession,
|
|
user_id: str,
|
|
conversation_id: str,
|
|
current_query: str,
|
|
) -> str:
|
|
"""
|
|
构建完整的记忆上下文字符串,
|
|
供注入到 Agent system prompt 中使用。
|
|
"""
|
|
parts = []
|
|
|
|
memories = await recall_user_memories(db, user_id, current_query, top_k=5)
|
|
if memories:
|
|
lines = []
|
|
for m in memories:
|
|
memory_text = m.get("memory", m.get("text", ""))
|
|
if memory_text:
|
|
lines.append(f" - {memory_text}")
|
|
if lines:
|
|
parts.append("【用户记忆】\n" + "\n".join(lines))
|
|
|
|
summaries = await get_summaries(db, conversation_id)
|
|
if summaries:
|
|
recent = summaries[-2:]
|
|
lines = [f"[对话摘要{i + 1}] {s.summary_text}" for i, s in enumerate(recent)]
|
|
parts.append("【之前对话摘要】\n" + "\n".join(lines))
|
|
|
|
brain_memories = await BrainService(db).recall_memories(user_id, current_query, top_k=3)
|
|
if brain_memories:
|
|
lines = []
|
|
for memory in brain_memories:
|
|
lines.append(f"- {memory.title}: {memory.content}")
|
|
parts.append("【知识大脑】\n" + "\n".join(lines))
|
|
|
|
if not parts:
|
|
return ""
|
|
return "\n\n".join(parts)
|
|
|
|
|
|
async def try_auto_summarize(
|
|
db: AsyncSession,
|
|
user_id: str,
|
|
conversation_id: str,
|
|
) -> bool:
|
|
"""
|
|
检查是否需要摘要,如果需要则生成并保存。
|
|
同时将对话内容存入 Mem0 进行记忆提取。
|
|
"""
|
|
if not await should_summarize(db, conversation_id):
|
|
return False
|
|
|
|
messages = await load_conversation_history(db, conversation_id, limit=30)
|
|
if len(messages) < 3:
|
|
return False
|
|
|
|
try:
|
|
summary_text = await generate_summary(db, conversation_id, messages)
|
|
turn_count = await get_conversation_turn_count(db, conversation_id)
|
|
await save_summary(db, user_id, conversation_id, summary_text, turn_count)
|
|
|
|
await extract_user_memories(db, user_id, conversation_id, messages)
|
|
return True
|
|
except Exception as e:
|
|
print(f"Auto summarize error: {e}")
|
|
return False
|
|
|
|
|
|
async def forget_memory(db: AsyncSession, user_id: str, memory_id: str) -> bool:
|
|
"""
|
|
主动遗忘某条记忆。
|
|
"""
|
|
try:
|
|
mem0 = await get_mem0(db, user_id)
|
|
mem0.delete(memory_id, user_id=user_id)
|
|
return True
|
|
except Exception as e:
|
|
print(f"Mem0 delete error: {e}")
|
|
return False
|
|
|
|
|
|
async def update_memory(
|
|
db: AsyncSession,
|
|
user_id: str,
|
|
memory_id: str,
|
|
content: str,
|
|
) -> bool:
|
|
"""
|
|
更新某条记忆。Mem0 会自动处理矛盾检测。
|
|
"""
|
|
try:
|
|
mem0 = await get_mem0(db, user_id)
|
|
mem0.update(memory_id, content, user_id=user_id)
|
|
return True
|
|
except Exception as e:
|
|
print(f"Mem0 update error: {e}")
|
|
return False
|