feat: enhance agent orchestration, knowledge flow and UI refinements
This commit is contained in:
@@ -1,23 +1,154 @@
|
||||
"""
|
||||
Jarvis 记忆系统
|
||||
Jarvis 记忆系统 (基于 Mem0)
|
||||
三层记忆: 短期(对话历史) → 中期(摘要) → 长期(用户画像)
|
||||
底层使用 Mem0 实现事实提取、时间线、矛盾解决和遗忘机制
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from typing import Optional, Any
|
||||
from sqlalchemy import select, desc, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.models.memory import MemorySummary, UserMemory
|
||||
from app.models.conversation import Conversation, Message
|
||||
from app.models.user import User
|
||||
from app.services.brain_service import BrainService
|
||||
from app.services.llm_service import get_llm
|
||||
from app.agents.context import get_current_user
|
||||
from app.config import settings as _settings
|
||||
|
||||
try:
|
||||
from mem0 import Memory
|
||||
|
||||
MEM0_AVAILABLE = True
|
||||
except ImportError:
|
||||
MEM0_AVAILABLE = False
|
||||
Memory = None
|
||||
|
||||
|
||||
async def _get_user_embedding_config(db: AsyncSession, user_id: str) -> dict | None:
|
||||
"""从用户配置中获取 embedding 模型配置"""
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user or not user.llm_config:
|
||||
return None
|
||||
|
||||
embedding_models = user.llm_config.get("embedding", [])
|
||||
for model in embedding_models:
|
||||
if model.get("enabled") and model.get("model"):
|
||||
return {
|
||||
"model": model.get("model"),
|
||||
"base_url": model.get("base_url") or _settings.EMBEDDING_BASE_URL,
|
||||
"api_key": model.get("api_key")
|
||||
or _settings.EMBEDDING_API_KEY
|
||||
or _settings.OPENAI_API_KEY,
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
async def _get_user_chat_config(db: AsyncSession, user_id: str) -> dict | None:
|
||||
"""从用户配置中获取 chat 模型配置"""
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user or not user.llm_config:
|
||||
return None
|
||||
|
||||
chat_models = user.llm_config.get("chat", [])
|
||||
for model in chat_models:
|
||||
if model.get("enabled") and model.get("model"):
|
||||
return {
|
||||
"model": model.get("model"),
|
||||
"base_url": model.get("base_url") or _settings.OPENAI_BASE_URL,
|
||||
"api_key": model.get("api_key") or _settings.OPENAI_API_KEY,
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
class Mem0Client:
|
||||
"""Mem0 客户端 - 按用户隔离"""
|
||||
|
||||
_instances: dict[str, Memory] = {}
|
||||
_persist_dir: str = "./data/mem0"
|
||||
|
||||
async def get_memory(self, db: AsyncSession, user_id: str) -> Memory:
|
||||
"""获取指定用户的 Mem0 实例"""
|
||||
cache_key = user_id
|
||||
|
||||
if cache_key not in self._instances:
|
||||
self._instances[cache_key] = await self._init_memory(db, user_id)
|
||||
|
||||
return self._instances[cache_key]
|
||||
|
||||
async def _init_memory(self, db: AsyncSession, user_id: str) -> Memory:
|
||||
if not MEM0_AVAILABLE:
|
||||
raise RuntimeError("mem0ai 未安装,请运行: pip install mem0ai")
|
||||
|
||||
os.makedirs(self._persist_dir, exist_ok=True)
|
||||
|
||||
llm_config = {
|
||||
"model": _settings.OPENAI_MODEL,
|
||||
"base_url": _settings.OPENAI_BASE_URL,
|
||||
"api_key": _settings.OPENAI_API_KEY,
|
||||
}
|
||||
|
||||
embed_config = _settings.EMBEDDING_MODEL
|
||||
embed_base_url = _settings.EMBEDDING_BASE_URL
|
||||
embed_api_key = _settings.EMBEDDING_API_KEY or _settings.OPENAI_API_KEY
|
||||
|
||||
if db and user_id:
|
||||
try:
|
||||
user_chat = await _get_user_chat_config(db, user_id)
|
||||
if user_chat:
|
||||
llm_config = user_chat
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
user_embed = await _get_user_embedding_config(db, user_id)
|
||||
if user_embed:
|
||||
embed_config = user_embed["model"]
|
||||
embed_base_url = user_embed["base_url"]
|
||||
embed_api_key = user_embed["api_key"]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
config = {
|
||||
"vector_store": {
|
||||
"provider": "chroma",
|
||||
"config": {
|
||||
"collection_name": f"jarvis_memory_{user_id}",
|
||||
"path": self._persist_dir,
|
||||
},
|
||||
},
|
||||
"llm": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": llm_config["model"],
|
||||
"api_key": llm_config["api_key"],
|
||||
"base_url": llm_config["base_url"],
|
||||
},
|
||||
},
|
||||
"embedder": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": embed_config,
|
||||
"api_key": embed_api_key,
|
||||
"base_url": embed_base_url,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return Memory.from_config(config)
|
||||
|
||||
|
||||
_mem0_client = Mem0Client()
|
||||
|
||||
|
||||
async def get_mem0(db: AsyncSession, user_id: str) -> Memory:
|
||||
"""获取指定用户的 Mem0 实例"""
|
||||
return await _mem0_client.get_memory(db, user_id)
|
||||
|
||||
|
||||
# ———— 短期记忆: 对话历史 ————
|
||||
|
||||
|
||||
async def load_conversation_history(
|
||||
db: AsyncSession,
|
||||
conversation_id: str,
|
||||
@@ -36,8 +167,7 @@ async def load_conversation_history(
|
||||
async def get_conversation_turn_count(db: AsyncSession, conversation_id: str) -> int:
|
||||
"""获取对话轮数(用户消息数)"""
|
||||
result = await db.execute(
|
||||
select(func.count(Message.id))
|
||||
.where(
|
||||
select(func.count(Message.id)).where(
|
||||
Message.conversation_id == conversation_id,
|
||||
Message.role == "user",
|
||||
)
|
||||
@@ -47,14 +177,15 @@ async def get_conversation_turn_count(db: AsyncSession, conversation_id: str) ->
|
||||
|
||||
# ———— 中期记忆: 对话摘要 ————
|
||||
|
||||
SUMMARIZE_THRESHOLD = 8 # 超过此轮数则摘要
|
||||
MAX_HISTORY_TURNS = 10 # Agent 最多看到的对话历史轮数
|
||||
SUMMARIZE_THRESHOLD = 8
|
||||
MAX_HISTORY_TURNS = 10
|
||||
|
||||
|
||||
async def should_summarize(db: AsyncSession, conversation_id: str) -> bool:
|
||||
"""判断当前对话是否需要摘要"""
|
||||
from app.models.memory import MemorySummary
|
||||
|
||||
turn_count = await get_conversation_turn_count(db, conversation_id)
|
||||
# 检查是否已有摘要覆盖到当前轮数
|
||||
result = await db.execute(
|
||||
select(MemorySummary)
|
||||
.where(MemorySummary.conversation_id == conversation_id)
|
||||
@@ -72,17 +203,21 @@ async def generate_summary(
|
||||
conversation_id: str,
|
||||
messages: list[Message],
|
||||
) -> str:
|
||||
"""调用 LLM 生成对话摘要"""
|
||||
history_text = "\n".join(
|
||||
f"[{m.role}] {m.content}" for m in messages
|
||||
)
|
||||
llm = get_llm()
|
||||
"""生成对话摘要"""
|
||||
from app.services.llm_service import get_llm
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
response = await llm.invoke([
|
||||
SystemMessage(content="你是一个记忆助手。请用简洁的中文总结以下对话的核心内容,"
|
||||
"提取关键信息、用户偏好、待办事项等。不超过150字。"),
|
||||
HumanMessage(content=history_text),
|
||||
])
|
||||
|
||||
history_text = "\n".join(f"[{m.role}] {m.content}" for m in messages)
|
||||
llm = get_llm()
|
||||
response = await llm.invoke(
|
||||
[
|
||||
SystemMessage(
|
||||
content="你是一个记忆助手。请用简洁的中文总结以下对话的核心内容,"
|
||||
"提取关键信息、用户偏好、待办事项等。不超过150字。"
|
||||
),
|
||||
HumanMessage(content=history_text),
|
||||
]
|
||||
)
|
||||
return response.content.strip()
|
||||
|
||||
|
||||
@@ -92,8 +227,10 @@ async def save_summary(
|
||||
conversation_id: str,
|
||||
summary_text: str,
|
||||
turn_count: int,
|
||||
) -> MemorySummary:
|
||||
"""保存对话摘要"""
|
||||
) -> Any:
|
||||
"""保存对话摘要到数据库"""
|
||||
from app.models.memory import MemorySummary
|
||||
|
||||
summary = MemorySummary(
|
||||
user_id=user_id,
|
||||
conversation_id=conversation_id,
|
||||
@@ -109,8 +246,10 @@ async def save_summary(
|
||||
async def get_summaries(
|
||||
db: AsyncSession,
|
||||
conversation_id: str,
|
||||
) -> list[MemorySummary]:
|
||||
) -> list[Any]:
|
||||
"""获取某对话的所有历史摘要"""
|
||||
from app.models.memory import MemorySummary
|
||||
|
||||
result = await db.execute(
|
||||
select(MemorySummary)
|
||||
.where(MemorySummary.conversation_id == conversation_id)
|
||||
@@ -119,31 +258,7 @@ async def get_summaries(
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
# ———— 长期记忆: 用户画像 ————
|
||||
|
||||
EXTRACTION_PROMPT = """从以下对话中提取关于用户的关键信息。
|
||||
只提取事实性的、可能对未来对话有帮助的信息,如:
|
||||
- 用户的身份/职业/背景
|
||||
- 用户的偏好和习惯
|
||||
- 用户的目标和计划
|
||||
- 重要的事件和日期
|
||||
- 用户的观点和态度
|
||||
|
||||
每条记忆格式: [类型] 内容
|
||||
类型: fact(事实) | preference(偏好) | goal(目标) | habit(习惯)
|
||||
|
||||
如果没有提取到任何记忆,回复"无"。
|
||||
"""
|
||||
|
||||
FACT_TYPES = {"fact", "preference", "goal", "habit"}
|
||||
|
||||
|
||||
def _parse_fact_line(line: str) -> tuple[str, str] | None:
|
||||
"""解析一行记忆: [fact] 内容 -> (type, content)"""
|
||||
m = re.match(r"\[(\w+)\]\s*(.+)", line.strip())
|
||||
if m and m.group(1) in FACT_TYPES:
|
||||
return m.group(1), m.group(2).strip()
|
||||
return None
|
||||
# ———— 长期记忆: 基于 Mem0 ————
|
||||
|
||||
|
||||
async def extract_user_memories(
|
||||
@@ -151,55 +266,34 @@ async def extract_user_memories(
|
||||
user_id: str,
|
||||
conversation_id: str,
|
||||
messages: list[Message],
|
||||
) -> list[UserMemory]:
|
||||
"""从对话中提取用户记忆并保存"""
|
||||
) -> list[dict]:
|
||||
"""
|
||||
从对话中提取用户记忆并存储到 Mem0。
|
||||
Mem0 会自动处理:
|
||||
- 事实提取
|
||||
- 时间线追踪
|
||||
- 矛盾解决
|
||||
- 遗忘机制
|
||||
"""
|
||||
if len(messages) < 2:
|
||||
return []
|
||||
|
||||
history_text = "\n".join(
|
||||
f"[{m.role}] {m.content}" for m in messages[-10:]
|
||||
)
|
||||
history_text = "\n".join(f"[{m.role}] {m.content}" for m in messages[-10:])
|
||||
|
||||
llm = get_llm()
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
response = await llm.invoke([
|
||||
SystemMessage(content=EXTRACTION_PROMPT),
|
||||
HumanMessage(content=history_text),
|
||||
])
|
||||
|
||||
text = response.content.strip()
|
||||
if text == "无" or not text:
|
||||
return []
|
||||
|
||||
memories = []
|
||||
for line in text.split("\n"):
|
||||
parsed = _parse_fact_line(line)
|
||||
if not parsed:
|
||||
continue
|
||||
mem_type, content = parsed
|
||||
# 检查是否已有完全相同的记忆
|
||||
existing = await db.execute(
|
||||
select(UserMemory).where(
|
||||
UserMemory.user_id == user_id,
|
||||
UserMemory.content == content,
|
||||
)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
continue
|
||||
|
||||
mem = UserMemory(
|
||||
try:
|
||||
mem0 = await get_mem0(db, user_id)
|
||||
result = mem0.add(
|
||||
messages=[{"role": m.role, "content": m.content} for m in messages[-10:]],
|
||||
user_id=user_id,
|
||||
memory_type=mem_type,
|
||||
content=content,
|
||||
importance=5,
|
||||
source_conversation_id=conversation_id,
|
||||
metadata={
|
||||
"conversation_id": conversation_id,
|
||||
"source": "jarvis_memory",
|
||||
},
|
||||
)
|
||||
db.add(mem)
|
||||
memories.append(mem)
|
||||
|
||||
if memories:
|
||||
await db.commit()
|
||||
return memories
|
||||
return result.get("results", [])
|
||||
except Exception as e:
|
||||
print(f"Mem0 extract error: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def recall_user_memories(
|
||||
@@ -207,41 +301,45 @@ async def recall_user_memories(
|
||||
user_id: str,
|
||||
query: str,
|
||||
top_k: int = 5,
|
||||
) -> list[UserMemory]:
|
||||
"""根据当前输入召回相关的用户记忆(简单关键词匹配)"""
|
||||
# 先尝试语义相似(通过 LLM 判断)
|
||||
# 降级: 直接从数据库取最近的重要记忆
|
||||
result = await db.execute(
|
||||
select(UserMemory)
|
||||
.where(UserMemory.user_id == user_id)
|
||||
.order_by(desc(UserMemory.importance), desc(UserMemory.recall_count))
|
||||
.limit(top_k)
|
||||
)
|
||||
memories = list(result.scalars().all())
|
||||
|
||||
# 重置召回标记
|
||||
for m in memories:
|
||||
m.is_recalled = False
|
||||
await db.commit()
|
||||
|
||||
return memories
|
||||
) -> list[dict]:
|
||||
"""
|
||||
根据当前输入召回相关的用户记忆。
|
||||
使用 Mem0 的语义搜索。
|
||||
"""
|
||||
try:
|
||||
mem0 = await get_mem0(db, user_id)
|
||||
results = mem0.search(
|
||||
query=query,
|
||||
filters={"user_id": user_id},
|
||||
limit=top_k,
|
||||
)
|
||||
return results.get("results", [])
|
||||
except Exception as e:
|
||||
print(f"Mem0 search error: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def mark_memory_recalled(db: AsyncSession, memory_id: str):
|
||||
"""标记记忆已被召回使用"""
|
||||
result = await db.execute(
|
||||
select(UserMemory).where(UserMemory.id == memory_id)
|
||||
)
|
||||
mem = result.scalar_one_or_none()
|
||||
if mem:
|
||||
mem.is_recalled = True
|
||||
mem.recall_count = (mem.recall_count or 0) + 1
|
||||
mem.last_recalled_at = datetime.now(UTC)
|
||||
await db.commit()
|
||||
async def get_user_profile(db: AsyncSession, user_id: str) -> dict:
|
||||
"""
|
||||
获取用户画像。
|
||||
Mem0 的 profile API 会返回 static 和 dynamic facts。
|
||||
"""
|
||||
try:
|
||||
mem0 = await get_mem0(db, user_id)
|
||||
result = mem0.history(user_id=user_id)
|
||||
return {
|
||||
"memories": result.get("results", []),
|
||||
"static": [],
|
||||
"dynamic": [],
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"Mem0 profile error: {e}")
|
||||
return {"memories": [], "static": [], "dynamic": []}
|
||||
|
||||
|
||||
# ———— 记忆组装: 供 Agent 使用的上下文 ————
|
||||
|
||||
|
||||
async def build_memory_context(
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
@@ -254,25 +352,22 @@ async def build_memory_context(
|
||||
"""
|
||||
parts = []
|
||||
|
||||
# 1. 用户画像(长期记忆)
|
||||
user_memories = await recall_user_memories(db, user_id, current_query, top_k=5)
|
||||
if user_memories:
|
||||
memories = await recall_user_memories(db, user_id, current_query, top_k=5)
|
||||
if memories:
|
||||
lines = []
|
||||
for m in user_memories:
|
||||
tag = f"[{m.memory_type}]"
|
||||
lines.append(f" {tag} {m.content}")
|
||||
await mark_memory_recalled(db, m.id)
|
||||
parts.append("【用户记忆】\n" + "\n".join(lines))
|
||||
for m in memories:
|
||||
memory_text = m.get("memory", m.get("text", ""))
|
||||
if memory_text:
|
||||
lines.append(f" - {memory_text}")
|
||||
if lines:
|
||||
parts.append("【用户记忆】\n" + "\n".join(lines))
|
||||
|
||||
# 2. 对话摘要(中期记忆)
|
||||
summaries = await get_summaries(db, conversation_id)
|
||||
if summaries:
|
||||
# 只取最近2条
|
||||
recent = summaries[-2:]
|
||||
lines = [f"[对话摘要{i+1}] {s.summary_text}" for i, s in enumerate(recent)]
|
||||
lines = [f"[对话摘要{i + 1}] {s.summary_text}" for i, s in enumerate(recent)]
|
||||
parts.append("【之前对话摘要】\n" + "\n".join(lines))
|
||||
|
||||
# 3. 知识大脑(长期项目记忆)
|
||||
brain_memories = await BrainService(db).recall_memories(user_id, current_query, top_k=3)
|
||||
if brain_memories:
|
||||
lines = []
|
||||
@@ -292,7 +387,7 @@ async def try_auto_summarize(
|
||||
) -> bool:
|
||||
"""
|
||||
检查是否需要摘要,如果需要则生成并保存。
|
||||
返回是否执行了摘要。
|
||||
同时将对话内容存入 Mem0 进行记忆提取。
|
||||
"""
|
||||
if not await should_summarize(db, conversation_id):
|
||||
return False
|
||||
@@ -306,8 +401,39 @@ async def try_auto_summarize(
|
||||
turn_count = await get_conversation_turn_count(db, conversation_id)
|
||||
await save_summary(db, user_id, conversation_id, summary_text, turn_count)
|
||||
|
||||
# 同时提取用户记忆
|
||||
await extract_user_memories(db, user_id, conversation_id, messages)
|
||||
return True
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
print(f"Auto summarize error: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def forget_memory(db: AsyncSession, user_id: str, memory_id: str) -> bool:
|
||||
"""
|
||||
主动遗忘某条记忆。
|
||||
"""
|
||||
try:
|
||||
mem0 = await get_mem0(db, user_id)
|
||||
mem0.delete(memory_id, user_id=user_id)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Mem0 delete error: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def update_memory(
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
memory_id: str,
|
||||
content: str,
|
||||
) -> bool:
|
||||
"""
|
||||
更新某条记忆。Mem0 会自动处理矛盾检测。
|
||||
"""
|
||||
try:
|
||||
mem0 = await get_mem0(db, user_id)
|
||||
mem0.update(memory_id, content, user_id=user_id)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Mem0 update error: {e}")
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user