2026-03-21 10:13:29 +08:00
|
|
|
|
"""
|
|
|
|
|
|
Jarvis 记忆系统
|
|
|
|
|
|
三层记忆: 短期(对话历史) → 中期(摘要) → 长期(用户画像)
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
import json
|
|
|
|
|
|
import re
|
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
|
from typing import Optional
|
|
|
|
|
|
from sqlalchemy import select, desc, func
|
|
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
|
|
from app.models.memory import MemorySummary, UserMemory
|
|
|
|
|
|
from app.models.conversation import Conversation, Message
|
2026-03-22 13:42:16 +08:00
|
|
|
|
from app.services.brain_service import BrainService
|
2026-03-21 10:13:29 +08:00
|
|
|
|
from app.services.llm_service import get_llm
|
|
|
|
|
|
from app.agents.context import get_current_user
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ———— 短期记忆: 对话历史 ————
|
|
|
|
|
|
|
|
|
|
|
|
async def load_conversation_history(
|
|
|
|
|
|
db: AsyncSession,
|
|
|
|
|
|
conversation_id: str,
|
|
|
|
|
|
limit: int = 20,
|
|
|
|
|
|
) -> list[Message]:
|
|
|
|
|
|
"""加载指定对话的历史消息"""
|
|
|
|
|
|
result = await db.execute(
|
|
|
|
|
|
select(Message)
|
|
|
|
|
|
.where(Message.conversation_id == conversation_id)
|
|
|
|
|
|
.order_by(Message.created_at)
|
|
|
|
|
|
.limit(limit)
|
|
|
|
|
|
)
|
|
|
|
|
|
return list(result.scalars().all())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def get_conversation_turn_count(db: AsyncSession, conversation_id: str) -> int:
|
|
|
|
|
|
"""获取对话轮数(用户消息数)"""
|
|
|
|
|
|
result = await db.execute(
|
|
|
|
|
|
select(func.count(Message.id))
|
|
|
|
|
|
.where(
|
|
|
|
|
|
Message.conversation_id == conversation_id,
|
|
|
|
|
|
Message.role == "user",
|
|
|
|
|
|
)
|
|
|
|
|
|
)
|
|
|
|
|
|
return result.scalar() or 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ———— 中期记忆: 对话摘要 ————
|
|
|
|
|
|
|
|
|
|
|
|
SUMMARIZE_THRESHOLD = 8 # 超过此轮数则摘要
|
|
|
|
|
|
MAX_HISTORY_TURNS = 10 # Agent 最多看到的对话历史轮数
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def should_summarize(db: AsyncSession, conversation_id: str) -> bool:
|
|
|
|
|
|
"""判断当前对话是否需要摘要"""
|
|
|
|
|
|
turn_count = await get_conversation_turn_count(db, conversation_id)
|
|
|
|
|
|
# 检查是否已有摘要覆盖到当前轮数
|
|
|
|
|
|
result = await db.execute(
|
|
|
|
|
|
select(MemorySummary)
|
|
|
|
|
|
.where(MemorySummary.conversation_id == conversation_id)
|
|
|
|
|
|
.order_by(desc(MemorySummary.turn_count))
|
|
|
|
|
|
.limit(1)
|
|
|
|
|
|
)
|
|
|
|
|
|
latest_summary = result.scalar_one_or_none()
|
|
|
|
|
|
if latest_summary:
|
|
|
|
|
|
return turn_count - latest_summary.turn_count >= SUMMARIZE_THRESHOLD
|
|
|
|
|
|
return turn_count >= SUMMARIZE_THRESHOLD
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def generate_summary(
|
|
|
|
|
|
db: AsyncSession,
|
|
|
|
|
|
conversation_id: str,
|
|
|
|
|
|
messages: list[Message],
|
|
|
|
|
|
) -> str:
|
|
|
|
|
|
"""调用 LLM 生成对话摘要"""
|
|
|
|
|
|
history_text = "\n".join(
|
|
|
|
|
|
f"[{m.role}] {m.content}" for m in messages
|
|
|
|
|
|
)
|
|
|
|
|
|
llm = get_llm()
|
|
|
|
|
|
from langchain_core.messages import HumanMessage, SystemMessage
|
|
|
|
|
|
response = await llm.invoke([
|
|
|
|
|
|
SystemMessage(content="你是一个记忆助手。请用简洁的中文总结以下对话的核心内容,"
|
|
|
|
|
|
"提取关键信息、用户偏好、待办事项等。不超过150字。"),
|
|
|
|
|
|
HumanMessage(content=history_text),
|
|
|
|
|
|
])
|
|
|
|
|
|
return response.content.strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def save_summary(
|
|
|
|
|
|
db: AsyncSession,
|
|
|
|
|
|
user_id: str,
|
|
|
|
|
|
conversation_id: str,
|
|
|
|
|
|
summary_text: str,
|
|
|
|
|
|
turn_count: int,
|
|
|
|
|
|
) -> MemorySummary:
|
|
|
|
|
|
"""保存对话摘要"""
|
|
|
|
|
|
summary = MemorySummary(
|
|
|
|
|
|
user_id=user_id,
|
|
|
|
|
|
conversation_id=conversation_id,
|
|
|
|
|
|
summary_text=summary_text,
|
|
|
|
|
|
turn_count=turn_count,
|
|
|
|
|
|
)
|
|
|
|
|
|
db.add(summary)
|
|
|
|
|
|
await db.commit()
|
|
|
|
|
|
await db.refresh(summary)
|
|
|
|
|
|
return summary
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def get_summaries(
|
|
|
|
|
|
db: AsyncSession,
|
|
|
|
|
|
conversation_id: str,
|
|
|
|
|
|
) -> list[MemorySummary]:
|
|
|
|
|
|
"""获取某对话的所有历史摘要"""
|
|
|
|
|
|
result = await db.execute(
|
|
|
|
|
|
select(MemorySummary)
|
|
|
|
|
|
.where(MemorySummary.conversation_id == conversation_id)
|
|
|
|
|
|
.order_by(MemorySummary.summary_at)
|
|
|
|
|
|
)
|
|
|
|
|
|
return list(result.scalars().all())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ———— 长期记忆: 用户画像 ————
|
|
|
|
|
|
|
|
|
|
|
|
EXTRACTION_PROMPT = """从以下对话中提取关于用户的关键信息。
|
|
|
|
|
|
只提取事实性的、可能对未来对话有帮助的信息,如:
|
|
|
|
|
|
- 用户的身份/职业/背景
|
|
|
|
|
|
- 用户的偏好和习惯
|
|
|
|
|
|
- 用户的目标和计划
|
|
|
|
|
|
- 重要的事件和日期
|
|
|
|
|
|
- 用户的观点和态度
|
|
|
|
|
|
|
|
|
|
|
|
每条记忆格式: [类型] 内容
|
|
|
|
|
|
类型: fact(事实) | preference(偏好) | goal(目标) | habit(习惯)
|
|
|
|
|
|
|
|
|
|
|
|
如果没有提取到任何记忆,回复"无"。
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
FACT_TYPES = {"fact", "preference", "goal", "habit"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _parse_fact_line(line: str) -> tuple[str, str] | None:
|
|
|
|
|
|
"""解析一行记忆: [fact] 内容 -> (type, content)"""
|
|
|
|
|
|
m = re.match(r"\[(\w+)\]\s*(.+)", line.strip())
|
|
|
|
|
|
if m and m.group(1) in FACT_TYPES:
|
|
|
|
|
|
return m.group(1), m.group(2).strip()
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def extract_user_memories(
|
|
|
|
|
|
db: AsyncSession,
|
|
|
|
|
|
user_id: str,
|
|
|
|
|
|
conversation_id: str,
|
|
|
|
|
|
messages: list[Message],
|
|
|
|
|
|
) -> list[UserMemory]:
|
|
|
|
|
|
"""从对话中提取用户记忆并保存"""
|
|
|
|
|
|
if len(messages) < 2:
|
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
history_text = "\n".join(
|
|
|
|
|
|
f"[{m.role}] {m.content}" for m in messages[-10:]
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
llm = get_llm()
|
|
|
|
|
|
from langchain_core.messages import HumanMessage, SystemMessage
|
|
|
|
|
|
response = await llm.invoke([
|
|
|
|
|
|
SystemMessage(content=EXTRACTION_PROMPT),
|
|
|
|
|
|
HumanMessage(content=history_text),
|
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
|
|
text = response.content.strip()
|
|
|
|
|
|
if text == "无" or not text:
|
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
memories = []
|
|
|
|
|
|
for line in text.split("\n"):
|
|
|
|
|
|
parsed = _parse_fact_line(line)
|
|
|
|
|
|
if not parsed:
|
|
|
|
|
|
continue
|
|
|
|
|
|
mem_type, content = parsed
|
|
|
|
|
|
# 检查是否已有完全相同的记忆
|
|
|
|
|
|
existing = await db.execute(
|
|
|
|
|
|
select(UserMemory).where(
|
|
|
|
|
|
UserMemory.user_id == user_id,
|
|
|
|
|
|
UserMemory.content == content,
|
|
|
|
|
|
)
|
|
|
|
|
|
)
|
|
|
|
|
|
if existing.scalar_one_or_none():
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
mem = UserMemory(
|
|
|
|
|
|
user_id=user_id,
|
|
|
|
|
|
memory_type=mem_type,
|
|
|
|
|
|
content=content,
|
|
|
|
|
|
importance=5,
|
|
|
|
|
|
source_conversation_id=conversation_id,
|
|
|
|
|
|
)
|
|
|
|
|
|
db.add(mem)
|
|
|
|
|
|
memories.append(mem)
|
|
|
|
|
|
|
|
|
|
|
|
if memories:
|
|
|
|
|
|
await db.commit()
|
|
|
|
|
|
return memories
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def recall_user_memories(
|
|
|
|
|
|
db: AsyncSession,
|
|
|
|
|
|
user_id: str,
|
|
|
|
|
|
query: str,
|
|
|
|
|
|
top_k: int = 5,
|
|
|
|
|
|
) -> list[UserMemory]:
|
|
|
|
|
|
"""根据当前输入召回相关的用户记忆(简单关键词匹配)"""
|
|
|
|
|
|
# 先尝试语义相似(通过 LLM 判断)
|
|
|
|
|
|
# 降级: 直接从数据库取最近的重要记忆
|
|
|
|
|
|
result = await db.execute(
|
|
|
|
|
|
select(UserMemory)
|
|
|
|
|
|
.where(UserMemory.user_id == user_id)
|
|
|
|
|
|
.order_by(desc(UserMemory.importance), desc(UserMemory.recall_count))
|
|
|
|
|
|
.limit(top_k)
|
|
|
|
|
|
)
|
|
|
|
|
|
memories = list(result.scalars().all())
|
|
|
|
|
|
|
|
|
|
|
|
# 重置召回标记
|
|
|
|
|
|
for m in memories:
|
|
|
|
|
|
m.is_recalled = False
|
|
|
|
|
|
await db.commit()
|
|
|
|
|
|
|
|
|
|
|
|
return memories
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def mark_memory_recalled(db: AsyncSession, memory_id: str):
|
|
|
|
|
|
"""标记记忆已被召回使用"""
|
|
|
|
|
|
result = await db.execute(
|
|
|
|
|
|
select(UserMemory).where(UserMemory.id == memory_id)
|
|
|
|
|
|
)
|
|
|
|
|
|
mem = result.scalar_one_or_none()
|
|
|
|
|
|
if mem:
|
|
|
|
|
|
mem.is_recalled = True
|
|
|
|
|
|
mem.recall_count = (mem.recall_count or 0) + 1
|
2026-03-22 13:42:16 +08:00
|
|
|
|
mem.last_recalled_at = datetime.now(UTC)
|
2026-03-21 10:13:29 +08:00
|
|
|
|
await db.commit()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ———— 记忆组装: 供 Agent 使用的上下文 ————
|
|
|
|
|
|
|
|
|
|
|
|
async def build_memory_context(
|
|
|
|
|
|
db: AsyncSession,
|
|
|
|
|
|
user_id: str,
|
|
|
|
|
|
conversation_id: str,
|
|
|
|
|
|
current_query: str,
|
|
|
|
|
|
) -> str:
|
|
|
|
|
|
"""
|
|
|
|
|
|
构建完整的记忆上下文字符串,
|
|
|
|
|
|
供注入到 Agent system prompt 中使用。
|
|
|
|
|
|
"""
|
|
|
|
|
|
parts = []
|
|
|
|
|
|
|
|
|
|
|
|
# 1. 用户画像(长期记忆)
|
|
|
|
|
|
user_memories = await recall_user_memories(db, user_id, current_query, top_k=5)
|
|
|
|
|
|
if user_memories:
|
|
|
|
|
|
lines = []
|
|
|
|
|
|
for m in user_memories:
|
|
|
|
|
|
tag = f"[{m.memory_type}]"
|
|
|
|
|
|
lines.append(f" {tag} {m.content}")
|
|
|
|
|
|
await mark_memory_recalled(db, m.id)
|
|
|
|
|
|
parts.append("【用户记忆】\n" + "\n".join(lines))
|
|
|
|
|
|
|
|
|
|
|
|
# 2. 对话摘要(中期记忆)
|
|
|
|
|
|
summaries = await get_summaries(db, conversation_id)
|
|
|
|
|
|
if summaries:
|
|
|
|
|
|
# 只取最近2条
|
|
|
|
|
|
recent = summaries[-2:]
|
|
|
|
|
|
lines = [f"[对话摘要{i+1}] {s.summary_text}" for i, s in enumerate(recent)]
|
|
|
|
|
|
parts.append("【之前对话摘要】\n" + "\n".join(lines))
|
|
|
|
|
|
|
2026-03-22 13:42:16 +08:00
|
|
|
|
# 3. 知识大脑(长期项目记忆)
|
|
|
|
|
|
brain_memories = await BrainService(db).recall_memories(user_id, current_query, top_k=3)
|
|
|
|
|
|
if brain_memories:
|
|
|
|
|
|
lines = []
|
|
|
|
|
|
for memory in brain_memories:
|
|
|
|
|
|
lines.append(f"- {memory.title}: {memory.content}")
|
|
|
|
|
|
parts.append("【知识大脑】\n" + "\n".join(lines))
|
|
|
|
|
|
|
2026-03-21 10:13:29 +08:00
|
|
|
|
if not parts:
|
|
|
|
|
|
return ""
|
|
|
|
|
|
return "\n\n".join(parts)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def try_auto_summarize(
|
|
|
|
|
|
db: AsyncSession,
|
|
|
|
|
|
user_id: str,
|
|
|
|
|
|
conversation_id: str,
|
|
|
|
|
|
) -> bool:
|
|
|
|
|
|
"""
|
|
|
|
|
|
检查是否需要摘要,如果需要则生成并保存。
|
|
|
|
|
|
返回是否执行了摘要。
|
|
|
|
|
|
"""
|
|
|
|
|
|
if not await should_summarize(db, conversation_id):
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
messages = await load_conversation_history(db, conversation_id, limit=30)
|
|
|
|
|
|
if len(messages) < 3:
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
summary_text = await generate_summary(db, conversation_id, messages)
|
|
|
|
|
|
turn_count = await get_conversation_turn_count(db, conversation_id)
|
|
|
|
|
|
await save_summary(db, user_id, conversation_id, summary_text, turn_count)
|
|
|
|
|
|
|
|
|
|
|
|
# 同时提取用户记忆
|
|
|
|
|
|
await extract_user_memories(db, user_id, conversation_id, messages)
|
|
|
|
|
|
return True
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
return False
|