fix: harden L3 runtime continuity and tool execution

Align the L3 graph, agent service, and sync tool shims on one canonical continuity contract so clarification resumes and persisted snapshots behave consistently. Add targeted regressions and hardening notes covering system-message coalescing, async bridge usage, and continuity rehydration.
This commit is contained in:
2026-04-03 13:14:59 +08:00
parent b3f9b5e715
commit 4972b4e6b1
18 changed files with 4755 additions and 735 deletions

View File

@@ -30,6 +30,56 @@ from app.agents.state import initial_state
logger = logging.getLogger(__name__)
MEMORY_SECTION_HEADERS = (
"【用户记忆】",
"【之前对话摘要】",
"【知识大脑】",
)
def _split_memory_context_sections(memory_context: str | None) -> dict[str, str]:
text = (memory_context or "").strip()
if not text:
return {}
sections: dict[str, str] = {}
current_header: str | None = None
current_lines: list[str] = []
for line in text.splitlines():
stripped = line.strip()
if stripped in MEMORY_SECTION_HEADERS:
if current_header and current_lines:
sections[current_header] = "\n".join(current_lines).strip()
current_header = stripped
current_lines = [stripped]
continue
if current_header:
current_lines.append(line)
if current_header and current_lines:
sections[current_header] = "\n".join(current_lines).strip()
return sections
def _derive_role_memory_contexts(memory_context: str | None) -> dict[str, str | None]:
sections = _split_memory_context_sections(memory_context)
user_memory = sections.get("【用户记忆】")
summaries = sections.get("【之前对话摘要】")
knowledge = sections.get("【知识大脑】")
def _join_parts(*parts: str | None) -> str | None:
values = [part for part in parts if part]
return "\n\n".join(values) if values else None
return {
"schedule_context_summary": _join_parts(user_memory, summaries),
"knowledge_context": knowledge,
"analysis_report": _join_parts(summaries, knowledge),
}
def _is_streaming_rejection_error(error: Exception, user_llm_config: dict | None) -> bool:
capabilities = resolve_provider_capabilities(user_llm_config)
error_text = str(error).lower()
@@ -87,11 +137,122 @@ _CONTINUITY_SNAPSHOT_FIELDS = (
)
def _normalize_legacy_turn_context(turn_context: Any, current_agent: Any) -> dict[str, Any] | None:
if not isinstance(turn_context, dict):
return None
normalized = dict(turn_context)
active_agent = normalized.pop("active_agent", None)
active_sub_flow = normalized.pop("active_sub_flow", None)
if isinstance(active_agent, str) and active_agent and "active_agent" not in normalized:
normalized["active_agent"] = active_agent
if isinstance(active_sub_flow, str) and active_sub_flow and "active_sub_commander" not in normalized:
normalized["active_sub_commander"] = active_sub_flow
if not normalized.get("active_agent") and isinstance(current_agent, str) and current_agent:
normalized["active_agent"] = current_agent
return normalized or None
def _normalize_legacy_pending_action(pending_action: Any) -> dict[str, Any] | None:
if not isinstance(pending_action, dict):
return None
normalized = dict(pending_action)
legacy_action_type = normalized.pop("action_type", None)
if legacy_action_type and "type" not in normalized:
normalized["type"] = legacy_action_type
legacy_agent = normalized.pop("agent", None)
legacy_sub_flow = normalized.pop("sub_flow", None)
if legacy_agent and "owner_agent" not in normalized:
normalized["owner_agent"] = legacy_agent
if legacy_sub_flow and "owner_sub_commander" not in normalized:
normalized["owner_sub_commander"] = legacy_sub_flow
legacy_status = normalized.get("status")
if legacy_status == "awaiting_confirmation":
normalized["status"] = "pending"
elif legacy_status == "awaiting_clarification":
normalized["status"] = "blocked_on_clarification"
return normalized or None
def _normalize_legacy_clarification_context(
clarification_context: Any,
pending_action: dict[str, Any] | None,
current_agent: Any,
) -> dict[str, Any] | None:
if not isinstance(clarification_context, dict):
return None
normalized = dict(clarification_context)
active_agent = normalized.pop("active_agent", None)
sub_flow = normalized.pop("sub_flow", None)
awaiting_user_input = normalized.pop("awaiting_user_input", None)
if isinstance(active_agent, str) and active_agent and "owning_agent" not in normalized:
normalized["owning_agent"] = active_agent
if isinstance(sub_flow, str) and sub_flow and "owning_sub_commander" not in normalized:
normalized["owning_sub_commander"] = sub_flow
if "target_action" not in normalized:
target_action = None
if pending_action:
pending_type = pending_action.get("type")
if isinstance(pending_type, str) and pending_type and pending_type != "clarification":
target_action = pending_type
if target_action is None and isinstance(sub_flow, str) and sub_flow.startswith("create_"):
target_action = sub_flow
if target_action:
normalized["target_action"] = target_action
if not normalized.get("owning_agent") and isinstance(current_agent, str) and current_agent:
normalized["owning_agent"] = current_agent
if awaiting_user_input is True and "status" not in normalized:
normalized["status"] = "pending"
return normalized or None
def _normalize_legacy_continuity_state(
continuity_state: Any,
clarification_context: dict[str, Any] | None,
) -> dict[str, Any] | None:
if not isinstance(continuity_state, dict):
return None
normalized = dict(continuity_state)
normalized.pop("active_agent", None)
normalized.pop("active_sub_flow", None)
legacy_status = normalized.get("status")
if legacy_status == "awaiting_clarification":
normalized["status"] = "fresh"
if clarification_context and "mode" not in normalized:
normalized["mode"] = "resume_after_clarification"
return normalized or None
def _normalize_continuity_snapshot(state: dict[str, Any]) -> dict[str, Any]:
normalized = dict(state)
current_agent = normalized.get("current_agent")
pending_action = _normalize_legacy_pending_action(normalized.get("pending_action"))
clarification_context = _normalize_legacy_clarification_context(
normalized.get("clarification_context"),
pending_action,
current_agent,
)
continuity_state = _normalize_legacy_continuity_state(
normalized.get("continuity_state"),
clarification_context,
)
turn_context = _normalize_legacy_turn_context(normalized.get("turn_context"), current_agent)
if pending_action is not None:
normalized["pending_action"] = pending_action
if clarification_context is not None:
normalized["clarification_context"] = clarification_context
if continuity_state is not None:
normalized["continuity_state"] = continuity_state
if turn_context is not None:
normalized["turn_context"] = turn_context
return normalized
def _build_continuity_snapshot(state: dict[str, Any]) -> dict[str, Any] | None:
normalized_state = _normalize_continuity_snapshot(state)
snapshot = {
field: state.get(field)
field: normalized_state.get(field)
for field in _CONTINUITY_SNAPSHOT_FIELDS
if state.get(field) is not None
if normalized_state.get(field) is not None
}
if not snapshot:
return None
@@ -116,7 +277,7 @@ def _extract_continuity_snapshot(payload: Any) -> dict[str, Any] | None:
return None
state = payload.get("state")
if isinstance(state, dict):
return state
return _normalize_continuity_snapshot(state)
return None
@@ -187,7 +348,7 @@ class AgentService:
return None
async def _load_continuity_snapshot(self, conversation: Conversation) -> dict[str, Any] | None:
snapshot = _extract_continuity_snapshot(conversation.agent_state)
snapshot = _extract_continuity_snapshot(getattr(conversation, "agent_state", None))
if snapshot:
return snapshot
@@ -358,6 +519,7 @@ class AgentService:
current_datetime_reference=current_datetime_reference,
user_llm_config=user_llm_config,
)
state.update(_derive_role_memory_contexts(memory_ctx))
yield self._build_progress_event("thinking", "Jarvis 正在分析请求", agent="master", step="理解你的问题")
@@ -464,7 +626,10 @@ class AgentService:
"kind": "agent_continuity_state",
**continuity_snapshot,
}] if continuity_snapshot else None)
conv.agent_state = continuity_snapshot
conv.agent_state = ({
"kind": "agent_continuity_state",
**continuity_snapshot,
} if continuity_snapshot else None)
await BrainService(self.db).create_event(
user_id,
**_build_assistant_event_payload(collected),
@@ -557,7 +722,7 @@ class AgentService:
current_datetime_reference=current_datetime_reference,
user_llm_config=user_llm_config,
)
state.update(_derive_role_memory_contexts(memory_ctx))
result_state = await graph.ainvoke(state)
response_content = result_state.get("final_response") or str(result_state.get("messages", [AIMessage(content="")])[-1].content)
except Exception as e:
@@ -585,7 +750,10 @@ class AgentService:
"kind": "agent_continuity_state",
**continuity_snapshot,
}] if continuity_snapshot else None)
conv.agent_state = continuity_snapshot
conv.agent_state = ({
"kind": "agent_continuity_state",
**continuity_snapshot,
} if continuity_snapshot else None)
await self.db.commit()
await self.db.refresh(assistant_msg)

