Files
JARVIS/backend/app/services/memory_service.py

656 lines
19 KiB
Python
Raw Normal View History

2026-03-21 10:13:29 +08:00
"""
Jarvis 记忆系统 (基于 Mem0)
2026-03-21 10:13:29 +08:00
三层记忆: 短期(对话历史) 中期(摘要) 长期(用户画像)
底层使用 Mem0 实现事实提取时间线矛盾解决和遗忘机制
2026-03-21 10:13:29 +08:00
"""
import logging
import os
import re
import json
from datetime import UTC, datetime
from typing import Optional, Any
2026-03-21 10:13:29 +08:00
from sqlalchemy import select, desc, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.conversation import Conversation, Message
from app.models.memory import UserMemory
from app.models.user import User
from app.services.brain_service import BrainService
from app.services.memory.frequency_tracker import FrequencyTracker
from app.services.memory.emotion_analyzer import EmotionAnalyzer
from app.services.memory.impact_evaluator import ImpactEvaluator
from app.services.memory.importance_scorer import ImportanceScorer
from app.config import settings as _settings
try:
from mem0 import Memory
MEM0_AVAILABLE = True
except ImportError:
MEM0_AVAILABLE = False
Memory = None
logger = logging.getLogger(__name__)
async def _get_user_embedding_config(db: AsyncSession, user_id: str) -> dict | None:
"""从用户配置中获取 embedding 模型配置"""
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user or not user.llm_config:
return None
embedding_models = user.llm_config.get("embedding", [])
for model in embedding_models:
if model.get("enabled") and model.get("model"):
return {
"model": model.get("model"),
"base_url": model.get("base_url") or _settings.EMBEDDING_BASE_URL,
"api_key": model.get("api_key")
or _settings.EMBEDDING_API_KEY
or _settings.OPENAI_API_KEY,
}
return None
async def _get_user_chat_config(db: AsyncSession, user_id: str) -> dict | None:
"""从用户配置中获取 chat 模型配置"""
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user or not user.llm_config:
return None
chat_models = user.llm_config.get("chat", [])
for model in chat_models:
if model.get("enabled") and model.get("model"):
return {
"model": model.get("model"),
"base_url": model.get("base_url") or _settings.OPENAI_BASE_URL,
"api_key": model.get("api_key") or _settings.OPENAI_API_KEY,
}
return None
class Mem0Client:
"""Mem0 客户端 - 按用户隔离"""
_instances: dict[str, Memory] = {}
_persist_dir: str = "./data/mem0"
async def get_memory(self, db: AsyncSession, user_id: str) -> Memory:
"""获取指定用户的 Mem0 实例"""
cache_key = user_id
if cache_key not in self._instances:
self._instances[cache_key] = await self._init_memory(db, user_id)
return self._instances[cache_key]
async def _init_memory(self, db: AsyncSession, user_id: str) -> Memory:
if not MEM0_AVAILABLE:
raise RuntimeError("mem0ai 未安装,请运行: pip install mem0ai")
os.makedirs(self._persist_dir, exist_ok=True)
llm_config = {
"model": _settings.OPENAI_MODEL,
"base_url": _settings.OPENAI_BASE_URL,
"api_key": _settings.OPENAI_API_KEY,
}
embed_config = _settings.EMBEDDING_MODEL
embed_base_url = _settings.EMBEDDING_BASE_URL
embed_api_key = _settings.EMBEDDING_API_KEY or _settings.OPENAI_API_KEY
if db and user_id:
try:
user_chat = await _get_user_chat_config(db, user_id)
if user_chat:
llm_config = user_chat
except Exception:
pass
try:
user_embed = await _get_user_embedding_config(db, user_id)
if user_embed:
embed_config = user_embed["model"]
embed_base_url = user_embed["base_url"]
embed_api_key = user_embed["api_key"]
except Exception:
pass
config = {
"vector_store": {
"provider": "chroma",
"config": {
"collection_name": f"jarvis_memory_{user_id}",
"path": self._persist_dir,
},
},
"llm": {
"provider": "openai",
"config": {
"model": llm_config["model"],
"api_key": llm_config["api_key"],
"base_url": llm_config["base_url"],
},
},
"embedder": {
"provider": "openai",
"config": {
"model": embed_config,
"api_key": embed_api_key,
"base_url": embed_base_url,
},
},
}
return Memory.from_config(config)
_mem0_client = Mem0Client()
async def get_mem0(db: AsyncSession, user_id: str) -> Memory:
"""获取指定用户的 Mem0 实例"""
return await _mem0_client.get_memory(db, user_id)
2026-03-21 10:13:29 +08:00
# ———— 短期记忆: 对话历史 ————
2026-03-21 10:13:29 +08:00
async def load_conversation_history(
db: AsyncSession,
conversation_id: str,
limit: int = 20,
) -> list[Message]:
"""加载指定对话的历史消息"""
result = await db.execute(
select(Message)
.where(Message.conversation_id == conversation_id)
.order_by(Message.created_at)
.limit(limit)
)
return list(result.scalars().all())
async def get_conversation_turn_count(db: AsyncSession, conversation_id: str) -> int:
"""获取对话轮数(用户消息数)"""
result = await db.execute(
select(func.count(Message.id)).where(
2026-03-21 10:13:29 +08:00
Message.conversation_id == conversation_id,
Message.role == "user",
)
)
return result.scalar() or 0
# ———— 中期记忆: 对话摘要 ————
SUMMARIZE_THRESHOLD = 8
MAX_HISTORY_TURNS = 10
2026-03-21 10:13:29 +08:00
async def should_summarize(db: AsyncSession, conversation_id: str) -> bool:
"""判断当前对话是否需要摘要"""
from app.models.memory import MemorySummary
2026-03-21 10:13:29 +08:00
turn_count = await get_conversation_turn_count(db, conversation_id)
result = await db.execute(
select(MemorySummary)
.where(MemorySummary.conversation_id == conversation_id)
.order_by(desc(MemorySummary.turn_count))
.limit(1)
)
latest_summary = result.scalar_one_or_none()
if latest_summary:
return turn_count - latest_summary.turn_count >= SUMMARIZE_THRESHOLD
return turn_count >= SUMMARIZE_THRESHOLD
async def generate_summary(
db: AsyncSession,
conversation_id: str,
messages: list[Message],
) -> str:
"""生成对话摘要"""
from app.services.llm_service import get_llm
2026-03-21 10:13:29 +08:00
from langchain_core.messages import HumanMessage, SystemMessage
history_text = "\n".join(f"[{m.role}] {m.content}" for m in messages)
llm = get_llm()
response = await llm.invoke(
[
SystemMessage(
content="你是一个记忆助手。请用简洁的中文总结以下对话的核心内容,"
"提取关键信息、用户偏好、待办事项等。不超过150字。"
),
HumanMessage(content=history_text),
]
)
2026-03-21 10:13:29 +08:00
return response.content.strip()
async def save_summary(
db: AsyncSession,
user_id: str,
conversation_id: str,
summary_text: str,
turn_count: int,
) -> Any:
"""保存对话摘要到数据库"""
from app.models.memory import MemorySummary
2026-03-21 10:13:29 +08:00
summary = MemorySummary(
user_id=user_id,
conversation_id=conversation_id,
summary_text=summary_text,
turn_count=turn_count,
)
db.add(summary)
await db.commit()
await db.refresh(summary)
return summary
async def get_summaries(
db: AsyncSession,
conversation_id: str,
) -> list[Any]:
2026-03-21 10:13:29 +08:00
"""获取某对话的所有历史摘要"""
from app.models.memory import MemorySummary
2026-03-21 10:13:29 +08:00
result = await db.execute(
select(MemorySummary)
.where(MemorySummary.conversation_id == conversation_id)
.order_by(MemorySummary.summary_at)
)
return list(result.scalars().all())
# ———— 长期记忆: 基于 Mem0 ————
2026-03-21 10:13:29 +08:00
async def extract_user_memories(
db: AsyncSession,
user_id: str,
conversation_id: str,
messages: list[Message],
) -> list[dict]:
"""
从对话中提取用户记忆并存储到 Mem0
Mem0 会自动处理:
- 事实提取
- 时间线追踪
- 矛盾解决
- 遗忘机制
"""
2026-03-21 10:13:29 +08:00
if len(messages) < 2:
return []
history_text = "\n".join(f"[{m.role}] {m.content}" for m in messages[-10:])
2026-03-21 10:13:29 +08:00
try:
mem0 = await get_mem0(db, user_id)
result = mem0.add(
messages=[{"role": m.role, "content": m.content} for m in messages[-10:]],
2026-03-21 10:13:29 +08:00
user_id=user_id,
metadata={
"conversation_id": conversation_id,
"source": "jarvis_memory",
},
2026-03-21 10:13:29 +08:00
)
return result.get("results", [])
except Exception as e:
print(f"Mem0 extract error: {e}")
return []
2026-03-21 10:13:29 +08:00
def _extract_memory_query_tokens(query: str) -> list[str]:
normalized_query = (query or "").lower()
tokens = [token for token in re.findall(r"[a-z0-9]+", normalized_query) if len(token) >= 3]
for chunk in re.findall(r"[\u4e00-\u9fff]+", query or ""):
stripped_chunk = chunk.strip()
if len(stripped_chunk) >= 4:
tokens.append(stripped_chunk)
if len(stripped_chunk) > 6:
tokens.extend(
stripped_chunk[index : index + 4] for index in range(len(stripped_chunk) - 3)
)
return list(dict.fromkeys(tokens))
2026-03-21 10:13:29 +08:00
async def recall_user_memories(
db: AsyncSession,
user_id: str,
query: str,
top_k: int = 5,
) -> list[dict]:
"""
根据当前输入召回相关的用户记忆
使用 Mem0 的语义搜索如果 Mem0 不可用或失败则回退到本地 UserMemory
"""
try:
mem0 = await get_mem0(db, user_id)
results = mem0.search(
query=query,
filters={"user_id": user_id},
limit=top_k,
)
mem0_results = results.get("results", [])
if mem0_results:
return mem0_results
except Exception as e:
print(f"Mem0 search error: {e}")
query_tokens = _extract_memory_query_tokens(query)
statement = select(UserMemory).where(UserMemory.user_id == user_id)
result = await db.execute(
statement.order_by(UserMemory.importance_score.desc(), UserMemory.created_at.desc())
)
fallback_memories = list(result.scalars().all())
if _contains_hint(_normalize_query(query), MEMORY_QUERY_HINTS) or _matches_memory_query_pattern(
_normalize_query(query)
):
return fallback_memories[:top_k]
if query_tokens:
matched_memories = [
memory
for memory in fallback_memories
if any(token in (memory.content or "").lower() for token in query_tokens)
]
return matched_memories[:top_k]
return []
async def _mark_memories_recalled(db: AsyncSession, memories: list[UserMemory]) -> None:
"""Mark memories as recalled and update importance score"""
from app.services.memory.frequency_tracker import FrequencyTracker
from app.services.memory.importance_scorer import ImportanceScorer
recalled_at = datetime.now(UTC)
tracker = FrequencyTracker()
scorer = ImportanceScorer()
updated = False
for memory in memories:
memory.is_recalled = True
memory.recall_count = (memory.recall_count or 0) + 1
memory.last_recalled_at = recalled_at
memory.frequency_count = memory.recall_count # Keep in sync
# Update importance score on recall
scorer.update_memory_importance(memory)
updated = True
if updated:
await db.commit()
async def _run_tolerated_section(
db: AsyncSession,
section_name: str,
builder,
) -> str:
try:
return await builder()
except Exception:
logger.warning(
"[MemoryService] %s失败,继续构建剩余上下文",
section_name,
exc_info=True,
)
return ""
2026-03-21 10:13:29 +08:00
async def get_user_profile(db: AsyncSession, user_id: str) -> dict:
"""
获取用户画像
Mem0 profile API 会返回 static dynamic facts
"""
try:
mem0 = await get_mem0(db, user_id)
result = mem0.history(user_id=user_id)
return {
"memories": result.get("results", []),
"static": [],
"dynamic": [],
}
except Exception as e:
print(f"Mem0 profile error: {e}")
return {"memories": [], "static": [], "dynamic": []}
2026-03-21 10:13:29 +08:00
# ———— 记忆组装: 供 Agent 使用的上下文 ————
MEMORY_QUERY_HINTS = (
"记住",
"记下",
"记一下",
"记着",
"提醒",
"偏好",
"习惯",
)
MEMORY_QUERY_PATTERNS = (re.compile(r"\bremember\s+(?:that\s+)?i\b"),)
GROUNDING_QUERY_HINTS = (
"根据文档",
"严格根据",
"只根据",
"文档内容",
"grounded",
"strictly based on",
"based on the document",
"based on the docs",
"document only",
"docs only",
"only use the document",
"only use the docs",
)
AVOID_USER_MEMORY_HINTS = (
"不要结合我的个人偏好",
"不要结合个人偏好",
"不要结合偏好",
"不要结合我的记忆",
"不要结合记忆",
)
def _normalize_query(text: str) -> str:
return text.strip().lower()
def _contains_hint(text: str, hints: tuple[str, ...]) -> bool:
return any(hint in text for hint in hints)
def _matches_memory_query_pattern(text: str) -> bool:
return any(pattern.search(text) for pattern in MEMORY_QUERY_PATTERNS)
def _should_include_user_memories(query: str) -> bool:
normalized_query = _normalize_query(query)
if _contains_hint(normalized_query, GROUNDING_QUERY_HINTS):
return False
if _contains_hint(normalized_query, AVOID_USER_MEMORY_HINTS):
return False
return True
def _should_include_summaries(query: str) -> bool:
normalized_query = _normalize_query(query)
if _contains_hint(normalized_query, GROUNDING_QUERY_HINTS):
return False
if _contains_hint(normalized_query, MEMORY_QUERY_HINTS):
return False
if _matches_memory_query_pattern(normalized_query):
return False
return True
async def _build_user_memory_section(
db: AsyncSession,
user_id: str,
current_query: str,
) -> str:
memories = await recall_user_memories(db, user_id, current_query, top_k=5)
if not memories:
return ""
lines = []
recalled_user_memories: list[UserMemory] = []
for memory in memories:
if isinstance(memory, UserMemory):
memory_text = memory.content
memory_type = memory.memory_type
recalled_user_memories.append(memory)
else:
memory_text = memory.get("memory", memory.get("text", ""))
memory_type = memory.get("memory_type")
if not memory_text:
continue
if memory_type:
lines.append(f" [{memory_type}] {memory_text}")
else:
lines.append(f" - {memory_text}")
if not lines:
return ""
if recalled_user_memories:
await _mark_memories_recalled(db, recalled_user_memories)
return "【用户记忆】\n" + "\n".join(lines)
async def _build_summary_section(db: AsyncSession, conversation_id: str) -> str:
summaries = await get_summaries(db, conversation_id)
if not summaries:
return ""
recent = summaries[-2:]
lines = [f"[对话摘要{i + 1}] {summary.summary_text}" for i, summary in enumerate(recent)]
return "【之前对话摘要】\n" + "\n".join(lines)
async def _build_brain_section(
db: AsyncSession,
user_id: str,
current_query: str,
) -> str:
brain_memories = await BrainService(db).recall_memories(user_id, current_query, top_k=3)
if not brain_memories:
return ""
lines = [f"- {memory.title}: {memory.content}" for memory in brain_memories]
return "【知识大脑】\n" + "\n".join(lines)
2026-03-21 10:13:29 +08:00
async def build_memory_context(
db: AsyncSession,
user_id: str,
conversation_id: str,
current_query: str,
) -> str:
"""
构建完整的记忆上下文字符串
供注入到 Agent system prompt 中使用
"""
parts: list[str] = []
2026-03-21 10:13:29 +08:00
if _should_include_user_memories(current_query):
user_memory_section = await _run_tolerated_section(
db,
"用户记忆召回",
lambda: _build_user_memory_section(db, user_id, current_query),
)
if user_memory_section:
parts.append(user_memory_section)
if _should_include_summaries(current_query):
summary_section = await _run_tolerated_section(
db,
"对话摘要加载",
lambda: _build_summary_section(db, conversation_id),
)
if summary_section:
parts.append(summary_section)
2026-03-21 10:13:29 +08:00
brain_section = await _run_tolerated_section(
db,
"知识大脑召回",
lambda: _build_brain_section(db, user_id, current_query),
)
if brain_section:
parts.append(brain_section)
2026-03-21 10:13:29 +08:00
if not parts:
return ""
return "\n\n".join(parts)
async def try_auto_summarize(
db: AsyncSession,
user_id: str,
conversation_id: str,
) -> bool:
"""
检查是否需要摘要如果需要则生成并保存
同时将对话内容存入 Mem0 进行记忆提取
2026-03-21 10:13:29 +08:00
"""
if not await should_summarize(db, conversation_id):
return False
messages = await load_conversation_history(db, conversation_id, limit=30)
if len(messages) < 3:
return False
try:
summary_text = await generate_summary(db, conversation_id, messages)
turn_count = await get_conversation_turn_count(db, conversation_id)
await save_summary(db, user_id, conversation_id, summary_text, turn_count)
await extract_user_memories(db, user_id, conversation_id, messages)
return True
except Exception as e:
print(f"Auto summarize error: {e}")
return False
async def forget_memory(db: AsyncSession, user_id: str, memory_id: str) -> bool:
"""
主动遗忘某条记忆
"""
try:
mem0 = await get_mem0(db, user_id)
mem0.delete(memory_id, user_id=user_id)
return True
except Exception as e:
print(f"Mem0 delete error: {e}")
return False
async def update_memory(
db: AsyncSession,
user_id: str,
memory_id: str,
content: str,
) -> bool:
"""
更新某条记忆Mem0 会自动处理矛盾检测
"""
try:
mem0 = await get_mem0(db, user_id)
mem0.update(memory_id, content, user_id=user_id)
return True
except Exception as e:
print(f"Mem0 update error: {e}")
2026-03-21 10:13:29 +08:00
return False