from __future__ import annotations from collections import defaultdict, deque from uuid import uuid4 from app.agents.orchestration.budget import build_subtask_budget from app.agents.schemas.orchestration import SubTaskSpec, TaskGraph, TaskNode class ParallelExecutionScheduler: def plan(self, task_graph: TaskGraph, *, query_text: str) -> list[SubTaskSpec]: ordered_nodes = _topological_nodes(task_graph) specs: list[SubTaskSpec] = [] for node in ordered_nodes: budget = build_subtask_budget( execution_mode=node.execution_mode, max_parallel_tasks=max(1, task_graph.max_parallelism), metadata={ "task_graph_id": task_graph.graph_id, "depends_on": node.depends_on, }, ) specs.append( SubTaskSpec( subtask_id=node.node_id, parent_run_id=task_graph.graph_id, title=node.title, role=node.role or "master", goal=node.goal or query_text, context_slice=_build_context_slice(node, query_text), allowed_tools=[], budget_tokens=1200, budget_tool_calls=budget.max_tool_calls or 2, expected_output_schema={ "summary": "string", "evidence": "list", "status": "completed|failed|blocked", }, expected_evidence=node.expected_evidence, dependencies=node.depends_on, ) ) return specs def build_subtask_specs(task_graph: TaskGraph, *, query_text: str) -> list[SubTaskSpec]: return ParallelExecutionScheduler().plan(task_graph, query_text=query_text) def _build_context_slice(node: TaskNode, query_text: str) -> dict[str, object]: return { "query": query_text, "role": node.role, "title": node.title, "goal": node.goal, "depends_on": node.depends_on, } def _topological_nodes(task_graph: TaskGraph) -> list[TaskNode]: by_id = {node.node_id: node for node in task_graph.nodes} indegree = {node.node_id: 0 for node in task_graph.nodes} edges: dict[str, list[str]] = defaultdict(list) for node in task_graph.nodes: for dep in node.depends_on: if dep not in by_id: continue edges[dep].append(node.node_id) indegree[node.node_id] += 1 ready = deque(node_id for node_id, count in indegree.items() if count == 0) ordered: list[TaskNode] = [] while ready: node_id = ready.popleft() ordered.append(by_id[node_id]) for target in edges.get(node_id, []): indegree[target] -= 1 if indegree[target] == 0: ready.append(target) if len(ordered) != len(task_graph.nodes): return list(task_graph.nodes) return ordered def ensure_child_links(specs: list[SubTaskSpec]) -> dict[str, list[str]]: graph: dict[str, list[str]] = defaultdict(list) for spec in specs: for dep in spec.dependencies: graph[dep].append(spec.subtask_id) return dict(graph)