Files
JARVIS/backend/app/agents/schemas/orchestration.py

212 lines
7.7 KiB
Python
Raw Normal View History

from __future__ import annotations
import re
from datetime import datetime, timezone
from typing import Any, Literal
from uuid import uuid4
from pydantic import BaseModel, Field
from app.agents.schemas.skills import SkillShortlistEntry
ExecutionMode = Literal["direct", "collaboration", "parallel", "delegated"]
ParallelPreference = Literal["direct", "collaboration", "parallel"]
class ParallelWorthiness(BaseModel):
should_parallelize: bool = False
score: float = 0.0
estimated_subtasks: int = 1
preferred_mode: ParallelPreference = "direct"
reasons: list[str] = Field(default_factory=list)
risk_flags: list[str] = Field(default_factory=list)
class TaskNode(BaseModel):
node_id: str
title: str
role: str | None = None
goal: str | None = None
depends_on: list[str] = Field(default_factory=list)
execution_mode: Literal["serial", "parallel"] = "serial"
expected_evidence: list[dict[str, Any]] = Field(default_factory=list)
class TaskGraph(BaseModel):
graph_id: str = Field(default_factory=lambda: str(uuid4()))
nodes: list[TaskNode] = Field(default_factory=list)
entry_node_ids: list[str] = Field(default_factory=list)
max_parallelism: int = 1
rationale: str | None = None
class SubTaskSpec(BaseModel):
subtask_id: str
parent_run_id: str
title: str
role: str
goal: str
context_slice: dict[str, Any] = Field(default_factory=dict)
allowed_tools: list[str] = Field(default_factory=list)
budget_tokens: int = 1200
budget_tool_calls: int = 2
expected_output_schema: dict[str, Any] = Field(default_factory=dict)
expected_evidence: list[dict[str, Any]] = Field(default_factory=list)
dependencies: list[str] = Field(default_factory=list)
class SubTaskResult(BaseModel):
subtask_id: str
status: Literal["completed", "failed", "blocked"]
summary: str | None = None
evidence: list[dict[str, Any]] = Field(default_factory=list)
output: dict[str, Any] = Field(default_factory=dict)
class MergeReport(BaseModel):
merge_id: str = Field(default_factory=lambda: str(uuid4()))
status: Literal["merged", "conflicted", "fallback"]
summary: str | None = None
evidence_union: list[dict[str, Any]] = Field(default_factory=list)
conflict_flags: list[str] = Field(default_factory=list)
resolution_strategy: str | None = None
resolved_summary: str | None = None
fallback_used: bool = False
class VerificationReport(BaseModel):
status: Literal["passed", "failed", "skipped"]
summary: str | None = None
evidence: list[dict[str, Any]] = Field(default_factory=list)
class ExecutionDecision(BaseModel):
request_id: str = Field(default_factory=lambda: str(uuid4()))
mode: ExecutionMode = "direct"
reason: str
complexity_score: float = 0.0
parallel_worthiness_score: float | None = None
selected_roles: list[str] = Field(default_factory=list)
class RuntimeRequestContext(BaseModel):
request_id: str = Field(default_factory=lambda: str(uuid4()))
session_id: str | None = None
user_id: str
conversation_id: str | None = None
query_text: str | None = None
raw_user_query: str | None = None
recalled_memories: list[str] = Field(default_factory=list)
retrospective_shortlist: list[dict[str, Any]] = Field(default_factory=list)
recalled_retrospectives: list[dict[str, Any]] = Field(default_factory=list)
skill_shortlist: list[SkillShortlistEntry] = Field(default_factory=list)
shortlisted_skills: list[str] = Field(default_factory=list)
parallel_worthiness: ParallelWorthiness = Field(default_factory=ParallelWorthiness)
task_graph: TaskGraph | None = None
recommended_runtime_mode: Literal["direct", "collaboration"] = "direct"
execution_mode: Literal["direct", "collaboration"] | None = None
current_agent_role: str | None = None
conversation_state_ref: str | None = None
assembly_metrics: dict[str, float] = Field(default_factory=dict)
assembled_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
def assess_parallel_worthiness(
query_text: str,
*,
retrospective_count: int = 0,
skill_count: int = 0,
) -> ParallelWorthiness:
normalized = (query_text or "").strip().lower()
reasons: list[str] = []
score = 0.0
multi_step_markers = ("然后", "接着", "同时", "并且", "最后", "汇总", "对比", "分析", "research")
artifact_markers = ("文档", "代码", "文件", "数据库", "论坛", "知识库", "计划")
if any(marker in normalized for marker in multi_step_markers):
score += 0.35
reasons.append("multi_step_request")
if sum(1 for marker in artifact_markers if marker in normalized) >= 2:
score += 0.25
reasons.append("multi_source_context")
if len(re.findall(r"[,、;;]", query_text or "")) >= 2:
score += 0.15
reasons.append("compound_instruction")
if retrospective_count > 0:
score += 0.1
reasons.append("historical_support")
if skill_count > 0:
score += 0.1
reasons.append("skill_candidates_available")
score = min(score, 1.0)
should_parallelize = score >= 0.55
preferred_mode: ParallelPreference = "parallel" if should_parallelize else "direct"
if not should_parallelize and score >= 0.3:
preferred_mode = "collaboration"
estimated_subtasks = 1
if preferred_mode == "parallel":
estimated_subtasks = 3 if score >= 0.8 else 2
elif preferred_mode == "collaboration":
estimated_subtasks = 2
return ParallelWorthiness(
should_parallelize=should_parallelize,
score=round(score, 3),
estimated_subtasks=estimated_subtasks,
preferred_mode=preferred_mode,
reasons=reasons,
)
def render_runtime_request_context_summary(context: RuntimeRequestContext) -> str:
lines = ["【Runtime Request Context】"]
lines.append(f"- 推荐运行模式: {context.recommended_runtime_mode}")
lines.append(
f"- 并行潜力: score={context.parallel_worthiness.score:.2f}, "
f"preferred={context.parallel_worthiness.preferred_mode}, "
f"estimated_subtasks={context.parallel_worthiness.estimated_subtasks}"
)
if context.parallel_worthiness.reasons:
lines.append(f"- 并行判断依据: {', '.join(context.parallel_worthiness.reasons)}")
if context.assembly_metrics:
total_ms = context.assembly_metrics.get("total_ms")
if total_ms is not None:
lines.append(f"- 上下文装配耗时: {total_ms:.1f} ms")
if context.task_graph and context.task_graph.nodes:
lines.append(
f"- 任务图: nodes={len(context.task_graph.nodes)}, max_parallelism={context.task_graph.max_parallelism}"
)
for node in context.task_graph.nodes[:4]:
deps = f", deps={len(node.depends_on)}" if node.depends_on else ""
lines.append(f" - [{node.execution_mode}] {node.title} ({node.role}{deps})")
if context.retrospective_shortlist:
lines.append("- 历史复盘命中:")
for item in context.retrospective_shortlist[:3]:
summary = (item.get("summary") or item.get("summary_text") or "").strip()
task_type = item.get("task_type") or "unknown"
lines.append(f" - [{task_type}] {summary[:160]}")
if context.skill_shortlist:
lines.append("- 技能候选:")
for item in context.skill_shortlist[:3]:
lines.append(
f" - {item.skill_name} ({item.injection_mode}, score={item.score:.2f})"
+ (f": {item.rationale}" if item.rationale else "")
)
if context.recalled_memories:
lines.append("- 记忆上下文已装配,可在回答中按需引用。")
return "\n".join(lines)