feat(agents): enhance agent core with state management

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
This commit is contained in:
2026-04-08 00:10:58 +08:00
parent 4702cc8ed2
commit de08165e07
4 changed files with 464 additions and 3 deletions

View File

@@ -27,6 +27,9 @@ from app.agents.prompts import (
MASTER_SYSTEM_PROMPT,
SCHEDULE_PLANNER_SYSTEM_PROMPT,
)
from app.agents.orchestration.result_merge import merge_task_results
from app.agents.orchestration.scheduler import build_subtask_specs, ensure_child_links
from app.agents.orchestration.subagent_runtime import subtask_spec_to_agent_task
from app.agents.registry import load_builtin_registry_indexes
from app.agents.runtime_metrics import (
coerce_cost_thresholds,
@@ -36,6 +39,14 @@ from app.agents.runtime_metrics import (
)
from app.agents.schemas.event import AgentEvent
from app.agents.schemas.message import AgentMessage
from app.agents.schemas.orchestration import (
ExecutionDecision,
MergeReport,
RuntimeRequestContext,
TaskGraph,
VerificationReport,
render_runtime_request_context_summary,
)
from app.agents.schemas.task import (
AgentTask,
CollaborationBudget,
@@ -44,6 +55,7 @@ from app.agents.schemas.task import (
TaskResult,
)
from app.agents.skill_registry import build_skill_context
from app.agents.skills.retriever import build_shortlisted_skill_context
from app.agents.state import AgentRole, AgentState
from app.agents.tools import SUB_COMMANDER_TOOLSETS
from app.agents.tools.time_reasoning import normalize_tool_time_arguments
@@ -1148,6 +1160,57 @@ def _parse_json_action(content: str, allowed_tools: list[str]) -> dict[str, Any]
return None
def _looks_like_internal_tool_markup(content: str) -> bool:
text = (content or "").strip()
if not text:
return False
lowered = text.lower()
xml_markers = (
"<minimax:tool_call",
"</minimax:tool_call>",
"<invoke name=",
"</invoke>",
"<parameter name=",
"</parameter>",
)
if any(marker in lowered for marker in xml_markers):
return True
return "分发说明" in text and ("<invoke name=" in lowered or "tool_call" in lowered)
def _clean_tool_result_for_user(tool_result: str | None) -> str:
text = (tool_result or "").strip()
if not text:
return ""
cleaned_lines = [
re.sub(r"^\[[^\]]+\]\s*", "", line).strip()
for line in text.splitlines()
if line.strip()
]
return "\n".join(line for line in cleaned_lines if line).strip()
def _build_internal_markup_fallback_response(
state: AgentState,
*,
sub_commander: str,
) -> str | None:
tool_result = state.get("last_tool_result")
cleaned_tool_result = _clean_tool_result_for_user(tool_result)
if not cleaned_tool_result or _tool_result_indicates_failure(cleaned_tool_result):
return None
if sub_commander == "librarian_retrieval":
if _is_missing_knowledge_result(tool_result):
return "这次检索没有拿到有效证据。我先不展示内部调度过程;如果您愿意,我可以直接基于常识回答,或改为联网搜索后再整理。"
return f"我已经完成检索,直接给您可用信息:\n\n{cleaned_tool_result}"
return cleaned_tool_result
def _has_active_structured_continuation(state: AgentState) -> bool:
pending_action = state.get("pending_action") or {}
routing_decision = state.get("routing_decision") or {}
@@ -1205,6 +1268,70 @@ def _build_structured_continuity_summary(state: AgentState) -> str | None:
return "\n".join(lines)
def _build_retrospective_context_summary(state: AgentState) -> str | None:
retrospectives = list(state.get("recalled_retrospectives") or [])
if not retrospectives:
return None
lines = ["【相关历史复盘】"]
for item in retrospectives[:2]:
if not isinstance(item, dict):
continue
request_summary = str(item.get("request_summary") or item.get("task_type") or "").strip()
execution_mode = str(item.get("execution_mode") or "").strip()
success_score = float(item.get("success_score") or 0.0)
reusable_patterns = list(item.get("reusable_patterns") or [])
avoid_patterns = list(item.get("avoid_patterns") or [])
summary_parts = [request_summary[:80] or execution_mode or "历史任务"]
if execution_mode:
summary_parts.append(f"mode={execution_mode}")
summary_parts.append(f"score={success_score:.2f}")
if reusable_patterns:
summary_parts.append(f"可复用={','.join(reusable_patterns[:2])}")
elif avoid_patterns:
summary_parts.append(f"避坑={','.join(avoid_patterns[:2])}")
lines.append(f"- {''.join(summary_parts)}")
return "\n".join(lines) if len(lines) > 1 else None
def _estimate_request_complexity(user_query: str, selected_roles: list[str]) -> float:
text = (user_query or "").strip()
base = min(len(text) / 120.0, 1.0)
role_boost = min(len(selected_roles) * 0.2, 0.6)
return round(min(base + role_boost, 1.0), 2)
def _record_execution_decision(
state: AgentState,
*,
user_query: str,
mode: Literal["direct", "collaboration", "parallel"],
reason: str,
selected_roles: list[str] | None = None,
parallel_worthiness_score: float | None = None,
) -> None:
runtime_request_context = state.get("runtime_request_context") or {}
request_id = str(runtime_request_context.get("request_id") or state.get("conversation_id") or "")
roles = list(selected_roles or [])
decision = ExecutionDecision(
request_id=request_id or f"request-{uuid4().hex[:8]}",
mode=mode,
reason=reason,
complexity_score=_estimate_request_complexity(user_query, roles),
parallel_worthiness_score=parallel_worthiness_score,
selected_roles=roles,
)
state["execution_decision"] = decision.model_dump(mode="json")
_append_event_trace(
state,
"agent.execution.decided",
payload=state["execution_decision"],
)
def _build_system_messages(
state: AgentState, system_prompt: str, role: AgentRole, sub_commander: str
) -> list[BaseMessage]:
@@ -1214,6 +1341,19 @@ def _build_system_messages(
if current_datetime_context:
messages.append(SystemMessage(content=current_datetime_context))
runtime_request_context = state.get("runtime_request_context")
if isinstance(runtime_request_context, dict) and runtime_request_context:
try:
runtime_context_model = RuntimeRequestContext.model_validate(runtime_request_context)
except Exception:
runtime_context_model = None
if runtime_context_model is not None:
messages.append(
SystemMessage(
content=render_runtime_request_context_summary(runtime_context_model)
)
)
continuity_summary = _build_structured_continuity_summary(state)
if continuity_summary:
messages.append(SystemMessage(content=continuity_summary))
@@ -1226,6 +1366,10 @@ def _build_system_messages(
if collaboration_summary:
messages.append(SystemMessage(content=collaboration_summary))
retrospective_summary = _build_retrospective_context_summary(state)
if retrospective_summary:
messages.append(SystemMessage(content=retrospective_summary))
role_context_map = {
AgentRole.SCHEDULE_PLANNER: state.get("schedule_context_summary"),
AgentRole.LIBRARIAN: state.get("knowledge_context"),
@@ -1237,7 +1381,11 @@ def _build_system_messages(
role_skill_key = ROLE_SKILL_CONTEXT.get(role)
if role_skill_key:
skill_context = build_skill_context(role_skill_key)
shortlisted_context = build_shortlisted_skill_context(
state.get("skill_shortlist"),
agent_type=role_skill_key,
)
skill_context = shortlisted_context or build_skill_context(role_skill_key)
if skill_context:
messages.append(SystemMessage(content=skill_context))
@@ -1322,6 +1470,29 @@ def _build_collaboration_tasks(user_query: str) -> list[AgentTask]:
return tasks
def _build_collaboration_plan_from_task_graph(
state: AgentState,
user_query: str,
) -> tuple[list[AgentTask], list[dict[str, Any]]]:
raw_task_graph = state.get("task_graph")
if not isinstance(raw_task_graph, dict) or not raw_task_graph.get("nodes"):
return _build_collaboration_tasks(user_query), []
task_graph = TaskGraph.model_validate(raw_task_graph)
specs = [
spec
for spec in build_subtask_specs(task_graph, query_text=user_query)
if spec.role != "master"
]
child_links = ensure_child_links(specs)
tasks: list[AgentTask] = []
for spec in specs:
task = subtask_spec_to_agent_task(spec)
task.child_task_ids = child_links.get(spec.subtask_id, [])
tasks.append(task)
return tasks, [spec.model_dump(mode="json") for spec in specs]
def _build_collaboration_context_summary(state: AgentState) -> str | None:
if state.get("execution_mode") != "collaboration":
return None
@@ -2076,7 +2247,36 @@ async def _run_sub_commander(
)
_record_response_usage(state, response)
response_text = _stringify_message_content(response.content)
response_text_stripped = response_text.strip()
parsed = _parse_json_action(response_text, allowed_tools)
if parsed is None and response_text_stripped and _looks_like_internal_tool_markup(
response_text_stripped
):
if int(state.get("retry_count") or 0) >= int(state.get("max_retries") or 0):
state["fallback_parse_error"] = "internal_tool_markup"
state["final_response"] = _build_internal_markup_fallback_response(
state,
sub_commander=sub_commander,
) or (
"这次内部调度没有正确收束成最终答复。我先不展示内部调用过程;您重试一次,我会直接用自然语言回答。"
)
break
if not _guard_sub_commander_budget(
state, "iteration_count", "max_iterations", "max_iterations_exceeded"
):
parsed = None
break
state["iteration_count"] = int(state.get("iteration_count") or 0) + 1
state["retry_count"] = int(state.get("retry_count") or 0) + 1
retry_instruction = SystemMessage(
content=(
"上一轮输出了内部调度或工具调用标记,这是协议错误。"
"不要再输出分发说明、XML 标签、<invoke>、<parameter>、JSON 或 tool_call。"
"请直接面向用户给出最终自然语言答复;如果已有工具结果,就基于结果整理;"
"如果工具没有找到证据,可以基于常识直接回答。"
)
)
continue
if parsed is None and response_text.strip() and state.get("tool_round_count"):
state["fallback_parse_error"] = None
state["final_response"] = response_text.strip()
@@ -2382,6 +2582,35 @@ def _build_collaboration_final_response(task_results: list[TaskResult | dict[str
return "\n".join(lines)
def _build_serial_fallback_response(
user_query: str,
task_results: list[TaskResult | dict[str, Any]],
merge_report: MergeReport | dict[str, Any] | None,
) -> str:
normalized_results = [normalize_task_result(item) for item in task_results]
completed = [item for item in normalized_results if item.status == "completed"]
merge_payload = (
merge_report.model_dump(mode="json")
if isinstance(merge_report, MergeReport)
else dict(merge_report or {})
)
lines = [
"并行/协作结果出现冲突或失败,我已切回保守收敛路径。",
f"原始请求:{user_query}",
]
if completed:
lines.append("当前仍可确认的结果:")
for item in completed[:3]:
lines.append(f"- [{item.owner_agent_id or 'unknown'}] {item.summary or '已完成'}")
if merge_payload.get("conflict_flags"):
lines.append("冲突/回退原因:")
for flag in list(merge_payload.get("conflict_flags") or [])[:3]:
lines.append(f"- {flag}")
if not completed:
lines.append("目前没有足够稳定的子任务结果,建议改走 direct 或更小范围的 collaboration。")
return "\n".join(lines)
def _verify_collaboration_results(
state: AgentState,
tasks: list[AgentTask],
@@ -2411,6 +2640,14 @@ def _verify_collaboration_results(
}
for item in normalized_results
]
merge_report = merge_task_results([item.model_dump(mode="json") for item in normalized_results])
state["merge_report"] = merge_report.model_dump(mode="json")
_append_event_trace(
state,
"agent.merge.completed",
payload=state["merge_report"],
)
if missing_task_ids:
summary = f"协作结果不完整,缺少任务结果: {', '.join(missing_task_ids)}"
verdict = verify_task_result(
@@ -2426,20 +2663,51 @@ def _verify_collaboration_results(
verdict = verify_task_result(
status="failed", summary=summary, evidence=verification_evidence
)
elif merge_report.status == "conflicted":
verdict = verify_task_result(
status="failed",
summary=merge_report.summary,
evidence=[
*verification_evidence,
{"type": "merge_conflict", "conflict_flags": merge_report.conflict_flags},
],
)
elif merge_report.status == "fallback":
verdict = verify_task_result(
status="failed",
summary=merge_report.summary,
evidence=[
*verification_evidence,
{"type": "merge_fallback", "conflict_flags": merge_report.conflict_flags},
],
)
else:
summary = f"协作模式已完成 {len(normalized_results)}/{len(tasks)} 个子任务,并为每个子任务回收了结果与 evidence。"
summary = (
merge_report.summary
or f"协作模式已完成 {len(normalized_results)}/{len(tasks)} 个子任务,并为每个子任务回收了结果与 evidence。"
)
verdict = verify_task_result(
status="passed", summary=summary, evidence=verification_evidence
)
updated_state = apply_verification_verdict(state, verdict)
state.update(updated_state)
state["verification_report"] = VerificationReport(
status=state.get("verification_status") or "skipped",
summary=state.get("verification_summary"),
evidence=list(state.get("verification_evidence") or []),
).model_dump(mode="json")
_append_event_trace(
state,
"agent.verify.completed",
payload=state["verification_report"],
)
async def _run_collaboration_flow(state: AgentState, user_query: str) -> AgentState:
_set_phase(state, "phase_2_controlled_collaboration", reason="collaboration_flow_started")
_record_checkpoint(state, "collaboration.tasks_planning", reason="collaboration_flow_started")
tasks = _build_collaboration_tasks(user_query)
tasks, scheduled_subtasks = _build_collaboration_plan_from_task_graph(state, user_query)
if len(tasks) < 2:
state["execution_mode"] = "direct"
state["routing_decision"] = {"mode": "direct", "reason": "collaboration_plan_fell_back"}
@@ -2487,6 +2755,7 @@ async def _run_collaboration_flow(state: AgentState, user_query: str) -> AgentSt
"agent.collaboration.budget.updated",
payload=budget_snapshot,
)
state["scheduled_subtasks"] = scheduled_subtasks
state["active_tasks"] = [task.model_dump(mode="json") for task in tasks]
_record_checkpoint(
state, "collaboration.tasks_ready", reason="tasks_built", payload={"task_count": len(tasks)}
@@ -2500,6 +2769,21 @@ async def _run_collaboration_flow(state: AgentState, user_query: str) -> AgentSt
_set_phase(state, "phase_3_dynamic_collaboration", reason="collaboration_workers_dispatch")
for task in tasks:
scheduled_subtask = next(
(item for item in scheduled_subtasks if item.get("subtask_id") == task.task_id),
None,
)
if scheduled_subtask is not None:
_append_event_trace(
state,
"agent.subtask.started",
payload={
"subtask_id": scheduled_subtask.get("subtask_id"),
"role": scheduled_subtask.get("role"),
"dependencies": scheduled_subtask.get("dependencies") or [],
},
task_id=task.task_id,
)
_record_checkpoint(
state,
"collaboration.task_dispatch",
@@ -2583,6 +2867,17 @@ async def _run_collaboration_flow(state: AgentState, user_query: str) -> AgentSt
child_task_id=(task.child_task_ids or [None])[0],
message_id=str(state.get("last_message_id") or "") or None,
)
if scheduled_subtask is not None:
_append_event_trace(
state,
"agent.subtask.completed",
payload={
"subtask_id": scheduled_subtask.get("subtask_id"),
"status": task_result.status,
"summary": task_result.summary,
},
task_id=task.task_id,
)
_apply_task_result_to_state(state, task, task_result)
if task_result.status != "completed":
@@ -2618,6 +2913,26 @@ async def _run_collaboration_flow(state: AgentState, user_query: str) -> AgentSt
},
severity="error" if state.get("verification_status") == "failed" else "info",
)
merge_report = state.get("merge_report") or {}
if state.get("verification_status") == "failed" and merge_report.get("fallback_used"):
state["final_response"] = _build_serial_fallback_response(
user_query,
state.get("task_results") or [],
merge_report,
)
state["routing_decision"] = {
"mode": "direct",
"reason": "fallback_to_serial_recovery",
}
_append_event_trace(
state,
"agent.rollback.triggered",
payload={
"layer": "collaboration_runtime",
"reason": "merge_fallback_used",
},
severity="warning",
)
_record_checkpoint(
state,
"collaboration.completed",
@@ -2679,23 +2994,99 @@ async def master_node(state: AgentState) -> AgentState:
return state
state["current_agent"] = _normalize_current_agent(state.get("current_agent"))
parallel_worthiness = state.get("parallel_worthiness")
if not isinstance(parallel_worthiness, dict):
runtime_request_context = state.get("runtime_request_context") or {}
if isinstance(runtime_request_context, dict):
candidate_parallel = runtime_request_context.get("parallel_worthiness")
if isinstance(candidate_parallel, dict):
parallel_worthiness = candidate_parallel
if isinstance(parallel_worthiness, dict) and parallel_worthiness:
_append_event_trace(
state,
"agent.parallel.assessed",
payload=parallel_worthiness,
)
skill_shortlist = list(state.get("skill_shortlist") or [])
if skill_shortlist:
_append_event_trace(
state,
"agent.skill.shortlisted",
payload={
"count": len(skill_shortlist),
"skills": [str(item.get("skill_name") or "") for item in skill_shortlist[:4]],
},
)
task_graph = state.get("task_graph")
if isinstance(task_graph, dict) and task_graph.get("nodes"):
_append_event_trace(
state,
"agent.task_graph.built",
payload={
"graph_id": task_graph.get("graph_id"),
"node_count": len(task_graph.get("nodes") or []),
"entry_node_ids": task_graph.get("entry_node_ids") or [],
"max_parallelism": task_graph.get("max_parallelism"),
},
)
elif (
isinstance(parallel_worthiness, dict)
and parallel_worthiness.get("preferred_mode") in {"collaboration", "parallel"}
and not (state.get("feature_flags") or {}).get("ENABLE_PARALLEL_TASK_GRAPH", True)
):
_append_event_trace(
state,
"agent.rollback.triggered",
payload={
"layer": "parallel_task_graph",
"reason": "feature_flag_disabled",
},
)
structured_continuity_route = _route_from_structured_continuity(state, user_query)
clarification_route = _route_from_clarification_context(state, user_query)
if structured_continuity_route is not None:
state["execution_mode"] = "direct"
routed_agent = structured_continuity_route
_record_execution_decision(
state,
user_query=user_query,
mode="direct",
reason="continue_pending_action",
selected_roles=[routed_agent.value],
)
elif clarification_route is not None:
state["execution_mode"] = "direct"
routed_agent = clarification_route
_record_execution_decision(
state,
user_query=user_query,
mode="direct",
reason="clarification_follow_up",
selected_roles=[routed_agent.value],
)
elif _is_short_confirmation(user_query) and _previous_turn_proposed_schedule_creation(
state.get("messages", [])
):
state["execution_mode"] = "direct"
routed_agent = AgentRole.SCHEDULE_PLANNER
_record_execution_decision(
state,
user_query=user_query,
mode="direct",
reason="schedule_confirmation_follow_up",
selected_roles=[routed_agent.value],
)
else:
request_mode, routing_metadata = _select_request_mode(user_query)
state["routing_decision"] = routing_metadata
_record_execution_decision(
state,
user_query=user_query,
mode=request_mode,
reason=str(routing_metadata.get("reason") or request_mode),
selected_roles=list(routing_metadata.get("roles") or []),
)
if request_mode == "collaboration":
collaboration_state = await _run_collaboration_flow(state, user_query)
if collaboration_state.get(

View File

@@ -1,5 +1,24 @@
from app.agents.schemas.event import AgentEvent
from app.agents.schemas.learning import (
LearningDecision,
LearningSignal,
PatternCandidate,
SessionRetrospective,
SkillCandidate,
)
from app.agents.schemas.message import AgentMessage
from app.agents.schemas.orchestration import (
ExecutionDecision,
MergeReport,
ParallelWorthiness,
RuntimeRequestContext,
SubTaskResult,
SubTaskSpec,
TaskGraph,
TaskNode,
VerificationReport,
)
from app.agents.schemas.skills import SkillActivationRecord, SkillShortlistEntry
from app.agents.schemas.task import (
AgentTask,
CollaborationBudget,
@@ -14,12 +33,28 @@ from app.agents.schemas.task import (
__all__ = [
"AgentEvent",
"AgentMessage",
"ExecutionDecision",
"AgentTask",
"CollaborationBudget",
"InterruptRecord",
"LearningDecision",
"LearningSignal",
"MergeReport",
"ParallelWorthiness",
"PatternCandidate",
"RecoveryRecord",
"RuntimeRequestContext",
"SessionRetrospective",
"SkillActivationRecord",
"SkillCandidate",
"SkillShortlistEntry",
"SubTaskResult",
"SubTaskSpec",
"TaskGraph",
"TaskNode",
"TaskLifecycleStatus",
"TaskResult",
"TaskResultStatus",
"VerificationReport",
"VerificationStatus",
]

View File

@@ -7,10 +7,21 @@ from pydantic import BaseModel, Field
AgentEventType = Literal[
"agent.execution.decided",
"agent.parallel.assessed",
"agent.skill.shortlisted",
"agent.task_graph.built",
"agent.subtask.started",
"agent.subtask.completed",
"agent.merge.completed",
"agent.tool.start",
"agent.tool.result",
"agent.verify.started",
"agent.verify.completed",
"agent.retrospective.created",
"agent.learning.decision",
"agent.skill.lifecycle.changed",
"agent.rollback.triggered",
"agent.created",
"agent.spawn.blocked",
"agent.message.sent",

View File

@@ -138,6 +138,18 @@ class AgentState(TypedDict):
memory_context: str | None
current_datetime_context: str | None
current_datetime_reference: dict[str, str] | None
runtime_request_context: dict[str, Any] | None
task_graph: dict[str, Any] | None
scheduled_subtasks: list[dict[str, Any]]
recalled_retrospectives: list[dict[str, Any]]
retrospective_shortlist: list[dict[str, Any]]
skill_shortlist: list[dict[str, Any]]
skill_activation_records: list[dict[str, Any]]
execution_decision: dict[str, Any] | None
merge_report: dict[str, Any] | None
verification_report: dict[str, Any] | None
feature_flags: dict[str, bool]
observability_report: dict[str, Any] | None
turn_context: dict[str, Any] | None
routing_decision: dict[str, Any] | None
@@ -254,6 +266,18 @@ def initial_state(user_id: str, conversation_id: str) -> AgentState:
memory_context=None,
current_datetime_context=None,
current_datetime_reference=None,
runtime_request_context=None,
task_graph=None,
scheduled_subtasks=[],
recalled_retrospectives=[],
retrospective_shortlist=[],
skill_shortlist=[],
skill_activation_records=[],
execution_decision=None,
merge_report=None,
verification_report=None,
feature_flags={},
observability_report=None,
turn_context=None,
routing_decision=None,
continuity_state=None,