feat(backend): add ontology and orchestrator API endpoints

New endpoints:
- server/src/app/api/v1/endpoints/ontology.py: ontology API
- server/src/app/api/v1/endpoints/orchestrator.py: orchestrator API

New schemas:
- server/src/app/schemas/ontology.py: ontology data schemas
- server/src/app/schemas/orchestrator.py: orchestrator data schemas
- server/src/app/schemas/user_agent.py: user agent data schemas

New services:
- server/src/app/services/ontology.py: ontology business logic
- server/src/app/services/orchestrator.py: orchestrator business logic
- server/src/app/services/runtime_chat.py: runtime chat service
- server/src/app/services/user_agent.py: user agent service

New tests:
- server/tests/test_ontology_service.py
- server/tests/test_orchestrator_service.py
- server/tests/test_user_agent_service.py
This commit is contained in:
caoxiaozhu
2026-05-12 01:24:39 +00:00
parent 19da459bb3
commit 22d47cbf2b
12 changed files with 4262 additions and 0 deletions

View File

@@ -0,0 +1,36 @@
from __future__ import annotations
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from app.api.deps import get_db
from app.schemas.common import ErrorResponse
from app.schemas.ontology import OntologyParseRequest, OntologyParseResult
from app.services.ontology import SemanticOntologyService
router = APIRouter(prefix="/ontology")
DbSession = Annotated[Session, Depends(get_db)]
@router.post(
"/parse",
response_model=OntologyParseResult,
summary="解析自然语言为语义本体",
description=(
"把自然语言问题解析成 Day 3 约定的 8 个核心字段,"
"并写入 AgentRun 与 SemanticParseLog。"
),
responses={
status.HTTP_400_BAD_REQUEST: {
"model": ErrorResponse,
"description": "请求缺少有效 query 或解析请求格式不合法。",
}
},
)
def parse_ontology(payload: OntologyParseRequest, db: DbSession) -> OntologyParseResult:
try:
return SemanticOntologyService(db).parse(payload)
except ValueError as exc:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc

View File

@@ -0,0 +1,33 @@
from __future__ import annotations
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from app.api.deps import get_db
from app.schemas.common import ErrorResponse
from app.schemas.orchestrator import OrchestratorRequest, OrchestratorResponse
from app.services.orchestrator import OrchestratorService
router = APIRouter(prefix="/orchestrator")
DbSession = Annotated[Session, Depends(get_db)]
@router.post(
"/run",
response_model=OrchestratorResponse,
summary="运行 Orchestrator 统一调度",
description="统一接收用户消息、定时任务和系统事件,完成语义解析、权限判断、路由和占位执行。",
responses={
status.HTTP_400_BAD_REQUEST: {
"model": ErrorResponse,
"description": "请求缺少 message 或 task_id无法启动调度。",
}
},
)
def run_orchestrator(payload: OrchestratorRequest, db: DbSession) -> OrchestratorResponse:
try:
return OrchestratorService(db).run(payload)
except ValueError as exc:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc

View File

@@ -0,0 +1,116 @@
from __future__ import annotations
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field
OntologyScenario = Literal[
"expense",
"accounts_receivable",
"accounts_payable",
"knowledge",
"unknown",
]
OntologyIntent = Literal["query", "explain", "compare", "risk_check", "draft", "operate"]
OntologyPermissionLevel = Literal["read", "draft_write", "approval_required", "forbidden"]
OntologyParseStrategy = Literal["llm_primary", "rule_fallback"]
class OntologyEntity(BaseModel):
model_config = ConfigDict(extra="forbid")
type: str = Field(description="业务对象类型,例如 employee / customer / vendor。")
value: str = Field(description="从原始问题中提取的对象值。")
normalized_value: str = Field(description="标准化后的对象值。")
role: str = Field(default="target", description="对象角色,例如 target / filter / threshold。")
confidence: float = Field(default=0.0, ge=0.0, le=1.0, description="字段级置信度。")
class OntologyTimeRange(BaseModel):
model_config = ConfigDict(extra="forbid")
raw: str = Field(default="", description="命中的原始时间表达。")
start_date: str | None = Field(default=None, description="ISO 格式起始日期。")
end_date: str | None = Field(default=None, description="ISO 格式结束日期。")
granularity: str | None = Field(
default=None,
description="day / week / month / quarter / year。",
)
class OntologyMetric(BaseModel):
model_config = ConfigDict(extra="forbid")
name: str = Field(description="指标名,例如 amount / count / overdue。")
aggregation: str | None = Field(default=None, description="sum / count / max 等聚合口径。")
unit: str | None = Field(default=None, description="金额、数量等单位。")
sort: str | None = Field(default=None, description="asc / desc 排序方向。")
top_n: int | None = Field(default=None, ge=1, description="Top N 口径。")
class OntologyConstraint(BaseModel):
model_config = ConfigDict(extra="forbid")
field: str = Field(description="约束字段,例如 department / status / amount。")
operator: str = Field(description="操作符,例如 = / > / < / desc。")
value: str | int | float | bool = Field(description="约束值。")
currency: str | None = Field(default=None, description="金额类约束使用的币种。")
class OntologyPermission(BaseModel):
model_config = ConfigDict(extra="forbid")
level: OntologyPermissionLevel = Field(default="read", description="动作权限等级。")
allowed: bool = Field(default=True, description="是否可直接执行当前动作。")
reason: str = Field(default="", description="权限判断原因。")
class OntologyFieldError(BaseModel):
model_config = ConfigDict(extra="forbid")
field: str = Field(description="发生问题的字段。")
code: str = Field(description="错误码。")
message: str = Field(description="面向前端展示的说明。")
class OntologyParseRequest(BaseModel):
query: str = Field(min_length=1, description="自然语言问题。")
user_id: str | None = Field(default=None, description="当前请求用户 ID。")
context_json: dict[str, Any] = Field(
default_factory=dict,
description="用户上下文,例如角色、部门、是否管理员。",
)
class OntologyParseResult(BaseModel):
scenario: OntologyScenario = Field(default="unknown", description="业务场景。")
intent: OntologyIntent = Field(default="query", description="用户意图。")
entities: list[OntologyEntity] = Field(default_factory=list, description="业务对象列表。")
time_range: OntologyTimeRange = Field(
default_factory=OntologyTimeRange,
description="时间范围。",
)
metrics: list[OntologyMetric] = Field(default_factory=list, description="指标解析结果。")
constraints: list[OntologyConstraint] = Field(
default_factory=list,
description="过滤、阈值、排序等约束。",
)
risk_flags: list[str] = Field(default_factory=list, description="风险信号列表。")
permission: OntologyPermission = Field(
default_factory=OntologyPermission,
description="权限结果。",
)
confidence: float = Field(default=0.0, ge=0.0, le=1.0, description="整体置信度。")
missing_slots: list[str] = Field(default_factory=list, description="继续处理所缺少的关键槽位。")
ambiguity: list[str] = Field(default_factory=list, description="当前识别中的潜在歧义。")
parse_strategy: OntologyParseStrategy = Field(
default="rule_fallback",
description="本次语义解析使用的主策略。",
)
clarification_required: bool = Field(default=False, description="是否需要追问。")
clarification_question: str | None = Field(default=None, description="推荐追问问题。")
run_id: str = Field(description="关联的 AgentRun.run_id。")
field_errors: list[OntologyFieldError] = Field(
default_factory=list,
description="字段级错误或提示。",
)

View File

@@ -0,0 +1,46 @@
from __future__ import annotations
from typing import Any, Literal
from pydantic import BaseModel, Field
OrchestratorSource = Literal["user_message", "schedule", "system_event"]
OrchestratorAgent = Literal["user_agent", "hermes"]
OrchestratorStatus = Literal["succeeded", "blocked", "failed"]
class OrchestratorRequest(BaseModel):
source: OrchestratorSource = Field(description="请求来源。")
user_id: str | None = Field(default=None, description="当前用户 ID任务触发可为空。")
message: str | None = Field(default=None, description="用户消息或任务描述。")
task_id: str | None = Field(default=None, description="任务资产 IDschedule 触发时优先使用。")
context_json: dict[str, Any] = Field(
default_factory=dict,
description="用户上下文、测试开关或调用方附加信息。",
)
class OrchestratorTraceSummary(BaseModel):
scenario: str = Field(description="语义场景。")
intent: str = Field(description="语义意图。")
tool_count: int = Field(default=0, ge=0, description="工具调用总数。")
failed_tool_count: int = Field(default=0, ge=0, description="失败工具调用数量。")
selected_capability_codes: list[str] = Field(
default_factory=list,
description="本次路由命中的能力编码。",
)
degraded: bool = Field(default=False, description="是否发生降级。")
class OrchestratorResponse(BaseModel):
run_id: str = Field(description="本次运行的唯一 run_id。")
selected_agent: OrchestratorAgent | None = Field(
default=None,
description="最终路由到的下游 Agent。",
)
route_reason: str = Field(description="路由原因摘要。")
permission_level: str = Field(description="权限级别。")
status: OrchestratorStatus = Field(description="最终运行状态。")
result: dict[str, Any] = Field(default_factory=dict, description="对前端可直接展示的最小结果。")
requires_confirmation: bool = Field(default=False, description="是否需要用户或管理员确认。")
trace_summary: OrchestratorTraceSummary = Field(description="简化后的 Trace 摘要。")

View File

