Files
JARVIS/backend/app/agents/orchestration/scheduler.py

94 lines
3.2 KiB
Python

from __future__ import annotations
from collections import defaultdict, deque
from uuid import uuid4
from app.agents.orchestration.budget import build_subtask_budget
from app.agents.schemas.orchestration import SubTaskSpec, TaskGraph, TaskNode
class ParallelExecutionScheduler:
def plan(self, task_graph: TaskGraph, *, query_text: str) -> list[SubTaskSpec]:
ordered_nodes = _topological_nodes(task_graph)
specs: list[SubTaskSpec] = []
for node in ordered_nodes:
budget = build_subtask_budget(
execution_mode=node.execution_mode,
max_parallel_tasks=max(1, task_graph.max_parallelism),
metadata={
"task_graph_id": task_graph.graph_id,
"depends_on": node.depends_on,
},
)
specs.append(
SubTaskSpec(
subtask_id=node.node_id,
parent_run_id=task_graph.graph_id,
title=node.title,
role=node.role or "master",
goal=node.goal or query_text,
context_slice=_build_context_slice(node, query_text),
allowed_tools=[],
budget_tokens=1200,
budget_tool_calls=budget.max_tool_calls or 2,
expected_output_schema={
"summary": "string",
"evidence": "list",
"status": "completed|failed|blocked",
},
expected_evidence=node.expected_evidence,
dependencies=node.depends_on,
)
)
return specs
def build_subtask_specs(task_graph: TaskGraph, *, query_text: str) -> list[SubTaskSpec]:
return ParallelExecutionScheduler().plan(task_graph, query_text=query_text)
def _build_context_slice(node: TaskNode, query_text: str) -> dict[str, object]:
return {
"query": query_text,
"role": node.role,
"title": node.title,
"goal": node.goal,
"depends_on": node.depends_on,
}
def _topological_nodes(task_graph: TaskGraph) -> list[TaskNode]:
by_id = {node.node_id: node for node in task_graph.nodes}
indegree = {node.node_id: 0 for node in task_graph.nodes}
edges: dict[str, list[str]] = defaultdict(list)
for node in task_graph.nodes:
for dep in node.depends_on:
if dep not in by_id:
continue
edges[dep].append(node.node_id)
indegree[node.node_id] += 1
ready = deque(node_id for node_id, count in indegree.items() if count == 0)
ordered: list[TaskNode] = []
while ready:
node_id = ready.popleft()
ordered.append(by_id[node_id])
for target in edges.get(node_id, []):
indegree[target] -= 1
if indegree[target] == 0:
ready.append(target)
if len(ordered) != len(task_graph.nodes):
return list(task_graph.nodes)
return ordered
def ensure_child_links(specs: list[SubTaskSpec]) -> dict[str, list[str]]:
graph: dict[str, list[str]] = defaultdict(list)
for spec in specs:
for dep in spec.dependencies:
graph[dep].append(spec.subtask_id)
return dict(graph)