fix: harden L3 runtime continuity and tool execution
Align the L3 graph, agent service, and sync tool shims on one canonical continuity contract so clarification resumes and persisted snapshots behave consistently. Add targeted regressions and hardening notes covering system-message coalescing, async bridge usage, and continuity rehydration.
This commit is contained in:
@@ -30,6 +30,56 @@ from app.agents.state import initial_state
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
MEMORY_SECTION_HEADERS = (
|
||||
"【用户记忆】",
|
||||
"【之前对话摘要】",
|
||||
"【知识大脑】",
|
||||
)
|
||||
|
||||
|
||||
def _split_memory_context_sections(memory_context: str | None) -> dict[str, str]:
|
||||
text = (memory_context or "").strip()
|
||||
if not text:
|
||||
return {}
|
||||
|
||||
sections: dict[str, str] = {}
|
||||
current_header: str | None = None
|
||||
current_lines: list[str] = []
|
||||
|
||||
for line in text.splitlines():
|
||||
stripped = line.strip()
|
||||
if stripped in MEMORY_SECTION_HEADERS:
|
||||
if current_header and current_lines:
|
||||
sections[current_header] = "\n".join(current_lines).strip()
|
||||
current_header = stripped
|
||||
current_lines = [stripped]
|
||||
continue
|
||||
if current_header:
|
||||
current_lines.append(line)
|
||||
|
||||
if current_header and current_lines:
|
||||
sections[current_header] = "\n".join(current_lines).strip()
|
||||
|
||||
return sections
|
||||
|
||||
|
||||
def _derive_role_memory_contexts(memory_context: str | None) -> dict[str, str | None]:
|
||||
sections = _split_memory_context_sections(memory_context)
|
||||
user_memory = sections.get("【用户记忆】")
|
||||
summaries = sections.get("【之前对话摘要】")
|
||||
knowledge = sections.get("【知识大脑】")
|
||||
|
||||
def _join_parts(*parts: str | None) -> str | None:
|
||||
values = [part for part in parts if part]
|
||||
return "\n\n".join(values) if values else None
|
||||
|
||||
return {
|
||||
"schedule_context_summary": _join_parts(user_memory, summaries),
|
||||
"knowledge_context": knowledge,
|
||||
"analysis_report": _join_parts(summaries, knowledge),
|
||||
}
|
||||
|
||||
|
||||
def _is_streaming_rejection_error(error: Exception, user_llm_config: dict | None) -> bool:
|
||||
capabilities = resolve_provider_capabilities(user_llm_config)
|
||||
error_text = str(error).lower()
|
||||
@@ -87,11 +137,122 @@ _CONTINUITY_SNAPSHOT_FIELDS = (
|
||||
)
|
||||
|
||||
|
||||
def _normalize_legacy_turn_context(turn_context: Any, current_agent: Any) -> dict[str, Any] | None:
|
||||
if not isinstance(turn_context, dict):
|
||||
return None
|
||||
normalized = dict(turn_context)
|
||||
active_agent = normalized.pop("active_agent", None)
|
||||
active_sub_flow = normalized.pop("active_sub_flow", None)
|
||||
if isinstance(active_agent, str) and active_agent and "active_agent" not in normalized:
|
||||
normalized["active_agent"] = active_agent
|
||||
if isinstance(active_sub_flow, str) and active_sub_flow and "active_sub_commander" not in normalized:
|
||||
normalized["active_sub_commander"] = active_sub_flow
|
||||
if not normalized.get("active_agent") and isinstance(current_agent, str) and current_agent:
|
||||
normalized["active_agent"] = current_agent
|
||||
return normalized or None
|
||||
|
||||
|
||||
def _normalize_legacy_pending_action(pending_action: Any) -> dict[str, Any] | None:
|
||||
if not isinstance(pending_action, dict):
|
||||
return None
|
||||
normalized = dict(pending_action)
|
||||
legacy_action_type = normalized.pop("action_type", None)
|
||||
if legacy_action_type and "type" not in normalized:
|
||||
normalized["type"] = legacy_action_type
|
||||
legacy_agent = normalized.pop("agent", None)
|
||||
legacy_sub_flow = normalized.pop("sub_flow", None)
|
||||
if legacy_agent and "owner_agent" not in normalized:
|
||||
normalized["owner_agent"] = legacy_agent
|
||||
if legacy_sub_flow and "owner_sub_commander" not in normalized:
|
||||
normalized["owner_sub_commander"] = legacy_sub_flow
|
||||
legacy_status = normalized.get("status")
|
||||
if legacy_status == "awaiting_confirmation":
|
||||
normalized["status"] = "pending"
|
||||
elif legacy_status == "awaiting_clarification":
|
||||
normalized["status"] = "blocked_on_clarification"
|
||||
return normalized or None
|
||||
|
||||
|
||||
def _normalize_legacy_clarification_context(
|
||||
clarification_context: Any,
|
||||
pending_action: dict[str, Any] | None,
|
||||
current_agent: Any,
|
||||
) -> dict[str, Any] | None:
|
||||
if not isinstance(clarification_context, dict):
|
||||
return None
|
||||
normalized = dict(clarification_context)
|
||||
active_agent = normalized.pop("active_agent", None)
|
||||
sub_flow = normalized.pop("sub_flow", None)
|
||||
awaiting_user_input = normalized.pop("awaiting_user_input", None)
|
||||
if isinstance(active_agent, str) and active_agent and "owning_agent" not in normalized:
|
||||
normalized["owning_agent"] = active_agent
|
||||
if isinstance(sub_flow, str) and sub_flow and "owning_sub_commander" not in normalized:
|
||||
normalized["owning_sub_commander"] = sub_flow
|
||||
if "target_action" not in normalized:
|
||||
target_action = None
|
||||
if pending_action:
|
||||
pending_type = pending_action.get("type")
|
||||
if isinstance(pending_type, str) and pending_type and pending_type != "clarification":
|
||||
target_action = pending_type
|
||||
if target_action is None and isinstance(sub_flow, str) and sub_flow.startswith("create_"):
|
||||
target_action = sub_flow
|
||||
if target_action:
|
||||
normalized["target_action"] = target_action
|
||||
if not normalized.get("owning_agent") and isinstance(current_agent, str) and current_agent:
|
||||
normalized["owning_agent"] = current_agent
|
||||
if awaiting_user_input is True and "status" not in normalized:
|
||||
normalized["status"] = "pending"
|
||||
return normalized or None
|
||||
|
||||
|
||||
def _normalize_legacy_continuity_state(
|
||||
continuity_state: Any,
|
||||
clarification_context: dict[str, Any] | None,
|
||||
) -> dict[str, Any] | None:
|
||||
if not isinstance(continuity_state, dict):
|
||||
return None
|
||||
normalized = dict(continuity_state)
|
||||
normalized.pop("active_agent", None)
|
||||
normalized.pop("active_sub_flow", None)
|
||||
legacy_status = normalized.get("status")
|
||||
if legacy_status == "awaiting_clarification":
|
||||
normalized["status"] = "fresh"
|
||||
if clarification_context and "mode" not in normalized:
|
||||
normalized["mode"] = "resume_after_clarification"
|
||||
return normalized or None
|
||||
|
||||
|
||||
def _normalize_continuity_snapshot(state: dict[str, Any]) -> dict[str, Any]:
|
||||
normalized = dict(state)
|
||||
current_agent = normalized.get("current_agent")
|
||||
pending_action = _normalize_legacy_pending_action(normalized.get("pending_action"))
|
||||
clarification_context = _normalize_legacy_clarification_context(
|
||||
normalized.get("clarification_context"),
|
||||
pending_action,
|
||||
current_agent,
|
||||
)
|
||||
continuity_state = _normalize_legacy_continuity_state(
|
||||
normalized.get("continuity_state"),
|
||||
clarification_context,
|
||||
)
|
||||
turn_context = _normalize_legacy_turn_context(normalized.get("turn_context"), current_agent)
|
||||
if pending_action is not None:
|
||||
normalized["pending_action"] = pending_action
|
||||
if clarification_context is not None:
|
||||
normalized["clarification_context"] = clarification_context
|
||||
if continuity_state is not None:
|
||||
normalized["continuity_state"] = continuity_state
|
||||
if turn_context is not None:
|
||||
normalized["turn_context"] = turn_context
|
||||
return normalized
|
||||
|
||||
|
||||
def _build_continuity_snapshot(state: dict[str, Any]) -> dict[str, Any] | None:
|
||||
normalized_state = _normalize_continuity_snapshot(state)
|
||||
snapshot = {
|
||||
field: state.get(field)
|
||||
field: normalized_state.get(field)
|
||||
for field in _CONTINUITY_SNAPSHOT_FIELDS
|
||||
if state.get(field) is not None
|
||||
if normalized_state.get(field) is not None
|
||||
}
|
||||
if not snapshot:
|
||||
return None
|
||||
@@ -116,7 +277,7 @@ def _extract_continuity_snapshot(payload: Any) -> dict[str, Any] | None:
|
||||
return None
|
||||
state = payload.get("state")
|
||||
if isinstance(state, dict):
|
||||
return state
|
||||
return _normalize_continuity_snapshot(state)
|
||||
return None
|
||||
|
||||
|
||||
@@ -187,7 +348,7 @@ class AgentService:
|
||||
return None
|
||||
|
||||
async def _load_continuity_snapshot(self, conversation: Conversation) -> dict[str, Any] | None:
|
||||
snapshot = _extract_continuity_snapshot(conversation.agent_state)
|
||||
snapshot = _extract_continuity_snapshot(getattr(conversation, "agent_state", None))
|
||||
if snapshot:
|
||||
return snapshot
|
||||
|
||||
@@ -358,6 +519,7 @@ class AgentService:
|
||||
current_datetime_reference=current_datetime_reference,
|
||||
user_llm_config=user_llm_config,
|
||||
)
|
||||
state.update(_derive_role_memory_contexts(memory_ctx))
|
||||
|
||||
yield self._build_progress_event("thinking", "Jarvis 正在分析请求", agent="master", step="理解你的问题")
|
||||
|
||||
@@ -464,7 +626,10 @@ class AgentService:
|
||||
"kind": "agent_continuity_state",
|
||||
**continuity_snapshot,
|
||||
}] if continuity_snapshot else None)
|
||||
conv.agent_state = continuity_snapshot
|
||||
conv.agent_state = ({
|
||||
"kind": "agent_continuity_state",
|
||||
**continuity_snapshot,
|
||||
} if continuity_snapshot else None)
|
||||
await BrainService(self.db).create_event(
|
||||
user_id,
|
||||
**_build_assistant_event_payload(collected),
|
||||
@@ -557,7 +722,7 @@ class AgentService:
|
||||
current_datetime_reference=current_datetime_reference,
|
||||
user_llm_config=user_llm_config,
|
||||
)
|
||||
|
||||
state.update(_derive_role_memory_contexts(memory_ctx))
|
||||
result_state = await graph.ainvoke(state)
|
||||
response_content = result_state.get("final_response") or str(result_state.get("messages", [AIMessage(content="")])[-1].content)
|
||||
except Exception as e:
|
||||
@@ -585,7 +750,10 @@ class AgentService:
|
||||
"kind": "agent_continuity_state",
|
||||
**continuity_snapshot,
|
||||
}] if continuity_snapshot else None)
|
||||
conv.agent_state = continuity_snapshot
|
||||
conv.agent_state = ({
|
||||
"kind": "agent_continuity_state",
|
||||
**continuity_snapshot,
|
||||
} if continuity_snapshot else None)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(assistant_msg)
|
||||
|
||||
|
||||
@@ -4,12 +4,15 @@ Jarvis 记忆系统 (基于 Mem0)
|
||||
底层使用 Mem0 实现事实提取、时间线、矛盾解决和遗忘机制
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
import re
|
||||
from datetime import UTC, datetime
|
||||
from typing import Optional, Any
|
||||
from sqlalchemy import select, desc, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.models.conversation import Conversation, Message
|
||||
from app.models.memory import UserMemory
|
||||
from app.models.user import User
|
||||
from app.services.brain_service import BrainService
|
||||
from app.config import settings as _settings
|
||||
@@ -23,6 +26,9 @@ except ImportError:
|
||||
Memory = None
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _get_user_embedding_config(db: AsyncSession, user_id: str) -> dict | None:
|
||||
"""从用户配置中获取 embedding 模型配置"""
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
@@ -296,6 +302,23 @@ async def extract_user_memories(
|
||||
return []
|
||||
|
||||
|
||||
def _extract_memory_query_tokens(query: str) -> list[str]:
|
||||
normalized_query = (query or "").lower()
|
||||
tokens = [token for token in re.findall(r"[a-z0-9]+", normalized_query) if len(token) >= 3]
|
||||
|
||||
for chunk in re.findall(r"[\u4e00-\u9fff]+", query or ""):
|
||||
stripped_chunk = chunk.strip()
|
||||
if len(stripped_chunk) >= 4:
|
||||
tokens.append(stripped_chunk)
|
||||
if len(stripped_chunk) > 6:
|
||||
tokens.extend(
|
||||
stripped_chunk[index:index + 4]
|
||||
for index in range(len(stripped_chunk) - 3)
|
||||
)
|
||||
|
||||
return list(dict.fromkeys(tokens))
|
||||
|
||||
|
||||
async def recall_user_memories(
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
@@ -304,7 +327,7 @@ async def recall_user_memories(
|
||||
) -> list[dict]:
|
||||
"""
|
||||
根据当前输入召回相关的用户记忆。
|
||||
使用 Mem0 的语义搜索。
|
||||
使用 Mem0 的语义搜索;如果 Mem0 不可用或失败,则回退到本地 UserMemory。
|
||||
"""
|
||||
try:
|
||||
mem0 = await get_mem0(db, user_id)
|
||||
@@ -313,10 +336,56 @@ async def recall_user_memories(
|
||||
filters={"user_id": user_id},
|
||||
limit=top_k,
|
||||
)
|
||||
return results.get("results", [])
|
||||
mem0_results = results.get("results", [])
|
||||
if mem0_results:
|
||||
return mem0_results
|
||||
except Exception as e:
|
||||
print(f"Mem0 search error: {e}")
|
||||
return []
|
||||
|
||||
query_tokens = _extract_memory_query_tokens(query)
|
||||
statement = select(UserMemory).where(UserMemory.user_id == user_id)
|
||||
result = await db.execute(statement.order_by(UserMemory.importance.desc(), UserMemory.created_at.desc()))
|
||||
fallback_memories = list(result.scalars().all())
|
||||
|
||||
if _contains_hint(_normalize_query(query), MEMORY_QUERY_HINTS) or _matches_memory_query_pattern(_normalize_query(query)):
|
||||
return fallback_memories[:top_k]
|
||||
|
||||
if query_tokens:
|
||||
matched_memories = [
|
||||
memory for memory in fallback_memories
|
||||
if any(token in (memory.content or '').lower() for token in query_tokens)
|
||||
]
|
||||
return matched_memories[:top_k]
|
||||
|
||||
return []
|
||||
|
||||
|
||||
async def _mark_memories_recalled(db: AsyncSession, memories: list[UserMemory]) -> None:
|
||||
recalled_at = datetime.now(UTC)
|
||||
updated = False
|
||||
for memory in memories:
|
||||
memory.is_recalled = True
|
||||
memory.recall_count = (memory.recall_count or 0) + 1
|
||||
memory.last_recalled_at = recalled_at
|
||||
updated = True
|
||||
if updated:
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def _run_tolerated_section(
|
||||
db: AsyncSession,
|
||||
section_name: str,
|
||||
builder,
|
||||
) -> str:
|
||||
try:
|
||||
return await builder()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"[MemoryService] %s失败,继续构建剩余上下文",
|
||||
section_name,
|
||||
exc_info=True,
|
||||
)
|
||||
return ""
|
||||
|
||||
|
||||
async def get_user_profile(db: AsyncSession, user_id: str) -> dict:
|
||||
@@ -339,6 +408,131 @@ async def get_user_profile(db: AsyncSession, user_id: str) -> dict:
|
||||
|
||||
# ———— 记忆组装: 供 Agent 使用的上下文 ————
|
||||
|
||||
MEMORY_QUERY_HINTS = (
|
||||
"记住",
|
||||
"记下",
|
||||
"记一下",
|
||||
"记着",
|
||||
"提醒",
|
||||
"偏好",
|
||||
"习惯",
|
||||
)
|
||||
MEMORY_QUERY_PATTERNS = (
|
||||
re.compile(r"\bremember\s+(?:that\s+)?i\b"),
|
||||
)
|
||||
GROUNDING_QUERY_HINTS = (
|
||||
"根据文档",
|
||||
"严格根据",
|
||||
"只根据",
|
||||
"文档内容",
|
||||
"grounded",
|
||||
"strictly based on",
|
||||
"based on the document",
|
||||
"based on the docs",
|
||||
"document only",
|
||||
"docs only",
|
||||
"only use the document",
|
||||
"only use the docs",
|
||||
)
|
||||
AVOID_USER_MEMORY_HINTS = (
|
||||
"不要结合我的个人偏好",
|
||||
"不要结合个人偏好",
|
||||
"不要结合偏好",
|
||||
"不要结合我的记忆",
|
||||
"不要结合记忆",
|
||||
)
|
||||
|
||||
|
||||
def _normalize_query(text: str) -> str:
|
||||
return text.strip().lower()
|
||||
|
||||
|
||||
def _contains_hint(text: str, hints: tuple[str, ...]) -> bool:
|
||||
return any(hint in text for hint in hints)
|
||||
|
||||
|
||||
def _matches_memory_query_pattern(text: str) -> bool:
|
||||
return any(pattern.search(text) for pattern in MEMORY_QUERY_PATTERNS)
|
||||
|
||||
|
||||
def _should_include_user_memories(query: str) -> bool:
|
||||
normalized_query = _normalize_query(query)
|
||||
if _contains_hint(normalized_query, GROUNDING_QUERY_HINTS):
|
||||
return False
|
||||
if _contains_hint(normalized_query, AVOID_USER_MEMORY_HINTS):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _should_include_summaries(query: str) -> bool:
|
||||
normalized_query = _normalize_query(query)
|
||||
if _contains_hint(normalized_query, GROUNDING_QUERY_HINTS):
|
||||
return False
|
||||
if _contains_hint(normalized_query, MEMORY_QUERY_HINTS):
|
||||
return False
|
||||
if _matches_memory_query_pattern(normalized_query):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
async def _build_user_memory_section(
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
current_query: str,
|
||||
) -> str:
|
||||
memories = await recall_user_memories(db, user_id, current_query, top_k=5)
|
||||
if not memories:
|
||||
return ""
|
||||
|
||||
lines = []
|
||||
recalled_user_memories: list[UserMemory] = []
|
||||
for memory in memories:
|
||||
if isinstance(memory, UserMemory):
|
||||
memory_text = memory.content
|
||||
memory_type = memory.memory_type
|
||||
recalled_user_memories.append(memory)
|
||||
else:
|
||||
memory_text = memory.get("memory", memory.get("text", ""))
|
||||
memory_type = memory.get("memory_type")
|
||||
|
||||
if not memory_text:
|
||||
continue
|
||||
|
||||
if memory_type:
|
||||
lines.append(f" [{memory_type}] {memory_text}")
|
||||
else:
|
||||
lines.append(f" - {memory_text}")
|
||||
|
||||
if not lines:
|
||||
return ""
|
||||
|
||||
if recalled_user_memories:
|
||||
await _mark_memories_recalled(db, recalled_user_memories)
|
||||
return "【用户记忆】\n" + "\n".join(lines)
|
||||
|
||||
|
||||
async def _build_summary_section(db: AsyncSession, conversation_id: str) -> str:
|
||||
summaries = await get_summaries(db, conversation_id)
|
||||
if not summaries:
|
||||
return ""
|
||||
|
||||
recent = summaries[-2:]
|
||||
lines = [f"[对话摘要{i + 1}] {summary.summary_text}" for i, summary in enumerate(recent)]
|
||||
return "【之前对话摘要】\n" + "\n".join(lines)
|
||||
|
||||
|
||||
async def _build_brain_section(
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
current_query: str,
|
||||
) -> str:
|
||||
brain_memories = await BrainService(db).recall_memories(user_id, current_query, top_k=3)
|
||||
if not brain_memories:
|
||||
return ""
|
||||
|
||||
lines = [f"- {memory.title}: {memory.content}" for memory in brain_memories]
|
||||
return "【知识大脑】\n" + "\n".join(lines)
|
||||
|
||||
|
||||
async def build_memory_context(
|
||||
db: AsyncSession,
|
||||
@@ -350,30 +544,33 @@ async def build_memory_context(
|
||||
构建完整的记忆上下文字符串,
|
||||
供注入到 Agent system prompt 中使用。
|
||||
"""
|
||||
parts = []
|
||||
parts: list[str] = []
|
||||
|
||||
memories = await recall_user_memories(db, user_id, current_query, top_k=5)
|
||||
if memories:
|
||||
lines = []
|
||||
for m in memories:
|
||||
memory_text = m.get("memory", m.get("text", ""))
|
||||
if memory_text:
|
||||
lines.append(f" - {memory_text}")
|
||||
if lines:
|
||||
parts.append("【用户记忆】\n" + "\n".join(lines))
|
||||
if _should_include_user_memories(current_query):
|
||||
user_memory_section = await _run_tolerated_section(
|
||||
db,
|
||||
"用户记忆召回",
|
||||
lambda: _build_user_memory_section(db, user_id, current_query),
|
||||
)
|
||||
if user_memory_section:
|
||||
parts.append(user_memory_section)
|
||||
|
||||
summaries = await get_summaries(db, conversation_id)
|
||||
if summaries:
|
||||
recent = summaries[-2:]
|
||||
lines = [f"[对话摘要{i + 1}] {s.summary_text}" for i, s in enumerate(recent)]
|
||||
parts.append("【之前对话摘要】\n" + "\n".join(lines))
|
||||
if _should_include_summaries(current_query):
|
||||
summary_section = await _run_tolerated_section(
|
||||
db,
|
||||
"对话摘要加载",
|
||||
lambda: _build_summary_section(db, conversation_id),
|
||||
)
|
||||
if summary_section:
|
||||
parts.append(summary_section)
|
||||
|
||||
brain_memories = await BrainService(db).recall_memories(user_id, current_query, top_k=3)
|
||||
if brain_memories:
|
||||
lines = []
|
||||
for memory in brain_memories:
|
||||
lines.append(f"- {memory.title}: {memory.content}")
|
||||
parts.append("【知识大脑】\n" + "\n".join(lines))
|
||||
brain_section = await _run_tolerated_section(
|
||||
db,
|
||||
"知识大脑召回",
|
||||
lambda: _build_brain_section(db, user_id, current_query),
|
||||
)
|
||||
if brain_section:
|
||||
parts.append(brain_section)
|
||||
|
||||
if not parts:
|
||||
return ""
|
||||
|
||||
Reference in New Issue
Block a user