@@ -0,0 +1,58 @@
from __future__ import annotations
from typing import Any, Literal
from pydantic import BaseModel, Field
from app.schemas.ontology import OntologyParseResult
UserAgentCitationType = Literal["rule", "knowledge"]
class UserAgentCitation(BaseModel):
source_type: UserAgentCitationType = Field(description="引用来源类型。")
code: str = Field(description="来源编码。")
title: str = Field(description="来源标题。")
version: str | None = Field(default=None, description="引用版本。")
updated_at: str | None = Field(default=None, description="来源更新时间。")
excerpt: str | None = Field(default=None, description="面向用户展示的引用摘要。")
class UserAgentSuggestedAction(BaseModel):
label: str = Field(description="建议动作文案。")
action_type: str = Field(description="动作类型,例如 open_detail / create_draft。")
description: str = Field(default="", description="动作说明。")
class UserAgentDraftPayload(BaseModel):
draft_type: str = Field(description="草稿类型。")
title: str = Field(description="草稿标题。")
body: str = Field(description="草稿正文。")
confirmation_required: bool = Field(default=True, description="是否需要人工确认。")
class UserAgentRequest(BaseModel):
run_id: str = Field(description="关联的 AgentRun.run_id。")
user_id: str | None = Field(default=None, description="当前请求用户 ID。")
message: str = Field(description="原始用户问题。")
ontology: OntologyParseResult = Field(description="语义解析结果。")
context_json: dict[str, Any] = Field(default_factory=dict, description="附加上下文。")
tool_payload: dict[str, Any] = Field(default_factory=dict, description="工具返回的原始结果。")
selected_capability_codes: list[str] = Field(
default_factory=list,
description="本次命中的能力编码。",
)
degraded: bool = Field(default=False, description="当前是否发生降级。")
requires_confirmation: bool = Field(default=False, description="是否要求确认。")
class UserAgentResponse(BaseModel):
answer: str = Field(description="面向用户展示的自然语言回答。")
citations: list[UserAgentCitation] = Field(default_factory=list, description="规则或知识引用。")
suggested_actions: list[UserAgentSuggestedAction] = Field(
default_factory=list,
description="建议的下一步动作。",
)
draft_payload: UserAgentDraftPayload | None = Field(default=None, description="可选草稿内容。")
risk_flags: list[str] = Field(default_factory=list, description="本次回答关联的风险标签。")
requires_confirmation: bool = Field(default=False, description="是否需要人工确认。")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,887 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import UTC, datetime
from time import perf_counter
from typing import Any
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from app.core.agent_enums import (
AgentAssetStatus,
AgentAssetType,
AgentName,
AgentPermissionLevel,
AgentRunSource,
AgentRunStatus,
AgentToolType,
)
from app.core.logging import get_logger
from app.models.financial_record import (
AccountsPayableRecord,
AccountsReceivableRecord,
ExpenseClaim,
)
from app.schemas.agent_asset import AgentAssetListItem, AgentAssetRead
from app.schemas.ontology import OntologyParseRequest, OntologyParseResult
from app.schemas.orchestrator import (
OrchestratorRequest,
OrchestratorResponse,
OrchestratorTraceSummary,
)
from app.schemas.user_agent import UserAgentRequest, UserAgentResponse
from app.services.agent_assets import AgentAssetService
from app.services.agent_foundation import AgentFoundationService
from app.services.agent_runs import AgentRunService
from app.services.ontology import SemanticOntologyService
from app.services.user_agent import UserAgentService
logger = get_logger("app.services.orchestrator")
SCENARIO_TO_DOMAIN = {
"expense": "expense",
"accounts_receivable": "ar",
"accounts_payable": "ap",
"knowledge": "knowledge",
"unknown": "system",
}
@dataclass(slots=True)
class ExecutionOutcome:
status: str
result: dict[str, Any]
degraded: bool
tool_count: int
failed_tool_count: int
class OrchestratorService:
def __init__(self, db: Session) -> None:
self.db = db
self.asset_service = AgentAssetService(db)
self.run_service = AgentRunService(db)
self.ontology_service = SemanticOntologyService(db)
self.user_agent_service = UserAgentService(db)
def run(self, payload: OrchestratorRequest) -> OrchestratorResponse:
AgentFoundationService(self.db).ensure_foundation_ready()
route_json: dict[str, Any] = {
"orchestrated_by": AgentName.ORCHESTRATOR.value,
"stage": "created",
}
run = self.run_service.create_run(
agent=AgentName.ORCHESTRATOR.value,
source=payload.source,
user_id=payload.user_id,
task_id=payload.task_id,
ontology_json={},
route_json=route_json,
permission_level=AgentPermissionLevel.READ.value,
status=AgentRunStatus.RUNNING.value,
result_summary="Orchestrator 已接收请求。",
)
try:
message, task_asset = self._resolve_message(payload)
ontology = self.ontology_service.parse_for_run(
OntologyParseRequest(
query=message,
user_id=payload.user_id,
context_json=payload.context_json,
),
run_id=run.run_id,
)
if payload.context_json.get("simulate_orchestrator_exception"):
raise RuntimeError("simulated orchestrator exception")
selected_agent, route_reason = self._select_agent(payload, ontology)
capabilities = self._select_capabilities(
payload=payload,
ontology=ontology,
task_asset=task_asset,
)
selected_capability_codes = self._flatten_capability_codes(capabilities)
requires_confirmation = (
ontology.permission.level == AgentPermissionLevel.APPROVAL_REQUIRED.value
)
route_json = {
"orchestrated_by": AgentName.ORCHESTRATOR.value,
"stage": "routed",
"selected_agent": selected_agent,
"route_reason": route_reason,
"selected_capability_codes": selected_capability_codes,
"ontology_run_id": ontology.run_id,
}
if ontology.permission.level == AgentPermissionLevel.FORBIDDEN.value:
outcome = ExecutionOutcome(
status=AgentRunStatus.BLOCKED.value,
result={
"message": ontology.permission.reason,
"clarification_question": ontology.clarification_question,
"degraded": False,
},
degraded=False,
tool_count=0,
failed_tool_count=0,
)
selected_agent = None
route_reason = "permission_forbidden"
route_json["stage"] = "blocked"
route_json["route_reason"] = route_reason
elif ontology.clarification_required:
outcome = ExecutionOutcome(
status=AgentRunStatus.BLOCKED.value,
result={
"message": ontology.clarification_question or "需要补充更多上下文。",
"clarification_required": True,
"missing_slots": ontology.missing_slots,
"ambiguity": ontology.ambiguity,
"parse_strategy": ontology.parse_strategy,
"degraded": False,
},
degraded=False,
tool_count=0,
failed_tool_count=0,
)
route_reason = "clarification_required"
route_json["stage"] = "clarification"
route_json["route_reason"] = route_reason
elif selected_agent == AgentName.HERMES.value:
outcome = self._execute_hermes(
payload=payload,
run_id=run.run_id,
ontology=ontology,
capabilities=capabilities,
requires_confirmation=requires_confirmation,
task_asset=task_asset,
)
else:
outcome = self._execute_user_agent(
payload=payload,
run_id=run.run_id,
ontology=ontology,
capabilities=capabilities,
requires_confirmation=requires_confirmation,
)
final_status = (
AgentRunStatus.BLOCKED.value
if requires_confirmation
and outcome.status == AgentRunStatus.SUCCEEDED.value
and ontology.permission.level == AgentPermissionLevel.APPROVAL_REQUIRED.value
else outcome.status
)
result_message = (
str(outcome.result.get("message", "")).strip()
or "Orchestrator 执行完成。"
)
self.run_service.update_run(
run.run_id,
agent=selected_agent or AgentName.ORCHESTRATOR.value,
ontology_json=self._build_ontology_json(ontology),
route_json={
**route_json,
"requires_confirmation": requires_confirmation,
"degraded": outcome.degraded,
},
permission_level=ontology.permission.level,
status=final_status,
result_summary=result_message,
error_message=None,
finished_at=datetime.now(UTC),
)
return OrchestratorResponse(
run_id=run.run_id,
selected_agent=selected_agent,
route_reason=route_reason,
permission_level=ontology.permission.level,
status=self._normalize_response_status(final_status),
result=outcome.result,
requires_confirmation=requires_confirmation,
trace_summary=OrchestratorTraceSummary(
scenario=ontology.scenario,
intent=ontology.intent,
tool_count=outcome.tool_count,
failed_tool_count=outcome.failed_tool_count,
selected_capability_codes=selected_capability_codes,
degraded=outcome.degraded,
),
)
except Exception as exc:
logger.exception("Orchestrator run failed run_id=%s", run.run_id)
self.run_service.update_run(
run.run_id,
agent=AgentName.ORCHESTRATOR.value,
route_json={**route_json, "stage": "failed"},
status=AgentRunStatus.FAILED.value,
result_summary="Orchestrator 执行失败。",
error_message=str(exc),
finished_at=datetime.now(UTC),
)
return OrchestratorResponse(
run_id=run.run_id,
selected_agent=None,
route_reason="orchestrator_exception",
permission_level=AgentPermissionLevel.READ.value,
status="failed",
result={"message": f"Orchestrator 执行失败:{exc}"},
requires_confirmation=False,
trace_summary=OrchestratorTraceSummary(
scenario="unknown",
intent="query",
tool_count=0,
failed_tool_count=0,
selected_capability_codes=[],
degraded=False,
),
)
def _resolve_message(
self,
payload: OrchestratorRequest,
) -> tuple[str, AgentAssetRead | None]:
task_asset = None
if payload.task_id:
task_asset = self.asset_service.get_asset(payload.task_id)
if payload.message and payload.message.strip():
return payload.message.strip(), task_asset
if task_asset is not None:
description = str(task_asset.description or "").strip()
scenario_text = " ".join(str(item) for item in task_asset.scenario_json)
message = f"{task_asset.name} {description} {scenario_text}".strip()
return message, task_asset
if payload.source == AgentRunSource.SCHEDULE.value:
return "定时风险巡检任务", task_asset
raise ValueError("message 或 task_id 至少需要提供一个。")
@staticmethod
def _select_agent(
payload: OrchestratorRequest,
ontology: OntologyParseResult,
) -> tuple[str, str]:
if payload.source == AgentRunSource.SCHEDULE.value:
return AgentName.HERMES.value, "schedule_source_defaults_to_hermes"
if payload.source == AgentRunSource.SYSTEM_EVENT.value and ontology.intent == "risk_check":
return AgentName.HERMES.value, "system_event_risk_check_routes_to_hermes"
if ontology.intent == "risk_check" and payload.source == AgentRunSource.SCHEDULE.value:
return AgentName.HERMES.value, "scheduled_risk_check_routes_to_hermes"
if ontology.intent in {"query", "explain", "draft", "compare", "operate"}:
return AgentName.USER_AGENT.value, f"{ontology.intent}_routes_to_user_agent"
return AgentName.USER_AGENT.value, "user_message_defaults_to_user_agent"
def _select_capabilities(
self,
*,
payload: OrchestratorRequest,
ontology: OntologyParseResult,
task_asset: AgentAssetRead | None,
) -> dict[str, list[AgentAssetListItem | AgentAssetRead]]:
domain_value = SCENARIO_TO_DOMAIN.get(ontology.scenario)
rules = self._rank_assets(
self.asset_service.list_assets(
asset_type=AgentAssetType.RULE.value,
status=AgentAssetStatus.ACTIVE.value,
domain=domain_value if domain_value not in {"knowledge", "system"} else None,
),
ontology,
)
skills = self._rank_assets(
self.asset_service.list_assets(
asset_type=AgentAssetType.SKILL.value,
status=AgentAssetStatus.ACTIVE.value,
domain=domain_value if domain_value not in {"system"} else None,
),
ontology,
)
mcps = self._rank_assets(
self.asset_service.list_assets(
asset_type=AgentAssetType.MCP.value,
status=AgentAssetStatus.ACTIVE.value,
),
ontology,
)
tasks: list[AgentAssetListItem | AgentAssetRead] = []
if task_asset is not None and task_asset.status == AgentAssetStatus.ACTIVE.value:
tasks.append(task_asset)
elif payload.source == AgentRunSource.SCHEDULE.value:
tasks = self._rank_assets(
self.asset_service.list_assets(
asset_type=AgentAssetType.TASK.value,
status=AgentAssetStatus.ACTIVE.value,
),
ontology,
)
return {
"rules": rules,
"skills": skills,
"mcps": mcps,
"tasks": tasks,
}
def _execute_user_agent(
self,
*,
payload: OrchestratorRequest,
run_id: str,
ontology: OntologyParseResult,
capabilities: dict[str, list[AgentAssetListItem | AgentAssetRead]],
requires_confirmation: bool,
) -> ExecutionOutcome:
selected_capability_codes = self._flatten_capability_codes(capabilities)
if requires_confirmation:
response, degraded = self._invoke_tool(
run_id=run_id,
tool_type=AgentToolType.LLM.value,
tool_name="user_agent.confirmation_placeholder",
request_json={
"message": payload.message,
"permission_level": ontology.permission.level,
},
context_json=payload.context_json,
executor=lambda: {
"confirmation_title": "操作需要确认",
"message": f"{ontology.permission.reason} 当前仅返回确认摘要,不直接执行动作。",
},
fallback_factory=lambda exc: {
"confirmation_title": "操作需要确认",
"message": f"确认摘要生成失败,已阻断自动执行:{exc}",
},
)
return ExecutionOutcome(
status=AgentRunStatus.BLOCKED.value,
result={**response, "degraded": degraded},
degraded=degraded,
tool_count=1,
failed_tool_count=1 if degraded else 0,
)
next_step = self._resolve_next_step(ontology, payload.source)
if next_step == "query_database":
tool_payload, degraded = self._invoke_tool(
run_id=run_id,
tool_type=AgentToolType.DATABASE.value,
tool_name=self._database_tool_name(ontology.scenario),
request_json=self._build_ontology_json(ontology),
context_json=payload.context_json,
executor=lambda: self._build_database_answer(ontology),
fallback_factory=lambda exc: {
"message": f"数据库查询暂时不可用,已返回降级说明:{exc}",
"degraded": True,
},
)
result = self._build_user_agent_result(
self.user_agent_service.respond(
UserAgentRequest(
run_id=run_id,
user_id=payload.user_id,
message=payload.message or "",
ontology=ontology,
context_json=payload.context_json,
tool_payload=tool_payload,
selected_capability_codes=selected_capability_codes,
degraded=degraded,
requires_confirmation=requires_confirmation,
)
),
degraded=degraded,
)
return ExecutionOutcome(
status=AgentRunStatus.SUCCEEDED.value,
result=result,
degraded=degraded,
tool_count=1,
failed_tool_count=1 if degraded else 0,
)
if next_step == "search_knowledge":
tool_payload, degraded = self._invoke_tool(
run_id=run_id,
tool_type=AgentToolType.DATABASE.value,
tool_name="knowledge.search",
request_json=self._build_ontology_json(ontology),
context_json=payload.context_json,
executor=lambda: self._build_knowledge_answer(ontology, capabilities),
fallback_factory=lambda exc: {
"message": f"知识检索暂时不可用,建议稍后重试:{exc}",
"degraded": True,
},
)
result = self._build_user_agent_result(
self.user_agent_service.respond(
UserAgentRequest(
run_id=run_id,
user_id=payload.user_id,
message=payload.message or "",
ontology=ontology,
context_json=payload.context_json,
tool_payload=tool_payload,
selected_capability_codes=selected_capability_codes,
degraded=degraded,
requires_confirmation=requires_confirmation,
)
),
degraded=degraded,
)
return ExecutionOutcome(
status=AgentRunStatus.SUCCEEDED.value,
result=result,
degraded=degraded,
tool_count=1,
failed_tool_count=1 if degraded else 0,
)
if next_step == "run_rule":
tool_payload, degraded = self._invoke_tool(
run_id=run_id,
tool_type=AgentToolType.RULE_ENGINE.value,
tool_name=self._rule_tool_name(capabilities),
request_json=self._build_ontology_json(ontology),
context_json=payload.context_json,
executor=lambda: self._build_rule_answer(ontology),
fallback_factory=lambda exc: {
"message": f"规则检查暂时不可用,已返回人工复核建议:{exc}",
"degraded": True,
},
)
result = self._build_user_agent_result(
self.user_agent_service.respond(
UserAgentRequest(
run_id=run_id,
user_id=payload.user_id,
message=payload.message or "",
ontology=ontology,
context_json=payload.context_json,
tool_payload=tool_payload,
selected_capability_codes=selected_capability_codes,
degraded=degraded,
requires_confirmation=requires_confirmation,
)
),
degraded=degraded,
)
return ExecutionOutcome(
status=AgentRunStatus.SUCCEEDED.value,
result=result,
degraded=degraded,
tool_count=1,
failed_tool_count=1 if degraded else 0,
)
tool_payload, degraded = self._invoke_tool(
run_id=run_id,
tool_type=AgentToolType.LLM.value,
tool_name="user_agent.draft_placeholder",
request_json=self._build_ontology_json(ontology),
context_json=payload.context_json,
executor=lambda: {
"message": (
f"已生成 {ontology.scenario} 场景草稿,"
"占位能力后续由 Day 5 User Agent 接管。"
),
"draft_only": True,
},
fallback_factory=lambda exc: {
"message": f"草稿生成暂时不可用,请稍后再试:{exc}",
"degraded": True,
},
)
result = self._build_user_agent_result(
self.user_agent_service.respond(
UserAgentRequest(
run_id=run_id,
user_id=payload.user_id,
message=payload.message or "",
ontology=ontology,
context_json=payload.context_json,
tool_payload=tool_payload,
selected_capability_codes=selected_capability_codes,
degraded=degraded,
requires_confirmation=requires_confirmation,
)
),
degraded=degraded,
)
return ExecutionOutcome(
status=AgentRunStatus.SUCCEEDED.value,
result=result,
degraded=degraded,
tool_count=1,
failed_tool_count=1 if degraded else 0,
)
def _execute_hermes(
self,
*,
payload: OrchestratorRequest,
run_id: str,
ontology: OntologyParseResult,
capabilities: dict[str, list[AgentAssetListItem | AgentAssetRead]],
requires_confirmation: bool,
task_asset: AgentAssetRead | None,
) -> ExecutionOutcome:
if requires_confirmation:
return ExecutionOutcome(
status=AgentRunStatus.BLOCKED.value,
result={
"message": "Hermes 不会自动执行需要确认的高风险动作,已阻断。",
"degraded": False,
},
degraded=False,
tool_count=0,
failed_tool_count=0,
)
rule_response, rule_degraded = self._invoke_tool(
run_id=run_id,
tool_type=AgentToolType.RULE_ENGINE.value,
tool_name=self._rule_tool_name(capabilities),
request_json=self._build_ontology_json(ontology),
context_json=payload.context_json,
executor=lambda: self._build_rule_answer(ontology),
fallback_factory=lambda exc: {
"message": f"规则巡检失败,已降级为待人工复核:{exc}",
"degraded": True,
},
)
mcp_response, mcp_degraded = self._invoke_tool(
run_id=run_id,
tool_type=AgentToolType.MCP.value,
tool_name=self._mcp_tool_name(capabilities),
request_json={
"task_code": task_asset.code if task_asset is not None else "",
"scenario": ontology.scenario,
},
context_json=payload.context_json,
executor=lambda: self._build_mcp_answer(task_asset, ontology),
fallback_factory=lambda exc: {
"message": f"MCP 调用失败,已使用缓存快照降级:{exc}",
"fallback": "used_cached_snapshot",
},
)
degraded = rule_degraded or mcp_degraded
failed_tool_count = int(rule_degraded) + int(mcp_degraded)
result = {
"message": self._build_hermes_message(
task_asset=task_asset,
ontology=ontology,
rule_response=rule_response,
mcp_response=mcp_response,
degraded=degraded,
),
"report_type": task_asset.code if task_asset is not None else "hermes_runtime",
"degraded": degraded,
}
return ExecutionOutcome(
status=AgentRunStatus.SUCCEEDED.value,
result=result,
degraded=degraded,
tool_count=2,
failed_tool_count=failed_tool_count,
)
@staticmethod
def _resolve_next_step(ontology: OntologyParseResult, source: str) -> str:
if ontology.clarification_required:
return "ask_clarification"
if ontology.intent == "draft":
return "create_draft"
if ontology.scenario == "knowledge" or ontology.intent == "explain":
return "search_knowledge"
if ontology.intent == "risk_check" or source == AgentRunSource.SCHEDULE.value:
return "run_rule"
if ontology.intent in {"query", "compare"}:
return "query_database"
return "create_draft"
@staticmethod
def _flatten_capability_codes(
capabilities: dict[str, list[AgentAssetListItem | AgentAssetRead]],
) -> list[str]:
codes: list[str] = []
for items in capabilities.values():
for item in items[:2]:
if item.code not in codes:
codes.append(item.code)
return codes
def _rank_assets(
self,
items: list[AgentAssetListItem],
ontology: OntologyParseResult,
) -> list[AgentAssetListItem]:
def score(item: AgentAssetListItem) -> tuple[int, str]:
item_tags = {str(value) for value in item.scenario_json or []}
weight = 0
if ontology.scenario in item_tags:
weight += 3
if ontology.intent in item_tags:
weight += 2
for risk_flag in ontology.risk_flags:
if risk_flag in item_tags:
weight += 4
return weight, item.code
ranked = sorted(items, key=score, reverse=True)
if not ranked:
return []
scored = [item for item in ranked if score(item)[0] > 0]
return scored or ranked[:1]
def _invoke_tool(
self,
*,
run_id: str,
tool_type: str,
tool_name: str,
request_json: dict[str, Any],
context_json: dict[str, Any],
executor,
fallback_factory,
) -> tuple[dict[str, Any], bool]:
started = perf_counter()
try:
self._maybe_raise_simulated_failure(tool_type, context_json)
response = executor()
duration_ms = int((perf_counter() - started) * 1000)
self.run_service.record_tool_call(
run_id=run_id,
tool_type=tool_type,
tool_name=tool_name,
request_json=request_json,
response_json=response,
status="succeeded",
duration_ms=duration_ms,
)
return response, False
except Exception as exc:
duration_ms = int((perf_counter() - started) * 1000)
response = fallback_factory(exc)
self.run_service.record_tool_call(
run_id=run_id,
tool_type=tool_type,
tool_name=tool_name,
request_json=request_json,
response_json=response,
status="failed",
duration_ms=duration_ms,
error_message=str(exc),
)
return response, True
@staticmethod
def _maybe_raise_simulated_failure(tool_type: str, context_json: dict[str, Any]) -> None:
expected = str(context_json.get("simulate_tool_failure") or "").strip().lower()
if not expected:
return
if expected == tool_type.lower():
raise RuntimeError(f"simulated {tool_type} failure")
def _build_database_answer(self, ontology: OntologyParseResult) -> dict[str, Any]:
if ontology.scenario == "expense":
count_stmt = select(func.count()).select_from(ExpenseClaim)
amount_stmt = select(
func.coalesce(func.sum(ExpenseClaim.amount), 0)
).select_from(ExpenseClaim)
employee_names = [
item.normalized_value
for item in ontology.entities
if item.type == "employee"
]
if employee_names:
count_stmt = count_stmt.where(ExpenseClaim.employee_name.in_(employee_names))
amount_stmt = amount_stmt.where(ExpenseClaim.employee_name.in_(employee_names))
total_count = int(self.db.scalar(count_stmt) or 0)
total_amount = float(self.db.scalar(amount_stmt) or 0)
return {
"record_count": total_count,
"total_amount": round(total_amount, 2),
}
if ontology.scenario == "accounts_receivable":
total_count = int(
self.db.scalar(
select(func.count()).select_from(AccountsReceivableRecord)
)
or 0
)
total_amount = float(
self.db.scalar(
select(func.coalesce(func.sum(AccountsReceivableRecord.amount_outstanding), 0))
)
or 0
)
return {
"record_count": total_count,
"outstanding_amount": round(total_amount, 2),
}
total_count = int(
self.db.scalar(select(func.count()).select_from(AccountsPayableRecord))
or 0
)
total_amount = float(
self.db.scalar(
select(func.coalesce(func.sum(AccountsPayableRecord.amount_outstanding), 0))
)
or 0
)
return {
"record_count": total_count,
"outstanding_amount": round(total_amount, 2),
}
@staticmethod
def _build_user_query_result(
ontology: OntologyParseResult,
response: dict[str, Any],
) -> dict[str, Any]:
if ontology.scenario == "expense":
return {
"message": (
f"已路由到 User Agent占位查询结果命中 {response['record_count']} 笔报销,"
f"金额合计 {response['total_amount']} 元。"
),
"data": response,
}
if ontology.scenario == "accounts_receivable":
return {
"message": (
f"已路由到 User Agent占位查询结果命中 {response['record_count']} 条应收,"
f"未回款金额 {response['outstanding_amount']} 元。"
),
"data": response,
}
return {
"message": (
f"已路由到 User Agent占位查询结果命中 {response['record_count']} 条应付,"
f"待付金额 {response['outstanding_amount']} 元。"
),
"data": response,
}
@staticmethod
def _build_user_agent_result(
response: UserAgentResponse,
*,
degraded: bool,
) -> dict[str, Any]:
result = {
"message": response.answer,
"answer": response.answer,
"citations": [item.model_dump() for item in response.citations],
"suggested_actions": [item.model_dump() for item in response.suggested_actions],
"risk_flags": response.risk_flags,
"requires_confirmation": response.requires_confirmation,
"degraded": degraded,
}
if response.draft_payload is not None:
result["draft_payload"] = response.draft_payload.model_dump()
return result
@staticmethod
def _build_knowledge_answer(
ontology: OntologyParseResult,
capabilities: dict[str, list[AgentAssetListItem | AgentAssetRead]],
) -> dict[str, Any]:
referenced = [item.code for item in capabilities["rules"][:1]] or [
"knowledge.policy.default"
]
return {
"message": f"已路由到 User Agent占位知识结果建议先查看 {', '.join(referenced)}",
"references": referenced,
}
@staticmethod
def _build_rule_answer(ontology: OntologyParseResult) -> dict[str, Any]:
risk_text = (
"".join(ontology.risk_flags)
if ontology.risk_flags
else "未识别到明确风险标签"
)
return {
"message": f"已完成占位规则检查,风险标签:{risk_text}",
"risk_flags": ontology.risk_flags,
}
@staticmethod
def _build_mcp_answer(
task_asset: AgentAssetRead | None,
ontology: OntologyParseResult,
) -> dict[str, Any]:
return {
"message": (
f"已调用占位 MCP 快照,任务={task_asset.code if task_asset else 'none'}"
f"scenario={ontology.scenario}"
),
"snapshot": "stubbed",
}
@staticmethod
def _build_hermes_message(
*,
task_asset: AgentAssetRead | None,
ontology: OntologyParseResult,
rule_response: dict[str, Any],
mcp_response: dict[str, Any],
degraded: bool,
) -> str:
task_code = task_asset.code if task_asset is not None else "task.unspecified"
suffix = ",其中部分能力已降级。" if degraded else ""
return (
f"Hermes 占位执行完成:任务 {task_code}"
f"场景 {ontology.scenario},规则结果={rule_response.get('message', '')}"
f"MCP 结果={mcp_response.get('message', '')}{suffix}"
)
@staticmethod
def _database_tool_name(scenario: str) -> str:
if scenario == "expense":
return "database.expense_claims.lookup"
if scenario == "accounts_receivable":
return "database.accounts_receivable.lookup"
return "database.accounts_payable.lookup"
@staticmethod
def _rule_tool_name(
capabilities: dict[str, list[AgentAssetListItem | AgentAssetRead]],
) -> str:
if capabilities["rules"]:
return capabilities["rules"][0].code
return "rule_engine.default_risk_check"
@staticmethod
def _mcp_tool_name(
capabilities: dict[str, list[AgentAssetListItem | AgentAssetRead]],
) -> str:
if capabilities["mcps"]:
return capabilities["mcps"][0].code
return "mcp.default_snapshot"
@staticmethod
def _build_ontology_json(ontology: OntologyParseResult) -> dict[str, Any]:
return {
"scenario": ontology.scenario,
"intent": ontology.intent,
"entities": [item.model_dump() for item in ontology.entities],
"time_range": ontology.time_range.model_dump(),
"metrics": [item.model_dump() for item in ontology.metrics],
"constraints": [item.model_dump() for item in ontology.constraints],
"risk_flags": ontology.risk_flags,
"permission": ontology.permission.model_dump(),
}
@staticmethod
def _normalize_response_status(status: str) -> str:
if status == AgentRunStatus.FAILED.value:
return "failed"
if status == AgentRunStatus.BLOCKED.value:
return "blocked"
return "succeeded"

