feat(backend): add ontology and orchestrator API endpoints
New endpoints: - server/src/app/api/v1/endpoints/ontology.py: ontology API - server/src/app/api/v1/endpoints/orchestrator.py: orchestrator API New schemas: - server/src/app/schemas/ontology.py: ontology data schemas - server/src/app/schemas/orchestrator.py: orchestrator data schemas - server/src/app/schemas/user_agent.py: user agent data schemas New services: - server/src/app/services/ontology.py: ontology business logic - server/src/app/services/orchestrator.py: orchestrator business logic - server/src/app/services/runtime_chat.py: runtime chat service - server/src/app/services/user_agent.py: user agent service New tests: - server/tests/test_ontology_service.py - server/tests/test_orchestrator_service.py - server/tests/test_user_agent_service.py
This commit is contained in:
36
server/src/app/api/v1/endpoints/ontology.py
Normal file
36
server/src/app/api/v1/endpoints/ontology.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.deps import get_db
|
||||
from app.schemas.common import ErrorResponse
|
||||
from app.schemas.ontology import OntologyParseRequest, OntologyParseResult
|
||||
from app.services.ontology import SemanticOntologyService
|
||||
|
||||
router = APIRouter(prefix="/ontology")
|
||||
DbSession = Annotated[Session, Depends(get_db)]
|
||||
|
||||
|
||||
@router.post(
|
||||
"/parse",
|
||||
response_model=OntologyParseResult,
|
||||
summary="解析自然语言为语义本体",
|
||||
description=(
|
||||
"把自然语言问题解析成 Day 3 约定的 8 个核心字段,"
|
||||
"并写入 AgentRun 与 SemanticParseLog。"
|
||||
),
|
||||
responses={
|
||||
status.HTTP_400_BAD_REQUEST: {
|
||||
"model": ErrorResponse,
|
||||
"description": "请求缺少有效 query 或解析请求格式不合法。",
|
||||
}
|
||||
},
|
||||
)
|
||||
def parse_ontology(payload: OntologyParseRequest, db: DbSession) -> OntologyParseResult:
|
||||
try:
|
||||
return SemanticOntologyService(db).parse(payload)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
|
||||
33
server/src/app/api/v1/endpoints/orchestrator.py
Normal file
33
server/src/app/api/v1/endpoints/orchestrator.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.deps import get_db
|
||||
from app.schemas.common import ErrorResponse
|
||||
from app.schemas.orchestrator import OrchestratorRequest, OrchestratorResponse
|
||||
from app.services.orchestrator import OrchestratorService
|
||||
|
||||
router = APIRouter(prefix="/orchestrator")
|
||||
DbSession = Annotated[Session, Depends(get_db)]
|
||||
|
||||
|
||||
@router.post(
|
||||
"/run",
|
||||
response_model=OrchestratorResponse,
|
||||
summary="运行 Orchestrator 统一调度",
|
||||
description="统一接收用户消息、定时任务和系统事件,完成语义解析、权限判断、路由和占位执行。",
|
||||
responses={
|
||||
status.HTTP_400_BAD_REQUEST: {
|
||||
"model": ErrorResponse,
|
||||
"description": "请求缺少 message 或 task_id,无法启动调度。",
|
||||
}
|
||||
},
|
||||
)
|
||||
def run_orchestrator(payload: OrchestratorRequest, db: DbSession) -> OrchestratorResponse:
|
||||
try:
|
||||
return OrchestratorService(db).run(payload)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
|
||||
116
server/src/app/schemas/ontology.py
Normal file
116
server/src/app/schemas/ontology.py
Normal file
@@ -0,0 +1,116 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
OntologyScenario = Literal[
|
||||
"expense",
|
||||
"accounts_receivable",
|
||||
"accounts_payable",
|
||||
"knowledge",
|
||||
"unknown",
|
||||
]
|
||||
OntologyIntent = Literal["query", "explain", "compare", "risk_check", "draft", "operate"]
|
||||
OntologyPermissionLevel = Literal["read", "draft_write", "approval_required", "forbidden"]
|
||||
OntologyParseStrategy = Literal["llm_primary", "rule_fallback"]
|
||||
|
||||
|
||||
class OntologyEntity(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
type: str = Field(description="业务对象类型,例如 employee / customer / vendor。")
|
||||
value: str = Field(description="从原始问题中提取的对象值。")
|
||||
normalized_value: str = Field(description="标准化后的对象值。")
|
||||
role: str = Field(default="target", description="对象角色,例如 target / filter / threshold。")
|
||||
confidence: float = Field(default=0.0, ge=0.0, le=1.0, description="字段级置信度。")
|
||||
|
||||
|
||||
class OntologyTimeRange(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
raw: str = Field(default="", description="命中的原始时间表达。")
|
||||
start_date: str | None = Field(default=None, description="ISO 格式起始日期。")
|
||||
end_date: str | None = Field(default=None, description="ISO 格式结束日期。")
|
||||
granularity: str | None = Field(
|
||||
default=None,
|
||||
description="day / week / month / quarter / year。",
|
||||
)
|
||||
|
||||
|
||||
class OntologyMetric(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
name: str = Field(description="指标名,例如 amount / count / overdue。")
|
||||
aggregation: str | None = Field(default=None, description="sum / count / max 等聚合口径。")
|
||||
unit: str | None = Field(default=None, description="金额、数量等单位。")
|
||||
sort: str | None = Field(default=None, description="asc / desc 排序方向。")
|
||||
top_n: int | None = Field(default=None, ge=1, description="Top N 口径。")
|
||||
|
||||
|
||||
class OntologyConstraint(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
field: str = Field(description="约束字段,例如 department / status / amount。")
|
||||
operator: str = Field(description="操作符,例如 = / > / < / desc。")
|
||||
value: str | int | float | bool = Field(description="约束值。")
|
||||
currency: str | None = Field(default=None, description="金额类约束使用的币种。")
|
||||
|
||||
|
||||
class OntologyPermission(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
level: OntologyPermissionLevel = Field(default="read", description="动作权限等级。")
|
||||
allowed: bool = Field(default=True, description="是否可直接执行当前动作。")
|
||||
reason: str = Field(default="", description="权限判断原因。")
|
||||
|
||||
|
||||
class OntologyFieldError(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
field: str = Field(description="发生问题的字段。")
|
||||
code: str = Field(description="错误码。")
|
||||
message: str = Field(description="面向前端展示的说明。")
|
||||
|
||||
|
||||
class OntologyParseRequest(BaseModel):
|
||||
query: str = Field(min_length=1, description="自然语言问题。")
|
||||
user_id: str | None = Field(default=None, description="当前请求用户 ID。")
|
||||
context_json: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="用户上下文,例如角色、部门、是否管理员。",
|
||||
)
|
||||
|
||||
|
||||
class OntologyParseResult(BaseModel):
|
||||
scenario: OntologyScenario = Field(default="unknown", description="业务场景。")
|
||||
intent: OntologyIntent = Field(default="query", description="用户意图。")
|
||||
entities: list[OntologyEntity] = Field(default_factory=list, description="业务对象列表。")
|
||||
time_range: OntologyTimeRange = Field(
|
||||
default_factory=OntologyTimeRange,
|
||||
description="时间范围。",
|
||||
)
|
||||
metrics: list[OntologyMetric] = Field(default_factory=list, description="指标解析结果。")
|
||||
constraints: list[OntologyConstraint] = Field(
|
||||
default_factory=list,
|
||||
description="过滤、阈值、排序等约束。",
|
||||
)
|
||||
risk_flags: list[str] = Field(default_factory=list, description="风险信号列表。")
|
||||
permission: OntologyPermission = Field(
|
||||
default_factory=OntologyPermission,
|
||||
description="权限结果。",
|
||||
)
|
||||
confidence: float = Field(default=0.0, ge=0.0, le=1.0, description="整体置信度。")
|
||||
missing_slots: list[str] = Field(default_factory=list, description="继续处理所缺少的关键槽位。")
|
||||
ambiguity: list[str] = Field(default_factory=list, description="当前识别中的潜在歧义。")
|
||||
parse_strategy: OntologyParseStrategy = Field(
|
||||
default="rule_fallback",
|
||||
description="本次语义解析使用的主策略。",
|
||||
)
|
||||
clarification_required: bool = Field(default=False, description="是否需要追问。")
|
||||
clarification_question: str | None = Field(default=None, description="推荐追问问题。")
|
||||
run_id: str = Field(description="关联的 AgentRun.run_id。")
|
||||
field_errors: list[OntologyFieldError] = Field(
|
||||
default_factory=list,
|
||||
description="字段级错误或提示。",
|
||||
)
|
||||
46
server/src/app/schemas/orchestrator.py
Normal file
46
server/src/app/schemas/orchestrator.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
OrchestratorSource = Literal["user_message", "schedule", "system_event"]
|
||||
OrchestratorAgent = Literal["user_agent", "hermes"]
|
||||
OrchestratorStatus = Literal["succeeded", "blocked", "failed"]
|
||||
|
||||
|
||||
class OrchestratorRequest(BaseModel):
|
||||
source: OrchestratorSource = Field(description="请求来源。")
|
||||
user_id: str | None = Field(default=None, description="当前用户 ID,任务触发可为空。")
|
||||
message: str | None = Field(default=None, description="用户消息或任务描述。")
|
||||
task_id: str | None = Field(default=None, description="任务资产 ID,schedule 触发时优先使用。")
|
||||
context_json: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="用户上下文、测试开关或调用方附加信息。",
|
||||
)
|
||||
|
||||
|
||||
class OrchestratorTraceSummary(BaseModel):
|
||||
scenario: str = Field(description="语义场景。")
|
||||
intent: str = Field(description="语义意图。")
|
||||
tool_count: int = Field(default=0, ge=0, description="工具调用总数。")
|
||||
failed_tool_count: int = Field(default=0, ge=0, description="失败工具调用数量。")
|
||||
selected_capability_codes: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="本次路由命中的能力编码。",
|
||||
)
|
||||
degraded: bool = Field(default=False, description="是否发生降级。")
|
||||
|
||||
|
||||
class OrchestratorResponse(BaseModel):
|
||||
run_id: str = Field(description="本次运行的唯一 run_id。")
|
||||
selected_agent: OrchestratorAgent | None = Field(
|
||||
default=None,
|
||||
description="最终路由到的下游 Agent。",
|
||||
)
|
||||
route_reason: str = Field(description="路由原因摘要。")
|
||||
permission_level: str = Field(description="权限级别。")
|
||||
status: OrchestratorStatus = Field(description="最终运行状态。")
|
||||
result: dict[str, Any] = Field(default_factory=dict, description="对前端可直接展示的最小结果。")
|
||||
requires_confirmation: bool = Field(default=False, description="是否需要用户或管理员确认。")
|
||||
trace_summary: OrchestratorTraceSummary = Field(description="简化后的 Trace 摘要。")
|
||||
58
server/src/app/schemas/user_agent.py
Normal file
58
server/src/app/schemas/user_agent.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.schemas.ontology import OntologyParseResult
|
||||
|
||||
UserAgentCitationType = Literal["rule", "knowledge"]
|
||||
|
||||
|
||||
class UserAgentCitation(BaseModel):
|
||||
source_type: UserAgentCitationType = Field(description="引用来源类型。")
|
||||
code: str = Field(description="来源编码。")
|
||||
title: str = Field(description="来源标题。")
|
||||
version: str | None = Field(default=None, description="引用版本。")
|
||||
updated_at: str | None = Field(default=None, description="来源更新时间。")
|
||||
excerpt: str | None = Field(default=None, description="面向用户展示的引用摘要。")
|
||||
|
||||
|
||||
class UserAgentSuggestedAction(BaseModel):
|
||||
label: str = Field(description="建议动作文案。")
|
||||
action_type: str = Field(description="动作类型,例如 open_detail / create_draft。")
|
||||
description: str = Field(default="", description="动作说明。")
|
||||
|
||||
|
||||
class UserAgentDraftPayload(BaseModel):
|
||||
draft_type: str = Field(description="草稿类型。")
|
||||
title: str = Field(description="草稿标题。")
|
||||
body: str = Field(description="草稿正文。")
|
||||
confirmation_required: bool = Field(default=True, description="是否需要人工确认。")
|
||||
|
||||
|
||||
class UserAgentRequest(BaseModel):
|
||||
run_id: str = Field(description="关联的 AgentRun.run_id。")
|
||||
user_id: str | None = Field(default=None, description="当前请求用户 ID。")
|
||||
message: str = Field(description="原始用户问题。")
|
||||
ontology: OntologyParseResult = Field(description="语义解析结果。")
|
||||
context_json: dict[str, Any] = Field(default_factory=dict, description="附加上下文。")
|
||||
tool_payload: dict[str, Any] = Field(default_factory=dict, description="工具返回的原始结果。")
|
||||
selected_capability_codes: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="本次命中的能力编码。",
|
||||
)
|
||||
degraded: bool = Field(default=False, description="当前是否发生降级。")
|
||||
requires_confirmation: bool = Field(default=False, description="是否要求确认。")
|
||||
|
||||
|
||||
class UserAgentResponse(BaseModel):
|
||||
answer: str = Field(description="面向用户展示的自然语言回答。")
|
||||
citations: list[UserAgentCitation] = Field(default_factory=list, description="规则或知识引用。")
|
||||
suggested_actions: list[UserAgentSuggestedAction] = Field(
|
||||
default_factory=list,
|
||||
description="建议的下一步动作。",
|
||||
)
|
||||
draft_payload: UserAgentDraftPayload | None = Field(default=None, description="可选草稿内容。")
|
||||
risk_flags: list[str] = Field(default_factory=list, description="本次回答关联的风险标签。")
|
||||
requires_confirmation: bool = Field(default=False, description="是否需要人工确认。")
|
||||
1470
server/src/app/services/ontology.py
Normal file
1470
server/src/app/services/ontology.py
Normal file
File diff suppressed because it is too large
Load Diff
887
server/src/app/services/orchestrator.py
Normal file
887
server/src/app/services/orchestrator.py
Normal file
@@ -0,0 +1,887 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from time import perf_counter
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.agent_enums import (
|
||||
AgentAssetStatus,
|
||||
AgentAssetType,
|
||||
AgentName,
|
||||
AgentPermissionLevel,
|
||||
AgentRunSource,
|
||||
AgentRunStatus,
|
||||
AgentToolType,
|
||||
)
|
||||
from app.core.logging import get_logger
|
||||
from app.models.financial_record import (
|
||||
AccountsPayableRecord,
|
||||
AccountsReceivableRecord,
|
||||
ExpenseClaim,
|
||||
)
|
||||
from app.schemas.agent_asset import AgentAssetListItem, AgentAssetRead
|
||||
from app.schemas.ontology import OntologyParseRequest, OntologyParseResult
|
||||
from app.schemas.orchestrator import (
|
||||
OrchestratorRequest,
|
||||
OrchestratorResponse,
|
||||
OrchestratorTraceSummary,
|
||||
)
|
||||
from app.schemas.user_agent import UserAgentRequest, UserAgentResponse
|
||||
from app.services.agent_assets import AgentAssetService
|
||||
from app.services.agent_foundation import AgentFoundationService
|
||||
from app.services.agent_runs import AgentRunService
|
||||
from app.services.ontology import SemanticOntologyService
|
||||
from app.services.user_agent import UserAgentService
|
||||
|
||||
logger = get_logger("app.services.orchestrator")
|
||||
|
||||
SCENARIO_TO_DOMAIN = {
|
||||
"expense": "expense",
|
||||
"accounts_receivable": "ar",
|
||||
"accounts_payable": "ap",
|
||||
"knowledge": "knowledge",
|
||||
"unknown": "system",
|
||||
}
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ExecutionOutcome:
|
||||
status: str
|
||||
result: dict[str, Any]
|
||||
degraded: bool
|
||||
tool_count: int
|
||||
failed_tool_count: int
|
||||
|
||||
|
||||
class OrchestratorService:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
self.asset_service = AgentAssetService(db)
|
||||
self.run_service = AgentRunService(db)
|
||||
self.ontology_service = SemanticOntologyService(db)
|
||||
self.user_agent_service = UserAgentService(db)
|
||||
|
||||
def run(self, payload: OrchestratorRequest) -> OrchestratorResponse:
|
||||
AgentFoundationService(self.db).ensure_foundation_ready()
|
||||
route_json: dict[str, Any] = {
|
||||
"orchestrated_by": AgentName.ORCHESTRATOR.value,
|
||||
"stage": "created",
|
||||
}
|
||||
run = self.run_service.create_run(
|
||||
agent=AgentName.ORCHESTRATOR.value,
|
||||
source=payload.source,
|
||||
user_id=payload.user_id,
|
||||
task_id=payload.task_id,
|
||||
ontology_json={},
|
||||
route_json=route_json,
|
||||
permission_level=AgentPermissionLevel.READ.value,
|
||||
status=AgentRunStatus.RUNNING.value,
|
||||
result_summary="Orchestrator 已接收请求。",
|
||||
)
|
||||
|
||||
try:
|
||||
message, task_asset = self._resolve_message(payload)
|
||||
ontology = self.ontology_service.parse_for_run(
|
||||
OntologyParseRequest(
|
||||
query=message,
|
||||
user_id=payload.user_id,
|
||||
context_json=payload.context_json,
|
||||
),
|
||||
run_id=run.run_id,
|
||||
)
|
||||
if payload.context_json.get("simulate_orchestrator_exception"):
|
||||
raise RuntimeError("simulated orchestrator exception")
|
||||
selected_agent, route_reason = self._select_agent(payload, ontology)
|
||||
capabilities = self._select_capabilities(
|
||||
payload=payload,
|
||||
ontology=ontology,
|
||||
task_asset=task_asset,
|
||||
)
|
||||
selected_capability_codes = self._flatten_capability_codes(capabilities)
|
||||
requires_confirmation = (
|
||||
ontology.permission.level == AgentPermissionLevel.APPROVAL_REQUIRED.value
|
||||
)
|
||||
|
||||
route_json = {
|
||||
"orchestrated_by": AgentName.ORCHESTRATOR.value,
|
||||
"stage": "routed",
|
||||
"selected_agent": selected_agent,
|
||||
"route_reason": route_reason,
|
||||
"selected_capability_codes": selected_capability_codes,
|
||||
"ontology_run_id": ontology.run_id,
|
||||
}
|
||||
|
||||
if ontology.permission.level == AgentPermissionLevel.FORBIDDEN.value:
|
||||
outcome = ExecutionOutcome(
|
||||
status=AgentRunStatus.BLOCKED.value,
|
||||
result={
|
||||
"message": ontology.permission.reason,
|
||||
"clarification_question": ontology.clarification_question,
|
||||
"degraded": False,
|
||||
},
|
||||
degraded=False,
|
||||
tool_count=0,
|
||||
failed_tool_count=0,
|
||||
)
|
||||
selected_agent = None
|
||||
route_reason = "permission_forbidden"
|
||||
route_json["stage"] = "blocked"
|
||||
route_json["route_reason"] = route_reason
|
||||
elif ontology.clarification_required:
|
||||
outcome = ExecutionOutcome(
|
||||
status=AgentRunStatus.BLOCKED.value,
|
||||
result={
|
||||
"message": ontology.clarification_question or "需要补充更多上下文。",
|
||||
"clarification_required": True,
|
||||
"missing_slots": ontology.missing_slots,
|
||||
"ambiguity": ontology.ambiguity,
|
||||
"parse_strategy": ontology.parse_strategy,
|
||||
"degraded": False,
|
||||
},
|
||||
degraded=False,
|
||||
tool_count=0,
|
||||
failed_tool_count=0,
|
||||
)
|
||||
route_reason = "clarification_required"
|
||||
route_json["stage"] = "clarification"
|
||||
route_json["route_reason"] = route_reason
|
||||
elif selected_agent == AgentName.HERMES.value:
|
||||
outcome = self._execute_hermes(
|
||||
payload=payload,
|
||||
run_id=run.run_id,
|
||||
ontology=ontology,
|
||||
capabilities=capabilities,
|
||||
requires_confirmation=requires_confirmation,
|
||||
task_asset=task_asset,
|
||||
)
|
||||
else:
|
||||
outcome = self._execute_user_agent(
|
||||
payload=payload,
|
||||
run_id=run.run_id,
|
||||
ontology=ontology,
|
||||
capabilities=capabilities,
|
||||
requires_confirmation=requires_confirmation,
|
||||
)
|
||||
|
||||
final_status = (
|
||||
AgentRunStatus.BLOCKED.value
|
||||
if requires_confirmation
|
||||
and outcome.status == AgentRunStatus.SUCCEEDED.value
|
||||
and ontology.permission.level == AgentPermissionLevel.APPROVAL_REQUIRED.value
|
||||
else outcome.status
|
||||
)
|
||||
result_message = (
|
||||
str(outcome.result.get("message", "")).strip()
|
||||
or "Orchestrator 执行完成。"
|
||||
)
|
||||
self.run_service.update_run(
|
||||
run.run_id,
|
||||
agent=selected_agent or AgentName.ORCHESTRATOR.value,
|
||||
ontology_json=self._build_ontology_json(ontology),
|
||||
route_json={
|
||||
**route_json,
|
||||
"requires_confirmation": requires_confirmation,
|
||||
"degraded": outcome.degraded,
|
||||
},
|
||||
permission_level=ontology.permission.level,
|
||||
status=final_status,
|
||||
result_summary=result_message,
|
||||
error_message=None,
|
||||
finished_at=datetime.now(UTC),
|
||||
)
|
||||
return OrchestratorResponse(
|
||||
run_id=run.run_id,
|
||||
selected_agent=selected_agent,
|
||||
route_reason=route_reason,
|
||||
permission_level=ontology.permission.level,
|
||||
status=self._normalize_response_status(final_status),
|
||||
result=outcome.result,
|
||||
requires_confirmation=requires_confirmation,
|
||||
trace_summary=OrchestratorTraceSummary(
|
||||
scenario=ontology.scenario,
|
||||
intent=ontology.intent,
|
||||
tool_count=outcome.tool_count,
|
||||
failed_tool_count=outcome.failed_tool_count,
|
||||
selected_capability_codes=selected_capability_codes,
|
||||
degraded=outcome.degraded,
|
||||
),
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception("Orchestrator run failed run_id=%s", run.run_id)
|
||||
self.run_service.update_run(
|
||||
run.run_id,
|
||||
agent=AgentName.ORCHESTRATOR.value,
|
||||
route_json={**route_json, "stage": "failed"},
|
||||
status=AgentRunStatus.FAILED.value,
|
||||
result_summary="Orchestrator 执行失败。",
|
||||
error_message=str(exc),
|
||||
finished_at=datetime.now(UTC),
|
||||
)
|
||||
return OrchestratorResponse(
|
||||
run_id=run.run_id,
|
||||
selected_agent=None,
|
||||
route_reason="orchestrator_exception",
|
||||
permission_level=AgentPermissionLevel.READ.value,
|
||||
status="failed",
|
||||
result={"message": f"Orchestrator 执行失败:{exc}"},
|
||||
requires_confirmation=False,
|
||||
trace_summary=OrchestratorTraceSummary(
|
||||
scenario="unknown",
|
||||
intent="query",
|
||||
tool_count=0,
|
||||
failed_tool_count=0,
|
||||
selected_capability_codes=[],
|
||||
degraded=False,
|
||||
),
|
||||
)
|
||||
|
||||
def _resolve_message(
|
||||
self,
|
||||
payload: OrchestratorRequest,
|
||||
) -> tuple[str, AgentAssetRead | None]:
|
||||
task_asset = None
|
||||
if payload.task_id:
|
||||
task_asset = self.asset_service.get_asset(payload.task_id)
|
||||
|
||||
if payload.message and payload.message.strip():
|
||||
return payload.message.strip(), task_asset
|
||||
|
||||
if task_asset is not None:
|
||||
description = str(task_asset.description or "").strip()
|
||||
scenario_text = " ".join(str(item) for item in task_asset.scenario_json)
|
||||
message = f"{task_asset.name} {description} {scenario_text}".strip()
|
||||
return message, task_asset
|
||||
|
||||
if payload.source == AgentRunSource.SCHEDULE.value:
|
||||
return "定时风险巡检任务", task_asset
|
||||
|
||||
raise ValueError("message 或 task_id 至少需要提供一个。")
|
||||
|
||||
@staticmethod
|
||||
def _select_agent(
|
||||
payload: OrchestratorRequest,
|
||||
ontology: OntologyParseResult,
|
||||
) -> tuple[str, str]:
|
||||
if payload.source == AgentRunSource.SCHEDULE.value:
|
||||
return AgentName.HERMES.value, "schedule_source_defaults_to_hermes"
|
||||
if payload.source == AgentRunSource.SYSTEM_EVENT.value and ontology.intent == "risk_check":
|
||||
return AgentName.HERMES.value, "system_event_risk_check_routes_to_hermes"
|
||||
if ontology.intent == "risk_check" and payload.source == AgentRunSource.SCHEDULE.value:
|
||||
return AgentName.HERMES.value, "scheduled_risk_check_routes_to_hermes"
|
||||
if ontology.intent in {"query", "explain", "draft", "compare", "operate"}:
|
||||
return AgentName.USER_AGENT.value, f"{ontology.intent}_routes_to_user_agent"
|
||||
return AgentName.USER_AGENT.value, "user_message_defaults_to_user_agent"
|
||||
|
||||
def _select_capabilities(
|
||||
self,
|
||||
*,
|
||||
payload: OrchestratorRequest,
|
||||
ontology: OntologyParseResult,
|
||||
task_asset: AgentAssetRead | None,
|
||||
) -> dict[str, list[AgentAssetListItem | AgentAssetRead]]:
|
||||
domain_value = SCENARIO_TO_DOMAIN.get(ontology.scenario)
|
||||
rules = self._rank_assets(
|
||||
self.asset_service.list_assets(
|
||||
asset_type=AgentAssetType.RULE.value,
|
||||
status=AgentAssetStatus.ACTIVE.value,
|
||||
domain=domain_value if domain_value not in {"knowledge", "system"} else None,
|
||||
),
|
||||
ontology,
|
||||
)
|
||||
skills = self._rank_assets(
|
||||
self.asset_service.list_assets(
|
||||
asset_type=AgentAssetType.SKILL.value,
|
||||
status=AgentAssetStatus.ACTIVE.value,
|
||||
domain=domain_value if domain_value not in {"system"} else None,
|
||||
),
|
||||
ontology,
|
||||
)
|
||||
mcps = self._rank_assets(
|
||||
self.asset_service.list_assets(
|
||||
asset_type=AgentAssetType.MCP.value,
|
||||
status=AgentAssetStatus.ACTIVE.value,
|
||||
),
|
||||
ontology,
|
||||
)
|
||||
tasks: list[AgentAssetListItem | AgentAssetRead] = []
|
||||
if task_asset is not None and task_asset.status == AgentAssetStatus.ACTIVE.value:
|
||||
tasks.append(task_asset)
|
||||
elif payload.source == AgentRunSource.SCHEDULE.value:
|
||||
tasks = self._rank_assets(
|
||||
self.asset_service.list_assets(
|
||||
asset_type=AgentAssetType.TASK.value,
|
||||
status=AgentAssetStatus.ACTIVE.value,
|
||||
),
|
||||
ontology,
|
||||
)
|
||||
|
||||
return {
|
||||
"rules": rules,
|
||||
"skills": skills,
|
||||
"mcps": mcps,
|
||||
"tasks": tasks,
|
||||
}
|
||||
|
||||
def _execute_user_agent(
|
||||
self,
|
||||
*,
|
||||
payload: OrchestratorRequest,
|
||||
run_id: str,
|
||||
ontology: OntologyParseResult,
|
||||
capabilities: dict[str, list[AgentAssetListItem | AgentAssetRead]],
|
||||
requires_confirmation: bool,
|
||||
) -> ExecutionOutcome:
|
||||
selected_capability_codes = self._flatten_capability_codes(capabilities)
|
||||
if requires_confirmation:
|
||||
response, degraded = self._invoke_tool(
|
||||
run_id=run_id,
|
||||
tool_type=AgentToolType.LLM.value,
|
||||
tool_name="user_agent.confirmation_placeholder",
|
||||
request_json={
|
||||
"message": payload.message,
|
||||
"permission_level": ontology.permission.level,
|
||||
},
|
||||
context_json=payload.context_json,
|
||||
executor=lambda: {
|
||||
"confirmation_title": "操作需要确认",
|
||||
"message": f"{ontology.permission.reason} 当前仅返回确认摘要,不直接执行动作。",
|
||||
},
|
||||
fallback_factory=lambda exc: {
|
||||
"confirmation_title": "操作需要确认",
|
||||
"message": f"确认摘要生成失败,已阻断自动执行:{exc}",
|
||||
},
|
||||
)
|
||||
return ExecutionOutcome(
|
||||
status=AgentRunStatus.BLOCKED.value,
|
||||
result={**response, "degraded": degraded},
|
||||
degraded=degraded,
|
||||
tool_count=1,
|
||||
failed_tool_count=1 if degraded else 0,
|
||||
)
|
||||
|
||||
next_step = self._resolve_next_step(ontology, payload.source)
|
||||
if next_step == "query_database":
|
||||
tool_payload, degraded = self._invoke_tool(
|
||||
run_id=run_id,
|
||||
tool_type=AgentToolType.DATABASE.value,
|
||||
tool_name=self._database_tool_name(ontology.scenario),
|
||||
request_json=self._build_ontology_json(ontology),
|
||||
context_json=payload.context_json,
|
||||
executor=lambda: self._build_database_answer(ontology),
|
||||
fallback_factory=lambda exc: {
|
||||
"message": f"数据库查询暂时不可用,已返回降级说明:{exc}",
|
||||
"degraded": True,
|
||||
},
|
||||
)
|
||||
result = self._build_user_agent_result(
|
||||
self.user_agent_service.respond(
|
||||
UserAgentRequest(
|
||||
run_id=run_id,
|
||||
user_id=payload.user_id,
|
||||
message=payload.message or "",
|
||||
ontology=ontology,
|
||||
context_json=payload.context_json,
|
||||
tool_payload=tool_payload,
|
||||
selected_capability_codes=selected_capability_codes,
|
||||
degraded=degraded,
|
||||
requires_confirmation=requires_confirmation,
|
||||
)
|
||||
),
|
||||
degraded=degraded,
|
||||
)
|
||||
return ExecutionOutcome(
|
||||
status=AgentRunStatus.SUCCEEDED.value,
|
||||
result=result,
|
||||
degraded=degraded,
|
||||
tool_count=1,
|
||||
failed_tool_count=1 if degraded else 0,
|
||||
)
|
||||
|
||||
if next_step == "search_knowledge":
|
||||
tool_payload, degraded = self._invoke_tool(
|
||||
run_id=run_id,
|
||||
tool_type=AgentToolType.DATABASE.value,
|
||||
tool_name="knowledge.search",
|
||||
request_json=self._build_ontology_json(ontology),
|
||||
context_json=payload.context_json,
|
||||
executor=lambda: self._build_knowledge_answer(ontology, capabilities),
|
||||
fallback_factory=lambda exc: {
|
||||
"message": f"知识检索暂时不可用,建议稍后重试:{exc}",
|
||||
"degraded": True,
|
||||
},
|
||||
)
|
||||
result = self._build_user_agent_result(
|
||||
self.user_agent_service.respond(
|
||||
UserAgentRequest(
|
||||
run_id=run_id,
|
||||
user_id=payload.user_id,
|
||||
message=payload.message or "",
|
||||
ontology=ontology,
|
||||
context_json=payload.context_json,
|
||||
tool_payload=tool_payload,
|
||||
selected_capability_codes=selected_capability_codes,
|
||||
degraded=degraded,
|
||||
requires_confirmation=requires_confirmation,
|
||||
)
|
||||
),
|
||||
degraded=degraded,
|
||||
)
|
||||
return ExecutionOutcome(
|
||||
status=AgentRunStatus.SUCCEEDED.value,
|
||||
result=result,
|
||||
degraded=degraded,
|
||||
tool_count=1,
|
||||
failed_tool_count=1 if degraded else 0,
|
||||
)
|
||||
|
||||
if next_step == "run_rule":
|
||||
tool_payload, degraded = self._invoke_tool(
|
||||
run_id=run_id,
|
||||
tool_type=AgentToolType.RULE_ENGINE.value,
|
||||
tool_name=self._rule_tool_name(capabilities),
|
||||
request_json=self._build_ontology_json(ontology),
|
||||
context_json=payload.context_json,
|
||||
executor=lambda: self._build_rule_answer(ontology),
|
||||
fallback_factory=lambda exc: {
|
||||
"message": f"规则检查暂时不可用,已返回人工复核建议:{exc}",
|
||||
"degraded": True,
|
||||
},
|
||||
)
|
||||
result = self._build_user_agent_result(
|
||||
self.user_agent_service.respond(
|
||||
UserAgentRequest(
|
||||
run_id=run_id,
|
||||
user_id=payload.user_id,
|
||||
message=payload.message or "",
|
||||
ontology=ontology,
|
||||
context_json=payload.context_json,
|
||||
tool_payload=tool_payload,
|
||||
selected_capability_codes=selected_capability_codes,
|
||||
degraded=degraded,
|
||||
requires_confirmation=requires_confirmation,
|
||||
)
|
||||
),
|
||||
degraded=degraded,
|
||||
)
|
||||
return ExecutionOutcome(
|
||||
status=AgentRunStatus.SUCCEEDED.value,
|
||||
result=result,
|
||||
degraded=degraded,
|
||||
tool_count=1,
|
||||
failed_tool_count=1 if degraded else 0,
|
||||
)
|
||||
|
||||
tool_payload, degraded = self._invoke_tool(
|
||||
run_id=run_id,
|
||||
tool_type=AgentToolType.LLM.value,
|
||||
tool_name="user_agent.draft_placeholder",
|
||||
request_json=self._build_ontology_json(ontology),
|
||||
context_json=payload.context_json,
|
||||
executor=lambda: {
|
||||
"message": (
|
||||
f"已生成 {ontology.scenario} 场景草稿,"
|
||||
"占位能力后续由 Day 5 User Agent 接管。"
|
||||
),
|
||||
"draft_only": True,
|
||||
},
|
||||
fallback_factory=lambda exc: {
|
||||
"message": f"草稿生成暂时不可用,请稍后再试:{exc}",
|
||||
"degraded": True,
|
||||
},
|
||||
)
|
||||
result = self._build_user_agent_result(
|
||||
self.user_agent_service.respond(
|
||||
UserAgentRequest(
|
||||
run_id=run_id,
|
||||
user_id=payload.user_id,
|
||||
message=payload.message or "",
|
||||
ontology=ontology,
|
||||
context_json=payload.context_json,
|
||||
tool_payload=tool_payload,
|
||||
selected_capability_codes=selected_capability_codes,
|
||||
degraded=degraded,
|
||||
requires_confirmation=requires_confirmation,
|
||||
)
|
||||
),
|
||||
degraded=degraded,
|
||||
)
|
||||
return ExecutionOutcome(
|
||||
status=AgentRunStatus.SUCCEEDED.value,
|
||||
result=result,
|
||||
degraded=degraded,
|
||||
tool_count=1,
|
||||
failed_tool_count=1 if degraded else 0,
|
||||
)
|
||||
|
||||
def _execute_hermes(
|
||||
self,
|
||||
*,
|
||||
payload: OrchestratorRequest,
|
||||
run_id: str,
|
||||
ontology: OntologyParseResult,
|
||||
capabilities: dict[str, list[AgentAssetListItem | AgentAssetRead]],
|
||||
requires_confirmation: bool,
|
||||
task_asset: AgentAssetRead | None,
|
||||
) -> ExecutionOutcome:
|
||||
if requires_confirmation:
|
||||
return ExecutionOutcome(
|
||||
status=AgentRunStatus.BLOCKED.value,
|
||||
result={
|
||||
"message": "Hermes 不会自动执行需要确认的高风险动作,已阻断。",
|
||||
"degraded": False,
|
||||
},
|
||||
degraded=False,
|
||||
tool_count=0,
|
||||
failed_tool_count=0,
|
||||
)
|
||||
|
||||
rule_response, rule_degraded = self._invoke_tool(
|
||||
run_id=run_id,
|
||||
tool_type=AgentToolType.RULE_ENGINE.value,
|
||||
tool_name=self._rule_tool_name(capabilities),
|
||||
request_json=self._build_ontology_json(ontology),
|
||||
context_json=payload.context_json,
|
||||
executor=lambda: self._build_rule_answer(ontology),
|
||||
fallback_factory=lambda exc: {
|
||||
"message": f"规则巡检失败,已降级为待人工复核:{exc}",
|
||||
"degraded": True,
|
||||
},
|
||||
)
|
||||
mcp_response, mcp_degraded = self._invoke_tool(
|
||||
run_id=run_id,
|
||||
tool_type=AgentToolType.MCP.value,
|
||||
tool_name=self._mcp_tool_name(capabilities),
|
||||
request_json={
|
||||
"task_code": task_asset.code if task_asset is not None else "",
|
||||
"scenario": ontology.scenario,
|
||||
},
|
||||
context_json=payload.context_json,
|
||||
executor=lambda: self._build_mcp_answer(task_asset, ontology),
|
||||
fallback_factory=lambda exc: {
|
||||
"message": f"MCP 调用失败,已使用缓存快照降级:{exc}",
|
||||
"fallback": "used_cached_snapshot",
|
||||
},
|
||||
)
|
||||
degraded = rule_degraded or mcp_degraded
|
||||
failed_tool_count = int(rule_degraded) + int(mcp_degraded)
|
||||
result = {
|
||||
"message": self._build_hermes_message(
|
||||
task_asset=task_asset,
|
||||
ontology=ontology,
|
||||
rule_response=rule_response,
|
||||
mcp_response=mcp_response,
|
||||
degraded=degraded,
|
||||
),
|
||||
"report_type": task_asset.code if task_asset is not None else "hermes_runtime",
|
||||
"degraded": degraded,
|
||||
}
|
||||
return ExecutionOutcome(
|
||||
status=AgentRunStatus.SUCCEEDED.value,
|
||||
result=result,
|
||||
degraded=degraded,
|
||||
tool_count=2,
|
||||
failed_tool_count=failed_tool_count,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_next_step(ontology: OntologyParseResult, source: str) -> str:
|
||||
if ontology.clarification_required:
|
||||
return "ask_clarification"
|
||||
if ontology.intent == "draft":
|
||||
return "create_draft"
|
||||
if ontology.scenario == "knowledge" or ontology.intent == "explain":
|
||||
return "search_knowledge"
|
||||
if ontology.intent == "risk_check" or source == AgentRunSource.SCHEDULE.value:
|
||||
return "run_rule"
|
||||
if ontology.intent in {"query", "compare"}:
|
||||
return "query_database"
|
||||
return "create_draft"
|
||||
|
||||
@staticmethod
|
||||
def _flatten_capability_codes(
|
||||
capabilities: dict[str, list[AgentAssetListItem | AgentAssetRead]],
|
||||
) -> list[str]:
|
||||
codes: list[str] = []
|
||||
for items in capabilities.values():
|
||||
for item in items[:2]:
|
||||
if item.code not in codes:
|
||||
codes.append(item.code)
|
||||
return codes
|
||||
|
||||
def _rank_assets(
|
||||
self,
|
||||
items: list[AgentAssetListItem],
|
||||
ontology: OntologyParseResult,
|
||||
) -> list[AgentAssetListItem]:
|
||||
def score(item: AgentAssetListItem) -> tuple[int, str]:
|
||||
item_tags = {str(value) for value in item.scenario_json or []}
|
||||
weight = 0
|
||||
if ontology.scenario in item_tags:
|
||||
weight += 3
|
||||
if ontology.intent in item_tags:
|
||||
weight += 2
|
||||
for risk_flag in ontology.risk_flags:
|
||||
if risk_flag in item_tags:
|
||||
weight += 4
|
||||
return weight, item.code
|
||||
|
||||
ranked = sorted(items, key=score, reverse=True)
|
||||
if not ranked:
|
||||
return []
|
||||
scored = [item for item in ranked if score(item)[0] > 0]
|
||||
return scored or ranked[:1]
|
||||
|
||||
def _invoke_tool(
|
||||
self,
|
||||
*,
|
||||
run_id: str,
|
||||
tool_type: str,
|
||||
tool_name: str,
|
||||
request_json: dict[str, Any],
|
||||
context_json: dict[str, Any],
|
||||
executor,
|
||||
fallback_factory,
|
||||
) -> tuple[dict[str, Any], bool]:
|
||||
started = perf_counter()
|
||||
try:
|
||||
self._maybe_raise_simulated_failure(tool_type, context_json)
|
||||
response = executor()
|
||||
duration_ms = int((perf_counter() - started) * 1000)
|
||||
self.run_service.record_tool_call(
|
||||
run_id=run_id,
|
||||
tool_type=tool_type,
|
||||
tool_name=tool_name,
|
||||
request_json=request_json,
|
||||
response_json=response,
|
||||
status="succeeded",
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
return response, False
|
||||
except Exception as exc:
|
||||
duration_ms = int((perf_counter() - started) * 1000)
|
||||
response = fallback_factory(exc)
|
||||
self.run_service.record_tool_call(
|
||||
run_id=run_id,
|
||||
tool_type=tool_type,
|
||||
tool_name=tool_name,
|
||||
request_json=request_json,
|
||||
response_json=response,
|
||||
status="failed",
|
||||
duration_ms=duration_ms,
|
||||
error_message=str(exc),
|
||||
)
|
||||
return response, True
|
||||
|
||||
@staticmethod
|
||||
def _maybe_raise_simulated_failure(tool_type: str, context_json: dict[str, Any]) -> None:
|
||||
expected = str(context_json.get("simulate_tool_failure") or "").strip().lower()
|
||||
if not expected:
|
||||
return
|
||||
if expected == tool_type.lower():
|
||||
raise RuntimeError(f"simulated {tool_type} failure")
|
||||
|
||||
def _build_database_answer(self, ontology: OntologyParseResult) -> dict[str, Any]:
|
||||
if ontology.scenario == "expense":
|
||||
count_stmt = select(func.count()).select_from(ExpenseClaim)
|
||||
amount_stmt = select(
|
||||
func.coalesce(func.sum(ExpenseClaim.amount), 0)
|
||||
).select_from(ExpenseClaim)
|
||||
employee_names = [
|
||||
item.normalized_value
|
||||
for item in ontology.entities
|
||||
if item.type == "employee"
|
||||
]
|
||||
if employee_names:
|
||||
count_stmt = count_stmt.where(ExpenseClaim.employee_name.in_(employee_names))
|
||||
amount_stmt = amount_stmt.where(ExpenseClaim.employee_name.in_(employee_names))
|
||||
total_count = int(self.db.scalar(count_stmt) or 0)
|
||||
total_amount = float(self.db.scalar(amount_stmt) or 0)
|
||||
return {
|
||||
"record_count": total_count,
|
||||
"total_amount": round(total_amount, 2),
|
||||
}
|
||||
|
||||
if ontology.scenario == "accounts_receivable":
|
||||
total_count = int(
|
||||
self.db.scalar(
|
||||
select(func.count()).select_from(AccountsReceivableRecord)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
total_amount = float(
|
||||
self.db.scalar(
|
||||
select(func.coalesce(func.sum(AccountsReceivableRecord.amount_outstanding), 0))
|
||||
)
|
||||
or 0
|
||||
)
|
||||
return {
|
||||
"record_count": total_count,
|
||||
"outstanding_amount": round(total_amount, 2),
|
||||
}
|
||||
|
||||
total_count = int(
|
||||
self.db.scalar(select(func.count()).select_from(AccountsPayableRecord))
|
||||
or 0
|
||||
)
|
||||
total_amount = float(
|
||||
self.db.scalar(
|
||||
select(func.coalesce(func.sum(AccountsPayableRecord.amount_outstanding), 0))
|
||||
)
|
||||
or 0
|
||||
)
|
||||
return {
|
||||
"record_count": total_count,
|
||||
"outstanding_amount": round(total_amount, 2),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _build_user_query_result(
|
||||
ontology: OntologyParseResult,
|
||||
response: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
if ontology.scenario == "expense":
|
||||
return {
|
||||
"message": (
|
||||
f"已路由到 User Agent,占位查询结果:命中 {response['record_count']} 笔报销,"
|
||||
f"金额合计 {response['total_amount']} 元。"
|
||||
),
|
||||
"data": response,
|
||||
}
|
||||
if ontology.scenario == "accounts_receivable":
|
||||
return {
|
||||
"message": (
|
||||
f"已路由到 User Agent,占位查询结果:命中 {response['record_count']} 条应收,"
|
||||
f"未回款金额 {response['outstanding_amount']} 元。"
|
||||
),
|
||||
"data": response,
|
||||
}
|
||||
return {
|
||||
"message": (
|
||||
f"已路由到 User Agent,占位查询结果:命中 {response['record_count']} 条应付,"
|
||||
f"待付金额 {response['outstanding_amount']} 元。"
|
||||
),
|
||||
"data": response,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _build_user_agent_result(
|
||||
response: UserAgentResponse,
|
||||
*,
|
||||
degraded: bool,
|
||||
) -> dict[str, Any]:
|
||||
result = {
|
||||
"message": response.answer,
|
||||
"answer": response.answer,
|
||||
"citations": [item.model_dump() for item in response.citations],
|
||||
"suggested_actions": [item.model_dump() for item in response.suggested_actions],
|
||||
"risk_flags": response.risk_flags,
|
||||
"requires_confirmation": response.requires_confirmation,
|
||||
"degraded": degraded,
|
||||
}
|
||||
if response.draft_payload is not None:
|
||||
result["draft_payload"] = response.draft_payload.model_dump()
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _build_knowledge_answer(
|
||||
ontology: OntologyParseResult,
|
||||
capabilities: dict[str, list[AgentAssetListItem | AgentAssetRead]],
|
||||
) -> dict[str, Any]:
|
||||
referenced = [item.code for item in capabilities["rules"][:1]] or [
|
||||
"knowledge.policy.default"
|
||||
]
|
||||
return {
|
||||
"message": f"已路由到 User Agent,占位知识结果:建议先查看 {', '.join(referenced)}。",
|
||||
"references": referenced,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _build_rule_answer(ontology: OntologyParseResult) -> dict[str, Any]:
|
||||
risk_text = (
|
||||
"、".join(ontology.risk_flags)
|
||||
if ontology.risk_flags
|
||||
else "未识别到明确风险标签"
|
||||
)
|
||||
return {
|
||||
"message": f"已完成占位规则检查,风险标签:{risk_text}。",
|
||||
"risk_flags": ontology.risk_flags,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _build_mcp_answer(
|
||||
task_asset: AgentAssetRead | None,
|
||||
ontology: OntologyParseResult,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"message": (
|
||||
f"已调用占位 MCP 快照,任务={task_asset.code if task_asset else 'none'},"
|
||||
f"scenario={ontology.scenario}。"
|
||||
),
|
||||
"snapshot": "stubbed",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _build_hermes_message(
|
||||
*,
|
||||
task_asset: AgentAssetRead | None,
|
||||
ontology: OntologyParseResult,
|
||||
rule_response: dict[str, Any],
|
||||
mcp_response: dict[str, Any],
|
||||
degraded: bool,
|
||||
) -> str:
|
||||
task_code = task_asset.code if task_asset is not None else "task.unspecified"
|
||||
suffix = ",其中部分能力已降级。" if degraded else "。"
|
||||
return (
|
||||
f"Hermes 占位执行完成:任务 {task_code},"
|
||||
f"场景 {ontology.scenario},规则结果={rule_response.get('message', '')},"
|
||||
f"MCP 结果={mcp_response.get('message', '')}{suffix}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _database_tool_name(scenario: str) -> str:
|
||||
if scenario == "expense":
|
||||
return "database.expense_claims.lookup"
|
||||
if scenario == "accounts_receivable":
|
||||
return "database.accounts_receivable.lookup"
|
||||
return "database.accounts_payable.lookup"
|
||||
|
||||
@staticmethod
|
||||
def _rule_tool_name(
|
||||
capabilities: dict[str, list[AgentAssetListItem | AgentAssetRead]],
|
||||
) -> str:
|
||||
if capabilities["rules"]:
|
||||
return capabilities["rules"][0].code
|
||||
return "rule_engine.default_risk_check"
|
||||
|
||||
@staticmethod
|
||||
def _mcp_tool_name(
|
||||
capabilities: dict[str, list[AgentAssetListItem | AgentAssetRead]],
|
||||
) -> str:
|
||||
if capabilities["mcps"]:
|
||||
return capabilities["mcps"][0].code
|
||||
return "mcp.default_snapshot"
|
||||
|
||||
@staticmethod
|
||||
def _build_ontology_json(ontology: OntologyParseResult) -> dict[str, Any]:
|
||||
return {
|
||||
"scenario": ontology.scenario,
|
||||
"intent": ontology.intent,
|
||||
"entities": [item.model_dump() for item in ontology.entities],
|
||||
"time_range": ontology.time_range.model_dump(),
|
||||
"metrics": [item.model_dump() for item in ontology.metrics],
|
||||
"constraints": [item.model_dump() for item in ontology.constraints],
|
||||
"risk_flags": ontology.risk_flags,
|
||||
"permission": ontology.permission.model_dump(),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _normalize_response_status(status: str) -> str:
|
||||
if status == AgentRunStatus.FAILED.value:
|
||||
return "failed"
|
||||
if status == AgentRunStatus.BLOCKED.value:
|
||||
return "blocked"
|
||||
return "succeeded"
|
||||
252
server/src/app/services/runtime_chat.py
Normal file
252
server/src/app/services/runtime_chat.py
Normal file
@@ -0,0 +1,252 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from http import HTTPStatus
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.services.model_connectivity import (
|
||||
AZURE_API_VERSION,
|
||||
ConnectivityCheckError,
|
||||
_build_azure_deployment_base,
|
||||
_build_headers,
|
||||
_ensure_path,
|
||||
_normalize_endpoint,
|
||||
_send_json_request,
|
||||
)
|
||||
from app.services.settings import SettingsService
|
||||
|
||||
logger = get_logger("app.services.runtime_chat")
|
||||
|
||||
|
||||
class RuntimeChatService:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
self.settings_service = SettingsService(db)
|
||||
|
||||
def complete(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
*,
|
||||
slot_priority: tuple[str, ...] = ("main", "backup"),
|
||||
max_tokens: int = 500,
|
||||
temperature: float = 0.2,
|
||||
) -> str | None:
|
||||
for slot in slot_priority:
|
||||
config = self._load_chat_slot(slot)
|
||||
if config is None:
|
||||
continue
|
||||
|
||||
try:
|
||||
response_text = self._request_chat_completion(
|
||||
config,
|
||||
messages,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Runtime chat request failed slot=%s provider=%s: %s",
|
||||
slot,
|
||||
config["provider"],
|
||||
exc,
|
||||
)
|
||||
continue
|
||||
|
||||
if response_text:
|
||||
return response_text.strip()
|
||||
|
||||
return None
|
||||
|
||||
def _load_chat_slot(self, slot: str) -> dict[str, str] | None:
|
||||
try:
|
||||
config = self.settings_service.get_runtime_model_config(slot)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
if config["capability"] != "chat":
|
||||
return None
|
||||
|
||||
provider = str(config["provider"] or "").strip()
|
||||
endpoint = str(config["endpoint"] or "").strip()
|
||||
model = str(config["model"] or "").strip()
|
||||
api_key = str(config["apiKey"] or "").strip()
|
||||
|
||||
if not provider or not endpoint or not model:
|
||||
return None
|
||||
|
||||
if provider != "Ollama" and not api_key:
|
||||
logger.info("Skip runtime chat slot=%s because api key is empty", slot)
|
||||
return None
|
||||
|
||||
return {
|
||||
"slot": slot,
|
||||
"provider": provider,
|
||||
"endpoint": endpoint,
|
||||
"model": model,
|
||||
"apiKey": api_key,
|
||||
}
|
||||
|
||||
def _request_chat_completion(
|
||||
self,
|
||||
config: dict[str, str],
|
||||
messages: list[dict[str, str]],
|
||||
*,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
) -> str:
|
||||
provider = config["provider"]
|
||||
endpoint = config["endpoint"]
|
||||
model = config["model"]
|
||||
api_key = config["apiKey"]
|
||||
|
||||
if provider == "Azure OpenAI":
|
||||
return self._request_azure_openai(
|
||||
endpoint=endpoint,
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
if provider == "Ollama":
|
||||
return self._request_ollama(
|
||||
endpoint=endpoint,
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
return self._request_openai_compatible(
|
||||
endpoint=endpoint,
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
def _request_openai_compatible(
|
||||
self,
|
||||
*,
|
||||
endpoint: str,
|
||||
model: str,
|
||||
api_key: str,
|
||||
messages: list[dict[str, str]],
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
) -> str:
|
||||
url = _ensure_path(_normalize_endpoint(endpoint), "chat/completions")
|
||||
status_code, payload = _send_json_request(
|
||||
"POST",
|
||||
url,
|
||||
headers=_build_headers(api_key=api_key, use_bearer=True),
|
||||
payload={
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
},
|
||||
)
|
||||
if status_code >= HTTPStatus.BAD_REQUEST:
|
||||
raise ConnectivityCheckError(
|
||||
f"模型接口返回异常状态 {status_code}。",
|
||||
status_code=status_code,
|
||||
)
|
||||
return self._extract_openai_text(payload)
|
||||
|
||||
def _request_ollama(
|
||||
self,
|
||||
*,
|
||||
endpoint: str,
|
||||
model: str,
|
||||
api_key: str,
|
||||
messages: list[dict[str, str]],
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
) -> str:
|
||||
url = _ensure_path(_normalize_endpoint(endpoint), "api/chat")
|
||||
status_code, payload = _send_json_request(
|
||||
"POST",
|
||||
url,
|
||||
headers=_build_headers(api_key=api_key, use_bearer=False),
|
||||
payload={
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"stream": False,
|
||||
"options": {
|
||||
"num_predict": max_tokens,
|
||||
"temperature": temperature,
|
||||
},
|
||||
},
|
||||
)
|
||||
if status_code >= HTTPStatus.BAD_REQUEST:
|
||||
raise ConnectivityCheckError(
|
||||
f"Ollama 返回异常状态 {status_code}。",
|
||||
status_code=status_code,
|
||||
)
|
||||
return str((payload or {}).get("message", {}).get("content", "")).strip()
|
||||
|
||||
def _request_azure_openai(
|
||||
self,
|
||||
*,
|
||||
endpoint: str,
|
||||
model: str,
|
||||
api_key: str,
|
||||
messages: list[dict[str, str]],
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
) -> str:
|
||||
deployment_base = _build_azure_deployment_base(endpoint, model)
|
||||
url = f"{deployment_base}/chat/completions?api-version={AZURE_API_VERSION}"
|
||||
status_code, payload = _send_json_request(
|
||||
"POST",
|
||||
url,
|
||||
headers=_build_headers(api_key=api_key, use_bearer=False, use_api_key=True),
|
||||
payload={
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
},
|
||||
)
|
||||
if status_code >= HTTPStatus.BAD_REQUEST:
|
||||
raise ConnectivityCheckError(
|
||||
f"Azure OpenAI 返回异常状态 {status_code}。",
|
||||
status_code=status_code,
|
||||
)
|
||||
return self._extract_openai_text(payload)
|
||||
|
||||
@staticmethod
|
||||
def _extract_openai_text(payload: Any) -> str:
|
||||
if not isinstance(payload, dict):
|
||||
return ""
|
||||
|
||||
choices = payload.get("choices")
|
||||
if not isinstance(choices, list) or not choices:
|
||||
return ""
|
||||
|
||||
first_choice = choices[0]
|
||||
if not isinstance(first_choice, dict):
|
||||
return ""
|
||||
|
||||
message = first_choice.get("message")
|
||||
if isinstance(message, dict):
|
||||
content = message.get("content", "")
|
||||
if isinstance(content, str):
|
||||
return content.strip()
|
||||
if isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for item in content:
|
||||
if isinstance(item, dict) and item.get("type") == "text":
|
||||
parts.append(str(item.get("text", "")))
|
||||
return "\n".join(part.strip() for part in parts if part.strip()).strip()
|
||||
|
||||
text = first_choice.get("text")
|
||||
if isinstance(text, str):
|
||||
return text.strip()
|
||||
|
||||
return ""
|
||||
547
server/src/app/services/user_agent.py
Normal file
547
server/src/app/services/user_agent.py
Normal file
@@ -0,0 +1,547 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.agent_enums import AgentAssetStatus, AgentAssetType
|
||||
from app.schemas.agent_asset import AgentAssetListItem
|
||||
from app.schemas.user_agent import (
|
||||
UserAgentCitation,
|
||||
UserAgentDraftPayload,
|
||||
UserAgentRequest,
|
||||
UserAgentResponse,
|
||||
UserAgentSuggestedAction,
|
||||
)
|
||||
from app.services.agent_assets import AgentAssetService
|
||||
from app.services.agent_foundation import AgentFoundationService
|
||||
from app.services.runtime_chat import RuntimeChatService
|
||||
|
||||
SCENARIO_LABELS = {
|
||||
"expense": "报销",
|
||||
"accounts_receivable": "应收",
|
||||
"accounts_payable": "应付",
|
||||
"knowledge": "知识",
|
||||
"unknown": "通用",
|
||||
}
|
||||
|
||||
RISK_REASON_MAP = {
|
||||
"duplicate_expense": "检测到同员工、同金额或近似单据存在重复提交迹象。",
|
||||
"amount_over_limit": "金额超过当前制度或预算阈值,需要补充例外说明。",
|
||||
"invoice_anomaly": "票据或附件完整性不满足当前规则要求,需要补件或人工复核。",
|
||||
"ar_overdue": "应收账款已出现逾期,存在回款延迟风险。",
|
||||
"ap_overdue": "应付付款已出现逾期,可能影响供应商履约或合作关系。",
|
||||
}
|
||||
|
||||
GENERIC_EXPENSE_PROMPTS = {
|
||||
"报销",
|
||||
"我要报销",
|
||||
"我想报销",
|
||||
"帮我报销",
|
||||
"我要申请报销",
|
||||
"发起报销",
|
||||
"提交报销",
|
||||
}
|
||||
|
||||
EXPLICIT_DRAFT_KEYWORDS = ("生成", "草稿", "起草", "创建", "发起", "准备")
|
||||
|
||||
EXPENSE_TYPE_LABELS = {
|
||||
"travel": "差旅",
|
||||
"hotel": "住宿",
|
||||
"transport": "交通",
|
||||
"meal": "餐费",
|
||||
"meeting": "会务",
|
||||
"entertainment": "招待",
|
||||
}
|
||||
|
||||
|
||||
class UserAgentService:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
self.asset_service = AgentAssetService(db)
|
||||
self.runtime_chat_service = RuntimeChatService(db)
|
||||
|
||||
def respond(self, payload: UserAgentRequest) -> UserAgentResponse:
|
||||
AgentFoundationService(self.db).ensure_foundation_ready()
|
||||
citations = self._build_rule_citations(payload)
|
||||
suggested_actions = self._build_suggested_actions(payload)
|
||||
risk_flags = self._resolve_risk_flags(payload)
|
||||
draft_payload = (
|
||||
self._build_draft_payload(payload)
|
||||
if payload.ontology.intent == "draft"
|
||||
else None
|
||||
)
|
||||
|
||||
if payload.degraded and payload.tool_payload.get("message"):
|
||||
return UserAgentResponse(
|
||||
answer=str(payload.tool_payload["message"]),
|
||||
citations=citations,
|
||||
suggested_actions=suggested_actions,
|
||||
risk_flags=risk_flags,
|
||||
requires_confirmation=payload.requires_confirmation,
|
||||
)
|
||||
|
||||
guided_answer = self._build_guided_answer(payload)
|
||||
if guided_answer:
|
||||
return UserAgentResponse(
|
||||
answer=guided_answer,
|
||||
citations=citations,
|
||||
suggested_actions=suggested_actions,
|
||||
draft_payload=draft_payload,
|
||||
risk_flags=risk_flags,
|
||||
requires_confirmation=payload.requires_confirmation,
|
||||
)
|
||||
|
||||
fallback_answer = self._build_fallback_answer(
|
||||
payload,
|
||||
citations=citations,
|
||||
draft_payload=draft_payload,
|
||||
)
|
||||
answer = self._generate_answer_with_model(
|
||||
payload,
|
||||
citations=citations,
|
||||
suggested_actions=suggested_actions,
|
||||
risk_flags=risk_flags,
|
||||
draft_payload=draft_payload,
|
||||
fallback_answer=fallback_answer,
|
||||
)
|
||||
|
||||
return UserAgentResponse(
|
||||
answer=answer or fallback_answer,
|
||||
citations=citations,
|
||||
suggested_actions=suggested_actions,
|
||||
draft_payload=draft_payload,
|
||||
risk_flags=risk_flags,
|
||||
requires_confirmation=payload.requires_confirmation,
|
||||
)
|
||||
|
||||
def _build_fallback_answer(
|
||||
self,
|
||||
payload: UserAgentRequest,
|
||||
*,
|
||||
citations: list[UserAgentCitation],
|
||||
draft_payload: UserAgentDraftPayload | None,
|
||||
) -> str:
|
||||
if payload.ontology.intent in {"query", "compare"}:
|
||||
return self._build_query_answer(payload)
|
||||
|
||||
if payload.ontology.intent == "risk_check":
|
||||
return self._build_risk_answer(payload, citations)
|
||||
|
||||
if payload.ontology.intent == "draft" and draft_payload is not None:
|
||||
return (
|
||||
f"已生成 {draft_payload.title},当前仅返回待人工确认的草稿内容,"
|
||||
"仍需人工确认后再进入正式流程。"
|
||||
)
|
||||
|
||||
return self._build_explain_answer(payload, citations)
|
||||
|
||||
def _build_guided_answer(self, payload: UserAgentRequest) -> str | None:
|
||||
if not self._is_generic_expense_prompt(payload):
|
||||
return self._build_implicit_expense_draft_guidance(payload)
|
||||
|
||||
attachment_names = self._resolve_attachment_names(payload)
|
||||
attachment_hint = ""
|
||||
if attachment_names:
|
||||
attachment_hint = (
|
||||
f" 我已带入 {len(attachment_names)} 份附件名称,但目前还不能直接读取附件内容,"
|
||||
"仍需要你补充关键信息。"
|
||||
)
|
||||
|
||||
return (
|
||||
"可以帮你发起报销。请补充费用类型、发生时间、金额、事由和相关对象,"
|
||||
"或者直接上传票据附件,我再继续帮你判断能否报、缺什么材料以及生成报销草稿。"
|
||||
f"{attachment_hint}"
|
||||
)
|
||||
|
||||
def _build_implicit_expense_draft_guidance(
|
||||
self,
|
||||
payload: UserAgentRequest,
|
||||
) -> str | None:
|
||||
if not self._is_implicit_expense_draft_request(payload):
|
||||
return None
|
||||
|
||||
amount_text = next(
|
||||
(item.value for item in payload.ontology.entities if item.type == "amount"),
|
||||
"",
|
||||
)
|
||||
expense_type = next(
|
||||
(
|
||||
EXPENSE_TYPE_LABELS.get(item.normalized_value, item.value)
|
||||
for item in payload.ontology.entities
|
||||
if item.type == "expense_type"
|
||||
),
|
||||
"报销",
|
||||
)
|
||||
time_text = payload.ontology.time_range.raw or "本次"
|
||||
amount_hint = f",金额 {amount_text}" if amount_text else ""
|
||||
|
||||
return (
|
||||
f"已识别到一笔{time_text}的{expense_type}支出{amount_hint}。"
|
||||
"如果要继续生成报销草稿,还需要补充客户单位、参与人员、费用明细和票据附件。"
|
||||
"你也可以继续上传发票或图片,我会把这些信息带入后续对话。"
|
||||
)
|
||||
|
||||
def _generate_answer_with_model(
|
||||
self,
|
||||
payload: UserAgentRequest,
|
||||
*,
|
||||
citations: list[UserAgentCitation],
|
||||
suggested_actions: list[UserAgentSuggestedAction],
|
||||
risk_flags: list[str],
|
||||
draft_payload: UserAgentDraftPayload | None,
|
||||
fallback_answer: str,
|
||||
) -> str | None:
|
||||
messages = self._build_model_messages(
|
||||
payload,
|
||||
citations=citations,
|
||||
suggested_actions=suggested_actions,
|
||||
risk_flags=risk_flags,
|
||||
draft_payload=draft_payload,
|
||||
fallback_answer=fallback_answer,
|
||||
)
|
||||
return self._sanitize_model_answer(
|
||||
self.runtime_chat_service.complete(
|
||||
messages,
|
||||
max_tokens=420,
|
||||
temperature=0.2,
|
||||
)
|
||||
)
|
||||
|
||||
def _sanitize_model_answer(self, answer: str | None) -> str | None:
|
||||
if not answer:
|
||||
return None
|
||||
|
||||
cleaned = re.sub(r"<think>.*?</think>", "", answer, flags=re.DOTALL | re.IGNORECASE)
|
||||
cleaned = cleaned.strip()
|
||||
return cleaned or None
|
||||
|
||||
def _build_model_messages(
|
||||
self,
|
||||
payload: UserAgentRequest,
|
||||
*,
|
||||
citations: list[UserAgentCitation],
|
||||
suggested_actions: list[UserAgentSuggestedAction],
|
||||
risk_flags: list[str],
|
||||
draft_payload: UserAgentDraftPayload | None,
|
||||
fallback_answer: str,
|
||||
) -> list[dict[str, str]]:
|
||||
facts = {
|
||||
"run_id": payload.run_id,
|
||||
"user_message": payload.message,
|
||||
"ontology": payload.ontology.model_dump(mode="json"),
|
||||
"context": {
|
||||
"entry_source": payload.context_json.get("entry_source"),
|
||||
"user_name": payload.context_json.get("name"),
|
||||
"user_role": payload.context_json.get("role"),
|
||||
"request_context": payload.context_json.get("request_context"),
|
||||
"attachment_count": payload.context_json.get("attachment_count"),
|
||||
"attachment_names": self._resolve_attachment_names(payload),
|
||||
},
|
||||
"tool_payload": payload.tool_payload,
|
||||
"citations": [item.model_dump(mode="json") for item in citations],
|
||||
"suggested_actions": [
|
||||
item.model_dump(mode="json") for item in suggested_actions
|
||||
],
|
||||
"risk_flags": risk_flags,
|
||||
"draft_payload": (
|
||||
draft_payload.model_dump(mode="json")
|
||||
if draft_payload is not None
|
||||
else None
|
||||
),
|
||||
"selected_capability_codes": payload.selected_capability_codes,
|
||||
"requires_confirmation": payload.requires_confirmation,
|
||||
"fallback_answer": fallback_answer,
|
||||
}
|
||||
|
||||
system_prompt = (
|
||||
"你是企业财务共享场景中的中文智能助手,负责和最终用户直接对话。"
|
||||
"你只能基于提供的事实回答,不能编造制度、流程结果或附件内容。"
|
||||
"如果用户问题很笼统,例如“我要报销”,优先告诉用户你可以协助什么,"
|
||||
"并明确要求补充费用类型、金额、时间、事由、参与对象或上传票据。"
|
||||
"如果上下文里只有附件名称,必须明确说明你只拿到了附件名称,"
|
||||
"不能假装已看过图片、PDF 或发票内容。"
|
||||
"不要声称已经提交、审批、付款、入账或真正执行了任何动作;如果只是建议、草稿或待确认,要明确说清楚。"
|
||||
"若给出了风险标签、制度引用或建议动作,可以简洁吸收进回答,但不要新增未提供的事实。"
|
||||
"只输出最终给用户看的自然语言,不要输出 JSON、Markdown、标题、"
|
||||
"<think> 标签或任何中间推理。"
|
||||
"使用简体中文,控制在 2 到 4 句。"
|
||||
)
|
||||
user_prompt = (
|
||||
"请根据以下事实生成最终答复,优先保持准确、具体、可执行:\n"
|
||||
f"{json.dumps(facts, ensure_ascii=False, indent=2)}"
|
||||
)
|
||||
return [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
]
|
||||
|
||||
def _build_query_answer(self, payload: UserAgentRequest) -> str:
|
||||
scenario = payload.ontology.scenario
|
||||
data = payload.tool_payload
|
||||
subject = self._resolve_subject(payload)
|
||||
|
||||
if scenario == "expense":
|
||||
record_count = int(data.get("record_count") or 0)
|
||||
total_amount = float(data.get("total_amount") or 0)
|
||||
return (
|
||||
f"{subject}共命中 {record_count} 笔报销,金额合计 {total_amount:.2f} 元。"
|
||||
"如需继续处理,可以查看明细或生成处理意见草稿。"
|
||||
)
|
||||
|
||||
if scenario == "accounts_receivable":
|
||||
record_count = int(data.get("record_count") or 0)
|
||||
outstanding_amount = float(data.get("outstanding_amount") or 0)
|
||||
return (
|
||||
f"{subject}共命中 {record_count} 条应收,未回款金额 {outstanding_amount:.2f} 元。"
|
||||
"建议结合账龄和客户分布继续排查逾期风险。"
|
||||
)
|
||||
|
||||
if scenario == "accounts_payable":
|
||||
record_count = int(data.get("record_count") or 0)
|
||||
outstanding_amount = float(data.get("outstanding_amount") or 0)
|
||||
return (
|
||||
f"{subject}共命中 {record_count} 条应付,待付金额 {outstanding_amount:.2f} 元。"
|
||||
"如需推进动作,建议先生成付款建议草稿并发起人工确认。"
|
||||
)
|
||||
|
||||
return "已完成当前查询,但暂时没有更多结构化结果可展示。"
|
||||
|
||||
def _build_explain_answer(
|
||||
self,
|
||||
payload: UserAgentRequest,
|
||||
citations: list[UserAgentCitation],
|
||||
) -> str:
|
||||
if citations:
|
||||
titles = "、".join(item.title for item in citations[:2])
|
||||
summary = citations[0].excerpt or "请结合制度全文进一步确认。"
|
||||
return f"已检索到相关依据:{titles}。核心说明:{summary}"
|
||||
|
||||
return (
|
||||
f"当前还没有与“{SCENARIO_LABELS.get(payload.ontology.scenario, '当前问题')}”"
|
||||
"强匹配的已上线规则引用,建议先人工复核或补充更具体的单据上下文。"
|
||||
)
|
||||
|
||||
def _build_risk_answer(
|
||||
self,
|
||||
payload: UserAgentRequest,
|
||||
citations: list[UserAgentCitation],
|
||||
) -> str:
|
||||
risk_flags = self._resolve_risk_flags(payload)
|
||||
if not risk_flags:
|
||||
return "当前未识别到明确风险标签,建议继续查看原始明细或补充更多上下文。"
|
||||
|
||||
reasons = [RISK_REASON_MAP.get(flag, f"{flag} 需要人工进一步确认。") for flag in risk_flags]
|
||||
citation_text = (
|
||||
f" 参考规则:{'、'.join(item.title for item in citations[:2])}。"
|
||||
if citations
|
||||
else ""
|
||||
)
|
||||
return (
|
||||
f"本次识别到 {len(risk_flags)} 类风险:{'、'.join(risk_flags)}。"
|
||||
f"触发原因:{';'.join(reasons)}。"
|
||||
"建议先复核明细、附件和审批链,再决定是否继续处理。"
|
||||
f"{citation_text}"
|
||||
)
|
||||
|
||||
def _build_draft_payload(self, payload: UserAgentRequest) -> UserAgentDraftPayload:
|
||||
scenario_label = SCENARIO_LABELS.get(payload.ontology.scenario, "业务")
|
||||
subject = self._resolve_subject(payload)
|
||||
title = f"{scenario_label}处理意见草稿"
|
||||
body = (
|
||||
f"主题:{subject}\n"
|
||||
"结论:已根据当前语义解析结果生成草稿,尚未自动执行。\n"
|
||||
"建议:请先核对明细、规则命中和所需附件,再由人工确认是否提交正式流程。\n"
|
||||
f"原始问题:{payload.message}"
|
||||
)
|
||||
return UserAgentDraftPayload(
|
||||
draft_type=payload.ontology.scenario,
|
||||
title=title,
|
||||
body=body,
|
||||
confirmation_required=True,
|
||||
)
|
||||
|
||||
def _build_suggested_actions(
|
||||
self,
|
||||
payload: UserAgentRequest,
|
||||
) -> list[UserAgentSuggestedAction]:
|
||||
if self._is_generic_expense_prompt(payload):
|
||||
return [
|
||||
UserAgentSuggestedAction(
|
||||
label="上传票据",
|
||||
action_type="ask_clarification",
|
||||
description="上传发票、行程单或付款截图,继续识别报销内容。",
|
||||
),
|
||||
UserAgentSuggestedAction(
|
||||
label="补充报销信息",
|
||||
action_type="ask_clarification",
|
||||
description="补充费用类型、金额、时间和事由后继续处理。",
|
||||
),
|
||||
]
|
||||
|
||||
if payload.ontology.intent in {"query", "compare"}:
|
||||
return [
|
||||
UserAgentSuggestedAction(
|
||||
label="查看明细",
|
||||
action_type="open_detail",
|
||||
description="继续查看命中记录和过滤条件。",
|
||||
),
|
||||
UserAgentSuggestedAction(
|
||||
label="生成处理意见",
|
||||
action_type="create_draft",
|
||||
description="把当前查询结果整理成可确认草稿。",
|
||||
),
|
||||
]
|
||||
|
||||
if payload.ontology.intent == "risk_check":
|
||||
return [
|
||||
UserAgentSuggestedAction(
|
||||
label="人工复核风险",
|
||||
action_type="manual_review",
|
||||
description="优先检查明细、附件和规则命中原因。",
|
||||
),
|
||||
UserAgentSuggestedAction(
|
||||
label="生成整改建议",
|
||||
action_type="create_draft",
|
||||
description="把风险说明整理成处理意见草稿。",
|
||||
),
|
||||
]
|
||||
|
||||
if payload.ontology.intent == "draft":
|
||||
return [
|
||||
UserAgentSuggestedAction(
|
||||
label="复制草稿",
|
||||
action_type="copy_draft",
|
||||
description="复制当前草稿后交由人工确认。",
|
||||
),
|
||||
UserAgentSuggestedAction(
|
||||
label="补充上下文",
|
||||
action_type="ask_clarification",
|
||||
description="补充单据编号、客户或供应商信息以完善草稿。",
|
||||
),
|
||||
]
|
||||
|
||||
return [
|
||||
UserAgentSuggestedAction(
|
||||
label="查看规则全文",
|
||||
action_type="open_rule",
|
||||
description="继续查看引用规则或知识内容。",
|
||||
),
|
||||
UserAgentSuggestedAction(
|
||||
label="补充问题上下文",
|
||||
action_type="ask_clarification",
|
||||
description="补充业务对象、时间或单据范围,提升回答准确度。",
|
||||
),
|
||||
]
|
||||
|
||||
def _build_rule_citations(self, payload: UserAgentRequest) -> list[UserAgentCitation]:
|
||||
domain = self._resolve_domain(payload.ontology.scenario)
|
||||
items = self.asset_service.list_assets(
|
||||
asset_type=AgentAssetType.RULE.value,
|
||||
status=AgentAssetStatus.ACTIVE.value,
|
||||
domain=domain,
|
||||
)
|
||||
ranked = self._rank_rule_assets(items, payload)
|
||||
citations: list[UserAgentCitation] = []
|
||||
for item in ranked[:2]:
|
||||
detail = self.asset_service.get_asset(item.id)
|
||||
if detail is None:
|
||||
continue
|
||||
excerpt = self._extract_excerpt(str(detail.current_version_content or ""))
|
||||
citations.append(
|
||||
UserAgentCitation(
|
||||
source_type="rule",
|
||||
code=detail.code,
|
||||
title=detail.name,
|
||||
version=detail.current_version,
|
||||
updated_at=detail.updated_at.date().isoformat(),
|
||||
excerpt=excerpt,
|
||||
)
|
||||
)
|
||||
return citations
|
||||
|
||||
@staticmethod
|
||||
def _resolve_risk_flags(payload: UserAgentRequest) -> list[str]:
|
||||
tool_flags = payload.tool_payload.get("risk_flags")
|
||||
if isinstance(tool_flags, list) and tool_flags:
|
||||
return [str(item) for item in tool_flags]
|
||||
return [str(item) for item in payload.ontology.risk_flags]
|
||||
|
||||
@staticmethod
|
||||
def _resolve_subject(payload: UserAgentRequest) -> str:
|
||||
named_entities = [
|
||||
item.value
|
||||
for item in payload.ontology.entities
|
||||
if item.type in {"employee", "customer", "vendor", "project"}
|
||||
]
|
||||
if named_entities:
|
||||
return f"{'、'.join(named_entities)} 相关数据"
|
||||
return f"{SCENARIO_LABELS.get(payload.ontology.scenario, '当前')}场景数据"
|
||||
|
||||
@staticmethod
|
||||
def _is_generic_expense_prompt(payload: UserAgentRequest) -> bool:
|
||||
if payload.ontology.scenario != "expense":
|
||||
return False
|
||||
normalized_message = re.sub(r"\s+", "", payload.message)
|
||||
return normalized_message in GENERIC_EXPENSE_PROMPTS
|
||||
|
||||
@staticmethod
|
||||
def _is_implicit_expense_draft_request(payload: UserAgentRequest) -> bool:
|
||||
if payload.ontology.scenario != "expense" or payload.ontology.intent != "draft":
|
||||
return False
|
||||
|
||||
compact_message = re.sub(r"\s+", "", payload.message)
|
||||
if any(keyword in compact_message for keyword in EXPLICIT_DRAFT_KEYWORDS):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _resolve_attachment_names(payload: UserAgentRequest) -> list[str]:
|
||||
names = payload.context_json.get("attachment_names")
|
||||
if not isinstance(names, list):
|
||||
return []
|
||||
return [str(name) for name in names if str(name).strip()]
|
||||
|
||||
@staticmethod
|
||||
def _resolve_domain(scenario: str) -> str | None:
|
||||
if scenario == "expense":
|
||||
return "expense"
|
||||
if scenario == "accounts_receivable":
|
||||
return "ar"
|
||||
if scenario == "accounts_payable":
|
||||
return "ap"
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _rank_rule_assets(
|
||||
items: list[AgentAssetListItem],
|
||||
payload: UserAgentRequest,
|
||||
) -> list[AgentAssetListItem]:
|
||||
def score(item: AgentAssetListItem) -> tuple[int, str]:
|
||||
tags = {str(value) for value in item.scenario_json or []}
|
||||
weight = 0
|
||||
if payload.ontology.scenario in tags:
|
||||
weight += 3
|
||||
if payload.ontology.intent in tags:
|
||||
weight += 2
|
||||
for risk_flag in payload.ontology.risk_flags:
|
||||
if risk_flag in tags:
|
||||
weight += 4
|
||||
return weight, item.code
|
||||
|
||||
ranked = sorted(items, key=score, reverse=True)
|
||||
return [item for item in ranked if score(item)[0] > 0]
|
||||
|
||||
@staticmethod
|
||||
def _extract_excerpt(content: str) -> str:
|
||||
lines = [line.strip() for line in str(content).splitlines() if line.strip()]
|
||||
cleaned: list[str] = []
|
||||
for line in lines:
|
||||
normalized = re.sub(r"^[#>\-\*\d\.\s`]+", "", line).strip()
|
||||
if normalized:
|
||||
cleaned.append(normalized)
|
||||
if len(cleaned) >= 2:
|
||||
break
|
||||
return ";".join(cleaned[:2])
|
||||
397
server/tests/test_ontology_service.py
Normal file
397
server/tests/test_ontology_service.py
Normal file
@@ -0,0 +1,397 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from app.api.deps import get_db
|
||||
from app.db.base import Base
|
||||
from app.main import create_app
|
||||
from app.schemas.ontology import OntologyParseRequest
|
||||
from app.services.ontology import LlmOntologyParseResult, SemanticOntologyService
|
||||
|
||||
|
||||
def build_session_factory() -> sessionmaker[Session]:
|
||||
engine = create_engine(
|
||||
"sqlite+pysqlite:///:memory:",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
return sessionmaker(bind=engine, autoflush=False, autocommit=False)
|
||||
|
||||
|
||||
def build_client() -> tuple[TestClient, sessionmaker[Session]]:
|
||||
session_factory = build_session_factory()
|
||||
app = create_app()
|
||||
|
||||
def override_db() -> Generator[Session, None, None]:
|
||||
db = session_factory()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
app.dependency_overrides[get_db] = override_db
|
||||
return TestClient(app), session_factory
|
||||
|
||||
|
||||
EVALUATION_CASES = [
|
||||
pytest.param(
|
||||
"查一下本周报销超标风险",
|
||||
"expense",
|
||||
"risk_check",
|
||||
"read",
|
||||
{},
|
||||
id="expense-risk-check",
|
||||
),
|
||||
pytest.param(
|
||||
"张三 4 月差旅报销金额是多少",
|
||||
"expense",
|
||||
"query",
|
||||
"read",
|
||||
{},
|
||||
id="expense-query-employee-month",
|
||||
),
|
||||
pytest.param(
|
||||
"为什么酒店超标报销不能直接通过",
|
||||
"expense",
|
||||
"explain",
|
||||
"read",
|
||||
{},
|
||||
id="expense-explain-policy",
|
||||
),
|
||||
pytest.param(
|
||||
"列出金额最高的10笔报销",
|
||||
"expense",
|
||||
"query",
|
||||
"read",
|
||||
{},
|
||||
id="expense-topn-query",
|
||||
),
|
||||
pytest.param(
|
||||
"帮我生成张三4月差旅报销草稿",
|
||||
"expense",
|
||||
"draft",
|
||||
"draft_write",
|
||||
{},
|
||||
id="expense-draft",
|
||||
),
|
||||
pytest.param(
|
||||
"我今天去客户现场,招待了客户,花销了1000元",
|
||||
"expense",
|
||||
"draft",
|
||||
"draft_write",
|
||||
{},
|
||||
id="expense-narrative-draft",
|
||||
),
|
||||
pytest.param(
|
||||
"客户 A 这个月还有多少应收",
|
||||
"accounts_receivable",
|
||||
"query",
|
||||
"read",
|
||||
{},
|
||||
id="ar-query-customer-month",
|
||||
),
|
||||
pytest.param(
|
||||
"对比客户A和客户B本月应收差异",
|
||||
"accounts_receivable",
|
||||
"compare",
|
||||
"read",
|
||||
{},
|
||||
id="ar-compare-customers",
|
||||
),
|
||||
pytest.param(
|
||||
"检查客户B逾期应收风险",
|
||||
"accounts_receivable",
|
||||
"risk_check",
|
||||
"read",
|
||||
{},
|
||||
id="ar-risk-check",
|
||||
),
|
||||
pytest.param(
|
||||
"生成客户A回款跟进草稿",
|
||||
"accounts_receivable",
|
||||
"draft",
|
||||
"draft_write",
|
||||
{},
|
||||
id="ar-draft",
|
||||
),
|
||||
pytest.param(
|
||||
"查询客户B账龄明细",
|
||||
"accounts_receivable",
|
||||
"query",
|
||||
"read",
|
||||
{},
|
||||
id="ar-aging-query",
|
||||
),
|
||||
pytest.param(
|
||||
"供应商 B 明天要付多少钱",
|
||||
"accounts_payable",
|
||||
"query",
|
||||
"read",
|
||||
{},
|
||||
id="ap-query-vendor-tomorrow",
|
||||
),
|
||||
pytest.param(
|
||||
"对比供应商A和供应商B本月应付差异",
|
||||
"accounts_payable",
|
||||
"compare",
|
||||
"read",
|
||||
{},
|
||||
id="ap-compare-vendors",
|
||||
),
|
||||
pytest.param(
|
||||
"检查供应商B逾期付款风险",
|
||||
"accounts_payable",
|
||||
"risk_check",
|
||||
"read",
|
||||
{},
|
||||
id="ap-risk-check",
|
||||
),
|
||||
pytest.param(
|
||||
"生成供应商A付款沟通草稿",
|
||||
"accounts_payable",
|
||||
"draft",
|
||||
"draft_write",
|
||||
{},
|
||||
id="ap-draft",
|
||||
),
|
||||
pytest.param(
|
||||
"帮我安排付款给供应商B",
|
||||
"accounts_payable",
|
||||
"operate",
|
||||
"approval_required",
|
||||
{"role_codes": ["finance"]},
|
||||
id="ap-operate-approval-required",
|
||||
),
|
||||
pytest.param(
|
||||
"公司财务制度在哪里看",
|
||||
"knowledge",
|
||||
"query",
|
||||
"read",
|
||||
{},
|
||||
id="knowledge-query",
|
||||
),
|
||||
pytest.param(
|
||||
"规则中心的审核依据是什么",
|
||||
"knowledge",
|
||||
"explain",
|
||||
"read",
|
||||
{},
|
||||
id="knowledge-explain",
|
||||
),
|
||||
pytest.param(
|
||||
"知识库里有没有双人复核制度",
|
||||
"knowledge",
|
||||
"query",
|
||||
"read",
|
||||
{},
|
||||
id="knowledge-query-library",
|
||||
),
|
||||
pytest.param(
|
||||
"帮我直接付款给供应商B",
|
||||
"accounts_payable",
|
||||
"operate",
|
||||
"forbidden",
|
||||
{"role_codes": ["user"]},
|
||||
id="forbidden-direct-payment",
|
||||
),
|
||||
pytest.param(
|
||||
"帮我上线付款双人复核规则",
|
||||
"accounts_payable",
|
||||
"operate",
|
||||
"forbidden",
|
||||
{"role_codes": ["user"]},
|
||||
id="forbidden-activate-rule",
|
||||
),
|
||||
pytest.param(
|
||||
"帮我删除今天的报销记录",
|
||||
"expense",
|
||||
"operate",
|
||||
"forbidden",
|
||||
{"role_codes": ["user"]},
|
||||
id="forbidden-delete-expense",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("query,scenario,intent,permission,context_json", EVALUATION_CASES)
|
||||
def test_semantic_ontology_service_matches_day3_evaluation_set(
|
||||
query: str,
|
||||
scenario: str,
|
||||
intent: str,
|
||||
permission: str,
|
||||
context_json: dict,
|
||||
) -> None:
|
||||
session_factory = build_session_factory()
|
||||
with session_factory() as db:
|
||||
result = SemanticOntologyService(db).parse(
|
||||
OntologyParseRequest(
|
||||
query=query,
|
||||
user_id="pytest",
|
||||
context_json=context_json,
|
||||
)
|
||||
)
|
||||
|
||||
assert result.scenario == scenario
|
||||
assert result.intent == intent
|
||||
assert result.permission.level == permission
|
||||
assert result.run_id.startswith("run_")
|
||||
|
||||
|
||||
def test_semantic_ontology_service_extracts_entities_time_and_constraints() -> None:
|
||||
session_factory = build_session_factory()
|
||||
with session_factory() as db:
|
||||
result = SemanticOntologyService(db).parse(
|
||||
OntologyParseRequest(
|
||||
query="张三 2026年4月差旅报销金额超过5000元的明细",
|
||||
user_id="pytest",
|
||||
)
|
||||
)
|
||||
|
||||
assert result.scenario == "expense"
|
||||
assert result.intent == "query"
|
||||
assert result.time_range.start_date == "2026-04-01"
|
||||
assert result.time_range.end_date == "2026-04-30"
|
||||
assert any(
|
||||
item.type == "employee" and item.normalized_value == "张三"
|
||||
for item in result.entities
|
||||
)
|
||||
assert any(
|
||||
item.type == "expense_type" and item.normalized_value == "travel"
|
||||
for item in result.entities
|
||||
)
|
||||
assert any(
|
||||
item.field == "amount" and item.operator == ">" and item.value == 5000
|
||||
for item in result.constraints
|
||||
)
|
||||
|
||||
|
||||
def test_semantic_ontology_service_prefers_expense_for_customer_entertainment_narrative() -> None:
|
||||
session_factory = build_session_factory()
|
||||
with session_factory() as db:
|
||||
result = SemanticOntologyService(db).parse(
|
||||
OntologyParseRequest(
|
||||
query="我今天去客户现场,招待了客户,花销了1000元",
|
||||
user_id="pytest",
|
||||
)
|
||||
)
|
||||
|
||||
assert result.scenario == "expense"
|
||||
assert result.intent == "draft"
|
||||
assert result.permission.level == "draft_write"
|
||||
assert result.time_range.raw == "今天"
|
||||
assert result.clarification_required is True
|
||||
assert "customer_name" in result.missing_slots
|
||||
assert "participants" in result.missing_slots
|
||||
assert any(
|
||||
item.type == "expense_type" and item.normalized_value == "entertainment"
|
||||
for item in result.entities
|
||||
)
|
||||
|
||||
|
||||
def test_semantic_ontology_service_uses_model_parse_when_available(monkeypatch) -> None:
|
||||
session_factory = build_session_factory()
|
||||
with session_factory() as db:
|
||||
service = SemanticOntologyService(db)
|
||||
monkeypatch.setattr(
|
||||
service,
|
||||
"_parse_with_model",
|
||||
lambda **kwargs: LlmOntologyParseResult(
|
||||
scenario="expense",
|
||||
intent="draft",
|
||||
confidence=0.91,
|
||||
clarification_required=True,
|
||||
clarification_question="请补充费用类型、金额和票据附件。",
|
||||
missing_slots=["expense_type", "amount", "attachments"],
|
||||
ambiguity=[],
|
||||
entity_hints=[],
|
||||
),
|
||||
)
|
||||
|
||||
result = service.parse(
|
||||
OntologyParseRequest(
|
||||
query="我要报销",
|
||||
user_id="pytest",
|
||||
)
|
||||
)
|
||||
|
||||
assert result.scenario == "expense"
|
||||
assert result.intent == "draft"
|
||||
assert result.parse_strategy == "llm_primary"
|
||||
assert result.clarification_required is True
|
||||
assert "expense_type" in result.missing_slots
|
||||
assert result.clarification_question == "请补充费用类型、金额和票据附件。"
|
||||
|
||||
|
||||
def test_parse_ontology_endpoint_returns_eight_fields_and_writes_trace() -> None:
|
||||
client, _ = build_client()
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/ontology/parse",
|
||||
json={
|
||||
"query": "查一下本周报销超标风险",
|
||||
"user_id": "pytest",
|
||||
"context_json": {"role_codes": ["finance"]},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["scenario"] == "expense"
|
||||
assert payload["intent"] == "risk_check"
|
||||
assert payload["permission"]["level"] == "read"
|
||||
assert payload["run_id"].startswith("run_")
|
||||
assert set(payload) >= {
|
||||
"scenario",
|
||||
"intent",
|
||||
"entities",
|
||||
"time_range",
|
||||
"metrics",
|
||||
"constraints",
|
||||
"risk_flags",
|
||||
"permission",
|
||||
"confidence",
|
||||
"missing_slots",
|
||||
"ambiguity",
|
||||
"parse_strategy",
|
||||
"clarification_required",
|
||||
"clarification_question",
|
||||
"run_id",
|
||||
"field_errors",
|
||||
}
|
||||
|
||||
run_response = client.get(f"/api/v1/agent-runs/{payload['run_id']}")
|
||||
|
||||
assert run_response.status_code == 200
|
||||
run_payload = run_response.json()
|
||||
assert run_payload["ontology_json"]["scenario"] == "expense"
|
||||
assert run_payload["ontology_json"]["intent"] == "risk_check"
|
||||
assert run_payload["semantic_parse"]["scenario"] == "expense"
|
||||
assert run_payload["semantic_parse"]["intent"] == "risk_check"
|
||||
|
||||
|
||||
def test_parse_ontology_endpoint_returns_forbidden_for_unprivileged_payment_request() -> None:
|
||||
client, _ = build_client()
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/ontology/parse",
|
||||
json={
|
||||
"query": "帮我直接付款给供应商B",
|
||||
"user_id": "pytest",
|
||||
"context_json": {"role_codes": ["user"]},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["scenario"] == "accounts_payable"
|
||||
assert payload["intent"] == "operate"
|
||||
assert payload["permission"]["level"] == "forbidden"
|
||||
assert payload["clarification_required"] is True
|
||||
assert payload["field_errors"]
|
||||
241
server/tests/test_orchestrator_service.py
Normal file
241
server/tests/test_orchestrator_service.py
Normal file
@@ -0,0 +1,241 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from app.api.deps import get_db
|
||||
from app.db.base import Base
|
||||
from app.main import create_app
|
||||
from app.services.agent_assets import AgentAssetService
|
||||
|
||||
|
||||
def build_client() -> tuple[TestClient, sessionmaker[Session]]:
|
||||
engine = create_engine(
|
||||
"sqlite+pysqlite:///:memory:",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
session_factory = sessionmaker(bind=engine, autoflush=False, autocommit=False)
|
||||
app = create_app()
|
||||
|
||||
def override_db() -> Generator[Session, None, None]:
|
||||
db = session_factory()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
app.dependency_overrides[get_db] = override_db
|
||||
return TestClient(app), session_factory
|
||||
|
||||
|
||||
def test_orchestrator_routes_user_query_to_user_agent() -> None:
|
||||
client, _ = build_client()
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/orchestrator/run",
|
||||
json={
|
||||
"source": "user_message",
|
||||
"user_id": "pytest",
|
||||
"message": "客户A这个月还有多少应收",
|
||||
"context_json": {"role_codes": ["finance"]},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["selected_agent"] == "user_agent"
|
||||
assert payload["permission_level"] == "read"
|
||||
assert payload["status"] == "succeeded"
|
||||
assert payload["result"]["answer"]
|
||||
assert payload["result"]["suggested_actions"]
|
||||
assert payload["trace_summary"]["tool_count"] >= 1
|
||||
|
||||
run_detail = client.get(f"/api/v1/agent-runs/{payload['run_id']}").json()
|
||||
assert run_detail["agent"] == "user_agent"
|
||||
assert run_detail["route_json"]["selected_agent"] == "user_agent"
|
||||
assert run_detail["semantic_parse"]["scenario"] == "accounts_receivable"
|
||||
assert run_detail["tool_calls"][0]["tool_type"] == "database"
|
||||
|
||||
|
||||
def test_orchestrator_routes_schedule_to_hermes() -> None:
|
||||
client, session_factory = build_client()
|
||||
|
||||
with session_factory() as db:
|
||||
task = next(
|
||||
item
|
||||
for item in AgentAssetService(db).list_assets(asset_type="task", status="active")
|
||||
if item.code == "task.hermes.daily_risk_scan"
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/orchestrator/run",
|
||||
json={
|
||||
"source": "schedule",
|
||||
"task_id": task.id,
|
||||
"context_json": {"role_codes": ["finance"]},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["selected_agent"] == "hermes"
|
||||
assert payload["status"] == "succeeded"
|
||||
assert payload["trace_summary"]["tool_count"] == 2
|
||||
|
||||
run_detail = client.get(f"/api/v1/agent-runs/{payload['run_id']}").json()
|
||||
assert run_detail["agent"] == "hermes"
|
||||
assert run_detail["route_json"]["selected_agent"] == "hermes"
|
||||
assert len(run_detail["tool_calls"]) == 2
|
||||
|
||||
|
||||
def test_orchestrator_forbidden_request_does_not_call_downstream_agent() -> None:
|
||||
client, _ = build_client()
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/orchestrator/run",
|
||||
json={
|
||||
"source": "user_message",
|
||||
"user_id": "pytest",
|
||||
"message": "帮我直接付款给供应商B",
|
||||
"context_json": {"role_codes": ["user"]},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["selected_agent"] is None
|
||||
assert payload["permission_level"] == "forbidden"
|
||||
assert payload["status"] == "blocked"
|
||||
assert payload["trace_summary"]["tool_count"] == 0
|
||||
|
||||
run_detail = client.get(f"/api/v1/agent-runs/{payload['run_id']}").json()
|
||||
assert run_detail["agent"] == "orchestrator"
|
||||
assert run_detail["tool_calls"] == []
|
||||
|
||||
|
||||
def test_orchestrator_approval_required_returns_confirmation_result() -> None:
|
||||
client, _ = build_client()
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/orchestrator/run",
|
||||
json={
|
||||
"source": "user_message",
|
||||
"user_id": "pytest",
|
||||
"message": "帮我安排付款给供应商B",
|
||||
"context_json": {"role_codes": ["finance"]},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["selected_agent"] == "user_agent"
|
||||
assert payload["permission_level"] == "approval_required"
|
||||
assert payload["requires_confirmation"] is True
|
||||
assert payload["status"] == "blocked"
|
||||
assert "确认" in payload["result"]["message"]
|
||||
|
||||
|
||||
def test_orchestrator_user_agent_draft_returns_structured_payload() -> None:
|
||||
client, _ = build_client()
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/orchestrator/run",
|
||||
json={
|
||||
"source": "user_message",
|
||||
"user_id": "pytest",
|
||||
"message": "帮我生成张三4月差旅报销草稿",
|
||||
"context_json": {"role_codes": ["finance"]},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["selected_agent"] == "user_agent"
|
||||
assert payload["status"] == "succeeded"
|
||||
assert payload["result"]["draft_payload"]["confirmation_required"] is True
|
||||
assert payload["result"]["suggested_actions"]
|
||||
|
||||
|
||||
def test_orchestrator_treats_expense_narrative_as_draft_instead_of_ar_query() -> None:
|
||||
client, _ = build_client()
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/orchestrator/run",
|
||||
json={
|
||||
"source": "user_message",
|
||||
"user_id": "pytest",
|
||||
"message": "我今天去客户现场,招待了客户,花销了1000元",
|
||||
"context_json": {"role_codes": ["finance"]},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["selected_agent"] == "user_agent"
|
||||
assert payload["permission_level"] == "draft_write"
|
||||
assert payload["status"] == "blocked"
|
||||
assert payload["route_reason"] == "clarification_required"
|
||||
assert payload["trace_summary"]["scenario"] == "expense"
|
||||
assert payload["trace_summary"]["intent"] == "draft"
|
||||
assert payload["trace_summary"]["tool_count"] == 0
|
||||
assert "应收场景数据" not in payload["result"]["message"]
|
||||
assert "请补充" in payload["result"]["message"]
|
||||
|
||||
|
||||
def test_orchestrator_tool_failure_is_logged_and_degraded() -> None:
|
||||
client, _ = build_client()
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/orchestrator/run",
|
||||
json={
|
||||
"source": "user_message",
|
||||
"user_id": "pytest",
|
||||
"message": "查一下本周报销金额",
|
||||
"context_json": {
|
||||
"role_codes": ["finance"],
|
||||
"simulate_tool_failure": "database",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["selected_agent"] == "user_agent"
|
||||
assert payload["status"] == "succeeded"
|
||||
assert payload["trace_summary"]["failed_tool_count"] == 1
|
||||
assert payload["trace_summary"]["degraded"] is True
|
||||
|
||||
run_detail = client.get(f"/api/v1/agent-runs/{payload['run_id']}").json()
|
||||
assert run_detail["tool_calls"][0]["status"] == "failed"
|
||||
assert "simulated database failure" in run_detail["tool_calls"][0]["error_message"]
|
||||
|
||||
|
||||
def test_orchestrator_exception_is_written_to_agent_run() -> None:
|
||||
client, _ = build_client()
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/orchestrator/run",
|
||||
json={
|
||||
"source": "user_message",
|
||||
"user_id": "pytest",
|
||||
"message": "查一下本周报销金额",
|
||||
"context_json": {
|
||||
"role_codes": ["finance"],
|
||||
"simulate_orchestrator_exception": True,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["status"] == "failed"
|
||||
|
||||
run_detail = client.get(f"/api/v1/agent-runs/{payload['run_id']}").json()
|
||||
assert run_detail["status"] == "failed"
|
||||
assert "simulated orchestrator exception" in run_detail["error_message"]
|
||||
179
server/tests/test_user_agent_service.py
Normal file
179
server/tests/test_user_agent_service.py
Normal file
@@ -0,0 +1,179 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from app.db.base import Base
|
||||
from app.schemas.ontology import OntologyParseRequest
|
||||
from app.schemas.user_agent import UserAgentRequest
|
||||
from app.services.ontology import SemanticOntologyService
|
||||
from app.services.user_agent import UserAgentService
|
||||
|
||||
|
||||
def build_session_factory() -> sessionmaker[Session]:
|
||||
engine = create_engine(
|
||||
"sqlite+pysqlite:///:memory:",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
return sessionmaker(bind=engine, autoflush=False, autocommit=False)
|
||||
|
||||
|
||||
def test_user_agent_query_returns_readable_answer_and_actions() -> None:
|
||||
session_factory = build_session_factory()
|
||||
with session_factory() as db:
|
||||
ontology = SemanticOntologyService(db).parse(
|
||||
OntologyParseRequest(
|
||||
query="张三 4 月差旅报销金额是多少",
|
||||
user_id="pytest",
|
||||
)
|
||||
)
|
||||
response = UserAgentService(db).respond(
|
||||
UserAgentRequest(
|
||||
run_id=ontology.run_id,
|
||||
user_id="pytest",
|
||||
message="张三 4 月差旅报销金额是多少",
|
||||
ontology=ontology,
|
||||
tool_payload={"record_count": 2, "total_amount": 8800.0},
|
||||
)
|
||||
)
|
||||
|
||||
assert "8800.00" in response.answer
|
||||
assert len(response.suggested_actions) >= 1
|
||||
|
||||
|
||||
def test_user_agent_prefers_runtime_model_answer_when_available(monkeypatch) -> None:
|
||||
session_factory = build_session_factory()
|
||||
with session_factory() as db:
|
||||
ontology = SemanticOntologyService(db).parse(
|
||||
OntologyParseRequest(
|
||||
query="张三 4 月差旅报销金额是多少",
|
||||
user_id="pytest",
|
||||
)
|
||||
)
|
||||
service = UserAgentService(db)
|
||||
monkeypatch.setattr(
|
||||
service,
|
||||
"_generate_answer_with_model",
|
||||
lambda *args, **kwargs: "这是模型回答",
|
||||
)
|
||||
|
||||
response = service.respond(
|
||||
UserAgentRequest(
|
||||
run_id=ontology.run_id,
|
||||
user_id="pytest",
|
||||
message="张三 4 月差旅报销金额是多少",
|
||||
ontology=ontology,
|
||||
tool_payload={"record_count": 2, "total_amount": 8800.0},
|
||||
)
|
||||
)
|
||||
|
||||
assert response.answer == "这是模型回答"
|
||||
|
||||
|
||||
def test_user_agent_sanitizes_model_thinking_blocks() -> None:
|
||||
session_factory = build_session_factory()
|
||||
with session_factory() as db:
|
||||
service = UserAgentService(db)
|
||||
|
||||
assert (
|
||||
service._sanitize_model_answer("<think>内部推理</think>\n最终答复")
|
||||
== "最终答复"
|
||||
)
|
||||
|
||||
|
||||
def test_user_agent_guides_generic_expense_request() -> None:
|
||||
session_factory = build_session_factory()
|
||||
with session_factory() as db:
|
||||
ontology = SemanticOntologyService(db).parse(
|
||||
OntologyParseRequest(
|
||||
query="我要报销",
|
||||
user_id="pytest",
|
||||
)
|
||||
)
|
||||
response = UserAgentService(db).respond(
|
||||
UserAgentRequest(
|
||||
run_id=ontology.run_id,
|
||||
user_id="pytest",
|
||||
message="我要报销",
|
||||
ontology=ontology,
|
||||
tool_payload={"record_count": 9, "total_amount": 12345.0},
|
||||
)
|
||||
)
|
||||
|
||||
assert "补充费用类型" in response.answer
|
||||
assert "上传票据" in response.answer
|
||||
|
||||
|
||||
def test_user_agent_guides_implicit_expense_draft_request() -> None:
|
||||
session_factory = build_session_factory()
|
||||
with session_factory() as db:
|
||||
ontology = SemanticOntologyService(db).parse(
|
||||
OntologyParseRequest(
|
||||
query="我今天去客户现场,招待了客户,花销了1000元",
|
||||
user_id="pytest",
|
||||
)
|
||||
)
|
||||
response = UserAgentService(db).respond(
|
||||
UserAgentRequest(
|
||||
run_id=ontology.run_id,
|
||||
user_id="pytest",
|
||||
message="我今天去客户现场,招待了客户,花销了1000元",
|
||||
ontology=ontology,
|
||||
tool_payload={"draft_only": True},
|
||||
)
|
||||
)
|
||||
|
||||
assert "1000元" in response.answer
|
||||
assert "票据附件" in response.answer
|
||||
assert "报销草稿" in response.answer
|
||||
|
||||
|
||||
def test_user_agent_risk_response_includes_rule_citations() -> None:
|
||||
session_factory = build_session_factory()
|
||||
with session_factory() as db:
|
||||
ontology = SemanticOntologyService(db).parse(
|
||||
OntologyParseRequest(
|
||||
query="检查重复报销风险",
|
||||
user_id="pytest",
|
||||
)
|
||||
)
|
||||
response = UserAgentService(db).respond(
|
||||
UserAgentRequest(
|
||||
run_id=ontology.run_id,
|
||||
user_id="pytest",
|
||||
message="检查重复报销风险",
|
||||
ontology=ontology,
|
||||
tool_payload={"risk_flags": ["duplicate_expense"]},
|
||||
)
|
||||
)
|
||||
|
||||
assert response.risk_flags == ["duplicate_expense"]
|
||||
assert any(item.source_type == "rule" for item in response.citations)
|
||||
assert "duplicate_expense" in response.answer
|
||||
|
||||
|
||||
def test_user_agent_draft_returns_structured_payload() -> None:
|
||||
session_factory = build_session_factory()
|
||||
with session_factory() as db:
|
||||
ontology = SemanticOntologyService(db).parse(
|
||||
OntologyParseRequest(
|
||||
query="帮我生成张三4月差旅报销草稿",
|
||||
user_id="pytest",
|
||||
)
|
||||
)
|
||||
response = UserAgentService(db).respond(
|
||||
UserAgentRequest(
|
||||
run_id=ontology.run_id,
|
||||
user_id="pytest",
|
||||
message="帮我生成张三4月差旅报销草稿",
|
||||
ontology=ontology,
|
||||
tool_payload={"draft_only": True},
|
||||
)
|
||||
)
|
||||
|
||||
assert response.draft_payload is not None
|
||||
assert response.draft_payload.confirmation_required is True
|
||||
assert "待人工确认" in response.answer
|
||||
Reference in New Issue
Block a user