Files
JARVIS/backend/app/agents/orchestration/task_graph.py

129 lines
4.6 KiB
Python

from __future__ import annotations
from uuid import uuid4
from app.agents.schemas.orchestration import ParallelWorthiness, TaskGraph, TaskNode
ROLE_KEYWORDS: list[tuple[str, tuple[str, ...]]] = [
("librarian", ("", "检索", "资料", "文档", "知识库", "年报", "forum", "search")),
("analyst", ("分析", "判断", "风险", "总结", "对比", "洞察", "结论")),
("schedule_planner", ("计划", "安排", "下周", "日程", "提醒", "优先级")),
("executor", ("执行", "创建", "更新", "落库", "提交", "发帖")),
]
def build_bounded_task_graph(
*,
query_text: str,
parallel_worthiness: ParallelWorthiness,
max_nodes: int = 4,
) -> TaskGraph | None:
roles = _infer_roles(query_text)
if not roles:
return None
independent_roles = roles[: min(max_nodes - 1, max(1, parallel_worthiness.estimated_subtasks))]
nodes: list[TaskNode] = []
for index, role in enumerate(independent_roles, start=1):
node_id = f"task-{index}-{uuid4().hex[:6]}"
nodes.append(
TaskNode(
node_id=node_id,
title=_build_title(role),
role=role,
goal=_build_goal(role, query_text),
depends_on=[],
execution_mode=(
"parallel"
if parallel_worthiness.preferred_mode in {"collaboration", "parallel"}
and len(independent_roles) > 1
else "serial"
),
expected_evidence=_build_expected_evidence(role),
)
)
if len(nodes) > 1:
merge_id = f"merge-{uuid4().hex[:6]}"
nodes.append(
TaskNode(
node_id=merge_id,
title="汇总并收敛最终结论",
role="master",
goal="汇总前置子任务结果,形成统一可验证的输出。",
depends_on=[node.node_id for node in nodes],
execution_mode="serial",
expected_evidence=[{"type": "merge", "detail": "merged summary and conflict notes"}],
)
)
return TaskGraph(
nodes=nodes,
entry_node_ids=[node.node_id for node in nodes if not node.depends_on],
max_parallelism=max(1, len(independent_roles)),
rationale=_build_rationale(parallel_worthiness, independent_roles),
)
def render_task_graph_summary(task_graph: TaskGraph | None) -> str | None:
if task_graph is None or not task_graph.nodes:
return None
lines = ["- 任务图:"]
for node in task_graph.nodes[:4]:
deps = f" deps={','.join(node.depends_on)}" if node.depends_on else ""
lines.append(f" - [{node.execution_mode}] {node.title} ({node.role}){deps}")
return "\n".join(lines)
def _infer_roles(query_text: str) -> list[str]:
selected: list[str] = []
text = query_text or ""
for role, keywords in ROLE_KEYWORDS:
if any(keyword in text for keyword in keywords):
selected.append(role)
if not selected:
return ["analyst"]
return selected
def _build_title(role: str) -> str:
mapping = {
"librarian": "收集事实与外部/内部证据",
"analyst": "形成判断与风险分析",
"schedule_planner": "整理计划和优先级",
"executor": "执行必要操作并回收结果",
}
return mapping.get(role, "处理子任务")
def _build_goal(role: str, query_text: str) -> str:
mapping = {
"librarian": f"围绕请求收集支持结论的事实和资料:{query_text}",
"analyst": f"基于当前请求输出结构化判断:{query_text}",
"schedule_planner": f"把当前请求收束为计划、安排或优先级:{query_text}",
"executor": f"基于请求执行必要动作并返回结果:{query_text}",
}
return mapping.get(role, query_text)
def _build_expected_evidence(role: str) -> list[dict[str, str]]:
mapping = {
"librarian": [{"type": "evidence", "detail": "retrieval findings"}],
"analyst": [{"type": "analysis", "detail": "structured judgment"}],
"schedule_planner": [{"type": "plan", "detail": "explicit schedule or priorities"}],
"executor": [{"type": "execution", "detail": "tool output or mutation result"}],
}
return mapping.get(role, [{"type": "summary", "detail": "task summary"}])
def _build_rationale(parallel_worthiness: ParallelWorthiness, roles: list[str]) -> str:
return (
f"preferred_mode={parallel_worthiness.preferred_mode}; "
f"score={parallel_worthiness.score:.2f}; "
f"roles={','.join(roles)}"
)