View File

@@ -0,0 +1,252 @@
from __future__ import annotations
from http import HTTPStatus
from typing import Any
from sqlalchemy.orm import Session
from app.core.logging import get_logger
from app.services.model_connectivity import (
AZURE_API_VERSION,
ConnectivityCheckError,
_build_azure_deployment_base,
_build_headers,
_ensure_path,
_normalize_endpoint,
_send_json_request,
)
from app.services.settings import SettingsService
logger = get_logger("app.services.runtime_chat")
class RuntimeChatService:
def __init__(self, db: Session) -> None:
self.db = db
self.settings_service = SettingsService(db)
def complete(
self,
messages: list[dict[str, str]],
*,
slot_priority: tuple[str, ...] = ("main", "backup"),
max_tokens: int = 500,
temperature: float = 0.2,
) -> str | None:
for slot in slot_priority:
config = self._load_chat_slot(slot)
if config is None:
continue
try:
response_text = self._request_chat_completion(
config,
messages,
max_tokens=max_tokens,
temperature=temperature,
)
except Exception as exc:
logger.warning(
"Runtime chat request failed slot=%s provider=%s: %s",
slot,
config["provider"],
exc,
)
continue
if response_text:
return response_text.strip()
return None
def _load_chat_slot(self, slot: str) -> dict[str, str] | None:
try:
config = self.settings_service.get_runtime_model_config(slot)
except ValueError:
return None
if config["capability"] != "chat":
return None
provider = str(config["provider"] or "").strip()
endpoint = str(config["endpoint"] or "").strip()
model = str(config["model"] or "").strip()
api_key = str(config["apiKey"] or "").strip()
if not provider or not endpoint or not model:
return None
if provider != "Ollama" and not api_key:
logger.info("Skip runtime chat slot=%s because api key is empty", slot)
return None
return {
"slot": slot,
"provider": provider,
"endpoint": endpoint,
"model": model,
"apiKey": api_key,
}
def _request_chat_completion(
self,
config: dict[str, str],
messages: list[dict[str, str]],
*,
max_tokens: int,
temperature: float,
) -> str:
provider = config["provider"]
endpoint = config["endpoint"]
model = config["model"]
api_key = config["apiKey"]
if provider == "Azure OpenAI":
return self._request_azure_openai(
endpoint=endpoint,
model=model,
api_key=api_key,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
)
if provider == "Ollama":
return self._request_ollama(
endpoint=endpoint,
model=model,
api_key=api_key,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
)
return self._request_openai_compatible(
endpoint=endpoint,
model=model,
api_key=api_key,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
)
def _request_openai_compatible(
self,
*,
endpoint: str,
model: str,
api_key: str,
messages: list[dict[str, str]],
max_tokens: int,
temperature: float,
) -> str:
url = _ensure_path(_normalize_endpoint(endpoint), "chat/completions")
status_code, payload = _send_json_request(
"POST",
url,
headers=_build_headers(api_key=api_key, use_bearer=True),
payload={
"model": model,
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
},
)
if status_code >= HTTPStatus.BAD_REQUEST:
raise ConnectivityCheckError(
f"模型接口返回异常状态 {status_code}",
status_code=status_code,
)
return self._extract_openai_text(payload)
def _request_ollama(
self,
*,
endpoint: str,
model: str,
api_key: str,
messages: list[dict[str, str]],
max_tokens: int,
temperature: float,
) -> str:
url = _ensure_path(_normalize_endpoint(endpoint), "api/chat")
status_code, payload = _send_json_request(
"POST",
url,
headers=_build_headers(api_key=api_key, use_bearer=False),
payload={
"model": model,
"messages": messages,
"stream": False,
"options": {
"num_predict": max_tokens,
"temperature": temperature,
},
},
)
if status_code >= HTTPStatus.BAD_REQUEST:
raise ConnectivityCheckError(
f"Ollama 返回异常状态 {status_code}",
status_code=status_code,
)
return str((payload or {}).get("message", {}).get("content", "")).strip()
def _request_azure_openai(
self,
*,
endpoint: str,
model: str,
api_key: str,
messages: list[dict[str, str]],
max_tokens: int,
temperature: float,
) -> str:
deployment_base = _build_azure_deployment_base(endpoint, model)
url = f"{deployment_base}/chat/completions?api-version={AZURE_API_VERSION}"
status_code, payload = _send_json_request(
"POST",
url,
headers=_build_headers(api_key=api_key, use_bearer=False, use_api_key=True),
payload={
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
},
)
if status_code >= HTTPStatus.BAD_REQUEST:
raise ConnectivityCheckError(
f"Azure OpenAI 返回异常状态 {status_code}",
status_code=status_code,
)
return self._extract_openai_text(payload)
@staticmethod
def _extract_openai_text(payload: Any) -> str:
if not isinstance(payload, dict):
return ""
choices = payload.get("choices")
if not isinstance(choices, list) or not choices:
return ""
first_choice = choices[0]
if not isinstance(first_choice, dict):
return ""
message = first_choice.get("message")
if isinstance(message, dict):
content = message.get("content", "")
if isinstance(content, str):
return content.strip()
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
parts.append(str(item.get("text", "")))
return "\n".join(part.strip() for part in parts if part.strip()).strip()
text = first_choice.get("text")
if isinstance(text, str):
return text.strip()
return ""