View File

@@ -4,12 +4,15 @@ Jarvis 记忆系统 (基于 Mem0)
底层使用 Mem0 实现事实提取、时间线、矛盾解决和遗忘机制
"""
import logging
import os
from datetime import datetime
import re
from datetime import UTC, datetime
from typing import Optional, Any
from sqlalchemy import select, desc, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.conversation import Conversation, Message
from app.models.memory import UserMemory
from app.models.user import User
from app.services.brain_service import BrainService
from app.config import settings as _settings
@@ -23,6 +26,9 @@ except ImportError:
Memory = None
logger = logging.getLogger(__name__)
async def _get_user_embedding_config(db: AsyncSession, user_id: str) -> dict | None:
"""从用户配置中获取 embedding 模型配置"""
result = await db.execute(select(User).where(User.id == user_id))
@@ -296,6 +302,23 @@ async def extract_user_memories(
return []
def _extract_memory_query_tokens(query: str) -> list[str]:
normalized_query = (query or "").lower()
tokens = [token for token in re.findall(r"[a-z0-9]+", normalized_query) if len(token) >= 3]
for chunk in re.findall(r"[\u4e00-\u9fff]+", query or ""):
stripped_chunk = chunk.strip()
if len(stripped_chunk) >= 4:
tokens.append(stripped_chunk)
if len(stripped_chunk) > 6:
tokens.extend(
stripped_chunk[index:index + 4]
for index in range(len(stripped_chunk) - 3)
)
return list(dict.fromkeys(tokens))
async def recall_user_memories(
db: AsyncSession,
user_id: str,
@@ -304,7 +327,7 @@ async def recall_user_memories(
) -> list[dict]:
"""
根据当前输入召回相关的用户记忆。
使用 Mem0 的语义搜索。
使用 Mem0 的语义搜索;如果 Mem0 不可用或失败,则回退到本地 UserMemory
"""
try:
mem0 = await get_mem0(db, user_id)
@@ -313,10 +336,56 @@ async def recall_user_memories(
filters={"user_id": user_id},
limit=top_k,
)
return results.get("results", [])
mem0_results = results.get("results", [])
if mem0_results:
return mem0_results
except Exception as e:
print(f"Mem0 search error: {e}")
return []
query_tokens = _extract_memory_query_tokens(query)
statement = select(UserMemory).where(UserMemory.user_id == user_id)
result = await db.execute(statement.order_by(UserMemory.importance.desc(), UserMemory.created_at.desc()))
fallback_memories = list(result.scalars().all())
if _contains_hint(_normalize_query(query), MEMORY_QUERY_HINTS) or _matches_memory_query_pattern(_normalize_query(query)):
return fallback_memories[:top_k]
if query_tokens:
matched_memories = [
memory for memory in fallback_memories
if any(token in (memory.content or '').lower() for token in query_tokens)
]
return matched_memories[:top_k]
return []
async def _mark_memories_recalled(db: AsyncSession, memories: list[UserMemory]) -> None:
recalled_at = datetime.now(UTC)
updated = False
for memory in memories:
memory.is_recalled = True
memory.recall_count = (memory.recall_count or 0) + 1
memory.last_recalled_at = recalled_at
updated = True
if updated:
await db.commit()
async def _run_tolerated_section(
db: AsyncSession,
section_name: str,
builder,
) -> str:
try:
return await builder()
except Exception:
logger.warning(
"[MemoryService] %s失败,继续构建剩余上下文",
section_name,
exc_info=True,
)
return ""
async def get_user_profile(db: AsyncSession, user_id: str) -> dict:
@@ -339,6 +408,131 @@ async def get_user_profile(db: AsyncSession, user_id: str) -> dict:
# ———— 记忆组装: 供 Agent 使用的上下文 ————
MEMORY_QUERY_HINTS = (
"记住",
"记下",
"记一下",
"记着",
"提醒",
"偏好",
"习惯",
)
MEMORY_QUERY_PATTERNS = (
re.compile(r"\bremember\s+(?:that\s+)?i\b"),
)
GROUNDING_QUERY_HINTS = (
"根据文档",
"严格根据",
"只根据",
"文档内容",
"grounded",
"strictly based on",
"based on the document",
"based on the docs",
"document only",
"docs only",
"only use the document",
"only use the docs",
)
AVOID_USER_MEMORY_HINTS = (
"不要结合我的个人偏好",
"不要结合个人偏好",
"不要结合偏好",
"不要结合我的记忆",
"不要结合记忆",
)
def _normalize_query(text: str) -> str:
return text.strip().lower()
def _contains_hint(text: str, hints: tuple[str, ...]) -> bool:
return any(hint in text for hint in hints)
def _matches_memory_query_pattern(text: str) -> bool:
return any(pattern.search(text) for pattern in MEMORY_QUERY_PATTERNS)
def _should_include_user_memories(query: str) -> bool:
normalized_query = _normalize_query(query)
if _contains_hint(normalized_query, GROUNDING_QUERY_HINTS):
return False
if _contains_hint(normalized_query, AVOID_USER_MEMORY_HINTS):
return False
return True
def _should_include_summaries(query: str) -> bool:
normalized_query = _normalize_query(query)
if _contains_hint(normalized_query, GROUNDING_QUERY_HINTS):
return False
if _contains_hint(normalized_query, MEMORY_QUERY_HINTS):
return False
if _matches_memory_query_pattern(normalized_query):
return False
return True
async def _build_user_memory_section(
db: AsyncSession,
user_id: str,
current_query: str,
) -> str:
memories = await recall_user_memories(db, user_id, current_query, top_k=5)
if not memories:
return ""
lines = []
recalled_user_memories: list[UserMemory] = []
for memory in memories:
if isinstance(memory, UserMemory):
memory_text = memory.content
memory_type = memory.memory_type
recalled_user_memories.append(memory)
else:
memory_text = memory.get("memory", memory.get("text", ""))
memory_type = memory.get("memory_type")
if not memory_text:
continue
if memory_type:
lines.append(f" [{memory_type}] {memory_text}")
else:
lines.append(f" - {memory_text}")
if not lines:
return ""
if recalled_user_memories:
await _mark_memories_recalled(db, recalled_user_memories)
return "【用户记忆】\n" + "\n".join(lines)
async def _build_summary_section(db: AsyncSession, conversation_id: str) -> str:
summaries = await get_summaries(db, conversation_id)
if not summaries:
return ""
recent = summaries[-2:]
lines = [f"[对话摘要{i + 1}] {summary.summary_text}" for i, summary in enumerate(recent)]
return "【之前对话摘要】\n" + "\n".join(lines)
async def _build_brain_section(
db: AsyncSession,
user_id: str,
current_query: str,
) -> str:
brain_memories = await BrainService(db).recall_memories(user_id, current_query, top_k=3)
if not brain_memories:
return ""
lines = [f"- {memory.title}: {memory.content}" for memory in brain_memories]
return "【知识大脑】\n" + "\n".join(lines)
async def build_memory_context(
db: AsyncSession,
@@ -350,30 +544,33 @@ async def build_memory_context(
构建完整的记忆上下文字符串,
供注入到 Agent system prompt 中使用。
"""
parts = []
parts: list[str] = []
memories = await recall_user_memories(db, user_id, current_query, top_k=5)
if memories:
lines = []
for m in memories:
memory_text = m.get("memory", m.get("text", ""))
if memory_text:
lines.append(f" - {memory_text}")
if lines:
parts.append("【用户记忆】\n" + "\n".join(lines))
if _should_include_user_memories(current_query):
user_memory_section = await _run_tolerated_section(
db,
"用户记忆召回",
lambda: _build_user_memory_section(db, user_id, current_query),
)
if user_memory_section:
parts.append(user_memory_section)
summaries = await get_summaries(db, conversation_id)
if summaries:
recent = summaries[-2:]
lines = [f"[对话摘要{i + 1}] {s.summary_text}" for i, s in enumerate(recent)]
parts.append("【之前对话摘要】\n" + "\n".join(lines))
if _should_include_summaries(current_query):
summary_section = await _run_tolerated_section(
db,
"对话摘要加载",
lambda: _build_summary_section(db, conversation_id),
)
if summary_section:
parts.append(summary_section)
brain_memories = await BrainService(db).recall_memories(user_id, current_query, top_k=3)
if brain_memories:
lines = []
for memory in brain_memories:
lines.append(f"- {memory.title}: {memory.content}")
parts.append("【知识大脑】\n" + "\n".join(lines))
brain_section = await _run_tolerated_section(
db,
"知识大脑召回",
lambda: _build_brain_section(db, user_id, current_query),
)
if brain_section:
parts.append(brain_section)
if not parts:
return ""