feat(backend): add ontology and orchestrator API endpoints
New endpoints: - server/src/app/api/v1/endpoints/ontology.py: ontology API - server/src/app/api/v1/endpoints/orchestrator.py: orchestrator API New schemas: - server/src/app/schemas/ontology.py: ontology data schemas - server/src/app/schemas/orchestrator.py: orchestrator data schemas - server/src/app/schemas/user_agent.py: user agent data schemas New services: - server/src/app/services/ontology.py: ontology business logic - server/src/app/services/orchestrator.py: orchestrator business logic - server/src/app/services/runtime_chat.py: runtime chat service - server/src/app/services/user_agent.py: user agent service New tests: - server/tests/test_ontology_service.py - server/tests/test_orchestrator_service.py - server/tests/test_user_agent_service.py
This commit is contained in:
36
server/src/app/api/v1/endpoints/ontology.py
Normal file
36
server/src/app/api/v1/endpoints/ontology.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.api.deps import get_db
|
||||||
|
from app.schemas.common import ErrorResponse
|
||||||
|
from app.schemas.ontology import OntologyParseRequest, OntologyParseResult
|
||||||
|
from app.services.ontology import SemanticOntologyService
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/ontology")
|
||||||
|
DbSession = Annotated[Session, Depends(get_db)]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/parse",
|
||||||
|
response_model=OntologyParseResult,
|
||||||
|
summary="解析自然语言为语义本体",
|
||||||
|
description=(
|
||||||
|
"把自然语言问题解析成 Day 3 约定的 8 个核心字段,"
|
||||||
|
"并写入 AgentRun 与 SemanticParseLog。"
|
||||||
|
),
|
||||||
|
responses={
|
||||||
|
status.HTTP_400_BAD_REQUEST: {
|
||||||
|
"model": ErrorResponse,
|
||||||
|
"description": "请求缺少有效 query 或解析请求格式不合法。",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
def parse_ontology(payload: OntologyParseRequest, db: DbSession) -> OntologyParseResult:
|
||||||
|
try:
|
||||||
|
return SemanticOntologyService(db).parse(payload)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
|
||||||
33
server/src/app/api/v1/endpoints/orchestrator.py
Normal file
33
server/src/app/api/v1/endpoints/orchestrator.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.api.deps import get_db
|
||||||
|
from app.schemas.common import ErrorResponse
|
||||||
|
from app.schemas.orchestrator import OrchestratorRequest, OrchestratorResponse
|
||||||
|
from app.services.orchestrator import OrchestratorService
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/orchestrator")
|
||||||
|
DbSession = Annotated[Session, Depends(get_db)]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/run",
|
||||||
|
response_model=OrchestratorResponse,
|
||||||
|
summary="运行 Orchestrator 统一调度",
|
||||||
|
description="统一接收用户消息、定时任务和系统事件,完成语义解析、权限判断、路由和占位执行。",
|
||||||
|
responses={
|
||||||
|
status.HTTP_400_BAD_REQUEST: {
|
||||||
|
"model": ErrorResponse,
|
||||||
|
"description": "请求缺少 message 或 task_id,无法启动调度。",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
def run_orchestrator(payload: OrchestratorRequest, db: DbSession) -> OrchestratorResponse:
|
||||||
|
try:
|
||||||
|
return OrchestratorService(db).run(payload)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
|
||||||
116
server/src/app/schemas/ontology.py
Normal file
116
server/src/app/schemas/ontology.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
OntologyScenario = Literal[
|
||||||
|
"expense",
|
||||||
|
"accounts_receivable",
|
||||||
|
"accounts_payable",
|
||||||
|
"knowledge",
|
||||||
|
"unknown",
|
||||||
|
]
|
||||||
|
OntologyIntent = Literal["query", "explain", "compare", "risk_check", "draft", "operate"]
|
||||||
|
OntologyPermissionLevel = Literal["read", "draft_write", "approval_required", "forbidden"]
|
||||||
|
OntologyParseStrategy = Literal["llm_primary", "rule_fallback"]
|
||||||
|
|
||||||
|
|
||||||
|
class OntologyEntity(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
type: str = Field(description="业务对象类型,例如 employee / customer / vendor。")
|
||||||
|
value: str = Field(description="从原始问题中提取的对象值。")
|
||||||
|
normalized_value: str = Field(description="标准化后的对象值。")
|
||||||
|
role: str = Field(default="target", description="对象角色,例如 target / filter / threshold。")
|
||||||
|
confidence: float = Field(default=0.0, ge=0.0, le=1.0, description="字段级置信度。")
|
||||||
|
|
||||||
|
|
||||||
|
class OntologyTimeRange(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
raw: str = Field(default="", description="命中的原始时间表达。")
|
||||||
|
start_date: str | None = Field(default=None, description="ISO 格式起始日期。")
|
||||||
|
end_date: str | None = Field(default=None, description="ISO 格式结束日期。")
|
||||||
|
granularity: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="day / week / month / quarter / year。",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OntologyMetric(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
name: str = Field(description="指标名,例如 amount / count / overdue。")
|
||||||
|
aggregation: str | None = Field(default=None, description="sum / count / max 等聚合口径。")
|
||||||
|
unit: str | None = Field(default=None, description="金额、数量等单位。")
|
||||||
|
sort: str | None = Field(default=None, description="asc / desc 排序方向。")
|
||||||
|
top_n: int | None = Field(default=None, ge=1, description="Top N 口径。")
|
||||||
|
|
||||||
|
|
||||||
|
class OntologyConstraint(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
field: str = Field(description="约束字段,例如 department / status / amount。")
|
||||||
|
operator: str = Field(description="操作符,例如 = / > / < / desc。")
|
||||||
|
value: str | int | float | bool = Field(description="约束值。")
|
||||||
|
currency: str | None = Field(default=None, description="金额类约束使用的币种。")
|
||||||
|
|
||||||
|
|
||||||
|
class OntologyPermission(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
level: OntologyPermissionLevel = Field(default="read", description="动作权限等级。")
|
||||||
|
allowed: bool = Field(default=True, description="是否可直接执行当前动作。")
|
||||||
|
reason: str = Field(default="", description="权限判断原因。")
|
||||||
|
|
||||||
|
|
||||||
|
class OntologyFieldError(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
field: str = Field(description="发生问题的字段。")
|
||||||
|
code: str = Field(description="错误码。")
|
||||||
|
message: str = Field(description="面向前端展示的说明。")
|
||||||
|
|
||||||
|
|
||||||
|
class OntologyParseRequest(BaseModel):
|
||||||
|
query: str = Field(min_length=1, description="自然语言问题。")
|
||||||
|
user_id: str | None = Field(default=None, description="当前请求用户 ID。")
|
||||||
|
context_json: dict[str, Any] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="用户上下文,例如角色、部门、是否管理员。",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OntologyParseResult(BaseModel):
|
||||||
|
scenario: OntologyScenario = Field(default="unknown", description="业务场景。")
|
||||||
|
intent: OntologyIntent = Field(default="query", description="用户意图。")
|
||||||
|
entities: list[OntologyEntity] = Field(default_factory=list, description="业务对象列表。")
|
||||||
|
time_range: OntologyTimeRange = Field(
|
||||||
|
default_factory=OntologyTimeRange,
|
||||||
|
description="时间范围。",
|
||||||
|
)
|
||||||
|
metrics: list[OntologyMetric] = Field(default_factory=list, description="指标解析结果。")
|
||||||
|
constraints: list[OntologyConstraint] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="过滤、阈值、排序等约束。",
|
||||||
|
)
|
||||||
|
risk_flags: list[str] = Field(default_factory=list, description="风险信号列表。")
|
||||||
|
permission: OntologyPermission = Field(
|
||||||
|
default_factory=OntologyPermission,
|
||||||
|
description="权限结果。",
|
||||||
|
)
|
||||||
|
confidence: float = Field(default=0.0, ge=0.0, le=1.0, description="整体置信度。")
|
||||||
|
missing_slots: list[str] = Field(default_factory=list, description="继续处理所缺少的关键槽位。")
|
||||||
|
ambiguity: list[str] = Field(default_factory=list, description="当前识别中的潜在歧义。")
|
||||||
|
parse_strategy: OntologyParseStrategy = Field(
|
||||||
|
default="rule_fallback",
|
||||||
|
description="本次语义解析使用的主策略。",
|
||||||
|
)
|
||||||
|
clarification_required: bool = Field(default=False, description="是否需要追问。")
|
||||||
|
clarification_question: str | None = Field(default=None, description="推荐追问问题。")
|
||||||
|
run_id: str = Field(description="关联的 AgentRun.run_id。")
|
||||||
|
field_errors: list[OntologyFieldError] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="字段级错误或提示。",
|
||||||
|
)
|
||||||
46
server/src/app/schemas/orchestrator.py
Normal file
46
server/src/app/schemas/orchestrator.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
OrchestratorSource = Literal["user_message", "schedule", "system_event"]
|
||||||
|
OrchestratorAgent = Literal["user_agent", "hermes"]
|
||||||
|
OrchestratorStatus = Literal["succeeded", "blocked", "failed"]
|
||||||
|
|
||||||
|
|
||||||
|
class OrchestratorRequest(BaseModel):
|
||||||
|
source: OrchestratorSource = Field(description="请求来源。")
|
||||||
|
user_id: str | None = Field(default=None, description="当前用户 ID,任务触发可为空。")
|
||||||
|
message: str | None = Field(default=None, description="用户消息或任务描述。")
|
||||||
|
task_id: str | None = Field(default=None, description="任务资产 ID,schedule 触发时优先使用。")
|
||||||
|
context_json: dict[str, Any] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="用户上下文、测试开关或调用方附加信息。",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OrchestratorTraceSummary(BaseModel):
|
||||||
|
scenario: str = Field(description="语义场景。")
|
||||||
|
intent: str = Field(description="语义意图。")
|
||||||
|
tool_count: int = Field(default=0, ge=0, description="工具调用总数。")
|
||||||
|
failed_tool_count: int = Field(default=0, ge=0, description="失败工具调用数量。")
|
||||||
|
selected_capability_codes: list[str] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="本次路由命中的能力编码。",
|
||||||
|
)
|
||||||
|
degraded: bool = Field(default=False, description="是否发生降级。")
|
||||||
|
|
||||||
|
|
||||||
|
class OrchestratorResponse(BaseModel):
|
||||||
|
run_id: str = Field(description="本次运行的唯一 run_id。")
|
||||||
|
selected_agent: OrchestratorAgent | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="最终路由到的下游 Agent。",
|
||||||
|
)
|
||||||
|
route_reason: str = Field(description="路由原因摘要。")
|
||||||
|
permission_level: str = Field(description="权限级别。")
|
||||||
|
status: OrchestratorStatus = Field(description="最终运行状态。")
|
||||||
|
result: dict[str, Any] = Field(default_factory=dict, description="对前端可直接展示的最小结果。")
|
||||||
|
requires_confirmation: bool = Field(default=False, description="是否需要用户或管理员确认。")
|
||||||
|
trace_summary: OrchestratorTraceSummary = Field(description="简化后的 Trace 摘要。")
|
||||||
58
server/src/app/schemas/user_agent.py
Normal file
58
server/src/app/schemas/user_agent.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from app.schemas.ontology import OntologyParseResult
|
||||||
|
|
||||||
|
UserAgentCitationType = Literal["rule", "knowledge"]
|
||||||
|
|
||||||
|
|
||||||
|
class UserAgentCitation(BaseModel):
|
||||||
|
source_type: UserAgentCitationType = Field(description="引用来源类型。")
|
||||||
|
code: str = Field(description="来源编码。")
|
||||||
|
title: str = Field(description="来源标题。")
|
||||||
|
version: str | None = Field(default=None, description="引用版本。")
|
||||||
|
updated_at: str | None = Field(default=None, description="来源更新时间。")
|
||||||
|
excerpt: str | None = Field(default=None, description="面向用户展示的引用摘要。")
|
||||||
|
|
||||||
|
|
||||||
|
class UserAgentSuggestedAction(BaseModel):
|
||||||
|
label: str = Field(description="建议动作文案。")
|
||||||
|
action_type: str = Field(description="动作类型,例如 open_detail / create_draft。")
|
||||||
|
description: str = Field(default="", description="动作说明。")
|
||||||
|
|
||||||
|
|
||||||
|
class UserAgentDraftPayload(BaseModel):
|
||||||
|
draft_type: str = Field(description="草稿类型。")
|
||||||
|
title: str = Field(description="草稿标题。")
|
||||||
|
body: str = Field(description="草稿正文。")
|
||||||
|
confirmation_required: bool = Field(default=True, description="是否需要人工确认。")
|
||||||
|
|
||||||
|
|
||||||
|
class UserAgentRequest(BaseModel):
|
||||||
|
run_id: str = Field(description="关联的 AgentRun.run_id。")
|
||||||
|
user_id: str | None = Field(default=None, description="当前请求用户 ID。")
|
||||||
|
message: str = Field(description="原始用户问题。")
|
||||||
|
ontology: OntologyParseResult = Field(description="语义解析结果。")
|
||||||
|
context_json: dict[str, Any] = Field(default_factory=dict, description="附加上下文。")
|
||||||
|
tool_payload: dict[str, Any] = Field(default_factory=dict, description="工具返回的原始结果。")
|
||||||
|
selected_capability_codes: list[str] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="本次命中的能力编码。",
|
||||||
|
)
|
||||||
|
degraded: bool = Field(default=False, description="当前是否发生降级。")
|
||||||
|
requires_confirmation: bool = Field(default=False, description="是否要求确认。")
|
||||||
|
|
||||||
|
|
||||||
|
class UserAgentResponse(BaseModel):
|
||||||
|
answer: str = Field(description="面向用户展示的自然语言回答。")
|
||||||
|
citations: list[UserAgentCitation] = Field(default_factory=list, description="规则或知识引用。")
|
||||||
|
suggested_actions: list[UserAgentSuggestedAction] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="建议的下一步动作。",
|
||||||
|
)
|
||||||
|
draft_payload: UserAgentDraftPayload | None = Field(default=None, description="可选草稿内容。")
|
||||||
|
risk_flags: list[str] = Field(default_factory=list, description="本次回答关联的风险标签。")
|
||||||
|
requires_confirmation: bool = Field(default=False, description="是否需要人工确认。")
|
||||||
1470
server/src/app/services/ontology.py
Normal file
1470
server/src/app/services/ontology.py
Normal file
File diff suppressed because it is too large
Load Diff
887
server/src/app/services/orchestrator.py
Normal file
887
server/src/app/services/orchestrator.py
Normal file
@@ -0,0 +1,887 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from time import perf_counter
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy import func, select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.core.agent_enums import (
|
||||||
|
AgentAssetStatus,
|
||||||
|
AgentAssetType,
|
||||||
|
AgentName,
|
||||||
|
AgentPermissionLevel,
|
||||||
|
AgentRunSource,
|
||||||
|
AgentRunStatus,
|
||||||
|
AgentToolType,
|
||||||
|
)
|
||||||
|
from app.core.logging import get_logger
|
||||||
|
from app.models.financial_record import (
|
||||||
|
AccountsPayableRecord,
|
||||||
|
AccountsReceivableRecord,
|
||||||
|
ExpenseClaim,
|
||||||
|
)
|
||||||
|
from app.schemas.agent_asset import AgentAssetListItem, AgentAssetRead
|
||||||
|
from app.schemas.ontology import OntologyParseRequest, OntologyParseResult
|
||||||
|
from app.schemas.orchestrator import (
|
||||||
|
OrchestratorRequest,
|
||||||
|
OrchestratorResponse,
|
||||||
|
OrchestratorTraceSummary,
|
||||||
|
)
|
||||||
|
from app.schemas.user_agent import UserAgentRequest, UserAgentResponse
|
||||||
|
from app.services.agent_assets import AgentAssetService
|
||||||
|
from app.services.agent_foundation import AgentFoundationService
|
||||||
|
from app.services.agent_runs import AgentRunService
|
||||||
|
from app.services.ontology import SemanticOntologyService
|
||||||
|
from app.services.user_agent import UserAgentService
|
||||||
|
|
||||||
|
logger = get_logger("app.services.orchestrator")
|
||||||
|
|
||||||
|
SCENARIO_TO_DOMAIN = {
|
||||||
|
"expense": "expense",
|
||||||
|
"accounts_receivable": "ar",
|
||||||
|
"accounts_payable": "ap",
|
||||||
|
"knowledge": "knowledge",
|
||||||
|
"unknown": "system",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class ExecutionOutcome:
|
||||||
|
status: str
|
||||||
|
result: dict[str, Any]
|
||||||
|
degraded: bool
|
||||||
|
tool_count: int
|
||||||
|
failed_tool_count: int
|
||||||
|
|
||||||
|
|
||||||
|
class OrchestratorService:
|
||||||
|
def __init__(self, db: Session) -> None:
|
||||||
|
self.db = db
|
||||||
|
self.asset_service = AgentAssetService(db)
|
||||||
|
self.run_service = AgentRunService(db)
|
||||||
|
self.ontology_service = SemanticOntologyService(db)
|
||||||
|
self.user_agent_service = UserAgentService(db)
|
||||||
|
|
||||||
|
def run(self, payload: OrchestratorRequest) -> OrchestratorResponse:
|
||||||
|
AgentFoundationService(self.db).ensure_foundation_ready()
|
||||||
|
route_json: dict[str, Any] = {
|
||||||
|
"orchestrated_by": AgentName.ORCHESTRATOR.value,
|
||||||
|
"stage": "created",
|
||||||
|
}
|
||||||
|
run = self.run_service.create_run(
|
||||||
|
agent=AgentName.ORCHESTRATOR.value,
|
||||||
|
source=payload.source,
|
||||||
|
user_id=payload.user_id,
|
||||||
|
task_id=payload.task_id,
|
||||||
|
ontology_json={},
|
||||||
|
route_json=route_json,
|
||||||
|
permission_level=AgentPermissionLevel.READ.value,
|
||||||
|
status=AgentRunStatus.RUNNING.value,
|
||||||
|
result_summary="Orchestrator 已接收请求。",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
message, task_asset = self._resolve_message(payload)
|
||||||
|
ontology = self.ontology_service.parse_for_run(
|
||||||
|
OntologyParseRequest(
|
||||||
|
query=message,
|
||||||
|
user_id=payload.user_id,
|
||||||
|
context_json=payload.context_json,
|
||||||
|
),
|
||||||
|
run_id=run.run_id,
|
||||||
|
)
|
||||||
|
if payload.context_json.get("simulate_orchestrator_exception"):
|
||||||
|
raise RuntimeError("simulated orchestrator exception")
|
||||||
|
selected_agent, route_reason = self._select_agent(payload, ontology)
|
||||||
|
capabilities = self._select_capabilities(
|
||||||
|
payload=payload,
|
||||||
|
ontology=ontology,
|
||||||
|
task_asset=task_asset,
|
||||||
|
)
|
||||||
|
selected_capability_codes = self._flatten_capability_codes(capabilities)
|
||||||
|
requires_confirmation = (
|
||||||
|
ontology.permission.level == AgentPermissionLevel.APPROVAL_REQUIRED.value
|
||||||
|
)
|
||||||
|
|
||||||
|
route_json = {
|
||||||
|
"orchestrated_by": AgentName.ORCHESTRATOR.value,
|
||||||
|
"stage": "routed",
|
||||||
|
"selected_agent": selected_agent,
|
||||||
|
"route_reason": route_reason,
|
||||||
|
"selected_capability_codes": selected_capability_codes,
|
||||||
|
"ontology_run_id": ontology.run_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
if ontology.permission.level == AgentPermissionLevel.FORBIDDEN.value:
|
||||||
|
outcome = ExecutionOutcome(
|
||||||
|
status=AgentRunStatus.BLOCKED.value,
|
||||||
|
result={
|
||||||
|
"message": ontology.permission.reason,
|
||||||
|
"clarification_question": ontology.clarification_question,
|
||||||
|
"degraded": False,
|
||||||
|
},
|
||||||
|
degraded=False,
|
||||||
|
tool_count=0,
|
||||||
|
failed_tool_count=0,
|
||||||
|
)
|
||||||
|
selected_agent = None
|
||||||
|
route_reason = "permission_forbidden"
|
||||||
|
route_json["stage"] = "blocked"
|
||||||
|
route_json["route_reason"] = route_reason
|
||||||
|
elif ontology.clarification_required:
|
||||||
|
outcome = ExecutionOutcome(
|
||||||
|
status=AgentRunStatus.BLOCKED.value,
|
||||||
|
result={
|
||||||
|
"message": ontology.clarification_question or "需要补充更多上下文。",
|
||||||
|
"clarification_required": True,
|
||||||
|
"missing_slots": ontology.missing_slots,
|
||||||
|
"ambiguity": ontology.ambiguity,
|
||||||
|
"parse_strategy": ontology.parse_strategy,
|
||||||
|
"degraded": False,
|
||||||
|
},
|
||||||
|
degraded=False,
|
||||||
|
tool_count=0,
|
||||||
|
failed_tool_count=0,
|
||||||
|
)
|
||||||
|
route_reason = "clarification_required"
|
||||||
|
route_json["stage"] = "clarification"
|
||||||
|
route_json["route_reason"] = route_reason
|
||||||
|
elif selected_agent == AgentName.HERMES.value:
|
||||||
|
outcome = self._execute_hermes(
|
||||||
|
payload=payload,
|
||||||
|
run_id=run.run_id,
|
||||||
|
ontology=ontology,
|
||||||
|
capabilities=capabilities,
|
||||||
|
requires_confirmation=requires_confirmation,
|
||||||
|
task_asset=task_asset,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
outcome = self._execute_user_agent(
|
||||||
|
payload=payload,
|
||||||
|
run_id=run.run_id,
|
||||||
|
ontology=ontology,
|
||||||
|
capabilities=capabilities,
|
||||||
|
requires_confirmation=requires_confirmation,
|
||||||
|
)
|
||||||
|
|
||||||
|
final_status = (
|
||||||
|
AgentRunStatus.BLOCKED.value
|
||||||
|
if requires_confirmation
|
||||||
|
and outcome.status == AgentRunStatus.SUCCEEDED.value
|
||||||
|
and ontology.permission.level == AgentPermissionLevel.APPROVAL_REQUIRED.value
|
||||||
|
else outcome.status
|
||||||
|
)
|
||||||
|
result_message = (
|
||||||
|
str(outcome.result.get("message", "")).strip()
|
||||||
|
or "Orchestrator 执行完成。"
|
||||||
|
)
|
||||||
|
self.run_service.update_run(
|
||||||
|
run.run_id,
|
||||||
|
agent=selected_agent or AgentName.ORCHESTRATOR.value,
|
||||||
|
ontology_json=self._build_ontology_json(ontology),
|
||||||
|
route_json={
|
||||||
|
**route_json,
|
||||||
|
"requires_confirmation": requires_confirmation,
|
||||||
|
"degraded": outcome.degraded,
|
||||||
|
},
|
||||||
|
permission_level=ontology.permission.level,
|
||||||
|
status=final_status,
|
||||||
|
result_summary=result_message,
|
||||||
|
error_message=None,
|
||||||
|
finished_at=datetime.now(UTC),
|
||||||
|
)
|
||||||
|
return OrchestratorResponse(
|
||||||
|
run_id=run.run_id,
|
||||||
|
selected_agent=selected_agent,
|
||||||
|
route_reason=route_reason,
|
||||||
|
permission_level=ontology.permission.level,
|
||||||
|
status=self._normalize_response_status(final_status),
|
||||||
|
result=outcome.result,
|
||||||
|
requires_confirmation=requires_confirmation,
|
||||||
|
trace_summary=OrchestratorTraceSummary(
|
||||||
|
scenario=ontology.scenario,
|
||||||
|
intent=ontology.intent,
|
||||||
|
tool_count=outcome.tool_count,
|
||||||
|
failed_tool_count=outcome.failed_tool_count,
|
||||||
|
selected_capability_codes=selected_capability_codes,
|
||||||
|
degraded=outcome.degraded,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Orchestrator run failed run_id=%s", run.run_id)
|
||||||
|
self.run_service.update_run(
|
||||||
|
run.run_id,
|
||||||
|
agent=AgentName.ORCHESTRATOR.value,
|
||||||
|
route_json={**route_json, "stage": "failed"},
|
||||||
|
status=AgentRunStatus.FAILED.value,
|
||||||
|
result_summary="Orchestrator 执行失败。",
|
||||||
|
error_message=str(exc),
|
||||||
|
finished_at=datetime.now(UTC),
|
||||||
|
)
|
||||||
|
return OrchestratorResponse(
|
||||||
|
run_id=run.run_id,
|
||||||
|
selected_agent=None,
|
||||||
|
route_reason="orchestrator_exception",
|
||||||
|
permission_level=AgentPermissionLevel.READ.value,
|
||||||
|
status="failed",
|
||||||
|
result={"message": f"Orchestrator 执行失败:{exc}"},
|
||||||
|
requires_confirmation=False,
|
||||||
|
trace_summary=OrchestratorTraceSummary(
|
||||||
|
scenario="unknown",
|
||||||
|
intent="query",
|
||||||
|
tool_count=0,
|
||||||
|
failed_tool_count=0,
|
||||||
|
selected_capability_codes=[],
|
||||||
|
degraded=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _resolve_message(
|
||||||
|
self,
|
||||||
|
payload: OrchestratorRequest,
|
||||||
|
) -> tuple[str, AgentAssetRead | None]:
|
||||||
|
task_asset = None
|
||||||
|
if payload.task_id:
|
||||||
|
task_asset = self.asset_service.get_asset(payload.task_id)
|
||||||
|
|
||||||
|
if payload.message and payload.message.strip():
|
||||||
|
return payload.message.strip(), task_asset
|
||||||
|
|
||||||
|
if task_asset is not None:
|
||||||
|
description = str(task_asset.description or "").strip()
|
||||||
|
scenario_text = " ".join(str(item) for item in task_asset.scenario_json)
|
||||||
|
message = f"{task_asset.name} {description} {scenario_text}".strip()
|
||||||
|
return message, task_asset
|
||||||
|
|
||||||
|
if payload.source == AgentRunSource.SCHEDULE.value:
|
||||||
|
return "定时风险巡检任务", task_asset
|
||||||
|
|
||||||
|
raise ValueError("message 或 task_id 至少需要提供一个。")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _select_agent(
|
||||||
|
payload: OrchestratorRequest,
|
||||||
|
ontology: OntologyParseResult,
|
||||||
|
) -> tuple[str, str]:
|
||||||
|
if payload.source == AgentRunSource.SCHEDULE.value:
|
||||||
|
return AgentName.HERMES.value, "schedule_source_defaults_to_hermes"
|
||||||
|
if payload.source == AgentRunSource.SYSTEM_EVENT.value and ontology.intent == "risk_check":
|
||||||
|
return AgentName.HERMES.value, "system_event_risk_check_routes_to_hermes"
|
||||||
|
if ontology.intent == "risk_check" and payload.source == AgentRunSource.SCHEDULE.value:
|
||||||
|
return AgentName.HERMES.value, "scheduled_risk_check_routes_to_hermes"
|
||||||
|
if ontology.intent in {"query", "explain", "draft", "compare", "operate"}:
|
||||||
|
return AgentName.USER_AGENT.value, f"{ontology.intent}_routes_to_user_agent"
|
||||||
|
return AgentName.USER_AGENT.value, "user_message_defaults_to_user_agent"
|
||||||
|
|
||||||
|
def _select_capabilities(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
payload: OrchestratorRequest,
|
||||||
|
ontology: OntologyParseResult,
|
||||||
|
task_asset: AgentAssetRead | None,
|
||||||
|
) -> dict[str, list[AgentAssetListItem | AgentAssetRead]]:
|
||||||
|
domain_value = SCENARIO_TO_DOMAIN.get(ontology.scenario)
|
||||||
|
rules = self._rank_assets(
|
||||||
|
self.asset_service.list_assets(
|
||||||
|
asset_type=AgentAssetType.RULE.value,
|
||||||
|
status=AgentAssetStatus.ACTIVE.value,
|
||||||
|
domain=domain_value if domain_value not in {"knowledge", "system"} else None,
|
||||||
|
),
|
||||||
|
ontology,
|
||||||
|
)
|
||||||
|
skills = self._rank_assets(
|
||||||
|
self.asset_service.list_assets(
|
||||||
|
asset_type=AgentAssetType.SKILL.value,
|
||||||
|
status=AgentAssetStatus.ACTIVE.value,
|
||||||
|
domain=domain_value if domain_value not in {"system"} else None,
|
||||||
|
),
|
||||||
|
ontology,
|
||||||
|
)
|
||||||
|
mcps = self._rank_assets(
|
||||||
|
self.asset_service.list_assets(
|
||||||
|
asset_type=AgentAssetType.MCP.value,
|
||||||
|
status=AgentAssetStatus.ACTIVE.value,
|
||||||
|
),
|
||||||
|
ontology,
|
||||||
|
)
|
||||||
|
tasks: list[AgentAssetListItem | AgentAssetRead] = []
|
||||||
|
if task_asset is not None and task_asset.status == AgentAssetStatus.ACTIVE.value:
|
||||||
|
tasks.append(task_asset)
|
||||||
|
elif payload.source == AgentRunSource.SCHEDULE.value:
|
||||||
|
tasks = self._rank_assets(
|
||||||
|
self.asset_service.list_assets(
|
||||||
|
asset_type=AgentAssetType.TASK.value,
|
||||||
|
status=AgentAssetStatus.ACTIVE.value,
|
||||||
|
),
|
||||||
|
ontology,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"rules": rules,
|
||||||
|
"skills": skills,
|
||||||
|
"mcps": mcps,
|
||||||
|
"tasks": tasks,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _execute_user_agent(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
payload: OrchestratorRequest,
|
||||||
|
run_id: str,
|
||||||
|
ontology: OntologyParseResult,
|
||||||
|
capabilities: dict[str, list[AgentAssetListItem | AgentAssetRead]],
|
||||||
|
requires_confirmation: bool,
|
||||||
|
) -> ExecutionOutcome:
|
||||||
|
selected_capability_codes = self._flatten_capability_codes(capabilities)
|
||||||
|
if requires_confirmation:
|
||||||
|
response, degraded = self._invoke_tool(
|
||||||
|
run_id=run_id,
|
||||||
|
tool_type=AgentToolType.LLM.value,
|
||||||
|
tool_name="user_agent.confirmation_placeholder",
|
||||||
|
request_json={
|
||||||
|
"message": payload.message,
|
||||||
|
"permission_level": ontology.permission.level,
|
||||||
|
},
|
||||||
|
context_json=payload.context_json,
|
||||||
|
executor=lambda: {
|
||||||
|
"confirmation_title": "操作需要确认",
|
||||||
|
"message": f"{ontology.permission.reason} 当前仅返回确认摘要,不直接执行动作。",
|
||||||
|
},
|
||||||
|
fallback_factory=lambda exc: {
|
||||||
|
"confirmation_title": "操作需要确认",
|
||||||
|
"message": f"确认摘要生成失败,已阻断自动执行:{exc}",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return ExecutionOutcome(
|
||||||
|
status=AgentRunStatus.BLOCKED.value,
|
||||||
|
result={**response, "degraded": degraded},
|
||||||
|
degraded=degraded,
|
||||||
|
tool_count=1,
|
||||||
|
failed_tool_count=1 if degraded else 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
next_step = self._resolve_next_step(ontology, payload.source)
|
||||||
|
if next_step == "query_database":
|
||||||
|
tool_payload, degraded = self._invoke_tool(
|
||||||
|
run_id=run_id,
|
||||||
|
tool_type=AgentToolType.DATABASE.value,
|
||||||
|
tool_name=self._database_tool_name(ontology.scenario),
|
||||||
|
request_json=self._build_ontology_json(ontology),
|
||||||
|
context_json=payload.context_json,
|
||||||
|
executor=lambda: self._build_database_answer(ontology),
|
||||||
|
fallback_factory=lambda exc: {
|
||||||
|
"message": f"数据库查询暂时不可用,已返回降级说明:{exc}",
|
||||||
|
"degraded": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
result = self._build_user_agent_result(
|
||||||
|
self.user_agent_service.respond(
|
||||||
|
UserAgentRequest(
|
||||||
|
run_id=run_id,
|
||||||
|
user_id=payload.user_id,
|
||||||
|
message=payload.message or "",
|
||||||
|
ontology=ontology,
|
||||||
|
context_json=payload.context_json,
|
||||||
|
tool_payload=tool_payload,
|
||||||
|
selected_capability_codes=selected_capability_codes,
|
||||||
|
degraded=degraded,
|
||||||
|
requires_confirmation=requires_confirmation,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
degraded=degraded,
|
||||||
|
)
|
||||||
|
return ExecutionOutcome(
|
||||||
|
status=AgentRunStatus.SUCCEEDED.value,
|
||||||
|
result=result,
|
||||||
|
degraded=degraded,
|
||||||
|
tool_count=1,
|
||||||
|
failed_tool_count=1 if degraded else 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if next_step == "search_knowledge":
|
||||||
|
tool_payload, degraded = self._invoke_tool(
|
||||||
|
run_id=run_id,
|
||||||
|
tool_type=AgentToolType.DATABASE.value,
|
||||||
|
tool_name="knowledge.search",
|
||||||
|
request_json=self._build_ontology_json(ontology),
|
||||||
|
context_json=payload.context_json,
|
||||||
|
executor=lambda: self._build_knowledge_answer(ontology, capabilities),
|
||||||
|
fallback_factory=lambda exc: {
|
||||||
|
"message": f"知识检索暂时不可用,建议稍后重试:{exc}",
|
||||||
|
"degraded": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
result = self._build_user_agent_result(
|
||||||
|
self.user_agent_service.respond(
|
||||||
|
UserAgentRequest(
|
||||||
|
run_id=run_id,
|
||||||
|
user_id=payload.user_id,
|
||||||
|
message=payload.message or "",
|
||||||
|
ontology=ontology,
|
||||||
|
context_json=payload.context_json,
|
||||||
|
tool_payload=tool_payload,
|
||||||
|
selected_capability_codes=selected_capability_codes,
|
||||||
|
degraded=degraded,
|
||||||
|
requires_confirmation=requires_confirmation,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
degraded=degraded,
|
||||||
|
)
|
||||||
|
return ExecutionOutcome(
|
||||||
|
status=AgentRunStatus.SUCCEEDED.value,
|
||||||
|
result=result,
|
||||||
|
degraded=degraded,
|
||||||
|
tool_count=1,
|
||||||
|
failed_tool_count=1 if degraded else 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if next_step == "run_rule":
|
||||||
|
tool_payload, degraded = self._invoke_tool(
|
||||||
|
run_id=run_id,
|
||||||
|
tool_type=AgentToolType.RULE_ENGINE.value,
|
||||||
|
tool_name=self._rule_tool_name(capabilities),
|
||||||
|
request_json=self._build_ontology_json(ontology),
|
||||||
|
context_json=payload.context_json,
|
||||||
|
executor=lambda: self._build_rule_answer(ontology),
|
||||||
|
fallback_factory=lambda exc: {
|
||||||
|
"message": f"规则检查暂时不可用,已返回人工复核建议:{exc}",
|
||||||
|
"degraded": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
result = self._build_user_agent_result(
|
||||||
|
self.user_agent_service.respond(
|
||||||
|
UserAgentRequest(
|
||||||
|
run_id=run_id,
|
||||||
|
user_id=payload.user_id,
|
||||||
|
message=payload.message or "",
|
||||||
|
ontology=ontology,
|
||||||
|
context_json=payload.context_json,
|
||||||
|
tool_payload=tool_payload,
|
||||||
|
selected_capability_codes=selected_capability_codes,
|
||||||
|
degraded=degraded,
|
||||||
|
requires_confirmation=requires_confirmation,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
degraded=degraded,
|
||||||
|
)
|
||||||
|
return ExecutionOutcome(
|
||||||
|
status=AgentRunStatus.SUCCEEDED.value,
|
||||||
|
result=result,
|
||||||
|
degraded=degraded,
|
||||||
|
tool_count=1,
|
||||||
|
failed_tool_count=1 if degraded else 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_payload, degraded = self._invoke_tool(
|
||||||
|
run_id=run_id,
|
||||||
|
tool_type=AgentToolType.LLM.value,
|
||||||
|
tool_name="user_agent.draft_placeholder",
|
||||||
|
request_json=self._build_ontology_json(ontology),
|
||||||
|
context_json=payload.context_json,
|
||||||
|
executor=lambda: {
|
||||||
|
"message": (
|
||||||
|
f"已生成 {ontology.scenario} 场景草稿,"
|
||||||
|
"占位能力后续由 Day 5 User Agent 接管。"
|
||||||
|
),
|
||||||
|
"draft_only": True,
|
||||||
|
},
|
||||||
|
fallback_factory=lambda exc: {
|
||||||
|
"message": f"草稿生成暂时不可用,请稍后再试:{exc}",
|
||||||
|
"degraded": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
result = self._build_user_agent_result(
|
||||||
|
self.user_agent_service.respond(
|
||||||
|
UserAgentRequest(
|
||||||
|
run_id=run_id,
|
||||||
|
user_id=payload.user_id,
|
||||||
|
message=payload.message or "",
|
||||||
|
ontology=ontology,
|
||||||
|
context_json=payload.context_json,
|
||||||
|
tool_payload=tool_payload,
|
||||||
|
selected_capability_codes=selected_capability_codes,
|
||||||
|
degraded=degraded,
|
||||||
|
requires_confirmation=requires_confirmation,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
degraded=degraded,
|
||||||
|
)
|
||||||
|
return ExecutionOutcome(
|
||||||
|
status=AgentRunStatus.SUCCEEDED.value,
|
||||||
|
result=result,
|
||||||
|
degraded=degraded,
|
||||||
|
tool_count=1,
|
||||||
|
failed_tool_count=1 if degraded else 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _execute_hermes(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
payload: OrchestratorRequest,
|
||||||
|
run_id: str,
|
||||||
|
ontology: OntologyParseResult,
|
||||||
|
capabilities: dict[str, list[AgentAssetListItem | AgentAssetRead]],
|
||||||
|
requires_confirmation: bool,
|
||||||
|
task_asset: AgentAssetRead | None,
|
||||||
|
) -> ExecutionOutcome:
|
||||||
|
if requires_confirmation:
|
||||||
|
return ExecutionOutcome(
|
||||||
|
status=AgentRunStatus.BLOCKED.value,
|
||||||
|
result={
|
||||||
|
"message": "Hermes 不会自动执行需要确认的高风险动作,已阻断。",
|
||||||
|
"degraded": False,
|
||||||
|
},
|
||||||
|
degraded=False,
|
||||||
|
tool_count=0,
|
||||||
|
failed_tool_count=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
rule_response, rule_degraded = self._invoke_tool(
|
||||||
|
run_id=run_id,
|
||||||
|
tool_type=AgentToolType.RULE_ENGINE.value,
|
||||||
|
tool_name=self._rule_tool_name(capabilities),
|
||||||
|
request_json=self._build_ontology_json(ontology),
|
||||||
|
context_json=payload.context_json,
|
||||||
|
executor=lambda: self._build_rule_answer(ontology),
|
||||||
|
fallback_factory=lambda exc: {
|
||||||
|
"message": f"规则巡检失败,已降级为待人工复核:{exc}",
|
||||||
|
"degraded": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
mcp_response, mcp_degraded = self._invoke_tool(
|
||||||
|
run_id=run_id,
|
||||||
|
tool_type=AgentToolType.MCP.value,
|
||||||
|
tool_name=self._mcp_tool_name(capabilities),
|
||||||
|
request_json={
|
||||||
|
"task_code": task_asset.code if task_asset is not None else "",
|
||||||
|
"scenario": ontology.scenario,
|
||||||
|
},
|
||||||
|
context_json=payload.context_json,
|
||||||
|
executor=lambda: self._build_mcp_answer(task_asset, ontology),
|
||||||
|
fallback_factory=lambda exc: {
|
||||||
|
"message": f"MCP 调用失败,已使用缓存快照降级:{exc}",
|
||||||
|
"fallback": "used_cached_snapshot",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
degraded = rule_degraded or mcp_degraded
|
||||||
|
failed_tool_count = int(rule_degraded) + int(mcp_degraded)
|
||||||
|
result = {
|
||||||
|
"message": self._build_hermes_message(
|
||||||
|
task_asset=task_asset,
|
||||||
|
ontology=ontology,
|
||||||
|
rule_response=rule_response,
|
||||||
|
mcp_response=mcp_response,
|
||||||
|
degraded=degraded,
|
||||||
|
),
|
||||||
|
"report_type": task_asset.code if task_asset is not None else "hermes_runtime",
|
||||||
|
"degraded": degraded,
|
||||||
|
}
|
||||||
|
return ExecutionOutcome(
|
||||||
|
status=AgentRunStatus.SUCCEEDED.value,
|
||||||
|
result=result,
|
||||||
|
degraded=degraded,
|
||||||
|
tool_count=2,
|
||||||
|
failed_tool_count=failed_tool_count,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _resolve_next_step(ontology: OntologyParseResult, source: str) -> str:
|
||||||
|
if ontology.clarification_required:
|
||||||
|
return "ask_clarification"
|
||||||
|
if ontology.intent == "draft":
|
||||||
|
return "create_draft"
|
||||||
|
if ontology.scenario == "knowledge" or ontology.intent == "explain":
|
||||||
|
return "search_knowledge"
|
||||||
|
if ontology.intent == "risk_check" or source == AgentRunSource.SCHEDULE.value:
|
||||||
|
return "run_rule"
|
||||||
|
if ontology.intent in {"query", "compare"}:
|
||||||
|
return "query_database"
|
||||||
|
return "create_draft"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _flatten_capability_codes(
|
||||||
|
capabilities: dict[str, list[AgentAssetListItem | AgentAssetRead]],
|
||||||
|
) -> list[str]:
|
||||||
|
codes: list[str] = []
|
||||||
|
for items in capabilities.values():
|
||||||
|
for item in items[:2]:
|
||||||
|
if item.code not in codes:
|
||||||
|
codes.append(item.code)
|
||||||
|
return codes
|
||||||
|
|
||||||
|
def _rank_assets(
|
||||||
|
self,
|
||||||
|
items: list[AgentAssetListItem],
|
||||||
|
ontology: OntologyParseResult,
|
||||||
|
) -> list[AgentAssetListItem]:
|
||||||
|
def score(item: AgentAssetListItem) -> tuple[int, str]:
|
||||||
|
item_tags = {str(value) for value in item.scenario_json or []}
|
||||||
|
weight = 0
|
||||||
|
if ontology.scenario in item_tags:
|
||||||
|
weight += 3
|
||||||
|
if ontology.intent in item_tags:
|
||||||
|
weight += 2
|
||||||
|
for risk_flag in ontology.risk_flags:
|
||||||
|
if risk_flag in item_tags:
|
||||||
|
weight += 4
|
||||||
|
return weight, item.code
|
||||||
|
|
||||||
|
ranked = sorted(items, key=score, reverse=True)
|
||||||
|
if not ranked:
|
||||||
|
return []
|
||||||
|
scored = [item for item in ranked if score(item)[0] > 0]
|
||||||
|
return scored or ranked[:1]
|
||||||
|
|
||||||
|
def _invoke_tool(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
run_id: str,
|
||||||
|
tool_type: str,
|
||||||
|
tool_name: str,
|
||||||
|
request_json: dict[str, Any],
|
||||||
|
context_json: dict[str, Any],
|
||||||
|
executor,
|
||||||
|
fallback_factory,
|
||||||
|
) -> tuple[dict[str, Any], bool]:
|
||||||
|
started = perf_counter()
|
||||||
|
try:
|
||||||
|
self._maybe_raise_simulated_failure(tool_type, context_json)
|
||||||
|
response = executor()
|
||||||
|
duration_ms = int((perf_counter() - started) * 1000)
|
||||||
|
self.run_service.record_tool_call(
|
||||||
|
run_id=run_id,
|
||||||
|
tool_type=tool_type,
|
||||||
|
tool_name=tool_name,
|
||||||
|
request_json=request_json,
|
||||||
|
response_json=response,
|
||||||
|
status="succeeded",
|
||||||
|
duration_ms=duration_ms,
|
||||||
|
)
|
||||||
|
return response, False
|
||||||
|
except Exception as exc:
|
||||||
|
duration_ms = int((perf_counter() - started) * 1000)
|
||||||
|
response = fallback_factory(exc)
|
||||||
|
self.run_service.record_tool_call(
|
||||||
|
run_id=run_id,
|
||||||
|
tool_type=tool_type,
|
||||||
|
tool_name=tool_name,
|
||||||
|
request_json=request_json,
|
||||||
|
response_json=response,
|
||||||
|
status="failed",
|
||||||
|
duration_ms=duration_ms,
|
||||||
|
error_message=str(exc),
|
||||||
|
)
|
||||||
|
return response, True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _maybe_raise_simulated_failure(tool_type: str, context_json: dict[str, Any]) -> None:
|
||||||
|
expected = str(context_json.get("simulate_tool_failure") or "").strip().lower()
|
||||||
|
if not expected:
|
||||||
|
return
|
||||||
|
if expected == tool_type.lower():
|
||||||
|
raise RuntimeError(f"simulated {tool_type} failure")
|
||||||
|
|
||||||
|
def _build_database_answer(self, ontology: OntologyParseResult) -> dict[str, Any]:
|
||||||
|
if ontology.scenario == "expense":
|
||||||
|
count_stmt = select(func.count()).select_from(ExpenseClaim)
|
||||||
|
amount_stmt = select(
|
||||||
|
func.coalesce(func.sum(ExpenseClaim.amount), 0)
|
||||||
|
).select_from(ExpenseClaim)
|
||||||
|
employee_names = [
|
||||||
|
item.normalized_value
|
||||||
|
for item in ontology.entities
|
||||||
|
if item.type == "employee"
|
||||||
|
]
|
||||||
|
if employee_names:
|
||||||
|
count_stmt = count_stmt.where(ExpenseClaim.employee_name.in_(employee_names))
|
||||||
|
amount_stmt = amount_stmt.where(ExpenseClaim.employee_name.in_(employee_names))
|
||||||
|
total_count = int(self.db.scalar(count_stmt) or 0)
|
||||||
|
total_amount = float(self.db.scalar(amount_stmt) or 0)
|
||||||
|
return {
|
||||||
|
"record_count": total_count,
|
||||||
|
"total_amount": round(total_amount, 2),
|
||||||
|
}
|
||||||
|
|
||||||
|
if ontology.scenario == "accounts_receivable":
|
||||||
|
total_count = int(
|
||||||
|
self.db.scalar(
|
||||||
|
select(func.count()).select_from(AccountsReceivableRecord)
|
||||||
|
)
|
||||||
|
or 0
|
||||||
|
)
|
||||||
|
total_amount = float(
|
||||||
|
self.db.scalar(
|
||||||
|
select(func.coalesce(func.sum(AccountsReceivableRecord.amount_outstanding), 0))
|
||||||
|
)
|
||||||
|
or 0
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"record_count": total_count,
|
||||||
|
"outstanding_amount": round(total_amount, 2),
|
||||||
|
}
|
||||||
|
|
||||||
|
total_count = int(
|
||||||
|
self.db.scalar(select(func.count()).select_from(AccountsPayableRecord))
|
||||||
|
or 0
|
||||||
|
)
|
||||||
|
total_amount = float(
|
||||||
|
self.db.scalar(
|
||||||
|
select(func.coalesce(func.sum(AccountsPayableRecord.amount_outstanding), 0))
|
||||||
|
)
|
||||||
|
or 0
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"record_count": total_count,
|
||||||
|
"outstanding_amount": round(total_amount, 2),
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_user_query_result(
|
||||||
|
ontology: OntologyParseResult,
|
||||||
|
response: dict[str, Any],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
if ontology.scenario == "expense":
|
||||||
|
return {
|
||||||
|
"message": (
|
||||||
|
f"已路由到 User Agent,占位查询结果:命中 {response['record_count']} 笔报销,"
|
||||||
|
f"金额合计 {response['total_amount']} 元。"
|
||||||
|
),
|
||||||
|
"data": response,
|
||||||
|
}
|
||||||
|
if ontology.scenario == "accounts_receivable":
|
||||||
|
return {
|
||||||
|
"message": (
|
||||||
|
f"已路由到 User Agent,占位查询结果:命中 {response['record_count']} 条应收,"
|
||||||
|
f"未回款金额 {response['outstanding_amount']} 元。"
|
||||||
|
),
|
||||||
|
"data": response,
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
"message": (
|
||||||
|
f"已路由到 User Agent,占位查询结果:命中 {response['record_count']} 条应付,"
|
||||||
|
f"待付金额 {response['outstanding_amount']} 元。"
|
||||||
|
),
|
||||||
|
"data": response,
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_user_agent_result(
|
||||||
|
response: UserAgentResponse,
|
||||||
|
*,
|
||||||
|
degraded: bool,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
result = {
|
||||||
|
"message": response.answer,
|
||||||
|
"answer": response.answer,
|
||||||
|
"citations": [item.model_dump() for item in response.citations],
|
||||||
|
"suggested_actions": [item.model_dump() for item in response.suggested_actions],
|
||||||
|
"risk_flags": response.risk_flags,
|
||||||
|
"requires_confirmation": response.requires_confirmation,
|
||||||
|
"degraded": degraded,
|
||||||
|
}
|
||||||
|
if response.draft_payload is not None:
|
||||||
|
result["draft_payload"] = response.draft_payload.model_dump()
|
||||||
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_knowledge_answer(
|
||||||
|
ontology: OntologyParseResult,
|
||||||
|
capabilities: dict[str, list[AgentAssetListItem | AgentAssetRead]],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
referenced = [item.code for item in capabilities["rules"][:1]] or [
|
||||||
|
"knowledge.policy.default"
|
||||||
|
]
|
||||||
|
return {
|
||||||
|
"message": f"已路由到 User Agent,占位知识结果:建议先查看 {', '.join(referenced)}。",
|
||||||
|
"references": referenced,
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_rule_answer(ontology: OntologyParseResult) -> dict[str, Any]:
|
||||||
|
risk_text = (
|
||||||
|
"、".join(ontology.risk_flags)
|
||||||
|
if ontology.risk_flags
|
||||||
|
else "未识别到明确风险标签"
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"message": f"已完成占位规则检查,风险标签:{risk_text}。",
|
||||||
|
"risk_flags": ontology.risk_flags,
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_mcp_answer(
|
||||||
|
task_asset: AgentAssetRead | None,
|
||||||
|
ontology: OntologyParseResult,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"message": (
|
||||||
|
f"已调用占位 MCP 快照,任务={task_asset.code if task_asset else 'none'},"
|
||||||
|
f"scenario={ontology.scenario}。"
|
||||||
|
),
|
||||||
|
"snapshot": "stubbed",
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_hermes_message(
|
||||||
|
*,
|
||||||
|
task_asset: AgentAssetRead | None,
|
||||||
|
ontology: OntologyParseResult,
|
||||||
|
rule_response: dict[str, Any],
|
||||||
|
mcp_response: dict[str, Any],
|
||||||
|
degraded: bool,
|
||||||
|
) -> str:
|
||||||
|
task_code = task_asset.code if task_asset is not None else "task.unspecified"
|
||||||
|
suffix = ",其中部分能力已降级。" if degraded else "。"
|
||||||
|
return (
|
||||||
|
f"Hermes 占位执行完成:任务 {task_code},"
|
||||||
|
f"场景 {ontology.scenario},规则结果={rule_response.get('message', '')},"
|
||||||
|
f"MCP 结果={mcp_response.get('message', '')}{suffix}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _database_tool_name(scenario: str) -> str:
|
||||||
|
if scenario == "expense":
|
||||||
|
return "database.expense_claims.lookup"
|
||||||
|
if scenario == "accounts_receivable":
|
||||||
|
return "database.accounts_receivable.lookup"
|
||||||
|
return "database.accounts_payable.lookup"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _rule_tool_name(
|
||||||
|
capabilities: dict[str, list[AgentAssetListItem | AgentAssetRead]],
|
||||||
|
) -> str:
|
||||||
|
if capabilities["rules"]:
|
||||||
|
return capabilities["rules"][0].code
|
||||||
|
return "rule_engine.default_risk_check"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _mcp_tool_name(
|
||||||
|
capabilities: dict[str, list[AgentAssetListItem | AgentAssetRead]],
|
||||||
|
) -> str:
|
||||||
|
if capabilities["mcps"]:
|
||||||
|
return capabilities["mcps"][0].code
|
||||||
|
return "mcp.default_snapshot"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_ontology_json(ontology: OntologyParseResult) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"scenario": ontology.scenario,
|
||||||
|
"intent": ontology.intent,
|
||||||
|
"entities": [item.model_dump() for item in ontology.entities],
|
||||||
|
"time_range": ontology.time_range.model_dump(),
|
||||||
|
"metrics": [item.model_dump() for item in ontology.metrics],
|
||||||
|
"constraints": [item.model_dump() for item in ontology.constraints],
|
||||||
|
"risk_flags": ontology.risk_flags,
|
||||||
|
"permission": ontology.permission.model_dump(),
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_response_status(status: str) -> str:
|
||||||
|
if status == AgentRunStatus.FAILED.value:
|
||||||
|
return "failed"
|
||||||
|
if status == AgentRunStatus.BLOCKED.value:
|
||||||
|
return "blocked"
|
||||||
|
return "succeeded"
|
||||||
252
server/src/app/services/runtime_chat.py
Normal file
252
server/src/app/services/runtime_chat.py
Normal file
@@ -0,0 +1,252 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from http import HTTPStatus
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.core.logging import get_logger
|
||||||
|
from app.services.model_connectivity import (
|
||||||
|
AZURE_API_VERSION,
|
||||||
|
ConnectivityCheckError,
|
||||||
|
_build_azure_deployment_base,
|
||||||
|
_build_headers,
|
||||||
|
_ensure_path,
|
||||||
|
_normalize_endpoint,
|
||||||
|
_send_json_request,
|
||||||
|
)
|
||||||
|
from app.services.settings import SettingsService
|
||||||
|
|
||||||
|
logger = get_logger("app.services.runtime_chat")
|
||||||
|
|
||||||
|
|
||||||
|
class RuntimeChatService:
|
||||||
|
def __init__(self, db: Session) -> None:
|
||||||
|
self.db = db
|
||||||
|
self.settings_service = SettingsService(db)
|
||||||
|
|
||||||
|
def complete(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
*,
|
||||||
|
slot_priority: tuple[str, ...] = ("main", "backup"),
|
||||||
|
max_tokens: int = 500,
|
||||||
|
temperature: float = 0.2,
|
||||||
|
) -> str | None:
|
||||||
|
for slot in slot_priority:
|
||||||
|
config = self._load_chat_slot(slot)
|
||||||
|
if config is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
response_text = self._request_chat_completion(
|
||||||
|
config,
|
||||||
|
messages,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"Runtime chat request failed slot=%s provider=%s: %s",
|
||||||
|
slot,
|
||||||
|
config["provider"],
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if response_text:
|
||||||
|
return response_text.strip()
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _load_chat_slot(self, slot: str) -> dict[str, str] | None:
|
||||||
|
try:
|
||||||
|
config = self.settings_service.get_runtime_model_config(slot)
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if config["capability"] != "chat":
|
||||||
|
return None
|
||||||
|
|
||||||
|
provider = str(config["provider"] or "").strip()
|
||||||
|
endpoint = str(config["endpoint"] or "").strip()
|
||||||
|
model = str(config["model"] or "").strip()
|
||||||
|
api_key = str(config["apiKey"] or "").strip()
|
||||||
|
|
||||||
|
if not provider or not endpoint or not model:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if provider != "Ollama" and not api_key:
|
||||||
|
logger.info("Skip runtime chat slot=%s because api key is empty", slot)
|
||||||
|
return None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"slot": slot,
|
||||||
|
"provider": provider,
|
||||||
|
"endpoint": endpoint,
|
||||||
|
"model": model,
|
||||||
|
"apiKey": api_key,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _request_chat_completion(
|
||||||
|
self,
|
||||||
|
config: dict[str, str],
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
*,
|
||||||
|
max_tokens: int,
|
||||||
|
temperature: float,
|
||||||
|
) -> str:
|
||||||
|
provider = config["provider"]
|
||||||
|
endpoint = config["endpoint"]
|
||||||
|
model = config["model"]
|
||||||
|
api_key = config["apiKey"]
|
||||||
|
|
||||||
|
if provider == "Azure OpenAI":
|
||||||
|
return self._request_azure_openai(
|
||||||
|
endpoint=endpoint,
|
||||||
|
model=model,
|
||||||
|
api_key=api_key,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
)
|
||||||
|
|
||||||
|
if provider == "Ollama":
|
||||||
|
return self._request_ollama(
|
||||||
|
endpoint=endpoint,
|
||||||
|
model=model,
|
||||||
|
api_key=api_key,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._request_openai_compatible(
|
||||||
|
endpoint=endpoint,
|
||||||
|
model=model,
|
||||||
|
api_key=api_key,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _request_openai_compatible(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
endpoint: str,
|
||||||
|
model: str,
|
||||||
|
api_key: str,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
max_tokens: int,
|
||||||
|
temperature: float,
|
||||||
|
) -> str:
|
||||||
|
url = _ensure_path(_normalize_endpoint(endpoint), "chat/completions")
|
||||||
|
status_code, payload = _send_json_request(
|
||||||
|
"POST",
|
||||||
|
url,
|
||||||
|
headers=_build_headers(api_key=api_key, use_bearer=True),
|
||||||
|
payload={
|
||||||
|
"model": model,
|
||||||
|
"messages": messages,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"temperature": temperature,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if status_code >= HTTPStatus.BAD_REQUEST:
|
||||||
|
raise ConnectivityCheckError(
|
||||||
|
f"模型接口返回异常状态 {status_code}。",
|
||||||
|
status_code=status_code,
|
||||||
|
)
|
||||||
|
return self._extract_openai_text(payload)
|
||||||
|
|
||||||
|
def _request_ollama(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
endpoint: str,
|
||||||
|
model: str,
|
||||||
|
api_key: str,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
max_tokens: int,
|
||||||
|
temperature: float,
|
||||||
|
) -> str:
|
||||||
|
url = _ensure_path(_normalize_endpoint(endpoint), "api/chat")
|
||||||
|
status_code, payload = _send_json_request(
|
||||||
|
"POST",
|
||||||
|
url,
|
||||||
|
headers=_build_headers(api_key=api_key, use_bearer=False),
|
||||||
|
payload={
|
||||||
|
"model": model,
|
||||||
|
"messages": messages,
|
||||||
|
"stream": False,
|
||||||
|
"options": {
|
||||||
|
"num_predict": max_tokens,
|
||||||
|
"temperature": temperature,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if status_code >= HTTPStatus.BAD_REQUEST:
|
||||||
|
raise ConnectivityCheckError(
|
||||||
|
f"Ollama 返回异常状态 {status_code}。",
|
||||||
|
status_code=status_code,
|
||||||
|
)
|
||||||
|
return str((payload or {}).get("message", {}).get("content", "")).strip()
|
||||||
|
|
||||||
|
def _request_azure_openai(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
endpoint: str,
|
||||||
|
model: str,
|
||||||
|
api_key: str,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
max_tokens: int,
|
||||||
|
temperature: float,
|
||||||
|
) -> str:
|
||||||
|
deployment_base = _build_azure_deployment_base(endpoint, model)
|
||||||
|
url = f"{deployment_base}/chat/completions?api-version={AZURE_API_VERSION}"
|
||||||
|
status_code, payload = _send_json_request(
|
||||||
|
"POST",
|
||||||
|
url,
|
||||||
|
headers=_build_headers(api_key=api_key, use_bearer=False, use_api_key=True),
|
||||||
|
payload={
|
||||||
|
"messages": messages,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"temperature": temperature,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if status_code >= HTTPStatus.BAD_REQUEST:
|
||||||
|
raise ConnectivityCheckError(
|
||||||
|
f"Azure OpenAI 返回异常状态 {status_code}。",
|
||||||
|
status_code=status_code,
|
||||||
|
)
|
||||||
|
return self._extract_openai_text(payload)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_openai_text(payload: Any) -> str:
|
||||||
|
if not isinstance(payload, dict):
|
||||||
|
return ""
|
||||||
|
|
||||||
|
choices = payload.get("choices")
|
||||||
|
if not isinstance(choices, list) or not choices:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
first_choice = choices[0]
|
||||||
|
if not isinstance(first_choice, dict):
|
||||||
|
return ""
|
||||||
|
|
||||||
|
message = first_choice.get("message")
|
||||||
|
if isinstance(message, dict):
|
||||||
|
content = message.get("content", "")
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content.strip()
|
||||||
|
if isinstance(content, list):
|
||||||
|
parts: list[str] = []
|
||||||
|
for item in content:
|
||||||
|
if isinstance(item, dict) and item.get("type") == "text":
|
||||||
|
parts.append(str(item.get("text", "")))
|
||||||
|
return "\n".join(part.strip() for part in parts if part.strip()).strip()
|
||||||
|
|
||||||
|
text = first_choice.get("text")
|
||||||
|
if isinstance(text, str):
|
||||||
|
return text.strip()
|
||||||
|
|
||||||
|
return ""
|
||||||
547
server/src/app/services/user_agent.py
Normal file
547
server/src/app/services/user_agent.py
Normal file
@@ -0,0 +1,547 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.core.agent_enums import AgentAssetStatus, AgentAssetType
|
||||||
|
from app.schemas.agent_asset import AgentAssetListItem
|
||||||
|
from app.schemas.user_agent import (
|
||||||
|
UserAgentCitation,
|
||||||
|
UserAgentDraftPayload,
|
||||||
|
UserAgentRequest,
|
||||||
|
UserAgentResponse,
|
||||||
|
UserAgentSuggestedAction,
|
||||||
|
)
|
||||||
|
from app.services.agent_assets import AgentAssetService
|
||||||
|
from app.services.agent_foundation import AgentFoundationService
|
||||||
|
from app.services.runtime_chat import RuntimeChatService
|
||||||
|
|
||||||
|
SCENARIO_LABELS = {
|
||||||
|
"expense": "报销",
|
||||||
|
"accounts_receivable": "应收",
|
||||||
|
"accounts_payable": "应付",
|
||||||
|
"knowledge": "知识",
|
||||||
|
"unknown": "通用",
|
||||||
|
}
|
||||||
|
|
||||||
|
RISK_REASON_MAP = {
|
||||||
|
"duplicate_expense": "检测到同员工、同金额或近似单据存在重复提交迹象。",
|
||||||
|
"amount_over_limit": "金额超过当前制度或预算阈值,需要补充例外说明。",
|
||||||
|
"invoice_anomaly": "票据或附件完整性不满足当前规则要求,需要补件或人工复核。",
|
||||||
|
"ar_overdue": "应收账款已出现逾期,存在回款延迟风险。",
|
||||||
|
"ap_overdue": "应付付款已出现逾期,可能影响供应商履约或合作关系。",
|
||||||
|
}
|
||||||
|
|
||||||
|
GENERIC_EXPENSE_PROMPTS = {
|
||||||
|
"报销",
|
||||||
|
"我要报销",
|
||||||
|
"我想报销",
|
||||||
|
"帮我报销",
|
||||||
|
"我要申请报销",
|
||||||
|
"发起报销",
|
||||||
|
"提交报销",
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPLICIT_DRAFT_KEYWORDS = ("生成", "草稿", "起草", "创建", "发起", "准备")
|
||||||
|
|
||||||
|
EXPENSE_TYPE_LABELS = {
|
||||||
|
"travel": "差旅",
|
||||||
|
"hotel": "住宿",
|
||||||
|
"transport": "交通",
|
||||||
|
"meal": "餐费",
|
||||||
|
"meeting": "会务",
|
||||||
|
"entertainment": "招待",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class UserAgentService:
|
||||||
|
def __init__(self, db: Session) -> None:
|
||||||
|
self.db = db
|
||||||
|
self.asset_service = AgentAssetService(db)
|
||||||
|
self.runtime_chat_service = RuntimeChatService(db)
|
||||||
|
|
||||||
|
def respond(self, payload: UserAgentRequest) -> UserAgentResponse:
|
||||||
|
AgentFoundationService(self.db).ensure_foundation_ready()
|
||||||
|
citations = self._build_rule_citations(payload)
|
||||||
|
suggested_actions = self._build_suggested_actions(payload)
|
||||||
|
risk_flags = self._resolve_risk_flags(payload)
|
||||||
|
draft_payload = (
|
||||||
|
self._build_draft_payload(payload)
|
||||||
|
if payload.ontology.intent == "draft"
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
if payload.degraded and payload.tool_payload.get("message"):
|
||||||
|
return UserAgentResponse(
|
||||||
|
answer=str(payload.tool_payload["message"]),
|
||||||
|
citations=citations,
|
||||||
|
suggested_actions=suggested_actions,
|
||||||
|
risk_flags=risk_flags,
|
||||||
|
requires_confirmation=payload.requires_confirmation,
|
||||||
|
)
|
||||||
|
|
||||||
|
guided_answer = self._build_guided_answer(payload)
|
||||||
|
if guided_answer:
|
||||||
|
return UserAgentResponse(
|
||||||
|
answer=guided_answer,
|
||||||
|
citations=citations,
|
||||||
|
suggested_actions=suggested_actions,
|
||||||
|
draft_payload=draft_payload,
|
||||||
|
risk_flags=risk_flags,
|
||||||
|
requires_confirmation=payload.requires_confirmation,
|
||||||
|
)
|
||||||
|
|
||||||
|
fallback_answer = self._build_fallback_answer(
|
||||||
|
payload,
|
||||||
|
citations=citations,
|
||||||
|
draft_payload=draft_payload,
|
||||||
|
)
|
||||||
|
answer = self._generate_answer_with_model(
|
||||||
|
payload,
|
||||||
|
citations=citations,
|
||||||
|
suggested_actions=suggested_actions,
|
||||||
|
risk_flags=risk_flags,
|
||||||
|
draft_payload=draft_payload,
|
||||||
|
fallback_answer=fallback_answer,
|
||||||
|
)
|
||||||
|
|
||||||
|
return UserAgentResponse(
|
||||||
|
answer=answer or fallback_answer,
|
||||||
|
citations=citations,
|
||||||
|
suggested_actions=suggested_actions,
|
||||||
|
draft_payload=draft_payload,
|
||||||
|
risk_flags=risk_flags,
|
||||||
|
requires_confirmation=payload.requires_confirmation,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_fallback_answer(
|
||||||
|
self,
|
||||||
|
payload: UserAgentRequest,
|
||||||
|
*,
|
||||||
|
citations: list[UserAgentCitation],
|
||||||
|
draft_payload: UserAgentDraftPayload | None,
|
||||||
|
) -> str:
|
||||||
|
if payload.ontology.intent in {"query", "compare"}:
|
||||||
|
return self._build_query_answer(payload)
|
||||||
|
|
||||||
|
if payload.ontology.intent == "risk_check":
|
||||||
|
return self._build_risk_answer(payload, citations)
|
||||||
|
|
||||||
|
if payload.ontology.intent == "draft" and draft_payload is not None:
|
||||||
|
return (
|
||||||
|
f"已生成 {draft_payload.title},当前仅返回待人工确认的草稿内容,"
|
||||||
|
"仍需人工确认后再进入正式流程。"
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._build_explain_answer(payload, citations)
|
||||||
|
|
||||||
|
def _build_guided_answer(self, payload: UserAgentRequest) -> str | None:
|
||||||
|
if not self._is_generic_expense_prompt(payload):
|
||||||
|
return self._build_implicit_expense_draft_guidance(payload)
|
||||||
|
|
||||||
|
attachment_names = self._resolve_attachment_names(payload)
|
||||||
|
attachment_hint = ""
|
||||||
|
if attachment_names:
|
||||||
|
attachment_hint = (
|
||||||
|
f" 我已带入 {len(attachment_names)} 份附件名称,但目前还不能直接读取附件内容,"
|
||||||
|
"仍需要你补充关键信息。"
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
"可以帮你发起报销。请补充费用类型、发生时间、金额、事由和相关对象,"
|
||||||
|
"或者直接上传票据附件,我再继续帮你判断能否报、缺什么材料以及生成报销草稿。"
|
||||||
|
f"{attachment_hint}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_implicit_expense_draft_guidance(
|
||||||
|
self,
|
||||||
|
payload: UserAgentRequest,
|
||||||
|
) -> str | None:
|
||||||
|
if not self._is_implicit_expense_draft_request(payload):
|
||||||
|
return None
|
||||||
|
|
||||||
|
amount_text = next(
|
||||||
|
(item.value for item in payload.ontology.entities if item.type == "amount"),
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
expense_type = next(
|
||||||
|
(
|
||||||
|
EXPENSE_TYPE_LABELS.get(item.normalized_value, item.value)
|
||||||
|
for item in payload.ontology.entities
|
||||||
|
if item.type == "expense_type"
|
||||||
|
),
|
||||||
|
"报销",
|
||||||
|
)
|
||||||
|
time_text = payload.ontology.time_range.raw or "本次"
|
||||||
|
amount_hint = f",金额 {amount_text}" if amount_text else ""
|
||||||
|
|
||||||
|
return (
|
||||||
|
f"已识别到一笔{time_text}的{expense_type}支出{amount_hint}。"
|
||||||
|
"如果要继续生成报销草稿,还需要补充客户单位、参与人员、费用明细和票据附件。"
|
||||||
|
"你也可以继续上传发票或图片,我会把这些信息带入后续对话。"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _generate_answer_with_model(
|
||||||
|
self,
|
||||||
|
payload: UserAgentRequest,
|
||||||
|
*,
|
||||||
|
citations: list[UserAgentCitation],
|
||||||
|
suggested_actions: list[UserAgentSuggestedAction],
|
||||||
|
risk_flags: list[str],
|
||||||
|
draft_payload: UserAgentDraftPayload | None,
|
||||||
|
fallback_answer: str,
|
||||||
|
) -> str | None:
|
||||||
|
messages = self._build_model_messages(
|
||||||
|
payload,
|
||||||
|
citations=citations,
|
||||||
|
suggested_actions=suggested_actions,
|
||||||
|
risk_flags=risk_flags,
|
||||||
|
draft_payload=draft_payload,
|
||||||
|
fallback_answer=fallback_answer,
|
||||||
|
)
|
||||||
|
return self._sanitize_model_answer(
|
||||||
|
self.runtime_chat_service.complete(
|
||||||
|
messages,
|
||||||
|
max_tokens=420,
|
||||||
|
temperature=0.2,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _sanitize_model_answer(self, answer: str | None) -> str | None:
|
||||||
|
if not answer:
|
||||||
|
return None
|
||||||
|
|
||||||
|
cleaned = re.sub(r"<think>.*?</think>", "", answer, flags=re.DOTALL | re.IGNORECASE)
|
||||||
|
cleaned = cleaned.strip()
|
||||||
|
return cleaned or None
|
||||||
|
|
||||||
|
def _build_model_messages(
|
||||||
|
self,
|
||||||
|
payload: UserAgentRequest,
|
||||||
|
*,
|
||||||
|
citations: list[UserAgentCitation],
|
||||||
|
suggested_actions: list[UserAgentSuggestedAction],
|
||||||
|
risk_flags: list[str],
|
||||||
|
draft_payload: UserAgentDraftPayload | None,
|
||||||
|
fallback_answer: str,
|
||||||
|
) -> list[dict[str, str]]:
|
||||||
|
facts = {
|
||||||
|
"run_id": payload.run_id,
|
||||||
|
"user_message": payload.message,
|
||||||
|
"ontology": payload.ontology.model_dump(mode="json"),
|
||||||
|
"context": {
|
||||||
|
"entry_source": payload.context_json.get("entry_source"),
|
||||||
|
"user_name": payload.context_json.get("name"),
|
||||||
|
"user_role": payload.context_json.get("role"),
|
||||||
|
"request_context": payload.context_json.get("request_context"),
|
||||||
|
"attachment_count": payload.context_json.get("attachment_count"),
|
||||||
|
"attachment_names": self._resolve_attachment_names(payload),
|
||||||
|
},
|
||||||
|
"tool_payload": payload.tool_payload,
|
||||||
|
"citations": [item.model_dump(mode="json") for item in citations],
|
||||||
|
"suggested_actions": [
|
||||||
|
item.model_dump(mode="json") for item in suggested_actions
|
||||||
|
],
|
||||||
|
"risk_flags": risk_flags,
|
||||||
|
"draft_payload": (
|
||||||
|
draft_payload.model_dump(mode="json")
|
||||||
|
if draft_payload is not None
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
"selected_capability_codes": payload.selected_capability_codes,
|
||||||
|
"requires_confirmation": payload.requires_confirmation,
|
||||||
|
"fallback_answer": fallback_answer,
|
||||||
|
}
|
||||||
|
|
||||||
|
system_prompt = (
|
||||||
|
"你是企业财务共享场景中的中文智能助手,负责和最终用户直接对话。"
|
||||||
|
"你只能基于提供的事实回答,不能编造制度、流程结果或附件内容。"
|
||||||
|
"如果用户问题很笼统,例如“我要报销”,优先告诉用户你可以协助什么,"
|
||||||
|
"并明确要求补充费用类型、金额、时间、事由、参与对象或上传票据。"
|
||||||
|
"如果上下文里只有附件名称,必须明确说明你只拿到了附件名称,"
|
||||||
|
"不能假装已看过图片、PDF 或发票内容。"
|
||||||
|
"不要声称已经提交、审批、付款、入账或真正执行了任何动作;如果只是建议、草稿或待确认,要明确说清楚。"
|
||||||
|
"若给出了风险标签、制度引用或建议动作,可以简洁吸收进回答,但不要新增未提供的事实。"
|
||||||
|
"只输出最终给用户看的自然语言,不要输出 JSON、Markdown、标题、"
|
||||||
|
"<think> 标签或任何中间推理。"
|
||||||
|
"使用简体中文,控制在 2 到 4 句。"
|
||||||
|
)
|
||||||
|
user_prompt = (
|
||||||
|
"请根据以下事实生成最终答复,优先保持准确、具体、可执行:\n"
|
||||||
|
f"{json.dumps(facts, ensure_ascii=False, indent=2)}"
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
{"role": "user", "content": user_prompt},
|
||||||
|
]
|
||||||
|
|
||||||
|
def _build_query_answer(self, payload: UserAgentRequest) -> str:
|
||||||
|
scenario = payload.ontology.scenario
|
||||||
|
data = payload.tool_payload
|
||||||
|
subject = self._resolve_subject(payload)
|
||||||
|
|
||||||
|
if scenario == "expense":
|
||||||
|
record_count = int(data.get("record_count") or 0)
|
||||||
|
total_amount = float(data.get("total_amount") or 0)
|
||||||
|
return (
|
||||||
|
f"{subject}共命中 {record_count} 笔报销,金额合计 {total_amount:.2f} 元。"
|
||||||
|
"如需继续处理,可以查看明细或生成处理意见草稿。"
|
||||||
|
)
|
||||||
|
|
||||||
|
if scenario == "accounts_receivable":
|
||||||
|
record_count = int(data.get("record_count") or 0)
|
||||||
|
outstanding_amount = float(data.get("outstanding_amount") or 0)
|
||||||
|
return (
|
||||||
|
f"{subject}共命中 {record_count} 条应收,未回款金额 {outstanding_amount:.2f} 元。"
|
||||||
|
"建议结合账龄和客户分布继续排查逾期风险。"
|
||||||
|
)
|
||||||
|
|
||||||
|
if scenario == "accounts_payable":
|
||||||
|
record_count = int(data.get("record_count") or 0)
|
||||||
|
outstanding_amount = float(data.get("outstanding_amount") or 0)
|
||||||
|
return (
|
||||||
|
f"{subject}共命中 {record_count} 条应付,待付金额 {outstanding_amount:.2f} 元。"
|
||||||
|
"如需推进动作,建议先生成付款建议草稿并发起人工确认。"
|
||||||
|
)
|
||||||
|
|
||||||
|
return "已完成当前查询,但暂时没有更多结构化结果可展示。"
|
||||||
|
|
||||||
|
def _build_explain_answer(
|
||||||
|
self,
|
||||||
|
payload: UserAgentRequest,
|
||||||
|
citations: list[UserAgentCitation],
|
||||||
|
) -> str:
|
||||||
|
if citations:
|
||||||
|
titles = "、".join(item.title for item in citations[:2])
|
||||||
|
summary = citations[0].excerpt or "请结合制度全文进一步确认。"
|
||||||
|
return f"已检索到相关依据:{titles}。核心说明:{summary}"
|
||||||
|
|
||||||
|
return (
|
||||||
|
f"当前还没有与“{SCENARIO_LABELS.get(payload.ontology.scenario, '当前问题')}”"
|
||||||
|
"强匹配的已上线规则引用,建议先人工复核或补充更具体的单据上下文。"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_risk_answer(
|
||||||
|
self,
|
||||||
|
payload: UserAgentRequest,
|
||||||
|
citations: list[UserAgentCitation],
|
||||||
|
) -> str:
|
||||||
|
risk_flags = self._resolve_risk_flags(payload)
|
||||||
|
if not risk_flags:
|
||||||
|
return "当前未识别到明确风险标签,建议继续查看原始明细或补充更多上下文。"
|
||||||
|
|
||||||
|
reasons = [RISK_REASON_MAP.get(flag, f"{flag} 需要人工进一步确认。") for flag in risk_flags]
|
||||||
|
citation_text = (
|
||||||
|
f" 参考规则:{'、'.join(item.title for item in citations[:2])}。"
|
||||||
|
if citations
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
f"本次识别到 {len(risk_flags)} 类风险:{'、'.join(risk_flags)}。"
|
||||||
|
f"触发原因:{';'.join(reasons)}。"
|
||||||
|
"建议先复核明细、附件和审批链,再决定是否继续处理。"
|
||||||
|
f"{citation_text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_draft_payload(self, payload: UserAgentRequest) -> UserAgentDraftPayload:
|
||||||
|
scenario_label = SCENARIO_LABELS.get(payload.ontology.scenario, "业务")
|
||||||
|
subject = self._resolve_subject(payload)
|
||||||
|
title = f"{scenario_label}处理意见草稿"
|
||||||
|
body = (
|
||||||
|
f"主题:{subject}\n"
|
||||||
|
"结论:已根据当前语义解析结果生成草稿,尚未自动执行。\n"
|
||||||
|
"建议:请先核对明细、规则命中和所需附件,再由人工确认是否提交正式流程。\n"
|
||||||
|
f"原始问题:{payload.message}"
|
||||||
|
)
|
||||||
|
return UserAgentDraftPayload(
|
||||||
|
draft_type=payload.ontology.scenario,
|
||||||
|
title=title,
|
||||||
|
body=body,
|
||||||
|
confirmation_required=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_suggested_actions(
|
||||||
|
self,
|
||||||
|
payload: UserAgentRequest,
|
||||||
|
) -> list[UserAgentSuggestedAction]:
|
||||||
|
if self._is_generic_expense_prompt(payload):
|
||||||
|
return [
|
||||||
|
UserAgentSuggestedAction(
|
||||||
|
label="上传票据",
|
||||||
|
action_type="ask_clarification",
|
||||||
|
description="上传发票、行程单或付款截图,继续识别报销内容。",
|
||||||
|
),
|
||||||
|
UserAgentSuggestedAction(
|
||||||
|
label="补充报销信息",
|
||||||
|
action_type="ask_clarification",
|
||||||
|
description="补充费用类型、金额、时间和事由后继续处理。",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
if payload.ontology.intent in {"query", "compare"}:
|
||||||
|
return [
|
||||||
|
UserAgentSuggestedAction(
|
||||||
|
label="查看明细",
|
||||||
|
action_type="open_detail",
|
||||||
|
description="继续查看命中记录和过滤条件。",
|
||||||
|
),
|
||||||
|
UserAgentSuggestedAction(
|
||||||
|
label="生成处理意见",
|
||||||
|
action_type="create_draft",
|
||||||
|
description="把当前查询结果整理成可确认草稿。",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
if payload.ontology.intent == "risk_check":
|
||||||
|
return [
|
||||||
|
UserAgentSuggestedAction(
|
||||||
|
label="人工复核风险",
|
||||||
|
action_type="manual_review",
|
||||||
|
description="优先检查明细、附件和规则命中原因。",
|
||||||
|
),
|
||||||
|
UserAgentSuggestedAction(
|
||||||
|
label="生成整改建议",
|
||||||
|
action_type="create_draft",
|
||||||
|
description="把风险说明整理成处理意见草稿。",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
if payload.ontology.intent == "draft":
|
||||||
|
return [
|
||||||
|
UserAgentSuggestedAction(
|
||||||
|
label="复制草稿",
|
||||||
|
action_type="copy_draft",
|
||||||
|
description="复制当前草稿后交由人工确认。",
|
||||||
|
),
|
||||||
|
UserAgentSuggestedAction(
|
||||||
|
label="补充上下文",
|
||||||
|
action_type="ask_clarification",
|
||||||
|
description="补充单据编号、客户或供应商信息以完善草稿。",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
return [
|
||||||
|
UserAgentSuggestedAction(
|
||||||
|
label="查看规则全文",
|
||||||
|
action_type="open_rule",
|
||||||
|
description="继续查看引用规则或知识内容。",
|
||||||
|
),
|
||||||
|
UserAgentSuggestedAction(
|
||||||
|
label="补充问题上下文",
|
||||||
|
action_type="ask_clarification",
|
||||||
|
description="补充业务对象、时间或单据范围,提升回答准确度。",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
def _build_rule_citations(self, payload: UserAgentRequest) -> list[UserAgentCitation]:
|
||||||
|
domain = self._resolve_domain(payload.ontology.scenario)
|
||||||
|
items = self.asset_service.list_assets(
|
||||||
|
asset_type=AgentAssetType.RULE.value,
|
||||||
|
status=AgentAssetStatus.ACTIVE.value,
|
||||||
|
domain=domain,
|
||||||
|
)
|
||||||
|
ranked = self._rank_rule_assets(items, payload)
|
||||||
|
citations: list[UserAgentCitation] = []
|
||||||
|
for item in ranked[:2]:
|
||||||
|
detail = self.asset_service.get_asset(item.id)
|
||||||
|
if detail is None:
|
||||||
|
continue
|
||||||
|
excerpt = self._extract_excerpt(str(detail.current_version_content or ""))
|
||||||
|
citations.append(
|
||||||
|
UserAgentCitation(
|
||||||
|
source_type="rule",
|
||||||
|
code=detail.code,
|
||||||
|
title=detail.name,
|
||||||
|
version=detail.current_version,
|
||||||
|
updated_at=detail.updated_at.date().isoformat(),
|
||||||
|
excerpt=excerpt,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return citations
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _resolve_risk_flags(payload: UserAgentRequest) -> list[str]:
|
||||||
|
tool_flags = payload.tool_payload.get("risk_flags")
|
||||||
|
if isinstance(tool_flags, list) and tool_flags:
|
||||||
|
return [str(item) for item in tool_flags]
|
||||||
|
return [str(item) for item in payload.ontology.risk_flags]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _resolve_subject(payload: UserAgentRequest) -> str:
|
||||||
|
named_entities = [
|
||||||
|
item.value
|
||||||
|
for item in payload.ontology.entities
|
||||||
|
if item.type in {"employee", "customer", "vendor", "project"}
|
||||||
|
]
|
||||||
|
if named_entities:
|
||||||
|
return f"{'、'.join(named_entities)} 相关数据"
|
||||||
|
return f"{SCENARIO_LABELS.get(payload.ontology.scenario, '当前')}场景数据"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_generic_expense_prompt(payload: UserAgentRequest) -> bool:
|
||||||
|
if payload.ontology.scenario != "expense":
|
||||||
|
return False
|
||||||
|
normalized_message = re.sub(r"\s+", "", payload.message)
|
||||||
|
return normalized_message in GENERIC_EXPENSE_PROMPTS
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_implicit_expense_draft_request(payload: UserAgentRequest) -> bool:
|
||||||
|
if payload.ontology.scenario != "expense" or payload.ontology.intent != "draft":
|
||||||
|
return False
|
||||||
|
|
||||||
|
compact_message = re.sub(r"\s+", "", payload.message)
|
||||||
|
if any(keyword in compact_message for keyword in EXPLICIT_DRAFT_KEYWORDS):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _resolve_attachment_names(payload: UserAgentRequest) -> list[str]:
|
||||||
|
names = payload.context_json.get("attachment_names")
|
||||||
|
if not isinstance(names, list):
|
||||||
|
return []
|
||||||
|
return [str(name) for name in names if str(name).strip()]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _resolve_domain(scenario: str) -> str | None:
|
||||||
|
if scenario == "expense":
|
||||||
|
return "expense"
|
||||||
|
if scenario == "accounts_receivable":
|
||||||
|
return "ar"
|
||||||
|
if scenario == "accounts_payable":
|
||||||
|
return "ap"
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _rank_rule_assets(
|
||||||
|
items: list[AgentAssetListItem],
|
||||||
|
payload: UserAgentRequest,
|
||||||
|
) -> list[AgentAssetListItem]:
|
||||||
|
def score(item: AgentAssetListItem) -> tuple[int, str]:
|
||||||
|
tags = {str(value) for value in item.scenario_json or []}
|
||||||
|
weight = 0
|
||||||
|
if payload.ontology.scenario in tags:
|
||||||
|
weight += 3
|
||||||
|
if payload.ontology.intent in tags:
|
||||||
|
weight += 2
|
||||||
|
for risk_flag in payload.ontology.risk_flags:
|
||||||
|
if risk_flag in tags:
|
||||||
|
weight += 4
|
||||||
|
return weight, item.code
|
||||||
|
|
||||||
|
ranked = sorted(items, key=score, reverse=True)
|
||||||
|
return [item for item in ranked if score(item)[0] > 0]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_excerpt(content: str) -> str:
|
||||||
|
lines = [line.strip() for line in str(content).splitlines() if line.strip()]
|
||||||
|
cleaned: list[str] = []
|
||||||
|
for line in lines:
|
||||||
|
normalized = re.sub(r"^[#>\-\*\d\.\s`]+", "", line).strip()
|
||||||
|
if normalized:
|
||||||
|
cleaned.append(normalized)
|
||||||
|
if len(cleaned) >= 2:
|
||||||
|
break
|
||||||
|
return ";".join(cleaned[:2])
|
||||||
397
server/tests/test_ontology_service.py
Normal file
397
server/tests/test_ontology_service.py
Normal file
@@ -0,0 +1,397 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Generator
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
from sqlalchemy.pool import StaticPool
|
||||||
|
|
||||||
|
from app.api.deps import get_db
|
||||||
|
from app.db.base import Base
|
||||||
|
from app.main import create_app
|
||||||
|
from app.schemas.ontology import OntologyParseRequest
|
||||||
|
from app.services.ontology import LlmOntologyParseResult, SemanticOntologyService
|
||||||
|
|
||||||
|
|
||||||
|
def build_session_factory() -> sessionmaker[Session]:
|
||||||
|
engine = create_engine(
|
||||||
|
"sqlite+pysqlite:///:memory:",
|
||||||
|
connect_args={"check_same_thread": False},
|
||||||
|
poolclass=StaticPool,
|
||||||
|
)
|
||||||
|
Base.metadata.create_all(bind=engine)
|
||||||
|
return sessionmaker(bind=engine, autoflush=False, autocommit=False)
|
||||||
|
|
||||||
|
|
||||||
|
def build_client() -> tuple[TestClient, sessionmaker[Session]]:
|
||||||
|
session_factory = build_session_factory()
|
||||||
|
app = create_app()
|
||||||
|
|
||||||
|
def override_db() -> Generator[Session, None, None]:
|
||||||
|
db = session_factory()
|
||||||
|
try:
|
||||||
|
yield db
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
app.dependency_overrides[get_db] = override_db
|
||||||
|
return TestClient(app), session_factory
|
||||||
|
|
||||||
|
|
||||||
|
EVALUATION_CASES = [
|
||||||
|
pytest.param(
|
||||||
|
"查一下本周报销超标风险",
|
||||||
|
"expense",
|
||||||
|
"risk_check",
|
||||||
|
"read",
|
||||||
|
{},
|
||||||
|
id="expense-risk-check",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
"张三 4 月差旅报销金额是多少",
|
||||||
|
"expense",
|
||||||
|
"query",
|
||||||
|
"read",
|
||||||
|
{},
|
||||||
|
id="expense-query-employee-month",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
"为什么酒店超标报销不能直接通过",
|
||||||
|
"expense",
|
||||||
|
"explain",
|
||||||
|
"read",
|
||||||
|
{},
|
||||||
|
id="expense-explain-policy",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
"列出金额最高的10笔报销",
|
||||||
|
"expense",
|
||||||
|
"query",
|
||||||
|
"read",
|
||||||
|
{},
|
||||||
|
id="expense-topn-query",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
"帮我生成张三4月差旅报销草稿",
|
||||||
|
"expense",
|
||||||
|
"draft",
|
||||||
|
"draft_write",
|
||||||
|
{},
|
||||||
|
id="expense-draft",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
"我今天去客户现场,招待了客户,花销了1000元",
|
||||||
|
"expense",
|
||||||
|
"draft",
|
||||||
|
"draft_write",
|
||||||
|
{},
|
||||||
|
id="expense-narrative-draft",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
"客户 A 这个月还有多少应收",
|
||||||
|
"accounts_receivable",
|
||||||
|
"query",
|
||||||
|
"read",
|
||||||
|
{},
|
||||||
|
id="ar-query-customer-month",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
"对比客户A和客户B本月应收差异",
|
||||||
|
"accounts_receivable",
|
||||||
|
"compare",
|
||||||
|
"read",
|
||||||
|
{},
|
||||||
|
id="ar-compare-customers",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
"检查客户B逾期应收风险",
|
||||||
|
"accounts_receivable",
|
||||||
|
"risk_check",
|
||||||
|
"read",
|
||||||
|
{},
|
||||||
|
id="ar-risk-check",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
"生成客户A回款跟进草稿",
|
||||||
|
"accounts_receivable",
|
||||||
|
"draft",
|
||||||
|
"draft_write",
|
||||||
|
{},
|
||||||
|
id="ar-draft",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
"查询客户B账龄明细",
|
||||||
|
"accounts_receivable",
|
||||||
|
"query",
|
||||||
|
"read",
|
||||||
|
{},
|
||||||
|
id="ar-aging-query",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
"供应商 B 明天要付多少钱",
|
||||||
|
"accounts_payable",
|
||||||
|
"query",
|
||||||
|
"read",
|
||||||
|
{},
|
||||||
|
id="ap-query-vendor-tomorrow",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
"对比供应商A和供应商B本月应付差异",
|
||||||
|
"accounts_payable",
|
||||||
|
"compare",
|
||||||
|
"read",
|
||||||
|
{},
|
||||||
|
id="ap-compare-vendors",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
"检查供应商B逾期付款风险",
|
||||||
|
"accounts_payable",
|
||||||
|
"risk_check",
|
||||||
|
"read",
|
||||||
|
{},
|
||||||
|
id="ap-risk-check",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
"生成供应商A付款沟通草稿",
|
||||||
|
"accounts_payable",
|
||||||
|
"draft",
|
||||||
|
"draft_write",
|
||||||
|
{},
|
||||||
|
id="ap-draft",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
"帮我安排付款给供应商B",
|
||||||
|
"accounts_payable",
|
||||||
|
"operate",
|
||||||
|
"approval_required",
|
||||||
|
{"role_codes": ["finance"]},
|
||||||
|
id="ap-operate-approval-required",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
"公司财务制度在哪里看",
|
||||||
|
"knowledge",
|
||||||
|
"query",
|
||||||
|
"read",
|
||||||
|
{},
|
||||||
|
id="knowledge-query",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
"规则中心的审核依据是什么",
|
||||||
|
"knowledge",
|
||||||
|
"explain",
|
||||||
|
"read",
|
||||||
|
{},
|
||||||
|
id="knowledge-explain",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
"知识库里有没有双人复核制度",
|
||||||
|
"knowledge",
|
||||||
|
"query",
|
||||||
|
"read",
|
||||||
|
{},
|
||||||
|
id="knowledge-query-library",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
"帮我直接付款给供应商B",
|
||||||
|
"accounts_payable",
|
||||||
|
"operate",
|
||||||
|
"forbidden",
|
||||||
|
{"role_codes": ["user"]},
|
||||||
|
id="forbidden-direct-payment",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
"帮我上线付款双人复核规则",
|
||||||
|
"accounts_payable",
|
||||||
|
"operate",
|
||||||
|
"forbidden",
|
||||||
|
{"role_codes": ["user"]},
|
||||||
|
id="forbidden-activate-rule",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
"帮我删除今天的报销记录",
|
||||||
|
"expense",
|
||||||
|
"operate",
|
||||||
|
"forbidden",
|
||||||
|
{"role_codes": ["user"]},
|
||||||
|
id="forbidden-delete-expense",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("query,scenario,intent,permission,context_json", EVALUATION_CASES)
|
||||||
|
def test_semantic_ontology_service_matches_day3_evaluation_set(
|
||||||
|
query: str,
|
||||||
|
scenario: str,
|
||||||
|
intent: str,
|
||||||
|
permission: str,
|
||||||
|
context_json: dict,
|
||||||
|
) -> None:
|
||||||
|
session_factory = build_session_factory()
|
||||||
|
with session_factory() as db:
|
||||||
|
result = SemanticOntologyService(db).parse(
|
||||||
|
OntologyParseRequest(
|
||||||
|
query=query,
|
||||||
|
user_id="pytest",
|
||||||
|
context_json=context_json,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.scenario == scenario
|
||||||
|
assert result.intent == intent
|
||||||
|
assert result.permission.level == permission
|
||||||
|
assert result.run_id.startswith("run_")
|
||||||
|
|
||||||
|
|
||||||
|
def test_semantic_ontology_service_extracts_entities_time_and_constraints() -> None:
|
||||||
|
session_factory = build_session_factory()
|
||||||
|
with session_factory() as db:
|
||||||
|
result = SemanticOntologyService(db).parse(
|
||||||
|
OntologyParseRequest(
|
||||||
|
query="张三 2026年4月差旅报销金额超过5000元的明细",
|
||||||
|
user_id="pytest",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.scenario == "expense"
|
||||||
|
assert result.intent == "query"
|
||||||
|
assert result.time_range.start_date == "2026-04-01"
|
||||||
|
assert result.time_range.end_date == "2026-04-30"
|
||||||
|
assert any(
|
||||||
|
item.type == "employee" and item.normalized_value == "张三"
|
||||||
|
for item in result.entities
|
||||||
|
)
|
||||||
|
assert any(
|
||||||
|
item.type == "expense_type" and item.normalized_value == "travel"
|
||||||
|
for item in result.entities
|
||||||
|
)
|
||||||
|
assert any(
|
||||||
|
item.field == "amount" and item.operator == ">" and item.value == 5000
|
||||||
|
for item in result.constraints
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_semantic_ontology_service_prefers_expense_for_customer_entertainment_narrative() -> None:
|
||||||
|
session_factory = build_session_factory()
|
||||||
|
with session_factory() as db:
|
||||||
|
result = SemanticOntologyService(db).parse(
|
||||||
|
OntologyParseRequest(
|
||||||
|
query="我今天去客户现场,招待了客户,花销了1000元",
|
||||||
|
user_id="pytest",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.scenario == "expense"
|
||||||
|
assert result.intent == "draft"
|
||||||
|
assert result.permission.level == "draft_write"
|
||||||
|
assert result.time_range.raw == "今天"
|
||||||
|
assert result.clarification_required is True
|
||||||
|
assert "customer_name" in result.missing_slots
|
||||||
|
assert "participants" in result.missing_slots
|
||||||
|
assert any(
|
||||||
|
item.type == "expense_type" and item.normalized_value == "entertainment"
|
||||||
|
for item in result.entities
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_semantic_ontology_service_uses_model_parse_when_available(monkeypatch) -> None:
|
||||||
|
session_factory = build_session_factory()
|
||||||
|
with session_factory() as db:
|
||||||
|
service = SemanticOntologyService(db)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
service,
|
||||||
|
"_parse_with_model",
|
||||||
|
lambda **kwargs: LlmOntologyParseResult(
|
||||||
|
scenario="expense",
|
||||||
|
intent="draft",
|
||||||
|
confidence=0.91,
|
||||||
|
clarification_required=True,
|
||||||
|
clarification_question="请补充费用类型、金额和票据附件。",
|
||||||
|
missing_slots=["expense_type", "amount", "attachments"],
|
||||||
|
ambiguity=[],
|
||||||
|
entity_hints=[],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = service.parse(
|
||||||
|
OntologyParseRequest(
|
||||||
|
query="我要报销",
|
||||||
|
user_id="pytest",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.scenario == "expense"
|
||||||
|
assert result.intent == "draft"
|
||||||
|
assert result.parse_strategy == "llm_primary"
|
||||||
|
assert result.clarification_required is True
|
||||||
|
assert "expense_type" in result.missing_slots
|
||||||
|
assert result.clarification_question == "请补充费用类型、金额和票据附件。"
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_ontology_endpoint_returns_eight_fields_and_writes_trace() -> None:
|
||||||
|
client, _ = build_client()
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/ontology/parse",
|
||||||
|
json={
|
||||||
|
"query": "查一下本周报销超标风险",
|
||||||
|
"user_id": "pytest",
|
||||||
|
"context_json": {"role_codes": ["finance"]},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
payload = response.json()
|
||||||
|
assert payload["scenario"] == "expense"
|
||||||
|
assert payload["intent"] == "risk_check"
|
||||||
|
assert payload["permission"]["level"] == "read"
|
||||||
|
assert payload["run_id"].startswith("run_")
|
||||||
|
assert set(payload) >= {
|
||||||
|
"scenario",
|
||||||
|
"intent",
|
||||||
|
"entities",
|
||||||
|
"time_range",
|
||||||
|
"metrics",
|
||||||
|
"constraints",
|
||||||
|
"risk_flags",
|
||||||
|
"permission",
|
||||||
|
"confidence",
|
||||||
|
"missing_slots",
|
||||||
|
"ambiguity",
|
||||||
|
"parse_strategy",
|
||||||
|
"clarification_required",
|
||||||
|
"clarification_question",
|
||||||
|
"run_id",
|
||||||
|
"field_errors",
|
||||||
|
}
|
||||||
|
|
||||||
|
run_response = client.get(f"/api/v1/agent-runs/{payload['run_id']}")
|
||||||
|
|
||||||
|
assert run_response.status_code == 200
|
||||||
|
run_payload = run_response.json()
|
||||||
|
assert run_payload["ontology_json"]["scenario"] == "expense"
|
||||||
|
assert run_payload["ontology_json"]["intent"] == "risk_check"
|
||||||
|
assert run_payload["semantic_parse"]["scenario"] == "expense"
|
||||||
|
assert run_payload["semantic_parse"]["intent"] == "risk_check"
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_ontology_endpoint_returns_forbidden_for_unprivileged_payment_request() -> None:
|
||||||
|
client, _ = build_client()
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/ontology/parse",
|
||||||
|
json={
|
||||||
|
"query": "帮我直接付款给供应商B",
|
||||||
|
"user_id": "pytest",
|
||||||
|
"context_json": {"role_codes": ["user"]},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
payload = response.json()
|
||||||
|
assert payload["scenario"] == "accounts_payable"
|
||||||
|
assert payload["intent"] == "operate"
|
||||||
|
assert payload["permission"]["level"] == "forbidden"
|
||||||
|
assert payload["clarification_required"] is True
|
||||||
|
assert payload["field_errors"]
|
||||||
241
server/tests/test_orchestrator_service.py
Normal file
241
server/tests/test_orchestrator_service.py
Normal file
@@ -0,0 +1,241 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Generator
|
||||||
|
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
from sqlalchemy.pool import StaticPool
|
||||||
|
|
||||||
|
from app.api.deps import get_db
|
||||||
|
from app.db.base import Base
|
||||||
|
from app.main import create_app
|
||||||
|
from app.services.agent_assets import AgentAssetService
|
||||||
|
|
||||||
|
|
||||||
|
def build_client() -> tuple[TestClient, sessionmaker[Session]]:
|
||||||
|
engine = create_engine(
|
||||||
|
"sqlite+pysqlite:///:memory:",
|
||||||
|
connect_args={"check_same_thread": False},
|
||||||
|
poolclass=StaticPool,
|
||||||
|
)
|
||||||
|
Base.metadata.create_all(bind=engine)
|
||||||
|
session_factory = sessionmaker(bind=engine, autoflush=False, autocommit=False)
|
||||||
|
app = create_app()
|
||||||
|
|
||||||
|
def override_db() -> Generator[Session, None, None]:
|
||||||
|
db = session_factory()
|
||||||
|
try:
|
||||||
|
yield db
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
app.dependency_overrides[get_db] = override_db
|
||||||
|
return TestClient(app), session_factory
|
||||||
|
|
||||||
|
|
||||||
|
def test_orchestrator_routes_user_query_to_user_agent() -> None:
|
||||||
|
client, _ = build_client()
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/orchestrator/run",
|
||||||
|
json={
|
||||||
|
"source": "user_message",
|
||||||
|
"user_id": "pytest",
|
||||||
|
"message": "客户A这个月还有多少应收",
|
||||||
|
"context_json": {"role_codes": ["finance"]},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
payload = response.json()
|
||||||
|
assert payload["selected_agent"] == "user_agent"
|
||||||
|
assert payload["permission_level"] == "read"
|
||||||
|
assert payload["status"] == "succeeded"
|
||||||
|
assert payload["result"]["answer"]
|
||||||
|
assert payload["result"]["suggested_actions"]
|
||||||
|
assert payload["trace_summary"]["tool_count"] >= 1
|
||||||
|
|
||||||
|
run_detail = client.get(f"/api/v1/agent-runs/{payload['run_id']}").json()
|
||||||
|
assert run_detail["agent"] == "user_agent"
|
||||||
|
assert run_detail["route_json"]["selected_agent"] == "user_agent"
|
||||||
|
assert run_detail["semantic_parse"]["scenario"] == "accounts_receivable"
|
||||||
|
assert run_detail["tool_calls"][0]["tool_type"] == "database"
|
||||||
|
|
||||||
|
|
||||||
|
def test_orchestrator_routes_schedule_to_hermes() -> None:
|
||||||
|
client, session_factory = build_client()
|
||||||
|
|
||||||
|
with session_factory() as db:
|
||||||
|
task = next(
|
||||||
|
item
|
||||||
|
for item in AgentAssetService(db).list_assets(asset_type="task", status="active")
|
||||||
|
if item.code == "task.hermes.daily_risk_scan"
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/orchestrator/run",
|
||||||
|
json={
|
||||||
|
"source": "schedule",
|
||||||
|
"task_id": task.id,
|
||||||
|
"context_json": {"role_codes": ["finance"]},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
payload = response.json()
|
||||||
|
assert payload["selected_agent"] == "hermes"
|
||||||
|
assert payload["status"] == "succeeded"
|
||||||
|
assert payload["trace_summary"]["tool_count"] == 2
|
||||||
|
|
||||||
|
run_detail = client.get(f"/api/v1/agent-runs/{payload['run_id']}").json()
|
||||||
|
assert run_detail["agent"] == "hermes"
|
||||||
|
assert run_detail["route_json"]["selected_agent"] == "hermes"
|
||||||
|
assert len(run_detail["tool_calls"]) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_orchestrator_forbidden_request_does_not_call_downstream_agent() -> None:
|
||||||
|
client, _ = build_client()
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/orchestrator/run",
|
||||||
|
json={
|
||||||
|
"source": "user_message",
|
||||||
|
"user_id": "pytest",
|
||||||
|
"message": "帮我直接付款给供应商B",
|
||||||
|
"context_json": {"role_codes": ["user"]},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
payload = response.json()
|
||||||
|
assert payload["selected_agent"] is None
|
||||||
|
assert payload["permission_level"] == "forbidden"
|
||||||
|
assert payload["status"] == "blocked"
|
||||||
|
assert payload["trace_summary"]["tool_count"] == 0
|
||||||
|
|
||||||
|
run_detail = client.get(f"/api/v1/agent-runs/{payload['run_id']}").json()
|
||||||
|
assert run_detail["agent"] == "orchestrator"
|
||||||
|
assert run_detail["tool_calls"] == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_orchestrator_approval_required_returns_confirmation_result() -> None:
|
||||||
|
client, _ = build_client()
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/orchestrator/run",
|
||||||
|
json={
|
||||||
|
"source": "user_message",
|
||||||
|
"user_id": "pytest",
|
||||||
|
"message": "帮我安排付款给供应商B",
|
||||||
|
"context_json": {"role_codes": ["finance"]},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
payload = response.json()
|
||||||
|
assert payload["selected_agent"] == "user_agent"
|
||||||
|
assert payload["permission_level"] == "approval_required"
|
||||||
|
assert payload["requires_confirmation"] is True
|
||||||
|
assert payload["status"] == "blocked"
|
||||||
|
assert "确认" in payload["result"]["message"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_orchestrator_user_agent_draft_returns_structured_payload() -> None:
|
||||||
|
client, _ = build_client()
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/orchestrator/run",
|
||||||
|
json={
|
||||||
|
"source": "user_message",
|
||||||
|
"user_id": "pytest",
|
||||||
|
"message": "帮我生成张三4月差旅报销草稿",
|
||||||
|
"context_json": {"role_codes": ["finance"]},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
payload = response.json()
|
||||||
|
assert payload["selected_agent"] == "user_agent"
|
||||||
|
assert payload["status"] == "succeeded"
|
||||||
|
assert payload["result"]["draft_payload"]["confirmation_required"] is True
|
||||||
|
assert payload["result"]["suggested_actions"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_orchestrator_treats_expense_narrative_as_draft_instead_of_ar_query() -> None:
|
||||||
|
client, _ = build_client()
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/orchestrator/run",
|
||||||
|
json={
|
||||||
|
"source": "user_message",
|
||||||
|
"user_id": "pytest",
|
||||||
|
"message": "我今天去客户现场,招待了客户,花销了1000元",
|
||||||
|
"context_json": {"role_codes": ["finance"]},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
payload = response.json()
|
||||||
|
assert payload["selected_agent"] == "user_agent"
|
||||||
|
assert payload["permission_level"] == "draft_write"
|
||||||
|
assert payload["status"] == "blocked"
|
||||||
|
assert payload["route_reason"] == "clarification_required"
|
||||||
|
assert payload["trace_summary"]["scenario"] == "expense"
|
||||||
|
assert payload["trace_summary"]["intent"] == "draft"
|
||||||
|
assert payload["trace_summary"]["tool_count"] == 0
|
||||||
|
assert "应收场景数据" not in payload["result"]["message"]
|
||||||
|
assert "请补充" in payload["result"]["message"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_orchestrator_tool_failure_is_logged_and_degraded() -> None:
|
||||||
|
client, _ = build_client()
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/orchestrator/run",
|
||||||
|
json={
|
||||||
|
"source": "user_message",
|
||||||
|
"user_id": "pytest",
|
||||||
|
"message": "查一下本周报销金额",
|
||||||
|
"context_json": {
|
||||||
|
"role_codes": ["finance"],
|
||||||
|
"simulate_tool_failure": "database",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
payload = response.json()
|
||||||
|
assert payload["selected_agent"] == "user_agent"
|
||||||
|
assert payload["status"] == "succeeded"
|
||||||
|
assert payload["trace_summary"]["failed_tool_count"] == 1
|
||||||
|
assert payload["trace_summary"]["degraded"] is True
|
||||||
|
|
||||||
|
run_detail = client.get(f"/api/v1/agent-runs/{payload['run_id']}").json()
|
||||||
|
assert run_detail["tool_calls"][0]["status"] == "failed"
|
||||||
|
assert "simulated database failure" in run_detail["tool_calls"][0]["error_message"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_orchestrator_exception_is_written_to_agent_run() -> None:
|
||||||
|
client, _ = build_client()
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/orchestrator/run",
|
||||||
|
json={
|
||||||
|
"source": "user_message",
|
||||||
|
"user_id": "pytest",
|
||||||
|
"message": "查一下本周报销金额",
|
||||||
|
"context_json": {
|
||||||
|
"role_codes": ["finance"],
|
||||||
|
"simulate_orchestrator_exception": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
payload = response.json()
|
||||||
|
assert payload["status"] == "failed"
|
||||||
|
|
||||||
|
run_detail = client.get(f"/api/v1/agent-runs/{payload['run_id']}").json()
|
||||||
|
assert run_detail["status"] == "failed"
|
||||||
|
assert "simulated orchestrator exception" in run_detail["error_message"]
|
||||||
179
server/tests/test_user_agent_service.py
Normal file
179
server/tests/test_user_agent_service.py
Normal file
@@ -0,0 +1,179 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
from sqlalchemy.pool import StaticPool
|
||||||
|
|
||||||
|
from app.db.base import Base
|
||||||
|
from app.schemas.ontology import OntologyParseRequest
|
||||||
|
from app.schemas.user_agent import UserAgentRequest
|
||||||
|
from app.services.ontology import SemanticOntologyService
|
||||||
|
from app.services.user_agent import UserAgentService
|
||||||
|
|
||||||
|
|
||||||
|
def build_session_factory() -> sessionmaker[Session]:
|
||||||
|
engine = create_engine(
|
||||||
|
"sqlite+pysqlite:///:memory:",
|
||||||
|
connect_args={"check_same_thread": False},
|
||||||
|
poolclass=StaticPool,
|
||||||
|
)
|
||||||
|
Base.metadata.create_all(bind=engine)
|
||||||
|
return sessionmaker(bind=engine, autoflush=False, autocommit=False)
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_agent_query_returns_readable_answer_and_actions() -> None:
|
||||||
|
session_factory = build_session_factory()
|
||||||
|
with session_factory() as db:
|
||||||
|
ontology = SemanticOntologyService(db).parse(
|
||||||
|
OntologyParseRequest(
|
||||||
|
query="张三 4 月差旅报销金额是多少",
|
||||||
|
user_id="pytest",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
response = UserAgentService(db).respond(
|
||||||
|
UserAgentRequest(
|
||||||
|
run_id=ontology.run_id,
|
||||||
|
user_id="pytest",
|
||||||
|
message="张三 4 月差旅报销金额是多少",
|
||||||
|
ontology=ontology,
|
||||||
|
tool_payload={"record_count": 2, "total_amount": 8800.0},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "8800.00" in response.answer
|
||||||
|
assert len(response.suggested_actions) >= 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_agent_prefers_runtime_model_answer_when_available(monkeypatch) -> None:
|
||||||
|
session_factory = build_session_factory()
|
||||||
|
with session_factory() as db:
|
||||||
|
ontology = SemanticOntologyService(db).parse(
|
||||||
|
OntologyParseRequest(
|
||||||
|
query="张三 4 月差旅报销金额是多少",
|
||||||
|
user_id="pytest",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
service = UserAgentService(db)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
service,
|
||||||
|
"_generate_answer_with_model",
|
||||||
|
lambda *args, **kwargs: "这是模型回答",
|
||||||
|
)
|
||||||
|
|
||||||
|
response = service.respond(
|
||||||
|
UserAgentRequest(
|
||||||
|
run_id=ontology.run_id,
|
||||||
|
user_id="pytest",
|
||||||
|
message="张三 4 月差旅报销金额是多少",
|
||||||
|
ontology=ontology,
|
||||||
|
tool_payload={"record_count": 2, "total_amount": 8800.0},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.answer == "这是模型回答"
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_agent_sanitizes_model_thinking_blocks() -> None:
|
||||||
|
session_factory = build_session_factory()
|
||||||
|
with session_factory() as db:
|
||||||
|
service = UserAgentService(db)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
service._sanitize_model_answer("<think>内部推理</think>\n最终答复")
|
||||||
|
== "最终答复"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_agent_guides_generic_expense_request() -> None:
|
||||||
|
session_factory = build_session_factory()
|
||||||
|
with session_factory() as db:
|
||||||
|
ontology = SemanticOntologyService(db).parse(
|
||||||
|
OntologyParseRequest(
|
||||||
|
query="我要报销",
|
||||||
|
user_id="pytest",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
response = UserAgentService(db).respond(
|
||||||
|
UserAgentRequest(
|
||||||
|
run_id=ontology.run_id,
|
||||||
|
user_id="pytest",
|
||||||
|
message="我要报销",
|
||||||
|
ontology=ontology,
|
||||||
|
tool_payload={"record_count": 9, "total_amount": 12345.0},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "补充费用类型" in response.answer
|
||||||
|
assert "上传票据" in response.answer
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_agent_guides_implicit_expense_draft_request() -> None:
|
||||||
|
session_factory = build_session_factory()
|
||||||
|
with session_factory() as db:
|
||||||
|
ontology = SemanticOntologyService(db).parse(
|
||||||
|
OntologyParseRequest(
|
||||||
|
query="我今天去客户现场,招待了客户,花销了1000元",
|
||||||
|
user_id="pytest",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
response = UserAgentService(db).respond(
|
||||||
|
UserAgentRequest(
|
||||||
|
run_id=ontology.run_id,
|
||||||
|
user_id="pytest",
|
||||||
|
message="我今天去客户现场,招待了客户,花销了1000元",
|
||||||
|
ontology=ontology,
|
||||||
|
tool_payload={"draft_only": True},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "1000元" in response.answer
|
||||||
|
assert "票据附件" in response.answer
|
||||||
|
assert "报销草稿" in response.answer
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_agent_risk_response_includes_rule_citations() -> None:
|
||||||
|
session_factory = build_session_factory()
|
||||||
|
with session_factory() as db:
|
||||||
|
ontology = SemanticOntologyService(db).parse(
|
||||||
|
OntologyParseRequest(
|
||||||
|
query="检查重复报销风险",
|
||||||
|
user_id="pytest",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
response = UserAgentService(db).respond(
|
||||||
|
UserAgentRequest(
|
||||||
|
run_id=ontology.run_id,
|
||||||
|
user_id="pytest",
|
||||||
|
message="检查重复报销风险",
|
||||||
|
ontology=ontology,
|
||||||
|
tool_payload={"risk_flags": ["duplicate_expense"]},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.risk_flags == ["duplicate_expense"]
|
||||||
|
assert any(item.source_type == "rule" for item in response.citations)
|
||||||
|
assert "duplicate_expense" in response.answer
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_agent_draft_returns_structured_payload() -> None:
|
||||||
|
session_factory = build_session_factory()
|
||||||
|
with session_factory() as db:
|
||||||
|
ontology = SemanticOntologyService(db).parse(
|
||||||
|
OntologyParseRequest(
|
||||||
|
query="帮我生成张三4月差旅报销草稿",
|
||||||
|
user_id="pytest",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
response = UserAgentService(db).respond(
|
||||||
|
UserAgentRequest(
|
||||||
|
run_id=ontology.run_id,
|
||||||
|
user_id="pytest",
|
||||||
|
message="帮我生成张三4月差旅报销草稿",
|
||||||
|
ontology=ontology,
|
||||||
|
tool_payload={"draft_only": True},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.draft_payload is not None
|
||||||
|
assert response.draft_payload.confirmation_required is True
|
||||||
|
assert "待人工确认" in response.answer
|
||||||
Reference in New Issue
Block a user