Files
JARVIS/backend/app/services/memory_service.py

305 lines
8.8 KiB
Python
Raw Normal View History

2026-03-21 10:13:29 +08:00
"""
Jarvis 记忆系统
三层记忆: 短期(对话历史) 中期(摘要) 长期(用户画像)
"""
import json
import re
from datetime import datetime
from typing import Optional
from sqlalchemy import select, desc, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.memory import MemorySummary, UserMemory
from app.models.conversation import Conversation, Message
from app.services.llm_service import get_llm
from app.agents.context import get_current_user
# ———— 短期记忆: 对话历史 ————
async def load_conversation_history(
db: AsyncSession,
conversation_id: str,
limit: int = 20,
) -> list[Message]:
"""加载指定对话的历史消息"""
result = await db.execute(
select(Message)
.where(Message.conversation_id == conversation_id)
.order_by(Message.created_at)
.limit(limit)
)
return list(result.scalars().all())
async def get_conversation_turn_count(db: AsyncSession, conversation_id: str) -> int:
"""获取对话轮数(用户消息数)"""
result = await db.execute(
select(func.count(Message.id))
.where(
Message.conversation_id == conversation_id,
Message.role == "user",
)
)
return result.scalar() or 0
# ———— 中期记忆: 对话摘要 ————
SUMMARIZE_THRESHOLD = 8 # 超过此轮数则摘要
MAX_HISTORY_TURNS = 10 # Agent 最多看到的对话历史轮数
async def should_summarize(db: AsyncSession, conversation_id: str) -> bool:
"""判断当前对话是否需要摘要"""
turn_count = await get_conversation_turn_count(db, conversation_id)
# 检查是否已有摘要覆盖到当前轮数
result = await db.execute(
select(MemorySummary)
.where(MemorySummary.conversation_id == conversation_id)
.order_by(desc(MemorySummary.turn_count))
.limit(1)
)
latest_summary = result.scalar_one_or_none()
if latest_summary:
return turn_count - latest_summary.turn_count >= SUMMARIZE_THRESHOLD
return turn_count >= SUMMARIZE_THRESHOLD
async def generate_summary(
db: AsyncSession,
conversation_id: str,
messages: list[Message],
) -> str:
"""调用 LLM 生成对话摘要"""
history_text = "\n".join(
f"[{m.role}] {m.content}" for m in messages
)
llm = get_llm()
from langchain_core.messages import HumanMessage, SystemMessage
response = await llm.invoke([
SystemMessage(content="你是一个记忆助手。请用简洁的中文总结以下对话的核心内容,"
"提取关键信息、用户偏好、待办事项等。不超过150字。"),
HumanMessage(content=history_text),
])
return response.content.strip()
async def save_summary(
db: AsyncSession,
user_id: str,
conversation_id: str,
summary_text: str,
turn_count: int,
) -> MemorySummary:
"""保存对话摘要"""
summary = MemorySummary(
user_id=user_id,
conversation_id=conversation_id,
summary_text=summary_text,
turn_count=turn_count,
)
db.add(summary)
await db.commit()
await db.refresh(summary)
return summary
async def get_summaries(
db: AsyncSession,
conversation_id: str,
) -> list[MemorySummary]:
"""获取某对话的所有历史摘要"""
result = await db.execute(
select(MemorySummary)
.where(MemorySummary.conversation_id == conversation_id)
.order_by(MemorySummary.summary_at)
)
return list(result.scalars().all())
# ———— 长期记忆: 用户画像 ————
EXTRACTION_PROMPT = """从以下对话中提取关于用户的关键信息。
只提取事实性的可能对未来对话有帮助的信息
- 用户的身份/职业/背景
- 用户的偏好和习惯
- 用户的目标和计划
- 重要的事件和日期
- 用户的观点和态度
每条记忆格式: [类型] 内容
类型: fact(事实) | preference(偏好) | goal(目标) | habit(习惯)
如果没有提取到任何记忆回复""
"""
FACT_TYPES = {"fact", "preference", "goal", "habit"}
def _parse_fact_line(line: str) -> tuple[str, str] | None:
"""解析一行记忆: [fact] 内容 -> (type, content)"""
m = re.match(r"\[(\w+)\]\s*(.+)", line.strip())
if m and m.group(1) in FACT_TYPES:
return m.group(1), m.group(2).strip()
return None
async def extract_user_memories(
db: AsyncSession,
user_id: str,
conversation_id: str,
messages: list[Message],
) -> list[UserMemory]:
"""从对话中提取用户记忆并保存"""
if len(messages) < 2:
return []
history_text = "\n".join(
f"[{m.role}] {m.content}" for m in messages[-10:]
)
llm = get_llm()
from langchain_core.messages import HumanMessage, SystemMessage
response = await llm.invoke([
SystemMessage(content=EXTRACTION_PROMPT),
HumanMessage(content=history_text),
])
text = response.content.strip()
if text == "" or not text:
return []
memories = []
for line in text.split("\n"):
parsed = _parse_fact_line(line)
if not parsed:
continue
mem_type, content = parsed
# 检查是否已有完全相同的记忆
existing = await db.execute(
select(UserMemory).where(
UserMemory.user_id == user_id,
UserMemory.content == content,
)
)
if existing.scalar_one_or_none():
continue
mem = UserMemory(
user_id=user_id,
memory_type=mem_type,
content=content,
importance=5,
source_conversation_id=conversation_id,
)
db.add(mem)
memories.append(mem)
if memories:
await db.commit()
return memories
async def recall_user_memories(
db: AsyncSession,
user_id: str,
query: str,
top_k: int = 5,
) -> list[UserMemory]:
"""根据当前输入召回相关的用户记忆(简单关键词匹配)"""
# 先尝试语义相似(通过 LLM 判断)
# 降级: 直接从数据库取最近的重要记忆
result = await db.execute(
select(UserMemory)
.where(UserMemory.user_id == user_id)
.order_by(desc(UserMemory.importance), desc(UserMemory.recall_count))
.limit(top_k)
)
memories = list(result.scalars().all())
# 重置召回标记
for m in memories:
m.is_recalled = False
await db.commit()
return memories
async def mark_memory_recalled(db: AsyncSession, memory_id: str):
"""标记记忆已被召回使用"""
result = await db.execute(
select(UserMemory).where(UserMemory.id == memory_id)
)
mem = result.scalar_one_or_none()
if mem:
mem.is_recalled = True
mem.recall_count = (mem.recall_count or 0) + 1
mem.last_recalled_at = datetime.utcnow()
await db.commit()
# ———— 记忆组装: 供 Agent 使用的上下文 ————
async def build_memory_context(
db: AsyncSession,
user_id: str,
conversation_id: str,
current_query: str,
) -> str:
"""
构建完整的记忆上下文字符串
供注入到 Agent system prompt 中使用
"""
parts = []
# 1. 用户画像(长期记忆)
user_memories = await recall_user_memories(db, user_id, current_query, top_k=5)
if user_memories:
lines = []
for m in user_memories:
tag = f"[{m.memory_type}]"
lines.append(f" {tag} {m.content}")
await mark_memory_recalled(db, m.id)
parts.append("【用户记忆】\n" + "\n".join(lines))
# 2. 对话摘要(中期记忆)
summaries = await get_summaries(db, conversation_id)
if summaries:
# 只取最近2条
recent = summaries[-2:]
lines = [f"[对话摘要{i+1}] {s.summary_text}" for i, s in enumerate(recent)]
parts.append("【之前对话摘要】\n" + "\n".join(lines))
if not parts:
return ""
return "\n\n".join(parts)
async def try_auto_summarize(
db: AsyncSession,
user_id: str,
conversation_id: str,
) -> bool:
"""
检查是否需要摘要如果需要则生成并保存
返回是否执行了摘要
"""
if not await should_summarize(db, conversation_id):
return False
messages = await load_conversation_history(db, conversation_id, limit=30)
if len(messages) < 3:
return False
try:
summary_text = await generate_summary(db, conversation_id, messages)
turn_count = await get_conversation_turn_count(db, conversation_id)
await save_summary(db, user_id, conversation_id, summary_text, turn_count)
# 同时提取用户记忆
await extract_user_memories(db, user_id, conversation_id, messages)
return True
except Exception:
return False