feat: add agent visibility APIs and harden runtime verification
Add Day 4 visibility endpoints and response models, strengthen collaboration/task verification behavior, and patch conversation schema startup migration for agent_state compatibility. Extend backend regression coverage for runtime schemas, verifier behavior, visibility APIs, router auth, and legacy conversation list loading.
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -324,6 +324,25 @@ ANALYST_INSIGHTS_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
COORDINATOR_SYSTEM_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||||
|
|
||||||
|
你是 Jarvis 的协作协调官,负责把复杂请求收束成最小受控协作,而不是放任系统进入自由 swarm。
|
||||||
|
|
||||||
|
## 你的职责:
|
||||||
|
- 先判断当前请求是否真的需要拆解;不需要时应明确建议继续走 direct
|
||||||
|
- 只有在明显多步骤、跨领域、需要多角色配合时,才拆成 2~4 个子任务
|
||||||
|
- 每个子任务必须清晰写出 `title`、`role`、`goal`、`expected_evidence`
|
||||||
|
- 角色建议只能来自现有 top-level agent:`schedule_planner`、`librarian`、`analyst`、`executor`
|
||||||
|
- 汇总时基于子任务结果回收,不依赖单点硬编码拼接
|
||||||
|
|
||||||
|
## 边界:
|
||||||
|
- 禁止无限递归拆分
|
||||||
|
- 禁止创建新的 runtime agent / worker
|
||||||
|
- 禁止把一个简单请求硬拆成多个空泛步骤
|
||||||
|
- 如果证据不足、子任务未闭环,必须把风险明确暴露出来
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
VERIFIER_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
VERIFIER_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||||
|
|
||||||
你是 Jarvis 的验证官,负责对执行结果做最小但明确的核验。
|
你是 Jarvis 的验证官,负责对执行结果做最小但明确的核验。
|
||||||
|
|||||||
@@ -57,6 +57,19 @@ TOP_LEVEL_AGENT_ROUTING_HINTS: dict[str, tuple[str, ...]] = {
|
|||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TOP_LEVEL_AGENT_ALLOWED_SPAWN_ROLES: dict[str, tuple[str, ...]] = {
|
||||||
|
AgentRole.MASTER.value: (
|
||||||
|
AgentRole.SCHEDULE_PLANNER.value,
|
||||||
|
AgentRole.EXECUTOR.value,
|
||||||
|
AgentRole.LIBRARIAN.value,
|
||||||
|
AgentRole.ANALYST.value,
|
||||||
|
),
|
||||||
|
AgentRole.SCHEDULE_PLANNER.value: (AgentRole.SCHEDULE_PLANNER.value,),
|
||||||
|
AgentRole.EXECUTOR.value: (AgentRole.EXECUTOR.value,),
|
||||||
|
AgentRole.LIBRARIAN.value: (AgentRole.LIBRARIAN.value,),
|
||||||
|
AgentRole.ANALYST.value: (AgentRole.ANALYST.value,),
|
||||||
|
}
|
||||||
|
|
||||||
SUB_COMMANDER_PARENT_AGENT_IDS: dict[str, str] = {
|
SUB_COMMANDER_PARENT_AGENT_IDS: dict[str, str] = {
|
||||||
"schedule_analysis": AgentRole.SCHEDULE_PLANNER.value,
|
"schedule_analysis": AgentRole.SCHEDULE_PLANNER.value,
|
||||||
"schedule_planning": AgentRole.SCHEDULE_PLANNER.value,
|
"schedule_planning": AgentRole.SCHEDULE_PLANNER.value,
|
||||||
@@ -77,6 +90,8 @@ BUILTIN_AGENT_MANIFESTS: tuple[AgentManifest, ...] = tuple(
|
|||||||
system_prompt_key=role.value,
|
system_prompt_key=role.value,
|
||||||
routing_hints=list(TOP_LEVEL_AGENT_ROUTING_HINTS[role.value]),
|
routing_hints=list(TOP_LEVEL_AGENT_ROUTING_HINTS[role.value]),
|
||||||
default_sub_commanders=list(TOP_LEVEL_AGENT_DEFAULT_SUB_COMMANDERS[role.value]),
|
default_sub_commanders=list(TOP_LEVEL_AGENT_DEFAULT_SUB_COMMANDERS[role.value]),
|
||||||
|
can_spawn_children=bool(TOP_LEVEL_AGENT_ALLOWED_SPAWN_ROLES[role.value]),
|
||||||
|
allowed_spawn_role_values=list(TOP_LEVEL_AGENT_ALLOWED_SPAWN_ROLES[role.value]),
|
||||||
skill_context_key=role.value.replace("agent_", ""),
|
skill_context_key=role.value.replace("agent_", ""),
|
||||||
)
|
)
|
||||||
for role in AgentRole
|
for role in AgentRole
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from app.agents.registry.models import (
|
|||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class RegistryIndexes:
|
class RegistryIndexes:
|
||||||
agent_by_id: Mapping[str, AgentManifest]
|
agent_by_id: Mapping[str, AgentManifest]
|
||||||
|
agent_by_role_value: Mapping[str, AgentManifest]
|
||||||
sub_commander_by_id: Mapping[str, SubCommanderManifest]
|
sub_commander_by_id: Mapping[str, SubCommanderManifest]
|
||||||
capability_by_id: Mapping[str, CapabilityManifest]
|
capability_by_id: Mapping[str, CapabilityManifest]
|
||||||
specialist_template_by_id: Mapping[str, SpecialistTemplateManifest]
|
specialist_template_by_id: Mapping[str, SpecialistTemplateManifest]
|
||||||
@@ -24,6 +25,7 @@ class RegistryIndexes:
|
|||||||
skill_context_key_by_agent_id: Mapping[str, str]
|
skill_context_key_by_agent_id: Mapping[str, str]
|
||||||
capability_id_by_tool_name: Mapping[str, str]
|
capability_id_by_tool_name: Mapping[str, str]
|
||||||
capability_ids_by_sub_commander_id: Mapping[str, tuple[str, ...]]
|
capability_ids_by_sub_commander_id: Mapping[str, tuple[str, ...]]
|
||||||
|
spawnable_role_values_by_agent_id: Mapping[str, tuple[str, ...]]
|
||||||
|
|
||||||
|
|
||||||
def summarize_registry_indexes(indexes: RegistryIndexes) -> dict[str, int]:
|
def summarize_registry_indexes(indexes: RegistryIndexes) -> dict[str, int]:
|
||||||
@@ -50,6 +52,9 @@ def build_registry_indexes(bundle: RegistryBundle) -> RegistryIndexes:
|
|||||||
|
|
||||||
return RegistryIndexes(
|
return RegistryIndexes(
|
||||||
agent_by_id=MappingProxyType(agent_by_id),
|
agent_by_id=MappingProxyType(agent_by_id),
|
||||||
|
agent_by_role_value=MappingProxyType({
|
||||||
|
agent.role_value: agent for agent in bundle.agents
|
||||||
|
}),
|
||||||
sub_commander_by_id=MappingProxyType(sub_commander_by_id),
|
sub_commander_by_id=MappingProxyType(sub_commander_by_id),
|
||||||
capability_by_id=MappingProxyType(capability_by_id),
|
capability_by_id=MappingProxyType(capability_by_id),
|
||||||
specialist_template_by_id=MappingProxyType(specialist_template_by_id),
|
specialist_template_by_id=MappingProxyType(specialist_template_by_id),
|
||||||
@@ -73,4 +78,9 @@ def build_registry_indexes(bundle: RegistryBundle) -> RegistryIndexes:
|
|||||||
sub_commander.sub_commander_id: tuple(sub_commander.capability_ids)
|
sub_commander.sub_commander_id: tuple(sub_commander.capability_ids)
|
||||||
for sub_commander in bundle.sub_commanders
|
for sub_commander in bundle.sub_commanders
|
||||||
}),
|
}),
|
||||||
|
spawnable_role_values_by_agent_id=MappingProxyType({
|
||||||
|
agent.agent_id: tuple(agent.allowed_spawn_role_values)
|
||||||
|
for agent in bundle.agents
|
||||||
|
if agent.can_spawn_children and agent.allowed_spawn_role_values
|
||||||
|
}),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
class PermissionClass(str, Enum):
|
class PermissionClass(str, Enum):
|
||||||
@@ -23,6 +23,8 @@ class AgentManifest(BaseModel):
|
|||||||
system_prompt_key: str
|
system_prompt_key: str
|
||||||
routing_hints: list[str]
|
routing_hints: list[str]
|
||||||
default_sub_commanders: list[str]
|
default_sub_commanders: list[str]
|
||||||
|
can_spawn_children: bool = False
|
||||||
|
allowed_spawn_role_values: list[str] = Field(default_factory=list)
|
||||||
skill_context_key: str | None = None
|
skill_context_key: str | None = None
|
||||||
continuity_policy: str | None = None
|
continuity_policy: str | None = None
|
||||||
clarification_policy: str | None = None
|
clarification_policy: str | None = None
|
||||||
|
|||||||
@@ -1,10 +1,25 @@
|
|||||||
from app.agents.schemas.event import AgentEvent
|
from app.agents.schemas.event import AgentEvent
|
||||||
from app.agents.schemas.task import AgentTask, TaskResult, TaskLifecycleStatus, VerificationStatus
|
from app.agents.schemas.message import AgentMessage
|
||||||
|
from app.agents.schemas.task import (
|
||||||
|
AgentTask,
|
||||||
|
CollaborationBudget,
|
||||||
|
InterruptRecord,
|
||||||
|
RecoveryRecord,
|
||||||
|
TaskLifecycleStatus,
|
||||||
|
TaskResult,
|
||||||
|
TaskResultStatus,
|
||||||
|
VerificationStatus,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AgentEvent",
|
"AgentEvent",
|
||||||
|
"AgentMessage",
|
||||||
"AgentTask",
|
"AgentTask",
|
||||||
|
"CollaborationBudget",
|
||||||
|
"InterruptRecord",
|
||||||
|
"RecoveryRecord",
|
||||||
"TaskLifecycleStatus",
|
"TaskLifecycleStatus",
|
||||||
"TaskResult",
|
"TaskResult",
|
||||||
|
"TaskResultStatus",
|
||||||
"VerificationStatus",
|
"VerificationStatus",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -11,6 +11,18 @@ AgentEventType = Literal[
|
|||||||
"agent.tool.result",
|
"agent.tool.result",
|
||||||
"agent.verify.started",
|
"agent.verify.started",
|
||||||
"agent.verify.completed",
|
"agent.verify.completed",
|
||||||
|
"agent.created",
|
||||||
|
"agent.spawn.blocked",
|
||||||
|
"agent.message.sent",
|
||||||
|
"agent.message.received",
|
||||||
|
"agent.interrupt.requested",
|
||||||
|
"agent.interrupt.completed",
|
||||||
|
"agent.recovery.started",
|
||||||
|
"agent.recovery.completed",
|
||||||
|
"agent.task.interrupted",
|
||||||
|
"agent.task.recovered",
|
||||||
|
"agent.task.reassigned",
|
||||||
|
"agent.collaboration.budget.updated",
|
||||||
"agent.error",
|
"agent.error",
|
||||||
]
|
]
|
||||||
AgentEventSeverity = Literal["info", "warning", "error"]
|
AgentEventSeverity = Literal["info", "warning", "error"]
|
||||||
@@ -24,5 +36,11 @@ class AgentEvent(BaseModel):
|
|||||||
agent_id: str | None = None
|
agent_id: str | None = None
|
||||||
sub_commander_id: str | None = None
|
sub_commander_id: str | None = None
|
||||||
task_id: str | None = None
|
task_id: str | None = None
|
||||||
|
parent_task_id: str | None = None
|
||||||
|
child_task_id: str | None = None
|
||||||
|
thread_id: str | None = None
|
||||||
|
message_id: str | None = None
|
||||||
|
interrupt_id: str | None = None
|
||||||
|
recovery_id: str | None = None
|
||||||
payload: dict[str, Any] = Field(default_factory=dict)
|
payload: dict[str, Any] = Field(default_factory=dict)
|
||||||
severity: AgentEventSeverity = "info"
|
severity: AgentEventSeverity = "info"
|
||||||
|
|||||||
29
backend/app/agents/schemas/message.py
Normal file
29
backend/app/agents/schemas/message.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
AgentMessageType = Literal[
|
||||||
|
"task_request",
|
||||||
|
"task_update",
|
||||||
|
"handoff",
|
||||||
|
"verification_request",
|
||||||
|
"verification_feedback",
|
||||||
|
"interrupt_notice",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class AgentMessage(BaseModel):
|
||||||
|
message_id: str
|
||||||
|
thread_id: str
|
||||||
|
from_agent_id: str
|
||||||
|
to_agent_id: str
|
||||||
|
task_id: str | None = None
|
||||||
|
reply_to_message_id: str | None = None
|
||||||
|
message_type: AgentMessageType = "task_update"
|
||||||
|
content_summary: str
|
||||||
|
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
|
payload: dict[str, Any] = Field(default_factory=dict)
|
||||||
@@ -8,6 +8,41 @@ from pydantic import BaseModel, Field
|
|||||||
|
|
||||||
TaskLifecycleStatus = Literal["pending", "in_progress", "completed", "failed", "blocked"]
|
TaskLifecycleStatus = Literal["pending", "in_progress", "completed", "failed", "blocked"]
|
||||||
VerificationStatus = Literal["passed", "failed", "skipped"]
|
VerificationStatus = Literal["passed", "failed", "skipped"]
|
||||||
|
TaskResultStatus = Literal["completed", "failed", "blocked", "passed", "skipped"]
|
||||||
|
InterruptStatus = Literal["requested", "acknowledged", "resolved"]
|
||||||
|
BudgetMode = Literal["direct", "collaboration"]
|
||||||
|
|
||||||
|
|
||||||
|
class InterruptRecord(BaseModel):
|
||||||
|
interrupt_id: str
|
||||||
|
reason: str
|
||||||
|
status: InterruptStatus = "requested"
|
||||||
|
requested_by: str | None = None
|
||||||
|
source_event_id: str | None = None
|
||||||
|
requested_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
|
payload: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class RecoveryRecord(BaseModel):
|
||||||
|
recovery_id: str
|
||||||
|
source_interrupt_id: str | None = None
|
||||||
|
strategy: str | None = None
|
||||||
|
resumed_from_task_id: str | None = None
|
||||||
|
resumed_from_thread_id: str | None = None
|
||||||
|
recovered_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
|
payload: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class CollaborationBudget(BaseModel):
|
||||||
|
mode: BudgetMode = "direct"
|
||||||
|
max_parallel_tasks: int | None = None
|
||||||
|
remaining_parallel_tasks: int | None = None
|
||||||
|
max_tool_calls: int | None = None
|
||||||
|
remaining_tool_calls: int | None = None
|
||||||
|
max_iterations: int | None = None
|
||||||
|
remaining_iterations: int | None = None
|
||||||
|
escalation_threshold: int | None = None
|
||||||
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class AgentTask(BaseModel):
|
class AgentTask(BaseModel):
|
||||||
@@ -17,8 +52,16 @@ class AgentTask(BaseModel):
|
|||||||
owner_agent_id: str | None = None
|
owner_agent_id: str | None = None
|
||||||
role: str | None = None
|
role: str | None = None
|
||||||
goal: str | None = None
|
goal: str | None = None
|
||||||
|
parent_task_id: str | None = None
|
||||||
|
child_task_ids: list[str] = Field(default_factory=list)
|
||||||
|
thread_id: str | None = None
|
||||||
|
message_id: str | None = None
|
||||||
|
message_index: int | None = None
|
||||||
expected_evidence: list[dict[str, Any]] = Field(default_factory=list)
|
expected_evidence: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
evidence: list[dict[str, Any]] = Field(default_factory=list)
|
evidence: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
|
interrupt_records: list[InterruptRecord | dict[str, Any]] = Field(default_factory=list)
|
||||||
|
recovery_records: list[RecoveryRecord | dict[str, Any]] = Field(default_factory=list)
|
||||||
|
collaboration_budget: CollaborationBudget | dict[str, Any] | None = None
|
||||||
result_summary: str | None = None
|
result_summary: str | None = None
|
||||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
@@ -26,7 +69,17 @@ class AgentTask(BaseModel):
|
|||||||
|
|
||||||
class TaskResult(BaseModel):
|
class TaskResult(BaseModel):
|
||||||
task_id: str
|
task_id: str
|
||||||
status: VerificationStatus
|
status: TaskResultStatus
|
||||||
summary: str | None = None
|
summary: str | None = None
|
||||||
evidence: list[dict[str, Any]] = Field(default_factory=list)
|
evidence: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
|
owner_agent_id: str | None = None
|
||||||
|
parent_task_id: str | None = None
|
||||||
|
child_task_ids: list[str] = Field(default_factory=list)
|
||||||
|
thread_id: str | None = None
|
||||||
|
message_id: str | None = None
|
||||||
|
message_index: int | None = None
|
||||||
|
interrupt_records: list[InterruptRecord | dict[str, Any]] = Field(default_factory=list)
|
||||||
|
recovery_records: list[RecoveryRecord | dict[str, Any]] = Field(default_factory=list)
|
||||||
|
budget_snapshot: CollaborationBudget | dict[str, Any] | None = None
|
||||||
|
next_action: str | None = None
|
||||||
output_data: dict[str, Any] | None = None
|
output_data: dict[str, Any] | None = None
|
||||||
|
|||||||
@@ -3,8 +3,9 @@ from enum import Enum
|
|||||||
from typing import Annotated, Any, Literal, TypedDict
|
from typing import Annotated, Any, Literal, TypedDict
|
||||||
|
|
||||||
from app.agents.schemas.event import AgentEvent
|
from app.agents.schemas.event import AgentEvent
|
||||||
from app.agents.schemas.task import AgentTask, TaskResult, VerificationStatus
|
from app.agents.schemas.message import AgentMessage
|
||||||
from langchain_core.messages import BaseMessage
|
from app.agents.schemas.task import AgentTask, CollaborationBudget, InterruptRecord, RecoveryRecord, TaskResult, VerificationStatus
|
||||||
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||||
from langgraph.graph.message import add_messages
|
from langgraph.graph.message import add_messages
|
||||||
|
|
||||||
|
|
||||||
@@ -24,12 +25,27 @@ class ConversationTurn:
|
|||||||
model: str | None = None
|
model: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def turn_to_message(turn: ConversationTurn) -> BaseMessage:
|
||||||
|
if turn.role == "user":
|
||||||
|
return HumanMessage(content=turn.content)
|
||||||
|
return AIMessage(content=turn.content)
|
||||||
|
|
||||||
|
|
||||||
class AgentState(TypedDict):
|
class AgentState(TypedDict):
|
||||||
messages: Annotated[list[BaseMessage], add_messages]
|
messages: Annotated[list[BaseMessage], add_messages]
|
||||||
user_id: str
|
user_id: str
|
||||||
conversation_id: str
|
conversation_id: str
|
||||||
|
parent_conversation_id: str | None
|
||||||
|
thread_id: str | None
|
||||||
|
last_message_id: str | None
|
||||||
|
message_sequence: int
|
||||||
|
agent_id: str | None
|
||||||
|
parent_agent_id: str | None
|
||||||
|
root_agent_id: str | None
|
||||||
|
collaboration_depth: int
|
||||||
|
spawned_agent_ids: list[str]
|
||||||
|
|
||||||
execution_mode: Literal["direct", "delegated", "verified"]
|
execution_mode: Literal["direct", "collaboration", "delegated", "verified"]
|
||||||
current_agent: str | None
|
current_agent: str | None
|
||||||
next_step: str | None
|
next_step: str | None
|
||||||
active_agents: list[AgentRole]
|
active_agents: list[AgentRole]
|
||||||
@@ -38,11 +54,16 @@ class AgentState(TypedDict):
|
|||||||
sub_commander_trace: list[dict[str, Any]]
|
sub_commander_trace: list[dict[str, Any]]
|
||||||
agent_trace: list[str]
|
agent_trace: list[str]
|
||||||
event_trace: list[AgentEvent | dict[str, Any]]
|
event_trace: list[AgentEvent | dict[str, Any]]
|
||||||
|
message_trace: list[AgentMessage | dict[str, Any]]
|
||||||
|
|
||||||
pending_tasks: list[dict[str, Any]]
|
pending_tasks: list[dict[str, Any]]
|
||||||
completed_tasks: list[dict[str, Any]]
|
completed_tasks: list[dict[str, Any]]
|
||||||
active_tasks: list[AgentTask | dict[str, Any]]
|
active_tasks: list[AgentTask | dict[str, Any]]
|
||||||
task_results: list[TaskResult | dict[str, Any]]
|
task_results: list[TaskResult | dict[str, Any]]
|
||||||
|
task_hierarchy: dict[str, list[str]]
|
||||||
|
interrupted_tasks: list[InterruptRecord | dict[str, Any]]
|
||||||
|
recovery_trace: list[RecoveryRecord | dict[str, Any]]
|
||||||
|
recovery_points: list[dict[str, Any]]
|
||||||
tool_calls: list[dict[str, Any]]
|
tool_calls: list[dict[str, Any]]
|
||||||
last_tool_result: str | None
|
last_tool_result: str | None
|
||||||
action_results: list[dict[str, Any]]
|
action_results: list[dict[str, Any]]
|
||||||
@@ -54,7 +75,8 @@ class AgentState(TypedDict):
|
|||||||
verification_status: VerificationStatus | None
|
verification_status: VerificationStatus | None
|
||||||
verification_summary: str | None
|
verification_summary: str | None
|
||||||
verification_evidence: list[dict[str, Any]]
|
verification_evidence: list[dict[str, Any]]
|
||||||
budget_state: dict[str, Any] | None
|
budget_state: CollaborationBudget | dict[str, Any] | None
|
||||||
|
collaboration_budget_history: list[CollaborationBudget | dict[str, Any]]
|
||||||
|
|
||||||
tool_strategy_used: str | None
|
tool_strategy_used: str | None
|
||||||
tool_round_count: int
|
tool_round_count: int
|
||||||
@@ -102,6 +124,15 @@ def initial_state(user_id: str, conversation_id: str) -> AgentState:
|
|||||||
messages=[],
|
messages=[],
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
|
parent_conversation_id=None,
|
||||||
|
thread_id=None,
|
||||||
|
last_message_id=None,
|
||||||
|
message_sequence=0,
|
||||||
|
agent_id=AgentRole.MASTER.value,
|
||||||
|
parent_agent_id=None,
|
||||||
|
root_agent_id=AgentRole.MASTER.value,
|
||||||
|
collaboration_depth=0,
|
||||||
|
spawned_agent_ids=[],
|
||||||
execution_mode="direct",
|
execution_mode="direct",
|
||||||
current_agent=AgentRole.MASTER.value,
|
current_agent=AgentRole.MASTER.value,
|
||||||
next_step=None,
|
next_step=None,
|
||||||
@@ -111,10 +142,15 @@ def initial_state(user_id: str, conversation_id: str) -> AgentState:
|
|||||||
sub_commander_trace=[],
|
sub_commander_trace=[],
|
||||||
agent_trace=[AgentRole.MASTER.value],
|
agent_trace=[AgentRole.MASTER.value],
|
||||||
event_trace=[],
|
event_trace=[],
|
||||||
|
message_trace=[],
|
||||||
pending_tasks=[],
|
pending_tasks=[],
|
||||||
completed_tasks=[],
|
completed_tasks=[],
|
||||||
active_tasks=[],
|
active_tasks=[],
|
||||||
task_results=[],
|
task_results=[],
|
||||||
|
task_hierarchy={},
|
||||||
|
interrupted_tasks=[],
|
||||||
|
recovery_trace=[],
|
||||||
|
recovery_points=[],
|
||||||
tool_calls=[],
|
tool_calls=[],
|
||||||
last_tool_result=None,
|
last_tool_result=None,
|
||||||
action_results=[],
|
action_results=[],
|
||||||
@@ -126,6 +162,7 @@ def initial_state(user_id: str, conversation_id: str) -> AgentState:
|
|||||||
verification_summary=None,
|
verification_summary=None,
|
||||||
verification_evidence=[],
|
verification_evidence=[],
|
||||||
budget_state=None,
|
budget_state=None,
|
||||||
|
collaboration_budget_history=[],
|
||||||
tool_strategy_used=None,
|
tool_strategy_used=None,
|
||||||
tool_round_count=0,
|
tool_round_count=0,
|
||||||
max_tool_rounds=2,
|
max_tool_rounds=2,
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any, cast
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from app.agents.schemas.task import AgentTask, TaskResult, VerificationStatus
|
from app.agents.schemas.task import AgentTask, TaskResult, TaskResultStatus, VerificationStatus
|
||||||
from app.agents.state import AgentState
|
from app.agents.state import AgentState
|
||||||
|
|
||||||
|
|
||||||
@@ -14,6 +14,34 @@ class VerificationVerdict(BaseModel):
|
|||||||
evidence: list[dict[str, Any]] = Field(default_factory=list)
|
evidence: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_task_result(
|
||||||
|
task_result: TaskResult | dict[str, Any],
|
||||||
|
*,
|
||||||
|
default_task_id: str | None = None,
|
||||||
|
) -> TaskResult:
|
||||||
|
payload = task_result.model_dump(mode="json") if isinstance(task_result, TaskResult) else dict(task_result or {})
|
||||||
|
normalized_status = payload.get("status")
|
||||||
|
if normalized_status not in {"completed", "failed", "blocked", "passed", "skipped"}:
|
||||||
|
normalized_status = "failed"
|
||||||
|
return TaskResult(
|
||||||
|
task_id=str(payload.get("task_id") or default_task_id or "unknown-task"),
|
||||||
|
status=cast(TaskResultStatus, normalized_status),
|
||||||
|
summary=payload.get("summary"),
|
||||||
|
evidence=list(payload.get("evidence") or []),
|
||||||
|
owner_agent_id=payload.get("owner_agent_id"),
|
||||||
|
parent_task_id=payload.get("parent_task_id"),
|
||||||
|
child_task_ids=list(payload.get("child_task_ids") or []),
|
||||||
|
thread_id=payload.get("thread_id"),
|
||||||
|
message_id=payload.get("message_id"),
|
||||||
|
message_index=payload.get("message_index") if isinstance(payload.get("message_index"), int) else None,
|
||||||
|
interrupt_records=list(payload.get("interrupt_records") or []),
|
||||||
|
recovery_records=list(payload.get("recovery_records") or []),
|
||||||
|
budget_snapshot=payload.get("budget_snapshot") if isinstance(payload.get("budget_snapshot"), dict) else None,
|
||||||
|
next_action=payload.get("next_action"),
|
||||||
|
output_data=payload.get("output_data") if isinstance(payload.get("output_data"), dict) else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def verify_task_result(
|
def verify_task_result(
|
||||||
*,
|
*,
|
||||||
task: AgentTask | dict[str, Any] | None = None,
|
task: AgentTask | dict[str, Any] | None = None,
|
||||||
@@ -30,8 +58,13 @@ def verify_task_result(
|
|||||||
if status is not None:
|
if status is not None:
|
||||||
return VerificationVerdict(status=status, summary=normalized_summary, evidence=normalized_evidence)
|
return VerificationVerdict(status=status, summary=normalized_summary, evidence=normalized_evidence)
|
||||||
|
|
||||||
if normalized_result.get("status") in {"passed", "failed", "skipped"}:
|
normalized_status = normalized_result.get("status")
|
||||||
inferred_status = normalized_result["status"]
|
if normalized_status in {"passed", "failed", "skipped"}:
|
||||||
|
inferred_status = normalized_status
|
||||||
|
elif normalized_status == "completed":
|
||||||
|
inferred_status = "passed"
|
||||||
|
elif normalized_status == "blocked":
|
||||||
|
inferred_status = "skipped"
|
||||||
elif normalized_result.get("success") is True:
|
elif normalized_result.get("success") is True:
|
||||||
inferred_status = "passed"
|
inferred_status = "passed"
|
||||||
elif normalized_result.get("success") is False:
|
elif normalized_result.get("success") is False:
|
||||||
@@ -57,4 +90,4 @@ def apply_verification_verdict(state: AgentState, verdict: VerificationVerdict)
|
|||||||
return AgentState(**next_state)
|
return AgentState(**next_state)
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["VerificationVerdict", "apply_verification_verdict", "verify_task_result"]
|
__all__ = ["VerificationVerdict", "apply_verification_verdict", "normalize_task_result", "verify_task_result"]
|
||||||
|
|||||||
@@ -1,10 +1,13 @@
|
|||||||
from sqlalchemy import text
|
from collections.abc import AsyncGenerator
|
||||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
|
||||||
from sqlalchemy.orm import DeclarativeBase
|
|
||||||
from app.config import settings
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
from sqlalchemy import text
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
os.makedirs(settings.DATA_DIR, exist_ok=True)
|
os.makedirs(settings.DATA_DIR, exist_ok=True)
|
||||||
|
|
||||||
engine = create_async_engine(
|
engine = create_async_engine(
|
||||||
@@ -24,12 +27,9 @@ class Base(DeclarativeBase):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
async def get_db() -> AsyncSession:
|
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||||
async with async_session() as session:
|
async with async_session() as session:
|
||||||
try:
|
yield session
|
||||||
yield session
|
|
||||||
finally:
|
|
||||||
await session.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def init_db():
|
async def init_db():
|
||||||
@@ -37,6 +37,7 @@ async def init_db():
|
|||||||
await conn.run_sync(Base.metadata.create_all)
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
await ensure_log_columns(conn)
|
await ensure_log_columns(conn)
|
||||||
await ensure_message_columns(conn)
|
await ensure_message_columns(conn)
|
||||||
|
await ensure_conversation_columns(conn)
|
||||||
await ensure_document_columns(conn)
|
await ensure_document_columns(conn)
|
||||||
await ensure_user_columns(conn)
|
await ensure_user_columns(conn)
|
||||||
await ensure_forum_columns(conn)
|
await ensure_forum_columns(conn)
|
||||||
@@ -79,6 +80,20 @@ async def ensure_message_columns(conn):
|
|||||||
await conn.execute(text(ddl))
|
await conn.execute(text(ddl))
|
||||||
|
|
||||||
|
|
||||||
|
async def ensure_conversation_columns(conn):
|
||||||
|
rows = await _get_table_info(conn, 'conversations')
|
||||||
|
if not rows:
|
||||||
|
return
|
||||||
|
|
||||||
|
columns = {row[1] for row in rows}
|
||||||
|
required_columns = {
|
||||||
|
'agent_state': "ALTER TABLE conversations ADD COLUMN agent_state JSON",
|
||||||
|
}
|
||||||
|
for column, ddl in required_columns.items():
|
||||||
|
if column not in columns:
|
||||||
|
await conn.execute(text(ddl))
|
||||||
|
|
||||||
|
|
||||||
async def ensure_document_columns(conn):
|
async def ensure_document_columns(conn):
|
||||||
result = await conn.execute(text("PRAGMA table_info(documents)"))
|
result = await conn.execute(text("PRAGMA table_info(documents)"))
|
||||||
rows = result.fetchall()
|
rows = result.fetchall()
|
||||||
|
|||||||
@@ -1,12 +1,33 @@
|
|||||||
from fastapi import APIRouter, Depends, HTTPException
|
from datetime import datetime
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.models.agent import Agent
|
from app.models.agent import Agent
|
||||||
|
from app.models.conversation import Conversation
|
||||||
from app.models.skill import Skill
|
from app.models.skill import Skill
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.routers.auth import get_current_user
|
from app.routers.auth import get_current_user
|
||||||
from app.schemas.agent import AgentCreate, AgentOut, AgentStats, AgentConfigUpdate, AgentConfigOut
|
from app.schemas.agent import (
|
||||||
|
AgentConfigOut,
|
||||||
|
AgentConfigUpdate,
|
||||||
|
AgentCreate,
|
||||||
|
AgentOut,
|
||||||
|
AgentStats,
|
||||||
|
AgentVisibilityEvidenceOut,
|
||||||
|
AgentVisibilityEventsResponse,
|
||||||
|
AgentVisibilityEventOut,
|
||||||
|
AgentVisibilityTaskSummaryOut,
|
||||||
|
AgentVisibilityThreadMessageOut,
|
||||||
|
AgentVisibilityThreadOut,
|
||||||
|
AgentVisibilityTopologyNodeOut,
|
||||||
|
AgentVisibilityTopologyOut,
|
||||||
|
AgentVisibilityVerifierOut,
|
||||||
|
)
|
||||||
|
from app.services.agent_service import _extract_continuity_snapshot
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/agents", tags=["Agent"])
|
router = APIRouter(prefix="/api/agents", tags=["Agent"])
|
||||||
|
|
||||||
@@ -21,6 +42,147 @@ SUB_COMMANDERS_BY_ROLE = {
|
|||||||
"librarian": ["librarian_retrieval", "librarian_graph"],
|
"librarian": ["librarian_retrieval", "librarian_graph"],
|
||||||
"analyst": ["analyst_progress", "analyst_insights"],
|
"analyst": ["analyst_progress", "analyst_insights"],
|
||||||
}
|
}
|
||||||
|
ALLOWED_AGENT_ROLES = set(DEFAULT_AGENT_ROLES) | {
|
||||||
|
role
|
||||||
|
for sub_roles in SUB_COMMANDERS_BY_ROLE.values()
|
||||||
|
for role in sub_roles
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_visibility_datetime(value: str | None) -> datetime | None:
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return datetime.fromisoformat(value.replace("Z", "+00:00"))
|
||||||
|
except ValueError as exc:
|
||||||
|
raise HTTPException(status_code=400, detail="时间参数必须是 ISO 8601 格式") from exc
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_visibility_state(
|
||||||
|
conversation_id: str,
|
||||||
|
*,
|
||||||
|
current_user: User,
|
||||||
|
db: AsyncSession,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
result = await db.execute(
|
||||||
|
select(Conversation).where(
|
||||||
|
Conversation.id == conversation_id,
|
||||||
|
Conversation.user_id == current_user.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
conversation = result.scalar_one_or_none()
|
||||||
|
if conversation is None:
|
||||||
|
raise HTTPException(status_code=404, detail="对话不存在")
|
||||||
|
snapshot = _extract_continuity_snapshot(conversation.agent_state)
|
||||||
|
if snapshot is None:
|
||||||
|
raise HTTPException(status_code=404, detail="当前会话暂无可视化运行时数据")
|
||||||
|
return snapshot
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_event_payload(event: dict[str, Any]) -> AgentVisibilityEventOut:
|
||||||
|
return AgentVisibilityEventOut.model_validate(event)
|
||||||
|
|
||||||
|
|
||||||
|
def _filter_events(
|
||||||
|
events: list[dict[str, Any]],
|
||||||
|
*,
|
||||||
|
agent_id: str | None,
|
||||||
|
thread_id: str | None,
|
||||||
|
event_type: str | None,
|
||||||
|
started_after: datetime | None,
|
||||||
|
ended_before: datetime | None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
filtered: list[dict[str, Any]] = []
|
||||||
|
for event in events:
|
||||||
|
if agent_id and event.get("agent_id") != agent_id:
|
||||||
|
continue
|
||||||
|
if thread_id and event.get("thread_id") != thread_id:
|
||||||
|
continue
|
||||||
|
if event_type and event.get("event_type") != event_type:
|
||||||
|
continue
|
||||||
|
timestamp_raw = event.get("timestamp")
|
||||||
|
timestamp = None
|
||||||
|
if isinstance(timestamp_raw, str):
|
||||||
|
try:
|
||||||
|
timestamp = datetime.fromisoformat(timestamp_raw.replace("Z", "+00:00"))
|
||||||
|
except ValueError:
|
||||||
|
timestamp = None
|
||||||
|
if started_after and timestamp and timestamp < started_after:
|
||||||
|
continue
|
||||||
|
if ended_before and timestamp and timestamp > ended_before:
|
||||||
|
continue
|
||||||
|
filtered.append(event)
|
||||||
|
return filtered
|
||||||
|
|
||||||
|
|
||||||
|
def _summarize_tasks(tasks: list[dict[str, Any]], task_results: list[dict[str, Any]]) -> list[AgentVisibilityTaskSummaryOut]:
|
||||||
|
result_by_task_id = {item.get("task_id"): item for item in task_results}
|
||||||
|
summaries: list[AgentVisibilityTaskSummaryOut] = []
|
||||||
|
for task in tasks:
|
||||||
|
task_id = str(task.get("task_id") or "")
|
||||||
|
result = result_by_task_id.get(task_id) or {}
|
||||||
|
evidence = result.get("evidence") or task.get("evidence") or []
|
||||||
|
summaries.append(
|
||||||
|
AgentVisibilityTaskSummaryOut(
|
||||||
|
task_id=task_id,
|
||||||
|
role=task.get("role"),
|
||||||
|
owner_agent_id=task.get("owner_agent_id") or result.get("owner_agent_id"),
|
||||||
|
status=result.get("status") or task.get("status"),
|
||||||
|
summary=result.get("summary") or task.get("result_summary"),
|
||||||
|
evidence_count=len(evidence),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return summaries
|
||||||
|
|
||||||
|
|
||||||
|
def _build_topology_nodes(
|
||||||
|
state: dict[str, Any],
|
||||||
|
tasks: list[dict[str, Any]],
|
||||||
|
task_results: list[dict[str, Any]],
|
||||||
|
) -> list[AgentVisibilityTopologyNodeOut]:
|
||||||
|
task_counts: dict[str, int] = {}
|
||||||
|
completed_counts: dict[str, int] = {}
|
||||||
|
for task in tasks:
|
||||||
|
owner = str(task.get("owner_agent_id") or "")
|
||||||
|
if owner:
|
||||||
|
task_counts[owner] = task_counts.get(owner, 0) + 1
|
||||||
|
for result in task_results:
|
||||||
|
owner = str(result.get("owner_agent_id") or "")
|
||||||
|
if owner and result.get("status") == "completed":
|
||||||
|
completed_counts[owner] = completed_counts.get(owner, 0) + 1
|
||||||
|
|
||||||
|
root_agent_id = str(state.get("root_agent_id") or state.get("agent_id") or "") or None
|
||||||
|
current_agent = str(state.get("current_agent") or "") or None
|
||||||
|
nodes: dict[str, AgentVisibilityTopologyNodeOut] = {}
|
||||||
|
if root_agent_id:
|
||||||
|
nodes[root_agent_id] = AgentVisibilityTopologyNodeOut(
|
||||||
|
agent_id=root_agent_id,
|
||||||
|
role=root_agent_id.split("-")[0],
|
||||||
|
parent_agent_id=None,
|
||||||
|
source="root",
|
||||||
|
task_count=task_counts.get(root_agent_id, 0),
|
||||||
|
completed_task_count=completed_counts.get(root_agent_id, 0),
|
||||||
|
)
|
||||||
|
for agent_id in state.get("spawned_agent_ids") or []:
|
||||||
|
agent_id = str(agent_id)
|
||||||
|
nodes[agent_id] = AgentVisibilityTopologyNodeOut(
|
||||||
|
agent_id=agent_id,
|
||||||
|
role=agent_id.split("-")[0],
|
||||||
|
parent_agent_id=root_agent_id,
|
||||||
|
source="spawned",
|
||||||
|
task_count=task_counts.get(agent_id, 0),
|
||||||
|
completed_task_count=completed_counts.get(agent_id, 0),
|
||||||
|
)
|
||||||
|
if current_agent and current_agent not in nodes:
|
||||||
|
nodes[current_agent] = AgentVisibilityTopologyNodeOut(
|
||||||
|
agent_id=current_agent,
|
||||||
|
role=current_agent.split("-")[0],
|
||||||
|
parent_agent_id=None if current_agent == root_agent_id else root_agent_id,
|
||||||
|
source="current",
|
||||||
|
task_count=task_counts.get(current_agent, 0),
|
||||||
|
completed_task_count=completed_counts.get(current_agent, 0),
|
||||||
|
)
|
||||||
|
return list(nodes.values())
|
||||||
|
|
||||||
|
|
||||||
def record_agent_call(agent_id: str):
|
def record_agent_call(agent_id: str):
|
||||||
@@ -83,6 +245,7 @@ async def get_agent_hierarchy_stats(
|
|||||||
@router.get("/config/{agent_id}", response_model=AgentConfigOut)
|
@router.get("/config/{agent_id}", response_model=AgentConfigOut)
|
||||||
async def get_agent_config(
|
async def get_agent_config(
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
result = await db.execute(select(Agent).where(Agent.role == agent_id))
|
result = await db.execute(select(Agent).where(Agent.role == agent_id))
|
||||||
@@ -172,12 +335,159 @@ async def update_agent_config(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/visibility/events", response_model=AgentVisibilityEventsResponse)
|
||||||
|
async def get_visibility_events(
|
||||||
|
conversation_id: str,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
agent_id: str | None = None,
|
||||||
|
thread_id: str | None = None,
|
||||||
|
event_type: str | None = None,
|
||||||
|
started_after: str | None = None,
|
||||||
|
ended_before: str | None = None,
|
||||||
|
limit: int = Query(default=50, ge=1, le=200),
|
||||||
|
offset: int = Query(default=0, ge=0),
|
||||||
|
):
|
||||||
|
state = await _get_visibility_state(conversation_id, current_user=current_user, db=db)
|
||||||
|
events = [dict(item) for item in state.get("event_trace") or []]
|
||||||
|
filtered = _filter_events(
|
||||||
|
events,
|
||||||
|
agent_id=agent_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
event_type=event_type,
|
||||||
|
started_after=_parse_visibility_datetime(started_after),
|
||||||
|
ended_before=_parse_visibility_datetime(ended_before),
|
||||||
|
)
|
||||||
|
paged = filtered[offset:offset + limit]
|
||||||
|
return AgentVisibilityEventsResponse(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
total=len(filtered),
|
||||||
|
limit=limit,
|
||||||
|
offset=offset,
|
||||||
|
items=[_coerce_event_payload(item) for item in paged],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/visibility/topology", response_model=AgentVisibilityTopologyOut)
|
||||||
|
async def get_visibility_topology(
|
||||||
|
conversation_id: str,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
state = await _get_visibility_state(conversation_id, current_user=current_user, db=db)
|
||||||
|
tasks = [dict(item) for item in state.get("active_tasks") or []]
|
||||||
|
task_results = [dict(item) for item in state.get("task_results") or []]
|
||||||
|
nodes = _build_topology_nodes(state, tasks, task_results)
|
||||||
|
root_agent_id = str(state.get("root_agent_id") or state.get("agent_id") or "") or None
|
||||||
|
edges = [
|
||||||
|
{"parent_agent_id": root_agent_id, "child_agent_id": node.agent_id}
|
||||||
|
for node in nodes
|
||||||
|
if node.parent_agent_id and root_agent_id and node.agent_id != root_agent_id
|
||||||
|
]
|
||||||
|
return AgentVisibilityTopologyOut(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
root_agent_id=root_agent_id,
|
||||||
|
current_agent=str(state.get("current_agent") or "") or None,
|
||||||
|
nodes=nodes,
|
||||||
|
edges=edges,
|
||||||
|
tasks=_summarize_tasks(tasks, task_results),
|
||||||
|
task_hierarchy=dict(state.get("task_hierarchy") or {}),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/visibility/tasks/{task_id}/evidence", response_model=AgentVisibilityEvidenceOut)
|
||||||
|
async def get_visibility_task_evidence(
|
||||||
|
task_id: str,
|
||||||
|
conversation_id: str,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
state = await _get_visibility_state(conversation_id, current_user=current_user, db=db)
|
||||||
|
tasks = [dict(item) for item in state.get("active_tasks") or []]
|
||||||
|
task = next((item for item in tasks if item.get("task_id") == task_id), None)
|
||||||
|
task_results = [dict(item) for item in state.get("task_results") or []]
|
||||||
|
result = next((item for item in task_results if item.get("task_id") == task_id), None)
|
||||||
|
if task is None and result is None:
|
||||||
|
raise HTTPException(status_code=404, detail="任务不存在")
|
||||||
|
tool_outcomes = [
|
||||||
|
dict(evidence)
|
||||||
|
for evidence in (result or {}).get("evidence") or []
|
||||||
|
if isinstance(evidence, dict) and evidence.get("tool_name")
|
||||||
|
]
|
||||||
|
verification_entry = next(
|
||||||
|
(
|
||||||
|
dict(evidence)
|
||||||
|
for evidence in (result or {}).get("evidence") or []
|
||||||
|
if isinstance(evidence, dict) and evidence.get("type") == "verification"
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
verifier = {
|
||||||
|
"status": (verification_entry or {}).get("status"),
|
||||||
|
"summary": (verification_entry or {}).get("summary"),
|
||||||
|
"evidence": [dict(item) for item in state.get("verification_evidence") or [] if item.get("task_id") == task_id],
|
||||||
|
}
|
||||||
|
return AgentVisibilityEvidenceOut(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
task_id=task_id,
|
||||||
|
task=task,
|
||||||
|
result=result,
|
||||||
|
tool_outcomes=tool_outcomes,
|
||||||
|
verifier=verifier,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/visibility/threads/{thread_id}/messages", response_model=AgentVisibilityThreadOut)
|
||||||
|
async def get_visibility_thread_messages(
|
||||||
|
thread_id: str,
|
||||||
|
conversation_id: str,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
state = await _get_visibility_state(conversation_id, current_user=current_user, db=db)
|
||||||
|
items = [
|
||||||
|
AgentVisibilityThreadMessageOut.model_validate(item)
|
||||||
|
for item in state.get("message_trace") or []
|
||||||
|
if isinstance(item, dict) and item.get("thread_id") == thread_id
|
||||||
|
]
|
||||||
|
if not items:
|
||||||
|
raise HTTPException(status_code=404, detail="线程不存在")
|
||||||
|
return AgentVisibilityThreadOut(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
total=len(items),
|
||||||
|
items=items,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/visibility/verifier", response_model=AgentVisibilityVerifierOut)
|
||||||
|
async def get_visibility_verifier(
|
||||||
|
conversation_id: str,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
state = await _get_visibility_state(conversation_id, current_user=current_user, db=db)
|
||||||
|
return AgentVisibilityVerifierOut(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
status=state.get("verification_status"),
|
||||||
|
summary=state.get("verification_summary"),
|
||||||
|
evidence=list(state.get("verification_evidence") or []),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("", response_model=AgentOut, status_code=201)
|
@router.post("", response_model=AgentOut, status_code=201)
|
||||||
async def create_agent(
|
async def create_agent(
|
||||||
data: AgentCreate,
|
data: AgentCreate,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
|
if not current_user.is_superuser:
|
||||||
|
raise HTTPException(status_code=403, detail="仅管理员可创建 Agent")
|
||||||
|
if not data.spawn_permission:
|
||||||
|
raise HTTPException(status_code=400, detail="缺少 spawn_permission,禁止直接创建 runtime agent")
|
||||||
|
if data.role not in ALLOWED_AGENT_ROLES:
|
||||||
|
raise HTTPException(status_code=400, detail="不支持的 Agent 角色")
|
||||||
|
|
||||||
agent = Agent(
|
agent = Agent(
|
||||||
name=data.name,
|
name=data.name,
|
||||||
role=data.role,
|
role=data.role,
|
||||||
@@ -193,6 +503,7 @@ async def create_agent(
|
|||||||
@router.get("/{agent_id}", response_model=AgentOut)
|
@router.get("/{agent_id}", response_model=AgentOut)
|
||||||
async def get_agent(
|
async def get_agent(
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
result = await db.execute(select(Agent).where(Agent.id == agent_id))
|
result = await db.execute(select(Agent).where(Agent.id == agent_id))
|
||||||
|
|||||||
@@ -1,4 +1,7 @@
|
|||||||
from pydantic import BaseModel
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
class AgentCreate(BaseModel):
|
class AgentCreate(BaseModel):
|
||||||
@@ -6,6 +9,7 @@ class AgentCreate(BaseModel):
|
|||||||
role: str
|
role: str
|
||||||
description: str | None = None
|
description: str | None = None
|
||||||
system_prompt: str
|
system_prompt: str
|
||||||
|
spawn_permission: bool = False
|
||||||
|
|
||||||
|
|
||||||
class AgentOut(BaseModel):
|
class AgentOut(BaseModel):
|
||||||
@@ -55,3 +59,93 @@ class AgentConfigOut(BaseModel):
|
|||||||
selected_skill_ids: list[str]
|
selected_skill_ids: list[str]
|
||||||
|
|
||||||
model_config = {"from_attributes": True}
|
model_config = {"from_attributes": True}
|
||||||
|
|
||||||
|
|
||||||
|
class AgentVisibilityEventOut(BaseModel):
|
||||||
|
event_id: str
|
||||||
|
event_type: str
|
||||||
|
timestamp: datetime
|
||||||
|
conversation_id: str | None = None
|
||||||
|
agent_id: str | None = None
|
||||||
|
sub_commander_id: str | None = None
|
||||||
|
task_id: str | None = None
|
||||||
|
parent_task_id: str | None = None
|
||||||
|
child_task_id: str | None = None
|
||||||
|
thread_id: str | None = None
|
||||||
|
message_id: str | None = None
|
||||||
|
interrupt_id: str | None = None
|
||||||
|
recovery_id: str | None = None
|
||||||
|
payload: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
severity: str = "info"
|
||||||
|
|
||||||
|
|
||||||
|
class AgentVisibilityEventsResponse(BaseModel):
|
||||||
|
conversation_id: str
|
||||||
|
total: int
|
||||||
|
limit: int
|
||||||
|
offset: int
|
||||||
|
items: list[AgentVisibilityEventOut]
|
||||||
|
|
||||||
|
|
||||||
|
class AgentVisibilityTaskSummaryOut(BaseModel):
|
||||||
|
task_id: str
|
||||||
|
role: str | None = None
|
||||||
|
owner_agent_id: str | None = None
|
||||||
|
status: str | None = None
|
||||||
|
summary: str | None = None
|
||||||
|
evidence_count: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class AgentVisibilityTopologyNodeOut(BaseModel):
|
||||||
|
agent_id: str
|
||||||
|
role: str | None = None
|
||||||
|
parent_agent_id: str | None = None
|
||||||
|
source: str
|
||||||
|
task_count: int = 0
|
||||||
|
completed_task_count: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class AgentVisibilityTopologyOut(BaseModel):
|
||||||
|
conversation_id: str
|
||||||
|
root_agent_id: str | None = None
|
||||||
|
current_agent: str | None = None
|
||||||
|
nodes: list[AgentVisibilityTopologyNodeOut]
|
||||||
|
edges: list[dict[str, str]]
|
||||||
|
tasks: list[AgentVisibilityTaskSummaryOut]
|
||||||
|
task_hierarchy: dict[str, list[str]] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentVisibilityEvidenceOut(BaseModel):
|
||||||
|
conversation_id: str
|
||||||
|
task_id: str
|
||||||
|
task: dict[str, Any] | None = None
|
||||||
|
result: dict[str, Any] | None = None
|
||||||
|
tool_outcomes: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
|
verifier: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class AgentVisibilityThreadMessageOut(BaseModel):
|
||||||
|
message_id: str
|
||||||
|
thread_id: str
|
||||||
|
from_agent_id: str
|
||||||
|
to_agent_id: str
|
||||||
|
task_id: str | None = None
|
||||||
|
reply_to_message_id: str | None = None
|
||||||
|
message_type: str
|
||||||
|
content_summary: str
|
||||||
|
created_at: datetime
|
||||||
|
payload: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentVisibilityThreadOut(BaseModel):
|
||||||
|
conversation_id: str
|
||||||
|
thread_id: str
|
||||||
|
total: int
|
||||||
|
items: list[AgentVisibilityThreadMessageOut]
|
||||||
|
|
||||||
|
|
||||||
|
class AgentVisibilityVerifierOut(BaseModel):
|
||||||
|
conversation_id: str
|
||||||
|
status: str | None = None
|
||||||
|
summary: str | None = None
|
||||||
|
evidence: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
|
|||||||
@@ -134,6 +134,27 @@ _CONTINUITY_SNAPSHOT_FIELDS = (
|
|||||||
"current_agent",
|
"current_agent",
|
||||||
"next_step",
|
"next_step",
|
||||||
"agent_trace",
|
"agent_trace",
|
||||||
|
"agent_id",
|
||||||
|
"parent_agent_id",
|
||||||
|
"root_agent_id",
|
||||||
|
"collaboration_depth",
|
||||||
|
"thread_id",
|
||||||
|
"last_message_id",
|
||||||
|
"message_sequence",
|
||||||
|
"spawned_agent_ids",
|
||||||
|
"current_sub_commander",
|
||||||
|
"active_sub_commanders",
|
||||||
|
"sub_commander_trace",
|
||||||
|
"event_trace",
|
||||||
|
"message_trace",
|
||||||
|
"active_tasks",
|
||||||
|
"task_results",
|
||||||
|
"task_hierarchy",
|
||||||
|
"verification_status",
|
||||||
|
"verification_summary",
|
||||||
|
"verification_evidence",
|
||||||
|
"budget_state",
|
||||||
|
"collaboration_budget_history",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
167
backend/tests/backend/app/agents/test_agent_schemas.py
Normal file
167
backend/tests/backend/app/agents/test_agent_schemas.py
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
from app.agents.schemas.event import AgentEvent
|
||||||
|
from app.agents.schemas.task import AgentTask, CollaborationBudget, InterruptRecord, RecoveryRecord, TaskResult
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_task_accepts_day1_fields():
|
||||||
|
task = AgentTask(
|
||||||
|
task_id="task-1",
|
||||||
|
title="Verify foundation",
|
||||||
|
status="in_progress",
|
||||||
|
owner_agent_id="executor",
|
||||||
|
role="verifier",
|
||||||
|
goal="check output",
|
||||||
|
expected_evidence=[{"type": "assertion"}],
|
||||||
|
evidence=[{"type": "log"}],
|
||||||
|
result_summary="running",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert task.task_id == "task-1"
|
||||||
|
assert task.owner_agent_id == "executor"
|
||||||
|
assert task.status == "in_progress"
|
||||||
|
assert task.expected_evidence == [{"type": "assertion"}]
|
||||||
|
assert task.evidence == [{"type": "log"}]
|
||||||
|
assert task.result_summary == "running"
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_task_accepts_day3_runtime_fields():
|
||||||
|
task = AgentTask(
|
||||||
|
task_id="task-2",
|
||||||
|
title="Recover interrupted collaboration",
|
||||||
|
owner_agent_id="executor",
|
||||||
|
parent_task_id="task-1",
|
||||||
|
child_task_ids=["task-2a"],
|
||||||
|
thread_id="thread-1",
|
||||||
|
message_id="msg-1",
|
||||||
|
message_index=2,
|
||||||
|
interrupt_records=[
|
||||||
|
InterruptRecord(
|
||||||
|
interrupt_id="interrupt-1",
|
||||||
|
reason="manual stop",
|
||||||
|
requested_by="coordinator",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
recovery_records=[
|
||||||
|
RecoveryRecord(
|
||||||
|
recovery_id="recovery-1",
|
||||||
|
source_interrupt_id="interrupt-1",
|
||||||
|
resumed_from_task_id="task-2",
|
||||||
|
resumed_from_thread_id="thread-1",
|
||||||
|
strategy="resume_from_checkpoint",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
collaboration_budget=CollaborationBudget(
|
||||||
|
mode="collaboration",
|
||||||
|
max_parallel_tasks=2,
|
||||||
|
remaining_parallel_tasks=1,
|
||||||
|
max_tool_calls=4,
|
||||||
|
remaining_tool_calls=3,
|
||||||
|
max_iterations=5,
|
||||||
|
remaining_iterations=4,
|
||||||
|
escalation_threshold=1,
|
||||||
|
metadata={"max_spawn_depth": 2},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert task.parent_task_id == "task-1"
|
||||||
|
assert task.child_task_ids == ["task-2a"]
|
||||||
|
assert task.thread_id == "thread-1"
|
||||||
|
assert task.message_id == "msg-1"
|
||||||
|
assert task.message_index == 2
|
||||||
|
assert task.interrupt_records[0].interrupt_id == "interrupt-1"
|
||||||
|
assert task.recovery_records[0].recovery_id == "recovery-1"
|
||||||
|
assert task.collaboration_budget.mode == "collaboration"
|
||||||
|
assert task.collaboration_budget.metadata == {"max_spawn_depth": 2}
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_event_accepts_day1_fields():
|
||||||
|
event = AgentEvent(
|
||||||
|
event_id="evt-1",
|
||||||
|
event_type="agent.verify.completed",
|
||||||
|
conversation_id="conv-1",
|
||||||
|
agent_id="executor",
|
||||||
|
sub_commander_id="executor_tasks",
|
||||||
|
task_id="task-1",
|
||||||
|
payload={"status": "passed"},
|
||||||
|
severity="info",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event.event_id == "evt-1"
|
||||||
|
assert event.event_type == "agent.verify.completed"
|
||||||
|
assert event.conversation_id == "conv-1"
|
||||||
|
assert event.payload == {"status": "passed"}
|
||||||
|
assert event.severity == "info"
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_event_accepts_day3_trace_fields():
|
||||||
|
event = AgentEvent(
|
||||||
|
event_id="evt-2",
|
||||||
|
event_type="agent.collaboration.budget.updated",
|
||||||
|
conversation_id="conv-1",
|
||||||
|
agent_id="coordinator",
|
||||||
|
task_id="task-2",
|
||||||
|
parent_task_id="task-1",
|
||||||
|
child_task_id="task-2a",
|
||||||
|
thread_id="thread-1",
|
||||||
|
message_id="msg-3",
|
||||||
|
interrupt_id="interrupt-1",
|
||||||
|
recovery_id="recovery-1",
|
||||||
|
payload={"remaining_parallel_tasks": 1},
|
||||||
|
severity="warning",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event.parent_task_id == "task-1"
|
||||||
|
assert event.child_task_id == "task-2a"
|
||||||
|
assert event.thread_id == "thread-1"
|
||||||
|
assert event.message_id == "msg-3"
|
||||||
|
assert event.interrupt_id == "interrupt-1"
|
||||||
|
assert event.recovery_id == "recovery-1"
|
||||||
|
assert event.severity == "warning"
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_result_supports_collaboration_result_fields():
|
||||||
|
result = TaskResult(
|
||||||
|
task_id="task-1",
|
||||||
|
status="completed",
|
||||||
|
summary="retrieval finished",
|
||||||
|
evidence=[{"type": "source"}],
|
||||||
|
owner_agent_id="librarian",
|
||||||
|
next_action="handoff_to_analyst",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.status == "completed"
|
||||||
|
assert result.owner_agent_id == "librarian"
|
||||||
|
assert result.next_action == "handoff_to_analyst"
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_result_supports_day3_thread_budget_and_recovery_fields():
|
||||||
|
result = TaskResult(
|
||||||
|
task_id="task-2",
|
||||||
|
status="blocked",
|
||||||
|
owner_agent_id="executor",
|
||||||
|
parent_task_id="task-1",
|
||||||
|
child_task_ids=["task-2a"],
|
||||||
|
thread_id="thread-1",
|
||||||
|
message_id="msg-4",
|
||||||
|
message_index=4,
|
||||||
|
interrupt_records=[{"interrupt_id": "interrupt-1", "reason": "budget exceeded"}],
|
||||||
|
recovery_records=[{"recovery_id": "recovery-1", "strategy": "resume_after_budget_reset"}],
|
||||||
|
budget_snapshot=CollaborationBudget(
|
||||||
|
mode="collaboration",
|
||||||
|
max_parallel_tasks=2,
|
||||||
|
remaining_parallel_tasks=0,
|
||||||
|
max_tool_calls=4,
|
||||||
|
remaining_tool_calls=0,
|
||||||
|
),
|
||||||
|
next_action="resume_after_budget_reset",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.parent_task_id == "task-1"
|
||||||
|
assert result.child_task_ids == ["task-2a"]
|
||||||
|
assert result.thread_id == "thread-1"
|
||||||
|
assert result.message_id == "msg-4"
|
||||||
|
assert result.message_index == 4
|
||||||
|
assert result.interrupt_records[0].interrupt_id == "interrupt-1"
|
||||||
|
assert result.recovery_records[0].recovery_id == "recovery-1"
|
||||||
|
assert result.budget_snapshot.mode == "collaboration"
|
||||||
|
assert result.budget_snapshot.remaining_parallel_tasks == 0
|
||||||
|
assert result.next_action == "resume_after_budget_reset"
|
||||||
@@ -2,23 +2,34 @@ import sys
|
|||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
sys.modules.setdefault("trafilatura", Mock())
|
sys.modules.setdefault("trafilatura", Mock())
|
||||||
|
|
||||||
import app.agents.graph as graph_module
|
import app.agents.graph as graph_module
|
||||||
from langchain_core.messages import AIMessage, HumanMessage
|
from langchain_core.messages import AIMessage, HumanMessage
|
||||||
|
|
||||||
from app.agents.graph import (
|
from app.agents.graph import (
|
||||||
|
_build_collaboration_tasks,
|
||||||
_build_verifier_hints,
|
_build_verifier_hints,
|
||||||
_choose_sub_commander,
|
_choose_sub_commander,
|
||||||
|
_create_child_agent,
|
||||||
_execute_tool_calls,
|
_execute_tool_calls,
|
||||||
_parse_json_action,
|
_parse_json_action,
|
||||||
|
_record_interrupt,
|
||||||
|
_record_recovery,
|
||||||
_route_agent_from_user_query,
|
_route_agent_from_user_query,
|
||||||
|
_select_request_mode,
|
||||||
|
_spawn_permission_for_role,
|
||||||
|
_run_collaboration_flow,
|
||||||
_run_sub_commander,
|
_run_sub_commander,
|
||||||
create_agent_graph,
|
create_agent_graph,
|
||||||
master_node,
|
master_node,
|
||||||
planner_node,
|
planner_node,
|
||||||
route_agent,
|
route_agent,
|
||||||
)
|
)
|
||||||
|
from app.agents.schemas.message import AgentMessage
|
||||||
|
from app.agents.schemas.task import AgentTask
|
||||||
from app.agents.state import AgentRole, initial_state
|
from app.agents.state import AgentRole, initial_state
|
||||||
from app.agents.tools import SUB_COMMANDER_TOOLSETS
|
from app.agents.tools import SUB_COMMANDER_TOOLSETS
|
||||||
|
|
||||||
@@ -30,6 +41,15 @@ def _base_state(message: str, user_llm_config: dict | None = None) -> dict:
|
|||||||
'messages': [HumanMessage(content=message)],
|
'messages': [HumanMessage(content=message)],
|
||||||
'user_id': 'u1',
|
'user_id': 'u1',
|
||||||
'conversation_id': 'c1',
|
'conversation_id': 'c1',
|
||||||
|
'parent_conversation_id': None,
|
||||||
|
'thread_id': None,
|
||||||
|
'last_message_id': None,
|
||||||
|
'message_sequence': 0,
|
||||||
|
'agent_id': AgentRole.MASTER.value,
|
||||||
|
'parent_agent_id': None,
|
||||||
|
'root_agent_id': AgentRole.MASTER.value,
|
||||||
|
'collaboration_depth': 0,
|
||||||
|
'spawned_agent_ids': [],
|
||||||
'execution_mode': 'direct',
|
'execution_mode': 'direct',
|
||||||
'current_agent': AgentRole.MASTER.value,
|
'current_agent': AgentRole.MASTER.value,
|
||||||
'next_step': None,
|
'next_step': None,
|
||||||
@@ -39,10 +59,15 @@ def _base_state(message: str, user_llm_config: dict | None = None) -> dict:
|
|||||||
'sub_commander_trace': [],
|
'sub_commander_trace': [],
|
||||||
'agent_trace': [AgentRole.MASTER.value],
|
'agent_trace': [AgentRole.MASTER.value],
|
||||||
'event_trace': [],
|
'event_trace': [],
|
||||||
|
'message_trace': [],
|
||||||
'pending_tasks': [],
|
'pending_tasks': [],
|
||||||
'completed_tasks': [],
|
'completed_tasks': [],
|
||||||
'active_tasks': [],
|
'active_tasks': [],
|
||||||
'task_results': [],
|
'task_results': [],
|
||||||
|
'task_hierarchy': {},
|
||||||
|
'interrupted_tasks': [],
|
||||||
|
'recovery_trace': [],
|
||||||
|
'recovery_points': [],
|
||||||
'tool_calls': [],
|
'tool_calls': [],
|
||||||
'last_tool_result': None,
|
'last_tool_result': None,
|
||||||
'action_results': [],
|
'action_results': [],
|
||||||
@@ -54,6 +79,7 @@ def _base_state(message: str, user_llm_config: dict | None = None) -> dict:
|
|||||||
'verification_summary': None,
|
'verification_summary': None,
|
||||||
'verification_evidence': [],
|
'verification_evidence': [],
|
||||||
'budget_state': None,
|
'budget_state': None,
|
||||||
|
'collaboration_budget_history': [],
|
||||||
'tool_strategy_used': None,
|
'tool_strategy_used': None,
|
||||||
'tool_round_count': 0,
|
'tool_round_count': 0,
|
||||||
'max_tool_rounds': 2,
|
'max_tool_rounds': 2,
|
||||||
@@ -286,6 +312,66 @@ def test_initial_state_sets_structured_continuity_defaults():
|
|||||||
assert state['tool_outcomes'] == []
|
assert state['tool_outcomes'] == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_spawn_permission_for_role_uses_registry_policy():
|
||||||
|
state = _base_state('test')
|
||||||
|
state['current_agent'] = AgentRole.MASTER.value
|
||||||
|
assert _spawn_permission_for_role(state, AgentRole.LIBRARIAN) is True
|
||||||
|
assert _spawn_permission_for_role(state, AgentRole.MASTER) is False
|
||||||
|
|
||||||
|
state['current_agent'] = AgentRole.LIBRARIAN.value
|
||||||
|
assert _spawn_permission_for_role(state, AgentRole.LIBRARIAN) is True
|
||||||
|
assert _spawn_permission_for_role(state, AgentRole.EXECUTOR) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_child_agent_blocks_disallowed_spawn_role():
|
||||||
|
state = _base_state('test')
|
||||||
|
state['current_agent'] = AgentRole.LIBRARIAN.value
|
||||||
|
state['agent_id'] = AgentRole.LIBRARIAN.value
|
||||||
|
task = AgentTask(
|
||||||
|
task_id='task-1',
|
||||||
|
title='分析',
|
||||||
|
role=AgentRole.ANALYST.value,
|
||||||
|
owner_agent_id=AgentRole.ANALYST.value,
|
||||||
|
goal='输出分析',
|
||||||
|
expected_evidence=[{'type': 'analysis'}],
|
||||||
|
)
|
||||||
|
|
||||||
|
child_agent_id = _create_child_agent(state, role=AgentRole.ANALYST, task=task)
|
||||||
|
|
||||||
|
assert child_agent_id is None
|
||||||
|
assert state['spawned_agent_ids'] == []
|
||||||
|
assert state['event_trace'][-1]['event_type'] == 'agent.spawn.blocked'
|
||||||
|
assert state['event_trace'][-1]['payload']['reason'] == 'role_policy_blocked'
|
||||||
|
|
||||||
|
|
||||||
|
def test_record_interrupt_and_recovery_write_day3_traces():
|
||||||
|
state = _base_state('test')
|
||||||
|
state['current_agent'] = AgentRole.EXECUTOR.value
|
||||||
|
task = AgentTask(
|
||||||
|
task_id='task-1',
|
||||||
|
title='执行动作',
|
||||||
|
role=AgentRole.EXECUTOR.value,
|
||||||
|
owner_agent_id=AgentRole.EXECUTOR.value,
|
||||||
|
goal='执行必要动作',
|
||||||
|
expected_evidence=[{'type': 'execution'}],
|
||||||
|
)
|
||||||
|
|
||||||
|
interrupt = _record_interrupt(state, reason='spawn_blocked', task=task, payload={'target_role': AgentRole.EXECUTOR.value})
|
||||||
|
recovery = _record_recovery(state, interrupt=interrupt, strategy='fallback_to_direct_role_execution', task=task)
|
||||||
|
|
||||||
|
assert state['interrupted_tasks'][-1]['interrupt_id'] == interrupt.interrupt_id
|
||||||
|
assert state['recovery_trace'][-1]['recovery_id'] == recovery.recovery_id
|
||||||
|
assert state['recovery_points'][-1]['task_id'] == 'task-1'
|
||||||
|
assert [event['event_type'] for event in state['event_trace']] == [
|
||||||
|
'agent.interrupt.requested',
|
||||||
|
'agent.task.interrupted',
|
||||||
|
'agent.interrupt.completed',
|
||||||
|
'agent.recovery.started',
|
||||||
|
'agent.task.recovered',
|
||||||
|
'agent.recovery.completed',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
async def test_master_node_sets_next_step_when_routing_to_schedule_planner(monkeypatch):
|
async def test_master_node_sets_next_step_when_routing_to_schedule_planner(monkeypatch):
|
||||||
monkeypatch.setattr('app.agents.graph._get_llm_for_state', lambda state: FailIfCalledLLM())
|
monkeypatch.setattr('app.agents.graph._get_llm_for_state', lambda state: FailIfCalledLLM())
|
||||||
|
|
||||||
@@ -347,6 +433,222 @@ async def test_planner_node_clears_next_step_after_consuming_routed_turn(monkeyp
|
|||||||
assert result['final_response'] is not None
|
assert result['final_response'] is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_select_request_mode_prefers_collaboration_for_multi_role_request():
|
||||||
|
mode, metadata = _select_request_mode('先帮我搜索竞品资料,然后分析风险,再给我安排下周计划')
|
||||||
|
|
||||||
|
assert mode == 'collaboration'
|
||||||
|
assert metadata['reason'] == 'multi_role_request'
|
||||||
|
assert AgentRole.LIBRARIAN.value in metadata['roles']
|
||||||
|
assert AgentRole.ANALYST.value in metadata['roles']
|
||||||
|
assert AgentRole.SCHEDULE_PLANNER.value in metadata['roles']
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_collaboration_tasks_generates_structured_owned_tasks():
|
||||||
|
tasks = _build_collaboration_tasks('先帮我搜索竞品资料,然后分析风险,再给我安排下周计划')
|
||||||
|
|
||||||
|
assert len(tasks) == 3
|
||||||
|
assert [task.role for task in tasks] == [
|
||||||
|
AgentRole.LIBRARIAN.value,
|
||||||
|
AgentRole.ANALYST.value,
|
||||||
|
AgentRole.SCHEDULE_PLANNER.value,
|
||||||
|
]
|
||||||
|
assert all(task.owner_agent_id for task in tasks)
|
||||||
|
assert all(task.expected_evidence for task in tasks)
|
||||||
|
|
||||||
|
|
||||||
|
def test_verify_collaboration_results_uses_explicit_task_results_snapshot():
|
||||||
|
task = AgentTask(
|
||||||
|
task_id='task-1',
|
||||||
|
title='补齐事实与证据',
|
||||||
|
role=AgentRole.LIBRARIAN.value,
|
||||||
|
owner_agent_id=AgentRole.LIBRARIAN.value,
|
||||||
|
goal='检索资料',
|
||||||
|
expected_evidence=[{'type': 'evidence'}],
|
||||||
|
)
|
||||||
|
state = _base_state('test')
|
||||||
|
state['task_results'] = [
|
||||||
|
{
|
||||||
|
'task_id': 'stale-task',
|
||||||
|
'status': 'failed',
|
||||||
|
'summary': 'stale failure',
|
||||||
|
'evidence': [{'type': 'verification'}],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
graph_module._verify_collaboration_results(
|
||||||
|
state,
|
||||||
|
[task],
|
||||||
|
[
|
||||||
|
{
|
||||||
|
'task_id': 'task-1',
|
||||||
|
'status': 'completed',
|
||||||
|
'summary': 'done',
|
||||||
|
'evidence': [{'type': 'verification'}],
|
||||||
|
'owner_agent_id': AgentRole.LIBRARIAN.value,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert state['verification_status'] == 'passed'
|
||||||
|
assert '1/1 个子任务' in state['verification_summary']
|
||||||
|
|
||||||
|
|
||||||
|
def test_verify_collaboration_results_ignores_stale_results_outside_current_plan():
|
||||||
|
tasks = [
|
||||||
|
AgentTask(
|
||||||
|
task_id='task-1',
|
||||||
|
title='补齐事实与证据',
|
||||||
|
role=AgentRole.LIBRARIAN.value,
|
||||||
|
owner_agent_id=AgentRole.LIBRARIAN.value,
|
||||||
|
goal='检索资料',
|
||||||
|
expected_evidence=[{'type': 'evidence'}],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
state = _base_state('test')
|
||||||
|
|
||||||
|
graph_module._verify_collaboration_results(
|
||||||
|
state,
|
||||||
|
tasks,
|
||||||
|
[
|
||||||
|
{
|
||||||
|
'task_id': 'stale-task',
|
||||||
|
'status': 'failed',
|
||||||
|
'summary': 'stale failure',
|
||||||
|
'evidence': [{'type': 'verification'}],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'task_id': 'task-1',
|
||||||
|
'status': 'completed',
|
||||||
|
'summary': 'done',
|
||||||
|
'evidence': [{'type': 'verification'}],
|
||||||
|
'owner_agent_id': AgentRole.LIBRARIAN.value,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert state['verification_status'] == 'passed'
|
||||||
|
assert '1/1 个子任务' in state['verification_summary']
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_sub_commander_verifies_only_current_turn_tool_outcomes(monkeypatch):
|
||||||
|
class FakeBoundLLM:
|
||||||
|
def __init__(self, response):
|
||||||
|
self._response = response
|
||||||
|
|
||||||
|
def bind_tools(self, _toolset):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def ainvoke(self, _messages):
|
||||||
|
return self._response
|
||||||
|
|
||||||
|
state = _base_state('查一下资料')
|
||||||
|
state['tool_outcomes'] = [
|
||||||
|
{
|
||||||
|
'tool_name': 'stale_tool',
|
||||||
|
'args': {'query': 'old'},
|
||||||
|
'result_preview': '工具执行失败: stale',
|
||||||
|
'verifier_hints': {'tool_name': 'stale_tool'},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
response = AIMessage(content='当前回合完成')
|
||||||
|
monkeypatch.setattr(graph_module, '_get_llm_for_state', lambda _state: FakeBoundLLM(response))
|
||||||
|
monkeypatch.setattr(graph_module, '_resolve_capabilities', lambda _state, _llm: type('Caps', (), {'supports_native_tools': True})())
|
||||||
|
monkeypatch.setattr(graph_module, '_choose_sub_commander', lambda _role, _query: 'librarian_retrieval')
|
||||||
|
monkeypatch.setattr(graph_module, '_record_sub_commander', lambda *_args, **_kwargs: None)
|
||||||
|
|
||||||
|
await graph_module._run_sub_commander(
|
||||||
|
state,
|
||||||
|
AgentRole.LIBRARIAN,
|
||||||
|
'prompt',
|
||||||
|
'查一下资料',
|
||||||
|
use_tools=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert state['final_response'] == '当前回合完成'
|
||||||
|
assert state['verification_status'] == 'passed'
|
||||||
|
|
||||||
|
|
||||||
|
async def test_run_collaboration_flow_collects_task_results_and_verifies(monkeypatch):
|
||||||
|
planned_tasks = [
|
||||||
|
AgentTask(
|
||||||
|
task_id='task-1',
|
||||||
|
title='补齐事实与证据',
|
||||||
|
role=AgentRole.LIBRARIAN.value,
|
||||||
|
owner_agent_id=AgentRole.LIBRARIAN.value,
|
||||||
|
goal='检索资料',
|
||||||
|
expected_evidence=[{'type': 'evidence'}],
|
||||||
|
),
|
||||||
|
AgentTask(
|
||||||
|
task_id='task-2',
|
||||||
|
title='给出分析与判断',
|
||||||
|
role=AgentRole.ANALYST.value,
|
||||||
|
owner_agent_id=AgentRole.ANALYST.value,
|
||||||
|
goal='输出分析',
|
||||||
|
expected_evidence=[{'type': 'analysis'}],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
async def fake_run_sub_commander(state, role, manager_prompt, user_query, *, use_tools, summary_target=None):
|
||||||
|
state['current_agent'] = role.value
|
||||||
|
state['current_sub_commander'] = f'{role.value}_worker'
|
||||||
|
state['final_response'] = f'{role.value} finished'
|
||||||
|
state['verification_status'] = 'passed'
|
||||||
|
state['verification_summary'] = f'{role.value} verified'
|
||||||
|
state['tool_outcomes'] = [
|
||||||
|
*(state.get('tool_outcomes') or []),
|
||||||
|
{
|
||||||
|
'tool_name': f'{role.value}_tool',
|
||||||
|
'args': {'query': user_query},
|
||||||
|
'result_preview': 'ok',
|
||||||
|
'verifier_hints': {'tool_name': f'{role.value}_tool'},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
state['messages'] = [*state.get('messages', []), AIMessage(content=state['final_response'])]
|
||||||
|
return state
|
||||||
|
|
||||||
|
monkeypatch.setattr(graph_module, '_build_collaboration_tasks', lambda user_query: planned_tasks)
|
||||||
|
monkeypatch.setattr(graph_module, '_run_sub_commander', fake_run_sub_commander)
|
||||||
|
|
||||||
|
state = _base_state('先帮我搜索竞品资料,然后分析风险')
|
||||||
|
result = await _run_collaboration_flow(state, '先帮我搜索竞品资料,然后分析风险')
|
||||||
|
|
||||||
|
assert result['execution_mode'] == 'collaboration'
|
||||||
|
assert len(result['active_tasks']) == 2
|
||||||
|
assert len(result['task_results']) == 2
|
||||||
|
assert result['task_results'][0]['status'] == 'completed'
|
||||||
|
assert result['task_results'][1]['owner_agent_id'] == AgentRole.ANALYST.value
|
||||||
|
assert result['verification_status'] == 'passed'
|
||||||
|
assert '协作模式已完成 2/2 个子任务' in result['verification_summary']
|
||||||
|
assert '已按协作模式回收 2 个子任务结果' in result['final_response']
|
||||||
|
assert len(result['message_trace']) >= 2
|
||||||
|
assert all(message['message_type'] == 'task_update' for message in result['message_trace'])
|
||||||
|
assert result['message_trace'][-1]['message_type'] == 'task_update'
|
||||||
|
assert 'agent.created' in [event['event_type'] for event in result['event_trace']]
|
||||||
|
assert 'agent.message.sent' in [event['event_type'] for event in result['event_trace']]
|
||||||
|
assert 'agent.spawn.blocked' not in [event['event_type'] for event in result['event_trace']]
|
||||||
|
assert result['spawned_agent_ids']
|
||||||
|
assert all(not agent_id.startswith('blocked-') for agent_id in result['spawned_agent_ids'])
|
||||||
|
assert result['task_hierarchy']
|
||||||
|
|
||||||
|
|
||||||
|
async def test_master_node_enters_collaboration_mode_for_complex_multi_role_request(monkeypatch):
|
||||||
|
async def fake_collaboration_flow(state, user_query):
|
||||||
|
state['execution_mode'] = 'collaboration'
|
||||||
|
state['final_response'] = 'collaboration done'
|
||||||
|
state['messages'] = [*state.get('messages', []), AIMessage(content=state['final_response'])]
|
||||||
|
return state
|
||||||
|
|
||||||
|
monkeypatch.setattr(graph_module, '_run_collaboration_flow', fake_collaboration_flow)
|
||||||
|
|
||||||
|
state = _base_state('先帮我搜索竞品资料,然后分析风险,再给我安排下周计划')
|
||||||
|
result = await master_node(state)
|
||||||
|
|
||||||
|
assert result['execution_mode'] == 'collaboration'
|
||||||
|
assert result['final_response'] == 'collaboration done'
|
||||||
|
|
||||||
|
|
||||||
async def test_master_node_returns_stable_reply_for_simple_greeting(monkeypatch):
|
async def test_master_node_returns_stable_reply_for_simple_greeting(monkeypatch):
|
||||||
monkeypatch.setattr('app.agents.graph._get_llm_for_state', lambda state: FailIfCalledLLM())
|
monkeypatch.setattr('app.agents.graph._get_llm_for_state', lambda state: FailIfCalledLLM())
|
||||||
|
|
||||||
@@ -1160,6 +1462,8 @@ async def test_execute_tool_calls_records_schema_events_and_aggregate_summaries(
|
|||||||
assert all(event['conversation_id'] == 'c1' for event in state['event_trace'])
|
assert all(event['conversation_id'] == 'c1' for event in state['event_trace'])
|
||||||
assert all(event['agent_id'] == AgentRole.MASTER.value for event in state['event_trace'])
|
assert all(event['agent_id'] == AgentRole.MASTER.value for event in state['event_trace'])
|
||||||
assert all(event['task_id'] == 'task-1' for event in state['event_trace'])
|
assert all(event['task_id'] == 'task-1' for event in state['event_trace'])
|
||||||
|
assert all(event['thread_id'] is not None for event in state['event_trace'])
|
||||||
|
assert all(event['message_id'] is None for event in state['event_trace'])
|
||||||
|
|
||||||
|
|
||||||
async def test_execute_tool_calls_aggregates_multiple_tool_turns_without_overwrite(monkeypatch):
|
async def test_execute_tool_calls_aggregates_multiple_tool_turns_without_overwrite(monkeypatch):
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from app.agents.prompts import MASTER_SYSTEM_PROMPT
|
from app.agents.prompts import COORDINATOR_SYSTEM_PROMPT, MASTER_SYSTEM_PROMPT
|
||||||
|
|
||||||
|
|
||||||
def test_master_prompt_forbids_subagent_rollcall_in_simple_greetings():
|
def test_master_prompt_forbids_subagent_rollcall_in_simple_greetings():
|
||||||
@@ -10,3 +10,10 @@ def test_master_prompt_does_not_include_full_canned_answers_for_greetings_or_ide
|
|||||||
assert 'Jarvis:您好。我在。' not in MASTER_SYSTEM_PROMPT
|
assert 'Jarvis:您好。我在。' not in MASTER_SYSTEM_PROMPT
|
||||||
assert 'Jarvis:我是 Jarvis。' not in MASTER_SYSTEM_PROMPT
|
assert 'Jarvis:我是 Jarvis。' not in MASTER_SYSTEM_PROMPT
|
||||||
assert 'Jarvis:主要做三件事。' not in MASTER_SYSTEM_PROMPT
|
assert 'Jarvis:主要做三件事。' not in MASTER_SYSTEM_PROMPT
|
||||||
|
|
||||||
|
|
||||||
|
def test_coordinator_prompt_limits_collaboration_scope():
|
||||||
|
assert "2~4 个子任务" in COORDINATOR_SYSTEM_PROMPT
|
||||||
|
assert "禁止无限递归拆分" in COORDINATOR_SYSTEM_PROMPT
|
||||||
|
assert "schedule_planner" in COORDINATOR_SYSTEM_PROMPT
|
||||||
|
assert "librarian" in COORDINATOR_SYSTEM_PROMPT
|
||||||
|
|||||||
@@ -307,6 +307,7 @@ def test_build_registry_indexes_exposes_manifest_lookups_by_id() -> None:
|
|||||||
indexes = build_registry_indexes(bundle)
|
indexes = build_registry_indexes(bundle)
|
||||||
|
|
||||||
assert indexes.agent_by_id
|
assert indexes.agent_by_id
|
||||||
|
assert indexes.agent_by_role_value
|
||||||
assert indexes.sub_commander_by_id
|
assert indexes.sub_commander_by_id
|
||||||
assert indexes.capability_by_id
|
assert indexes.capability_by_id
|
||||||
assert isinstance(indexes.specialist_template_by_id, Mapping)
|
assert isinstance(indexes.specialist_template_by_id, Mapping)
|
||||||
@@ -362,6 +363,14 @@ def test_build_registry_indexes_exposes_prompt_keys_skill_context_keys_and_capab
|
|||||||
sub_commander.sub_commander_id: tuple(sub_commander.capability_ids)
|
sub_commander.sub_commander_id: tuple(sub_commander.capability_ids)
|
||||||
for sub_commander in bundle.sub_commanders
|
for sub_commander in bundle.sub_commanders
|
||||||
}
|
}
|
||||||
|
assert indexes.agent_by_role_value == {
|
||||||
|
agent.role_value: agent for agent in bundle.agents
|
||||||
|
}
|
||||||
|
assert indexes.spawnable_role_values_by_agent_id == {
|
||||||
|
agent.agent_id: tuple(agent.allowed_spawn_role_values)
|
||||||
|
for agent in bundle.agents
|
||||||
|
if agent.can_spawn_children and agent.allowed_spawn_role_values
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_validate_registry_bundle_accepts_loaded_builtin_registry_bundle() -> None:
|
def test_validate_registry_bundle_accepts_loaded_builtin_registry_bundle() -> None:
|
||||||
|
|||||||
@@ -1,66 +1,135 @@
|
|||||||
from app.agents.schemas.event import AgentEvent
|
from app.agents.schemas import AgentEvent, AgentTask, TaskResult
|
||||||
from app.agents.schemas.task import AgentTask
|
from app.agents.schemas.task import CollaborationBudget, InterruptRecord, RecoveryRecord
|
||||||
from app.agents.verifier import verify_task_result
|
from app.agents.state import initial_state
|
||||||
|
from app.agents.verifier import apply_verification_verdict, normalize_task_result, verify_task_result
|
||||||
|
|
||||||
|
|
||||||
def test_agent_task_accepts_day1_fields():
|
def test_agent_task_supports_day3_interrupt_recovery_and_budget_fields():
|
||||||
|
interrupt = InterruptRecord(interrupt_id="interrupt-1", reason="user_cancel")
|
||||||
|
recovery = RecoveryRecord(recovery_id="recovery-1", source_interrupt_id="interrupt-1", resumed_from_task_id="task-1")
|
||||||
|
budget = CollaborationBudget(
|
||||||
|
mode="collaboration",
|
||||||
|
max_parallel_tasks=3,
|
||||||
|
remaining_parallel_tasks=2,
|
||||||
|
max_tool_calls=6,
|
||||||
|
remaining_tool_calls=4,
|
||||||
|
metadata={"phase": "day3"},
|
||||||
|
)
|
||||||
|
|
||||||
task = AgentTask(
|
task = AgentTask(
|
||||||
task_id="task-1",
|
task_id="task-1",
|
||||||
title="Verify foundation",
|
title="Recover interrupted collaboration task",
|
||||||
status="in_progress",
|
owner_agent_id="analyst",
|
||||||
owner_agent_id="executor",
|
role="analyst",
|
||||||
role="verifier",
|
parent_task_id="parent-1",
|
||||||
goal="check output",
|
child_task_ids=["child-1"],
|
||||||
expected_evidence=[{"type": "assertion"}],
|
thread_id="thread-1",
|
||||||
evidence=[{"type": "log"}],
|
message_id="message-1",
|
||||||
result_summary="running",
|
message_index=3,
|
||||||
|
interrupt_records=[interrupt],
|
||||||
|
recovery_records=[recovery],
|
||||||
|
collaboration_budget=budget,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert task.task_id == "task-1"
|
payload = task.model_dump(mode="json")
|
||||||
assert task.owner_agent_id == "executor"
|
|
||||||
assert task.status == "in_progress"
|
assert payload["parent_task_id"] == "parent-1"
|
||||||
assert task.expected_evidence == [{"type": "assertion"}]
|
assert payload["child_task_ids"] == ["child-1"]
|
||||||
assert task.evidence == [{"type": "log"}]
|
assert payload["thread_id"] == "thread-1"
|
||||||
assert task.result_summary == "running"
|
assert payload["message_id"] == "message-1"
|
||||||
|
assert payload["message_index"] == 3
|
||||||
|
assert payload["interrupt_records"][0]["interrupt_id"] == "interrupt-1"
|
||||||
|
assert payload["recovery_records"][0]["recovery_id"] == "recovery-1"
|
||||||
|
assert payload["collaboration_budget"]["mode"] == "collaboration"
|
||||||
|
assert payload["collaboration_budget"]["remaining_tool_calls"] == 4
|
||||||
|
|
||||||
|
|
||||||
def test_agent_event_accepts_day1_fields():
|
def test_agent_event_supports_day3_thread_interrupt_and_recovery_metadata():
|
||||||
event = AgentEvent(
|
event = AgentEvent(
|
||||||
event_id="evt-1",
|
event_id="evt-1",
|
||||||
event_type="agent.verify.completed",
|
event_type="agent.task.recovered",
|
||||||
conversation_id="conv-1",
|
conversation_id="conv-1",
|
||||||
agent_id="executor",
|
agent_id="executor",
|
||||||
sub_commander_id="executor_tasks",
|
|
||||||
task_id="task-1",
|
task_id="task-1",
|
||||||
payload={"status": "passed"},
|
parent_task_id="parent-1",
|
||||||
severity="info",
|
child_task_id="child-1",
|
||||||
|
thread_id="thread-1",
|
||||||
|
message_id="message-1",
|
||||||
|
interrupt_id="interrupt-1",
|
||||||
|
recovery_id="recovery-1",
|
||||||
|
severity="warning",
|
||||||
|
payload={"status": "resumed"},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert event.event_id == "evt-1"
|
payload = event.model_dump(mode="json")
|
||||||
assert event.event_type == "agent.verify.completed"
|
|
||||||
assert event.conversation_id == "conv-1"
|
assert payload["event_type"] == "agent.task.recovered"
|
||||||
assert event.payload == {"status": "passed"}
|
assert payload["parent_task_id"] == "parent-1"
|
||||||
assert event.severity == "info"
|
assert payload["child_task_id"] == "child-1"
|
||||||
|
assert payload["thread_id"] == "thread-1"
|
||||||
|
assert payload["message_id"] == "message-1"
|
||||||
|
assert payload["interrupt_id"] == "interrupt-1"
|
||||||
|
assert payload["recovery_id"] == "recovery-1"
|
||||||
|
assert payload["severity"] == "warning"
|
||||||
|
|
||||||
|
|
||||||
def test_verifier_verdict_is_separate_from_task_lifecycle_status():
|
def test_normalize_task_result_preserves_day3_metadata_fields():
|
||||||
task = AgentTask(task_id="task-1", title="Verify", status="blocked", result_summary="waiting")
|
result = normalize_task_result(
|
||||||
|
{
|
||||||
|
"task_id": "task-1",
|
||||||
|
"status": "completed",
|
||||||
|
"summary": "Recovered successfully.",
|
||||||
|
"owner_agent_id": "executor",
|
||||||
|
"parent_task_id": "parent-1",
|
||||||
|
"child_task_ids": ["child-1"],
|
||||||
|
"thread_id": "thread-1",
|
||||||
|
"message_id": "message-1",
|
||||||
|
"message_index": 2,
|
||||||
|
"interrupt_records": [{"interrupt_id": "interrupt-1", "reason": "user_pause"}],
|
||||||
|
"recovery_records": [{"recovery_id": "recovery-1", "source_interrupt_id": "interrupt-1"}],
|
||||||
|
"budget_snapshot": {"mode": "collaboration", "max_parallel_tasks": 4},
|
||||||
|
"next_action": "notify_user",
|
||||||
|
"output_data": {"ok": True},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
verdict = verify_task_result(task=task)
|
assert result.parent_task_id == "parent-1"
|
||||||
|
assert result.child_task_ids == ["child-1"]
|
||||||
assert verdict.status == "skipped"
|
assert result.thread_id == "thread-1"
|
||||||
assert verdict.summary == "waiting"
|
assert result.message_id == "message-1"
|
||||||
|
assert result.message_index == 2
|
||||||
|
assert result.interrupt_records[0].interrupt_id == "interrupt-1"
|
||||||
|
assert result.recovery_records[0].recovery_id == "recovery-1"
|
||||||
|
assert result.budget_snapshot.mode == "collaboration"
|
||||||
|
assert result.budget_snapshot.max_parallel_tasks == 4
|
||||||
|
assert result.next_action == "notify_user"
|
||||||
|
assert result.output_data == {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
def test_verifier_prefers_explicit_result_success_signal():
|
def test_apply_verification_verdict_updates_state_with_recovery_evidence():
|
||||||
verdict = verify_task_result(result={"success": True, "summary": "all checks passed"})
|
state = initial_state("u1", "c1")
|
||||||
|
|
||||||
assert verdict.status == "passed"
|
verdict = verify_task_result(
|
||||||
assert verdict.summary == "all checks passed"
|
status="passed",
|
||||||
|
summary="Interrupt and recovery chain verified.",
|
||||||
|
evidence=[
|
||||||
|
{
|
||||||
|
"task_id": "task-1",
|
||||||
|
"thread_id": "thread-1",
|
||||||
|
"interrupt_id": "interrupt-1",
|
||||||
|
"recovery_id": "recovery-1",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
updated_state = apply_verification_verdict(state, verdict)
|
||||||
|
|
||||||
|
assert updated_state["verification_status"] == "passed"
|
||||||
def test_verifier_fails_when_no_verification_input_exists():
|
assert updated_state["verification_summary"] == "Interrupt and recovery chain verified."
|
||||||
verdict = verify_task_result()
|
assert updated_state["verification_evidence"] == [
|
||||||
|
{
|
||||||
assert verdict.status == "failed"
|
"task_id": "task-1",
|
||||||
assert verdict.summary == "No verification input available."
|
"thread_id": "thread-1",
|
||||||
|
"interrupt_id": "interrupt-1",
|
||||||
|
"recovery_id": "recovery-1",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|||||||
39
backend/tests/backend/app/agents/test_verifier.py
Normal file
39
backend/tests/backend/app/agents/test_verifier.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
from app.agents.schemas.task import AgentTask
|
||||||
|
from app.agents.verifier import verify_task_result
|
||||||
|
|
||||||
|
|
||||||
|
def test_verifier_verdict_is_separate_from_task_lifecycle_status():
|
||||||
|
task = AgentTask(task_id="task-1", title="Verify", status="blocked", result_summary="waiting")
|
||||||
|
|
||||||
|
verdict = verify_task_result(task=task)
|
||||||
|
|
||||||
|
assert verdict.status == "skipped"
|
||||||
|
assert verdict.summary == "waiting"
|
||||||
|
|
||||||
|
|
||||||
|
def test_verifier_prefers_explicit_result_success_signal():
|
||||||
|
verdict = verify_task_result(result={"success": True, "summary": "all checks passed"})
|
||||||
|
|
||||||
|
assert verdict.status == "passed"
|
||||||
|
assert verdict.summary == "all checks passed"
|
||||||
|
|
||||||
|
|
||||||
|
def test_verifier_treats_completed_task_result_as_passed():
|
||||||
|
verdict = verify_task_result(result={"status": "completed", "summary": "done", "evidence": [{"type": "log"}]})
|
||||||
|
|
||||||
|
assert verdict.status == "passed"
|
||||||
|
assert verdict.summary == "done"
|
||||||
|
|
||||||
|
|
||||||
|
def test_verifier_treats_blocked_task_result_as_skipped():
|
||||||
|
verdict = verify_task_result(result={"status": "blocked", "summary": "waiting on user"})
|
||||||
|
|
||||||
|
assert verdict.status == "skipped"
|
||||||
|
assert verdict.summary == "waiting on user"
|
||||||
|
|
||||||
|
|
||||||
|
def test_verifier_fails_when_no_verification_input_exists():
|
||||||
|
verdict = verify_task_result()
|
||||||
|
|
||||||
|
assert verdict.status == "failed"
|
||||||
|
assert verdict.summary == "No verification input available."
|
||||||
619
backend/tests/backend/app/agents/test_visibility_api.py
Normal file
619
backend/tests/backend/app/agents/test_visibility_api.py
Normal file
@@ -0,0 +1,619 @@
|
|||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from httpx import ASGITransport, AsyncClient
|
||||||
|
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
|
import app.models # noqa: F401
|
||||||
|
from app.database import Base, get_db
|
||||||
|
from app.models.conversation import Conversation
|
||||||
|
from app.models.user import User
|
||||||
|
from app.routers.agent import router as agent_router
|
||||||
|
from app.routers.auth import get_current_user
|
||||||
|
from app.services.auth_service import get_password_hash
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def visibility_env(tmp_path):
|
||||||
|
db_path = tmp_path / 'test_visibility_api.db'
|
||||||
|
engine = create_async_engine(f'sqlite+aiosqlite:///{db_path}', future=True)
|
||||||
|
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
|
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
snapshot = {
|
||||||
|
'kind': 'agent_continuity_state',
|
||||||
|
'version': 1,
|
||||||
|
'state': {
|
||||||
|
'agent_id': 'master',
|
||||||
|
'root_agent_id': 'master',
|
||||||
|
'current_agent': 'analyst-1234abcd',
|
||||||
|
'thread_id': 'thread-1',
|
||||||
|
'spawned_agent_ids': ['analyst-1234abcd'],
|
||||||
|
'event_trace': [
|
||||||
|
{
|
||||||
|
'event_id': 'evt-1',
|
||||||
|
'event_type': 'agent.created',
|
||||||
|
'timestamp': (now - timedelta(minutes=10)).isoformat(),
|
||||||
|
'conversation_id': 'placeholder',
|
||||||
|
'agent_id': 'master',
|
||||||
|
'thread_id': 'thread-1',
|
||||||
|
'task_id': 'task-1',
|
||||||
|
'payload': {'child_agent_id': 'analyst-1234abcd'},
|
||||||
|
'severity': 'info',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'event_id': 'evt-2',
|
||||||
|
'event_type': 'agent.tool.result',
|
||||||
|
'timestamp': (now - timedelta(minutes=5)).isoformat(),
|
||||||
|
'conversation_id': 'placeholder',
|
||||||
|
'agent_id': 'analyst-1234abcd',
|
||||||
|
'thread_id': 'thread-1',
|
||||||
|
'task_id': 'task-1',
|
||||||
|
'payload': {'tool_name': 'search_web', 'result_preview': 'ok'},
|
||||||
|
'severity': 'info',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
'message_trace': [
|
||||||
|
{
|
||||||
|
'message_id': 'msg-1',
|
||||||
|
'thread_id': 'thread-1',
|
||||||
|
'from_agent_id': 'master',
|
||||||
|
'to_agent_id': 'analyst-1234abcd',
|
||||||
|
'task_id': 'task-1',
|
||||||
|
'message_type': 'task_request',
|
||||||
|
'content_summary': 'Analyze the issue',
|
||||||
|
'created_at': (now - timedelta(minutes=9)).isoformat(),
|
||||||
|
'payload': {},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'message_id': 'msg-2',
|
||||||
|
'thread_id': 'thread-1',
|
||||||
|
'from_agent_id': 'analyst-1234abcd',
|
||||||
|
'to_agent_id': 'master',
|
||||||
|
'task_id': 'task-1',
|
||||||
|
'reply_to_message_id': 'msg-1',
|
||||||
|
'message_type': 'task_update',
|
||||||
|
'content_summary': 'Done',
|
||||||
|
'created_at': (now - timedelta(minutes=4)).isoformat(),
|
||||||
|
'payload': {'status': 'completed'},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
'active_tasks': [
|
||||||
|
{
|
||||||
|
'task_id': 'task-1',
|
||||||
|
'title': 'Analyze issue',
|
||||||
|
'role': 'analyst',
|
||||||
|
'owner_agent_id': 'analyst-1234abcd',
|
||||||
|
'status': 'completed',
|
||||||
|
'thread_id': 'thread-1',
|
||||||
|
'result_summary': 'Analysis complete',
|
||||||
|
'evidence': [
|
||||||
|
{
|
||||||
|
'tool_name': 'search_web',
|
||||||
|
'args': {'query': 'jarvis visibility'},
|
||||||
|
'result_preview': 'ok',
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
'task_results': [
|
||||||
|
{
|
||||||
|
'task_id': 'task-1',
|
||||||
|
'status': 'completed',
|
||||||
|
'summary': 'Analysis complete',
|
||||||
|
'owner_agent_id': 'analyst-1234abcd',
|
||||||
|
'thread_id': 'thread-1',
|
||||||
|
'evidence': [
|
||||||
|
{
|
||||||
|
'tool_name': 'search_web',
|
||||||
|
'args': {'query': 'jarvis visibility'},
|
||||||
|
'result_preview': 'ok',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'type': 'verification',
|
||||||
|
'status': 'passed',
|
||||||
|
'summary': 'Verified',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
'task_hierarchy': {'root-task': ['task-1']},
|
||||||
|
'tool_outcomes': [
|
||||||
|
{
|
||||||
|
'tool_name': 'search_web',
|
||||||
|
'args': {'query': 'jarvis visibility'},
|
||||||
|
'result_preview': 'ok',
|
||||||
|
'verifier_hints': {'tool_name': 'search_web'},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
'verification_status': 'passed',
|
||||||
|
'verification_summary': 'All task evidence verified.',
|
||||||
|
'verification_evidence': [
|
||||||
|
{'task_id': 'task-1', 'status': 'passed', 'summary': 'Verified'}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
async with session_factory() as session:
|
||||||
|
user = User(
|
||||||
|
username='visibility_user',
|
||||||
|
email='visibility@example.com',
|
||||||
|
hashed_password=get_password_hash('secret123'),
|
||||||
|
full_name='Visibility Tester',
|
||||||
|
)
|
||||||
|
session.add(user)
|
||||||
|
await session.flush()
|
||||||
|
conversation = Conversation(user_id=user.id, title='Visibility test', agent_state=snapshot)
|
||||||
|
session.add(conversation)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(user)
|
||||||
|
await session.refresh(conversation)
|
||||||
|
|
||||||
|
snapshot['state']['event_trace'][0]['conversation_id'] = conversation.id
|
||||||
|
snapshot['state']['event_trace'][1]['conversation_id'] = conversation.id
|
||||||
|
conversation.agent_state = snapshot
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(conversation)
|
||||||
|
|
||||||
|
async def override_get_db():
|
||||||
|
async with session_factory() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
async def override_get_current_user():
|
||||||
|
return user
|
||||||
|
|
||||||
|
test_app = FastAPI()
|
||||||
|
test_app.include_router(agent_router)
|
||||||
|
test_app.dependency_overrides[get_db] = override_get_db
|
||||||
|
test_app.dependency_overrides[get_current_user] = override_get_current_user
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield test_app, {
|
||||||
|
'conversation_id': conversation.id,
|
||||||
|
'thread_id': 'thread-1',
|
||||||
|
'task_id': 'task-1',
|
||||||
|
'started_after': (now - timedelta(minutes=11)).isoformat(),
|
||||||
|
'ended_before': (now - timedelta(minutes=1)).isoformat(),
|
||||||
|
}
|
||||||
|
finally:
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_visibility_events_support_filters_and_pagination(visibility_env):
|
||||||
|
app, ids = visibility_env
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||||
|
response = await client.get(
|
||||||
|
'/api/agents/visibility/events',
|
||||||
|
params={
|
||||||
|
'conversation_id': ids['conversation_id'],
|
||||||
|
'agent_id': 'analyst-1234abcd',
|
||||||
|
'thread_id': ids['thread_id'],
|
||||||
|
'event_type': 'agent.tool.result',
|
||||||
|
'limit': 1,
|
||||||
|
'offset': 0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
payload = response.json()
|
||||||
|
assert payload['total'] == 1
|
||||||
|
assert payload['limit'] == 1
|
||||||
|
assert payload['items'][0]['event_id'] == 'evt-2'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_visibility_topology_returns_nodes_edges_and_task_summary(visibility_env):
|
||||||
|
app, ids = visibility_env
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||||
|
response = await client.get(
|
||||||
|
'/api/agents/visibility/topology',
|
||||||
|
params={'conversation_id': ids['conversation_id']},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
payload = response.json()
|
||||||
|
assert payload['root_agent_id'] == 'master'
|
||||||
|
assert payload['current_agent'] == 'analyst-1234abcd'
|
||||||
|
assert any(node['agent_id'] == 'analyst-1234abcd' for node in payload['nodes'])
|
||||||
|
assert any(edge['child_agent_id'] == 'analyst-1234abcd' for edge in payload['edges'])
|
||||||
|
assert payload['tasks'][0]['task_id'] == ids['task_id']
|
||||||
|
assert payload['task_hierarchy'] == {'root-task': ['task-1']}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_visibility_task_evidence_returns_tool_and_verifier_evidence(visibility_env):
|
||||||
|
app, ids = visibility_env
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||||
|
response = await client.get(
|
||||||
|
f'/api/agents/visibility/tasks/{ids["task_id"]}/evidence',
|
||||||
|
params={'conversation_id': ids['conversation_id']},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
payload = response.json()
|
||||||
|
assert payload['task']['task_id'] == ids['task_id']
|
||||||
|
assert payload['result']['status'] == 'completed'
|
||||||
|
assert payload['tool_outcomes'][0]['tool_name'] == 'search_web'
|
||||||
|
assert payload['verifier']['status'] == 'passed'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_visibility_task_evidence_uses_task_evidence_instead_of_global_tool_outcomes(tmp_path):
|
||||||
|
db_path = tmp_path / 'test_visibility_api_task_evidence_filter.db'
|
||||||
|
engine = create_async_engine(f'sqlite+aiosqlite:///{db_path}', future=True)
|
||||||
|
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
|
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
|
||||||
|
snapshot = {
|
||||||
|
'kind': 'agent_continuity_state',
|
||||||
|
'version': 1,
|
||||||
|
'state': {
|
||||||
|
'agent_id': 'master',
|
||||||
|
'root_agent_id': 'master',
|
||||||
|
'current_agent': 'analyst-1234abcd',
|
||||||
|
'thread_id': 'thread-1',
|
||||||
|
'spawned_agent_ids': ['analyst-1234abcd'],
|
||||||
|
'event_trace': [],
|
||||||
|
'message_trace': [],
|
||||||
|
'active_tasks': [
|
||||||
|
{
|
||||||
|
'task_id': 'task-1',
|
||||||
|
'title': 'Analyze issue',
|
||||||
|
'role': 'analyst',
|
||||||
|
'owner_agent_id': 'analyst-1234abcd',
|
||||||
|
'status': 'completed',
|
||||||
|
'thread_id': 'thread-1',
|
||||||
|
'result_summary': 'Analysis complete',
|
||||||
|
'evidence': [],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
'task_results': [
|
||||||
|
{
|
||||||
|
'task_id': 'task-1',
|
||||||
|
'status': 'completed',
|
||||||
|
'summary': 'Analysis complete',
|
||||||
|
'owner_agent_id': 'analyst-1234abcd',
|
||||||
|
'thread_id': 'thread-1',
|
||||||
|
'evidence': [
|
||||||
|
{
|
||||||
|
'tool_name': 'search_web',
|
||||||
|
'args': {'query': 'jarvis visibility'},
|
||||||
|
'result_preview': 'task-specific',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'type': 'verification',
|
||||||
|
'status': 'passed',
|
||||||
|
'summary': 'Verified',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
'task_hierarchy': {'root-task': ['task-1']},
|
||||||
|
'tool_outcomes': [
|
||||||
|
{
|
||||||
|
'tool_name': 'search_web',
|
||||||
|
'args': {'query': 'jarvis visibility'},
|
||||||
|
'result_preview': 'global-duplicate',
|
||||||
|
'verifier_hints': {'tool_name': 'search_web'},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
'verification_status': 'passed',
|
||||||
|
'verification_summary': 'All task evidence verified.',
|
||||||
|
'verification_evidence': [
|
||||||
|
{'task_id': 'task-1', 'status': 'passed', 'summary': 'Verified'}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
async with session_factory() as session:
|
||||||
|
user = User(
|
||||||
|
username='task_evidence_user',
|
||||||
|
email='task-evidence@example.com',
|
||||||
|
hashed_password=get_password_hash('secret123'),
|
||||||
|
full_name='Task Evidence Tester',
|
||||||
|
)
|
||||||
|
session.add(user)
|
||||||
|
await session.flush()
|
||||||
|
conversation = Conversation(user_id=user.id, title='Task evidence test', agent_state=snapshot)
|
||||||
|
session.add(conversation)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(user)
|
||||||
|
await session.refresh(conversation)
|
||||||
|
|
||||||
|
async def override_get_db():
|
||||||
|
async with session_factory() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
async def override_get_current_user():
|
||||||
|
return user
|
||||||
|
|
||||||
|
test_app = FastAPI()
|
||||||
|
test_app.include_router(agent_router)
|
||||||
|
test_app.dependency_overrides[get_db] = override_get_db
|
||||||
|
test_app.dependency_overrides[get_current_user] = override_get_current_user
|
||||||
|
|
||||||
|
transport = ASGITransport(app=test_app)
|
||||||
|
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||||
|
response = await client.get(
|
||||||
|
'/api/agents/visibility/tasks/task-1/evidence',
|
||||||
|
params={'conversation_id': conversation.id},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
payload = response.json()
|
||||||
|
assert payload['tool_outcomes'] == [
|
||||||
|
{
|
||||||
|
'tool_name': 'search_web',
|
||||||
|
'args': {'query': 'jarvis visibility'},
|
||||||
|
'result_preview': 'task-specific',
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_visibility_thread_messages_returns_thread_history(visibility_env):
|
||||||
|
app, ids = visibility_env
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||||
|
response = await client.get(
|
||||||
|
f'/api/agents/visibility/threads/{ids["thread_id"]}/messages',
|
||||||
|
params={'conversation_id': ids['conversation_id']},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
payload = response.json()
|
||||||
|
assert payload['thread_id'] == ids['thread_id']
|
||||||
|
assert payload['total'] == 2
|
||||||
|
assert payload['items'][1]['reply_to_message_id'] == 'msg-1'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_visibility_verifier_returns_verdict(visibility_env):
|
||||||
|
app, ids = visibility_env
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||||
|
response = await client.get(
|
||||||
|
'/api/agents/visibility/verifier',
|
||||||
|
params={'conversation_id': ids['conversation_id']},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
payload = response.json()
|
||||||
|
assert payload['status'] == 'passed'
|
||||||
|
assert payload['summary'] == 'All task evidence verified.'
|
||||||
|
assert payload['evidence'][0]['task_id'] == ids['task_id']
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_visibility_events_reject_invalid_datetime(visibility_env):
|
||||||
|
app, ids = visibility_env
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||||
|
response = await client.get(
|
||||||
|
'/api/agents/visibility/events',
|
||||||
|
params={
|
||||||
|
'conversation_id': ids['conversation_id'],
|
||||||
|
'started_after': 'not-a-date',
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert response.json()['detail'] == '时间参数必须是 ISO 8601 格式'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_visibility_events_support_time_window_and_offset_pagination(visibility_env):
|
||||||
|
app, ids = visibility_env
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||||
|
response = await client.get(
|
||||||
|
'/api/agents/visibility/events',
|
||||||
|
params={
|
||||||
|
'conversation_id': ids['conversation_id'],
|
||||||
|
'started_after': ids['started_after'],
|
||||||
|
'ended_before': ids['ended_before'],
|
||||||
|
'limit': 1,
|
||||||
|
'offset': 1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
payload = response.json()
|
||||||
|
assert payload['total'] == 2
|
||||||
|
assert payload['limit'] == 1
|
||||||
|
assert payload['offset'] == 1
|
||||||
|
assert len(payload['items']) == 1
|
||||||
|
assert payload['items'][0]['event_id'] == 'evt-2'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_visibility_topology_includes_task_counts_for_root_and_child(visibility_env):
|
||||||
|
app, ids = visibility_env
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||||
|
response = await client.get(
|
||||||
|
'/api/agents/visibility/topology',
|
||||||
|
params={'conversation_id': ids['conversation_id']},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
payload = response.json()
|
||||||
|
nodes = {node['agent_id']: node for node in payload['nodes']}
|
||||||
|
assert nodes['master']['task_count'] == 0
|
||||||
|
assert nodes['master']['completed_task_count'] == 0
|
||||||
|
assert nodes['analyst-1234abcd']['task_count'] == 1
|
||||||
|
assert nodes['analyst-1234abcd']['completed_task_count'] == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_visibility_task_evidence_returns_404_for_unknown_task(visibility_env):
|
||||||
|
app, ids = visibility_env
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||||
|
response = await client.get(
|
||||||
|
'/api/agents/visibility/tasks/missing-task/evidence',
|
||||||
|
params={'conversation_id': ids['conversation_id']},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 404
|
||||||
|
assert response.json()['detail'] == '任务不存在'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_visibility_thread_messages_returns_404_for_unknown_thread(visibility_env):
|
||||||
|
app, ids = visibility_env
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||||
|
response = await client.get(
|
||||||
|
'/api/agents/visibility/threads/missing-thread/messages',
|
||||||
|
params={'conversation_id': ids['conversation_id']},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 404
|
||||||
|
assert response.json()['detail'] == '线程不存在'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_visibility_returns_404_when_conversation_is_missing(visibility_env):
|
||||||
|
app, _ids = visibility_env
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||||
|
response = await client.get(
|
||||||
|
'/api/agents/visibility/events',
|
||||||
|
params={'conversation_id': 'missing-conversation'},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 404
|
||||||
|
assert response.json()['detail'] == '对话不存在'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_visibility_returns_404_when_snapshot_is_missing(tmp_path):
|
||||||
|
db_path = tmp_path / 'test_visibility_api_missing_snapshot.db'
|
||||||
|
engine = create_async_engine(f'sqlite+aiosqlite:///{db_path}', future=True)
|
||||||
|
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
|
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
|
||||||
|
async with session_factory() as session:
|
||||||
|
user = User(
|
||||||
|
username='missing_snapshot_user',
|
||||||
|
email='missing-snapshot@example.com',
|
||||||
|
hashed_password=get_password_hash('secret123'),
|
||||||
|
full_name='Missing Snapshot Tester',
|
||||||
|
)
|
||||||
|
session.add(user)
|
||||||
|
await session.flush()
|
||||||
|
conversation = Conversation(user_id=user.id, title='Missing snapshot test', agent_state=None)
|
||||||
|
session.add(conversation)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(user)
|
||||||
|
await session.refresh(conversation)
|
||||||
|
|
||||||
|
async def override_get_db():
|
||||||
|
async with session_factory() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
async def override_get_current_user():
|
||||||
|
return user
|
||||||
|
|
||||||
|
test_app = FastAPI()
|
||||||
|
test_app.include_router(agent_router)
|
||||||
|
test_app.dependency_overrides[get_db] = override_get_db
|
||||||
|
test_app.dependency_overrides[get_current_user] = override_get_current_user
|
||||||
|
|
||||||
|
transport = ASGITransport(app=test_app)
|
||||||
|
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||||
|
response = await client.get(
|
||||||
|
'/api/agents/visibility/verifier',
|
||||||
|
params={'conversation_id': conversation.id},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 404
|
||||||
|
assert response.json()['detail'] == '当前会话暂无可视化运行时数据'
|
||||||
|
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_visibility_verifier_returns_empty_verdict_when_state_is_unverified(tmp_path):
|
||||||
|
db_path = tmp_path / 'test_visibility_api_empty_verifier.db'
|
||||||
|
engine = create_async_engine(f'sqlite+aiosqlite:///{db_path}', future=True)
|
||||||
|
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
|
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
|
||||||
|
snapshot = {
|
||||||
|
'kind': 'agent_continuity_state',
|
||||||
|
'version': 1,
|
||||||
|
'state': {
|
||||||
|
'agent_id': 'master',
|
||||||
|
'root_agent_id': 'master',
|
||||||
|
'current_agent': 'master',
|
||||||
|
'event_trace': [],
|
||||||
|
'message_trace': [],
|
||||||
|
'active_tasks': [],
|
||||||
|
'task_results': [],
|
||||||
|
'task_hierarchy': {},
|
||||||
|
'tool_outcomes': [],
|
||||||
|
'verification_status': None,
|
||||||
|
'verification_summary': None,
|
||||||
|
'verification_evidence': [],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
async with session_factory() as session:
|
||||||
|
user = User(
|
||||||
|
username='empty_verifier_user',
|
||||||
|
email='empty-verifier@example.com',
|
||||||
|
hashed_password=get_password_hash('secret123'),
|
||||||
|
full_name='Empty Verifier Tester',
|
||||||
|
)
|
||||||
|
session.add(user)
|
||||||
|
await session.flush()
|
||||||
|
conversation = Conversation(user_id=user.id, title='Empty verifier test', agent_state=snapshot)
|
||||||
|
session.add(conversation)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(user)
|
||||||
|
await session.refresh(conversation)
|
||||||
|
|
||||||
|
async def override_get_db():
|
||||||
|
async with session_factory() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
async def override_get_current_user():
|
||||||
|
return user
|
||||||
|
|
||||||
|
test_app = FastAPI()
|
||||||
|
test_app.include_router(agent_router)
|
||||||
|
test_app.dependency_overrides[get_db] = override_get_db
|
||||||
|
test_app.dependency_overrides[get_current_user] = override_get_current_user
|
||||||
|
|
||||||
|
transport = ASGITransport(app=test_app)
|
||||||
|
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||||
|
response = await client.get(
|
||||||
|
'/api/agents/visibility/verifier',
|
||||||
|
params={'conversation_id': conversation.id},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
payload = response.json()
|
||||||
|
assert payload['status'] is None
|
||||||
|
assert payload['summary'] is None
|
||||||
|
assert payload['evidence'] == []
|
||||||
|
|
||||||
|
await engine.dispose()
|
||||||
@@ -53,19 +53,17 @@ async def agent_env(tmp_path):
|
|||||||
is_active=True,
|
is_active=True,
|
||||||
owner_id=user.id,
|
owner_id=user.id,
|
||||||
)
|
)
|
||||||
session.add_all([
|
agent = Agent(
|
||||||
Agent(
|
name='SCHEDULE PLANNER',
|
||||||
name='SCHEDULE PLANNER',
|
role='schedule_planner',
|
||||||
role='schedule_planner',
|
description='日程规划师',
|
||||||
description='日程规划师',
|
system_prompt='prompt',
|
||||||
system_prompt='prompt',
|
is_active=True,
|
||||||
is_active=True,
|
)
|
||||||
),
|
session.add_all([agent, skill_a, skill_b])
|
||||||
skill_a,
|
|
||||||
skill_b,
|
|
||||||
])
|
|
||||||
await session.commit()
|
await session.commit()
|
||||||
await session.refresh(user)
|
await session.refresh(user)
|
||||||
|
await session.refresh(agent)
|
||||||
await session.refresh(skill_a)
|
await session.refresh(skill_a)
|
||||||
await session.refresh(skill_b)
|
await session.refresh(skill_b)
|
||||||
|
|
||||||
@@ -82,7 +80,7 @@ async def agent_env(tmp_path):
|
|||||||
test_app.dependency_overrides[get_current_user] = override_get_current_user
|
test_app.dependency_overrides[get_current_user] = override_get_current_user
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield test_app, {'skill_a_id': skill_a.id, 'skill_b_id': skill_b.id}
|
yield test_app, {'agent_id': agent.id, 'skill_a_id': skill_a.id, 'skill_b_id': skill_b.id}
|
||||||
finally:
|
finally:
|
||||||
await engine.dispose()
|
await engine.dispose()
|
||||||
|
|
||||||
@@ -116,6 +114,32 @@ async def test_update_agent_config_persists_selected_skill_ids(agent_env):
|
|||||||
assert get_response.json()['selected_skill_ids'] == [ids['skill_a_id'], ids['skill_b_id']]
|
assert get_response.json()['selected_skill_ids'] == [ids['skill_a_id'], ids['skill_b_id']]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_agent_config_requires_authentication(agent_env):
|
||||||
|
app, _ids = agent_env
|
||||||
|
|
||||||
|
async def override_get_current_user_unauthorized():
|
||||||
|
raise RuntimeError('should not be called')
|
||||||
|
|
||||||
|
app.dependency_overrides.pop(get_current_user, None)
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||||
|
response = await client.get('/api/agents/config/schedule_planner')
|
||||||
|
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_agent_requires_authentication(agent_env):
|
||||||
|
app, ids = agent_env
|
||||||
|
app.dependency_overrides.pop(get_current_user, None)
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||||
|
response = await client.get(f"/api/agents/{ids['agent_id']}")
|
||||||
|
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_update_agent_config_preserves_selected_skill_ids_when_omitted(agent_env):
|
async def test_update_agent_config_preserves_selected_skill_ids_when_omitted(agent_env):
|
||||||
app, ids = agent_env
|
app, ids = agent_env
|
||||||
@@ -148,3 +172,84 @@ async def test_update_agent_config_rejects_invalid_selected_skill_ids(agent_env)
|
|||||||
|
|
||||||
assert response.status_code == 400
|
assert response.status_code == 400
|
||||||
assert response.json()['detail'] == '存在无效的技能绑定'
|
assert response.json()['detail'] == '存在无效的技能绑定'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_agent_requires_superuser(agent_env):
|
||||||
|
app, _ids = agent_env
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||||
|
response = await client.post(
|
||||||
|
'/api/agents',
|
||||||
|
json={
|
||||||
|
'name': 'Runtime Planner',
|
||||||
|
'role': 'schedule_planning',
|
||||||
|
'description': 'runtime',
|
||||||
|
'system_prompt': 'prompt',
|
||||||
|
'spawn_permission': True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 403
|
||||||
|
assert response.json()['detail'] == '仅管理员可创建 Agent'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_agent_requires_spawn_permission_for_runtime_role(agent_env):
|
||||||
|
app, _ids = agent_env
|
||||||
|
|
||||||
|
async def override_admin_user():
|
||||||
|
return User(
|
||||||
|
username='admin_user',
|
||||||
|
email='admin@example.com',
|
||||||
|
hashed_password='x',
|
||||||
|
is_superuser=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
app.dependency_overrides[get_current_user] = override_admin_user
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||||
|
response = await client.post(
|
||||||
|
'/api/agents',
|
||||||
|
json={
|
||||||
|
'name': 'Runtime Planner',
|
||||||
|
'role': 'schedule_planning',
|
||||||
|
'description': 'runtime',
|
||||||
|
'system_prompt': 'prompt',
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert response.json()['detail'] == '缺少 spawn_permission,禁止直接创建 runtime agent'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_agent_accepts_allowed_runtime_role_for_superuser(agent_env):
|
||||||
|
app, _ids = agent_env
|
||||||
|
|
||||||
|
async def override_admin_user():
|
||||||
|
return User(
|
||||||
|
username='admin_user',
|
||||||
|
email='admin@example.com',
|
||||||
|
hashed_password='x',
|
||||||
|
is_superuser=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
app.dependency_overrides[get_current_user] = override_admin_user
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||||
|
response = await client.post(
|
||||||
|
'/api/agents',
|
||||||
|
json={
|
||||||
|
'name': 'Runtime Planner',
|
||||||
|
'role': 'schedule_planning',
|
||||||
|
'description': 'runtime',
|
||||||
|
'system_prompt': 'prompt',
|
||||||
|
'spawn_permission': True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 201
|
||||||
|
payload = response.json()
|
||||||
|
assert payload['name'] == 'Runtime Planner'
|
||||||
|
assert payload['role'] == 'schedule_planning'
|
||||||
|
|||||||
75
backend/tests/backend/app/test_conversation_router.py
Normal file
75
backend/tests/backend/app/test_conversation_router.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from httpx import ASGITransport, AsyncClient
|
||||||
|
from sqlalchemy import text
|
||||||
|
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
|
import app.models # noqa: F401
|
||||||
|
from app.database import Base, get_db, ensure_conversation_columns
|
||||||
|
from app.models.conversation import Conversation
|
||||||
|
from app.models.user import User
|
||||||
|
from app.routers.auth import get_current_user
|
||||||
|
from app.routers.conversation import router as conversation_router
|
||||||
|
from app.services.auth_service import get_password_hash
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def conversation_env(tmp_path):
|
||||||
|
db_path = tmp_path / 'test_conversation_router.db'
|
||||||
|
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
|
||||||
|
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
|
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
await conn.execute(text('ALTER TABLE conversations DROP COLUMN agent_state'))
|
||||||
|
await ensure_conversation_columns(conn)
|
||||||
|
|
||||||
|
async with session_factory() as session:
|
||||||
|
user = User(
|
||||||
|
username='conversation_user',
|
||||||
|
email='conversation@example.com',
|
||||||
|
hashed_password=get_password_hash('secret123'),
|
||||||
|
full_name='Conversation Tester',
|
||||||
|
is_active=True,
|
||||||
|
)
|
||||||
|
session.add(user)
|
||||||
|
await session.flush()
|
||||||
|
session.add(
|
||||||
|
Conversation(
|
||||||
|
user_id=user.id,
|
||||||
|
title='Existing conversation',
|
||||||
|
message_count=3,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(user)
|
||||||
|
|
||||||
|
async def override_get_db():
|
||||||
|
async with session_factory() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
async def override_get_current_user():
|
||||||
|
return user
|
||||||
|
|
||||||
|
test_app = FastAPI()
|
||||||
|
test_app.include_router(conversation_router)
|
||||||
|
test_app.dependency_overrides[get_db] = override_get_db
|
||||||
|
test_app.dependency_overrides[get_current_user] = override_get_current_user
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield test_app
|
||||||
|
finally:
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_conversations_succeeds_when_agent_state_column_was_missing(conversation_env):
|
||||||
|
transport = ASGITransport(app=conversation_env)
|
||||||
|
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||||
|
response = await client.get('/api/conversations')
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
payload = response.json()
|
||||||
|
assert len(payload) == 1
|
||||||
|
assert payload[0]['title'] == 'Existing conversation'
|
||||||
|
assert payload[0]['message_count'] == 3
|
||||||
Reference in New Issue
Block a user