diff --git a/backend/app/agents/graph.py b/backend/app/agents/graph.py index a24feb0..a16b613 100644 --- a/backend/app/agents/graph.py +++ b/backend/app/agents/graph.py @@ -14,6 +14,7 @@ from langgraph.graph import END, StateGraph from app.agents.prompts import ( ANALYST_SYSTEM_PROMPT, + COORDINATOR_SYSTEM_PROMPT, EXECUTOR_SYSTEM_PROMPT, JSON_ACTION_FALLBACK_PROMPT, LIBRARIAN_SYSTEM_PROMPT, @@ -22,11 +23,13 @@ from app.agents.prompts import ( ) from app.agents.registry import load_builtin_registry_indexes from app.agents.schemas.event import AgentEvent +from app.agents.schemas.message import AgentMessage +from app.agents.schemas.task import AgentTask, CollaborationBudget, InterruptRecord, RecoveryRecord, TaskResult from app.agents.skill_registry import build_skill_context from app.agents.state import AgentRole, AgentState from app.agents.tools import SUB_COMMANDER_TOOLSETS from app.agents.tools.time_reasoning import normalize_tool_time_arguments -from app.agents.verifier import apply_verification_verdict, verify_task_result +from app.agents.verifier import apply_verification_verdict, normalize_task_result, verify_task_result from app.services.llm_service import ( create_llm_from_config, default_provider_capabilities, @@ -115,8 +118,13 @@ SCHEDULE_CONFIRMATION_HINTS = ( "是否需要我现在创建", ) SCHEDULE_CONFIRMATION_QUESTION_MARKERS = ("是否", "要不要", "吗", "?", "?") - - +COLLABORATION_STEP_MARKERS = ("然后", "再", "并且", "同时", "顺便", "最后", "分别", "拆成", "协作", "整合") +COLLABORATION_ROLE_ORDER = { + AgentRole.LIBRARIAN: 0, + AgentRole.ANALYST: 1, + AgentRole.SCHEDULE_PLANNER: 2, + AgentRole.EXECUTOR: 3, +} def _role_value(role: AgentRole | str | None) -> str: if isinstance(role, AgentRole): return role.value @@ -195,6 +203,296 @@ def _summary_state_key(target: str) -> Literal["schedule_context_summary", "know return cast(Literal["schedule_context_summary", "knowledge_context", "analysis_report"], target) +def _ensure_message_thread(state: AgentState) -> str: + thread_id = str(state.get("thread_id") or "").strip() + if thread_id: + return thread_id + thread_id = f"thread-{uuid4()}" + state["thread_id"] = thread_id + return thread_id + + +def _bump_message_sequence(state: AgentState) -> tuple[str, int]: + sequence = int(state.get("message_sequence") or 0) + 1 + state["message_sequence"] = sequence + message_id = f"msg-{sequence}" + state["last_message_id"] = message_id + return message_id, sequence + + +def _task_payload(task: AgentTask | dict[str, Any] | None) -> dict[str, Any]: + if isinstance(task, AgentTask): + return task.model_dump(mode="json") + return dict(task or {}) + + +def _budget_snapshot_from_state( + state: AgentState, + *, + mode: Literal["direct", "collaboration"], + remaining_parallel_tasks: int | None = None, + metadata: dict[str, Any] | None = None, +) -> dict[str, Any]: + metadata_payload = { + "max_spawn_depth": 1, + "max_child_agents": remaining_parallel_tasks, + "max_messages_per_thread": 24, + "max_messages_per_turn": 8, + "max_parallel_collaborators": remaining_parallel_tasks, + "recovery_attempt_limit": _get_state_int(state, "max_retries") or 1, + "current_depth": _get_state_int(state, "collaboration_depth"), + **(metadata or {}), + } + budget = CollaborationBudget( + mode=mode, + max_parallel_tasks=remaining_parallel_tasks, + remaining_parallel_tasks=remaining_parallel_tasks, + max_tool_calls=_get_state_int(state, "max_tool_rounds") or None, + remaining_tool_calls=max(_get_state_int(state, "max_tool_rounds") - _get_state_int(state, "tool_round_count"), 0) + if _get_state_int(state, "max_tool_rounds") + else None, + max_iterations=_get_state_int(state, "max_iterations") or None, + remaining_iterations=max(_get_state_int(state, "max_iterations") - _get_state_int(state, "iteration_count"), 0) + if _get_state_int(state, "max_iterations") + else None, + metadata=metadata_payload, + ) + return budget.model_dump(mode="json") + + +def _store_budget_snapshot(state: AgentState, budget_snapshot: dict[str, Any]) -> None: + state["budget_state"] = budget_snapshot + state["collaboration_budget_history"] = [ + *(state.get("collaboration_budget_history") or []), + budget_snapshot, + ] + + +def _budget_metadata(state: AgentState) -> dict[str, Any]: + budget_state = state.get("budget_state") or {} + if isinstance(budget_state, CollaborationBudget): + return dict(budget_state.metadata or {}) + if isinstance(budget_state, dict): + return dict(budget_state.get("metadata") or {}) + return {} + + +def _budget_limit(state: AgentState, key: str, default: int) -> int: + value = _budget_metadata(state).get(key) + return value if isinstance(value, int) and value > 0 else default + + +def _append_message_trace( + state: AgentState, + *, + from_agent_id: str, + to_agent_id: str, + message_type: str, + content_summary: str, + task_id: str | None = None, + reply_to_message_id: str | None = None, + payload: dict[str, Any] | None = None, + message_id: str | None = None, + message_index: int | None = None, +) -> dict[str, Any]: + thread_id = _ensure_message_thread(state) + if message_id is None or message_index is None: + message_id, message_index = _bump_message_sequence(state) + message = AgentMessage( + message_id=message_id, + thread_id=thread_id, + from_agent_id=from_agent_id, + to_agent_id=to_agent_id, + task_id=task_id, + reply_to_message_id=reply_to_message_id, + message_type=cast(Any, message_type), + content_summary=content_summary[:400], + payload=payload or {}, + ) + serialized = message.model_dump(mode="json") + state["message_trace"] = [*(state.get("message_trace") or []), serialized] + return serialized + + +def _create_child_agent( + state: AgentState, + *, + role: AgentRole, + task: AgentTask, +) -> str | None: + parent_agent_id = str(state.get("agent_id") or AgentRole.MASTER.value) + if not _spawn_permission_for_role(state, role): + _append_event_trace( + state, + "agent.spawn.blocked", + payload={"reason": "role_policy_blocked", "target_role": role.value}, + severity="warning", + task_id=task.task_id, + parent_task_id=task.parent_task_id, + child_task_id=task.task_id, + ) + return None + current_depth = _get_state_int(state, "collaboration_depth") + if current_depth >= _budget_limit(state, "max_spawn_depth", 1): + _append_event_trace( + state, + "agent.spawn.blocked", + payload={"reason": "max_spawn_depth_exceeded", "target_role": role.value, "depth": current_depth}, + severity="warning", + task_id=task.task_id, + parent_task_id=task.parent_task_id, + child_task_id=task.task_id, + ) + return None + spawned_agent_ids = list(state.get("spawned_agent_ids") or []) + if len(spawned_agent_ids) >= _budget_limit(state, "max_child_agents", 4): + _append_event_trace( + state, + "agent.spawn.blocked", + payload={"reason": "max_child_agents_exceeded", "target_role": role.value}, + severity="warning", + task_id=task.task_id, + parent_task_id=task.parent_task_id, + child_task_id=task.task_id, + ) + return None + child_agent_id = f"{role.value}-{uuid4().hex[:8]}" + state["spawned_agent_ids"] = [*spawned_agent_ids, child_agent_id] + _append_event_trace( + state, + "agent.created", + payload={"parent_agent_id": parent_agent_id, "child_agent_id": child_agent_id, "target_role": role.value}, + task_id=task.task_id, + parent_task_id=task.parent_task_id, + child_task_id=task.task_id, + ) + return child_agent_id + + +def _collaboration_task_from_state(state: AgentState) -> dict[str, Any]: + return dict((state.get("turn_context") or {}).get("collaboration_task") or {}) + + +def _record_interrupt( + state: AgentState, + *, + reason: str, + task: AgentTask | dict[str, Any] | None = None, + payload: dict[str, Any] | None = None, +) -> InterruptRecord: + task_payload = _task_payload(task) if task is not None else _collaboration_task_from_state(state) + interrupt = InterruptRecord( + interrupt_id=f"interrupt-{uuid4()}", + reason=reason, + requested_by=_role_value(state.get("current_agent")), + payload=payload or {}, + ) + interrupt_payload = interrupt.model_dump(mode="json") + state["interrupted_tasks"] = [*(state.get("interrupted_tasks") or []), interrupt_payload] + state["recovery_points"] = [ + *(state.get("recovery_points") or []), + { + "interrupt_id": interrupt.interrupt_id, + "task_id": task_payload.get("task_id"), + "thread_id": state.get("thread_id"), + "message_id": state.get("last_message_id"), + }, + ] + _append_event_trace( + state, + "agent.interrupt.requested", + payload={"reason": reason, **(payload or {})}, + severity="warning", + task_id=task_payload.get("task_id"), + parent_task_id=task_payload.get("parent_task_id"), + child_task_id=(task_payload.get("child_task_ids") or [None])[0], + interrupt_id=interrupt.interrupt_id, + ) + _append_event_trace( + state, + "agent.task.interrupted", + payload={"reason": reason, **(payload or {})}, + severity="warning", + task_id=task_payload.get("task_id"), + parent_task_id=task_payload.get("parent_task_id"), + child_task_id=(task_payload.get("child_task_ids") or [None])[0], + interrupt_id=interrupt.interrupt_id, + ) + _append_event_trace( + state, + "agent.interrupt.completed", + payload={"reason": reason, **(payload or {})}, + severity="warning", + task_id=task_payload.get("task_id"), + parent_task_id=task_payload.get("parent_task_id"), + child_task_id=(task_payload.get("child_task_ids") or [None])[0], + interrupt_id=interrupt.interrupt_id, + ) + return interrupt + + +def _record_recovery( + state: AgentState, + *, + interrupt: InterruptRecord, + strategy: str, + task: AgentTask | dict[str, Any] | None = None, + payload: dict[str, Any] | None = None, +) -> RecoveryRecord: + task_payload = _task_payload(task) if task is not None else _collaboration_task_from_state(state) + recovery = RecoveryRecord( + recovery_id=f"recovery-{uuid4()}", + source_interrupt_id=interrupt.interrupt_id, + strategy=strategy, + resumed_from_task_id=task_payload.get("task_id"), + resumed_from_thread_id=str(state.get("thread_id") or "") or None, + payload=payload or {}, + ) + recovery_payload = recovery.model_dump(mode="json") + state["recovery_trace"] = [*(state.get("recovery_trace") or []), recovery_payload] + _append_event_trace( + state, + "agent.recovery.started", + payload={"strategy": strategy, **(payload or {})}, + task_id=task_payload.get("task_id"), + parent_task_id=task_payload.get("parent_task_id"), + child_task_id=(task_payload.get("child_task_ids") or [None])[0], + recovery_id=recovery.recovery_id, + interrupt_id=interrupt.interrupt_id, + ) + _append_event_trace( + state, + "agent.task.recovered", + payload={"strategy": strategy, **(payload or {})}, + task_id=task_payload.get("task_id"), + parent_task_id=task_payload.get("parent_task_id"), + child_task_id=(task_payload.get("child_task_ids") or [None])[0], + recovery_id=recovery.recovery_id, + interrupt_id=interrupt.interrupt_id, + ) + _append_event_trace( + state, + "agent.recovery.completed", + payload={"strategy": strategy, **(payload or {})}, + task_id=task_payload.get("task_id"), + parent_task_id=task_payload.get("parent_task_id"), + child_task_id=(task_payload.get("child_task_ids") or [None])[0], + recovery_id=recovery.recovery_id, + interrupt_id=interrupt.interrupt_id, + ) + return recovery + + +def _spawn_permission_for_role(state: AgentState, role: AgentRole) -> bool: + indexes = load_builtin_registry_indexes() + current_role_value = _normalize_current_agent(state.get("current_agent")) + current_manifest = indexes.agent_by_role_value.get(current_role_value) + if current_manifest is None or not current_manifest.can_spawn_children: + return False + allowed_roles = indexes.spawnable_role_values_by_agent_id.get(current_manifest.agent_id, ()) + return role.value in allowed_roles + + def _get_llm_for_state(state: AgentState): user_llm_config = state.get("user_llm_config") llm = create_llm_from_config(user_llm_config) if user_llm_config else get_llm() @@ -378,29 +676,78 @@ def _should_clear_schedule_creation_continuity(state: AgentState, created_entiti return any(entity.get("type") == "reminder" for entity in created_entities) -def _route_agent_from_user_query(user_query: str) -> AgentRole: +def _classify_request_signals(user_query: str) -> dict[str, Any]: text = (user_query or "").strip().lower() - has_accounting_signal = any(keyword in text for keyword in ACCOUNTING_INTENT_KEYWORDS) - has_schedule_signal = bool(re.search(r"\d{1,2}月\d{1,2}日", text) or any(keyword in text for keyword in SCHEDULE_KEYWORDS)) + has_schedule_signal = not has_accounting_signal and ( + bool(re.search(r"\d{1,2}月\d{1,2}日", text)) or any(keyword in text for keyword in SCHEDULE_KEYWORDS) + ) has_analysis_signal = any(keyword in text for keyword in ANALYSIS_KEYWORDS) + has_knowledge_signal = any(keyword in text for keyword in KNOWLEDGE_KEYWORDS) + has_execution_signal = has_accounting_signal or any(keyword in text for keyword in EXECUTION_KEYWORDS) - if has_accounting_signal: + roles: list[AgentRole] = [] + if has_knowledge_signal: + roles.append(AgentRole.LIBRARIAN) + if has_analysis_signal: + roles.append(AgentRole.ANALYST) + if has_schedule_signal: + roles.append(AgentRole.SCHEDULE_PLANNER) + if has_execution_signal: + roles.append(AgentRole.EXECUTOR) + + return { + "text": text, + "has_accounting_signal": has_accounting_signal, + "has_schedule_signal": has_schedule_signal, + "has_analysis_signal": has_analysis_signal, + "has_knowledge_signal": has_knowledge_signal, + "has_execution_signal": has_execution_signal, + "roles": roles, + } + + +def _select_request_mode(user_query: str) -> tuple[Literal["direct", "collaboration"], dict[str, Any]]: + signals = _classify_request_signals(user_query) + roles: list[AgentRole] = signals["roles"] + text = signals["text"] + has_multi_step_signal = any(marker in text for marker in COLLABORATION_STEP_MARKERS) + is_explicit_collaboration_request = "协作" in text or ("拆" in text and "任务" in text) + is_long_request = len((user_query or "").strip()) >= 24 + + should_collaborate = len(roles) >= 3 or ( + len(roles) >= 2 and (has_multi_step_signal or is_explicit_collaboration_request or is_long_request) + ) + selected_mode: Literal["direct", "collaboration"] = "collaboration" if should_collaborate else "direct" + metadata = { + "mode": selected_mode, + "reason": "multi_role_request" if should_collaborate else "single_role_or_simple_request", + "roles": [role.value for role in roles], + "multi_step": has_multi_step_signal, + } + return selected_mode, metadata + + +def _route_agent_from_user_query(user_query: str) -> AgentRole: + signals = _classify_request_signals(user_query) + text = signals["text"] + + if signals["has_accounting_signal"]: return AgentRole.EXECUTOR - if has_schedule_signal: + if signals["has_schedule_signal"]: return AgentRole.SCHEDULE_PLANNER - if has_analysis_signal: + if signals["has_analysis_signal"]: return AgentRole.ANALYST - if any(keyword in text for keyword in KNOWLEDGE_KEYWORDS): + if signals["has_knowledge_signal"]: return AgentRole.LIBRARIAN if any(pattern in text for pattern in GENERAL_QA_PATTERNS): return AgentRole.MASTER - if any(keyword in text for keyword in EXECUTION_KEYWORDS): + if signals["has_execution_signal"]: return AgentRole.EXECUTOR return AgentRole.MASTER @@ -583,6 +930,10 @@ def _build_system_messages(state: AgentState, system_prompt: str, role: AgentRol if clarification_summary: messages.append(SystemMessage(content=clarification_summary)) + collaboration_summary = _build_collaboration_context_summary(state) + if collaboration_summary: + messages.append(SystemMessage(content=collaboration_summary)) + role_context_map = { AgentRole.SCHEDULE_PLANNER: state.get("schedule_context_summary"), AgentRole.LIBRARIAN: state.get("knowledge_context"), @@ -603,6 +954,103 @@ def _build_system_messages(state: AgentState, system_prompt: str, role: AgentRol return messages +def _assign_agent_for_task_role(role_value: str | None) -> AgentRole: + try: + return AgentRole(str(role_value or AgentRole.MASTER.value)) + except ValueError: + return AgentRole.MASTER + + +def _build_task_title(role: AgentRole) -> str: + titles = { + AgentRole.LIBRARIAN: "补齐事实与证据", + AgentRole.ANALYST: "给出分析与判断", + AgentRole.SCHEDULE_PLANNER: "形成安排与计划", + AgentRole.EXECUTOR: "执行必要动作", + } + return titles.get(role, "处理协作任务") + + +def _build_task_goal(role: AgentRole, user_query: str) -> str: + goals = { + AgentRole.LIBRARIAN: f"围绕用户请求检索并整理能直接支撑结论的证据:{user_query}", + AgentRole.ANALYST: f"基于已有上下文与证据,对用户请求给出结论、风险和建议:{user_query}", + AgentRole.SCHEDULE_PLANNER: f"把用户请求收束为明确安排、节奏或日程建议:{user_query}", + AgentRole.EXECUTOR: f"基于用户请求和前序结论执行必要动作,并回收结果:{user_query}", + } + return goals.get(role, user_query) + + +def _build_expected_evidence(role: AgentRole) -> list[dict[str, Any]]: + evidence_map = { + AgentRole.LIBRARIAN: [{"type": "evidence", "detail": "retrieval findings or source-backed context"}], + AgentRole.ANALYST: [{"type": "analysis", "detail": "judgment with supporting rationale"}], + AgentRole.SCHEDULE_PLANNER: [{"type": "plan", "detail": "explicit schedule or next-step proposal"}], + AgentRole.EXECUTOR: [{"type": "execution", "detail": "tool output or execution confirmation"}], + } + return evidence_map.get(role, [{"type": "summary", "detail": "task evidence"}]) + + +def _build_collaboration_tasks(user_query: str) -> list[AgentTask]: + _, metadata = _select_request_mode(user_query) + raw_roles = metadata.get("roles") or [] + roles = sorted( + (_assign_agent_for_task_role(role_value) for role_value in raw_roles), + key=lambda role: COLLABORATION_ROLE_ORDER.get(role, 99), + ) + unique_roles: list[AgentRole] = [] + for role in roles: + if role not in unique_roles and role != AgentRole.MASTER: + unique_roles.append(role) + + parent_task_id = f"task-root-{uuid4().hex[:8]}" + task_ids = [f"task-{index}-{uuid4().hex[:8]}" for index, _ in enumerate(unique_roles[:4], start=1)] + tasks: list[AgentTask] = [] + for index, role in enumerate(unique_roles[:4]): + tasks.append( + AgentTask( + task_id=task_ids[index], + title=_build_task_title(role), + owner_agent_id=role.value, + role=role.value, + goal=_build_task_goal(role, user_query), + parent_task_id=parent_task_id, + child_task_ids=[task_ids[index + 1]] if index + 1 < len(task_ids) else [], + expected_evidence=_build_expected_evidence(role), + ) + ) + return tasks + + +def _build_collaboration_context_summary(state: AgentState) -> str | None: + if state.get("execution_mode") != "collaboration": + return None + + lines = [COORDINATOR_SYSTEM_PROMPT] + current_task = (state.get("turn_context") or {}).get("collaboration_task") or {} + if current_task: + lines.extend( + [ + "current_collaboration_task:", + f"- task_id: {current_task.get('task_id')}", + f"- title: {current_task.get('title')}", + f"- role: {current_task.get('role')}", + f"- goal: {current_task.get('goal')}", + ] + ) + + task_results = state.get("task_results") or [] + if task_results: + lines.append("completed_collaboration_results:") + for item in task_results[-4:]: + normalized = item.model_dump(mode="json") if isinstance(item, TaskResult) else dict(item) + lines.append( + f"- {normalized.get('task_id')} | {normalized.get('status')} | {normalized.get('summary') or ''}" + ) + + return "\n".join(lines) + + def _maybe_reset_turn_budgets(state: AgentState) -> None: messages = state.get("messages") or [] if not messages: @@ -643,7 +1091,13 @@ def _append_event_trace( payload: dict[str, Any] | None = None, severity: str = "info", task_id: str | None = None, + parent_task_id: str | None = None, + child_task_id: str | None = None, + interrupt_id: str | None = None, + recovery_id: str | None = None, + message_id: str | None = None, ) -> None: + thread_id = _ensure_message_thread(state) event = AgentEvent( event_id=f"evt-{uuid4()}", event_type=cast(Any, event_type), @@ -651,6 +1105,12 @@ def _append_event_trace( agent_id=_role_value(state.get("current_agent")), sub_commander_id=state.get("current_sub_commander"), task_id=task_id, + parent_task_id=parent_task_id, + child_task_id=child_task_id, + thread_id=thread_id, + message_id=message_id or (str(state.get("last_message_id") or "") or None), + interrupt_id=interrupt_id, + recovery_id=recovery_id, payload=payload or {}, severity=cast(Any, severity), ) @@ -707,28 +1167,45 @@ def _update_task_result_summary(state: AgentState, tool_summaries: list[dict[str def _record_sub_commander(state: AgentState, role: AgentRole, sub_commander: str, user_query: str) -> None: + thread_id = _ensure_message_thread(state) + message_id, message_index = _bump_message_sequence(state) + current_agent_id = str(state.get("agent_id") or AgentRole.MASTER.value) state["current_agent"] = role.value state["current_sub_commander"] = sub_commander state["active_agents"] = _normalize_active_agents(state.get("active_agents")) if role not in state["active_agents"]: state["active_agents"] = [*state["active_agents"], role] state["active_sub_commanders"] = [*(state.get("active_sub_commanders") or []), sub_commander] - state["sub_commander_trace"] = [ - *(state.get("sub_commander_trace") or []), - { - "agent": _role_value(role), - "sub_commander": sub_commander, - "query": user_query, - }, - ] - state["retrieval_trace"] = [ - *(state.get("retrieval_trace") or []), - { - "agent": _role_value(role), - "sub_commander": sub_commander, - "query": user_query, - }, - ] + state["agent_trace"] = [*(state.get("agent_trace") or []), role.value] + trace_entry = { + "agent": _role_value(role), + "agent_id": current_agent_id, + "sub_commander": sub_commander, + "query": user_query, + "thread_id": thread_id, + "message_id": message_id, + "message_index": message_index, + } + state["sub_commander_trace"] = [*(state.get("sub_commander_trace") or []), trace_entry] + state["retrieval_trace"] = [*(state.get("retrieval_trace") or []), trace_entry] + _append_message_trace( + state, + from_agent_id=str(state.get("parent_agent_id") or AgentRole.MASTER.value), + to_agent_id=current_agent_id, + message_type="task_request", + content_summary=user_query, + task_id=((state.get("turn_context") or {}).get("collaboration_task") or {}).get("task_id"), + payload={"sub_commander": sub_commander}, + message_id=message_id, + message_index=message_index, + ) + _append_event_trace( + state, + "agent.message.received", + payload={"sub_commander": sub_commander, "summary": user_query[:200]}, + task_id=((state.get("turn_context") or {}).get("collaboration_task") or {}).get("task_id"), + message_id=message_id, + ) def _stop_sub_commander_due_to_budget(state: AgentState, reason: str) -> None: @@ -1096,6 +1573,7 @@ async def _run_sub_commander( state["last_tool_result"] = None state["tool_strategy_used"] = None state["fallback_parse_error"] = None + start_tool_index = len(state.get("tool_outcomes") or []) if not _guard_sub_commander_budget(state, "tool_round_count", "max_tool_rounds", "max_tool_rounds_exceeded"): pass @@ -1251,7 +1729,7 @@ async def _run_sub_commander( state[_summary_state_key(summary_target)] = state.get("final_response") task_result_summary = state.get("task_result_summary") - tool_outcomes = list(state.get("tool_outcomes") or []) + tool_outcomes = list(state.get("tool_outcomes") or [])[start_tool_index:] has_tool_failure = any( _tool_result_indicates_failure(outcome.get("result_preview")) for outcome in tool_outcomes @@ -1307,6 +1785,322 @@ async def _run_sub_commander( return state +def _derive_task_status(state: AgentState) -> Literal["completed", "failed", "blocked"]: + if state.get("clarification_needed") or state.get("stop_reason") == "clarification_needed": + return "blocked" + if state.get("verification_status") == "failed": + return "failed" + if state.get("stop_reason") and state.get("stop_reason") != "clarification_needed": + return "blocked" + if state.get("final_response"): + return "completed" + return "failed" + + +def _build_task_evidence(state: AgentState, start_index: int) -> list[dict[str, Any]]: + tool_outcomes = [dict(item) for item in list(state.get("tool_outcomes") or [])[start_index:]] + if tool_outcomes: + evidence = tool_outcomes + else: + evidence = [] + + if state.get("verification_status") or state.get("verification_summary"): + evidence.append( + { + "type": "verification", + "status": state.get("verification_status"), + "summary": state.get("verification_summary"), + } + ) + if not evidence and state.get("final_response"): + evidence.append( + { + "type": "response_summary", + "content": _stringify_message_content(state.get("final_response"))[:400], + } + ) + return evidence + + +def _collect_task_result(task: AgentTask, state: AgentState, start_tool_index: int) -> TaskResult: + status = _derive_task_status(state) + next_action = state.get("clarification_question") + if not next_action and status != "completed" and state.get("stop_reason"): + next_action = f"resolve_{state['stop_reason']}" + return TaskResult( + task_id=task.task_id, + status=status, + summary=_stringify_message_content(state.get("final_response") or state.get("verification_summary")), + evidence=_build_task_evidence(state, start_tool_index), + owner_agent_id=task.owner_agent_id, + parent_task_id=task.parent_task_id, + child_task_ids=list(task.child_task_ids), + thread_id=str(state.get("thread_id") or "") or None, + message_id=str(state.get("last_message_id") or "") or None, + message_index=int(state.get("message_sequence") or 0) or None, + interrupt_records=list(state.get("interrupted_tasks") or []), + recovery_records=list(state.get("recovery_trace") or []), + budget_snapshot=state.get("budget_state"), + next_action=next_action, + output_data={ + "role": task.role, + "sub_commander": state.get("current_sub_commander"), + "verification_status": state.get("verification_status"), + }, + ) + + +def _apply_task_result_to_state(state: AgentState, task: AgentTask, task_result: TaskResult) -> None: + normalized_result = normalize_task_result(task_result, default_task_id=task.task_id) + serialized_result = normalized_result.model_dump(mode="json") + state["task_results"] = [*(state.get("task_results") or []), serialized_result] + + updated_tasks: list[dict[str, Any]] = [] + completed_entry: dict[str, Any] | None = None + pending_tasks: list[dict[str, Any]] = [] + for existing_task in state.get("active_tasks") or []: + normalized_task = existing_task.model_dump(mode="json") if isinstance(existing_task, AgentTask) else dict(existing_task) + if normalized_task.get("task_id") == task.task_id: + normalized_task["status"] = normalized_result.status + normalized_task["evidence"] = list(normalized_result.evidence) + normalized_task["result_summary"] = normalized_result.summary + if normalized_result.next_action: + normalized_task["next_action"] = normalized_result.next_action + completed_entry = { + "task_id": task.task_id, + "role": task.role, + "owner_agent_id": task.owner_agent_id, + "status": normalized_result.status, + "summary": normalized_result.summary, + "next_action": normalized_result.next_action, + } + updated_tasks.append(normalized_task) + if normalized_task.get("status") != "completed": + pending_tasks.append(normalized_task) + state["active_tasks"] = updated_tasks + state["pending_tasks"] = pending_tasks + + if completed_entry is not None: + state["completed_tasks"] = [*(state.get("completed_tasks") or []), completed_entry] + + +def _build_collaboration_final_response(task_results: list[TaskResult | dict[str, Any]]) -> str: + normalized_results = [ + normalize_task_result(item) + for item in task_results + ] + lines = [f"已按协作模式回收 {len(normalized_results)} 个子任务结果:"] + final_completed_summary: str | None = None + for index, result in enumerate(normalized_results, start=1): + owner = result.owner_agent_id or "unknown" + summary = result.summary or "未返回摘要。" + lines.append(f"{index}. [{owner}] ({result.status}) {summary}") + if result.status == "completed" and result.summary: + final_completed_summary = result.summary + if final_completed_summary: + lines.append(f"汇总结论:{final_completed_summary}") + elif normalized_results: + blocked_result = next((item for item in reversed(normalized_results) if item.next_action), None) + if blocked_result and blocked_result.next_action: + lines.append(f"下一步:{blocked_result.next_action}") + else: + lines.append("汇总结论:协作任务已执行,但最终摘要为空。") + return "\n".join(lines) + + +def _verify_collaboration_results( + state: AgentState, + tasks: list[AgentTask], + task_results: list[TaskResult | dict[str, Any]] | None = None, +) -> None: + expected_task_ids = {task.task_id for task in tasks} + normalized_results = [ + normalize_task_result(item) + for item in (task_results if task_results is not None else (state.get("task_results") or [])) + if normalize_task_result(item).task_id in expected_task_ids + ] + result_by_task_id = {item.task_id: item for item in normalized_results} + missing_task_ids = [task.task_id for task in tasks if task.task_id not in result_by_task_id] + failed_or_blocked = [ + item.task_id + for item in normalized_results + if item.status in {"failed", "blocked"} + ] + missing_evidence = [ + item.task_id + for item in normalized_results + if not item.evidence + ] + + verification_evidence = [ + { + "task_id": item.task_id, + "owner_agent_id": item.owner_agent_id, + "status": item.status, + "evidence_count": len(item.evidence or []), + } + for item in normalized_results + ] + if missing_task_ids: + summary = f"协作结果不完整,缺少任务结果: {', '.join(missing_task_ids)}" + verdict = verify_task_result(status="failed", summary=summary, evidence=verification_evidence) + elif failed_or_blocked: + summary = f"协作结果未闭环,存在失败或阻塞任务: {', '.join(failed_or_blocked)}" + verdict = verify_task_result(status="failed", summary=summary, evidence=verification_evidence) + elif missing_evidence: + summary = f"协作结果证据不足,缺少 evidence 的任务: {', '.join(missing_evidence)}" + verdict = verify_task_result(status="failed", summary=summary, evidence=verification_evidence) + else: + summary = f"协作模式已完成 {len(normalized_results)}/{len(tasks)} 个子任务,并为每个子任务回收了结果与 evidence。" + verdict = verify_task_result(status="passed", summary=summary, evidence=verification_evidence) + + updated_state = apply_verification_verdict(state, verdict) + state.update(updated_state) + + +async def _run_collaboration_flow(state: AgentState, user_query: str) -> AgentState: + tasks = _build_collaboration_tasks(user_query) + if len(tasks) < 2: + state["execution_mode"] = "direct" + state["routing_decision"] = {"mode": "direct", "reason": "collaboration_plan_fell_back"} + return state + + base_history = list(state.get("messages", [])) + root_agent_id = str(state.get("root_agent_id") or state.get("agent_id") or AgentRole.MASTER.value) + coordinator_agent_id = str(state.get("agent_id") or AgentRole.MASTER.value) + state["execution_mode"] = "collaboration" + state["routing_decision"] = { + "mode": "collaboration", + "reason": "multi_role_request", + "task_count": len(tasks), + "roles": [task.role for task in tasks], + } + budget_snapshot = _budget_snapshot_from_state( + state, + mode="collaboration", + remaining_parallel_tasks=len(tasks), + metadata={"task_count": len(tasks)}, + ) + _store_budget_snapshot(state, budget_snapshot) + _append_event_trace( + state, + "agent.collaboration.budget.updated", + payload=budget_snapshot, + ) + state["active_tasks"] = [task.model_dump(mode="json") for task in tasks] + parent_task_id = next((task.parent_task_id for task in tasks if task.parent_task_id), None) or "root" + state["task_hierarchy"] = {parent_task_id: [task.task_id for task in tasks]} + state["task_results"] = [] + state["next_step"] = None + + for task in tasks: + state["current_agent"] = AgentRole.MASTER.value + state["agent_id"] = coordinator_agent_id + state["parent_agent_id"] = None + state["root_agent_id"] = root_agent_id + state["collaboration_depth"] = 0 + assigned_role = _assign_agent_for_task_role(task.role) + child_agent_id = _create_child_agent(state, role=assigned_role, task=task) + if child_agent_id is None: + interrupt = _record_interrupt( + state, + reason="spawn_blocked", + task=task, + payload={"target_role": assigned_role.value}, + ) + _record_recovery( + state, + interrupt=interrupt, + strategy="fallback_to_direct_role_execution", + task=task, + payload={"target_role": assigned_role.value}, + ) + child_agent_id = f"blocked-{assigned_role.value}" + state["messages"] = list(base_history) + state["current_agent"] = assigned_role.value + state["agent_id"] = child_agent_id + state["parent_agent_id"] = coordinator_agent_id + state["root_agent_id"] = root_agent_id + state["collaboration_depth"] = 1 + task.thread_id = str(state.get("thread_id") or "") or None + task.message_id = str(state.get("last_message_id") or "") or None + task.message_index = int(state.get("message_sequence") or 0) or None + task.collaboration_budget = state.get("budget_state") + state["turn_context"] = {"collaboration_task": task.model_dump(mode="json")} + state["final_response"] = None + state["verification_status"] = None + state["verification_summary"] = None + state["verification_evidence"] = [] + state["clarification_needed"] = False + state["clarification_question"] = None + state["stop_reason"] = None + start_tool_index = len(state.get("tool_outcomes") or []) + + await _run_sub_commander( + state, + assigned_role, + ROLE_SYSTEM_PROMPTS[assigned_role], + task.goal or user_query, + use_tools=True, + summary_target=None, + ) + + task_result = _collect_task_result(task, state, start_tool_index) + _append_message_trace( + state, + from_agent_id=child_agent_id, + to_agent_id=coordinator_agent_id, + message_type="task_update", + content_summary=task_result.summary or task.title, + task_id=task.task_id, + reply_to_message_id=task.message_id, + payload={"status": task_result.status, "owner_agent_id": task_result.owner_agent_id}, + ) + _append_event_trace( + state, + "agent.message.sent", + payload={"status": task_result.status, "summary": (task_result.summary or "")[:200]}, + task_id=task.task_id, + parent_task_id=task.parent_task_id, + child_task_id=(task.child_task_ids or [None])[0], + message_id=str(state.get("last_message_id") or "") or None, + ) + _apply_task_result_to_state(state, task, task_result) + + if task_result.status != "completed": + break + + state["turn_context"] = None + state["current_agent"] = AgentRole.MASTER.value + state["agent_id"] = coordinator_agent_id + state["parent_agent_id"] = None + state["root_agent_id"] = root_agent_id + state["collaboration_depth"] = 0 + state["final_response"] = _build_collaboration_final_response(state.get("task_results") or []) + _append_event_trace( + state, + "agent.verify.started", + payload={ + "summary_present": bool(state["final_response"]), + "evidence_count": len(state.get("task_results") or []), + }, + ) + _verify_collaboration_results(state, tasks, state.get("task_results") or []) + _append_event_trace( + state, + "agent.verify.completed", + payload={ + "status": state.get("verification_status"), + "summary": state.get("verification_summary"), + "evidence_count": len(state.get("verification_evidence") or []), + }, + severity="error" if state.get("verification_status") == "failed" else "info", + ) + state["messages"] = [*base_history, AIMessage(content=state["final_response"])] + state["should_respond"] = True + return state + + def _can_delegate_within_hop_budget(state: AgentState) -> bool: return _get_state_int(state, "routing_hops") < _get_state_int(state, "max_routing_hops") @@ -1323,7 +2117,7 @@ async def master_node(state: AgentState) -> AgentState: user_messages = _filter_user_messages(state["messages"]) user_query = _stringify_message_content(user_messages[-1].content).strip() if user_messages else "" - state["current_agent"] = _normalize_current_agent(state.get("current_agent")) + state["current_agent"] = state.get("current_agent") or AgentRole.MASTER state["active_agents"] = _normalize_active_agents(state.get("active_agents")) if _is_simple_greeting(user_query): @@ -1350,15 +2144,27 @@ async def master_node(state: AgentState) -> AgentState: state["messages"] = [*state.get("messages", []), AIMessage(content=state["final_response"])] return state + state["current_agent"] = _normalize_current_agent(state.get("current_agent")) + structured_continuity_route = _route_from_structured_continuity(state, user_query) clarification_route = _route_from_clarification_context(state, user_query) if structured_continuity_route is not None: + state["execution_mode"] = "direct" routed_agent = structured_continuity_route elif clarification_route is not None: + state["execution_mode"] = "direct" routed_agent = clarification_route elif _is_short_confirmation(user_query) and _previous_turn_proposed_schedule_creation(state.get("messages", [])): + state["execution_mode"] = "direct" routed_agent = AgentRole.SCHEDULE_PLANNER else: + request_mode, routing_metadata = _select_request_mode(user_query) + state["routing_decision"] = routing_metadata + if request_mode == "collaboration": + collaboration_state = await _run_collaboration_flow(state, user_query) + if collaboration_state.get("execution_mode") == "collaboration" or collaboration_state.get("final_response"): + return collaboration_state + state["execution_mode"] = "direct" routed_agent = _route_agent_from_user_query(user_query) if routed_agent != AgentRole.MASTER: if not _can_delegate_within_hop_budget(state): @@ -1515,10 +2321,13 @@ def get_agent_graph(callbacks: list | None = None): __all__ = [ + "_build_collaboration_tasks", "_build_verifier_hints", "_choose_sub_commander", "_parse_json_action", "_route_agent_from_user_query", + "_select_request_mode", + "_run_collaboration_flow", "_run_sub_commander", "create_agent_graph", "get_agent_graph", diff --git a/backend/app/agents/prompts.py b/backend/app/agents/prompts.py index 79b1b8f..bcd6cd3 100644 --- a/backend/app/agents/prompts.py +++ b/backend/app/agents/prompts.py @@ -324,6 +324,25 @@ ANALYST_INSIGHTS_PROMPT = f"""{JARVIS_PERSONA_PROMPT} """ +COORDINATOR_SYSTEM_PROMPT = f"""{JARVIS_PERSONA_PROMPT} + +你是 Jarvis 的协作协调官,负责把复杂请求收束成最小受控协作,而不是放任系统进入自由 swarm。 + +## 你的职责: +- 先判断当前请求是否真的需要拆解;不需要时应明确建议继续走 direct +- 只有在明显多步骤、跨领域、需要多角色配合时,才拆成 2~4 个子任务 +- 每个子任务必须清晰写出 `title`、`role`、`goal`、`expected_evidence` +- 角色建议只能来自现有 top-level agent:`schedule_planner`、`librarian`、`analyst`、`executor` +- 汇总时基于子任务结果回收,不依赖单点硬编码拼接 + +## 边界: +- 禁止无限递归拆分 +- 禁止创建新的 runtime agent / worker +- 禁止把一个简单请求硬拆成多个空泛步骤 +- 如果证据不足、子任务未闭环,必须把风险明确暴露出来 +""" + + VERIFIER_PROMPT = f"""{JARVIS_PERSONA_PROMPT} 你是 Jarvis 的验证官,负责对执行结果做最小但明确的核验。 diff --git a/backend/app/agents/registry/builtins.py b/backend/app/agents/registry/builtins.py index ea9d663..7f9dd4c 100644 --- a/backend/app/agents/registry/builtins.py +++ b/backend/app/agents/registry/builtins.py @@ -57,6 +57,19 @@ TOP_LEVEL_AGENT_ROUTING_HINTS: dict[str, tuple[str, ...]] = { ), } +TOP_LEVEL_AGENT_ALLOWED_SPAWN_ROLES: dict[str, tuple[str, ...]] = { + AgentRole.MASTER.value: ( + AgentRole.SCHEDULE_PLANNER.value, + AgentRole.EXECUTOR.value, + AgentRole.LIBRARIAN.value, + AgentRole.ANALYST.value, + ), + AgentRole.SCHEDULE_PLANNER.value: (AgentRole.SCHEDULE_PLANNER.value,), + AgentRole.EXECUTOR.value: (AgentRole.EXECUTOR.value,), + AgentRole.LIBRARIAN.value: (AgentRole.LIBRARIAN.value,), + AgentRole.ANALYST.value: (AgentRole.ANALYST.value,), +} + SUB_COMMANDER_PARENT_AGENT_IDS: dict[str, str] = { "schedule_analysis": AgentRole.SCHEDULE_PLANNER.value, "schedule_planning": AgentRole.SCHEDULE_PLANNER.value, @@ -77,6 +90,8 @@ BUILTIN_AGENT_MANIFESTS: tuple[AgentManifest, ...] = tuple( system_prompt_key=role.value, routing_hints=list(TOP_LEVEL_AGENT_ROUTING_HINTS[role.value]), default_sub_commanders=list(TOP_LEVEL_AGENT_DEFAULT_SUB_COMMANDERS[role.value]), + can_spawn_children=bool(TOP_LEVEL_AGENT_ALLOWED_SPAWN_ROLES[role.value]), + allowed_spawn_role_values=list(TOP_LEVEL_AGENT_ALLOWED_SPAWN_ROLES[role.value]), skill_context_key=role.value.replace("agent_", ""), ) for role in AgentRole diff --git a/backend/app/agents/registry/indexes.py b/backend/app/agents/registry/indexes.py index 93fbbdf..19dd8d3 100644 --- a/backend/app/agents/registry/indexes.py +++ b/backend/app/agents/registry/indexes.py @@ -16,6 +16,7 @@ from app.agents.registry.models import ( @dataclass(frozen=True) class RegistryIndexes: agent_by_id: Mapping[str, AgentManifest] + agent_by_role_value: Mapping[str, AgentManifest] sub_commander_by_id: Mapping[str, SubCommanderManifest] capability_by_id: Mapping[str, CapabilityManifest] specialist_template_by_id: Mapping[str, SpecialistTemplateManifest] @@ -24,6 +25,7 @@ class RegistryIndexes: skill_context_key_by_agent_id: Mapping[str, str] capability_id_by_tool_name: Mapping[str, str] capability_ids_by_sub_commander_id: Mapping[str, tuple[str, ...]] + spawnable_role_values_by_agent_id: Mapping[str, tuple[str, ...]] def summarize_registry_indexes(indexes: RegistryIndexes) -> dict[str, int]: @@ -50,6 +52,9 @@ def build_registry_indexes(bundle: RegistryBundle) -> RegistryIndexes: return RegistryIndexes( agent_by_id=MappingProxyType(agent_by_id), + agent_by_role_value=MappingProxyType({ + agent.role_value: agent for agent in bundle.agents + }), sub_commander_by_id=MappingProxyType(sub_commander_by_id), capability_by_id=MappingProxyType(capability_by_id), specialist_template_by_id=MappingProxyType(specialist_template_by_id), @@ -73,4 +78,9 @@ def build_registry_indexes(bundle: RegistryBundle) -> RegistryIndexes: sub_commander.sub_commander_id: tuple(sub_commander.capability_ids) for sub_commander in bundle.sub_commanders }), + spawnable_role_values_by_agent_id=MappingProxyType({ + agent.agent_id: tuple(agent.allowed_spawn_role_values) + for agent in bundle.agents + if agent.can_spawn_children and agent.allowed_spawn_role_values + }), ) diff --git a/backend/app/agents/registry/models.py b/backend/app/agents/registry/models.py index c102feb..366a7ee 100644 --- a/backend/app/agents/registry/models.py +++ b/backend/app/agents/registry/models.py @@ -1,6 +1,6 @@ from enum import Enum -from pydantic import BaseModel +from pydantic import BaseModel, Field class PermissionClass(str, Enum): @@ -23,6 +23,8 @@ class AgentManifest(BaseModel): system_prompt_key: str routing_hints: list[str] default_sub_commanders: list[str] + can_spawn_children: bool = False + allowed_spawn_role_values: list[str] = Field(default_factory=list) skill_context_key: str | None = None continuity_policy: str | None = None clarification_policy: str | None = None diff --git a/backend/app/agents/schemas/__init__.py b/backend/app/agents/schemas/__init__.py index ad3609f..a3cab43 100644 --- a/backend/app/agents/schemas/__init__.py +++ b/backend/app/agents/schemas/__init__.py @@ -1,10 +1,25 @@ from app.agents.schemas.event import AgentEvent -from app.agents.schemas.task import AgentTask, TaskResult, TaskLifecycleStatus, VerificationStatus +from app.agents.schemas.message import AgentMessage +from app.agents.schemas.task import ( + AgentTask, + CollaborationBudget, + InterruptRecord, + RecoveryRecord, + TaskLifecycleStatus, + TaskResult, + TaskResultStatus, + VerificationStatus, +) __all__ = [ "AgentEvent", + "AgentMessage", "AgentTask", + "CollaborationBudget", + "InterruptRecord", + "RecoveryRecord", "TaskLifecycleStatus", "TaskResult", + "TaskResultStatus", "VerificationStatus", ] diff --git a/backend/app/agents/schemas/event.py b/backend/app/agents/schemas/event.py index f08d1e1..ebb5095 100644 --- a/backend/app/agents/schemas/event.py +++ b/backend/app/agents/schemas/event.py @@ -11,6 +11,18 @@ AgentEventType = Literal[ "agent.tool.result", "agent.verify.started", "agent.verify.completed", + "agent.created", + "agent.spawn.blocked", + "agent.message.sent", + "agent.message.received", + "agent.interrupt.requested", + "agent.interrupt.completed", + "agent.recovery.started", + "agent.recovery.completed", + "agent.task.interrupted", + "agent.task.recovered", + "agent.task.reassigned", + "agent.collaboration.budget.updated", "agent.error", ] AgentEventSeverity = Literal["info", "warning", "error"] @@ -24,5 +36,11 @@ class AgentEvent(BaseModel): agent_id: str | None = None sub_commander_id: str | None = None task_id: str | None = None + parent_task_id: str | None = None + child_task_id: str | None = None + thread_id: str | None = None + message_id: str | None = None + interrupt_id: str | None = None + recovery_id: str | None = None payload: dict[str, Any] = Field(default_factory=dict) severity: AgentEventSeverity = "info" diff --git a/backend/app/agents/schemas/message.py b/backend/app/agents/schemas/message.py new file mode 100644 index 0000000..8da49a0 --- /dev/null +++ b/backend/app/agents/schemas/message.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +from datetime import datetime, timezone +from typing import Any, Literal + +from pydantic import BaseModel, Field + + +AgentMessageType = Literal[ + "task_request", + "task_update", + "handoff", + "verification_request", + "verification_feedback", + "interrupt_notice", +] + + +class AgentMessage(BaseModel): + message_id: str + thread_id: str + from_agent_id: str + to_agent_id: str + task_id: str | None = None + reply_to_message_id: str | None = None + message_type: AgentMessageType = "task_update" + content_summary: str + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + payload: dict[str, Any] = Field(default_factory=dict) diff --git a/backend/app/agents/schemas/task.py b/backend/app/agents/schemas/task.py index 1f32dd6..dbea254 100644 --- a/backend/app/agents/schemas/task.py +++ b/backend/app/agents/schemas/task.py @@ -8,6 +8,41 @@ from pydantic import BaseModel, Field TaskLifecycleStatus = Literal["pending", "in_progress", "completed", "failed", "blocked"] VerificationStatus = Literal["passed", "failed", "skipped"] +TaskResultStatus = Literal["completed", "failed", "blocked", "passed", "skipped"] +InterruptStatus = Literal["requested", "acknowledged", "resolved"] +BudgetMode = Literal["direct", "collaboration"] + + +class InterruptRecord(BaseModel): + interrupt_id: str + reason: str + status: InterruptStatus = "requested" + requested_by: str | None = None + source_event_id: str | None = None + requested_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + payload: dict[str, Any] = Field(default_factory=dict) + + +class RecoveryRecord(BaseModel): + recovery_id: str + source_interrupt_id: str | None = None + strategy: str | None = None + resumed_from_task_id: str | None = None + resumed_from_thread_id: str | None = None + recovered_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + payload: dict[str, Any] = Field(default_factory=dict) + + +class CollaborationBudget(BaseModel): + mode: BudgetMode = "direct" + max_parallel_tasks: int | None = None + remaining_parallel_tasks: int | None = None + max_tool_calls: int | None = None + remaining_tool_calls: int | None = None + max_iterations: int | None = None + remaining_iterations: int | None = None + escalation_threshold: int | None = None + metadata: dict[str, Any] = Field(default_factory=dict) class AgentTask(BaseModel): @@ -17,8 +52,16 @@ class AgentTask(BaseModel): owner_agent_id: str | None = None role: str | None = None goal: str | None = None + parent_task_id: str | None = None + child_task_ids: list[str] = Field(default_factory=list) + thread_id: str | None = None + message_id: str | None = None + message_index: int | None = None expected_evidence: list[dict[str, Any]] = Field(default_factory=list) evidence: list[dict[str, Any]] = Field(default_factory=list) + interrupt_records: list[InterruptRecord | dict[str, Any]] = Field(default_factory=list) + recovery_records: list[RecoveryRecord | dict[str, Any]] = Field(default_factory=list) + collaboration_budget: CollaborationBudget | dict[str, Any] | None = None result_summary: str | None = None created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) @@ -26,7 +69,17 @@ class AgentTask(BaseModel): class TaskResult(BaseModel): task_id: str - status: VerificationStatus + status: TaskResultStatus summary: str | None = None evidence: list[dict[str, Any]] = Field(default_factory=list) + owner_agent_id: str | None = None + parent_task_id: str | None = None + child_task_ids: list[str] = Field(default_factory=list) + thread_id: str | None = None + message_id: str | None = None + message_index: int | None = None + interrupt_records: list[InterruptRecord | dict[str, Any]] = Field(default_factory=list) + recovery_records: list[RecoveryRecord | dict[str, Any]] = Field(default_factory=list) + budget_snapshot: CollaborationBudget | dict[str, Any] | None = None + next_action: str | None = None output_data: dict[str, Any] | None = None diff --git a/backend/app/agents/state.py b/backend/app/agents/state.py index e1219ed..f7959fb 100644 --- a/backend/app/agents/state.py +++ b/backend/app/agents/state.py @@ -3,8 +3,9 @@ from enum import Enum from typing import Annotated, Any, Literal, TypedDict from app.agents.schemas.event import AgentEvent -from app.agents.schemas.task import AgentTask, TaskResult, VerificationStatus -from langchain_core.messages import BaseMessage +from app.agents.schemas.message import AgentMessage +from app.agents.schemas.task import AgentTask, CollaborationBudget, InterruptRecord, RecoveryRecord, TaskResult, VerificationStatus +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage from langgraph.graph.message import add_messages @@ -24,12 +25,27 @@ class ConversationTurn: model: str | None = None +def turn_to_message(turn: ConversationTurn) -> BaseMessage: + if turn.role == "user": + return HumanMessage(content=turn.content) + return AIMessage(content=turn.content) + + class AgentState(TypedDict): messages: Annotated[list[BaseMessage], add_messages] user_id: str conversation_id: str + parent_conversation_id: str | None + thread_id: str | None + last_message_id: str | None + message_sequence: int + agent_id: str | None + parent_agent_id: str | None + root_agent_id: str | None + collaboration_depth: int + spawned_agent_ids: list[str] - execution_mode: Literal["direct", "delegated", "verified"] + execution_mode: Literal["direct", "collaboration", "delegated", "verified"] current_agent: str | None next_step: str | None active_agents: list[AgentRole] @@ -38,11 +54,16 @@ class AgentState(TypedDict): sub_commander_trace: list[dict[str, Any]] agent_trace: list[str] event_trace: list[AgentEvent | dict[str, Any]] + message_trace: list[AgentMessage | dict[str, Any]] pending_tasks: list[dict[str, Any]] completed_tasks: list[dict[str, Any]] active_tasks: list[AgentTask | dict[str, Any]] task_results: list[TaskResult | dict[str, Any]] + task_hierarchy: dict[str, list[str]] + interrupted_tasks: list[InterruptRecord | dict[str, Any]] + recovery_trace: list[RecoveryRecord | dict[str, Any]] + recovery_points: list[dict[str, Any]] tool_calls: list[dict[str, Any]] last_tool_result: str | None action_results: list[dict[str, Any]] @@ -54,7 +75,8 @@ class AgentState(TypedDict): verification_status: VerificationStatus | None verification_summary: str | None verification_evidence: list[dict[str, Any]] - budget_state: dict[str, Any] | None + budget_state: CollaborationBudget | dict[str, Any] | None + collaboration_budget_history: list[CollaborationBudget | dict[str, Any]] tool_strategy_used: str | None tool_round_count: int @@ -102,6 +124,15 @@ def initial_state(user_id: str, conversation_id: str) -> AgentState: messages=[], user_id=user_id, conversation_id=conversation_id, + parent_conversation_id=None, + thread_id=None, + last_message_id=None, + message_sequence=0, + agent_id=AgentRole.MASTER.value, + parent_agent_id=None, + root_agent_id=AgentRole.MASTER.value, + collaboration_depth=0, + spawned_agent_ids=[], execution_mode="direct", current_agent=AgentRole.MASTER.value, next_step=None, @@ -111,10 +142,15 @@ def initial_state(user_id: str, conversation_id: str) -> AgentState: sub_commander_trace=[], agent_trace=[AgentRole.MASTER.value], event_trace=[], + message_trace=[], pending_tasks=[], completed_tasks=[], active_tasks=[], task_results=[], + task_hierarchy={}, + interrupted_tasks=[], + recovery_trace=[], + recovery_points=[], tool_calls=[], last_tool_result=None, action_results=[], @@ -126,6 +162,7 @@ def initial_state(user_id: str, conversation_id: str) -> AgentState: verification_summary=None, verification_evidence=[], budget_state=None, + collaboration_budget_history=[], tool_strategy_used=None, tool_round_count=0, max_tool_rounds=2, diff --git a/backend/app/agents/verifier.py b/backend/app/agents/verifier.py index 7aa3e83..653cf7a 100644 --- a/backend/app/agents/verifier.py +++ b/backend/app/agents/verifier.py @@ -1,10 +1,10 @@ from __future__ import annotations -from typing import Any +from typing import Any, cast from pydantic import BaseModel, Field -from app.agents.schemas.task import AgentTask, TaskResult, VerificationStatus +from app.agents.schemas.task import AgentTask, TaskResult, TaskResultStatus, VerificationStatus from app.agents.state import AgentState @@ -14,6 +14,34 @@ class VerificationVerdict(BaseModel): evidence: list[dict[str, Any]] = Field(default_factory=list) +def normalize_task_result( + task_result: TaskResult | dict[str, Any], + *, + default_task_id: str | None = None, +) -> TaskResult: + payload = task_result.model_dump(mode="json") if isinstance(task_result, TaskResult) else dict(task_result or {}) + normalized_status = payload.get("status") + if normalized_status not in {"completed", "failed", "blocked", "passed", "skipped"}: + normalized_status = "failed" + return TaskResult( + task_id=str(payload.get("task_id") or default_task_id or "unknown-task"), + status=cast(TaskResultStatus, normalized_status), + summary=payload.get("summary"), + evidence=list(payload.get("evidence") or []), + owner_agent_id=payload.get("owner_agent_id"), + parent_task_id=payload.get("parent_task_id"), + child_task_ids=list(payload.get("child_task_ids") or []), + thread_id=payload.get("thread_id"), + message_id=payload.get("message_id"), + message_index=payload.get("message_index") if isinstance(payload.get("message_index"), int) else None, + interrupt_records=list(payload.get("interrupt_records") or []), + recovery_records=list(payload.get("recovery_records") or []), + budget_snapshot=payload.get("budget_snapshot") if isinstance(payload.get("budget_snapshot"), dict) else None, + next_action=payload.get("next_action"), + output_data=payload.get("output_data") if isinstance(payload.get("output_data"), dict) else None, + ) + + def verify_task_result( *, task: AgentTask | dict[str, Any] | None = None, @@ -30,8 +58,13 @@ def verify_task_result( if status is not None: return VerificationVerdict(status=status, summary=normalized_summary, evidence=normalized_evidence) - if normalized_result.get("status") in {"passed", "failed", "skipped"}: - inferred_status = normalized_result["status"] + normalized_status = normalized_result.get("status") + if normalized_status in {"passed", "failed", "skipped"}: + inferred_status = normalized_status + elif normalized_status == "completed": + inferred_status = "passed" + elif normalized_status == "blocked": + inferred_status = "skipped" elif normalized_result.get("success") is True: inferred_status = "passed" elif normalized_result.get("success") is False: @@ -57,4 +90,4 @@ def apply_verification_verdict(state: AgentState, verdict: VerificationVerdict) return AgentState(**next_state) -__all__ = ["VerificationVerdict", "apply_verification_verdict", "verify_task_result"] +__all__ = ["VerificationVerdict", "apply_verification_verdict", "normalize_task_result", "verify_task_result"] diff --git a/backend/app/database.py b/backend/app/database.py index b6927ff..fbd1552 100644 --- a/backend/app/database.py +++ b/backend/app/database.py @@ -1,10 +1,13 @@ -from sqlalchemy import text -from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker -from sqlalchemy.orm import DeclarativeBase -from app.config import settings +from collections.abc import AsyncGenerator import os import re +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.orm import DeclarativeBase + +from app.config import settings + os.makedirs(settings.DATA_DIR, exist_ok=True) engine = create_async_engine( @@ -24,12 +27,9 @@ class Base(DeclarativeBase): pass -async def get_db() -> AsyncSession: +async def get_db() -> AsyncGenerator[AsyncSession, None]: async with async_session() as session: - try: - yield session - finally: - await session.close() + yield session async def init_db(): @@ -37,6 +37,7 @@ async def init_db(): await conn.run_sync(Base.metadata.create_all) await ensure_log_columns(conn) await ensure_message_columns(conn) + await ensure_conversation_columns(conn) await ensure_document_columns(conn) await ensure_user_columns(conn) await ensure_forum_columns(conn) @@ -79,6 +80,20 @@ async def ensure_message_columns(conn): await conn.execute(text(ddl)) +async def ensure_conversation_columns(conn): + rows = await _get_table_info(conn, 'conversations') + if not rows: + return + + columns = {row[1] for row in rows} + required_columns = { + 'agent_state': "ALTER TABLE conversations ADD COLUMN agent_state JSON", + } + for column, ddl in required_columns.items(): + if column not in columns: + await conn.execute(text(ddl)) + + async def ensure_document_columns(conn): result = await conn.execute(text("PRAGMA table_info(documents)")) rows = result.fetchall() diff --git a/backend/app/routers/agent.py b/backend/app/routers/agent.py index 6f74760..487559f 100644 --- a/backend/app/routers/agent.py +++ b/backend/app/routers/agent.py @@ -1,12 +1,33 @@ -from fastapi import APIRouter, Depends, HTTPException -from sqlalchemy.ext.asyncio import AsyncSession +from datetime import datetime +from typing import Any + +from fastapi import APIRouter, Depends, HTTPException, Query from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + from app.database import get_db from app.models.agent import Agent +from app.models.conversation import Conversation from app.models.skill import Skill from app.models.user import User from app.routers.auth import get_current_user -from app.schemas.agent import AgentCreate, AgentOut, AgentStats, AgentConfigUpdate, AgentConfigOut +from app.schemas.agent import ( + AgentConfigOut, + AgentConfigUpdate, + AgentCreate, + AgentOut, + AgentStats, + AgentVisibilityEvidenceOut, + AgentVisibilityEventsResponse, + AgentVisibilityEventOut, + AgentVisibilityTaskSummaryOut, + AgentVisibilityThreadMessageOut, + AgentVisibilityThreadOut, + AgentVisibilityTopologyNodeOut, + AgentVisibilityTopologyOut, + AgentVisibilityVerifierOut, +) +from app.services.agent_service import _extract_continuity_snapshot router = APIRouter(prefix="/api/agents", tags=["Agent"]) @@ -21,6 +42,147 @@ SUB_COMMANDERS_BY_ROLE = { "librarian": ["librarian_retrieval", "librarian_graph"], "analyst": ["analyst_progress", "analyst_insights"], } +ALLOWED_AGENT_ROLES = set(DEFAULT_AGENT_ROLES) | { + role + for sub_roles in SUB_COMMANDERS_BY_ROLE.values() + for role in sub_roles +} + + +def _parse_visibility_datetime(value: str | None) -> datetime | None: + if value is None: + return None + try: + return datetime.fromisoformat(value.replace("Z", "+00:00")) + except ValueError as exc: + raise HTTPException(status_code=400, detail="时间参数必须是 ISO 8601 格式") from exc + + +async def _get_visibility_state( + conversation_id: str, + *, + current_user: User, + db: AsyncSession, +) -> dict[str, Any]: + result = await db.execute( + select(Conversation).where( + Conversation.id == conversation_id, + Conversation.user_id == current_user.id, + ) + ) + conversation = result.scalar_one_or_none() + if conversation is None: + raise HTTPException(status_code=404, detail="对话不存在") + snapshot = _extract_continuity_snapshot(conversation.agent_state) + if snapshot is None: + raise HTTPException(status_code=404, detail="当前会话暂无可视化运行时数据") + return snapshot + + +def _coerce_event_payload(event: dict[str, Any]) -> AgentVisibilityEventOut: + return AgentVisibilityEventOut.model_validate(event) + + +def _filter_events( + events: list[dict[str, Any]], + *, + agent_id: str | None, + thread_id: str | None, + event_type: str | None, + started_after: datetime | None, + ended_before: datetime | None, +) -> list[dict[str, Any]]: + filtered: list[dict[str, Any]] = [] + for event in events: + if agent_id and event.get("agent_id") != agent_id: + continue + if thread_id and event.get("thread_id") != thread_id: + continue + if event_type and event.get("event_type") != event_type: + continue + timestamp_raw = event.get("timestamp") + timestamp = None + if isinstance(timestamp_raw, str): + try: + timestamp = datetime.fromisoformat(timestamp_raw.replace("Z", "+00:00")) + except ValueError: + timestamp = None + if started_after and timestamp and timestamp < started_after: + continue + if ended_before and timestamp and timestamp > ended_before: + continue + filtered.append(event) + return filtered + + +def _summarize_tasks(tasks: list[dict[str, Any]], task_results: list[dict[str, Any]]) -> list[AgentVisibilityTaskSummaryOut]: + result_by_task_id = {item.get("task_id"): item for item in task_results} + summaries: list[AgentVisibilityTaskSummaryOut] = [] + for task in tasks: + task_id = str(task.get("task_id") or "") + result = result_by_task_id.get(task_id) or {} + evidence = result.get("evidence") or task.get("evidence") or [] + summaries.append( + AgentVisibilityTaskSummaryOut( + task_id=task_id, + role=task.get("role"), + owner_agent_id=task.get("owner_agent_id") or result.get("owner_agent_id"), + status=result.get("status") or task.get("status"), + summary=result.get("summary") or task.get("result_summary"), + evidence_count=len(evidence), + ) + ) + return summaries + + +def _build_topology_nodes( + state: dict[str, Any], + tasks: list[dict[str, Any]], + task_results: list[dict[str, Any]], +) -> list[AgentVisibilityTopologyNodeOut]: + task_counts: dict[str, int] = {} + completed_counts: dict[str, int] = {} + for task in tasks: + owner = str(task.get("owner_agent_id") or "") + if owner: + task_counts[owner] = task_counts.get(owner, 0) + 1 + for result in task_results: + owner = str(result.get("owner_agent_id") or "") + if owner and result.get("status") == "completed": + completed_counts[owner] = completed_counts.get(owner, 0) + 1 + + root_agent_id = str(state.get("root_agent_id") or state.get("agent_id") or "") or None + current_agent = str(state.get("current_agent") or "") or None + nodes: dict[str, AgentVisibilityTopologyNodeOut] = {} + if root_agent_id: + nodes[root_agent_id] = AgentVisibilityTopologyNodeOut( + agent_id=root_agent_id, + role=root_agent_id.split("-")[0], + parent_agent_id=None, + source="root", + task_count=task_counts.get(root_agent_id, 0), + completed_task_count=completed_counts.get(root_agent_id, 0), + ) + for agent_id in state.get("spawned_agent_ids") or []: + agent_id = str(agent_id) + nodes[agent_id] = AgentVisibilityTopologyNodeOut( + agent_id=agent_id, + role=agent_id.split("-")[0], + parent_agent_id=root_agent_id, + source="spawned", + task_count=task_counts.get(agent_id, 0), + completed_task_count=completed_counts.get(agent_id, 0), + ) + if current_agent and current_agent not in nodes: + nodes[current_agent] = AgentVisibilityTopologyNodeOut( + agent_id=current_agent, + role=current_agent.split("-")[0], + parent_agent_id=None if current_agent == root_agent_id else root_agent_id, + source="current", + task_count=task_counts.get(current_agent, 0), + completed_task_count=completed_counts.get(current_agent, 0), + ) + return list(nodes.values()) def record_agent_call(agent_id: str): @@ -83,6 +245,7 @@ async def get_agent_hierarchy_stats( @router.get("/config/{agent_id}", response_model=AgentConfigOut) async def get_agent_config( agent_id: str, + current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): result = await db.execute(select(Agent).where(Agent.role == agent_id)) @@ -172,12 +335,159 @@ async def update_agent_config( ) +@router.get("/visibility/events", response_model=AgentVisibilityEventsResponse) +async def get_visibility_events( + conversation_id: str, + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), + agent_id: str | None = None, + thread_id: str | None = None, + event_type: str | None = None, + started_after: str | None = None, + ended_before: str | None = None, + limit: int = Query(default=50, ge=1, le=200), + offset: int = Query(default=0, ge=0), +): + state = await _get_visibility_state(conversation_id, current_user=current_user, db=db) + events = [dict(item) for item in state.get("event_trace") or []] + filtered = _filter_events( + events, + agent_id=agent_id, + thread_id=thread_id, + event_type=event_type, + started_after=_parse_visibility_datetime(started_after), + ended_before=_parse_visibility_datetime(ended_before), + ) + paged = filtered[offset:offset + limit] + return AgentVisibilityEventsResponse( + conversation_id=conversation_id, + total=len(filtered), + limit=limit, + offset=offset, + items=[_coerce_event_payload(item) for item in paged], + ) + + +@router.get("/visibility/topology", response_model=AgentVisibilityTopologyOut) +async def get_visibility_topology( + conversation_id: str, + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + state = await _get_visibility_state(conversation_id, current_user=current_user, db=db) + tasks = [dict(item) for item in state.get("active_tasks") or []] + task_results = [dict(item) for item in state.get("task_results") or []] + nodes = _build_topology_nodes(state, tasks, task_results) + root_agent_id = str(state.get("root_agent_id") or state.get("agent_id") or "") or None + edges = [ + {"parent_agent_id": root_agent_id, "child_agent_id": node.agent_id} + for node in nodes + if node.parent_agent_id and root_agent_id and node.agent_id != root_agent_id + ] + return AgentVisibilityTopologyOut( + conversation_id=conversation_id, + root_agent_id=root_agent_id, + current_agent=str(state.get("current_agent") or "") or None, + nodes=nodes, + edges=edges, + tasks=_summarize_tasks(tasks, task_results), + task_hierarchy=dict(state.get("task_hierarchy") or {}), + ) + + +@router.get("/visibility/tasks/{task_id}/evidence", response_model=AgentVisibilityEvidenceOut) +async def get_visibility_task_evidence( + task_id: str, + conversation_id: str, + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + state = await _get_visibility_state(conversation_id, current_user=current_user, db=db) + tasks = [dict(item) for item in state.get("active_tasks") or []] + task = next((item for item in tasks if item.get("task_id") == task_id), None) + task_results = [dict(item) for item in state.get("task_results") or []] + result = next((item for item in task_results if item.get("task_id") == task_id), None) + if task is None and result is None: + raise HTTPException(status_code=404, detail="任务不存在") + tool_outcomes = [ + dict(evidence) + for evidence in (result or {}).get("evidence") or [] + if isinstance(evidence, dict) and evidence.get("tool_name") + ] + verification_entry = next( + ( + dict(evidence) + for evidence in (result or {}).get("evidence") or [] + if isinstance(evidence, dict) and evidence.get("type") == "verification" + ), + None, + ) + verifier = { + "status": (verification_entry or {}).get("status"), + "summary": (verification_entry or {}).get("summary"), + "evidence": [dict(item) for item in state.get("verification_evidence") or [] if item.get("task_id") == task_id], + } + return AgentVisibilityEvidenceOut( + conversation_id=conversation_id, + task_id=task_id, + task=task, + result=result, + tool_outcomes=tool_outcomes, + verifier=verifier, + ) + + +@router.get("/visibility/threads/{thread_id}/messages", response_model=AgentVisibilityThreadOut) +async def get_visibility_thread_messages( + thread_id: str, + conversation_id: str, + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + state = await _get_visibility_state(conversation_id, current_user=current_user, db=db) + items = [ + AgentVisibilityThreadMessageOut.model_validate(item) + for item in state.get("message_trace") or [] + if isinstance(item, dict) and item.get("thread_id") == thread_id + ] + if not items: + raise HTTPException(status_code=404, detail="线程不存在") + return AgentVisibilityThreadOut( + conversation_id=conversation_id, + thread_id=thread_id, + total=len(items), + items=items, + ) + + +@router.get("/visibility/verifier", response_model=AgentVisibilityVerifierOut) +async def get_visibility_verifier( + conversation_id: str, + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + state = await _get_visibility_state(conversation_id, current_user=current_user, db=db) + return AgentVisibilityVerifierOut( + conversation_id=conversation_id, + status=state.get("verification_status"), + summary=state.get("verification_summary"), + evidence=list(state.get("verification_evidence") or []), + ) + + @router.post("", response_model=AgentOut, status_code=201) async def create_agent( data: AgentCreate, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): + if not current_user.is_superuser: + raise HTTPException(status_code=403, detail="仅管理员可创建 Agent") + if not data.spawn_permission: + raise HTTPException(status_code=400, detail="缺少 spawn_permission,禁止直接创建 runtime agent") + if data.role not in ALLOWED_AGENT_ROLES: + raise HTTPException(status_code=400, detail="不支持的 Agent 角色") + agent = Agent( name=data.name, role=data.role, @@ -193,6 +503,7 @@ async def create_agent( @router.get("/{agent_id}", response_model=AgentOut) async def get_agent( agent_id: str, + current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): result = await db.execute(select(Agent).where(Agent.id == agent_id)) diff --git a/backend/app/schemas/agent.py b/backend/app/schemas/agent.py index 07b408c..651ca6f 100644 --- a/backend/app/schemas/agent.py +++ b/backend/app/schemas/agent.py @@ -1,4 +1,7 @@ -from pydantic import BaseModel +from datetime import datetime +from typing import Any + +from pydantic import BaseModel, Field class AgentCreate(BaseModel): @@ -6,6 +9,7 @@ class AgentCreate(BaseModel): role: str description: str | None = None system_prompt: str + spawn_permission: bool = False class AgentOut(BaseModel): @@ -55,3 +59,93 @@ class AgentConfigOut(BaseModel): selected_skill_ids: list[str] model_config = {"from_attributes": True} + + +class AgentVisibilityEventOut(BaseModel): + event_id: str + event_type: str + timestamp: datetime + conversation_id: str | None = None + agent_id: str | None = None + sub_commander_id: str | None = None + task_id: str | None = None + parent_task_id: str | None = None + child_task_id: str | None = None + thread_id: str | None = None + message_id: str | None = None + interrupt_id: str | None = None + recovery_id: str | None = None + payload: dict[str, Any] = Field(default_factory=dict) + severity: str = "info" + + +class AgentVisibilityEventsResponse(BaseModel): + conversation_id: str + total: int + limit: int + offset: int + items: list[AgentVisibilityEventOut] + + +class AgentVisibilityTaskSummaryOut(BaseModel): + task_id: str + role: str | None = None + owner_agent_id: str | None = None + status: str | None = None + summary: str | None = None + evidence_count: int = 0 + + +class AgentVisibilityTopologyNodeOut(BaseModel): + agent_id: str + role: str | None = None + parent_agent_id: str | None = None + source: str + task_count: int = 0 + completed_task_count: int = 0 + + +class AgentVisibilityTopologyOut(BaseModel): + conversation_id: str + root_agent_id: str | None = None + current_agent: str | None = None + nodes: list[AgentVisibilityTopologyNodeOut] + edges: list[dict[str, str]] + tasks: list[AgentVisibilityTaskSummaryOut] + task_hierarchy: dict[str, list[str]] = Field(default_factory=dict) + + +class AgentVisibilityEvidenceOut(BaseModel): + conversation_id: str + task_id: str + task: dict[str, Any] | None = None + result: dict[str, Any] | None = None + tool_outcomes: list[dict[str, Any]] = Field(default_factory=list) + verifier: dict[str, Any] + + +class AgentVisibilityThreadMessageOut(BaseModel): + message_id: str + thread_id: str + from_agent_id: str + to_agent_id: str + task_id: str | None = None + reply_to_message_id: str | None = None + message_type: str + content_summary: str + created_at: datetime + payload: dict[str, Any] = Field(default_factory=dict) + + +class AgentVisibilityThreadOut(BaseModel): + conversation_id: str + thread_id: str + total: int + items: list[AgentVisibilityThreadMessageOut] + + +class AgentVisibilityVerifierOut(BaseModel): + conversation_id: str + status: str | None = None + summary: str | None = None + evidence: list[dict[str, Any]] = Field(default_factory=list) diff --git a/backend/app/services/agent_service.py b/backend/app/services/agent_service.py index 6317b9e..91da89e 100644 --- a/backend/app/services/agent_service.py +++ b/backend/app/services/agent_service.py @@ -134,6 +134,27 @@ _CONTINUITY_SNAPSHOT_FIELDS = ( "current_agent", "next_step", "agent_trace", + "agent_id", + "parent_agent_id", + "root_agent_id", + "collaboration_depth", + "thread_id", + "last_message_id", + "message_sequence", + "spawned_agent_ids", + "current_sub_commander", + "active_sub_commanders", + "sub_commander_trace", + "event_trace", + "message_trace", + "active_tasks", + "task_results", + "task_hierarchy", + "verification_status", + "verification_summary", + "verification_evidence", + "budget_state", + "collaboration_budget_history", ) diff --git a/backend/tests/backend/app/agents/test_agent_schemas.py b/backend/tests/backend/app/agents/test_agent_schemas.py new file mode 100644 index 0000000..e85b151 --- /dev/null +++ b/backend/tests/backend/app/agents/test_agent_schemas.py @@ -0,0 +1,167 @@ +from app.agents.schemas.event import AgentEvent +from app.agents.schemas.task import AgentTask, CollaborationBudget, InterruptRecord, RecoveryRecord, TaskResult + + +def test_agent_task_accepts_day1_fields(): + task = AgentTask( + task_id="task-1", + title="Verify foundation", + status="in_progress", + owner_agent_id="executor", + role="verifier", + goal="check output", + expected_evidence=[{"type": "assertion"}], + evidence=[{"type": "log"}], + result_summary="running", + ) + + assert task.task_id == "task-1" + assert task.owner_agent_id == "executor" + assert task.status == "in_progress" + assert task.expected_evidence == [{"type": "assertion"}] + assert task.evidence == [{"type": "log"}] + assert task.result_summary == "running" + + +def test_agent_task_accepts_day3_runtime_fields(): + task = AgentTask( + task_id="task-2", + title="Recover interrupted collaboration", + owner_agent_id="executor", + parent_task_id="task-1", + child_task_ids=["task-2a"], + thread_id="thread-1", + message_id="msg-1", + message_index=2, + interrupt_records=[ + InterruptRecord( + interrupt_id="interrupt-1", + reason="manual stop", + requested_by="coordinator", + ) + ], + recovery_records=[ + RecoveryRecord( + recovery_id="recovery-1", + source_interrupt_id="interrupt-1", + resumed_from_task_id="task-2", + resumed_from_thread_id="thread-1", + strategy="resume_from_checkpoint", + ) + ], + collaboration_budget=CollaborationBudget( + mode="collaboration", + max_parallel_tasks=2, + remaining_parallel_tasks=1, + max_tool_calls=4, + remaining_tool_calls=3, + max_iterations=5, + remaining_iterations=4, + escalation_threshold=1, + metadata={"max_spawn_depth": 2}, + ), + ) + + assert task.parent_task_id == "task-1" + assert task.child_task_ids == ["task-2a"] + assert task.thread_id == "thread-1" + assert task.message_id == "msg-1" + assert task.message_index == 2 + assert task.interrupt_records[0].interrupt_id == "interrupt-1" + assert task.recovery_records[0].recovery_id == "recovery-1" + assert task.collaboration_budget.mode == "collaboration" + assert task.collaboration_budget.metadata == {"max_spawn_depth": 2} + + +def test_agent_event_accepts_day1_fields(): + event = AgentEvent( + event_id="evt-1", + event_type="agent.verify.completed", + conversation_id="conv-1", + agent_id="executor", + sub_commander_id="executor_tasks", + task_id="task-1", + payload={"status": "passed"}, + severity="info", + ) + + assert event.event_id == "evt-1" + assert event.event_type == "agent.verify.completed" + assert event.conversation_id == "conv-1" + assert event.payload == {"status": "passed"} + assert event.severity == "info" + + +def test_agent_event_accepts_day3_trace_fields(): + event = AgentEvent( + event_id="evt-2", + event_type="agent.collaboration.budget.updated", + conversation_id="conv-1", + agent_id="coordinator", + task_id="task-2", + parent_task_id="task-1", + child_task_id="task-2a", + thread_id="thread-1", + message_id="msg-3", + interrupt_id="interrupt-1", + recovery_id="recovery-1", + payload={"remaining_parallel_tasks": 1}, + severity="warning", + ) + + assert event.parent_task_id == "task-1" + assert event.child_task_id == "task-2a" + assert event.thread_id == "thread-1" + assert event.message_id == "msg-3" + assert event.interrupt_id == "interrupt-1" + assert event.recovery_id == "recovery-1" + assert event.severity == "warning" + + +def test_task_result_supports_collaboration_result_fields(): + result = TaskResult( + task_id="task-1", + status="completed", + summary="retrieval finished", + evidence=[{"type": "source"}], + owner_agent_id="librarian", + next_action="handoff_to_analyst", + ) + + assert result.status == "completed" + assert result.owner_agent_id == "librarian" + assert result.next_action == "handoff_to_analyst" + + +def test_task_result_supports_day3_thread_budget_and_recovery_fields(): + result = TaskResult( + task_id="task-2", + status="blocked", + owner_agent_id="executor", + parent_task_id="task-1", + child_task_ids=["task-2a"], + thread_id="thread-1", + message_id="msg-4", + message_index=4, + interrupt_records=[{"interrupt_id": "interrupt-1", "reason": "budget exceeded"}], + recovery_records=[{"recovery_id": "recovery-1", "strategy": "resume_after_budget_reset"}], + budget_snapshot=CollaborationBudget( + mode="collaboration", + max_parallel_tasks=2, + remaining_parallel_tasks=0, + max_tool_calls=4, + remaining_tool_calls=0, + ), + next_action="resume_after_budget_reset", + ) + + assert result.parent_task_id == "task-1" + assert result.child_task_ids == ["task-2a"] + assert result.thread_id == "thread-1" + assert result.message_id == "msg-4" + assert result.message_index == 4 + assert result.interrupt_records[0].interrupt_id == "interrupt-1" + assert result.recovery_records[0].recovery_id == "recovery-1" + assert result.budget_snapshot.mode == "collaboration" + assert result.budget_snapshot.remaining_parallel_tasks == 0 + assert result.next_action == "resume_after_budget_reset" diff --git a/backend/tests/backend/app/agents/test_graph.py b/backend/tests/backend/app/agents/test_graph.py index a43e873..1cf3f63 100644 --- a/backend/tests/backend/app/agents/test_graph.py +++ b/backend/tests/backend/app/agents/test_graph.py @@ -2,23 +2,34 @@ import sys from types import SimpleNamespace from unittest.mock import Mock +import pytest + sys.modules.setdefault("trafilatura", Mock()) import app.agents.graph as graph_module from langchain_core.messages import AIMessage, HumanMessage from app.agents.graph import ( + _build_collaboration_tasks, _build_verifier_hints, _choose_sub_commander, + _create_child_agent, _execute_tool_calls, _parse_json_action, + _record_interrupt, + _record_recovery, _route_agent_from_user_query, + _select_request_mode, + _spawn_permission_for_role, + _run_collaboration_flow, _run_sub_commander, create_agent_graph, master_node, planner_node, route_agent, ) +from app.agents.schemas.message import AgentMessage +from app.agents.schemas.task import AgentTask from app.agents.state import AgentRole, initial_state from app.agents.tools import SUB_COMMANDER_TOOLSETS @@ -30,6 +41,15 @@ def _base_state(message: str, user_llm_config: dict | None = None) -> dict: 'messages': [HumanMessage(content=message)], 'user_id': 'u1', 'conversation_id': 'c1', + 'parent_conversation_id': None, + 'thread_id': None, + 'last_message_id': None, + 'message_sequence': 0, + 'agent_id': AgentRole.MASTER.value, + 'parent_agent_id': None, + 'root_agent_id': AgentRole.MASTER.value, + 'collaboration_depth': 0, + 'spawned_agent_ids': [], 'execution_mode': 'direct', 'current_agent': AgentRole.MASTER.value, 'next_step': None, @@ -39,10 +59,15 @@ def _base_state(message: str, user_llm_config: dict | None = None) -> dict: 'sub_commander_trace': [], 'agent_trace': [AgentRole.MASTER.value], 'event_trace': [], + 'message_trace': [], 'pending_tasks': [], 'completed_tasks': [], 'active_tasks': [], 'task_results': [], + 'task_hierarchy': {}, + 'interrupted_tasks': [], + 'recovery_trace': [], + 'recovery_points': [], 'tool_calls': [], 'last_tool_result': None, 'action_results': [], @@ -54,6 +79,7 @@ def _base_state(message: str, user_llm_config: dict | None = None) -> dict: 'verification_summary': None, 'verification_evidence': [], 'budget_state': None, + 'collaboration_budget_history': [], 'tool_strategy_used': None, 'tool_round_count': 0, 'max_tool_rounds': 2, @@ -286,6 +312,66 @@ def test_initial_state_sets_structured_continuity_defaults(): assert state['tool_outcomes'] == [] +def test_spawn_permission_for_role_uses_registry_policy(): + state = _base_state('test') + state['current_agent'] = AgentRole.MASTER.value + assert _spawn_permission_for_role(state, AgentRole.LIBRARIAN) is True + assert _spawn_permission_for_role(state, AgentRole.MASTER) is False + + state['current_agent'] = AgentRole.LIBRARIAN.value + assert _spawn_permission_for_role(state, AgentRole.LIBRARIAN) is True + assert _spawn_permission_for_role(state, AgentRole.EXECUTOR) is False + + +def test_create_child_agent_blocks_disallowed_spawn_role(): + state = _base_state('test') + state['current_agent'] = AgentRole.LIBRARIAN.value + state['agent_id'] = AgentRole.LIBRARIAN.value + task = AgentTask( + task_id='task-1', + title='分析', + role=AgentRole.ANALYST.value, + owner_agent_id=AgentRole.ANALYST.value, + goal='输出分析', + expected_evidence=[{'type': 'analysis'}], + ) + + child_agent_id = _create_child_agent(state, role=AgentRole.ANALYST, task=task) + + assert child_agent_id is None + assert state['spawned_agent_ids'] == [] + assert state['event_trace'][-1]['event_type'] == 'agent.spawn.blocked' + assert state['event_trace'][-1]['payload']['reason'] == 'role_policy_blocked' + + +def test_record_interrupt_and_recovery_write_day3_traces(): + state = _base_state('test') + state['current_agent'] = AgentRole.EXECUTOR.value + task = AgentTask( + task_id='task-1', + title='执行动作', + role=AgentRole.EXECUTOR.value, + owner_agent_id=AgentRole.EXECUTOR.value, + goal='执行必要动作', + expected_evidence=[{'type': 'execution'}], + ) + + interrupt = _record_interrupt(state, reason='spawn_blocked', task=task, payload={'target_role': AgentRole.EXECUTOR.value}) + recovery = _record_recovery(state, interrupt=interrupt, strategy='fallback_to_direct_role_execution', task=task) + + assert state['interrupted_tasks'][-1]['interrupt_id'] == interrupt.interrupt_id + assert state['recovery_trace'][-1]['recovery_id'] == recovery.recovery_id + assert state['recovery_points'][-1]['task_id'] == 'task-1' + assert [event['event_type'] for event in state['event_trace']] == [ + 'agent.interrupt.requested', + 'agent.task.interrupted', + 'agent.interrupt.completed', + 'agent.recovery.started', + 'agent.task.recovered', + 'agent.recovery.completed', + ] + + async def test_master_node_sets_next_step_when_routing_to_schedule_planner(monkeypatch): monkeypatch.setattr('app.agents.graph._get_llm_for_state', lambda state: FailIfCalledLLM()) @@ -347,6 +433,222 @@ async def test_planner_node_clears_next_step_after_consuming_routed_turn(monkeyp assert result['final_response'] is not None +def test_select_request_mode_prefers_collaboration_for_multi_role_request(): + mode, metadata = _select_request_mode('先帮我搜索竞品资料,然后分析风险,再给我安排下周计划') + + assert mode == 'collaboration' + assert metadata['reason'] == 'multi_role_request' + assert AgentRole.LIBRARIAN.value in metadata['roles'] + assert AgentRole.ANALYST.value in metadata['roles'] + assert AgentRole.SCHEDULE_PLANNER.value in metadata['roles'] + + +def test_build_collaboration_tasks_generates_structured_owned_tasks(): + tasks = _build_collaboration_tasks('先帮我搜索竞品资料,然后分析风险,再给我安排下周计划') + + assert len(tasks) == 3 + assert [task.role for task in tasks] == [ + AgentRole.LIBRARIAN.value, + AgentRole.ANALYST.value, + AgentRole.SCHEDULE_PLANNER.value, + ] + assert all(task.owner_agent_id for task in tasks) + assert all(task.expected_evidence for task in tasks) + + +def test_verify_collaboration_results_uses_explicit_task_results_snapshot(): + task = AgentTask( + task_id='task-1', + title='补齐事实与证据', + role=AgentRole.LIBRARIAN.value, + owner_agent_id=AgentRole.LIBRARIAN.value, + goal='检索资料', + expected_evidence=[{'type': 'evidence'}], + ) + state = _base_state('test') + state['task_results'] = [ + { + 'task_id': 'stale-task', + 'status': 'failed', + 'summary': 'stale failure', + 'evidence': [{'type': 'verification'}], + } + ] + + graph_module._verify_collaboration_results( + state, + [task], + [ + { + 'task_id': 'task-1', + 'status': 'completed', + 'summary': 'done', + 'evidence': [{'type': 'verification'}], + 'owner_agent_id': AgentRole.LIBRARIAN.value, + } + ], + ) + + assert state['verification_status'] == 'passed' + assert '1/1 个子任务' in state['verification_summary'] + + +def test_verify_collaboration_results_ignores_stale_results_outside_current_plan(): + tasks = [ + AgentTask( + task_id='task-1', + title='补齐事实与证据', + role=AgentRole.LIBRARIAN.value, + owner_agent_id=AgentRole.LIBRARIAN.value, + goal='检索资料', + expected_evidence=[{'type': 'evidence'}], + ) + ] + state = _base_state('test') + + graph_module._verify_collaboration_results( + state, + tasks, + [ + { + 'task_id': 'stale-task', + 'status': 'failed', + 'summary': 'stale failure', + 'evidence': [{'type': 'verification'}], + }, + { + 'task_id': 'task-1', + 'status': 'completed', + 'summary': 'done', + 'evidence': [{'type': 'verification'}], + 'owner_agent_id': AgentRole.LIBRARIAN.value, + }, + ], + ) + + assert state['verification_status'] == 'passed' + assert '1/1 个子任务' in state['verification_summary'] + + +@pytest.mark.asyncio +async def test_run_sub_commander_verifies_only_current_turn_tool_outcomes(monkeypatch): + class FakeBoundLLM: + def __init__(self, response): + self._response = response + + def bind_tools(self, _toolset): + return self + + async def ainvoke(self, _messages): + return self._response + + state = _base_state('查一下资料') + state['tool_outcomes'] = [ + { + 'tool_name': 'stale_tool', + 'args': {'query': 'old'}, + 'result_preview': '工具执行失败: stale', + 'verifier_hints': {'tool_name': 'stale_tool'}, + } + ] + + response = AIMessage(content='当前回合完成') + monkeypatch.setattr(graph_module, '_get_llm_for_state', lambda _state: FakeBoundLLM(response)) + monkeypatch.setattr(graph_module, '_resolve_capabilities', lambda _state, _llm: type('Caps', (), {'supports_native_tools': True})()) + monkeypatch.setattr(graph_module, '_choose_sub_commander', lambda _role, _query: 'librarian_retrieval') + monkeypatch.setattr(graph_module, '_record_sub_commander', lambda *_args, **_kwargs: None) + + await graph_module._run_sub_commander( + state, + AgentRole.LIBRARIAN, + 'prompt', + '查一下资料', + use_tools=True, + ) + + assert state['final_response'] == '当前回合完成' + assert state['verification_status'] == 'passed' + + +async def test_run_collaboration_flow_collects_task_results_and_verifies(monkeypatch): + planned_tasks = [ + AgentTask( + task_id='task-1', + title='补齐事实与证据', + role=AgentRole.LIBRARIAN.value, + owner_agent_id=AgentRole.LIBRARIAN.value, + goal='检索资料', + expected_evidence=[{'type': 'evidence'}], + ), + AgentTask( + task_id='task-2', + title='给出分析与判断', + role=AgentRole.ANALYST.value, + owner_agent_id=AgentRole.ANALYST.value, + goal='输出分析', + expected_evidence=[{'type': 'analysis'}], + ), + ] + + async def fake_run_sub_commander(state, role, manager_prompt, user_query, *, use_tools, summary_target=None): + state['current_agent'] = role.value + state['current_sub_commander'] = f'{role.value}_worker' + state['final_response'] = f'{role.value} finished' + state['verification_status'] = 'passed' + state['verification_summary'] = f'{role.value} verified' + state['tool_outcomes'] = [ + *(state.get('tool_outcomes') or []), + { + 'tool_name': f'{role.value}_tool', + 'args': {'query': user_query}, + 'result_preview': 'ok', + 'verifier_hints': {'tool_name': f'{role.value}_tool'}, + }, + ] + state['messages'] = [*state.get('messages', []), AIMessage(content=state['final_response'])] + return state + + monkeypatch.setattr(graph_module, '_build_collaboration_tasks', lambda user_query: planned_tasks) + monkeypatch.setattr(graph_module, '_run_sub_commander', fake_run_sub_commander) + + state = _base_state('先帮我搜索竞品资料,然后分析风险') + result = await _run_collaboration_flow(state, '先帮我搜索竞品资料,然后分析风险') + + assert result['execution_mode'] == 'collaboration' + assert len(result['active_tasks']) == 2 + assert len(result['task_results']) == 2 + assert result['task_results'][0]['status'] == 'completed' + assert result['task_results'][1]['owner_agent_id'] == AgentRole.ANALYST.value + assert result['verification_status'] == 'passed' + assert '协作模式已完成 2/2 个子任务' in result['verification_summary'] + assert '已按协作模式回收 2 个子任务结果' in result['final_response'] + assert len(result['message_trace']) >= 2 + assert all(message['message_type'] == 'task_update' for message in result['message_trace']) + assert result['message_trace'][-1]['message_type'] == 'task_update' + assert 'agent.created' in [event['event_type'] for event in result['event_trace']] + assert 'agent.message.sent' in [event['event_type'] for event in result['event_trace']] + assert 'agent.spawn.blocked' not in [event['event_type'] for event in result['event_trace']] + assert result['spawned_agent_ids'] + assert all(not agent_id.startswith('blocked-') for agent_id in result['spawned_agent_ids']) + assert result['task_hierarchy'] + + +async def test_master_node_enters_collaboration_mode_for_complex_multi_role_request(monkeypatch): + async def fake_collaboration_flow(state, user_query): + state['execution_mode'] = 'collaboration' + state['final_response'] = 'collaboration done' + state['messages'] = [*state.get('messages', []), AIMessage(content=state['final_response'])] + return state + + monkeypatch.setattr(graph_module, '_run_collaboration_flow', fake_collaboration_flow) + + state = _base_state('先帮我搜索竞品资料,然后分析风险,再给我安排下周计划') + result = await master_node(state) + + assert result['execution_mode'] == 'collaboration' + assert result['final_response'] == 'collaboration done' + + async def test_master_node_returns_stable_reply_for_simple_greeting(monkeypatch): monkeypatch.setattr('app.agents.graph._get_llm_for_state', lambda state: FailIfCalledLLM()) @@ -1160,6 +1462,8 @@ async def test_execute_tool_calls_records_schema_events_and_aggregate_summaries( assert all(event['conversation_id'] == 'c1' for event in state['event_trace']) assert all(event['agent_id'] == AgentRole.MASTER.value for event in state['event_trace']) assert all(event['task_id'] == 'task-1' for event in state['event_trace']) + assert all(event['thread_id'] is not None for event in state['event_trace']) + assert all(event['message_id'] is None for event in state['event_trace']) async def test_execute_tool_calls_aggregates_multiple_tool_turns_without_overwrite(monkeypatch): diff --git a/backend/tests/backend/app/agents/test_prompts.py b/backend/tests/backend/app/agents/test_prompts.py index f4f911a..e21306d 100644 --- a/backend/tests/backend/app/agents/test_prompts.py +++ b/backend/tests/backend/app/agents/test_prompts.py @@ -1,4 +1,4 @@ -from app.agents.prompts import MASTER_SYSTEM_PROMPT +from app.agents.prompts import COORDINATOR_SYSTEM_PROMPT, MASTER_SYSTEM_PROMPT def test_master_prompt_forbids_subagent_rollcall_in_simple_greetings(): @@ -10,3 +10,10 @@ def test_master_prompt_does_not_include_full_canned_answers_for_greetings_or_ide assert 'Jarvis:您好。我在。' not in MASTER_SYSTEM_PROMPT assert 'Jarvis:我是 Jarvis。' not in MASTER_SYSTEM_PROMPT assert 'Jarvis:主要做三件事。' not in MASTER_SYSTEM_PROMPT + + +def test_coordinator_prompt_limits_collaboration_scope(): + assert "2~4 个子任务" in COORDINATOR_SYSTEM_PROMPT + assert "禁止无限递归拆分" in COORDINATOR_SYSTEM_PROMPT + assert "schedule_planner" in COORDINATOR_SYSTEM_PROMPT + assert "librarian" in COORDINATOR_SYSTEM_PROMPT diff --git a/backend/tests/backend/app/agents/test_registry.py b/backend/tests/backend/app/agents/test_registry.py index bbdfc0d..b6cff39 100644 --- a/backend/tests/backend/app/agents/test_registry.py +++ b/backend/tests/backend/app/agents/test_registry.py @@ -307,6 +307,7 @@ def test_build_registry_indexes_exposes_manifest_lookups_by_id() -> None: indexes = build_registry_indexes(bundle) assert indexes.agent_by_id + assert indexes.agent_by_role_value assert indexes.sub_commander_by_id assert indexes.capability_by_id assert isinstance(indexes.specialist_template_by_id, Mapping) @@ -362,6 +363,14 @@ def test_build_registry_indexes_exposes_prompt_keys_skill_context_keys_and_capab sub_commander.sub_commander_id: tuple(sub_commander.capability_ids) for sub_commander in bundle.sub_commanders } + assert indexes.agent_by_role_value == { + agent.role_value: agent for agent in bundle.agents + } + assert indexes.spawnable_role_values_by_agent_id == { + agent.agent_id: tuple(agent.allowed_spawn_role_values) + for agent in bundle.agents + if agent.can_spawn_children and agent.allowed_spawn_role_values + } def test_validate_registry_bundle_accepts_loaded_builtin_registry_bundle() -> None: diff --git a/backend/tests/backend/app/agents/test_schema_verifier.py b/backend/tests/backend/app/agents/test_schema_verifier.py index dbb04eb..32b8c23 100644 --- a/backend/tests/backend/app/agents/test_schema_verifier.py +++ b/backend/tests/backend/app/agents/test_schema_verifier.py @@ -1,66 +1,135 @@ -from app.agents.schemas.event import AgentEvent -from app.agents.schemas.task import AgentTask -from app.agents.verifier import verify_task_result +from app.agents.schemas import AgentEvent, AgentTask, TaskResult +from app.agents.schemas.task import CollaborationBudget, InterruptRecord, RecoveryRecord +from app.agents.state import initial_state +from app.agents.verifier import apply_verification_verdict, normalize_task_result, verify_task_result -def test_agent_task_accepts_day1_fields(): +def test_agent_task_supports_day3_interrupt_recovery_and_budget_fields(): + interrupt = InterruptRecord(interrupt_id="interrupt-1", reason="user_cancel") + recovery = RecoveryRecord(recovery_id="recovery-1", source_interrupt_id="interrupt-1", resumed_from_task_id="task-1") + budget = CollaborationBudget( + mode="collaboration", + max_parallel_tasks=3, + remaining_parallel_tasks=2, + max_tool_calls=6, + remaining_tool_calls=4, + metadata={"phase": "day3"}, + ) + task = AgentTask( task_id="task-1", - title="Verify foundation", - status="in_progress", - owner_agent_id="executor", - role="verifier", - goal="check output", - expected_evidence=[{"type": "assertion"}], - evidence=[{"type": "log"}], - result_summary="running", + title="Recover interrupted collaboration task", + owner_agent_id="analyst", + role="analyst", + parent_task_id="parent-1", + child_task_ids=["child-1"], + thread_id="thread-1", + message_id="message-1", + message_index=3, + interrupt_records=[interrupt], + recovery_records=[recovery], + collaboration_budget=budget, ) - assert task.task_id == "task-1" - assert task.owner_agent_id == "executor" - assert task.status == "in_progress" - assert task.expected_evidence == [{"type": "assertion"}] - assert task.evidence == [{"type": "log"}] - assert task.result_summary == "running" + payload = task.model_dump(mode="json") + + assert payload["parent_task_id"] == "parent-1" + assert payload["child_task_ids"] == ["child-1"] + assert payload["thread_id"] == "thread-1" + assert payload["message_id"] == "message-1" + assert payload["message_index"] == 3 + assert payload["interrupt_records"][0]["interrupt_id"] == "interrupt-1" + assert payload["recovery_records"][0]["recovery_id"] == "recovery-1" + assert payload["collaboration_budget"]["mode"] == "collaboration" + assert payload["collaboration_budget"]["remaining_tool_calls"] == 4 -def test_agent_event_accepts_day1_fields(): +def test_agent_event_supports_day3_thread_interrupt_and_recovery_metadata(): event = AgentEvent( event_id="evt-1", - event_type="agent.verify.completed", + event_type="agent.task.recovered", conversation_id="conv-1", agent_id="executor", - sub_commander_id="executor_tasks", task_id="task-1", - payload={"status": "passed"}, - severity="info", + parent_task_id="parent-1", + child_task_id="child-1", + thread_id="thread-1", + message_id="message-1", + interrupt_id="interrupt-1", + recovery_id="recovery-1", + severity="warning", + payload={"status": "resumed"}, ) - assert event.event_id == "evt-1" - assert event.event_type == "agent.verify.completed" - assert event.conversation_id == "conv-1" - assert event.payload == {"status": "passed"} - assert event.severity == "info" + payload = event.model_dump(mode="json") + + assert payload["event_type"] == "agent.task.recovered" + assert payload["parent_task_id"] == "parent-1" + assert payload["child_task_id"] == "child-1" + assert payload["thread_id"] == "thread-1" + assert payload["message_id"] == "message-1" + assert payload["interrupt_id"] == "interrupt-1" + assert payload["recovery_id"] == "recovery-1" + assert payload["severity"] == "warning" -def test_verifier_verdict_is_separate_from_task_lifecycle_status(): - task = AgentTask(task_id="task-1", title="Verify", status="blocked", result_summary="waiting") +def test_normalize_task_result_preserves_day3_metadata_fields(): + result = normalize_task_result( + { + "task_id": "task-1", + "status": "completed", + "summary": "Recovered successfully.", + "owner_agent_id": "executor", + "parent_task_id": "parent-1", + "child_task_ids": ["child-1"], + "thread_id": "thread-1", + "message_id": "message-1", + "message_index": 2, + "interrupt_records": [{"interrupt_id": "interrupt-1", "reason": "user_pause"}], + "recovery_records": [{"recovery_id": "recovery-1", "source_interrupt_id": "interrupt-1"}], + "budget_snapshot": {"mode": "collaboration", "max_parallel_tasks": 4}, + "next_action": "notify_user", + "output_data": {"ok": True}, + } + ) - verdict = verify_task_result(task=task) - - assert verdict.status == "skipped" - assert verdict.summary == "waiting" + assert result.parent_task_id == "parent-1" + assert result.child_task_ids == ["child-1"] + assert result.thread_id == "thread-1" + assert result.message_id == "message-1" + assert result.message_index == 2 + assert result.interrupt_records[0].interrupt_id == "interrupt-1" + assert result.recovery_records[0].recovery_id == "recovery-1" + assert result.budget_snapshot.mode == "collaboration" + assert result.budget_snapshot.max_parallel_tasks == 4 + assert result.next_action == "notify_user" + assert result.output_data == {"ok": True} -def test_verifier_prefers_explicit_result_success_signal(): - verdict = verify_task_result(result={"success": True, "summary": "all checks passed"}) +def test_apply_verification_verdict_updates_state_with_recovery_evidence(): + state = initial_state("u1", "c1") - assert verdict.status == "passed" - assert verdict.summary == "all checks passed" + verdict = verify_task_result( + status="passed", + summary="Interrupt and recovery chain verified.", + evidence=[ + { + "task_id": "task-1", + "thread_id": "thread-1", + "interrupt_id": "interrupt-1", + "recovery_id": "recovery-1", + } + ], + ) + updated_state = apply_verification_verdict(state, verdict) - -def test_verifier_fails_when_no_verification_input_exists(): - verdict = verify_task_result() - - assert verdict.status == "failed" - assert verdict.summary == "No verification input available." + assert updated_state["verification_status"] == "passed" + assert updated_state["verification_summary"] == "Interrupt and recovery chain verified." + assert updated_state["verification_evidence"] == [ + { + "task_id": "task-1", + "thread_id": "thread-1", + "interrupt_id": "interrupt-1", + "recovery_id": "recovery-1", + } + ] diff --git a/backend/tests/backend/app/agents/test_verifier.py b/backend/tests/backend/app/agents/test_verifier.py new file mode 100644 index 0000000..4187238 --- /dev/null +++ b/backend/tests/backend/app/agents/test_verifier.py @@ -0,0 +1,39 @@ +from app.agents.schemas.task import AgentTask +from app.agents.verifier import verify_task_result + + +def test_verifier_verdict_is_separate_from_task_lifecycle_status(): + task = AgentTask(task_id="task-1", title="Verify", status="blocked", result_summary="waiting") + + verdict = verify_task_result(task=task) + + assert verdict.status == "skipped" + assert verdict.summary == "waiting" + + +def test_verifier_prefers_explicit_result_success_signal(): + verdict = verify_task_result(result={"success": True, "summary": "all checks passed"}) + + assert verdict.status == "passed" + assert verdict.summary == "all checks passed" + + +def test_verifier_treats_completed_task_result_as_passed(): + verdict = verify_task_result(result={"status": "completed", "summary": "done", "evidence": [{"type": "log"}]}) + + assert verdict.status == "passed" + assert verdict.summary == "done" + + +def test_verifier_treats_blocked_task_result_as_skipped(): + verdict = verify_task_result(result={"status": "blocked", "summary": "waiting on user"}) + + assert verdict.status == "skipped" + assert verdict.summary == "waiting on user" + + +def test_verifier_fails_when_no_verification_input_exists(): + verdict = verify_task_result() + + assert verdict.status == "failed" + assert verdict.summary == "No verification input available." diff --git a/backend/tests/backend/app/agents/test_visibility_api.py b/backend/tests/backend/app/agents/test_visibility_api.py new file mode 100644 index 0000000..aa147e9 --- /dev/null +++ b/backend/tests/backend/app/agents/test_visibility_api.py @@ -0,0 +1,619 @@ +from datetime import datetime, timedelta, timezone + +import pytest +from fastapi import FastAPI +from httpx import ASGITransport, AsyncClient +from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine + +import app.models # noqa: F401 +from app.database import Base, get_db +from app.models.conversation import Conversation +from app.models.user import User +from app.routers.agent import router as agent_router +from app.routers.auth import get_current_user +from app.services.auth_service import get_password_hash + + +@pytest.fixture +async def visibility_env(tmp_path): + db_path = tmp_path / 'test_visibility_api.db' + engine = create_async_engine(f'sqlite+aiosqlite:///{db_path}', future=True) + session_factory = async_sessionmaker(engine, expire_on_commit=False) + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + now = datetime.now(timezone.utc) + snapshot = { + 'kind': 'agent_continuity_state', + 'version': 1, + 'state': { + 'agent_id': 'master', + 'root_agent_id': 'master', + 'current_agent': 'analyst-1234abcd', + 'thread_id': 'thread-1', + 'spawned_agent_ids': ['analyst-1234abcd'], + 'event_trace': [ + { + 'event_id': 'evt-1', + 'event_type': 'agent.created', + 'timestamp': (now - timedelta(minutes=10)).isoformat(), + 'conversation_id': 'placeholder', + 'agent_id': 'master', + 'thread_id': 'thread-1', + 'task_id': 'task-1', + 'payload': {'child_agent_id': 'analyst-1234abcd'}, + 'severity': 'info', + }, + { + 'event_id': 'evt-2', + 'event_type': 'agent.tool.result', + 'timestamp': (now - timedelta(minutes=5)).isoformat(), + 'conversation_id': 'placeholder', + 'agent_id': 'analyst-1234abcd', + 'thread_id': 'thread-1', + 'task_id': 'task-1', + 'payload': {'tool_name': 'search_web', 'result_preview': 'ok'}, + 'severity': 'info', + }, + ], + 'message_trace': [ + { + 'message_id': 'msg-1', + 'thread_id': 'thread-1', + 'from_agent_id': 'master', + 'to_agent_id': 'analyst-1234abcd', + 'task_id': 'task-1', + 'message_type': 'task_request', + 'content_summary': 'Analyze the issue', + 'created_at': (now - timedelta(minutes=9)).isoformat(), + 'payload': {}, + }, + { + 'message_id': 'msg-2', + 'thread_id': 'thread-1', + 'from_agent_id': 'analyst-1234abcd', + 'to_agent_id': 'master', + 'task_id': 'task-1', + 'reply_to_message_id': 'msg-1', + 'message_type': 'task_update', + 'content_summary': 'Done', + 'created_at': (now - timedelta(minutes=4)).isoformat(), + 'payload': {'status': 'completed'}, + }, + ], + 'active_tasks': [ + { + 'task_id': 'task-1', + 'title': 'Analyze issue', + 'role': 'analyst', + 'owner_agent_id': 'analyst-1234abcd', + 'status': 'completed', + 'thread_id': 'thread-1', + 'result_summary': 'Analysis complete', + 'evidence': [ + { + 'tool_name': 'search_web', + 'args': {'query': 'jarvis visibility'}, + 'result_preview': 'ok', + } + ], + } + ], + 'task_results': [ + { + 'task_id': 'task-1', + 'status': 'completed', + 'summary': 'Analysis complete', + 'owner_agent_id': 'analyst-1234abcd', + 'thread_id': 'thread-1', + 'evidence': [ + { + 'tool_name': 'search_web', + 'args': {'query': 'jarvis visibility'}, + 'result_preview': 'ok', + }, + { + 'type': 'verification', + 'status': 'passed', + 'summary': 'Verified', + }, + ], + } + ], + 'task_hierarchy': {'root-task': ['task-1']}, + 'tool_outcomes': [ + { + 'tool_name': 'search_web', + 'args': {'query': 'jarvis visibility'}, + 'result_preview': 'ok', + 'verifier_hints': {'tool_name': 'search_web'}, + } + ], + 'verification_status': 'passed', + 'verification_summary': 'All task evidence verified.', + 'verification_evidence': [ + {'task_id': 'task-1', 'status': 'passed', 'summary': 'Verified'} + ], + }, + } + + async with session_factory() as session: + user = User( + username='visibility_user', + email='visibility@example.com', + hashed_password=get_password_hash('secret123'), + full_name='Visibility Tester', + ) + session.add(user) + await session.flush() + conversation = Conversation(user_id=user.id, title='Visibility test', agent_state=snapshot) + session.add(conversation) + await session.commit() + await session.refresh(user) + await session.refresh(conversation) + + snapshot['state']['event_trace'][0]['conversation_id'] = conversation.id + snapshot['state']['event_trace'][1]['conversation_id'] = conversation.id + conversation.agent_state = snapshot + await session.commit() + await session.refresh(conversation) + + async def override_get_db(): + async with session_factory() as session: + yield session + + async def override_get_current_user(): + return user + + test_app = FastAPI() + test_app.include_router(agent_router) + test_app.dependency_overrides[get_db] = override_get_db + test_app.dependency_overrides[get_current_user] = override_get_current_user + + try: + yield test_app, { + 'conversation_id': conversation.id, + 'thread_id': 'thread-1', + 'task_id': 'task-1', + 'started_after': (now - timedelta(minutes=11)).isoformat(), + 'ended_before': (now - timedelta(minutes=1)).isoformat(), + } + finally: + await engine.dispose() + + +@pytest.mark.asyncio +async def test_visibility_events_support_filters_and_pagination(visibility_env): + app, ids = visibility_env + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url='http://testserver') as client: + response = await client.get( + '/api/agents/visibility/events', + params={ + 'conversation_id': ids['conversation_id'], + 'agent_id': 'analyst-1234abcd', + 'thread_id': ids['thread_id'], + 'event_type': 'agent.tool.result', + 'limit': 1, + 'offset': 0, + }, + ) + + assert response.status_code == 200 + payload = response.json() + assert payload['total'] == 1 + assert payload['limit'] == 1 + assert payload['items'][0]['event_id'] == 'evt-2' + + +@pytest.mark.asyncio +async def test_visibility_topology_returns_nodes_edges_and_task_summary(visibility_env): + app, ids = visibility_env + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url='http://testserver') as client: + response = await client.get( + '/api/agents/visibility/topology', + params={'conversation_id': ids['conversation_id']}, + ) + + assert response.status_code == 200 + payload = response.json() + assert payload['root_agent_id'] == 'master' + assert payload['current_agent'] == 'analyst-1234abcd' + assert any(node['agent_id'] == 'analyst-1234abcd' for node in payload['nodes']) + assert any(edge['child_agent_id'] == 'analyst-1234abcd' for edge in payload['edges']) + assert payload['tasks'][0]['task_id'] == ids['task_id'] + assert payload['task_hierarchy'] == {'root-task': ['task-1']} + + +@pytest.mark.asyncio +async def test_visibility_task_evidence_returns_tool_and_verifier_evidence(visibility_env): + app, ids = visibility_env + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url='http://testserver') as client: + response = await client.get( + f'/api/agents/visibility/tasks/{ids["task_id"]}/evidence', + params={'conversation_id': ids['conversation_id']}, + ) + + assert response.status_code == 200 + payload = response.json() + assert payload['task']['task_id'] == ids['task_id'] + assert payload['result']['status'] == 'completed' + assert payload['tool_outcomes'][0]['tool_name'] == 'search_web' + assert payload['verifier']['status'] == 'passed' + + +@pytest.mark.asyncio +async def test_visibility_task_evidence_uses_task_evidence_instead_of_global_tool_outcomes(tmp_path): + db_path = tmp_path / 'test_visibility_api_task_evidence_filter.db' + engine = create_async_engine(f'sqlite+aiosqlite:///{db_path}', future=True) + session_factory = async_sessionmaker(engine, expire_on_commit=False) + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + snapshot = { + 'kind': 'agent_continuity_state', + 'version': 1, + 'state': { + 'agent_id': 'master', + 'root_agent_id': 'master', + 'current_agent': 'analyst-1234abcd', + 'thread_id': 'thread-1', + 'spawned_agent_ids': ['analyst-1234abcd'], + 'event_trace': [], + 'message_trace': [], + 'active_tasks': [ + { + 'task_id': 'task-1', + 'title': 'Analyze issue', + 'role': 'analyst', + 'owner_agent_id': 'analyst-1234abcd', + 'status': 'completed', + 'thread_id': 'thread-1', + 'result_summary': 'Analysis complete', + 'evidence': [], + } + ], + 'task_results': [ + { + 'task_id': 'task-1', + 'status': 'completed', + 'summary': 'Analysis complete', + 'owner_agent_id': 'analyst-1234abcd', + 'thread_id': 'thread-1', + 'evidence': [ + { + 'tool_name': 'search_web', + 'args': {'query': 'jarvis visibility'}, + 'result_preview': 'task-specific', + }, + { + 'type': 'verification', + 'status': 'passed', + 'summary': 'Verified', + }, + ], + } + ], + 'task_hierarchy': {'root-task': ['task-1']}, + 'tool_outcomes': [ + { + 'tool_name': 'search_web', + 'args': {'query': 'jarvis visibility'}, + 'result_preview': 'global-duplicate', + 'verifier_hints': {'tool_name': 'search_web'}, + } + ], + 'verification_status': 'passed', + 'verification_summary': 'All task evidence verified.', + 'verification_evidence': [ + {'task_id': 'task-1', 'status': 'passed', 'summary': 'Verified'} + ], + }, + } + + async with session_factory() as session: + user = User( + username='task_evidence_user', + email='task-evidence@example.com', + hashed_password=get_password_hash('secret123'), + full_name='Task Evidence Tester', + ) + session.add(user) + await session.flush() + conversation = Conversation(user_id=user.id, title='Task evidence test', agent_state=snapshot) + session.add(conversation) + await session.commit() + await session.refresh(user) + await session.refresh(conversation) + + async def override_get_db(): + async with session_factory() as session: + yield session + + async def override_get_current_user(): + return user + + test_app = FastAPI() + test_app.include_router(agent_router) + test_app.dependency_overrides[get_db] = override_get_db + test_app.dependency_overrides[get_current_user] = override_get_current_user + + transport = ASGITransport(app=test_app) + async with AsyncClient(transport=transport, base_url='http://testserver') as client: + response = await client.get( + '/api/agents/visibility/tasks/task-1/evidence', + params={'conversation_id': conversation.id}, + ) + + assert response.status_code == 200 + payload = response.json() + assert payload['tool_outcomes'] == [ + { + 'tool_name': 'search_web', + 'args': {'query': 'jarvis visibility'}, + 'result_preview': 'task-specific', + } + ] + + await engine.dispose() + + +@pytest.mark.asyncio +async def test_visibility_thread_messages_returns_thread_history(visibility_env): + app, ids = visibility_env + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url='http://testserver') as client: + response = await client.get( + f'/api/agents/visibility/threads/{ids["thread_id"]}/messages', + params={'conversation_id': ids['conversation_id']}, + ) + + assert response.status_code == 200 + payload = response.json() + assert payload['thread_id'] == ids['thread_id'] + assert payload['total'] == 2 + assert payload['items'][1]['reply_to_message_id'] == 'msg-1' + + +@pytest.mark.asyncio +async def test_visibility_verifier_returns_verdict(visibility_env): + app, ids = visibility_env + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url='http://testserver') as client: + response = await client.get( + '/api/agents/visibility/verifier', + params={'conversation_id': ids['conversation_id']}, + ) + + assert response.status_code == 200 + payload = response.json() + assert payload['status'] == 'passed' + assert payload['summary'] == 'All task evidence verified.' + assert payload['evidence'][0]['task_id'] == ids['task_id'] + + +@pytest.mark.asyncio +async def test_visibility_events_reject_invalid_datetime(visibility_env): + app, ids = visibility_env + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url='http://testserver') as client: + response = await client.get( + '/api/agents/visibility/events', + params={ + 'conversation_id': ids['conversation_id'], + 'started_after': 'not-a-date', + }, + ) + + assert response.status_code == 400 + assert response.json()['detail'] == '时间参数必须是 ISO 8601 格式' + + +@pytest.mark.asyncio +async def test_visibility_events_support_time_window_and_offset_pagination(visibility_env): + app, ids = visibility_env + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url='http://testserver') as client: + response = await client.get( + '/api/agents/visibility/events', + params={ + 'conversation_id': ids['conversation_id'], + 'started_after': ids['started_after'], + 'ended_before': ids['ended_before'], + 'limit': 1, + 'offset': 1, + }, + ) + + assert response.status_code == 200 + payload = response.json() + assert payload['total'] == 2 + assert payload['limit'] == 1 + assert payload['offset'] == 1 + assert len(payload['items']) == 1 + assert payload['items'][0]['event_id'] == 'evt-2' + + +@pytest.mark.asyncio +async def test_visibility_topology_includes_task_counts_for_root_and_child(visibility_env): + app, ids = visibility_env + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url='http://testserver') as client: + response = await client.get( + '/api/agents/visibility/topology', + params={'conversation_id': ids['conversation_id']}, + ) + + assert response.status_code == 200 + payload = response.json() + nodes = {node['agent_id']: node for node in payload['nodes']} + assert nodes['master']['task_count'] == 0 + assert nodes['master']['completed_task_count'] == 0 + assert nodes['analyst-1234abcd']['task_count'] == 1 + assert nodes['analyst-1234abcd']['completed_task_count'] == 1 + + +@pytest.mark.asyncio +async def test_visibility_task_evidence_returns_404_for_unknown_task(visibility_env): + app, ids = visibility_env + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url='http://testserver') as client: + response = await client.get( + '/api/agents/visibility/tasks/missing-task/evidence', + params={'conversation_id': ids['conversation_id']}, + ) + + assert response.status_code == 404 + assert response.json()['detail'] == '任务不存在' + + +@pytest.mark.asyncio +async def test_visibility_thread_messages_returns_404_for_unknown_thread(visibility_env): + app, ids = visibility_env + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url='http://testserver') as client: + response = await client.get( + '/api/agents/visibility/threads/missing-thread/messages', + params={'conversation_id': ids['conversation_id']}, + ) + + assert response.status_code == 404 + assert response.json()['detail'] == '线程不存在' + + +@pytest.mark.asyncio +async def test_visibility_returns_404_when_conversation_is_missing(visibility_env): + app, _ids = visibility_env + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url='http://testserver') as client: + response = await client.get( + '/api/agents/visibility/events', + params={'conversation_id': 'missing-conversation'}, + ) + + assert response.status_code == 404 + assert response.json()['detail'] == '对话不存在' + + +@pytest.mark.asyncio +async def test_visibility_returns_404_when_snapshot_is_missing(tmp_path): + db_path = tmp_path / 'test_visibility_api_missing_snapshot.db' + engine = create_async_engine(f'sqlite+aiosqlite:///{db_path}', future=True) + session_factory = async_sessionmaker(engine, expire_on_commit=False) + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + async with session_factory() as session: + user = User( + username='missing_snapshot_user', + email='missing-snapshot@example.com', + hashed_password=get_password_hash('secret123'), + full_name='Missing Snapshot Tester', + ) + session.add(user) + await session.flush() + conversation = Conversation(user_id=user.id, title='Missing snapshot test', agent_state=None) + session.add(conversation) + await session.commit() + await session.refresh(user) + await session.refresh(conversation) + + async def override_get_db(): + async with session_factory() as session: + yield session + + async def override_get_current_user(): + return user + + test_app = FastAPI() + test_app.include_router(agent_router) + test_app.dependency_overrides[get_db] = override_get_db + test_app.dependency_overrides[get_current_user] = override_get_current_user + + transport = ASGITransport(app=test_app) + async with AsyncClient(transport=transport, base_url='http://testserver') as client: + response = await client.get( + '/api/agents/visibility/verifier', + params={'conversation_id': conversation.id}, + ) + + assert response.status_code == 404 + assert response.json()['detail'] == '当前会话暂无可视化运行时数据' + + await engine.dispose() + + +@pytest.mark.asyncio +async def test_visibility_verifier_returns_empty_verdict_when_state_is_unverified(tmp_path): + db_path = tmp_path / 'test_visibility_api_empty_verifier.db' + engine = create_async_engine(f'sqlite+aiosqlite:///{db_path}', future=True) + session_factory = async_sessionmaker(engine, expire_on_commit=False) + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + snapshot = { + 'kind': 'agent_continuity_state', + 'version': 1, + 'state': { + 'agent_id': 'master', + 'root_agent_id': 'master', + 'current_agent': 'master', + 'event_trace': [], + 'message_trace': [], + 'active_tasks': [], + 'task_results': [], + 'task_hierarchy': {}, + 'tool_outcomes': [], + 'verification_status': None, + 'verification_summary': None, + 'verification_evidence': [], + }, + } + + async with session_factory() as session: + user = User( + username='empty_verifier_user', + email='empty-verifier@example.com', + hashed_password=get_password_hash('secret123'), + full_name='Empty Verifier Tester', + ) + session.add(user) + await session.flush() + conversation = Conversation(user_id=user.id, title='Empty verifier test', agent_state=snapshot) + session.add(conversation) + await session.commit() + await session.refresh(user) + await session.refresh(conversation) + + async def override_get_db(): + async with session_factory() as session: + yield session + + async def override_get_current_user(): + return user + + test_app = FastAPI() + test_app.include_router(agent_router) + test_app.dependency_overrides[get_db] = override_get_db + test_app.dependency_overrides[get_current_user] = override_get_current_user + + transport = ASGITransport(app=test_app) + async with AsyncClient(transport=transport, base_url='http://testserver') as client: + response = await client.get( + '/api/agents/visibility/verifier', + params={'conversation_id': conversation.id}, + ) + + assert response.status_code == 200 + payload = response.json() + assert payload['status'] is None + assert payload['summary'] is None + assert payload['evidence'] == [] + + await engine.dispose() diff --git a/backend/tests/backend/app/test_agent_router.py b/backend/tests/backend/app/test_agent_router.py index 50abafb..e3822d1 100644 --- a/backend/tests/backend/app/test_agent_router.py +++ b/backend/tests/backend/app/test_agent_router.py @@ -53,19 +53,17 @@ async def agent_env(tmp_path): is_active=True, owner_id=user.id, ) - session.add_all([ - Agent( - name='SCHEDULE PLANNER', - role='schedule_planner', - description='日程规划师', - system_prompt='prompt', - is_active=True, - ), - skill_a, - skill_b, - ]) + agent = Agent( + name='SCHEDULE PLANNER', + role='schedule_planner', + description='日程规划师', + system_prompt='prompt', + is_active=True, + ) + session.add_all([agent, skill_a, skill_b]) await session.commit() await session.refresh(user) + await session.refresh(agent) await session.refresh(skill_a) await session.refresh(skill_b) @@ -82,7 +80,7 @@ async def agent_env(tmp_path): test_app.dependency_overrides[get_current_user] = override_get_current_user try: - yield test_app, {'skill_a_id': skill_a.id, 'skill_b_id': skill_b.id} + yield test_app, {'agent_id': agent.id, 'skill_a_id': skill_a.id, 'skill_b_id': skill_b.id} finally: await engine.dispose() @@ -116,6 +114,32 @@ async def test_update_agent_config_persists_selected_skill_ids(agent_env): assert get_response.json()['selected_skill_ids'] == [ids['skill_a_id'], ids['skill_b_id']] +@pytest.mark.asyncio +async def test_get_agent_config_requires_authentication(agent_env): + app, _ids = agent_env + + async def override_get_current_user_unauthorized(): + raise RuntimeError('should not be called') + + app.dependency_overrides.pop(get_current_user, None) + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url='http://testserver') as client: + response = await client.get('/api/agents/config/schedule_planner') + + assert response.status_code == 401 + + +@pytest.mark.asyncio +async def test_get_agent_requires_authentication(agent_env): + app, ids = agent_env + app.dependency_overrides.pop(get_current_user, None) + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url='http://testserver') as client: + response = await client.get(f"/api/agents/{ids['agent_id']}") + + assert response.status_code == 401 + + @pytest.mark.asyncio async def test_update_agent_config_preserves_selected_skill_ids_when_omitted(agent_env): app, ids = agent_env @@ -148,3 +172,84 @@ async def test_update_agent_config_rejects_invalid_selected_skill_ids(agent_env) assert response.status_code == 400 assert response.json()['detail'] == '存在无效的技能绑定' + + +@pytest.mark.asyncio +async def test_create_agent_requires_superuser(agent_env): + app, _ids = agent_env + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url='http://testserver') as client: + response = await client.post( + '/api/agents', + json={ + 'name': 'Runtime Planner', + 'role': 'schedule_planning', + 'description': 'runtime', + 'system_prompt': 'prompt', + 'spawn_permission': True, + }, + ) + + assert response.status_code == 403 + assert response.json()['detail'] == '仅管理员可创建 Agent' + + +@pytest.mark.asyncio +async def test_create_agent_requires_spawn_permission_for_runtime_role(agent_env): + app, _ids = agent_env + + async def override_admin_user(): + return User( + username='admin_user', + email='admin@example.com', + hashed_password='x', + is_superuser=True, + ) + + app.dependency_overrides[get_current_user] = override_admin_user + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url='http://testserver') as client: + response = await client.post( + '/api/agents', + json={ + 'name': 'Runtime Planner', + 'role': 'schedule_planning', + 'description': 'runtime', + 'system_prompt': 'prompt', + }, + ) + + assert response.status_code == 400 + assert response.json()['detail'] == '缺少 spawn_permission,禁止直接创建 runtime agent' + + +@pytest.mark.asyncio +async def test_create_agent_accepts_allowed_runtime_role_for_superuser(agent_env): + app, _ids = agent_env + + async def override_admin_user(): + return User( + username='admin_user', + email='admin@example.com', + hashed_password='x', + is_superuser=True, + ) + + app.dependency_overrides[get_current_user] = override_admin_user + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url='http://testserver') as client: + response = await client.post( + '/api/agents', + json={ + 'name': 'Runtime Planner', + 'role': 'schedule_planning', + 'description': 'runtime', + 'system_prompt': 'prompt', + 'spawn_permission': True, + }, + ) + + assert response.status_code == 201 + payload = response.json() + assert payload['name'] == 'Runtime Planner' + assert payload['role'] == 'schedule_planning' diff --git a/backend/tests/backend/app/test_conversation_router.py b/backend/tests/backend/app/test_conversation_router.py new file mode 100644 index 0000000..f157beb --- /dev/null +++ b/backend/tests/backend/app/test_conversation_router.py @@ -0,0 +1,75 @@ +import pytest +from fastapi import FastAPI +from httpx import ASGITransport, AsyncClient +from sqlalchemy import text +from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine + +import app.models # noqa: F401 +from app.database import Base, get_db, ensure_conversation_columns +from app.models.conversation import Conversation +from app.models.user import User +from app.routers.auth import get_current_user +from app.routers.conversation import router as conversation_router +from app.services.auth_service import get_password_hash + + +@pytest.fixture +async def conversation_env(tmp_path): + db_path = tmp_path / 'test_conversation_router.db' + engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True) + session_factory = async_sessionmaker(engine, expire_on_commit=False) + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + await conn.execute(text('ALTER TABLE conversations DROP COLUMN agent_state')) + await ensure_conversation_columns(conn) + + async with session_factory() as session: + user = User( + username='conversation_user', + email='conversation@example.com', + hashed_password=get_password_hash('secret123'), + full_name='Conversation Tester', + is_active=True, + ) + session.add(user) + await session.flush() + session.add( + Conversation( + user_id=user.id, + title='Existing conversation', + message_count=3, + ) + ) + await session.commit() + await session.refresh(user) + + async def override_get_db(): + async with session_factory() as session: + yield session + + async def override_get_current_user(): + return user + + test_app = FastAPI() + test_app.include_router(conversation_router) + test_app.dependency_overrides[get_db] = override_get_db + test_app.dependency_overrides[get_current_user] = override_get_current_user + + try: + yield test_app + finally: + await engine.dispose() + + +@pytest.mark.asyncio +async def test_list_conversations_succeeds_when_agent_state_column_was_missing(conversation_env): + transport = ASGITransport(app=conversation_env) + async with AsyncClient(transport=transport, base_url='http://testserver') as client: + response = await client.get('/api/conversations') + + assert response.status_code == 200 + payload = response.json() + assert len(payload) == 1 + assert payload[0]['title'] == 'Existing conversation' + assert payload[0]['message_count'] == 3