Add FastAPI backend with agent system
This commit is contained in:
304
backend/app/services/memory_service.py
Normal file
304
backend/app/services/memory_service.py
Normal file
@@ -0,0 +1,304 @@
|
||||
"""
|
||||
Jarvis 记忆系统
|
||||
三层记忆: 短期(对话历史) → 中期(摘要) → 长期(用户画像)
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from sqlalchemy import select, desc, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.models.memory import MemorySummary, UserMemory
|
||||
from app.models.conversation import Conversation, Message
|
||||
from app.services.llm_service import get_llm
|
||||
from app.agents.context import get_current_user
|
||||
|
||||
|
||||
# ———— 短期记忆: 对话历史 ————
|
||||
|
||||
async def load_conversation_history(
|
||||
db: AsyncSession,
|
||||
conversation_id: str,
|
||||
limit: int = 20,
|
||||
) -> list[Message]:
|
||||
"""加载指定对话的历史消息"""
|
||||
result = await db.execute(
|
||||
select(Message)
|
||||
.where(Message.conversation_id == conversation_id)
|
||||
.order_by(Message.created_at)
|
||||
.limit(limit)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
async def get_conversation_turn_count(db: AsyncSession, conversation_id: str) -> int:
|
||||
"""获取对话轮数(用户消息数)"""
|
||||
result = await db.execute(
|
||||
select(func.count(Message.id))
|
||||
.where(
|
||||
Message.conversation_id == conversation_id,
|
||||
Message.role == "user",
|
||||
)
|
||||
)
|
||||
return result.scalar() or 0
|
||||
|
||||
|
||||
# ———— 中期记忆: 对话摘要 ————
|
||||
|
||||
SUMMARIZE_THRESHOLD = 8 # 超过此轮数则摘要
|
||||
MAX_HISTORY_TURNS = 10 # Agent 最多看到的对话历史轮数
|
||||
|
||||
|
||||
async def should_summarize(db: AsyncSession, conversation_id: str) -> bool:
|
||||
"""判断当前对话是否需要摘要"""
|
||||
turn_count = await get_conversation_turn_count(db, conversation_id)
|
||||
# 检查是否已有摘要覆盖到当前轮数
|
||||
result = await db.execute(
|
||||
select(MemorySummary)
|
||||
.where(MemorySummary.conversation_id == conversation_id)
|
||||
.order_by(desc(MemorySummary.turn_count))
|
||||
.limit(1)
|
||||
)
|
||||
latest_summary = result.scalar_one_or_none()
|
||||
if latest_summary:
|
||||
return turn_count - latest_summary.turn_count >= SUMMARIZE_THRESHOLD
|
||||
return turn_count >= SUMMARIZE_THRESHOLD
|
||||
|
||||
|
||||
async def generate_summary(
|
||||
db: AsyncSession,
|
||||
conversation_id: str,
|
||||
messages: list[Message],
|
||||
) -> str:
|
||||
"""调用 LLM 生成对话摘要"""
|
||||
history_text = "\n".join(
|
||||
f"[{m.role}] {m.content}" for m in messages
|
||||
)
|
||||
llm = get_llm()
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
response = await llm.invoke([
|
||||
SystemMessage(content="你是一个记忆助手。请用简洁的中文总结以下对话的核心内容,"
|
||||
"提取关键信息、用户偏好、待办事项等。不超过150字。"),
|
||||
HumanMessage(content=history_text),
|
||||
])
|
||||
return response.content.strip()
|
||||
|
||||
|
||||
async def save_summary(
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
conversation_id: str,
|
||||
summary_text: str,
|
||||
turn_count: int,
|
||||
) -> MemorySummary:
|
||||
"""保存对话摘要"""
|
||||
summary = MemorySummary(
|
||||
user_id=user_id,
|
||||
conversation_id=conversation_id,
|
||||
summary_text=summary_text,
|
||||
turn_count=turn_count,
|
||||
)
|
||||
db.add(summary)
|
||||
await db.commit()
|
||||
await db.refresh(summary)
|
||||
return summary
|
||||
|
||||
|
||||
async def get_summaries(
|
||||
db: AsyncSession,
|
||||
conversation_id: str,
|
||||
) -> list[MemorySummary]:
|
||||
"""获取某对话的所有历史摘要"""
|
||||
result = await db.execute(
|
||||
select(MemorySummary)
|
||||
.where(MemorySummary.conversation_id == conversation_id)
|
||||
.order_by(MemorySummary.summary_at)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
# ———— 长期记忆: 用户画像 ————
|
||||
|
||||
EXTRACTION_PROMPT = """从以下对话中提取关于用户的关键信息。
|
||||
只提取事实性的、可能对未来对话有帮助的信息,如:
|
||||
- 用户的身份/职业/背景
|
||||
- 用户的偏好和习惯
|
||||
- 用户的目标和计划
|
||||
- 重要的事件和日期
|
||||
- 用户的观点和态度
|
||||
|
||||
每条记忆格式: [类型] 内容
|
||||
类型: fact(事实) | preference(偏好) | goal(目标) | habit(习惯)
|
||||
|
||||
如果没有提取到任何记忆,回复"无"。
|
||||
"""
|
||||
|
||||
FACT_TYPES = {"fact", "preference", "goal", "habit"}
|
||||
|
||||
|
||||
def _parse_fact_line(line: str) -> tuple[str, str] | None:
|
||||
"""解析一行记忆: [fact] 内容 -> (type, content)"""
|
||||
m = re.match(r"\[(\w+)\]\s*(.+)", line.strip())
|
||||
if m and m.group(1) in FACT_TYPES:
|
||||
return m.group(1), m.group(2).strip()
|
||||
return None
|
||||
|
||||
|
||||
async def extract_user_memories(
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
conversation_id: str,
|
||||
messages: list[Message],
|
||||
) -> list[UserMemory]:
|
||||
"""从对话中提取用户记忆并保存"""
|
||||
if len(messages) < 2:
|
||||
return []
|
||||
|
||||
history_text = "\n".join(
|
||||
f"[{m.role}] {m.content}" for m in messages[-10:]
|
||||
)
|
||||
|
||||
llm = get_llm()
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
response = await llm.invoke([
|
||||
SystemMessage(content=EXTRACTION_PROMPT),
|
||||
HumanMessage(content=history_text),
|
||||
])
|
||||
|
||||
text = response.content.strip()
|
||||
if text == "无" or not text:
|
||||
return []
|
||||
|
||||
memories = []
|
||||
for line in text.split("\n"):
|
||||
parsed = _parse_fact_line(line)
|
||||
if not parsed:
|
||||
continue
|
||||
mem_type, content = parsed
|
||||
# 检查是否已有完全相同的记忆
|
||||
existing = await db.execute(
|
||||
select(UserMemory).where(
|
||||
UserMemory.user_id == user_id,
|
||||
UserMemory.content == content,
|
||||
)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
continue
|
||||
|
||||
mem = UserMemory(
|
||||
user_id=user_id,
|
||||
memory_type=mem_type,
|
||||
content=content,
|
||||
importance=5,
|
||||
source_conversation_id=conversation_id,
|
||||
)
|
||||
db.add(mem)
|
||||
memories.append(mem)
|
||||
|
||||
if memories:
|
||||
await db.commit()
|
||||
return memories
|
||||
|
||||
|
||||
async def recall_user_memories(
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
query: str,
|
||||
top_k: int = 5,
|
||||
) -> list[UserMemory]:
|
||||
"""根据当前输入召回相关的用户记忆(简单关键词匹配)"""
|
||||
# 先尝试语义相似(通过 LLM 判断)
|
||||
# 降级: 直接从数据库取最近的重要记忆
|
||||
result = await db.execute(
|
||||
select(UserMemory)
|
||||
.where(UserMemory.user_id == user_id)
|
||||
.order_by(desc(UserMemory.importance), desc(UserMemory.recall_count))
|
||||
.limit(top_k)
|
||||
)
|
||||
memories = list(result.scalars().all())
|
||||
|
||||
# 重置召回标记
|
||||
for m in memories:
|
||||
m.is_recalled = False
|
||||
await db.commit()
|
||||
|
||||
return memories
|
||||
|
||||
|
||||
async def mark_memory_recalled(db: AsyncSession, memory_id: str):
|
||||
"""标记记忆已被召回使用"""
|
||||
result = await db.execute(
|
||||
select(UserMemory).where(UserMemory.id == memory_id)
|
||||
)
|
||||
mem = result.scalar_one_or_none()
|
||||
if mem:
|
||||
mem.is_recalled = True
|
||||
mem.recall_count = (mem.recall_count or 0) + 1
|
||||
mem.last_recalled_at = datetime.utcnow()
|
||||
await db.commit()
|
||||
|
||||
|
||||
# ———— 记忆组装: 供 Agent 使用的上下文 ————
|
||||
|
||||
async def build_memory_context(
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
conversation_id: str,
|
||||
current_query: str,
|
||||
) -> str:
|
||||
"""
|
||||
构建完整的记忆上下文字符串,
|
||||
供注入到 Agent system prompt 中使用。
|
||||
"""
|
||||
parts = []
|
||||
|
||||
# 1. 用户画像(长期记忆)
|
||||
user_memories = await recall_user_memories(db, user_id, current_query, top_k=5)
|
||||
if user_memories:
|
||||
lines = []
|
||||
for m in user_memories:
|
||||
tag = f"[{m.memory_type}]"
|
||||
lines.append(f" {tag} {m.content}")
|
||||
await mark_memory_recalled(db, m.id)
|
||||
parts.append("【用户记忆】\n" + "\n".join(lines))
|
||||
|
||||
# 2. 对话摘要(中期记忆)
|
||||
summaries = await get_summaries(db, conversation_id)
|
||||
if summaries:
|
||||
# 只取最近2条
|
||||
recent = summaries[-2:]
|
||||
lines = [f"[对话摘要{i+1}] {s.summary_text}" for i, s in enumerate(recent)]
|
||||
parts.append("【之前对话摘要】\n" + "\n".join(lines))
|
||||
|
||||
if not parts:
|
||||
return ""
|
||||
return "\n\n".join(parts)
|
||||
|
||||
|
||||
async def try_auto_summarize(
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
conversation_id: str,
|
||||
) -> bool:
|
||||
"""
|
||||
检查是否需要摘要,如果需要则生成并保存。
|
||||
返回是否执行了摘要。
|
||||
"""
|
||||
if not await should_summarize(db, conversation_id):
|
||||
return False
|
||||
|
||||
messages = await load_conversation_history(db, conversation_id, limit=30)
|
||||
if len(messages) < 3:
|
||||
return False
|
||||
|
||||
try:
|
||||
summary_text = await generate_summary(db, conversation_id, messages)
|
||||
turn_count = await get_conversation_turn_count(db, conversation_id)
|
||||
await save_summary(db, user_id, conversation_id, summary_text, turn_count)
|
||||
|
||||
# 同时提取用户记忆
|
||||
await extract_user_memories(db, user_id, conversation_id, messages)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
Reference in New Issue
Block a user