Files
JARVIS/backend/app/services/memory_service.py
WIN-JHFT4D3SIVT\caoxiaozhu 9bfa0dcc11 feat(memory): Day M.1 complete - importance scoring system
- Add FrequencyTracker: increment(), get_frequency_score(), get_recency_score(), get_time_decay()
- Add EmotionAnalyzer: EMOTION_KEYWORDS dict, extract(), calculate_score(), get_emotion_profile()
- Add ImpactEvaluator: evaluate(), get_topic_overlap(), rank_by_impact()
- Add ImportanceScorer: composite scoring (freq 35% + recency 20% + emotion 25% + impact 20%)
- Update UserMemory model: frequency_count, emotion_tags, importance_score, importance_level, associated_topics
- Integrate ImportanceScorer into memory_service.py (recall + importance update)
- Add 37 tests for all memory scoring components
- Fix urgency patterns: remove overly broad '今天' that matched neutral text
- Update memory-update checklist: mark all M.1 tasks complete
2026-04-05 13:22:23 +08:00

656 lines
19 KiB
Python

"""
Jarvis 记忆系统 (基于 Mem0)
三层记忆: 短期(对话历史) → 中期(摘要) → 长期(用户画像)
底层使用 Mem0 实现事实提取、时间线、矛盾解决和遗忘机制
"""
import logging
import os
import re
import json
from datetime import UTC, datetime
from typing import Optional, Any
from sqlalchemy import select, desc, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.conversation import Conversation, Message
from app.models.memory import UserMemory
from app.models.user import User
from app.services.brain_service import BrainService
from app.services.memory.frequency_tracker import FrequencyTracker
from app.services.memory.emotion_analyzer import EmotionAnalyzer
from app.services.memory.impact_evaluator import ImpactEvaluator
from app.services.memory.importance_scorer import ImportanceScorer
from app.config import settings as _settings
try:
from mem0 import Memory
MEM0_AVAILABLE = True
except ImportError:
MEM0_AVAILABLE = False
Memory = None
logger = logging.getLogger(__name__)
async def _get_user_embedding_config(db: AsyncSession, user_id: str) -> dict | None:
"""从用户配置中获取 embedding 模型配置"""
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user or not user.llm_config:
return None
embedding_models = user.llm_config.get("embedding", [])
for model in embedding_models:
if model.get("enabled") and model.get("model"):
return {
"model": model.get("model"),
"base_url": model.get("base_url") or _settings.EMBEDDING_BASE_URL,
"api_key": model.get("api_key")
or _settings.EMBEDDING_API_KEY
or _settings.OPENAI_API_KEY,
}
return None
async def _get_user_chat_config(db: AsyncSession, user_id: str) -> dict | None:
"""从用户配置中获取 chat 模型配置"""
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user or not user.llm_config:
return None
chat_models = user.llm_config.get("chat", [])
for model in chat_models:
if model.get("enabled") and model.get("model"):
return {
"model": model.get("model"),
"base_url": model.get("base_url") or _settings.OPENAI_BASE_URL,
"api_key": model.get("api_key") or _settings.OPENAI_API_KEY,
}
return None
class Mem0Client:
"""Mem0 客户端 - 按用户隔离"""
_instances: dict[str, Memory] = {}
_persist_dir: str = "./data/mem0"
async def get_memory(self, db: AsyncSession, user_id: str) -> Memory:
"""获取指定用户的 Mem0 实例"""
cache_key = user_id
if cache_key not in self._instances:
self._instances[cache_key] = await self._init_memory(db, user_id)
return self._instances[cache_key]
async def _init_memory(self, db: AsyncSession, user_id: str) -> Memory:
if not MEM0_AVAILABLE:
raise RuntimeError("mem0ai 未安装,请运行: pip install mem0ai")
os.makedirs(self._persist_dir, exist_ok=True)
llm_config = {
"model": _settings.OPENAI_MODEL,
"base_url": _settings.OPENAI_BASE_URL,
"api_key": _settings.OPENAI_API_KEY,
}
embed_config = _settings.EMBEDDING_MODEL
embed_base_url = _settings.EMBEDDING_BASE_URL
embed_api_key = _settings.EMBEDDING_API_KEY or _settings.OPENAI_API_KEY
if db and user_id:
try:
user_chat = await _get_user_chat_config(db, user_id)
if user_chat:
llm_config = user_chat
except Exception:
pass
try:
user_embed = await _get_user_embedding_config(db, user_id)
if user_embed:
embed_config = user_embed["model"]
embed_base_url = user_embed["base_url"]
embed_api_key = user_embed["api_key"]
except Exception:
pass
config = {
"vector_store": {
"provider": "chroma",
"config": {
"collection_name": f"jarvis_memory_{user_id}",
"path": self._persist_dir,
},
},
"llm": {
"provider": "openai",
"config": {
"model": llm_config["model"],
"api_key": llm_config["api_key"],
"base_url": llm_config["base_url"],
},
},
"embedder": {
"provider": "openai",
"config": {
"model": embed_config,
"api_key": embed_api_key,
"base_url": embed_base_url,
},
},
}
return Memory.from_config(config)
_mem0_client = Mem0Client()
async def get_mem0(db: AsyncSession, user_id: str) -> Memory:
"""获取指定用户的 Mem0 实例"""
return await _mem0_client.get_memory(db, user_id)
# ———— 短期记忆: 对话历史 ————
async def load_conversation_history(
db: AsyncSession,
conversation_id: str,
limit: int = 20,
) -> list[Message]:
"""加载指定对话的历史消息"""
result = await db.execute(
select(Message)
.where(Message.conversation_id == conversation_id)
.order_by(Message.created_at)
.limit(limit)
)
return list(result.scalars().all())
async def get_conversation_turn_count(db: AsyncSession, conversation_id: str) -> int:
"""获取对话轮数(用户消息数)"""
result = await db.execute(
select(func.count(Message.id)).where(
Message.conversation_id == conversation_id,
Message.role == "user",
)
)
return result.scalar() or 0
# ———— 中期记忆: 对话摘要 ————
SUMMARIZE_THRESHOLD = 8
MAX_HISTORY_TURNS = 10
async def should_summarize(db: AsyncSession, conversation_id: str) -> bool:
"""判断当前对话是否需要摘要"""
from app.models.memory import MemorySummary
turn_count = await get_conversation_turn_count(db, conversation_id)
result = await db.execute(
select(MemorySummary)
.where(MemorySummary.conversation_id == conversation_id)
.order_by(desc(MemorySummary.turn_count))
.limit(1)
)
latest_summary = result.scalar_one_or_none()
if latest_summary:
return turn_count - latest_summary.turn_count >= SUMMARIZE_THRESHOLD
return turn_count >= SUMMARIZE_THRESHOLD
async def generate_summary(
db: AsyncSession,
conversation_id: str,
messages: list[Message],
) -> str:
"""生成对话摘要"""
from app.services.llm_service import get_llm
from langchain_core.messages import HumanMessage, SystemMessage
history_text = "\n".join(f"[{m.role}] {m.content}" for m in messages)
llm = get_llm()
response = await llm.invoke(
[
SystemMessage(
content="你是一个记忆助手。请用简洁的中文总结以下对话的核心内容,"
"提取关键信息、用户偏好、待办事项等。不超过150字。"
),
HumanMessage(content=history_text),
]
)
return response.content.strip()
async def save_summary(
db: AsyncSession,
user_id: str,
conversation_id: str,
summary_text: str,
turn_count: int,
) -> Any:
"""保存对话摘要到数据库"""
from app.models.memory import MemorySummary
summary = MemorySummary(
user_id=user_id,
conversation_id=conversation_id,
summary_text=summary_text,
turn_count=turn_count,
)
db.add(summary)
await db.commit()
await db.refresh(summary)
return summary
async def get_summaries(
db: AsyncSession,
conversation_id: str,
) -> list[Any]:
"""获取某对话的所有历史摘要"""
from app.models.memory import MemorySummary
result = await db.execute(
select(MemorySummary)
.where(MemorySummary.conversation_id == conversation_id)
.order_by(MemorySummary.summary_at)
)
return list(result.scalars().all())
# ———— 长期记忆: 基于 Mem0 ————
async def extract_user_memories(
db: AsyncSession,
user_id: str,
conversation_id: str,
messages: list[Message],
) -> list[dict]:
"""
从对话中提取用户记忆并存储到 Mem0。
Mem0 会自动处理:
- 事实提取
- 时间线追踪
- 矛盾解决
- 遗忘机制
"""
if len(messages) < 2:
return []
history_text = "\n".join(f"[{m.role}] {m.content}" for m in messages[-10:])
try:
mem0 = await get_mem0(db, user_id)
result = mem0.add(
messages=[{"role": m.role, "content": m.content} for m in messages[-10:]],
user_id=user_id,
metadata={
"conversation_id": conversation_id,
"source": "jarvis_memory",
},
)
return result.get("results", [])
except Exception as e:
print(f"Mem0 extract error: {e}")
return []
def _extract_memory_query_tokens(query: str) -> list[str]:
normalized_query = (query or "").lower()
tokens = [token for token in re.findall(r"[a-z0-9]+", normalized_query) if len(token) >= 3]
for chunk in re.findall(r"[\u4e00-\u9fff]+", query or ""):
stripped_chunk = chunk.strip()
if len(stripped_chunk) >= 4:
tokens.append(stripped_chunk)
if len(stripped_chunk) > 6:
tokens.extend(
stripped_chunk[index : index + 4] for index in range(len(stripped_chunk) - 3)
)
return list(dict.fromkeys(tokens))
async def recall_user_memories(
db: AsyncSession,
user_id: str,
query: str,
top_k: int = 5,
) -> list[dict]:
"""
根据当前输入召回相关的用户记忆。
使用 Mem0 的语义搜索;如果 Mem0 不可用或失败,则回退到本地 UserMemory。
"""
try:
mem0 = await get_mem0(db, user_id)
results = mem0.search(
query=query,
filters={"user_id": user_id},
limit=top_k,
)
mem0_results = results.get("results", [])
if mem0_results:
return mem0_results
except Exception as e:
print(f"Mem0 search error: {e}")
query_tokens = _extract_memory_query_tokens(query)
statement = select(UserMemory).where(UserMemory.user_id == user_id)
result = await db.execute(
statement.order_by(UserMemory.importance_score.desc(), UserMemory.created_at.desc())
)
fallback_memories = list(result.scalars().all())
if _contains_hint(_normalize_query(query), MEMORY_QUERY_HINTS) or _matches_memory_query_pattern(
_normalize_query(query)
):
return fallback_memories[:top_k]
if query_tokens:
matched_memories = [
memory
for memory in fallback_memories
if any(token in (memory.content or "").lower() for token in query_tokens)
]
return matched_memories[:top_k]
return []
async def _mark_memories_recalled(db: AsyncSession, memories: list[UserMemory]) -> None:
"""Mark memories as recalled and update importance score"""
from app.services.memory.frequency_tracker import FrequencyTracker
from app.services.memory.importance_scorer import ImportanceScorer
recalled_at = datetime.now(UTC)
tracker = FrequencyTracker()
scorer = ImportanceScorer()
updated = False
for memory in memories:
memory.is_recalled = True
memory.recall_count = (memory.recall_count or 0) + 1
memory.last_recalled_at = recalled_at
memory.frequency_count = memory.recall_count # Keep in sync
# Update importance score on recall
scorer.update_memory_importance(memory)
updated = True
if updated:
await db.commit()
async def _run_tolerated_section(
db: AsyncSession,
section_name: str,
builder,
) -> str:
try:
return await builder()
except Exception:
logger.warning(
"[MemoryService] %s失败,继续构建剩余上下文",
section_name,
exc_info=True,
)
return ""
async def get_user_profile(db: AsyncSession, user_id: str) -> dict:
"""
获取用户画像。
Mem0 的 profile API 会返回 static 和 dynamic facts。
"""
try:
mem0 = await get_mem0(db, user_id)
result = mem0.history(user_id=user_id)
return {
"memories": result.get("results", []),
"static": [],
"dynamic": [],
}
except Exception as e:
print(f"Mem0 profile error: {e}")
return {"memories": [], "static": [], "dynamic": []}
# ———— 记忆组装: 供 Agent 使用的上下文 ————
MEMORY_QUERY_HINTS = (
"记住",
"记下",
"记一下",
"记着",
"提醒",
"偏好",
"习惯",
)
MEMORY_QUERY_PATTERNS = (re.compile(r"\bremember\s+(?:that\s+)?i\b"),)
GROUNDING_QUERY_HINTS = (
"根据文档",
"严格根据",
"只根据",
"文档内容",
"grounded",
"strictly based on",
"based on the document",
"based on the docs",
"document only",
"docs only",
"only use the document",
"only use the docs",
)
AVOID_USER_MEMORY_HINTS = (
"不要结合我的个人偏好",
"不要结合个人偏好",
"不要结合偏好",
"不要结合我的记忆",
"不要结合记忆",
)
def _normalize_query(text: str) -> str:
return text.strip().lower()
def _contains_hint(text: str, hints: tuple[str, ...]) -> bool:
return any(hint in text for hint in hints)
def _matches_memory_query_pattern(text: str) -> bool:
return any(pattern.search(text) for pattern in MEMORY_QUERY_PATTERNS)
def _should_include_user_memories(query: str) -> bool:
normalized_query = _normalize_query(query)
if _contains_hint(normalized_query, GROUNDING_QUERY_HINTS):
return False
if _contains_hint(normalized_query, AVOID_USER_MEMORY_HINTS):
return False
return True
def _should_include_summaries(query: str) -> bool:
normalized_query = _normalize_query(query)
if _contains_hint(normalized_query, GROUNDING_QUERY_HINTS):
return False
if _contains_hint(normalized_query, MEMORY_QUERY_HINTS):
return False
if _matches_memory_query_pattern(normalized_query):
return False
return True
async def _build_user_memory_section(
db: AsyncSession,
user_id: str,
current_query: str,
) -> str:
memories = await recall_user_memories(db, user_id, current_query, top_k=5)
if not memories:
return ""
lines = []
recalled_user_memories: list[UserMemory] = []
for memory in memories:
if isinstance(memory, UserMemory):
memory_text = memory.content
memory_type = memory.memory_type
recalled_user_memories.append(memory)
else:
memory_text = memory.get("memory", memory.get("text", ""))
memory_type = memory.get("memory_type")
if not memory_text:
continue
if memory_type:
lines.append(f" [{memory_type}] {memory_text}")
else:
lines.append(f" - {memory_text}")
if not lines:
return ""
if recalled_user_memories:
await _mark_memories_recalled(db, recalled_user_memories)
return "【用户记忆】\n" + "\n".join(lines)
async def _build_summary_section(db: AsyncSession, conversation_id: str) -> str:
summaries = await get_summaries(db, conversation_id)
if not summaries:
return ""
recent = summaries[-2:]
lines = [f"[对话摘要{i + 1}] {summary.summary_text}" for i, summary in enumerate(recent)]
return "【之前对话摘要】\n" + "\n".join(lines)
async def _build_brain_section(
db: AsyncSession,
user_id: str,
current_query: str,
) -> str:
brain_memories = await BrainService(db).recall_memories(user_id, current_query, top_k=3)
if not brain_memories:
return ""
lines = [f"- {memory.title}: {memory.content}" for memory in brain_memories]
return "【知识大脑】\n" + "\n".join(lines)
async def build_memory_context(
db: AsyncSession,
user_id: str,
conversation_id: str,
current_query: str,
) -> str:
"""
构建完整的记忆上下文字符串,
供注入到 Agent system prompt 中使用。
"""
parts: list[str] = []
if _should_include_user_memories(current_query):
user_memory_section = await _run_tolerated_section(
db,
"用户记忆召回",
lambda: _build_user_memory_section(db, user_id, current_query),
)
if user_memory_section:
parts.append(user_memory_section)
if _should_include_summaries(current_query):
summary_section = await _run_tolerated_section(
db,
"对话摘要加载",
lambda: _build_summary_section(db, conversation_id),
)
if summary_section:
parts.append(summary_section)
brain_section = await _run_tolerated_section(
db,
"知识大脑召回",
lambda: _build_brain_section(db, user_id, current_query),
)
if brain_section:
parts.append(brain_section)
if not parts:
return ""
return "\n\n".join(parts)
async def try_auto_summarize(
db: AsyncSession,
user_id: str,
conversation_id: str,
) -> bool:
"""
检查是否需要摘要,如果需要则生成并保存。
同时将对话内容存入 Mem0 进行记忆提取。
"""
if not await should_summarize(db, conversation_id):
return False
messages = await load_conversation_history(db, conversation_id, limit=30)
if len(messages) < 3:
return False
try:
summary_text = await generate_summary(db, conversation_id, messages)
turn_count = await get_conversation_turn_count(db, conversation_id)
await save_summary(db, user_id, conversation_id, summary_text, turn_count)
await extract_user_memories(db, user_id, conversation_id, messages)
return True
except Exception as e:
print(f"Auto summarize error: {e}")
return False
async def forget_memory(db: AsyncSession, user_id: str, memory_id: str) -> bool:
"""
主动遗忘某条记忆。
"""
try:
mem0 = await get_mem0(db, user_id)
mem0.delete(memory_id, user_id=user_id)
return True
except Exception as e:
print(f"Mem0 delete error: {e}")
return False
async def update_memory(
db: AsyncSession,
user_id: str,
memory_id: str,
content: str,
) -> bool:
"""
更新某条记忆。Mem0 会自动处理矛盾检测。
"""
try:
mem0 = await get_mem0(db, user_id)
mem0.update(memory_id, content, user_id=user_id)
return True
except Exception as e:
print(f"Mem0 update error: {e}")
return False