View File

@@ -0,0 +1,547 @@
from __future__ import annotations
import json
import re
from sqlalchemy.orm import Session
from app.core.agent_enums import AgentAssetStatus, AgentAssetType
from app.schemas.agent_asset import AgentAssetListItem
from app.schemas.user_agent import (
UserAgentCitation,
UserAgentDraftPayload,
UserAgentRequest,
UserAgentResponse,
UserAgentSuggestedAction,
)
from app.services.agent_assets import AgentAssetService
from app.services.agent_foundation import AgentFoundationService
from app.services.runtime_chat import RuntimeChatService
SCENARIO_LABELS = {
"expense": "报销",
"accounts_receivable": "应收",
"accounts_payable": "应付",
"knowledge": "知识",
"unknown": "通用",
}
RISK_REASON_MAP = {
"duplicate_expense": "检测到同员工、同金额或近似单据存在重复提交迹象。",
"amount_over_limit": "金额超过当前制度或预算阈值,需要补充例外说明。",
"invoice_anomaly": "票据或附件完整性不满足当前规则要求,需要补件或人工复核。",
"ar_overdue": "应收账款已出现逾期,存在回款延迟风险。",
"ap_overdue": "应付付款已出现逾期,可能影响供应商履约或合作关系。",
}
GENERIC_EXPENSE_PROMPTS = {
"报销",
"我要报销",
"我想报销",
"帮我报销",
"我要申请报销",
"发起报销",
"提交报销",
}
EXPLICIT_DRAFT_KEYWORDS = ("生成", "草稿", "起草", "创建", "发起", "准备")
EXPENSE_TYPE_LABELS = {
"travel": "差旅",
"hotel": "住宿",
"transport": "交通",
"meal": "餐费",
"meeting": "会务",
"entertainment": "招待",
}
class UserAgentService:
def __init__(self, db: Session) -> None:
self.db = db
self.asset_service = AgentAssetService(db)
self.runtime_chat_service = RuntimeChatService(db)
def respond(self, payload: UserAgentRequest) -> UserAgentResponse:
AgentFoundationService(self.db).ensure_foundation_ready()
citations = self._build_rule_citations(payload)
suggested_actions = self._build_suggested_actions(payload)
risk_flags = self._resolve_risk_flags(payload)
draft_payload = (
self._build_draft_payload(payload)
if payload.ontology.intent == "draft"
else None
)
if payload.degraded and payload.tool_payload.get("message"):
return UserAgentResponse(
answer=str(payload.tool_payload["message"]),
citations=citations,
suggested_actions=suggested_actions,
risk_flags=risk_flags,
requires_confirmation=payload.requires_confirmation,
)
guided_answer = self._build_guided_answer(payload)
if guided_answer:
return UserAgentResponse(
answer=guided_answer,
citations=citations,
suggested_actions=suggested_actions,
draft_payload=draft_payload,
risk_flags=risk_flags,
requires_confirmation=payload.requires_confirmation,
)
fallback_answer = self._build_fallback_answer(
payload,
citations=citations,
draft_payload=draft_payload,
)
answer = self._generate_answer_with_model(
payload,
citations=citations,
suggested_actions=suggested_actions,
risk_flags=risk_flags,
draft_payload=draft_payload,
fallback_answer=fallback_answer,
)
return UserAgentResponse(
answer=answer or fallback_answer,
citations=citations,
suggested_actions=suggested_actions,
draft_payload=draft_payload,
risk_flags=risk_flags,
requires_confirmation=payload.requires_confirmation,
)
def _build_fallback_answer(
self,
payload: UserAgentRequest,
*,
citations: list[UserAgentCitation],
draft_payload: UserAgentDraftPayload | None,
) -> str:
if payload.ontology.intent in {"query", "compare"}:
return self._build_query_answer(payload)
if payload.ontology.intent == "risk_check":
return self._build_risk_answer(payload, citations)
if payload.ontology.intent == "draft" and draft_payload is not None:
return (
f"已生成 {draft_payload.title},当前仅返回待人工确认的草稿内容,"
"仍需人工确认后再进入正式流程。"
)
return self._build_explain_answer(payload, citations)
def _build_guided_answer(self, payload: UserAgentRequest) -> str | None:
if not self._is_generic_expense_prompt(payload):
return self._build_implicit_expense_draft_guidance(payload)
attachment_names = self._resolve_attachment_names(payload)
attachment_hint = ""
if attachment_names:
attachment_hint = (
f" 我已带入 {len(attachment_names)} 份附件名称,但目前还不能直接读取附件内容,"
"仍需要你补充关键信息。"
)
return (
"可以帮你发起报销。请补充费用类型、发生时间、金额、事由和相关对象,"
"或者直接上传票据附件,我再继续帮你判断能否报、缺什么材料以及生成报销草稿。"
f"{attachment_hint}"
)
def _build_implicit_expense_draft_guidance(
self,
payload: UserAgentRequest,
) -> str | None:
if not self._is_implicit_expense_draft_request(payload):
return None
amount_text = next(
(item.value for item in payload.ontology.entities if item.type == "amount"),
"",
)
expense_type = next(
(
EXPENSE_TYPE_LABELS.get(item.normalized_value, item.value)
for item in payload.ontology.entities
if item.type == "expense_type"
),
"报销",
)
time_text = payload.ontology.time_range.raw or "本次"
amount_hint = f",金额 {amount_text}" if amount_text else ""
return (
f"已识别到一笔{time_text}{expense_type}支出{amount_hint}"
"如果要继续生成报销草稿,还需要补充客户单位、参与人员、费用明细和票据附件。"
"你也可以继续上传发票或图片,我会把这些信息带入后续对话。"
)
def _generate_answer_with_model(
self,
payload: UserAgentRequest,
*,
citations: list[UserAgentCitation],
suggested_actions: list[UserAgentSuggestedAction],
risk_flags: list[str],
draft_payload: UserAgentDraftPayload | None,
fallback_answer: str,
) -> str | None:
messages = self._build_model_messages(
payload,
citations=citations,
suggested_actions=suggested_actions,
risk_flags=risk_flags,
draft_payload=draft_payload,
fallback_answer=fallback_answer,
)
return self._sanitize_model_answer(
self.runtime_chat_service.complete(
messages,
max_tokens=420,
temperature=0.2,
)
)
def _sanitize_model_answer(self, answer: str | None) -> str | None:
if not answer:
return None
cleaned = re.sub(r"<think>.*?</think>", "", answer, flags=re.DOTALL | re.IGNORECASE)
cleaned = cleaned.strip()
return cleaned or None
def _build_model_messages(
self,
payload: UserAgentRequest,
*,
citations: list[UserAgentCitation],
suggested_actions: list[UserAgentSuggestedAction],
risk_flags: list[str],
draft_payload: UserAgentDraftPayload | None,
fallback_answer: str,
) -> list[dict[str, str]]:
facts = {
"run_id": payload.run_id,
"user_message": payload.message,
"ontology": payload.ontology.model_dump(mode="json"),
"context": {
"entry_source": payload.context_json.get("entry_source"),
"user_name": payload.context_json.get("name"),
"user_role": payload.context_json.get("role"),
"request_context": payload.context_json.get("request_context"),
"attachment_count": payload.context_json.get("attachment_count"),
"attachment_names": self._resolve_attachment_names(payload),
},
"tool_payload": payload.tool_payload,
"citations": [item.model_dump(mode="json") for item in citations],
"suggested_actions": [
item.model_dump(mode="json") for item in suggested_actions
],
"risk_flags": risk_flags,
"draft_payload": (
draft_payload.model_dump(mode="json")
if draft_payload is not None
else None
),
"selected_capability_codes": payload.selected_capability_codes,
"requires_confirmation": payload.requires_confirmation,
"fallback_answer": fallback_answer,
}
system_prompt = (
"你是企业财务共享场景中的中文智能助手,负责和最终用户直接对话。"
"你只能基于提供的事实回答,不能编造制度、流程结果或附件内容。"
"如果用户问题很笼统,例如“我要报销”,优先告诉用户你可以协助什么,"
"并明确要求补充费用类型、金额、时间、事由、参与对象或上传票据。"
"如果上下文里只有附件名称,必须明确说明你只拿到了附件名称,"
"不能假装已看过图片、PDF 或发票内容。"
"不要声称已经提交、审批、付款、入账或真正执行了任何动作;如果只是建议、草稿或待确认,要明确说清楚。"
"若给出了风险标签、制度引用或建议动作,可以简洁吸收进回答,但不要新增未提供的事实。"
"只输出最终给用户看的自然语言,不要输出 JSON、Markdown、标题、"
"<think> 标签或任何中间推理。"
"使用简体中文,控制在 2 到 4 句。"
)
user_prompt = (
"请根据以下事实生成最终答复,优先保持准确、具体、可执行:\n"
f"{json.dumps(facts, ensure_ascii=False, indent=2)}"
)
return [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
def _build_query_answer(self, payload: UserAgentRequest) -> str:
scenario = payload.ontology.scenario
data = payload.tool_payload
subject = self._resolve_subject(payload)
if scenario == "expense":
record_count = int(data.get("record_count") or 0)
total_amount = float(data.get("total_amount") or 0)
return (
f"{subject}共命中 {record_count} 笔报销,金额合计 {total_amount:.2f} 元。"
"如需继续处理,可以查看明细或生成处理意见草稿。"
)
if scenario == "accounts_receivable":
record_count = int(data.get("record_count") or 0)
outstanding_amount = float(data.get("outstanding_amount") or 0)
return (
f"{subject}共命中 {record_count} 条应收,未回款金额 {outstanding_amount:.2f} 元。"
"建议结合账龄和客户分布继续排查逾期风险。"
)
if scenario == "accounts_payable":
record_count = int(data.get("record_count") or 0)
outstanding_amount = float(data.get("outstanding_amount") or 0)
return (
f"{subject}共命中 {record_count} 条应付,待付金额 {outstanding_amount:.2f} 元。"
"如需推进动作,建议先生成付款建议草稿并发起人工确认。"
)
return "已完成当前查询,但暂时没有更多结构化结果可展示。"
def _build_explain_answer(
self,
payload: UserAgentRequest,
citations: list[UserAgentCitation],
) -> str:
if citations:
titles = "".join(item.title for item in citations[:2])
summary = citations[0].excerpt or "请结合制度全文进一步确认。"
return f"已检索到相关依据:{titles}。核心说明:{summary}"
return (
f"当前还没有与“{SCENARIO_LABELS.get(payload.ontology.scenario, '当前问题')}"
"强匹配的已上线规则引用,建议先人工复核或补充更具体的单据上下文。"
)
def _build_risk_answer(
self,
payload: UserAgentRequest,
citations: list[UserAgentCitation],
) -> str:
risk_flags = self._resolve_risk_flags(payload)
if not risk_flags:
return "当前未识别到明确风险标签,建议继续查看原始明细或补充更多上下文。"
reasons = [RISK_REASON_MAP.get(flag, f"{flag} 需要人工进一步确认。") for flag in risk_flags]
citation_text = (
f" 参考规则:{''.join(item.title for item in citations[:2])}"
if citations
else ""
)
return (
f"本次识别到 {len(risk_flags)} 类风险:{''.join(risk_flags)}"
f"触发原因:{''.join(reasons)}"
"建议先复核明细、附件和审批链,再决定是否继续处理。"
f"{citation_text}"
)
def _build_draft_payload(self, payload: UserAgentRequest) -> UserAgentDraftPayload:
scenario_label = SCENARIO_LABELS.get(payload.ontology.scenario, "业务")
subject = self._resolve_subject(payload)
title = f"{scenario_label}处理意见草稿"
body = (
f"主题:{subject}\n"
"结论:已根据当前语义解析结果生成草稿,尚未自动执行。\n"
"建议:请先核对明细、规则命中和所需附件,再由人工确认是否提交正式流程。\n"
f"原始问题:{payload.message}"
)
return UserAgentDraftPayload(
draft_type=payload.ontology.scenario,
title=title,
body=body,
confirmation_required=True,
)
def _build_suggested_actions(
self,
payload: UserAgentRequest,
) -> list[UserAgentSuggestedAction]:
if self._is_generic_expense_prompt(payload):
return [
UserAgentSuggestedAction(
label="上传票据",
action_type="ask_clarification",
description="上传发票、行程单或付款截图,继续识别报销内容。",
),
UserAgentSuggestedAction(
label="补充报销信息",
action_type="ask_clarification",
description="补充费用类型、金额、时间和事由后继续处理。",
),
]
if payload.ontology.intent in {"query", "compare"}:
return [
UserAgentSuggestedAction(
label="查看明细",
action_type="open_detail",
description="继续查看命中记录和过滤条件。",
),
UserAgentSuggestedAction(
label="生成处理意见",
action_type="create_draft",
description="把当前查询结果整理成可确认草稿。",
),
]
if payload.ontology.intent == "risk_check":
return [
UserAgentSuggestedAction(
label="人工复核风险",
action_type="manual_review",
description="优先检查明细、附件和规则命中原因。",
),
UserAgentSuggestedAction(
label="生成整改建议",
action_type="create_draft",
description="把风险说明整理成处理意见草稿。",
),
]
if payload.ontology.intent == "draft":
return [
UserAgentSuggestedAction(
label="复制草稿",
action_type="copy_draft",
description="复制当前草稿后交由人工确认。",
),
UserAgentSuggestedAction(
label="补充上下文",
action_type="ask_clarification",
description="补充单据编号、客户或供应商信息以完善草稿。",
),
]
return [
UserAgentSuggestedAction(
label="查看规则全文",
action_type="open_rule",
description="继续查看引用规则或知识内容。",
),
UserAgentSuggestedAction(
label="补充问题上下文",
action_type="ask_clarification",
description="补充业务对象、时间或单据范围,提升回答准确度。",
),
]
def _build_rule_citations(self, payload: UserAgentRequest) -> list[UserAgentCitation]:
domain = self._resolve_domain(payload.ontology.scenario)
items = self.asset_service.list_assets(
asset_type=AgentAssetType.RULE.value,
status=AgentAssetStatus.ACTIVE.value,
domain=domain,
)
ranked = self._rank_rule_assets(items, payload)
citations: list[UserAgentCitation] = []
for item in ranked[:2]:
detail = self.asset_service.get_asset(item.id)
if detail is None:
continue
excerpt = self._extract_excerpt(str(detail.current_version_content or ""))
citations.append(
UserAgentCitation(
source_type="rule",
code=detail.code,
title=detail.name,
version=detail.current_version,
updated_at=detail.updated_at.date().isoformat(),
excerpt=excerpt,
)
)
return citations
@staticmethod
def _resolve_risk_flags(payload: UserAgentRequest) -> list[str]:
tool_flags = payload.tool_payload.get("risk_flags")
if isinstance(tool_flags, list) and tool_flags:
return [str(item) for item in tool_flags]
return [str(item) for item in payload.ontology.risk_flags]
@staticmethod
def _resolve_subject(payload: UserAgentRequest) -> str:
named_entities = [
item.value
for item in payload.ontology.entities
if item.type in {"employee", "customer", "vendor", "project"}
]
if named_entities:
return f"{''.join(named_entities)} 相关数据"
return f"{SCENARIO_LABELS.get(payload.ontology.scenario, '当前')}场景数据"
@staticmethod
def _is_generic_expense_prompt(payload: UserAgentRequest) -> bool:
if payload.ontology.scenario != "expense":
return False
normalized_message = re.sub(r"\s+", "", payload.message)
return normalized_message in GENERIC_EXPENSE_PROMPTS
@staticmethod
def _is_implicit_expense_draft_request(payload: UserAgentRequest) -> bool:
if payload.ontology.scenario != "expense" or payload.ontology.intent != "draft":
return False
compact_message = re.sub(r"\s+", "", payload.message)
if any(keyword in compact_message for keyword in EXPLICIT_DRAFT_KEYWORDS):
return False
return True
@staticmethod
def _resolve_attachment_names(payload: UserAgentRequest) -> list[str]:
names = payload.context_json.get("attachment_names")
if not isinstance(names, list):
return []
return [str(name) for name in names if str(name).strip()]
@staticmethod
def _resolve_domain(scenario: str) -> str | None:
if scenario == "expense":
return "expense"
if scenario == "accounts_receivable":
return "ar"
if scenario == "accounts_payable":
return "ap"
return None
@staticmethod
def _rank_rule_assets(
items: list[AgentAssetListItem],
payload: UserAgentRequest,
) -> list[AgentAssetListItem]:
def score(item: AgentAssetListItem) -> tuple[int, str]:
tags = {str(value) for value in item.scenario_json or []}
weight = 0
if payload.ontology.scenario in tags:
weight += 3
if payload.ontology.intent in tags:
weight += 2
for risk_flag in payload.ontology.risk_flags:
if risk_flag in tags:
weight += 4
return weight, item.code
ranked = sorted(items, key=score, reverse=True)
return [item for item in ranked if score(item)[0] > 0]
@staticmethod
def _extract_excerpt(content: str) -> str:
lines = [line.strip() for line in str(content).splitlines() if line.strip()]
cleaned: list[str] = []
for line in lines:
normalized = re.sub(r"^[#>\-\*\d\.\s`]+", "", line).strip()
if normalized:
cleaned.append(normalized)
if len(cleaned) >= 2:
break
return "".join(cleaned[:2])

View File

@@ -0,0 +1,397 @@
from __future__ import annotations
from collections.abc import Generator
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import StaticPool
from app.api.deps import get_db
from app.db.base import Base
from app.main import create_app
from app.schemas.ontology import OntologyParseRequest
from app.services.ontology import LlmOntologyParseResult, SemanticOntologyService
def build_session_factory() -> sessionmaker[Session]:
engine = create_engine(
"sqlite+pysqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
Base.metadata.create_all(bind=engine)
return sessionmaker(bind=engine, autoflush=False, autocommit=False)
def build_client() -> tuple[TestClient, sessionmaker[Session]]:
session_factory = build_session_factory()
app = create_app()
def override_db() -> Generator[Session, None, None]:
db = session_factory()
try:
yield db
finally:
db.close()
app.dependency_overrides[get_db] = override_db
return TestClient(app), session_factory
EVALUATION_CASES = [
pytest.param(
"查一下本周报销超标风险",
"expense",
"risk_check",
"read",
{},
id="expense-risk-check",
),
pytest.param(
"张三 4 月差旅报销金额是多少",
"expense",
"query",
"read",
{},
id="expense-query-employee-month",
),
pytest.param(
"为什么酒店超标报销不能直接通过",
"expense",
"explain",
"read",
{},
id="expense-explain-policy",
),
pytest.param(
"列出金额最高的10笔报销",
"expense",
"query",
"read",
{},
id="expense-topn-query",
),
pytest.param(
"帮我生成张三4月差旅报销草稿",
"expense",
"draft",
"draft_write",
{},
id="expense-draft",
),
pytest.param(
"我今天去客户现场招待了客户花销了1000元",
"expense",
"draft",
"draft_write",
{},
id="expense-narrative-draft",
),
pytest.param(
"客户 A 这个月还有多少应收",
"accounts_receivable",
"query",
"read",
{},
id="ar-query-customer-month",
),
pytest.param(
"对比客户A和客户B本月应收差异",
"accounts_receivable",
"compare",
"read",
{},
id="ar-compare-customers",
),
pytest.param(
"检查客户B逾期应收风险",
"accounts_receivable",
"risk_check",
"read",
{},
id="ar-risk-check",
),
pytest.param(
"生成客户A回款跟进草稿",
"accounts_receivable",
"draft",
"draft_write",
{},
id="ar-draft",
),
pytest.param(
"查询客户B账龄明细",
"accounts_receivable",
"query",
"read",
{},
id="ar-aging-query",
),
pytest.param(
"供应商 B 明天要付多少钱",
"accounts_payable",
"query",
"read",
{},
id="ap-query-vendor-tomorrow",
),
pytest.param(
"对比供应商A和供应商B本月应付差异",
"accounts_payable",
"compare",
"read",
{},
id="ap-compare-vendors",
),
pytest.param(
"检查供应商B逾期付款风险",
"accounts_payable",
"risk_check",
"read",
{},
id="ap-risk-check",
),
pytest.param(
"生成供应商A付款沟通草稿",
"accounts_payable",
"draft",
"draft_write",
{},
id="ap-draft",
),
pytest.param(
"帮我安排付款给供应商B",
"accounts_payable",
"operate",
"approval_required",
{"role_codes": ["finance"]},
id="ap-operate-approval-required",
),
pytest.param(
"公司财务制度在哪里看",
"knowledge",
"query",
"read",
{},
id="knowledge-query",
),
pytest.param(
"规则中心的审核依据是什么",
"knowledge",
"explain",
"read",
{},
id="knowledge-explain",
),
pytest.param(
"知识库里有没有双人复核制度",
"knowledge",
"query",
"read",
{},
id="knowledge-query-library",
),
pytest.param(
"帮我直接付款给供应商B",
"accounts_payable",
"operate",
"forbidden",
{"role_codes": ["user"]},
id="forbidden-direct-payment",
),
pytest.param(
"帮我上线付款双人复核规则",
"accounts_payable",
"operate",
"forbidden",
{"role_codes": ["user"]},
id="forbidden-activate-rule",
),
pytest.param(
"帮我删除今天的报销记录",
"expense",
"operate",
"forbidden",
{"role_codes": ["user"]},
id="forbidden-delete-expense",
),
]
@pytest.mark.parametrize("query,scenario,intent,permission,context_json", EVALUATION_CASES)
def test_semantic_ontology_service_matches_day3_evaluation_set(
query: str,
scenario: str,
intent: str,
permission: str,
context_json: dict,
) -> None:
session_factory = build_session_factory()
with session_factory() as db:
result = SemanticOntologyService(db).parse(
OntologyParseRequest(
query=query,
user_id="pytest",
context_json=context_json,
)
)
assert result.scenario == scenario
assert result.intent == intent
assert result.permission.level == permission
assert result.run_id.startswith("run_")
def test_semantic_ontology_service_extracts_entities_time_and_constraints() -> None:
session_factory = build_session_factory()
with session_factory() as db:
result = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="张三 2026年4月差旅报销金额超过5000元的明细",
user_id="pytest",
)
)
assert result.scenario == "expense"
assert result.intent == "query"
assert result.time_range.start_date == "2026-04-01"
assert result.time_range.end_date == "2026-04-30"
assert any(
item.type == "employee" and item.normalized_value == "张三"
for item in result.entities
)
assert any(
item.type == "expense_type" and item.normalized_value == "travel"
for item in result.entities
)
assert any(
item.field == "amount" and item.operator == ">" and item.value == 5000
for item in result.constraints
)
def test_semantic_ontology_service_prefers_expense_for_customer_entertainment_narrative() -> None:
session_factory = build_session_factory()
with session_factory() as db:
result = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="我今天去客户现场招待了客户花销了1000元",
user_id="pytest",
)
)
assert result.scenario == "expense"
assert result.intent == "draft"
assert result.permission.level == "draft_write"
assert result.time_range.raw == "今天"
assert result.clarification_required is True
assert "customer_name" in result.missing_slots
assert "participants" in result.missing_slots
assert any(
item.type == "expense_type" and item.normalized_value == "entertainment"
for item in result.entities
)
def test_semantic_ontology_service_uses_model_parse_when_available(monkeypatch) -> None:
session_factory = build_session_factory()
with session_factory() as db:
service = SemanticOntologyService(db)
monkeypatch.setattr(
service,
"_parse_with_model",
lambda **kwargs: LlmOntologyParseResult(
scenario="expense",
intent="draft",
confidence=0.91,
clarification_required=True,
clarification_question="请补充费用类型、金额和票据附件。",
missing_slots=["expense_type", "amount", "attachments"],
ambiguity=[],
entity_hints=[],
),
)
result = service.parse(
OntologyParseRequest(
query="我要报销",
user_id="pytest",
)
)
assert result.scenario == "expense"
assert result.intent == "draft"
assert result.parse_strategy == "llm_primary"
assert result.clarification_required is True
assert "expense_type" in result.missing_slots
assert result.clarification_question == "请补充费用类型、金额和票据附件。"
def test_parse_ontology_endpoint_returns_eight_fields_and_writes_trace() -> None:
client, _ = build_client()
response = client.post(
"/api/v1/ontology/parse",
json={
"query": "查一下本周报销超标风险",
"user_id": "pytest",
"context_json": {"role_codes": ["finance"]},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["scenario"] == "expense"
assert payload["intent"] == "risk_check"
assert payload["permission"]["level"] == "read"
assert payload["run_id"].startswith("run_")
assert set(payload) >= {
"scenario",
"intent",
"entities",
"time_range",
"metrics",
"constraints",
"risk_flags",
"permission",
"confidence",
"missing_slots",
"ambiguity",
"parse_strategy",
"clarification_required",
"clarification_question",
"run_id",
"field_errors",
}
run_response = client.get(f"/api/v1/agent-runs/{payload['run_id']}")
assert run_response.status_code == 200
run_payload = run_response.json()
assert run_payload["ontology_json"]["scenario"] == "expense"
assert run_payload["ontology_json"]["intent"] == "risk_check"
assert run_payload["semantic_parse"]["scenario"] == "expense"
assert run_payload["semantic_parse"]["intent"] == "risk_check"
def test_parse_ontology_endpoint_returns_forbidden_for_unprivileged_payment_request() -> None:
client, _ = build_client()
response = client.post(
"/api/v1/ontology/parse",
json={
"query": "帮我直接付款给供应商B",
"user_id": "pytest",
"context_json": {"role_codes": ["user"]},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["scenario"] == "accounts_payable"
assert payload["intent"] == "operate"
assert payload["permission"]["level"] == "forbidden"
assert payload["clarification_required"] is True
assert payload["field_errors"]

View File

@@ -0,0 +1,241 @@
from __future__ import annotations
from collections.abc import Generator
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import StaticPool
from app.api.deps import get_db
from app.db.base import Base
from app.main import create_app
from app.services.agent_assets import AgentAssetService
def build_client() -> tuple[TestClient, sessionmaker[Session]]:
engine = create_engine(
"sqlite+pysqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
Base.metadata.create_all(bind=engine)
session_factory = sessionmaker(bind=engine, autoflush=False, autocommit=False)
app = create_app()
def override_db() -> Generator[Session, None, None]:
db = session_factory()
try:
yield db
finally:
db.close()
app.dependency_overrides[get_db] = override_db
return TestClient(app), session_factory
def test_orchestrator_routes_user_query_to_user_agent() -> None:
client, _ = build_client()
response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": "pytest",
"message": "客户A这个月还有多少应收",
"context_json": {"role_codes": ["finance"]},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["selected_agent"] == "user_agent"
assert payload["permission_level"] == "read"
assert payload["status"] == "succeeded"
assert payload["result"]["answer"]
assert payload["result"]["suggested_actions"]
assert payload["trace_summary"]["tool_count"] >= 1
run_detail = client.get(f"/api/v1/agent-runs/{payload['run_id']}").json()
assert run_detail["agent"] == "user_agent"
assert run_detail["route_json"]["selected_agent"] == "user_agent"
assert run_detail["semantic_parse"]["scenario"] == "accounts_receivable"
assert run_detail["tool_calls"][0]["tool_type"] == "database"
def test_orchestrator_routes_schedule_to_hermes() -> None:
client, session_factory = build_client()
with session_factory() as db:
task = next(
item
for item in AgentAssetService(db).list_assets(asset_type="task", status="active")
if item.code == "task.hermes.daily_risk_scan"
)
response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "schedule",
"task_id": task.id,
"context_json": {"role_codes": ["finance"]},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["selected_agent"] == "hermes"
assert payload["status"] == "succeeded"
assert payload["trace_summary"]["tool_count"] == 2
run_detail = client.get(f"/api/v1/agent-runs/{payload['run_id']}").json()
assert run_detail["agent"] == "hermes"
assert run_detail["route_json"]["selected_agent"] == "hermes"
assert len(run_detail["tool_calls"]) == 2
def test_orchestrator_forbidden_request_does_not_call_downstream_agent() -> None:
client, _ = build_client()
response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": "pytest",
"message": "帮我直接付款给供应商B",
"context_json": {"role_codes": ["user"]},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["selected_agent"] is None
assert payload["permission_level"] == "forbidden"
assert payload["status"] == "blocked"
assert payload["trace_summary"]["tool_count"] == 0
run_detail = client.get(f"/api/v1/agent-runs/{payload['run_id']}").json()
assert run_detail["agent"] == "orchestrator"
assert run_detail["tool_calls"] == []
def test_orchestrator_approval_required_returns_confirmation_result() -> None:
client, _ = build_client()
response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": "pytest",
"message": "帮我安排付款给供应商B",
"context_json": {"role_codes": ["finance"]},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["selected_agent"] == "user_agent"
assert payload["permission_level"] == "approval_required"
assert payload["requires_confirmation"] is True
assert payload["status"] == "blocked"
assert "确认" in payload["result"]["message"]
def test_orchestrator_user_agent_draft_returns_structured_payload() -> None:
client, _ = build_client()
response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": "pytest",
"message": "帮我生成张三4月差旅报销草稿",
"context_json": {"role_codes": ["finance"]},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["selected_agent"] == "user_agent"
assert payload["status"] == "succeeded"
assert payload["result"]["draft_payload"]["confirmation_required"] is True
assert payload["result"]["suggested_actions"]
def test_orchestrator_treats_expense_narrative_as_draft_instead_of_ar_query() -> None:
client, _ = build_client()
response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": "pytest",
"message": "我今天去客户现场招待了客户花销了1000元",
"context_json": {"role_codes": ["finance"]},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["selected_agent"] == "user_agent"
assert payload["permission_level"] == "draft_write"
assert payload["status"] == "blocked"
assert payload["route_reason"] == "clarification_required"
assert payload["trace_summary"]["scenario"] == "expense"
assert payload["trace_summary"]["intent"] == "draft"
assert payload["trace_summary"]["tool_count"] == 0
assert "应收场景数据" not in payload["result"]["message"]
assert "请补充" in payload["result"]["message"]
def test_orchestrator_tool_failure_is_logged_and_degraded() -> None:
client, _ = build_client()
response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": "pytest",
"message": "查一下本周报销金额",
"context_json": {
"role_codes": ["finance"],
"simulate_tool_failure": "database",
},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["selected_agent"] == "user_agent"
assert payload["status"] == "succeeded"
assert payload["trace_summary"]["failed_tool_count"] == 1
assert payload["trace_summary"]["degraded"] is True
run_detail = client.get(f"/api/v1/agent-runs/{payload['run_id']}").json()
assert run_detail["tool_calls"][0]["status"] == "failed"
assert "simulated database failure" in run_detail["tool_calls"][0]["error_message"]
def test_orchestrator_exception_is_written_to_agent_run() -> None:
client, _ = build_client()
response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": "pytest",
"message": "查一下本周报销金额",
"context_json": {
"role_codes": ["finance"],
"simulate_orchestrator_exception": True,
},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["status"] == "failed"
run_detail = client.get(f"/api/v1/agent-runs/{payload['run_id']}").json()
assert run_detail["status"] == "failed"
assert "simulated orchestrator exception" in run_detail["error_message"]

View File

@@ -0,0 +1,179 @@
from __future__ import annotations
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import StaticPool
from app.db.base import Base
from app.schemas.ontology import OntologyParseRequest
from app.schemas.user_agent import UserAgentRequest
from app.services.ontology import SemanticOntologyService
from app.services.user_agent import UserAgentService
def build_session_factory() -> sessionmaker[Session]:
engine = create_engine(
"sqlite+pysqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
Base.metadata.create_all(bind=engine)
return sessionmaker(bind=engine, autoflush=False, autocommit=False)
def test_user_agent_query_returns_readable_answer_and_actions() -> None:
session_factory = build_session_factory()
with session_factory() as db:
ontology = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="张三 4 月差旅报销金额是多少",
user_id="pytest",
)
)
response = UserAgentService(db).respond(
UserAgentRequest(
run_id=ontology.run_id,
user_id="pytest",
message="张三 4 月差旅报销金额是多少",
ontology=ontology,
tool_payload={"record_count": 2, "total_amount": 8800.0},
)
)
assert "8800.00" in response.answer
assert len(response.suggested_actions) >= 1
def test_user_agent_prefers_runtime_model_answer_when_available(monkeypatch) -> None:
session_factory = build_session_factory()
with session_factory() as db:
ontology = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="张三 4 月差旅报销金额是多少",
user_id="pytest",
)
)
service = UserAgentService(db)
monkeypatch.setattr(
service,
"_generate_answer_with_model",
lambda *args, **kwargs: "这是模型回答",
)
response = service.respond(
UserAgentRequest(
run_id=ontology.run_id,
user_id="pytest",
message="张三 4 月差旅报销金额是多少",
ontology=ontology,
tool_payload={"record_count": 2, "total_amount": 8800.0},
)
)
assert response.answer == "这是模型回答"
def test_user_agent_sanitizes_model_thinking_blocks() -> None:
session_factory = build_session_factory()
with session_factory() as db:
service = UserAgentService(db)
assert (
service._sanitize_model_answer("<think>内部推理</think>\n最终答复")
== "最终答复"
)
def test_user_agent_guides_generic_expense_request() -> None:
session_factory = build_session_factory()
with session_factory() as db:
ontology = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="我要报销",
user_id="pytest",
)
)
response = UserAgentService(db).respond(
UserAgentRequest(
run_id=ontology.run_id,
user_id="pytest",
message="我要报销",
ontology=ontology,
tool_payload={"record_count": 9, "total_amount": 12345.0},
)
)
assert "补充费用类型" in response.answer
assert "上传票据" in response.answer
def test_user_agent_guides_implicit_expense_draft_request() -> None:
session_factory = build_session_factory()
with session_factory() as db:
ontology = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="我今天去客户现场招待了客户花销了1000元",
user_id="pytest",
)
)
response = UserAgentService(db).respond(
UserAgentRequest(
run_id=ontology.run_id,
user_id="pytest",
message="我今天去客户现场招待了客户花销了1000元",
ontology=ontology,
tool_payload={"draft_only": True},
)
)
assert "1000元" in response.answer
assert "票据附件" in response.answer
assert "报销草稿" in response.answer
def test_user_agent_risk_response_includes_rule_citations() -> None:
session_factory = build_session_factory()
with session_factory() as db:
ontology = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="检查重复报销风险",
user_id="pytest",
)
)
response = UserAgentService(db).respond(
UserAgentRequest(
run_id=ontology.run_id,
user_id="pytest",
message="检查重复报销风险",
ontology=ontology,
tool_payload={"risk_flags": ["duplicate_expense"]},
)
)
assert response.risk_flags == ["duplicate_expense"]
assert any(item.source_type == "rule" for item in response.citations)
assert "duplicate_expense" in response.answer
def test_user_agent_draft_returns_structured_payload() -> None:
session_factory = build_session_factory()
with session_factory() as db:
ontology = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="帮我生成张三4月差旅报销草稿",
user_id="pytest",
)
)
response = UserAgentService(db).respond(
UserAgentRequest(
run_id=ontology.run_id,
user_id="pytest",
message="帮我生成张三4月差旅报销草稿",
ontology=ontology,
tool_payload={"draft_only": True},
)
)
assert response.draft_payload is not None
assert response.draft_payload.confirmation_required is True
assert "待人工确认" in response.answer