- Add FrequencyTracker: increment(), get_frequency_score(), get_recency_score(), get_time_decay() - Add EmotionAnalyzer: EMOTION_KEYWORDS dict, extract(), calculate_score(), get_emotion_profile() - Add ImpactEvaluator: evaluate(), get_topic_overlap(), rank_by_impact() - Add ImportanceScorer: composite scoring (freq 35% + recency 20% + emotion 25% + impact 20%) - Update UserMemory model: frequency_count, emotion_tags, importance_score, importance_level, associated_topics - Integrate ImportanceScorer into memory_service.py (recall + importance update) - Add 37 tests for all memory scoring components - Fix urgency patterns: remove overly broad '今天' that matched neutral text - Update memory-update checklist: mark all M.1 tasks complete
656 lines
19 KiB
Python
656 lines
19 KiB
Python
"""
|
|
Jarvis 记忆系统 (基于 Mem0)
|
|
三层记忆: 短期(对话历史) → 中期(摘要) → 长期(用户画像)
|
|
底层使用 Mem0 实现事实提取、时间线、矛盾解决和遗忘机制
|
|
"""
|
|
|
|
import logging
|
|
import os
|
|
import re
|
|
import json
|
|
from datetime import UTC, datetime
|
|
from typing import Optional, Any
|
|
from sqlalchemy import select, desc, func
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from app.models.conversation import Conversation, Message
|
|
from app.models.memory import UserMemory
|
|
from app.models.user import User
|
|
from app.services.brain_service import BrainService
|
|
from app.services.memory.frequency_tracker import FrequencyTracker
|
|
from app.services.memory.emotion_analyzer import EmotionAnalyzer
|
|
from app.services.memory.impact_evaluator import ImpactEvaluator
|
|
from app.services.memory.importance_scorer import ImportanceScorer
|
|
from app.config import settings as _settings
|
|
|
|
try:
|
|
from mem0 import Memory
|
|
|
|
MEM0_AVAILABLE = True
|
|
except ImportError:
|
|
MEM0_AVAILABLE = False
|
|
Memory = None
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
async def _get_user_embedding_config(db: AsyncSession, user_id: str) -> dict | None:
|
|
"""从用户配置中获取 embedding 模型配置"""
|
|
result = await db.execute(select(User).where(User.id == user_id))
|
|
user = result.scalar_one_or_none()
|
|
if not user or not user.llm_config:
|
|
return None
|
|
|
|
embedding_models = user.llm_config.get("embedding", [])
|
|
for model in embedding_models:
|
|
if model.get("enabled") and model.get("model"):
|
|
return {
|
|
"model": model.get("model"),
|
|
"base_url": model.get("base_url") or _settings.EMBEDDING_BASE_URL,
|
|
"api_key": model.get("api_key")
|
|
or _settings.EMBEDDING_API_KEY
|
|
or _settings.OPENAI_API_KEY,
|
|
}
|
|
return None
|
|
|
|
|
|
async def _get_user_chat_config(db: AsyncSession, user_id: str) -> dict | None:
|
|
"""从用户配置中获取 chat 模型配置"""
|
|
result = await db.execute(select(User).where(User.id == user_id))
|
|
user = result.scalar_one_or_none()
|
|
if not user or not user.llm_config:
|
|
return None
|
|
|
|
chat_models = user.llm_config.get("chat", [])
|
|
for model in chat_models:
|
|
if model.get("enabled") and model.get("model"):
|
|
return {
|
|
"model": model.get("model"),
|
|
"base_url": model.get("base_url") or _settings.OPENAI_BASE_URL,
|
|
"api_key": model.get("api_key") or _settings.OPENAI_API_KEY,
|
|
}
|
|
return None
|
|
|
|
|
|
class Mem0Client:
|
|
"""Mem0 客户端 - 按用户隔离"""
|
|
|
|
_instances: dict[str, Memory] = {}
|
|
_persist_dir: str = "./data/mem0"
|
|
|
|
async def get_memory(self, db: AsyncSession, user_id: str) -> Memory:
|
|
"""获取指定用户的 Mem0 实例"""
|
|
cache_key = user_id
|
|
|
|
if cache_key not in self._instances:
|
|
self._instances[cache_key] = await self._init_memory(db, user_id)
|
|
|
|
return self._instances[cache_key]
|
|
|
|
async def _init_memory(self, db: AsyncSession, user_id: str) -> Memory:
|
|
if not MEM0_AVAILABLE:
|
|
raise RuntimeError("mem0ai 未安装,请运行: pip install mem0ai")
|
|
|
|
os.makedirs(self._persist_dir, exist_ok=True)
|
|
|
|
llm_config = {
|
|
"model": _settings.OPENAI_MODEL,
|
|
"base_url": _settings.OPENAI_BASE_URL,
|
|
"api_key": _settings.OPENAI_API_KEY,
|
|
}
|
|
|
|
embed_config = _settings.EMBEDDING_MODEL
|
|
embed_base_url = _settings.EMBEDDING_BASE_URL
|
|
embed_api_key = _settings.EMBEDDING_API_KEY or _settings.OPENAI_API_KEY
|
|
|
|
if db and user_id:
|
|
try:
|
|
user_chat = await _get_user_chat_config(db, user_id)
|
|
if user_chat:
|
|
llm_config = user_chat
|
|
except Exception:
|
|
pass
|
|
|
|
try:
|
|
user_embed = await _get_user_embedding_config(db, user_id)
|
|
if user_embed:
|
|
embed_config = user_embed["model"]
|
|
embed_base_url = user_embed["base_url"]
|
|
embed_api_key = user_embed["api_key"]
|
|
except Exception:
|
|
pass
|
|
|
|
config = {
|
|
"vector_store": {
|
|
"provider": "chroma",
|
|
"config": {
|
|
"collection_name": f"jarvis_memory_{user_id}",
|
|
"path": self._persist_dir,
|
|
},
|
|
},
|
|
"llm": {
|
|
"provider": "openai",
|
|
"config": {
|
|
"model": llm_config["model"],
|
|
"api_key": llm_config["api_key"],
|
|
"base_url": llm_config["base_url"],
|
|
},
|
|
},
|
|
"embedder": {
|
|
"provider": "openai",
|
|
"config": {
|
|
"model": embed_config,
|
|
"api_key": embed_api_key,
|
|
"base_url": embed_base_url,
|
|
},
|
|
},
|
|
}
|
|
|
|
return Memory.from_config(config)
|
|
|
|
|
|
_mem0_client = Mem0Client()
|
|
|
|
|
|
async def get_mem0(db: AsyncSession, user_id: str) -> Memory:
|
|
"""获取指定用户的 Mem0 实例"""
|
|
return await _mem0_client.get_memory(db, user_id)
|
|
|
|
|
|
# ———— 短期记忆: 对话历史 ————
|
|
|
|
|
|
async def load_conversation_history(
|
|
db: AsyncSession,
|
|
conversation_id: str,
|
|
limit: int = 20,
|
|
) -> list[Message]:
|
|
"""加载指定对话的历史消息"""
|
|
result = await db.execute(
|
|
select(Message)
|
|
.where(Message.conversation_id == conversation_id)
|
|
.order_by(Message.created_at)
|
|
.limit(limit)
|
|
)
|
|
return list(result.scalars().all())
|
|
|
|
|
|
async def get_conversation_turn_count(db: AsyncSession, conversation_id: str) -> int:
|
|
"""获取对话轮数(用户消息数)"""
|
|
result = await db.execute(
|
|
select(func.count(Message.id)).where(
|
|
Message.conversation_id == conversation_id,
|
|
Message.role == "user",
|
|
)
|
|
)
|
|
return result.scalar() or 0
|
|
|
|
|
|
# ———— 中期记忆: 对话摘要 ————
|
|
|
|
SUMMARIZE_THRESHOLD = 8
|
|
MAX_HISTORY_TURNS = 10
|
|
|
|
|
|
async def should_summarize(db: AsyncSession, conversation_id: str) -> bool:
|
|
"""判断当前对话是否需要摘要"""
|
|
from app.models.memory import MemorySummary
|
|
|
|
turn_count = await get_conversation_turn_count(db, conversation_id)
|
|
result = await db.execute(
|
|
select(MemorySummary)
|
|
.where(MemorySummary.conversation_id == conversation_id)
|
|
.order_by(desc(MemorySummary.turn_count))
|
|
.limit(1)
|
|
)
|
|
latest_summary = result.scalar_one_or_none()
|
|
if latest_summary:
|
|
return turn_count - latest_summary.turn_count >= SUMMARIZE_THRESHOLD
|
|
return turn_count >= SUMMARIZE_THRESHOLD
|
|
|
|
|
|
async def generate_summary(
|
|
db: AsyncSession,
|
|
conversation_id: str,
|
|
messages: list[Message],
|
|
) -> str:
|
|
"""生成对话摘要"""
|
|
from app.services.llm_service import get_llm
|
|
from langchain_core.messages import HumanMessage, SystemMessage
|
|
|
|
history_text = "\n".join(f"[{m.role}] {m.content}" for m in messages)
|
|
llm = get_llm()
|
|
response = await llm.invoke(
|
|
[
|
|
SystemMessage(
|
|
content="你是一个记忆助手。请用简洁的中文总结以下对话的核心内容,"
|
|
"提取关键信息、用户偏好、待办事项等。不超过150字。"
|
|
),
|
|
HumanMessage(content=history_text),
|
|
]
|
|
)
|
|
return response.content.strip()
|
|
|
|
|
|
async def save_summary(
|
|
db: AsyncSession,
|
|
user_id: str,
|
|
conversation_id: str,
|
|
summary_text: str,
|
|
turn_count: int,
|
|
) -> Any:
|
|
"""保存对话摘要到数据库"""
|
|
from app.models.memory import MemorySummary
|
|
|
|
summary = MemorySummary(
|
|
user_id=user_id,
|
|
conversation_id=conversation_id,
|
|
summary_text=summary_text,
|
|
turn_count=turn_count,
|
|
)
|
|
db.add(summary)
|
|
await db.commit()
|
|
await db.refresh(summary)
|
|
return summary
|
|
|
|
|
|
async def get_summaries(
|
|
db: AsyncSession,
|
|
conversation_id: str,
|
|
) -> list[Any]:
|
|
"""获取某对话的所有历史摘要"""
|
|
from app.models.memory import MemorySummary
|
|
|
|
result = await db.execute(
|
|
select(MemorySummary)
|
|
.where(MemorySummary.conversation_id == conversation_id)
|
|
.order_by(MemorySummary.summary_at)
|
|
)
|
|
return list(result.scalars().all())
|
|
|
|
|
|
# ———— 长期记忆: 基于 Mem0 ————
|
|
|
|
|
|
async def extract_user_memories(
|
|
db: AsyncSession,
|
|
user_id: str,
|
|
conversation_id: str,
|
|
messages: list[Message],
|
|
) -> list[dict]:
|
|
"""
|
|
从对话中提取用户记忆并存储到 Mem0。
|
|
Mem0 会自动处理:
|
|
- 事实提取
|
|
- 时间线追踪
|
|
- 矛盾解决
|
|
- 遗忘机制
|
|
"""
|
|
if len(messages) < 2:
|
|
return []
|
|
|
|
history_text = "\n".join(f"[{m.role}] {m.content}" for m in messages[-10:])
|
|
|
|
try:
|
|
mem0 = await get_mem0(db, user_id)
|
|
result = mem0.add(
|
|
messages=[{"role": m.role, "content": m.content} for m in messages[-10:]],
|
|
user_id=user_id,
|
|
metadata={
|
|
"conversation_id": conversation_id,
|
|
"source": "jarvis_memory",
|
|
},
|
|
)
|
|
return result.get("results", [])
|
|
except Exception as e:
|
|
print(f"Mem0 extract error: {e}")
|
|
return []
|
|
|
|
|
|
def _extract_memory_query_tokens(query: str) -> list[str]:
|
|
normalized_query = (query or "").lower()
|
|
tokens = [token for token in re.findall(r"[a-z0-9]+", normalized_query) if len(token) >= 3]
|
|
|
|
for chunk in re.findall(r"[\u4e00-\u9fff]+", query or ""):
|
|
stripped_chunk = chunk.strip()
|
|
if len(stripped_chunk) >= 4:
|
|
tokens.append(stripped_chunk)
|
|
if len(stripped_chunk) > 6:
|
|
tokens.extend(
|
|
stripped_chunk[index : index + 4] for index in range(len(stripped_chunk) - 3)
|
|
)
|
|
|
|
return list(dict.fromkeys(tokens))
|
|
|
|
|
|
async def recall_user_memories(
|
|
db: AsyncSession,
|
|
user_id: str,
|
|
query: str,
|
|
top_k: int = 5,
|
|
) -> list[dict]:
|
|
"""
|
|
根据当前输入召回相关的用户记忆。
|
|
使用 Mem0 的语义搜索;如果 Mem0 不可用或失败,则回退到本地 UserMemory。
|
|
"""
|
|
try:
|
|
mem0 = await get_mem0(db, user_id)
|
|
results = mem0.search(
|
|
query=query,
|
|
filters={"user_id": user_id},
|
|
limit=top_k,
|
|
)
|
|
mem0_results = results.get("results", [])
|
|
if mem0_results:
|
|
return mem0_results
|
|
except Exception as e:
|
|
print(f"Mem0 search error: {e}")
|
|
|
|
query_tokens = _extract_memory_query_tokens(query)
|
|
statement = select(UserMemory).where(UserMemory.user_id == user_id)
|
|
result = await db.execute(
|
|
statement.order_by(UserMemory.importance_score.desc(), UserMemory.created_at.desc())
|
|
)
|
|
fallback_memories = list(result.scalars().all())
|
|
|
|
if _contains_hint(_normalize_query(query), MEMORY_QUERY_HINTS) or _matches_memory_query_pattern(
|
|
_normalize_query(query)
|
|
):
|
|
return fallback_memories[:top_k]
|
|
|
|
if query_tokens:
|
|
matched_memories = [
|
|
memory
|
|
for memory in fallback_memories
|
|
if any(token in (memory.content or "").lower() for token in query_tokens)
|
|
]
|
|
return matched_memories[:top_k]
|
|
|
|
return []
|
|
|
|
|
|
async def _mark_memories_recalled(db: AsyncSession, memories: list[UserMemory]) -> None:
|
|
"""Mark memories as recalled and update importance score"""
|
|
from app.services.memory.frequency_tracker import FrequencyTracker
|
|
from app.services.memory.importance_scorer import ImportanceScorer
|
|
|
|
recalled_at = datetime.now(UTC)
|
|
tracker = FrequencyTracker()
|
|
scorer = ImportanceScorer()
|
|
updated = False
|
|
|
|
for memory in memories:
|
|
memory.is_recalled = True
|
|
memory.recall_count = (memory.recall_count or 0) + 1
|
|
memory.last_recalled_at = recalled_at
|
|
memory.frequency_count = memory.recall_count # Keep in sync
|
|
|
|
# Update importance score on recall
|
|
scorer.update_memory_importance(memory)
|
|
updated = True
|
|
|
|
if updated:
|
|
await db.commit()
|
|
|
|
|
|
async def _run_tolerated_section(
|
|
db: AsyncSession,
|
|
section_name: str,
|
|
builder,
|
|
) -> str:
|
|
try:
|
|
return await builder()
|
|
except Exception:
|
|
logger.warning(
|
|
"[MemoryService] %s失败,继续构建剩余上下文",
|
|
section_name,
|
|
exc_info=True,
|
|
)
|
|
return ""
|
|
|
|
|
|
async def get_user_profile(db: AsyncSession, user_id: str) -> dict:
|
|
"""
|
|
获取用户画像。
|
|
Mem0 的 profile API 会返回 static 和 dynamic facts。
|
|
"""
|
|
try:
|
|
mem0 = await get_mem0(db, user_id)
|
|
result = mem0.history(user_id=user_id)
|
|
return {
|
|
"memories": result.get("results", []),
|
|
"static": [],
|
|
"dynamic": [],
|
|
}
|
|
except Exception as e:
|
|
print(f"Mem0 profile error: {e}")
|
|
return {"memories": [], "static": [], "dynamic": []}
|
|
|
|
|
|
# ———— 记忆组装: 供 Agent 使用的上下文 ————
|
|
|
|
MEMORY_QUERY_HINTS = (
|
|
"记住",
|
|
"记下",
|
|
"记一下",
|
|
"记着",
|
|
"提醒",
|
|
"偏好",
|
|
"习惯",
|
|
)
|
|
MEMORY_QUERY_PATTERNS = (re.compile(r"\bremember\s+(?:that\s+)?i\b"),)
|
|
GROUNDING_QUERY_HINTS = (
|
|
"根据文档",
|
|
"严格根据",
|
|
"只根据",
|
|
"文档内容",
|
|
"grounded",
|
|
"strictly based on",
|
|
"based on the document",
|
|
"based on the docs",
|
|
"document only",
|
|
"docs only",
|
|
"only use the document",
|
|
"only use the docs",
|
|
)
|
|
AVOID_USER_MEMORY_HINTS = (
|
|
"不要结合我的个人偏好",
|
|
"不要结合个人偏好",
|
|
"不要结合偏好",
|
|
"不要结合我的记忆",
|
|
"不要结合记忆",
|
|
)
|
|
|
|
|
|
def _normalize_query(text: str) -> str:
|
|
return text.strip().lower()
|
|
|
|
|
|
def _contains_hint(text: str, hints: tuple[str, ...]) -> bool:
|
|
return any(hint in text for hint in hints)
|
|
|
|
|
|
def _matches_memory_query_pattern(text: str) -> bool:
|
|
return any(pattern.search(text) for pattern in MEMORY_QUERY_PATTERNS)
|
|
|
|
|
|
def _should_include_user_memories(query: str) -> bool:
|
|
normalized_query = _normalize_query(query)
|
|
if _contains_hint(normalized_query, GROUNDING_QUERY_HINTS):
|
|
return False
|
|
if _contains_hint(normalized_query, AVOID_USER_MEMORY_HINTS):
|
|
return False
|
|
return True
|
|
|
|
|
|
def _should_include_summaries(query: str) -> bool:
|
|
normalized_query = _normalize_query(query)
|
|
if _contains_hint(normalized_query, GROUNDING_QUERY_HINTS):
|
|
return False
|
|
if _contains_hint(normalized_query, MEMORY_QUERY_HINTS):
|
|
return False
|
|
if _matches_memory_query_pattern(normalized_query):
|
|
return False
|
|
return True
|
|
|
|
|
|
async def _build_user_memory_section(
|
|
db: AsyncSession,
|
|
user_id: str,
|
|
current_query: str,
|
|
) -> str:
|
|
memories = await recall_user_memories(db, user_id, current_query, top_k=5)
|
|
if not memories:
|
|
return ""
|
|
|
|
lines = []
|
|
recalled_user_memories: list[UserMemory] = []
|
|
for memory in memories:
|
|
if isinstance(memory, UserMemory):
|
|
memory_text = memory.content
|
|
memory_type = memory.memory_type
|
|
recalled_user_memories.append(memory)
|
|
else:
|
|
memory_text = memory.get("memory", memory.get("text", ""))
|
|
memory_type = memory.get("memory_type")
|
|
|
|
if not memory_text:
|
|
continue
|
|
|
|
if memory_type:
|
|
lines.append(f" [{memory_type}] {memory_text}")
|
|
else:
|
|
lines.append(f" - {memory_text}")
|
|
|
|
if not lines:
|
|
return ""
|
|
|
|
if recalled_user_memories:
|
|
await _mark_memories_recalled(db, recalled_user_memories)
|
|
return "【用户记忆】\n" + "\n".join(lines)
|
|
|
|
|
|
async def _build_summary_section(db: AsyncSession, conversation_id: str) -> str:
|
|
summaries = await get_summaries(db, conversation_id)
|
|
if not summaries:
|
|
return ""
|
|
|
|
recent = summaries[-2:]
|
|
lines = [f"[对话摘要{i + 1}] {summary.summary_text}" for i, summary in enumerate(recent)]
|
|
return "【之前对话摘要】\n" + "\n".join(lines)
|
|
|
|
|
|
async def _build_brain_section(
|
|
db: AsyncSession,
|
|
user_id: str,
|
|
current_query: str,
|
|
) -> str:
|
|
brain_memories = await BrainService(db).recall_memories(user_id, current_query, top_k=3)
|
|
if not brain_memories:
|
|
return ""
|
|
|
|
lines = [f"- {memory.title}: {memory.content}" for memory in brain_memories]
|
|
return "【知识大脑】\n" + "\n".join(lines)
|
|
|
|
|
|
async def build_memory_context(
|
|
db: AsyncSession,
|
|
user_id: str,
|
|
conversation_id: str,
|
|
current_query: str,
|
|
) -> str:
|
|
"""
|
|
构建完整的记忆上下文字符串,
|
|
供注入到 Agent system prompt 中使用。
|
|
"""
|
|
parts: list[str] = []
|
|
|
|
if _should_include_user_memories(current_query):
|
|
user_memory_section = await _run_tolerated_section(
|
|
db,
|
|
"用户记忆召回",
|
|
lambda: _build_user_memory_section(db, user_id, current_query),
|
|
)
|
|
if user_memory_section:
|
|
parts.append(user_memory_section)
|
|
|
|
if _should_include_summaries(current_query):
|
|
summary_section = await _run_tolerated_section(
|
|
db,
|
|
"对话摘要加载",
|
|
lambda: _build_summary_section(db, conversation_id),
|
|
)
|
|
if summary_section:
|
|
parts.append(summary_section)
|
|
|
|
brain_section = await _run_tolerated_section(
|
|
db,
|
|
"知识大脑召回",
|
|
lambda: _build_brain_section(db, user_id, current_query),
|
|
)
|
|
if brain_section:
|
|
parts.append(brain_section)
|
|
|
|
if not parts:
|
|
return ""
|
|
return "\n\n".join(parts)
|
|
|
|
|
|
async def try_auto_summarize(
|
|
db: AsyncSession,
|
|
user_id: str,
|
|
conversation_id: str,
|
|
) -> bool:
|
|
"""
|
|
检查是否需要摘要,如果需要则生成并保存。
|
|
同时将对话内容存入 Mem0 进行记忆提取。
|
|
"""
|
|
if not await should_summarize(db, conversation_id):
|
|
return False
|
|
|
|
messages = await load_conversation_history(db, conversation_id, limit=30)
|
|
if len(messages) < 3:
|
|
return False
|
|
|
|
try:
|
|
summary_text = await generate_summary(db, conversation_id, messages)
|
|
turn_count = await get_conversation_turn_count(db, conversation_id)
|
|
await save_summary(db, user_id, conversation_id, summary_text, turn_count)
|
|
|
|
await extract_user_memories(db, user_id, conversation_id, messages)
|
|
return True
|
|
except Exception as e:
|
|
print(f"Auto summarize error: {e}")
|
|
return False
|
|
|
|
|
|
async def forget_memory(db: AsyncSession, user_id: str, memory_id: str) -> bool:
|
|
"""
|
|
主动遗忘某条记忆。
|
|
"""
|
|
try:
|
|
mem0 = await get_mem0(db, user_id)
|
|
mem0.delete(memory_id, user_id=user_id)
|
|
return True
|
|
except Exception as e:
|
|
print(f"Mem0 delete error: {e}")
|
|
return False
|
|
|
|
|
|
async def update_memory(
|
|
db: AsyncSession,
|
|
user_id: str,
|
|
memory_id: str,
|
|
content: str,
|
|
) -> bool:
|
|
"""
|
|
更新某条记忆。Mem0 会自动处理矛盾检测。
|
|
"""
|
|
try:
|
|
mem0 = await get_mem0(db, user_id)
|
|
mem0.update(memory_id, content, user_id=user_id)
|
|
return True
|
|
except Exception as e:
|
|
print(f"Mem0 update error: {e}")
|
|
return False
|