Files
JARVIS/backend/app/agents/schemas/orchestration.py

212 lines
7.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import re
from datetime import datetime, timezone
from typing import Any, Literal
from uuid import uuid4
from pydantic import BaseModel, Field
from app.agents.schemas.skills import SkillShortlistEntry
ExecutionMode = Literal["direct", "collaboration", "parallel", "delegated"]
ParallelPreference = Literal["direct", "collaboration", "parallel"]
class ParallelWorthiness(BaseModel):
should_parallelize: bool = False
score: float = 0.0
estimated_subtasks: int = 1
preferred_mode: ParallelPreference = "direct"
reasons: list[str] = Field(default_factory=list)
risk_flags: list[str] = Field(default_factory=list)
class TaskNode(BaseModel):
node_id: str
title: str
role: str | None = None
goal: str | None = None
depends_on: list[str] = Field(default_factory=list)
execution_mode: Literal["serial", "parallel"] = "serial"
expected_evidence: list[dict[str, Any]] = Field(default_factory=list)
class TaskGraph(BaseModel):
graph_id: str = Field(default_factory=lambda: str(uuid4()))
nodes: list[TaskNode] = Field(default_factory=list)
entry_node_ids: list[str] = Field(default_factory=list)
max_parallelism: int = 1
rationale: str | None = None
class SubTaskSpec(BaseModel):
subtask_id: str
parent_run_id: str
title: str
role: str
goal: str
context_slice: dict[str, Any] = Field(default_factory=dict)
allowed_tools: list[str] = Field(default_factory=list)
budget_tokens: int = 1200
budget_tool_calls: int = 2
expected_output_schema: dict[str, Any] = Field(default_factory=dict)
expected_evidence: list[dict[str, Any]] = Field(default_factory=list)
dependencies: list[str] = Field(default_factory=list)
class SubTaskResult(BaseModel):
subtask_id: str
status: Literal["completed", "failed", "blocked"]
summary: str | None = None
evidence: list[dict[str, Any]] = Field(default_factory=list)
output: dict[str, Any] = Field(default_factory=dict)
class MergeReport(BaseModel):
merge_id: str = Field(default_factory=lambda: str(uuid4()))
status: Literal["merged", "conflicted", "fallback"]
summary: str | None = None
evidence_union: list[dict[str, Any]] = Field(default_factory=list)
conflict_flags: list[str] = Field(default_factory=list)
resolution_strategy: str | None = None
resolved_summary: str | None = None
fallback_used: bool = False
class VerificationReport(BaseModel):
status: Literal["passed", "failed", "skipped"]
summary: str | None = None
evidence: list[dict[str, Any]] = Field(default_factory=list)
class ExecutionDecision(BaseModel):
request_id: str = Field(default_factory=lambda: str(uuid4()))
mode: ExecutionMode = "direct"
reason: str
complexity_score: float = 0.0
parallel_worthiness_score: float | None = None
selected_roles: list[str] = Field(default_factory=list)
class RuntimeRequestContext(BaseModel):
request_id: str = Field(default_factory=lambda: str(uuid4()))
session_id: str | None = None
user_id: str
conversation_id: str | None = None
query_text: str | None = None
raw_user_query: str | None = None
recalled_memories: list[str] = Field(default_factory=list)
retrospective_shortlist: list[dict[str, Any]] = Field(default_factory=list)
recalled_retrospectives: list[dict[str, Any]] = Field(default_factory=list)
skill_shortlist: list[SkillShortlistEntry] = Field(default_factory=list)
shortlisted_skills: list[str] = Field(default_factory=list)
parallel_worthiness: ParallelWorthiness = Field(default_factory=ParallelWorthiness)
task_graph: TaskGraph | None = None
recommended_runtime_mode: Literal["direct", "collaboration"] = "direct"
execution_mode: Literal["direct", "collaboration"] | None = None
current_agent_role: str | None = None
conversation_state_ref: str | None = None
assembly_metrics: dict[str, float] = Field(default_factory=dict)
assembled_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
def assess_parallel_worthiness(
query_text: str,
*,
retrospective_count: int = 0,
skill_count: int = 0,
) -> ParallelWorthiness:
normalized = (query_text or "").strip().lower()
reasons: list[str] = []
score = 0.0
multi_step_markers = ("然后", "接着", "同时", "并且", "最后", "汇总", "对比", "分析", "research")
artifact_markers = ("文档", "代码", "文件", "数据库", "论坛", "知识库", "计划")
if any(marker in normalized for marker in multi_step_markers):
score += 0.35
reasons.append("multi_step_request")
if sum(1 for marker in artifact_markers if marker in normalized) >= 2:
score += 0.25
reasons.append("multi_source_context")
if len(re.findall(r"[,、;;]", query_text or "")) >= 2:
score += 0.15
reasons.append("compound_instruction")
if retrospective_count > 0:
score += 0.1
reasons.append("historical_support")
if skill_count > 0:
score += 0.1
reasons.append("skill_candidates_available")
score = min(score, 1.0)
should_parallelize = score >= 0.55
preferred_mode: ParallelPreference = "parallel" if should_parallelize else "direct"
if not should_parallelize and score >= 0.3:
preferred_mode = "collaboration"
estimated_subtasks = 1
if preferred_mode == "parallel":
estimated_subtasks = 3 if score >= 0.8 else 2
elif preferred_mode == "collaboration":
estimated_subtasks = 2
return ParallelWorthiness(
should_parallelize=should_parallelize,
score=round(score, 3),
estimated_subtasks=estimated_subtasks,
preferred_mode=preferred_mode,
reasons=reasons,
)
def render_runtime_request_context_summary(context: RuntimeRequestContext) -> str:
lines = ["【Runtime Request Context】"]
lines.append(f"- 推荐运行模式: {context.recommended_runtime_mode}")
lines.append(
f"- 并行潜力: score={context.parallel_worthiness.score:.2f}, "
f"preferred={context.parallel_worthiness.preferred_mode}, "
f"estimated_subtasks={context.parallel_worthiness.estimated_subtasks}"
)
if context.parallel_worthiness.reasons:
lines.append(f"- 并行判断依据: {', '.join(context.parallel_worthiness.reasons)}")
if context.assembly_metrics:
total_ms = context.assembly_metrics.get("total_ms")
if total_ms is not None:
lines.append(f"- 上下文装配耗时: {total_ms:.1f} ms")
if context.task_graph and context.task_graph.nodes:
lines.append(
f"- 任务图: nodes={len(context.task_graph.nodes)}, max_parallelism={context.task_graph.max_parallelism}"
)
for node in context.task_graph.nodes[:4]:
deps = f", deps={len(node.depends_on)}" if node.depends_on else ""
lines.append(f" - [{node.execution_mode}] {node.title} ({node.role}{deps})")
if context.retrospective_shortlist:
lines.append("- 历史复盘命中:")
for item in context.retrospective_shortlist[:3]:
summary = (item.get("summary") or item.get("summary_text") or "").strip()
task_type = item.get("task_type") or "unknown"
lines.append(f" - [{task_type}] {summary[:160]}")
if context.skill_shortlist:
lines.append("- 技能候选:")
for item in context.skill_shortlist[:3]:
lines.append(
f" - {item.skill_name} ({item.injection_mode}, score={item.score:.2f})"
+ (f": {item.rationale}" if item.rationale else "")
)
if context.recalled_memories:
lines.append("- 记忆上下文已装配,可在回答中按需引用。")
return "\n".join(lines)