diff --git a/server/src/app/api/v1/endpoints/ontology.py b/server/src/app/api/v1/endpoints/ontology.py new file mode 100644 index 0000000..8438bcf --- /dev/null +++ b/server/src/app/api/v1/endpoints/ontology.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from typing import Annotated + +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy.orm import Session + +from app.api.deps import get_db +from app.schemas.common import ErrorResponse +from app.schemas.ontology import OntologyParseRequest, OntologyParseResult +from app.services.ontology import SemanticOntologyService + +router = APIRouter(prefix="/ontology") +DbSession = Annotated[Session, Depends(get_db)] + + +@router.post( + "/parse", + response_model=OntologyParseResult, + summary="解析自然语言为语义本体", + description=( + "把自然语言问题解析成 Day 3 约定的 8 个核心字段," + "并写入 AgentRun 与 SemanticParseLog。" + ), + responses={ + status.HTTP_400_BAD_REQUEST: { + "model": ErrorResponse, + "description": "请求缺少有效 query 或解析请求格式不合法。", + } + }, +) +def parse_ontology(payload: OntologyParseRequest, db: DbSession) -> OntologyParseResult: + try: + return SemanticOntologyService(db).parse(payload) + except ValueError as exc: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc diff --git a/server/src/app/api/v1/endpoints/orchestrator.py b/server/src/app/api/v1/endpoints/orchestrator.py new file mode 100644 index 0000000..4944ec6 --- /dev/null +++ b/server/src/app/api/v1/endpoints/orchestrator.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from typing import Annotated + +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy.orm import Session + +from app.api.deps import get_db +from app.schemas.common import ErrorResponse +from app.schemas.orchestrator import OrchestratorRequest, OrchestratorResponse +from app.services.orchestrator import OrchestratorService + +router = APIRouter(prefix="/orchestrator") +DbSession = Annotated[Session, Depends(get_db)] + + +@router.post( + "/run", + response_model=OrchestratorResponse, + summary="运行 Orchestrator 统一调度", + description="统一接收用户消息、定时任务和系统事件,完成语义解析、权限判断、路由和占位执行。", + responses={ + status.HTTP_400_BAD_REQUEST: { + "model": ErrorResponse, + "description": "请求缺少 message 或 task_id,无法启动调度。", + } + }, +) +def run_orchestrator(payload: OrchestratorRequest, db: DbSession) -> OrchestratorResponse: + try: + return OrchestratorService(db).run(payload) + except ValueError as exc: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc diff --git a/server/src/app/schemas/ontology.py b/server/src/app/schemas/ontology.py new file mode 100644 index 0000000..3bd92e9 --- /dev/null +++ b/server/src/app/schemas/ontology.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +from typing import Any, Literal + +from pydantic import BaseModel, ConfigDict, Field + +OntologyScenario = Literal[ + "expense", + "accounts_receivable", + "accounts_payable", + "knowledge", + "unknown", +] +OntologyIntent = Literal["query", "explain", "compare", "risk_check", "draft", "operate"] +OntologyPermissionLevel = Literal["read", "draft_write", "approval_required", "forbidden"] +OntologyParseStrategy = Literal["llm_primary", "rule_fallback"] + + +class OntologyEntity(BaseModel): + model_config = ConfigDict(extra="forbid") + + type: str = Field(description="业务对象类型,例如 employee / customer / vendor。") + value: str = Field(description="从原始问题中提取的对象值。") + normalized_value: str = Field(description="标准化后的对象值。") + role: str = Field(default="target", description="对象角色,例如 target / filter / threshold。") + confidence: float = Field(default=0.0, ge=0.0, le=1.0, description="字段级置信度。") + + +class OntologyTimeRange(BaseModel): + model_config = ConfigDict(extra="forbid") + + raw: str = Field(default="", description="命中的原始时间表达。") + start_date: str | None = Field(default=None, description="ISO 格式起始日期。") + end_date: str | None = Field(default=None, description="ISO 格式结束日期。") + granularity: str | None = Field( + default=None, + description="day / week / month / quarter / year。", + ) + + +class OntologyMetric(BaseModel): + model_config = ConfigDict(extra="forbid") + + name: str = Field(description="指标名,例如 amount / count / overdue。") + aggregation: str | None = Field(default=None, description="sum / count / max 等聚合口径。") + unit: str | None = Field(default=None, description="金额、数量等单位。") + sort: str | None = Field(default=None, description="asc / desc 排序方向。") + top_n: int | None = Field(default=None, ge=1, description="Top N 口径。") + + +class OntologyConstraint(BaseModel): + model_config = ConfigDict(extra="forbid") + + field: str = Field(description="约束字段,例如 department / status / amount。") + operator: str = Field(description="操作符,例如 = / > / < / desc。") + value: str | int | float | bool = Field(description="约束值。") + currency: str | None = Field(default=None, description="金额类约束使用的币种。") + + +class OntologyPermission(BaseModel): + model_config = ConfigDict(extra="forbid") + + level: OntologyPermissionLevel = Field(default="read", description="动作权限等级。") + allowed: bool = Field(default=True, description="是否可直接执行当前动作。") + reason: str = Field(default="", description="权限判断原因。") + + +class OntologyFieldError(BaseModel): + model_config = ConfigDict(extra="forbid") + + field: str = Field(description="发生问题的字段。") + code: str = Field(description="错误码。") + message: str = Field(description="面向前端展示的说明。") + + +class OntologyParseRequest(BaseModel): + query: str = Field(min_length=1, description="自然语言问题。") + user_id: str | None = Field(default=None, description="当前请求用户 ID。") + context_json: dict[str, Any] = Field( + default_factory=dict, + description="用户上下文,例如角色、部门、是否管理员。", + ) + + +class OntologyParseResult(BaseModel): + scenario: OntologyScenario = Field(default="unknown", description="业务场景。") + intent: OntologyIntent = Field(default="query", description="用户意图。") + entities: list[OntologyEntity] = Field(default_factory=list, description="业务对象列表。") + time_range: OntologyTimeRange = Field( + default_factory=OntologyTimeRange, + description="时间范围。", + ) + metrics: list[OntologyMetric] = Field(default_factory=list, description="指标解析结果。") + constraints: list[OntologyConstraint] = Field( + default_factory=list, + description="过滤、阈值、排序等约束。", + ) + risk_flags: list[str] = Field(default_factory=list, description="风险信号列表。") + permission: OntologyPermission = Field( + default_factory=OntologyPermission, + description="权限结果。", + ) + confidence: float = Field(default=0.0, ge=0.0, le=1.0, description="整体置信度。") + missing_slots: list[str] = Field(default_factory=list, description="继续处理所缺少的关键槽位。") + ambiguity: list[str] = Field(default_factory=list, description="当前识别中的潜在歧义。") + parse_strategy: OntologyParseStrategy = Field( + default="rule_fallback", + description="本次语义解析使用的主策略。", + ) + clarification_required: bool = Field(default=False, description="是否需要追问。") + clarification_question: str | None = Field(default=None, description="推荐追问问题。") + run_id: str = Field(description="关联的 AgentRun.run_id。") + field_errors: list[OntologyFieldError] = Field( + default_factory=list, + description="字段级错误或提示。", + ) diff --git a/server/src/app/schemas/orchestrator.py b/server/src/app/schemas/orchestrator.py new file mode 100644 index 0000000..9489f25 --- /dev/null +++ b/server/src/app/schemas/orchestrator.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from typing import Any, Literal + +from pydantic import BaseModel, Field + +OrchestratorSource = Literal["user_message", "schedule", "system_event"] +OrchestratorAgent = Literal["user_agent", "hermes"] +OrchestratorStatus = Literal["succeeded", "blocked", "failed"] + + +class OrchestratorRequest(BaseModel): + source: OrchestratorSource = Field(description="请求来源。") + user_id: str | None = Field(default=None, description="当前用户 ID,任务触发可为空。") + message: str | None = Field(default=None, description="用户消息或任务描述。") + task_id: str | None = Field(default=None, description="任务资产 ID,schedule 触发时优先使用。") + context_json: dict[str, Any] = Field( + default_factory=dict, + description="用户上下文、测试开关或调用方附加信息。", + ) + + +class OrchestratorTraceSummary(BaseModel): + scenario: str = Field(description="语义场景。") + intent: str = Field(description="语义意图。") + tool_count: int = Field(default=0, ge=0, description="工具调用总数。") + failed_tool_count: int = Field(default=0, ge=0, description="失败工具调用数量。") + selected_capability_codes: list[str] = Field( + default_factory=list, + description="本次路由命中的能力编码。", + ) + degraded: bool = Field(default=False, description="是否发生降级。") + + +class OrchestratorResponse(BaseModel): + run_id: str = Field(description="本次运行的唯一 run_id。") + selected_agent: OrchestratorAgent | None = Field( + default=None, + description="最终路由到的下游 Agent。", + ) + route_reason: str = Field(description="路由原因摘要。") + permission_level: str = Field(description="权限级别。") + status: OrchestratorStatus = Field(description="最终运行状态。") + result: dict[str, Any] = Field(default_factory=dict, description="对前端可直接展示的最小结果。") + requires_confirmation: bool = Field(default=False, description="是否需要用户或管理员确认。") + trace_summary: OrchestratorTraceSummary = Field(description="简化后的 Trace 摘要。") diff --git a/server/src/app/schemas/user_agent.py b/server/src/app/schemas/user_agent.py new file mode 100644 index 0000000..7db7291 --- /dev/null +++ b/server/src/app/schemas/user_agent.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +from typing import Any, Literal + +from pydantic import BaseModel, Field + +from app.schemas.ontology import OntologyParseResult + +UserAgentCitationType = Literal["rule", "knowledge"] + + +class UserAgentCitation(BaseModel): + source_type: UserAgentCitationType = Field(description="引用来源类型。") + code: str = Field(description="来源编码。") + title: str = Field(description="来源标题。") + version: str | None = Field(default=None, description="引用版本。") + updated_at: str | None = Field(default=None, description="来源更新时间。") + excerpt: str | None = Field(default=None, description="面向用户展示的引用摘要。") + + +class UserAgentSuggestedAction(BaseModel): + label: str = Field(description="建议动作文案。") + action_type: str = Field(description="动作类型,例如 open_detail / create_draft。") + description: str = Field(default="", description="动作说明。") + + +class UserAgentDraftPayload(BaseModel): + draft_type: str = Field(description="草稿类型。") + title: str = Field(description="草稿标题。") + body: str = Field(description="草稿正文。") + confirmation_required: bool = Field(default=True, description="是否需要人工确认。") + + +class UserAgentRequest(BaseModel): + run_id: str = Field(description="关联的 AgentRun.run_id。") + user_id: str | None = Field(default=None, description="当前请求用户 ID。") + message: str = Field(description="原始用户问题。") + ontology: OntologyParseResult = Field(description="语义解析结果。") + context_json: dict[str, Any] = Field(default_factory=dict, description="附加上下文。") + tool_payload: dict[str, Any] = Field(default_factory=dict, description="工具返回的原始结果。") + selected_capability_codes: list[str] = Field( + default_factory=list, + description="本次命中的能力编码。", + ) + degraded: bool = Field(default=False, description="当前是否发生降级。") + requires_confirmation: bool = Field(default=False, description="是否要求确认。") + + +class UserAgentResponse(BaseModel): + answer: str = Field(description="面向用户展示的自然语言回答。") + citations: list[UserAgentCitation] = Field(default_factory=list, description="规则或知识引用。") + suggested_actions: list[UserAgentSuggestedAction] = Field( + default_factory=list, + description="建议的下一步动作。", + ) + draft_payload: UserAgentDraftPayload | None = Field(default=None, description="可选草稿内容。") + risk_flags: list[str] = Field(default_factory=list, description="本次回答关联的风险标签。") + requires_confirmation: bool = Field(default=False, description="是否需要人工确认。") diff --git a/server/src/app/services/ontology.py b/server/src/app/services/ontology.py new file mode 100644 index 0000000..bd50486 --- /dev/null +++ b/server/src/app/services/ontology.py @@ -0,0 +1,1470 @@ +from __future__ import annotations + +import calendar +import json +import re +from dataclasses import dataclass +from datetime import UTC, date, datetime, timedelta +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field, ValidationError +from sqlalchemy import select +from sqlalchemy.orm import Session + +from app.core.agent_enums import ( + AgentName, + AgentPermissionLevel, + AgentRunSource, + AgentRunStatus, +) +from app.core.logging import get_logger +from app.models.employee import Employee +from app.models.financial_record import ( + AccountsPayableRecord, + AccountsReceivableRecord, + ExpenseClaim, +) +from app.models.organization import OrganizationUnit +from app.schemas.ontology import ( + OntologyConstraint, + OntologyEntity, + OntologyFieldError, + OntologyIntent, + OntologyMetric, + OntologyParseRequest, + OntologyParseResult, + OntologyPermission, + OntologyScenario, + OntologyTimeRange, +) +from app.services.agent_foundation import AgentFoundationService +from app.services.agent_runs import AgentRunService +from app.services.runtime_chat import RuntimeChatService + +logger = get_logger("app.services.ontology") + +DATE_RANGE_PATTERN = re.compile( + r"(?P\d{4}-\d{1,2}-\d{1,2})\s*(?:到|至|~|-)\s*(?P\d{4}-\d{1,2}-\d{1,2})" +) +EXPLICIT_MONTH_PATTERN = re.compile(r"(?P\d{4})年(?P\d{1,2})月") +EXPLICIT_DATE_PATTERN = re.compile( + r"(?P\d{4})[年/-](?P\d{1,2})[月/-](?P\d{1,2})日?" +) +MONTH_DAY_RANGE_PATTERN = re.compile( + r"(?P\d{1,2})月(?P\d{1,2})日?\s*(?:到|至|~|-)\s*" + r"(?P\d{1,2})月(?P\d{1,2})日?" +) +MONTH_DAY_PATTERN = re.compile(r"(?P\d{1,2})月(?P\d{1,2})日?") +AMOUNT_PATTERN = re.compile( + r"(?P超过|大于|高于|不少于|不低于|小于|低于|少于|至多|不超过|<=|>=|<|>|=|=)?\s*" + r"(?P\d+(?:\.\d+)?)\s*(?P万元|万|元)?" +) +TOP_N_PATTERN = re.compile(r"(?:top|TOP|前|最高的?|最低的?)\s*(?P\d+)") + +SCENARIO_KEYWORDS = { + "expense": ( + ("报销", 0.20), + ("报账", 0.20), + ("差旅", 0.20), + ("费用", 0.14), + ("发票", 0.14), + ("票据", 0.12), + ("借款", 0.12), + ("住宿", 0.10), + ("餐费", 0.10), + ("招待", 0.18), + ("招待费", 0.18), + ("花销", 0.16), + ("花了", 0.14), + ("支出", 0.14), + ("垫付", 0.14), + ), + "accounts_receivable": ( + ("应收", 0.22), + ("回款", 0.20), + ("收款", 0.18), + ("账龄", 0.18), + ("客户欠款", 0.22), + ), + "accounts_payable": ( + ("应付", 0.22), + ("付款", 0.20), + ("请款", 0.18), + ("供应商", 0.20), + ("待付", 0.16), + ("打款", 0.18), + ), + "knowledge": ( + ("制度", 0.20), + ("规则", 0.20), + ("办法", 0.18), + ("依据", 0.18), + ("政策", 0.16), + ("知识库", 0.18), + ), +} + +QUERY_KEYWORDS = ( + "查", + "查询", + "查看", + "列出", + "统计", + "汇总", + "多少", + "几笔", + "金额", + "明细", +) +EXPLAIN_KEYWORDS = ("为什么", "依据", "原因", "怎么处理", "是否可以", "能不能", "按什么规则") +COMPARE_KEYWORDS = ("对比", "比较", "相比", "差异", "变化") +RISK_KEYWORDS = ("风险", "异常", "重复", "超标", "超预算", "逾期", "验真", "巡检") +DRAFT_KEYWORDS = ("生成", "草稿", "起草", "拟一份", "创建", "发起", "准备") +OPERATE_KEYWORDS = ( + "直接付款", + "帮我付款", + "安排付款", + "发起付款", + "直接审批", + "审批通过", + "帮我审批", + "驳回", + "上线", + "激活", + "停用", + "删除", +) + +EXPENSE_TYPE_KEYWORDS = { + "差旅": "travel", + "住宿": "hotel", + "酒店": "hotel", + "交通": "transport", + "餐费": "meal", + "会务": "meeting", + "招待费": "entertainment", + "招待": "entertainment", +} + +EXPENSE_NARRATIVE_KEYWORDS = ( + "报销", + "报账", + "招待", + "招待费", + "花销", + "花了", + "支出", + "垫付", + "打车", + "车费", + "餐费", + "住宿", + "发票", + "票据", + "差旅", + "客户现场", +) + +AR_CORE_KEYWORDS = ("应收", "回款", "收款", "账龄", "欠款", "未回款") +AP_CORE_KEYWORDS = ("应付", "付款", "请款", "待付", "打款", "未付款") +GENERIC_EXPENSE_PROMPTS = { + "报销", + "我要报销", + "我想报销", + "帮我报销", + "我要申请报销", + "发起报销", + "提交报销", +} +MISSING_SLOT_LABELS = { + "expense_type": "费用类型", + "amount": "金额", + "customer_name": "客户单位", + "vendor_name": "供应商", + "participants": "参与人员", + "attachments": "票据附件", + "time_range": "发生时间", + "reason": "事由说明", + "document_id": "单据号", +} + +STATUS_KEYWORDS = { + "逾期": "overdue", + "待审批": "pending", + "待审": "pending", + "已审批": "approved", + "已通过": "approved", + "已付款": "paid", + "未付款": "unpaid", + "未回款": "unreceived", +} + +PRIVILEGED_ROLE_CODES = {"manager", "finance", "approver", "executive"} + + +@dataclass(slots=True) +class ReferenceCatalog: + employees: list[str] + departments: list[str] + customers: list[str] + vendors: list[str] + projects: list[str] + + +class LlmOntologyEntityHint(BaseModel): + model_config = ConfigDict(extra="ignore") + + type: str + value: str + normalized_value: str | None = None + role: str = "target" + confidence: float = Field(default=0.72, ge=0.0, le=1.0) + + +class LlmOntologyParseResult(BaseModel): + model_config = ConfigDict(extra="ignore") + + scenario: OntologyScenario = Field(default="unknown") + intent: OntologyIntent = Field(default="query") + confidence: float = Field(default=0.0, ge=0.0, le=1.0) + clarification_required: bool = False + clarification_question: str | None = None + missing_slots: list[str] = Field(default_factory=list) + ambiguity: list[str] = Field(default_factory=list) + entity_hints: list[LlmOntologyEntityHint] = Field(default_factory=list) + + +class SemanticOntologyService: + def __init__(self, db: Session) -> None: + self.db = db + self.run_service = AgentRunService(db) + self.runtime_chat_service = RuntimeChatService(db) + + def parse(self, payload: OntologyParseRequest) -> OntologyParseResult: + analyzed = self._analyze(payload) + run = self.run_service.create_run( + agent=AgentName.ORCHESTRATOR.value, + source=AgentRunSource.USER_MESSAGE.value, + user_id=payload.user_id, + ontology_json=self._build_ontology_json(analyzed), + route_json={ + "stage": "semantic_parse", + "clarification_required": analyzed["clarification_required"], + "field_error_count": len(analyzed["field_errors"]), + }, + permission_level=analyzed["permission"].level, + status=( + AgentRunStatus.BLOCKED.value + if analyzed["clarification_required"] + or analyzed["permission"].level == AgentPermissionLevel.FORBIDDEN.value + else AgentRunStatus.SUCCEEDED.value + ), + result_summary=self._build_result_summary( + analyzed["scenario"], + analyzed["intent"], + analyzed["permission"].level, + analyzed["confidence"], + ), + error_message=( + analyzed["permission"].reason + if analyzed["permission"].level == AgentPermissionLevel.FORBIDDEN.value + else None + ), + ) + self._record_semantic_parse( + run_id=run.run_id, + payload=payload, + analyzed=analyzed, + ) + return self._build_result(analyzed, run.run_id) + + def parse_for_run(self, payload: OntologyParseRequest, *, run_id: str) -> OntologyParseResult: + analyzed = self._analyze(payload) + self._record_semantic_parse(run_id=run_id, payload=payload, analyzed=analyzed) + return self._build_result(analyzed, run_id) + + def _analyze(self, payload: OntologyParseRequest) -> dict[str, object]: + query = payload.query.strip() + if not query: + raise ValueError("query 不能为空。") + + AgentFoundationService(self.db).ensure_foundation_ready() + reference = self._load_reference_catalog() + compact_query = self._compact(query) + + entities = self._extract_entities(query, compact_query, reference) + rule_scenario, scenario_score = self._detect_scenario(compact_query) + time_range, _time_score = self._extract_time_range(query, compact_query) + if rule_scenario == "unknown": + inferred_scenario = self._infer_scenario_from_entities(entities) + if inferred_scenario is not None: + rule_scenario = inferred_scenario + scenario_score = 0.18 + + if self._looks_like_expense_narrative( + compact_query, + scenario=rule_scenario, + entities=entities, + time_range=time_range, + ): + rule_scenario = "expense" + scenario_score = max(scenario_score, 0.24) + + rule_intent, intent_score = self._detect_intent( + compact_query, + scenario=rule_scenario, + entities=entities, + time_range=time_range, + ) + metrics = self._extract_metrics(compact_query) + constraints = self._extract_constraints(compact_query, entities) + model_parse = self._parse_with_model( + payload=payload, + query=query, + compact_query=compact_query, + fallback_scenario=rule_scenario, + fallback_intent=rule_intent, + entities=entities, + time_range=time_range, + metrics=metrics, + constraints=constraints, + ) + scenario = self._resolve_scenario(rule_scenario, model_parse) + entities = self._merge_entities( + entities, + model_parse.entity_hints if model_parse is not None else [], + ) + intent = self._resolve_intent( + compact_query, + fallback_intent=rule_intent, + scenario=scenario, + entities=entities, + time_range=time_range, + model_parse=model_parse, + ) + missing_slots = self._normalize_short_text_list( + model_parse.missing_slots if model_parse is not None else [] + ) + missing_slots = self._normalize_short_text_list( + missing_slots + + self._infer_default_missing_slots( + compact_query, + scenario=scenario, + intent=intent, + entities=entities, + time_range=time_range, + context_json=payload.context_json or {}, + ) + ) + ambiguity = self._normalize_short_text_list( + model_parse.ambiguity if model_parse is not None else [] + ) + risk_flags = self._extract_risk_flags(compact_query, scenario) + permission = self._resolve_permission( + compact_query, + payload.context_json or {}, + intent, + ) + + field_errors = self._build_field_errors( + scenario=scenario, + intent=intent, + entities=entities, + permission=permission, + missing_slots=missing_slots, + ambiguity=ambiguity, + ) + clarification_required, clarification_question = self._build_clarification( + scenario=scenario, + intent=intent, + entities=entities, + permission=permission, + missing_slots=missing_slots, + ambiguity=ambiguity, + model_clarification_required=bool( + model_parse is not None and model_parse.clarification_required + ), + model_clarification_question=( + model_parse.clarification_question if model_parse is not None else None + ), + ) + fallback_confidence = self._compute_confidence( + scenario=scenario, + scenario_score=scenario_score, + intent_score=intent_score, + entities=entities, + time_range=time_range, + metrics=metrics, + constraints=constraints, + risk_flags=risk_flags, + clarification_required=clarification_required, + permission=permission, + ) + confidence = self._resolve_confidence( + model_confidence=( + model_parse.confidence + if model_parse is not None + else None + ), + fallback_confidence=fallback_confidence, + clarification_required=clarification_required, + permission=permission, + ) + return { + "scenario": scenario, + "intent": intent, + "entities": entities, + "time_range": time_range, + "metrics": metrics, + "constraints": constraints, + "risk_flags": risk_flags, + "permission": permission, + "confidence": confidence, + "missing_slots": missing_slots, + "ambiguity": ambiguity, + "parse_strategy": "llm_primary" if model_parse is not None else "rule_fallback", + "clarification_required": clarification_required, + "clarification_question": clarification_question, + "field_errors": field_errors, + } + + def _record_semantic_parse( + self, + *, + run_id: str, + payload: OntologyParseRequest, + analyzed: dict[str, object], + ) -> None: + self.run_service.record_semantic_parse( + run_id=run_id, + user_id=payload.user_id, + raw_query=payload.query.strip(), + scenario=str(analyzed["scenario"]), + intent=str(analyzed["intent"]), + entities_json=[item.model_dump() for item in analyzed["entities"]], + time_range_json=analyzed["time_range"].model_dump(), + metrics_json=[item.model_dump() for item in analyzed["metrics"]], + constraints_json=[item.model_dump() for item in analyzed["constraints"]], + risk_flags_json=list(analyzed["risk_flags"]), + permission_json=analyzed["permission"].model_dump(), + confidence=float(analyzed["confidence"]), + ) + logger.info( + "Parsed ontology run_id=%s scenario=%s intent=%s permission=%s", + run_id, + analyzed["scenario"], + analyzed["intent"], + analyzed["permission"].level, + ) + + @staticmethod + def _build_ontology_json(analyzed: dict[str, object]) -> dict[str, object]: + return { + "scenario": analyzed["scenario"], + "intent": analyzed["intent"], + "entities": [item.model_dump() for item in analyzed["entities"]], + "time_range": analyzed["time_range"].model_dump(), + "metrics": [item.model_dump() for item in analyzed["metrics"]], + "constraints": [item.model_dump() for item in analyzed["constraints"]], + "risk_flags": list(analyzed["risk_flags"]), + "permission": analyzed["permission"].model_dump(), + "missing_slots": list(analyzed["missing_slots"]), + "ambiguity": list(analyzed["ambiguity"]), + "parse_strategy": analyzed["parse_strategy"], + "confidence": analyzed["confidence"], + } + + @staticmethod + def _build_result(analyzed: dict[str, object], run_id: str) -> OntologyParseResult: + return OntologyParseResult( + scenario=analyzed["scenario"], + intent=analyzed["intent"], + entities=analyzed["entities"], + time_range=analyzed["time_range"], + metrics=analyzed["metrics"], + constraints=analyzed["constraints"], + risk_flags=analyzed["risk_flags"], + permission=analyzed["permission"], + confidence=analyzed["confidence"], + missing_slots=analyzed["missing_slots"], + ambiguity=analyzed["ambiguity"], + parse_strategy=analyzed["parse_strategy"], + clarification_required=analyzed["clarification_required"], + clarification_question=analyzed["clarification_question"], + run_id=run_id, + field_errors=analyzed["field_errors"], + ) + + def _load_reference_catalog(self) -> ReferenceCatalog: + employees = self._read_distinct_values(select(Employee.name)) + departments = self._read_distinct_values(select(OrganizationUnit.name)) + departments += self._read_distinct_values(select(ExpenseClaim.department_name)) + customers = self._read_distinct_values(select(AccountsReceivableRecord.customer_name)) + vendors = self._read_distinct_values(select(AccountsPayableRecord.vendor_name)) + projects = self._read_distinct_values(select(ExpenseClaim.project_code)) + + return ReferenceCatalog( + employees=self._dedupe_and_sort(employees), + departments=self._dedupe_and_sort(departments), + customers=self._dedupe_and_sort(customers), + vendors=self._dedupe_and_sort(vendors), + projects=self._dedupe_and_sort(projects), + ) + + def _read_distinct_values(self, stmt) -> list[str]: + values = self.db.scalars(stmt.distinct()).all() + return [str(item).strip() for item in values if item] + + @staticmethod + def _dedupe_and_sort(values: list[str]) -> list[str]: + items = {str(item).strip() for item in values if str(item).strip()} + return sorted(items, key=lambda item: (-len(item), item)) + + @staticmethod + def _compact(text: str) -> str: + return re.sub(r"\s+", "", text).lower() + + def _detect_scenario(self, compact_query: str) -> tuple[str, float]: + scores = {key: 0.0 for key in SCENARIO_KEYWORDS} + for scenario, keywords in SCENARIO_KEYWORDS.items(): + for keyword, weight in keywords: + if keyword in compact_query: + scores[scenario] += weight + + best_scenario = max(scores, key=scores.get) + best_score = scores[best_scenario] + if best_score <= 0: + return "unknown", 0.0 + + if best_scenario == "knowledge": + business_scores = [ + scores["expense"], + scores["accounts_receivable"], + scores["accounts_payable"], + ] + if max(business_scores) > 0: + best_scenario = ("expense", "accounts_receivable", "accounts_payable")[ + business_scores.index(max(business_scores)) + ] + best_score = max(business_scores) + + return best_scenario, round(min(best_score, 0.34), 2) + + def _detect_intent( + self, + compact_query: str, + *, + scenario: str, + entities: list[OntologyEntity], + time_range: OntologyTimeRange, + ) -> tuple[str, float]: + if any(keyword in compact_query for keyword in OPERATE_KEYWORDS): + return "operate", 0.30 + if any(keyword in compact_query for keyword in DRAFT_KEYWORDS): + return "draft", 0.26 + if scenario == "expense" and self._is_generic_expense_prompt(compact_query): + return "draft", 0.24 + if any(keyword in compact_query for keyword in COMPARE_KEYWORDS): + return "compare", 0.24 + if any(keyword in compact_query for keyword in EXPLAIN_KEYWORDS): + return "explain", 0.22 + if any(keyword in compact_query for keyword in RISK_KEYWORDS): + return "risk_check", 0.24 + if any(keyword in compact_query for keyword in QUERY_KEYWORDS): + return "query", 0.20 + if self._looks_like_expense_narrative( + compact_query, + scenario=scenario, + entities=entities, + time_range=time_range, + ): + return "draft", 0.22 + return "query", 0.10 + + @staticmethod + def _is_generic_expense_prompt(compact_query: str) -> bool: + return compact_query in GENERIC_EXPENSE_PROMPTS + + @staticmethod + def _looks_like_expense_narrative( + compact_query: str, + *, + scenario: str, + entities: list[OntologyEntity], + time_range: OntologyTimeRange, + ) -> bool: + if scenario not in {"expense", "accounts_receivable", "accounts_payable", "unknown"}: + return False + + if any(keyword in compact_query for keyword in AR_CORE_KEYWORDS + AP_CORE_KEYWORDS): + return False + + entity_types = {item.type for item in entities} + has_expense_signal = any( + keyword in compact_query for keyword in EXPENSE_NARRATIVE_KEYWORDS + ) or "expense_type" in entity_types + has_context_signal = bool(time_range.start_date) or "amount" in entity_types + + return has_expense_signal and has_context_signal + + def _parse_with_model( + self, + *, + payload: OntologyParseRequest, + query: str, + compact_query: str, + fallback_scenario: str, + fallback_intent: str, + entities: list[OntologyEntity], + time_range: OntologyTimeRange, + metrics: list[OntologyMetric], + constraints: list[OntologyConstraint], + ) -> LlmOntologyParseResult | None: + messages = self._build_model_messages( + payload=payload, + query=query, + compact_query=compact_query, + fallback_scenario=fallback_scenario, + fallback_intent=fallback_intent, + entities=entities, + time_range=time_range, + metrics=metrics, + constraints=constraints, + ) + response_text = self.runtime_chat_service.complete( + messages, + max_tokens=600, + temperature=0.0, + ) + payload_json = self._extract_json_payload(response_text) + if payload_json is None: + return None + + try: + return LlmOntologyParseResult.model_validate(payload_json) + except ValidationError as exc: + logger.warning("Semantic model output validation failed: %s", exc) + return None + + @staticmethod + def _build_model_messages( + *, + payload: OntologyParseRequest, + query: str, + compact_query: str, + fallback_scenario: str, + fallback_intent: str, + entities: list[OntologyEntity], + time_range: OntologyTimeRange, + metrics: list[OntologyMetric], + constraints: list[OntologyConstraint], + ) -> list[dict[str, str]]: + facts = { + "query": query, + "compact_query": compact_query, + "context": { + "entry_source": payload.context_json.get("entry_source"), + "attachment_names": payload.context_json.get("attachment_names", []), + "attachment_count": payload.context_json.get("attachment_count", 0), + "request_context": payload.context_json.get("request_context"), + "role_codes": payload.context_json.get("role_codes", []), + }, + "rule_candidates": { + "scenario": fallback_scenario, + "intent": fallback_intent, + "entities": [item.model_dump(mode="json") for item in entities], + "time_range": time_range.model_dump(mode="json"), + "metrics": [item.model_dump(mode="json") for item in metrics], + "constraints": [item.model_dump(mode="json") for item in constraints], + }, + } + + system_prompt = ( + "你是企业财务共享平台的语义解析器。" + "你的任务是把用户输入解析为固定 JSON,用于后续路由、追问和权限判断。" + "只输出 JSON 对象,不要输出 Markdown、代码块、解释、标题或 。" + "场景 scenario 只能是:expense, accounts_receivable, " + "accounts_payable, knowledge, unknown。" + "意图 intent 只能是:query, explain, compare, risk_check, draft, operate。" + "如果用户是在描述一笔待处理费用、待报销事项、上传票据或希望整理报销," + "即使没有明确说“生成草稿”,也优先使用 expense + draft。" + "出现“客户”不等于应收,出现“供应商”不等于应付,必须结合动作词和业务目标判断。" + "只有明确查询、统计、列出、多少、明细、对比时才优先使用 query 或 compare。" + "信息不足时 clarification_required=true,并给出一句简短中文追问。" + "missing_slots 使用简短 snake_case,例如 expense_type, amount, " + "customer_name, participants, attachments。" + "entity_hints 只填写你比较确定的业务对象;如果不确定,可以返回空数组。" + ) + user_prompt = ( + "请根据以下事实输出 JSON:\n" + f"{json.dumps(facts, ensure_ascii=False, indent=2, default=str)}\n\n" + "输出格式:\n" + "{\n" + ' "scenario": "expense",\n' + ' "intent": "draft",\n' + ' "confidence": 0.88,\n' + ' "clarification_required": true,\n' + ' "clarification_question": "请补充客户单位、参与人员和票据附件。",\n' + ' "missing_slots": ["customer_name", "participants", "attachments"],\n' + ' "ambiguity": [],\n' + ' "entity_hints": [\n' + ' {"type": "expense_type", "value": "招待", ' + '"normalized_value": "entertainment", "role": "filter", ' + '"confidence": 0.86}\n' + " ]\n" + "}" + ) + return [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ] + + @staticmethod + def _extract_json_payload(response_text: str | None) -> dict[str, Any] | None: + if not response_text: + return None + + cleaned = re.sub(r".*?", "", response_text, flags=re.DOTALL | re.IGNORECASE) + cleaned = cleaned.strip() + if not cleaned: + return None + + fenced_match = re.search(r"```(?:json)?\s*(\{.*\})\s*```", cleaned, flags=re.DOTALL) + candidates = [fenced_match.group(1)] if fenced_match else [] + candidates.extend([cleaned]) + + start = cleaned.find("{") + end = cleaned.rfind("}") + if start != -1 and end != -1 and end > start: + candidates.append(cleaned[start : end + 1]) + + for candidate in candidates: + try: + parsed = json.loads(candidate) + except json.JSONDecodeError: + continue + if isinstance(parsed, dict): + return parsed + + return None + + @staticmethod + def _resolve_scenario( + fallback_scenario: str, + model_parse: LlmOntologyParseResult | None, + ) -> str: + if model_parse is None: + return fallback_scenario + if model_parse.scenario == "unknown" and fallback_scenario != "unknown": + return fallback_scenario + return model_parse.scenario + + def _resolve_intent( + self, + compact_query: str, + *, + fallback_intent: str, + scenario: str, + entities: list[OntologyEntity], + time_range: OntologyTimeRange, + model_parse: LlmOntologyParseResult | None, + ) -> str: + candidate = model_parse.intent if model_parse is not None else fallback_intent + if candidate == "query" and scenario == "expense": + if self._is_generic_expense_prompt(compact_query) or fallback_intent == "draft": + return "draft" + return candidate + + @staticmethod + def _merge_entities( + base_entities: list[OntologyEntity], + entity_hints: list[LlmOntologyEntityHint], + ) -> list[OntologyEntity]: + merged: dict[tuple[str, str], OntologyEntity] = { + (item.type, item.normalized_value): item for item in base_entities + } + + for hint in entity_hints: + value = str(hint.value or "").strip() + if not value: + continue + normalized_value = str(hint.normalized_value or value).strip() + key = (str(hint.type).strip(), normalized_value) + candidate = OntologyEntity( + type=str(hint.type).strip(), + value=value, + normalized_value=normalized_value, + role=str(hint.role or "target").strip() or "target", + confidence=float(hint.confidence), + ) + existing = merged.get(key) + if existing is None or existing.confidence < candidate.confidence: + merged[key] = candidate + + return list(merged.values()) + + @staticmethod + def _normalize_short_text_list(values: list[str]) -> list[str]: + normalized: list[str] = [] + seen: set[str] = set() + for value in values: + cleaned = str(value or "").strip() + if not cleaned or cleaned in seen: + continue + normalized.append(cleaned) + seen.add(cleaned) + return normalized[:6] + + def _infer_default_missing_slots( + self, + compact_query: str, + *, + scenario: str, + intent: str, + entities: list[OntologyEntity], + time_range: OntologyTimeRange, + context_json: dict[str, Any], + ) -> list[str]: + if scenario != "expense" or intent != "draft": + return [] + + entity_types = {item.type for item in entities} + attachment_count = int(context_json.get("attachment_count") or 0) + missing_slots: list[str] = [] + + if self._is_generic_expense_prompt(compact_query): + if "expense_type" not in entity_types: + missing_slots.append("expense_type") + if "amount" not in entity_types: + missing_slots.append("amount") + if not time_range.start_date: + missing_slots.append("time_range") + missing_slots.append("reason") + if attachment_count <= 0: + missing_slots.append("attachments") + return missing_slots + + if any( + item.normalized_value == "entertainment" + for item in entities + if item.type == "expense_type" + ): + if "customer" not in entity_types: + missing_slots.append("customer_name") + missing_slots.append("participants") + if attachment_count <= 0: + missing_slots.append("attachments") + + return missing_slots + + @staticmethod + def _resolve_confidence( + *, + model_confidence: float | None, + fallback_confidence: float, + clarification_required: bool, + permission: OntologyPermission, + ) -> float: + confidence = fallback_confidence if model_confidence is None else float(model_confidence) + confidence = max(0.0, min(confidence, 0.98)) + if permission.level == AgentPermissionLevel.FORBIDDEN.value: + confidence = max(confidence, 0.86) + if clarification_required and permission.level != AgentPermissionLevel.FORBIDDEN.value: + confidence = min(confidence, 0.58) + return round(confidence, 2) + + def _extract_entities( + self, + query: str, + compact_query: str, + reference: ReferenceCatalog, + ) -> list[OntologyEntity]: + entities: dict[tuple[str, str], OntologyEntity] = {} + + def upsert(entity: OntologyEntity) -> None: + key = (entity.type, entity.normalized_value) + if key not in entities: + entities[key] = entity + + for match in re.finditer(r"客户\s*([A-Za-z0-9一二三四五六七八九十]+)", query): + suffix = match.group(1).strip() + normalized = f"客户{suffix}".replace(" ", "") + upsert(self._make_entity("customer", match.group(0).strip(), normalized, role="filter")) + + for match in re.finditer(r"供应商\s*([A-Za-z0-9一二三四五六七八九十]+)", query): + suffix = match.group(1).strip() + normalized = f"供应商{suffix}".replace(" ", "") + upsert(self._make_entity("vendor", match.group(0).strip(), normalized, role="filter")) + + employee_match = re.search( + r"(?P[赵钱孙李周吴郑王冯陈褚卫蒋沈韩杨朱秦许何吕施张孔曹严华金魏陶姜" + r"戚谢邹喻柏水窦章云苏潘葛范彭郎鲁韦昌马苗凤花方俞任袁柳鲍史唐费廉岑" + r"薛雷贺倪汤滕殷罗毕郝邬安常乐于时傅卞康伍余元卜顾孟平黄和穆萧尹姚邵" + r"湛汪祁毛禹狄米贝明臧计成戴宋庞熊纪舒屈项祝董梁杜阮蓝闵席季强贾路江" + r"童颜郭梅盛林钟徐邱骆高夏蔡田樊胡凌霍虞万支柯管卢莫房裘缪解应宗丁宣" + r"邓洪包左石崔吉龚程嵇邢裴陆荣翁荀羊惠甄曲家封芮储靳汲邴糜松井段富巫" + r"乌焦巴弓牧隗山谷车侯伊宫宁仇栾刘景詹束龙叶司黎薄印白怀蒲邰从鄂索咸" + r"籍卓蔺屠蒙池乔阴胥能苍双闻莘党翟谭贡姬申扶堵冉宰郦雍桑桂牛寿通边扈" + r"燕冀浦尚农温别庄晏柴瞿阎连茹习艾容向古易慎戈廖庾终暨居衡步都耿满弘" + r"匡国文寇广禄阙东欧殳沃利蔚越夔隆师巩聂晁勾敖融冷辛阚那简饶曾关蒯相" + r"查后荆游竺权盖益桓公][\u4e00-\u9fa5]{1,2})(?=\s*(?:\d{4}年|\d{1,2}月|本月|" + r"上月|本周|报销|差旅|费用|申请))", + query, + ) + if employee_match: + name = employee_match.group("name") + upsert(self._make_entity("employee", name, name, role="filter")) + + for name in reference.employees: + if self._compact(name) in compact_query: + upsert(self._make_entity("employee", name, name, role="filter")) + for name in reference.departments: + if self._compact(name) in compact_query: + upsert(self._make_entity("department", name, name, role="filter")) + for name in reference.customers: + if self._compact(name) in compact_query: + upsert(self._make_entity("customer", name, name, role="filter")) + for name in reference.vendors: + if self._compact(name) in compact_query: + upsert(self._make_entity("vendor", name, name, role="filter")) + for code in reference.projects: + if self._compact(code) in compact_query: + upsert(self._make_entity("project", code, code, role="filter")) + + for code in re.findall(r"PRJ-[A-Z]+-\d+", query, flags=re.IGNORECASE): + upsert(self._make_entity("project", code, code.upper(), role="filter")) + for code in re.findall(r"EXP-\d{6}-\d{3}", query, flags=re.IGNORECASE): + upsert(self._make_entity("expense_claim", code, code.upper())) + for code in re.findall(r"AR-\d{6}-\d{3}", query, flags=re.IGNORECASE): + upsert(self._make_entity("receivable", code, code.upper())) + for code in re.findall(r"AP-\d{6}-\d{3}", query, flags=re.IGNORECASE): + upsert(self._make_entity("payable", code, code.upper())) + for code in re.findall(r"INV-[A-Z]+-\d+", query, flags=re.IGNORECASE): + upsert(self._make_entity("invoice", code, code.upper())) + for code in re.findall(r"CTR-[A-Z]+-\d+", query, flags=re.IGNORECASE): + upsert(self._make_entity("contract", code, code.upper())) + + for label, normalized in EXPENSE_TYPE_KEYWORDS.items(): + if label in query: + upsert(self._make_entity("expense_type", label, normalized, role="filter")) + + for amount in self._extract_amount_entities(query): + upsert(amount) + + return list(entities.values()) + + def _extract_amount_entities(self, query: str) -> list[OntologyEntity]: + entities: list[OntologyEntity] = [] + for match in AMOUNT_PATTERN.finditer(query): + raw_value = match.group("value") + unit = match.group("unit") + prefix = match.group("prefix") + if raw_value is None: + continue + if prefix is None and unit is None: + continue + + amount_value = self._normalize_amount(raw_value, unit) + display_value = f"{raw_value}{unit or ''}" + role = "threshold" if prefix else "target" + entities.append( + self._make_entity( + "amount", + display_value, + str(amount_value), + role=role, + confidence=0.9, + ) + ) + return entities + + @staticmethod + def _make_entity( + entity_type: str, + value: str, + normalized_value: str, + *, + role: str = "target", + confidence: float = 0.92, + ) -> OntologyEntity: + return OntologyEntity( + type=entity_type, + value=value, + normalized_value=normalized_value, + role=role, + confidence=confidence, + ) + + @staticmethod + def _infer_scenario_from_entities(entities: list[OntologyEntity]) -> str | None: + entity_types = {item.type for item in entities} + if entity_types & {"vendor", "payable"}: + return "accounts_payable" + if entity_types & {"customer", "receivable", "contract"}: + return "accounts_receivable" + if entity_types & {"employee", "expense_claim", "expense_type"}: + return "expense" + return None + + def _extract_time_range( + self, + query: str, + compact_query: str, + ) -> tuple[OntologyTimeRange, float]: + today = datetime.now(UTC).date() + + direct_mappings = { + "今天": self._single_day_range(today, "今天", "day"), + "昨日": self._single_day_range(today - timedelta(days=1), "昨日", "day"), + "昨天": self._single_day_range(today - timedelta(days=1), "昨天", "day"), + "明天": self._single_day_range(today + timedelta(days=1), "明天", "day"), + } + for keyword, value in direct_mappings.items(): + if keyword in query: + return value, 0.10 + + if "本周" in query or "这周" in query or "本星期" in query: + start = today - timedelta(days=today.weekday()) + end = start + timedelta(days=6) + return self._range(start, end, "本周", "week"), 0.10 + if "上周" in query: + end = today - timedelta(days=today.weekday() + 1) + start = end - timedelta(days=6) + return self._range(start, end, "上周", "week"), 0.10 + if "本月" in query or "这个月" in query: + start = date(today.year, today.month, 1) + end = date(today.year, today.month, calendar.monthrange(today.year, today.month)[1]) + return self._range(start, end, "本月", "month"), 0.10 + if "上月" in query: + year = today.year if today.month > 1 else today.year - 1 + month = today.month - 1 if today.month > 1 else 12 + start = date(year, month, 1) + end = date(year, month, calendar.monthrange(year, month)[1]) + return self._range(start, end, "上月", "month"), 0.10 + if "本季度" in query or "这个季度" in query: + quarter = (today.month - 1) // 3 + start_month = quarter * 3 + 1 + end_month = start_month + 2 + start = date(today.year, start_month, 1) + end = date(today.year, end_month, calendar.monthrange(today.year, end_month)[1]) + return self._range(start, end, "本季度", "quarter"), 0.10 + if "今年" in query: + return ( + self._range(date(today.year, 1, 1), date(today.year, 12, 31), "今年", "year"), + 0.10, + ) + + match = DATE_RANGE_PATTERN.search(query) + if match: + start = self._parse_iso_date(match.group("start")) + end = self._parse_iso_date(match.group("end")) + if start and end: + return self._range(start, end, match.group(0), "custom"), 0.10 + + match = EXPLICIT_DATE_PATTERN.search(query) + if match: + explicit = date( + int(match.group("year")), + int(match.group("month")), + int(match.group("day")), + ) + return self._single_day_range(explicit, match.group(0), "day"), 0.10 + + match = EXPLICIT_MONTH_PATTERN.search(query) + if match: + year = int(match.group("year")) + month = int(match.group("month")) + start = date(year, month, 1) + end = date(year, month, calendar.monthrange(year, month)[1]) + return self._range(start, end, match.group(0), "month"), 0.10 + + match = MONTH_DAY_RANGE_PATTERN.search(query) + if match: + start = date(today.year, int(match.group("start_month")), int(match.group("start_day"))) + end = date(today.year, int(match.group("end_month")), int(match.group("end_day"))) + return self._range(start, end, match.group(0), "custom"), 0.10 + + match = MONTH_DAY_PATTERN.search(compact_query) + if match: + explicit = date(today.year, int(match.group("month")), int(match.group("day"))) + return self._single_day_range(explicit, match.group(0), "day"), 0.08 + + month_match = re.search(r"(?P\d{1,2})月", compact_query) + if month_match: + month = int(month_match.group("month")) + start = date(today.year, month, 1) + end = date(today.year, month, calendar.monthrange(today.year, month)[1]) + return self._range(start, end, month_match.group(0), "month"), 0.08 + + return OntologyTimeRange(), 0.0 + + @staticmethod + def _single_day_range(target: date, raw: str, granularity: str) -> OntologyTimeRange: + return OntologyTimeRange( + raw=raw, + start_date=target.isoformat(), + end_date=target.isoformat(), + granularity=granularity, + ) + + @staticmethod + def _range(start: date, end: date, raw: str, granularity: str) -> OntologyTimeRange: + return OntologyTimeRange( + raw=raw, + start_date=start.isoformat(), + end_date=end.isoformat(), + granularity=granularity, + ) + + @staticmethod + def _parse_iso_date(value: str) -> date | None: + try: + return date.fromisoformat(value) + except ValueError: + return None + + def _extract_metrics(self, compact_query: str) -> list[OntologyMetric]: + metrics: dict[str, OntologyMetric] = {} + + def upsert(metric: OntologyMetric) -> None: + metrics[metric.name] = metric + + if any( + keyword in compact_query + for keyword in ("多少钱", "金额", "总额", "支出", "回款", "应收", "应付") + ): + upsert(OntologyMetric(name="amount", aggregation="sum", unit="CNY")) + if any(keyword in compact_query for keyword in ("多少笔", "几笔", "数量", "条数", "单数")): + upsert(OntologyMetric(name="count", aggregation="count", unit="records")) + if "超标" in compact_query or "超预算" in compact_query: + upsert(OntologyMetric(name="amount_over_limit")) + if "逾期" in compact_query or "账龄" in compact_query: + upsert(OntologyMetric(name="overdue")) + if "重复" in compact_query: + upsert(OntologyMetric(name="duplicate_expense")) + + top_match = TOP_N_PATTERN.search(compact_query) + if top_match: + metrics["amount"] = OntologyMetric( + name="amount", + aggregation="sum", + unit="CNY", + sort="desc" if "最低" not in compact_query else "asc", + top_n=int(top_match.group("top")), + ) + + return list(metrics.values()) + + def _extract_constraints( + self, + compact_query: str, + entities: list[OntologyEntity], + ) -> list[OntologyConstraint]: + constraints: dict[tuple[str, str, str, str | None], OntologyConstraint] = {} + + def upsert(constraint: OntologyConstraint) -> None: + key = ( + constraint.field, + constraint.operator, + str(constraint.value), + constraint.currency, + ) + if key not in constraints: + constraints[key] = constraint + + for entity in entities: + if entity.type in { + "employee", + "department", + "customer", + "vendor", + "project", + "expense_type", + }: + upsert( + OntologyConstraint( + field=entity.type, + operator="=", + value=entity.normalized_value, + ) + ) + + for keyword, normalized in STATUS_KEYWORDS.items(): + if keyword in compact_query: + upsert(OntologyConstraint(field="status", operator="=", value=normalized)) + + for amount_match in AMOUNT_PATTERN.finditer(compact_query): + if not amount_match.group("prefix"): + continue + + operator = self._normalize_operator(amount_match.group("prefix")) + value = self._normalize_amount(amount_match.group("value"), amount_match.group("unit")) + upsert( + OntologyConstraint( + field="amount", + operator=operator, + value=value, + currency="CNY", + ) + ) + break + + top_match = TOP_N_PATTERN.search(compact_query) + if top_match: + top_n = int(top_match.group("top")) + upsert(OntologyConstraint(field="top_n", operator="=", value=top_n)) + upsert( + OntologyConstraint( + field="sort_by", + operator="desc" if "最低" not in compact_query else "asc", + value="amount", + ) + ) + + return list(constraints.values()) + + def _extract_risk_flags(self, compact_query: str, scenario: str) -> list[str]: + risk_flags: list[str] = [] + + def append(flag: str) -> None: + if flag not in risk_flags: + risk_flags.append(flag) + + if "重复" in compact_query: + append("duplicate_expense") + if any( + keyword in compact_query + for keyword in ("发票异常", "票据异常", "验真失败", "附件缺失", "补件") + ): + append("invoice_anomaly") + if any(keyword in compact_query for keyword in ("超标", "超预算", "超限")): + append("amount_over_limit") + if scenario == "accounts_receivable" and any( + keyword in compact_query for keyword in ("逾期", "账龄", "欠款", "未回款") + ): + append("ar_overdue") + if scenario == "accounts_payable" and any( + keyword in compact_query for keyword in ("逾期", "待付", "付款风险", "未付款") + ): + append("ap_overdue") + + return risk_flags + + def _resolve_permission( + self, + compact_query: str, + context_json: dict, + intent: str, + ) -> OntologyPermission: + role_codes = { + str(item).strip().lower() + for item in context_json.get("role_codes", []) + if str(item).strip() + } + is_admin = bool(context_json.get("is_admin")) + privileged = is_admin or bool(role_codes & PRIVILEGED_ROLE_CODES) + + if intent in {"query", "explain", "compare", "risk_check"}: + return OntologyPermission( + level=AgentPermissionLevel.READ.value, + allowed=True, + reason="只读查询。", + ) + if intent == "draft": + return OntologyPermission( + level=AgentPermissionLevel.DRAFT_WRITE.value, + allowed=True, + reason="允许生成草稿,但不会直接提交业务动作。", + ) + + if any(keyword in compact_query for keyword in OPERATE_KEYWORDS) or "付款" in compact_query: + if privileged: + return OntologyPermission( + level=AgentPermissionLevel.APPROVAL_REQUIRED.value, + allowed=False, + reason="涉及付款、审批或上线动作,必须进入人工审批链。", + ) + return OntologyPermission( + level=AgentPermissionLevel.FORBIDDEN.value, + allowed=False, + reason="当前账号缺少财务或审批权限,只能查看结果或生成草稿。", + ) + + return OntologyPermission( + level=AgentPermissionLevel.APPROVAL_REQUIRED.value, + allowed=False, + reason="操作类请求需要人工审批确认。", + ) + + def _build_field_errors( + self, + *, + scenario: str, + intent: str, + entities: list[OntologyEntity], + permission: OntologyPermission, + missing_slots: list[str], + ambiguity: list[str], + ) -> list[OntologyFieldError]: + errors: list[OntologyFieldError] = [] + if scenario == "unknown": + errors.append( + OntologyFieldError( + field="scenario", + code="scenario_unknown", + message="未识别出明确业务场景,请补充是报销、应收、应付还是制度问题。", + ) + ) + if intent == "compare" and len([item for item in entities if item.type != "amount"]) < 2: + errors.append( + OntologyFieldError( + field="entities", + code="compare_target_missing", + message="对比类问题请至少给出两个对象,或给出更明确的对比范围。", + ) + ) + if missing_slots: + errors.append( + OntologyFieldError( + field="missing_slots", + code="required_slot_missing", + message=( + "继续处理前还缺少关键信息:" + f"{'、'.join(self._display_slot_label(item) for item in missing_slots)}。" + ), + ) + ) + if ambiguity: + errors.append( + OntologyFieldError( + field="ambiguity", + code="ambiguity_detected", + message=f"当前问题存在歧义:{';'.join(ambiguity)}。", + ) + ) + if permission.level == AgentPermissionLevel.FORBIDDEN.value: + errors.append( + OntologyFieldError( + field="permission", + code="permission_forbidden", + message=permission.reason, + ) + ) + return errors + + def _build_clarification( + self, + *, + scenario: str, + intent: str, + entities: list[OntologyEntity], + permission: OntologyPermission, + missing_slots: list[str], + ambiguity: list[str], + model_clarification_required: bool, + model_clarification_question: str | None, + ) -> tuple[bool, str | None]: + if permission.level == AgentPermissionLevel.FORBIDDEN.value: + return True, "当前动作超出权限范围。是否改为生成草稿或建议?" + if model_clarification_required: + question = str(model_clarification_question or "").strip() + if question: + return True, question + if missing_slots: + return True, self._build_missing_slot_question(missing_slots) + if ambiguity: + return True, f"当前问题存在歧义,请进一步说明:{';'.join(ambiguity)}。" + if scenario == "unknown": + return True, "请说明这是报销、应收、应付,还是制度知识问题?" + if intent == "compare" and len([item for item in entities if item.type != "amount"]) < 2: + return True, "请补充需要对比的两个对象,例如两个客户、两个供应商或两个员工。" + if missing_slots: + return True, self._build_missing_slot_question(missing_slots) + if ambiguity: + return True, f"当前问题存在歧义,请进一步说明:{';'.join(ambiguity)}。" + return False, None + + @staticmethod + def _display_slot_label(slot: str) -> str: + return MISSING_SLOT_LABELS.get(slot, slot) + + def _build_missing_slot_question(self, missing_slots: list[str]) -> str: + labels = [self._display_slot_label(item) for item in missing_slots[:4]] + if not labels: + return "请补充更多上下文后再继续。" + return f"请补充{'、'.join(labels)},我再继续帮你解析和处理。" + + @staticmethod + def _compute_confidence( + *, + scenario: str, + scenario_score: float, + intent_score: float, + entities: list[OntologyEntity], + time_range: OntologyTimeRange, + metrics: list[OntologyMetric], + constraints: list[OntologyConstraint], + risk_flags: list[str], + clarification_required: bool, + permission: OntologyPermission, + ) -> float: + confidence = 0.18 + scenario_score + intent_score + confidence += min(0.16, len(entities) * 0.04) + if time_range.start_date: + confidence += 0.10 + if metrics: + confidence += 0.06 + if constraints: + confidence += 0.06 + if risk_flags: + confidence += 0.08 + if permission.level == AgentPermissionLevel.FORBIDDEN.value: + confidence = max(confidence, 0.86) + + if scenario == "unknown": + confidence = min(confidence, 0.45) + if clarification_required and permission.level != AgentPermissionLevel.FORBIDDEN.value: + confidence = min(confidence, 0.58) + + return round(min(confidence, 0.98), 2) + + @staticmethod + def _build_result_summary( + scenario: str, + intent: str, + permission_level: str, + confidence: float, + ) -> str: + return ( + f"语义解析完成:scenario={scenario}, intent={intent}, " + f"permission={permission_level}, confidence={confidence:.2f}" + ) + + @staticmethod + def _normalize_operator(value: str) -> str: + mapping = { + "超过": ">", + "大于": ">", + "高于": ">", + ">": ">", + ">=": ">=", + "不少于": ">=", + "不低于": ">=", + "小于": "<", + "低于": "<", + "少于": "<", + "<": "<", + "<=": "<=", + "至多": "<=", + "不超过": "<=", + "=": "=", + "=": "=", + } + return mapping.get(value, value) + + @staticmethod + def _normalize_amount(raw_value: str | None, unit: str | None) -> int | float: + numeric = float(raw_value or 0) + if unit in {"万", "万元"}: + numeric *= 10000 + return int(numeric) if numeric.is_integer() else round(numeric, 2) diff --git a/server/src/app/services/orchestrator.py b/server/src/app/services/orchestrator.py new file mode 100644 index 0000000..31c1f3a --- /dev/null +++ b/server/src/app/services/orchestrator.py @@ -0,0 +1,887 @@ +from __future__ import annotations + +from dataclasses import dataclass +from datetime import UTC, datetime +from time import perf_counter +from typing import Any + +from sqlalchemy import func, select +from sqlalchemy.orm import Session + +from app.core.agent_enums import ( + AgentAssetStatus, + AgentAssetType, + AgentName, + AgentPermissionLevel, + AgentRunSource, + AgentRunStatus, + AgentToolType, +) +from app.core.logging import get_logger +from app.models.financial_record import ( + AccountsPayableRecord, + AccountsReceivableRecord, + ExpenseClaim, +) +from app.schemas.agent_asset import AgentAssetListItem, AgentAssetRead +from app.schemas.ontology import OntologyParseRequest, OntologyParseResult +from app.schemas.orchestrator import ( + OrchestratorRequest, + OrchestratorResponse, + OrchestratorTraceSummary, +) +from app.schemas.user_agent import UserAgentRequest, UserAgentResponse +from app.services.agent_assets import AgentAssetService +from app.services.agent_foundation import AgentFoundationService +from app.services.agent_runs import AgentRunService +from app.services.ontology import SemanticOntologyService +from app.services.user_agent import UserAgentService + +logger = get_logger("app.services.orchestrator") + +SCENARIO_TO_DOMAIN = { + "expense": "expense", + "accounts_receivable": "ar", + "accounts_payable": "ap", + "knowledge": "knowledge", + "unknown": "system", +} + + +@dataclass(slots=True) +class ExecutionOutcome: + status: str + result: dict[str, Any] + degraded: bool + tool_count: int + failed_tool_count: int + + +class OrchestratorService: + def __init__(self, db: Session) -> None: + self.db = db + self.asset_service = AgentAssetService(db) + self.run_service = AgentRunService(db) + self.ontology_service = SemanticOntologyService(db) + self.user_agent_service = UserAgentService(db) + + def run(self, payload: OrchestratorRequest) -> OrchestratorResponse: + AgentFoundationService(self.db).ensure_foundation_ready() + route_json: dict[str, Any] = { + "orchestrated_by": AgentName.ORCHESTRATOR.value, + "stage": "created", + } + run = self.run_service.create_run( + agent=AgentName.ORCHESTRATOR.value, + source=payload.source, + user_id=payload.user_id, + task_id=payload.task_id, + ontology_json={}, + route_json=route_json, + permission_level=AgentPermissionLevel.READ.value, + status=AgentRunStatus.RUNNING.value, + result_summary="Orchestrator 已接收请求。", + ) + + try: + message, task_asset = self._resolve_message(payload) + ontology = self.ontology_service.parse_for_run( + OntologyParseRequest( + query=message, + user_id=payload.user_id, + context_json=payload.context_json, + ), + run_id=run.run_id, + ) + if payload.context_json.get("simulate_orchestrator_exception"): + raise RuntimeError("simulated orchestrator exception") + selected_agent, route_reason = self._select_agent(payload, ontology) + capabilities = self._select_capabilities( + payload=payload, + ontology=ontology, + task_asset=task_asset, + ) + selected_capability_codes = self._flatten_capability_codes(capabilities) + requires_confirmation = ( + ontology.permission.level == AgentPermissionLevel.APPROVAL_REQUIRED.value + ) + + route_json = { + "orchestrated_by": AgentName.ORCHESTRATOR.value, + "stage": "routed", + "selected_agent": selected_agent, + "route_reason": route_reason, + "selected_capability_codes": selected_capability_codes, + "ontology_run_id": ontology.run_id, + } + + if ontology.permission.level == AgentPermissionLevel.FORBIDDEN.value: + outcome = ExecutionOutcome( + status=AgentRunStatus.BLOCKED.value, + result={ + "message": ontology.permission.reason, + "clarification_question": ontology.clarification_question, + "degraded": False, + }, + degraded=False, + tool_count=0, + failed_tool_count=0, + ) + selected_agent = None + route_reason = "permission_forbidden" + route_json["stage"] = "blocked" + route_json["route_reason"] = route_reason + elif ontology.clarification_required: + outcome = ExecutionOutcome( + status=AgentRunStatus.BLOCKED.value, + result={ + "message": ontology.clarification_question or "需要补充更多上下文。", + "clarification_required": True, + "missing_slots": ontology.missing_slots, + "ambiguity": ontology.ambiguity, + "parse_strategy": ontology.parse_strategy, + "degraded": False, + }, + degraded=False, + tool_count=0, + failed_tool_count=0, + ) + route_reason = "clarification_required" + route_json["stage"] = "clarification" + route_json["route_reason"] = route_reason + elif selected_agent == AgentName.HERMES.value: + outcome = self._execute_hermes( + payload=payload, + run_id=run.run_id, + ontology=ontology, + capabilities=capabilities, + requires_confirmation=requires_confirmation, + task_asset=task_asset, + ) + else: + outcome = self._execute_user_agent( + payload=payload, + run_id=run.run_id, + ontology=ontology, + capabilities=capabilities, + requires_confirmation=requires_confirmation, + ) + + final_status = ( + AgentRunStatus.BLOCKED.value + if requires_confirmation + and outcome.status == AgentRunStatus.SUCCEEDED.value + and ontology.permission.level == AgentPermissionLevel.APPROVAL_REQUIRED.value + else outcome.status + ) + result_message = ( + str(outcome.result.get("message", "")).strip() + or "Orchestrator 执行完成。" + ) + self.run_service.update_run( + run.run_id, + agent=selected_agent or AgentName.ORCHESTRATOR.value, + ontology_json=self._build_ontology_json(ontology), + route_json={ + **route_json, + "requires_confirmation": requires_confirmation, + "degraded": outcome.degraded, + }, + permission_level=ontology.permission.level, + status=final_status, + result_summary=result_message, + error_message=None, + finished_at=datetime.now(UTC), + ) + return OrchestratorResponse( + run_id=run.run_id, + selected_agent=selected_agent, + route_reason=route_reason, + permission_level=ontology.permission.level, + status=self._normalize_response_status(final_status), + result=outcome.result, + requires_confirmation=requires_confirmation, + trace_summary=OrchestratorTraceSummary( + scenario=ontology.scenario, + intent=ontology.intent, + tool_count=outcome.tool_count, + failed_tool_count=outcome.failed_tool_count, + selected_capability_codes=selected_capability_codes, + degraded=outcome.degraded, + ), + ) + except Exception as exc: + logger.exception("Orchestrator run failed run_id=%s", run.run_id) + self.run_service.update_run( + run.run_id, + agent=AgentName.ORCHESTRATOR.value, + route_json={**route_json, "stage": "failed"}, + status=AgentRunStatus.FAILED.value, + result_summary="Orchestrator 执行失败。", + error_message=str(exc), + finished_at=datetime.now(UTC), + ) + return OrchestratorResponse( + run_id=run.run_id, + selected_agent=None, + route_reason="orchestrator_exception", + permission_level=AgentPermissionLevel.READ.value, + status="failed", + result={"message": f"Orchestrator 执行失败:{exc}"}, + requires_confirmation=False, + trace_summary=OrchestratorTraceSummary( + scenario="unknown", + intent="query", + tool_count=0, + failed_tool_count=0, + selected_capability_codes=[], + degraded=False, + ), + ) + + def _resolve_message( + self, + payload: OrchestratorRequest, + ) -> tuple[str, AgentAssetRead | None]: + task_asset = None + if payload.task_id: + task_asset = self.asset_service.get_asset(payload.task_id) + + if payload.message and payload.message.strip(): + return payload.message.strip(), task_asset + + if task_asset is not None: + description = str(task_asset.description or "").strip() + scenario_text = " ".join(str(item) for item in task_asset.scenario_json) + message = f"{task_asset.name} {description} {scenario_text}".strip() + return message, task_asset + + if payload.source == AgentRunSource.SCHEDULE.value: + return "定时风险巡检任务", task_asset + + raise ValueError("message 或 task_id 至少需要提供一个。") + + @staticmethod + def _select_agent( + payload: OrchestratorRequest, + ontology: OntologyParseResult, + ) -> tuple[str, str]: + if payload.source == AgentRunSource.SCHEDULE.value: + return AgentName.HERMES.value, "schedule_source_defaults_to_hermes" + if payload.source == AgentRunSource.SYSTEM_EVENT.value and ontology.intent == "risk_check": + return AgentName.HERMES.value, "system_event_risk_check_routes_to_hermes" + if ontology.intent == "risk_check" and payload.source == AgentRunSource.SCHEDULE.value: + return AgentName.HERMES.value, "scheduled_risk_check_routes_to_hermes" + if ontology.intent in {"query", "explain", "draft", "compare", "operate"}: + return AgentName.USER_AGENT.value, f"{ontology.intent}_routes_to_user_agent" + return AgentName.USER_AGENT.value, "user_message_defaults_to_user_agent" + + def _select_capabilities( + self, + *, + payload: OrchestratorRequest, + ontology: OntologyParseResult, + task_asset: AgentAssetRead | None, + ) -> dict[str, list[AgentAssetListItem | AgentAssetRead]]: + domain_value = SCENARIO_TO_DOMAIN.get(ontology.scenario) + rules = self._rank_assets( + self.asset_service.list_assets( + asset_type=AgentAssetType.RULE.value, + status=AgentAssetStatus.ACTIVE.value, + domain=domain_value if domain_value not in {"knowledge", "system"} else None, + ), + ontology, + ) + skills = self._rank_assets( + self.asset_service.list_assets( + asset_type=AgentAssetType.SKILL.value, + status=AgentAssetStatus.ACTIVE.value, + domain=domain_value if domain_value not in {"system"} else None, + ), + ontology, + ) + mcps = self._rank_assets( + self.asset_service.list_assets( + asset_type=AgentAssetType.MCP.value, + status=AgentAssetStatus.ACTIVE.value, + ), + ontology, + ) + tasks: list[AgentAssetListItem | AgentAssetRead] = [] + if task_asset is not None and task_asset.status == AgentAssetStatus.ACTIVE.value: + tasks.append(task_asset) + elif payload.source == AgentRunSource.SCHEDULE.value: + tasks = self._rank_assets( + self.asset_service.list_assets( + asset_type=AgentAssetType.TASK.value, + status=AgentAssetStatus.ACTIVE.value, + ), + ontology, + ) + + return { + "rules": rules, + "skills": skills, + "mcps": mcps, + "tasks": tasks, + } + + def _execute_user_agent( + self, + *, + payload: OrchestratorRequest, + run_id: str, + ontology: OntologyParseResult, + capabilities: dict[str, list[AgentAssetListItem | AgentAssetRead]], + requires_confirmation: bool, + ) -> ExecutionOutcome: + selected_capability_codes = self._flatten_capability_codes(capabilities) + if requires_confirmation: + response, degraded = self._invoke_tool( + run_id=run_id, + tool_type=AgentToolType.LLM.value, + tool_name="user_agent.confirmation_placeholder", + request_json={ + "message": payload.message, + "permission_level": ontology.permission.level, + }, + context_json=payload.context_json, + executor=lambda: { + "confirmation_title": "操作需要确认", + "message": f"{ontology.permission.reason} 当前仅返回确认摘要,不直接执行动作。", + }, + fallback_factory=lambda exc: { + "confirmation_title": "操作需要确认", + "message": f"确认摘要生成失败,已阻断自动执行:{exc}", + }, + ) + return ExecutionOutcome( + status=AgentRunStatus.BLOCKED.value, + result={**response, "degraded": degraded}, + degraded=degraded, + tool_count=1, + failed_tool_count=1 if degraded else 0, + ) + + next_step = self._resolve_next_step(ontology, payload.source) + if next_step == "query_database": + tool_payload, degraded = self._invoke_tool( + run_id=run_id, + tool_type=AgentToolType.DATABASE.value, + tool_name=self._database_tool_name(ontology.scenario), + request_json=self._build_ontology_json(ontology), + context_json=payload.context_json, + executor=lambda: self._build_database_answer(ontology), + fallback_factory=lambda exc: { + "message": f"数据库查询暂时不可用,已返回降级说明:{exc}", + "degraded": True, + }, + ) + result = self._build_user_agent_result( + self.user_agent_service.respond( + UserAgentRequest( + run_id=run_id, + user_id=payload.user_id, + message=payload.message or "", + ontology=ontology, + context_json=payload.context_json, + tool_payload=tool_payload, + selected_capability_codes=selected_capability_codes, + degraded=degraded, + requires_confirmation=requires_confirmation, + ) + ), + degraded=degraded, + ) + return ExecutionOutcome( + status=AgentRunStatus.SUCCEEDED.value, + result=result, + degraded=degraded, + tool_count=1, + failed_tool_count=1 if degraded else 0, + ) + + if next_step == "search_knowledge": + tool_payload, degraded = self._invoke_tool( + run_id=run_id, + tool_type=AgentToolType.DATABASE.value, + tool_name="knowledge.search", + request_json=self._build_ontology_json(ontology), + context_json=payload.context_json, + executor=lambda: self._build_knowledge_answer(ontology, capabilities), + fallback_factory=lambda exc: { + "message": f"知识检索暂时不可用,建议稍后重试:{exc}", + "degraded": True, + }, + ) + result = self._build_user_agent_result( + self.user_agent_service.respond( + UserAgentRequest( + run_id=run_id, + user_id=payload.user_id, + message=payload.message or "", + ontology=ontology, + context_json=payload.context_json, + tool_payload=tool_payload, + selected_capability_codes=selected_capability_codes, + degraded=degraded, + requires_confirmation=requires_confirmation, + ) + ), + degraded=degraded, + ) + return ExecutionOutcome( + status=AgentRunStatus.SUCCEEDED.value, + result=result, + degraded=degraded, + tool_count=1, + failed_tool_count=1 if degraded else 0, + ) + + if next_step == "run_rule": + tool_payload, degraded = self._invoke_tool( + run_id=run_id, + tool_type=AgentToolType.RULE_ENGINE.value, + tool_name=self._rule_tool_name(capabilities), + request_json=self._build_ontology_json(ontology), + context_json=payload.context_json, + executor=lambda: self._build_rule_answer(ontology), + fallback_factory=lambda exc: { + "message": f"规则检查暂时不可用,已返回人工复核建议:{exc}", + "degraded": True, + }, + ) + result = self._build_user_agent_result( + self.user_agent_service.respond( + UserAgentRequest( + run_id=run_id, + user_id=payload.user_id, + message=payload.message or "", + ontology=ontology, + context_json=payload.context_json, + tool_payload=tool_payload, + selected_capability_codes=selected_capability_codes, + degraded=degraded, + requires_confirmation=requires_confirmation, + ) + ), + degraded=degraded, + ) + return ExecutionOutcome( + status=AgentRunStatus.SUCCEEDED.value, + result=result, + degraded=degraded, + tool_count=1, + failed_tool_count=1 if degraded else 0, + ) + + tool_payload, degraded = self._invoke_tool( + run_id=run_id, + tool_type=AgentToolType.LLM.value, + tool_name="user_agent.draft_placeholder", + request_json=self._build_ontology_json(ontology), + context_json=payload.context_json, + executor=lambda: { + "message": ( + f"已生成 {ontology.scenario} 场景草稿," + "占位能力后续由 Day 5 User Agent 接管。" + ), + "draft_only": True, + }, + fallback_factory=lambda exc: { + "message": f"草稿生成暂时不可用,请稍后再试:{exc}", + "degraded": True, + }, + ) + result = self._build_user_agent_result( + self.user_agent_service.respond( + UserAgentRequest( + run_id=run_id, + user_id=payload.user_id, + message=payload.message or "", + ontology=ontology, + context_json=payload.context_json, + tool_payload=tool_payload, + selected_capability_codes=selected_capability_codes, + degraded=degraded, + requires_confirmation=requires_confirmation, + ) + ), + degraded=degraded, + ) + return ExecutionOutcome( + status=AgentRunStatus.SUCCEEDED.value, + result=result, + degraded=degraded, + tool_count=1, + failed_tool_count=1 if degraded else 0, + ) + + def _execute_hermes( + self, + *, + payload: OrchestratorRequest, + run_id: str, + ontology: OntologyParseResult, + capabilities: dict[str, list[AgentAssetListItem | AgentAssetRead]], + requires_confirmation: bool, + task_asset: AgentAssetRead | None, + ) -> ExecutionOutcome: + if requires_confirmation: + return ExecutionOutcome( + status=AgentRunStatus.BLOCKED.value, + result={ + "message": "Hermes 不会自动执行需要确认的高风险动作,已阻断。", + "degraded": False, + }, + degraded=False, + tool_count=0, + failed_tool_count=0, + ) + + rule_response, rule_degraded = self._invoke_tool( + run_id=run_id, + tool_type=AgentToolType.RULE_ENGINE.value, + tool_name=self._rule_tool_name(capabilities), + request_json=self._build_ontology_json(ontology), + context_json=payload.context_json, + executor=lambda: self._build_rule_answer(ontology), + fallback_factory=lambda exc: { + "message": f"规则巡检失败,已降级为待人工复核:{exc}", + "degraded": True, + }, + ) + mcp_response, mcp_degraded = self._invoke_tool( + run_id=run_id, + tool_type=AgentToolType.MCP.value, + tool_name=self._mcp_tool_name(capabilities), + request_json={ + "task_code": task_asset.code if task_asset is not None else "", + "scenario": ontology.scenario, + }, + context_json=payload.context_json, + executor=lambda: self._build_mcp_answer(task_asset, ontology), + fallback_factory=lambda exc: { + "message": f"MCP 调用失败,已使用缓存快照降级:{exc}", + "fallback": "used_cached_snapshot", + }, + ) + degraded = rule_degraded or mcp_degraded + failed_tool_count = int(rule_degraded) + int(mcp_degraded) + result = { + "message": self._build_hermes_message( + task_asset=task_asset, + ontology=ontology, + rule_response=rule_response, + mcp_response=mcp_response, + degraded=degraded, + ), + "report_type": task_asset.code if task_asset is not None else "hermes_runtime", + "degraded": degraded, + } + return ExecutionOutcome( + status=AgentRunStatus.SUCCEEDED.value, + result=result, + degraded=degraded, + tool_count=2, + failed_tool_count=failed_tool_count, + ) + + @staticmethod + def _resolve_next_step(ontology: OntologyParseResult, source: str) -> str: + if ontology.clarification_required: + return "ask_clarification" + if ontology.intent == "draft": + return "create_draft" + if ontology.scenario == "knowledge" or ontology.intent == "explain": + return "search_knowledge" + if ontology.intent == "risk_check" or source == AgentRunSource.SCHEDULE.value: + return "run_rule" + if ontology.intent in {"query", "compare"}: + return "query_database" + return "create_draft" + + @staticmethod + def _flatten_capability_codes( + capabilities: dict[str, list[AgentAssetListItem | AgentAssetRead]], + ) -> list[str]: + codes: list[str] = [] + for items in capabilities.values(): + for item in items[:2]: + if item.code not in codes: + codes.append(item.code) + return codes + + def _rank_assets( + self, + items: list[AgentAssetListItem], + ontology: OntologyParseResult, + ) -> list[AgentAssetListItem]: + def score(item: AgentAssetListItem) -> tuple[int, str]: + item_tags = {str(value) for value in item.scenario_json or []} + weight = 0 + if ontology.scenario in item_tags: + weight += 3 + if ontology.intent in item_tags: + weight += 2 + for risk_flag in ontology.risk_flags: + if risk_flag in item_tags: + weight += 4 + return weight, item.code + + ranked = sorted(items, key=score, reverse=True) + if not ranked: + return [] + scored = [item for item in ranked if score(item)[0] > 0] + return scored or ranked[:1] + + def _invoke_tool( + self, + *, + run_id: str, + tool_type: str, + tool_name: str, + request_json: dict[str, Any], + context_json: dict[str, Any], + executor, + fallback_factory, + ) -> tuple[dict[str, Any], bool]: + started = perf_counter() + try: + self._maybe_raise_simulated_failure(tool_type, context_json) + response = executor() + duration_ms = int((perf_counter() - started) * 1000) + self.run_service.record_tool_call( + run_id=run_id, + tool_type=tool_type, + tool_name=tool_name, + request_json=request_json, + response_json=response, + status="succeeded", + duration_ms=duration_ms, + ) + return response, False + except Exception as exc: + duration_ms = int((perf_counter() - started) * 1000) + response = fallback_factory(exc) + self.run_service.record_tool_call( + run_id=run_id, + tool_type=tool_type, + tool_name=tool_name, + request_json=request_json, + response_json=response, + status="failed", + duration_ms=duration_ms, + error_message=str(exc), + ) + return response, True + + @staticmethod + def _maybe_raise_simulated_failure(tool_type: str, context_json: dict[str, Any]) -> None: + expected = str(context_json.get("simulate_tool_failure") or "").strip().lower() + if not expected: + return + if expected == tool_type.lower(): + raise RuntimeError(f"simulated {tool_type} failure") + + def _build_database_answer(self, ontology: OntologyParseResult) -> dict[str, Any]: + if ontology.scenario == "expense": + count_stmt = select(func.count()).select_from(ExpenseClaim) + amount_stmt = select( + func.coalesce(func.sum(ExpenseClaim.amount), 0) + ).select_from(ExpenseClaim) + employee_names = [ + item.normalized_value + for item in ontology.entities + if item.type == "employee" + ] + if employee_names: + count_stmt = count_stmt.where(ExpenseClaim.employee_name.in_(employee_names)) + amount_stmt = amount_stmt.where(ExpenseClaim.employee_name.in_(employee_names)) + total_count = int(self.db.scalar(count_stmt) or 0) + total_amount = float(self.db.scalar(amount_stmt) or 0) + return { + "record_count": total_count, + "total_amount": round(total_amount, 2), + } + + if ontology.scenario == "accounts_receivable": + total_count = int( + self.db.scalar( + select(func.count()).select_from(AccountsReceivableRecord) + ) + or 0 + ) + total_amount = float( + self.db.scalar( + select(func.coalesce(func.sum(AccountsReceivableRecord.amount_outstanding), 0)) + ) + or 0 + ) + return { + "record_count": total_count, + "outstanding_amount": round(total_amount, 2), + } + + total_count = int( + self.db.scalar(select(func.count()).select_from(AccountsPayableRecord)) + or 0 + ) + total_amount = float( + self.db.scalar( + select(func.coalesce(func.sum(AccountsPayableRecord.amount_outstanding), 0)) + ) + or 0 + ) + return { + "record_count": total_count, + "outstanding_amount": round(total_amount, 2), + } + + @staticmethod + def _build_user_query_result( + ontology: OntologyParseResult, + response: dict[str, Any], + ) -> dict[str, Any]: + if ontology.scenario == "expense": + return { + "message": ( + f"已路由到 User Agent,占位查询结果:命中 {response['record_count']} 笔报销," + f"金额合计 {response['total_amount']} 元。" + ), + "data": response, + } + if ontology.scenario == "accounts_receivable": + return { + "message": ( + f"已路由到 User Agent,占位查询结果:命中 {response['record_count']} 条应收," + f"未回款金额 {response['outstanding_amount']} 元。" + ), + "data": response, + } + return { + "message": ( + f"已路由到 User Agent,占位查询结果:命中 {response['record_count']} 条应付," + f"待付金额 {response['outstanding_amount']} 元。" + ), + "data": response, + } + + @staticmethod + def _build_user_agent_result( + response: UserAgentResponse, + *, + degraded: bool, + ) -> dict[str, Any]: + result = { + "message": response.answer, + "answer": response.answer, + "citations": [item.model_dump() for item in response.citations], + "suggested_actions": [item.model_dump() for item in response.suggested_actions], + "risk_flags": response.risk_flags, + "requires_confirmation": response.requires_confirmation, + "degraded": degraded, + } + if response.draft_payload is not None: + result["draft_payload"] = response.draft_payload.model_dump() + return result + + @staticmethod + def _build_knowledge_answer( + ontology: OntologyParseResult, + capabilities: dict[str, list[AgentAssetListItem | AgentAssetRead]], + ) -> dict[str, Any]: + referenced = [item.code for item in capabilities["rules"][:1]] or [ + "knowledge.policy.default" + ] + return { + "message": f"已路由到 User Agent,占位知识结果:建议先查看 {', '.join(referenced)}。", + "references": referenced, + } + + @staticmethod + def _build_rule_answer(ontology: OntologyParseResult) -> dict[str, Any]: + risk_text = ( + "、".join(ontology.risk_flags) + if ontology.risk_flags + else "未识别到明确风险标签" + ) + return { + "message": f"已完成占位规则检查,风险标签:{risk_text}。", + "risk_flags": ontology.risk_flags, + } + + @staticmethod + def _build_mcp_answer( + task_asset: AgentAssetRead | None, + ontology: OntologyParseResult, + ) -> dict[str, Any]: + return { + "message": ( + f"已调用占位 MCP 快照,任务={task_asset.code if task_asset else 'none'}," + f"scenario={ontology.scenario}。" + ), + "snapshot": "stubbed", + } + + @staticmethod + def _build_hermes_message( + *, + task_asset: AgentAssetRead | None, + ontology: OntologyParseResult, + rule_response: dict[str, Any], + mcp_response: dict[str, Any], + degraded: bool, + ) -> str: + task_code = task_asset.code if task_asset is not None else "task.unspecified" + suffix = ",其中部分能力已降级。" if degraded else "。" + return ( + f"Hermes 占位执行完成:任务 {task_code}," + f"场景 {ontology.scenario},规则结果={rule_response.get('message', '')}," + f"MCP 结果={mcp_response.get('message', '')}{suffix}" + ) + + @staticmethod + def _database_tool_name(scenario: str) -> str: + if scenario == "expense": + return "database.expense_claims.lookup" + if scenario == "accounts_receivable": + return "database.accounts_receivable.lookup" + return "database.accounts_payable.lookup" + + @staticmethod + def _rule_tool_name( + capabilities: dict[str, list[AgentAssetListItem | AgentAssetRead]], + ) -> str: + if capabilities["rules"]: + return capabilities["rules"][0].code + return "rule_engine.default_risk_check" + + @staticmethod + def _mcp_tool_name( + capabilities: dict[str, list[AgentAssetListItem | AgentAssetRead]], + ) -> str: + if capabilities["mcps"]: + return capabilities["mcps"][0].code + return "mcp.default_snapshot" + + @staticmethod + def _build_ontology_json(ontology: OntologyParseResult) -> dict[str, Any]: + return { + "scenario": ontology.scenario, + "intent": ontology.intent, + "entities": [item.model_dump() for item in ontology.entities], + "time_range": ontology.time_range.model_dump(), + "metrics": [item.model_dump() for item in ontology.metrics], + "constraints": [item.model_dump() for item in ontology.constraints], + "risk_flags": ontology.risk_flags, + "permission": ontology.permission.model_dump(), + } + + @staticmethod + def _normalize_response_status(status: str) -> str: + if status == AgentRunStatus.FAILED.value: + return "failed" + if status == AgentRunStatus.BLOCKED.value: + return "blocked" + return "succeeded" diff --git a/server/src/app/services/runtime_chat.py b/server/src/app/services/runtime_chat.py new file mode 100644 index 0000000..0e61b1e --- /dev/null +++ b/server/src/app/services/runtime_chat.py @@ -0,0 +1,252 @@ +from __future__ import annotations + +from http import HTTPStatus +from typing import Any + +from sqlalchemy.orm import Session + +from app.core.logging import get_logger +from app.services.model_connectivity import ( + AZURE_API_VERSION, + ConnectivityCheckError, + _build_azure_deployment_base, + _build_headers, + _ensure_path, + _normalize_endpoint, + _send_json_request, +) +from app.services.settings import SettingsService + +logger = get_logger("app.services.runtime_chat") + + +class RuntimeChatService: + def __init__(self, db: Session) -> None: + self.db = db + self.settings_service = SettingsService(db) + + def complete( + self, + messages: list[dict[str, str]], + *, + slot_priority: tuple[str, ...] = ("main", "backup"), + max_tokens: int = 500, + temperature: float = 0.2, + ) -> str | None: + for slot in slot_priority: + config = self._load_chat_slot(slot) + if config is None: + continue + + try: + response_text = self._request_chat_completion( + config, + messages, + max_tokens=max_tokens, + temperature=temperature, + ) + except Exception as exc: + logger.warning( + "Runtime chat request failed slot=%s provider=%s: %s", + slot, + config["provider"], + exc, + ) + continue + + if response_text: + return response_text.strip() + + return None + + def _load_chat_slot(self, slot: str) -> dict[str, str] | None: + try: + config = self.settings_service.get_runtime_model_config(slot) + except ValueError: + return None + + if config["capability"] != "chat": + return None + + provider = str(config["provider"] or "").strip() + endpoint = str(config["endpoint"] or "").strip() + model = str(config["model"] or "").strip() + api_key = str(config["apiKey"] or "").strip() + + if not provider or not endpoint or not model: + return None + + if provider != "Ollama" and not api_key: + logger.info("Skip runtime chat slot=%s because api key is empty", slot) + return None + + return { + "slot": slot, + "provider": provider, + "endpoint": endpoint, + "model": model, + "apiKey": api_key, + } + + def _request_chat_completion( + self, + config: dict[str, str], + messages: list[dict[str, str]], + *, + max_tokens: int, + temperature: float, + ) -> str: + provider = config["provider"] + endpoint = config["endpoint"] + model = config["model"] + api_key = config["apiKey"] + + if provider == "Azure OpenAI": + return self._request_azure_openai( + endpoint=endpoint, + model=model, + api_key=api_key, + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + ) + + if provider == "Ollama": + return self._request_ollama( + endpoint=endpoint, + model=model, + api_key=api_key, + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + ) + + return self._request_openai_compatible( + endpoint=endpoint, + model=model, + api_key=api_key, + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + ) + + def _request_openai_compatible( + self, + *, + endpoint: str, + model: str, + api_key: str, + messages: list[dict[str, str]], + max_tokens: int, + temperature: float, + ) -> str: + url = _ensure_path(_normalize_endpoint(endpoint), "chat/completions") + status_code, payload = _send_json_request( + "POST", + url, + headers=_build_headers(api_key=api_key, use_bearer=True), + payload={ + "model": model, + "messages": messages, + "max_tokens": max_tokens, + "temperature": temperature, + }, + ) + if status_code >= HTTPStatus.BAD_REQUEST: + raise ConnectivityCheckError( + f"模型接口返回异常状态 {status_code}。", + status_code=status_code, + ) + return self._extract_openai_text(payload) + + def _request_ollama( + self, + *, + endpoint: str, + model: str, + api_key: str, + messages: list[dict[str, str]], + max_tokens: int, + temperature: float, + ) -> str: + url = _ensure_path(_normalize_endpoint(endpoint), "api/chat") + status_code, payload = _send_json_request( + "POST", + url, + headers=_build_headers(api_key=api_key, use_bearer=False), + payload={ + "model": model, + "messages": messages, + "stream": False, + "options": { + "num_predict": max_tokens, + "temperature": temperature, + }, + }, + ) + if status_code >= HTTPStatus.BAD_REQUEST: + raise ConnectivityCheckError( + f"Ollama 返回异常状态 {status_code}。", + status_code=status_code, + ) + return str((payload or {}).get("message", {}).get("content", "")).strip() + + def _request_azure_openai( + self, + *, + endpoint: str, + model: str, + api_key: str, + messages: list[dict[str, str]], + max_tokens: int, + temperature: float, + ) -> str: + deployment_base = _build_azure_deployment_base(endpoint, model) + url = f"{deployment_base}/chat/completions?api-version={AZURE_API_VERSION}" + status_code, payload = _send_json_request( + "POST", + url, + headers=_build_headers(api_key=api_key, use_bearer=False, use_api_key=True), + payload={ + "messages": messages, + "max_tokens": max_tokens, + "temperature": temperature, + }, + ) + if status_code >= HTTPStatus.BAD_REQUEST: + raise ConnectivityCheckError( + f"Azure OpenAI 返回异常状态 {status_code}。", + status_code=status_code, + ) + return self._extract_openai_text(payload) + + @staticmethod + def _extract_openai_text(payload: Any) -> str: + if not isinstance(payload, dict): + return "" + + choices = payload.get("choices") + if not isinstance(choices, list) or not choices: + return "" + + first_choice = choices[0] + if not isinstance(first_choice, dict): + return "" + + message = first_choice.get("message") + if isinstance(message, dict): + content = message.get("content", "") + if isinstance(content, str): + return content.strip() + if isinstance(content, list): + parts: list[str] = [] + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + parts.append(str(item.get("text", ""))) + return "\n".join(part.strip() for part in parts if part.strip()).strip() + + text = first_choice.get("text") + if isinstance(text, str): + return text.strip() + + return "" diff --git a/server/src/app/services/user_agent.py b/server/src/app/services/user_agent.py new file mode 100644 index 0000000..624345e --- /dev/null +++ b/server/src/app/services/user_agent.py @@ -0,0 +1,547 @@ +from __future__ import annotations + +import json +import re + +from sqlalchemy.orm import Session + +from app.core.agent_enums import AgentAssetStatus, AgentAssetType +from app.schemas.agent_asset import AgentAssetListItem +from app.schemas.user_agent import ( + UserAgentCitation, + UserAgentDraftPayload, + UserAgentRequest, + UserAgentResponse, + UserAgentSuggestedAction, +) +from app.services.agent_assets import AgentAssetService +from app.services.agent_foundation import AgentFoundationService +from app.services.runtime_chat import RuntimeChatService + +SCENARIO_LABELS = { + "expense": "报销", + "accounts_receivable": "应收", + "accounts_payable": "应付", + "knowledge": "知识", + "unknown": "通用", +} + +RISK_REASON_MAP = { + "duplicate_expense": "检测到同员工、同金额或近似单据存在重复提交迹象。", + "amount_over_limit": "金额超过当前制度或预算阈值,需要补充例外说明。", + "invoice_anomaly": "票据或附件完整性不满足当前规则要求,需要补件或人工复核。", + "ar_overdue": "应收账款已出现逾期,存在回款延迟风险。", + "ap_overdue": "应付付款已出现逾期,可能影响供应商履约或合作关系。", +} + +GENERIC_EXPENSE_PROMPTS = { + "报销", + "我要报销", + "我想报销", + "帮我报销", + "我要申请报销", + "发起报销", + "提交报销", +} + +EXPLICIT_DRAFT_KEYWORDS = ("生成", "草稿", "起草", "创建", "发起", "准备") + +EXPENSE_TYPE_LABELS = { + "travel": "差旅", + "hotel": "住宿", + "transport": "交通", + "meal": "餐费", + "meeting": "会务", + "entertainment": "招待", +} + + +class UserAgentService: + def __init__(self, db: Session) -> None: + self.db = db + self.asset_service = AgentAssetService(db) + self.runtime_chat_service = RuntimeChatService(db) + + def respond(self, payload: UserAgentRequest) -> UserAgentResponse: + AgentFoundationService(self.db).ensure_foundation_ready() + citations = self._build_rule_citations(payload) + suggested_actions = self._build_suggested_actions(payload) + risk_flags = self._resolve_risk_flags(payload) + draft_payload = ( + self._build_draft_payload(payload) + if payload.ontology.intent == "draft" + else None + ) + + if payload.degraded and payload.tool_payload.get("message"): + return UserAgentResponse( + answer=str(payload.tool_payload["message"]), + citations=citations, + suggested_actions=suggested_actions, + risk_flags=risk_flags, + requires_confirmation=payload.requires_confirmation, + ) + + guided_answer = self._build_guided_answer(payload) + if guided_answer: + return UserAgentResponse( + answer=guided_answer, + citations=citations, + suggested_actions=suggested_actions, + draft_payload=draft_payload, + risk_flags=risk_flags, + requires_confirmation=payload.requires_confirmation, + ) + + fallback_answer = self._build_fallback_answer( + payload, + citations=citations, + draft_payload=draft_payload, + ) + answer = self._generate_answer_with_model( + payload, + citations=citations, + suggested_actions=suggested_actions, + risk_flags=risk_flags, + draft_payload=draft_payload, + fallback_answer=fallback_answer, + ) + + return UserAgentResponse( + answer=answer or fallback_answer, + citations=citations, + suggested_actions=suggested_actions, + draft_payload=draft_payload, + risk_flags=risk_flags, + requires_confirmation=payload.requires_confirmation, + ) + + def _build_fallback_answer( + self, + payload: UserAgentRequest, + *, + citations: list[UserAgentCitation], + draft_payload: UserAgentDraftPayload | None, + ) -> str: + if payload.ontology.intent in {"query", "compare"}: + return self._build_query_answer(payload) + + if payload.ontology.intent == "risk_check": + return self._build_risk_answer(payload, citations) + + if payload.ontology.intent == "draft" and draft_payload is not None: + return ( + f"已生成 {draft_payload.title},当前仅返回待人工确认的草稿内容," + "仍需人工确认后再进入正式流程。" + ) + + return self._build_explain_answer(payload, citations) + + def _build_guided_answer(self, payload: UserAgentRequest) -> str | None: + if not self._is_generic_expense_prompt(payload): + return self._build_implicit_expense_draft_guidance(payload) + + attachment_names = self._resolve_attachment_names(payload) + attachment_hint = "" + if attachment_names: + attachment_hint = ( + f" 我已带入 {len(attachment_names)} 份附件名称,但目前还不能直接读取附件内容," + "仍需要你补充关键信息。" + ) + + return ( + "可以帮你发起报销。请补充费用类型、发生时间、金额、事由和相关对象," + "或者直接上传票据附件,我再继续帮你判断能否报、缺什么材料以及生成报销草稿。" + f"{attachment_hint}" + ) + + def _build_implicit_expense_draft_guidance( + self, + payload: UserAgentRequest, + ) -> str | None: + if not self._is_implicit_expense_draft_request(payload): + return None + + amount_text = next( + (item.value for item in payload.ontology.entities if item.type == "amount"), + "", + ) + expense_type = next( + ( + EXPENSE_TYPE_LABELS.get(item.normalized_value, item.value) + for item in payload.ontology.entities + if item.type == "expense_type" + ), + "报销", + ) + time_text = payload.ontology.time_range.raw or "本次" + amount_hint = f",金额 {amount_text}" if amount_text else "" + + return ( + f"已识别到一笔{time_text}的{expense_type}支出{amount_hint}。" + "如果要继续生成报销草稿,还需要补充客户单位、参与人员、费用明细和票据附件。" + "你也可以继续上传发票或图片,我会把这些信息带入后续对话。" + ) + + def _generate_answer_with_model( + self, + payload: UserAgentRequest, + *, + citations: list[UserAgentCitation], + suggested_actions: list[UserAgentSuggestedAction], + risk_flags: list[str], + draft_payload: UserAgentDraftPayload | None, + fallback_answer: str, + ) -> str | None: + messages = self._build_model_messages( + payload, + citations=citations, + suggested_actions=suggested_actions, + risk_flags=risk_flags, + draft_payload=draft_payload, + fallback_answer=fallback_answer, + ) + return self._sanitize_model_answer( + self.runtime_chat_service.complete( + messages, + max_tokens=420, + temperature=0.2, + ) + ) + + def _sanitize_model_answer(self, answer: str | None) -> str | None: + if not answer: + return None + + cleaned = re.sub(r".*?", "", answer, flags=re.DOTALL | re.IGNORECASE) + cleaned = cleaned.strip() + return cleaned or None + + def _build_model_messages( + self, + payload: UserAgentRequest, + *, + citations: list[UserAgentCitation], + suggested_actions: list[UserAgentSuggestedAction], + risk_flags: list[str], + draft_payload: UserAgentDraftPayload | None, + fallback_answer: str, + ) -> list[dict[str, str]]: + facts = { + "run_id": payload.run_id, + "user_message": payload.message, + "ontology": payload.ontology.model_dump(mode="json"), + "context": { + "entry_source": payload.context_json.get("entry_source"), + "user_name": payload.context_json.get("name"), + "user_role": payload.context_json.get("role"), + "request_context": payload.context_json.get("request_context"), + "attachment_count": payload.context_json.get("attachment_count"), + "attachment_names": self._resolve_attachment_names(payload), + }, + "tool_payload": payload.tool_payload, + "citations": [item.model_dump(mode="json") for item in citations], + "suggested_actions": [ + item.model_dump(mode="json") for item in suggested_actions + ], + "risk_flags": risk_flags, + "draft_payload": ( + draft_payload.model_dump(mode="json") + if draft_payload is not None + else None + ), + "selected_capability_codes": payload.selected_capability_codes, + "requires_confirmation": payload.requires_confirmation, + "fallback_answer": fallback_answer, + } + + system_prompt = ( + "你是企业财务共享场景中的中文智能助手,负责和最终用户直接对话。" + "你只能基于提供的事实回答,不能编造制度、流程结果或附件内容。" + "如果用户问题很笼统,例如“我要报销”,优先告诉用户你可以协助什么," + "并明确要求补充费用类型、金额、时间、事由、参与对象或上传票据。" + "如果上下文里只有附件名称,必须明确说明你只拿到了附件名称," + "不能假装已看过图片、PDF 或发票内容。" + "不要声称已经提交、审批、付款、入账或真正执行了任何动作;如果只是建议、草稿或待确认,要明确说清楚。" + "若给出了风险标签、制度引用或建议动作,可以简洁吸收进回答,但不要新增未提供的事实。" + "只输出最终给用户看的自然语言,不要输出 JSON、Markdown、标题、" + " 标签或任何中间推理。" + "使用简体中文,控制在 2 到 4 句。" + ) + user_prompt = ( + "请根据以下事实生成最终答复,优先保持准确、具体、可执行:\n" + f"{json.dumps(facts, ensure_ascii=False, indent=2)}" + ) + return [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ] + + def _build_query_answer(self, payload: UserAgentRequest) -> str: + scenario = payload.ontology.scenario + data = payload.tool_payload + subject = self._resolve_subject(payload) + + if scenario == "expense": + record_count = int(data.get("record_count") or 0) + total_amount = float(data.get("total_amount") or 0) + return ( + f"{subject}共命中 {record_count} 笔报销,金额合计 {total_amount:.2f} 元。" + "如需继续处理,可以查看明细或生成处理意见草稿。" + ) + + if scenario == "accounts_receivable": + record_count = int(data.get("record_count") or 0) + outstanding_amount = float(data.get("outstanding_amount") or 0) + return ( + f"{subject}共命中 {record_count} 条应收,未回款金额 {outstanding_amount:.2f} 元。" + "建议结合账龄和客户分布继续排查逾期风险。" + ) + + if scenario == "accounts_payable": + record_count = int(data.get("record_count") or 0) + outstanding_amount = float(data.get("outstanding_amount") or 0) + return ( + f"{subject}共命中 {record_count} 条应付,待付金额 {outstanding_amount:.2f} 元。" + "如需推进动作,建议先生成付款建议草稿并发起人工确认。" + ) + + return "已完成当前查询,但暂时没有更多结构化结果可展示。" + + def _build_explain_answer( + self, + payload: UserAgentRequest, + citations: list[UserAgentCitation], + ) -> str: + if citations: + titles = "、".join(item.title for item in citations[:2]) + summary = citations[0].excerpt or "请结合制度全文进一步确认。" + return f"已检索到相关依据:{titles}。核心说明:{summary}" + + return ( + f"当前还没有与“{SCENARIO_LABELS.get(payload.ontology.scenario, '当前问题')}”" + "强匹配的已上线规则引用,建议先人工复核或补充更具体的单据上下文。" + ) + + def _build_risk_answer( + self, + payload: UserAgentRequest, + citations: list[UserAgentCitation], + ) -> str: + risk_flags = self._resolve_risk_flags(payload) + if not risk_flags: + return "当前未识别到明确风险标签,建议继续查看原始明细或补充更多上下文。" + + reasons = [RISK_REASON_MAP.get(flag, f"{flag} 需要人工进一步确认。") for flag in risk_flags] + citation_text = ( + f" 参考规则:{'、'.join(item.title for item in citations[:2])}。" + if citations + else "" + ) + return ( + f"本次识别到 {len(risk_flags)} 类风险:{'、'.join(risk_flags)}。" + f"触发原因:{';'.join(reasons)}。" + "建议先复核明细、附件和审批链,再决定是否继续处理。" + f"{citation_text}" + ) + + def _build_draft_payload(self, payload: UserAgentRequest) -> UserAgentDraftPayload: + scenario_label = SCENARIO_LABELS.get(payload.ontology.scenario, "业务") + subject = self._resolve_subject(payload) + title = f"{scenario_label}处理意见草稿" + body = ( + f"主题:{subject}\n" + "结论:已根据当前语义解析结果生成草稿,尚未自动执行。\n" + "建议:请先核对明细、规则命中和所需附件,再由人工确认是否提交正式流程。\n" + f"原始问题:{payload.message}" + ) + return UserAgentDraftPayload( + draft_type=payload.ontology.scenario, + title=title, + body=body, + confirmation_required=True, + ) + + def _build_suggested_actions( + self, + payload: UserAgentRequest, + ) -> list[UserAgentSuggestedAction]: + if self._is_generic_expense_prompt(payload): + return [ + UserAgentSuggestedAction( + label="上传票据", + action_type="ask_clarification", + description="上传发票、行程单或付款截图,继续识别报销内容。", + ), + UserAgentSuggestedAction( + label="补充报销信息", + action_type="ask_clarification", + description="补充费用类型、金额、时间和事由后继续处理。", + ), + ] + + if payload.ontology.intent in {"query", "compare"}: + return [ + UserAgentSuggestedAction( + label="查看明细", + action_type="open_detail", + description="继续查看命中记录和过滤条件。", + ), + UserAgentSuggestedAction( + label="生成处理意见", + action_type="create_draft", + description="把当前查询结果整理成可确认草稿。", + ), + ] + + if payload.ontology.intent == "risk_check": + return [ + UserAgentSuggestedAction( + label="人工复核风险", + action_type="manual_review", + description="优先检查明细、附件和规则命中原因。", + ), + UserAgentSuggestedAction( + label="生成整改建议", + action_type="create_draft", + description="把风险说明整理成处理意见草稿。", + ), + ] + + if payload.ontology.intent == "draft": + return [ + UserAgentSuggestedAction( + label="复制草稿", + action_type="copy_draft", + description="复制当前草稿后交由人工确认。", + ), + UserAgentSuggestedAction( + label="补充上下文", + action_type="ask_clarification", + description="补充单据编号、客户或供应商信息以完善草稿。", + ), + ] + + return [ + UserAgentSuggestedAction( + label="查看规则全文", + action_type="open_rule", + description="继续查看引用规则或知识内容。", + ), + UserAgentSuggestedAction( + label="补充问题上下文", + action_type="ask_clarification", + description="补充业务对象、时间或单据范围,提升回答准确度。", + ), + ] + + def _build_rule_citations(self, payload: UserAgentRequest) -> list[UserAgentCitation]: + domain = self._resolve_domain(payload.ontology.scenario) + items = self.asset_service.list_assets( + asset_type=AgentAssetType.RULE.value, + status=AgentAssetStatus.ACTIVE.value, + domain=domain, + ) + ranked = self._rank_rule_assets(items, payload) + citations: list[UserAgentCitation] = [] + for item in ranked[:2]: + detail = self.asset_service.get_asset(item.id) + if detail is None: + continue + excerpt = self._extract_excerpt(str(detail.current_version_content or "")) + citations.append( + UserAgentCitation( + source_type="rule", + code=detail.code, + title=detail.name, + version=detail.current_version, + updated_at=detail.updated_at.date().isoformat(), + excerpt=excerpt, + ) + ) + return citations + + @staticmethod + def _resolve_risk_flags(payload: UserAgentRequest) -> list[str]: + tool_flags = payload.tool_payload.get("risk_flags") + if isinstance(tool_flags, list) and tool_flags: + return [str(item) for item in tool_flags] + return [str(item) for item in payload.ontology.risk_flags] + + @staticmethod + def _resolve_subject(payload: UserAgentRequest) -> str: + named_entities = [ + item.value + for item in payload.ontology.entities + if item.type in {"employee", "customer", "vendor", "project"} + ] + if named_entities: + return f"{'、'.join(named_entities)} 相关数据" + return f"{SCENARIO_LABELS.get(payload.ontology.scenario, '当前')}场景数据" + + @staticmethod + def _is_generic_expense_prompt(payload: UserAgentRequest) -> bool: + if payload.ontology.scenario != "expense": + return False + normalized_message = re.sub(r"\s+", "", payload.message) + return normalized_message in GENERIC_EXPENSE_PROMPTS + + @staticmethod + def _is_implicit_expense_draft_request(payload: UserAgentRequest) -> bool: + if payload.ontology.scenario != "expense" or payload.ontology.intent != "draft": + return False + + compact_message = re.sub(r"\s+", "", payload.message) + if any(keyword in compact_message for keyword in EXPLICIT_DRAFT_KEYWORDS): + return False + + return True + + @staticmethod + def _resolve_attachment_names(payload: UserAgentRequest) -> list[str]: + names = payload.context_json.get("attachment_names") + if not isinstance(names, list): + return [] + return [str(name) for name in names if str(name).strip()] + + @staticmethod + def _resolve_domain(scenario: str) -> str | None: + if scenario == "expense": + return "expense" + if scenario == "accounts_receivable": + return "ar" + if scenario == "accounts_payable": + return "ap" + return None + + @staticmethod + def _rank_rule_assets( + items: list[AgentAssetListItem], + payload: UserAgentRequest, + ) -> list[AgentAssetListItem]: + def score(item: AgentAssetListItem) -> tuple[int, str]: + tags = {str(value) for value in item.scenario_json or []} + weight = 0 + if payload.ontology.scenario in tags: + weight += 3 + if payload.ontology.intent in tags: + weight += 2 + for risk_flag in payload.ontology.risk_flags: + if risk_flag in tags: + weight += 4 + return weight, item.code + + ranked = sorted(items, key=score, reverse=True) + return [item for item in ranked if score(item)[0] > 0] + + @staticmethod + def _extract_excerpt(content: str) -> str: + lines = [line.strip() for line in str(content).splitlines() if line.strip()] + cleaned: list[str] = [] + for line in lines: + normalized = re.sub(r"^[#>\-\*\d\.\s`]+", "", line).strip() + if normalized: + cleaned.append(normalized) + if len(cleaned) >= 2: + break + return ";".join(cleaned[:2]) diff --git a/server/tests/test_ontology_service.py b/server/tests/test_ontology_service.py new file mode 100644 index 0000000..6086771 --- /dev/null +++ b/server/tests/test_ontology_service.py @@ -0,0 +1,397 @@ +from __future__ import annotations + +from collections.abc import Generator + +import pytest +from fastapi.testclient import TestClient +from sqlalchemy import create_engine +from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy.pool import StaticPool + +from app.api.deps import get_db +from app.db.base import Base +from app.main import create_app +from app.schemas.ontology import OntologyParseRequest +from app.services.ontology import LlmOntologyParseResult, SemanticOntologyService + + +def build_session_factory() -> sessionmaker[Session]: + engine = create_engine( + "sqlite+pysqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + Base.metadata.create_all(bind=engine) + return sessionmaker(bind=engine, autoflush=False, autocommit=False) + + +def build_client() -> tuple[TestClient, sessionmaker[Session]]: + session_factory = build_session_factory() + app = create_app() + + def override_db() -> Generator[Session, None, None]: + db = session_factory() + try: + yield db + finally: + db.close() + + app.dependency_overrides[get_db] = override_db + return TestClient(app), session_factory + + +EVALUATION_CASES = [ + pytest.param( + "查一下本周报销超标风险", + "expense", + "risk_check", + "read", + {}, + id="expense-risk-check", + ), + pytest.param( + "张三 4 月差旅报销金额是多少", + "expense", + "query", + "read", + {}, + id="expense-query-employee-month", + ), + pytest.param( + "为什么酒店超标报销不能直接通过", + "expense", + "explain", + "read", + {}, + id="expense-explain-policy", + ), + pytest.param( + "列出金额最高的10笔报销", + "expense", + "query", + "read", + {}, + id="expense-topn-query", + ), + pytest.param( + "帮我生成张三4月差旅报销草稿", + "expense", + "draft", + "draft_write", + {}, + id="expense-draft", + ), + pytest.param( + "我今天去客户现场,招待了客户,花销了1000元", + "expense", + "draft", + "draft_write", + {}, + id="expense-narrative-draft", + ), + pytest.param( + "客户 A 这个月还有多少应收", + "accounts_receivable", + "query", + "read", + {}, + id="ar-query-customer-month", + ), + pytest.param( + "对比客户A和客户B本月应收差异", + "accounts_receivable", + "compare", + "read", + {}, + id="ar-compare-customers", + ), + pytest.param( + "检查客户B逾期应收风险", + "accounts_receivable", + "risk_check", + "read", + {}, + id="ar-risk-check", + ), + pytest.param( + "生成客户A回款跟进草稿", + "accounts_receivable", + "draft", + "draft_write", + {}, + id="ar-draft", + ), + pytest.param( + "查询客户B账龄明细", + "accounts_receivable", + "query", + "read", + {}, + id="ar-aging-query", + ), + pytest.param( + "供应商 B 明天要付多少钱", + "accounts_payable", + "query", + "read", + {}, + id="ap-query-vendor-tomorrow", + ), + pytest.param( + "对比供应商A和供应商B本月应付差异", + "accounts_payable", + "compare", + "read", + {}, + id="ap-compare-vendors", + ), + pytest.param( + "检查供应商B逾期付款风险", + "accounts_payable", + "risk_check", + "read", + {}, + id="ap-risk-check", + ), + pytest.param( + "生成供应商A付款沟通草稿", + "accounts_payable", + "draft", + "draft_write", + {}, + id="ap-draft", + ), + pytest.param( + "帮我安排付款给供应商B", + "accounts_payable", + "operate", + "approval_required", + {"role_codes": ["finance"]}, + id="ap-operate-approval-required", + ), + pytest.param( + "公司财务制度在哪里看", + "knowledge", + "query", + "read", + {}, + id="knowledge-query", + ), + pytest.param( + "规则中心的审核依据是什么", + "knowledge", + "explain", + "read", + {}, + id="knowledge-explain", + ), + pytest.param( + "知识库里有没有双人复核制度", + "knowledge", + "query", + "read", + {}, + id="knowledge-query-library", + ), + pytest.param( + "帮我直接付款给供应商B", + "accounts_payable", + "operate", + "forbidden", + {"role_codes": ["user"]}, + id="forbidden-direct-payment", + ), + pytest.param( + "帮我上线付款双人复核规则", + "accounts_payable", + "operate", + "forbidden", + {"role_codes": ["user"]}, + id="forbidden-activate-rule", + ), + pytest.param( + "帮我删除今天的报销记录", + "expense", + "operate", + "forbidden", + {"role_codes": ["user"]}, + id="forbidden-delete-expense", + ), +] + + +@pytest.mark.parametrize("query,scenario,intent,permission,context_json", EVALUATION_CASES) +def test_semantic_ontology_service_matches_day3_evaluation_set( + query: str, + scenario: str, + intent: str, + permission: str, + context_json: dict, +) -> None: + session_factory = build_session_factory() + with session_factory() as db: + result = SemanticOntologyService(db).parse( + OntologyParseRequest( + query=query, + user_id="pytest", + context_json=context_json, + ) + ) + + assert result.scenario == scenario + assert result.intent == intent + assert result.permission.level == permission + assert result.run_id.startswith("run_") + + +def test_semantic_ontology_service_extracts_entities_time_and_constraints() -> None: + session_factory = build_session_factory() + with session_factory() as db: + result = SemanticOntologyService(db).parse( + OntologyParseRequest( + query="张三 2026年4月差旅报销金额超过5000元的明细", + user_id="pytest", + ) + ) + + assert result.scenario == "expense" + assert result.intent == "query" + assert result.time_range.start_date == "2026-04-01" + assert result.time_range.end_date == "2026-04-30" + assert any( + item.type == "employee" and item.normalized_value == "张三" + for item in result.entities + ) + assert any( + item.type == "expense_type" and item.normalized_value == "travel" + for item in result.entities + ) + assert any( + item.field == "amount" and item.operator == ">" and item.value == 5000 + for item in result.constraints + ) + + +def test_semantic_ontology_service_prefers_expense_for_customer_entertainment_narrative() -> None: + session_factory = build_session_factory() + with session_factory() as db: + result = SemanticOntologyService(db).parse( + OntologyParseRequest( + query="我今天去客户现场,招待了客户,花销了1000元", + user_id="pytest", + ) + ) + + assert result.scenario == "expense" + assert result.intent == "draft" + assert result.permission.level == "draft_write" + assert result.time_range.raw == "今天" + assert result.clarification_required is True + assert "customer_name" in result.missing_slots + assert "participants" in result.missing_slots + assert any( + item.type == "expense_type" and item.normalized_value == "entertainment" + for item in result.entities + ) + + +def test_semantic_ontology_service_uses_model_parse_when_available(monkeypatch) -> None: + session_factory = build_session_factory() + with session_factory() as db: + service = SemanticOntologyService(db) + monkeypatch.setattr( + service, + "_parse_with_model", + lambda **kwargs: LlmOntologyParseResult( + scenario="expense", + intent="draft", + confidence=0.91, + clarification_required=True, + clarification_question="请补充费用类型、金额和票据附件。", + missing_slots=["expense_type", "amount", "attachments"], + ambiguity=[], + entity_hints=[], + ), + ) + + result = service.parse( + OntologyParseRequest( + query="我要报销", + user_id="pytest", + ) + ) + + assert result.scenario == "expense" + assert result.intent == "draft" + assert result.parse_strategy == "llm_primary" + assert result.clarification_required is True + assert "expense_type" in result.missing_slots + assert result.clarification_question == "请补充费用类型、金额和票据附件。" + + +def test_parse_ontology_endpoint_returns_eight_fields_and_writes_trace() -> None: + client, _ = build_client() + + response = client.post( + "/api/v1/ontology/parse", + json={ + "query": "查一下本周报销超标风险", + "user_id": "pytest", + "context_json": {"role_codes": ["finance"]}, + }, + ) + + assert response.status_code == 200 + payload = response.json() + assert payload["scenario"] == "expense" + assert payload["intent"] == "risk_check" + assert payload["permission"]["level"] == "read" + assert payload["run_id"].startswith("run_") + assert set(payload) >= { + "scenario", + "intent", + "entities", + "time_range", + "metrics", + "constraints", + "risk_flags", + "permission", + "confidence", + "missing_slots", + "ambiguity", + "parse_strategy", + "clarification_required", + "clarification_question", + "run_id", + "field_errors", + } + + run_response = client.get(f"/api/v1/agent-runs/{payload['run_id']}") + + assert run_response.status_code == 200 + run_payload = run_response.json() + assert run_payload["ontology_json"]["scenario"] == "expense" + assert run_payload["ontology_json"]["intent"] == "risk_check" + assert run_payload["semantic_parse"]["scenario"] == "expense" + assert run_payload["semantic_parse"]["intent"] == "risk_check" + + +def test_parse_ontology_endpoint_returns_forbidden_for_unprivileged_payment_request() -> None: + client, _ = build_client() + + response = client.post( + "/api/v1/ontology/parse", + json={ + "query": "帮我直接付款给供应商B", + "user_id": "pytest", + "context_json": {"role_codes": ["user"]}, + }, + ) + + assert response.status_code == 200 + payload = response.json() + assert payload["scenario"] == "accounts_payable" + assert payload["intent"] == "operate" + assert payload["permission"]["level"] == "forbidden" + assert payload["clarification_required"] is True + assert payload["field_errors"] diff --git a/server/tests/test_orchestrator_service.py b/server/tests/test_orchestrator_service.py new file mode 100644 index 0000000..14d440f --- /dev/null +++ b/server/tests/test_orchestrator_service.py @@ -0,0 +1,241 @@ +from __future__ import annotations + +from collections.abc import Generator + +from fastapi.testclient import TestClient +from sqlalchemy import create_engine +from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy.pool import StaticPool + +from app.api.deps import get_db +from app.db.base import Base +from app.main import create_app +from app.services.agent_assets import AgentAssetService + + +def build_client() -> tuple[TestClient, sessionmaker[Session]]: + engine = create_engine( + "sqlite+pysqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + Base.metadata.create_all(bind=engine) + session_factory = sessionmaker(bind=engine, autoflush=False, autocommit=False) + app = create_app() + + def override_db() -> Generator[Session, None, None]: + db = session_factory() + try: + yield db + finally: + db.close() + + app.dependency_overrides[get_db] = override_db + return TestClient(app), session_factory + + +def test_orchestrator_routes_user_query_to_user_agent() -> None: + client, _ = build_client() + + response = client.post( + "/api/v1/orchestrator/run", + json={ + "source": "user_message", + "user_id": "pytest", + "message": "客户A这个月还有多少应收", + "context_json": {"role_codes": ["finance"]}, + }, + ) + + assert response.status_code == 200 + payload = response.json() + assert payload["selected_agent"] == "user_agent" + assert payload["permission_level"] == "read" + assert payload["status"] == "succeeded" + assert payload["result"]["answer"] + assert payload["result"]["suggested_actions"] + assert payload["trace_summary"]["tool_count"] >= 1 + + run_detail = client.get(f"/api/v1/agent-runs/{payload['run_id']}").json() + assert run_detail["agent"] == "user_agent" + assert run_detail["route_json"]["selected_agent"] == "user_agent" + assert run_detail["semantic_parse"]["scenario"] == "accounts_receivable" + assert run_detail["tool_calls"][0]["tool_type"] == "database" + + +def test_orchestrator_routes_schedule_to_hermes() -> None: + client, session_factory = build_client() + + with session_factory() as db: + task = next( + item + for item in AgentAssetService(db).list_assets(asset_type="task", status="active") + if item.code == "task.hermes.daily_risk_scan" + ) + + response = client.post( + "/api/v1/orchestrator/run", + json={ + "source": "schedule", + "task_id": task.id, + "context_json": {"role_codes": ["finance"]}, + }, + ) + + assert response.status_code == 200 + payload = response.json() + assert payload["selected_agent"] == "hermes" + assert payload["status"] == "succeeded" + assert payload["trace_summary"]["tool_count"] == 2 + + run_detail = client.get(f"/api/v1/agent-runs/{payload['run_id']}").json() + assert run_detail["agent"] == "hermes" + assert run_detail["route_json"]["selected_agent"] == "hermes" + assert len(run_detail["tool_calls"]) == 2 + + +def test_orchestrator_forbidden_request_does_not_call_downstream_agent() -> None: + client, _ = build_client() + + response = client.post( + "/api/v1/orchestrator/run", + json={ + "source": "user_message", + "user_id": "pytest", + "message": "帮我直接付款给供应商B", + "context_json": {"role_codes": ["user"]}, + }, + ) + + assert response.status_code == 200 + payload = response.json() + assert payload["selected_agent"] is None + assert payload["permission_level"] == "forbidden" + assert payload["status"] == "blocked" + assert payload["trace_summary"]["tool_count"] == 0 + + run_detail = client.get(f"/api/v1/agent-runs/{payload['run_id']}").json() + assert run_detail["agent"] == "orchestrator" + assert run_detail["tool_calls"] == [] + + +def test_orchestrator_approval_required_returns_confirmation_result() -> None: + client, _ = build_client() + + response = client.post( + "/api/v1/orchestrator/run", + json={ + "source": "user_message", + "user_id": "pytest", + "message": "帮我安排付款给供应商B", + "context_json": {"role_codes": ["finance"]}, + }, + ) + + assert response.status_code == 200 + payload = response.json() + assert payload["selected_agent"] == "user_agent" + assert payload["permission_level"] == "approval_required" + assert payload["requires_confirmation"] is True + assert payload["status"] == "blocked" + assert "确认" in payload["result"]["message"] + + +def test_orchestrator_user_agent_draft_returns_structured_payload() -> None: + client, _ = build_client() + + response = client.post( + "/api/v1/orchestrator/run", + json={ + "source": "user_message", + "user_id": "pytest", + "message": "帮我生成张三4月差旅报销草稿", + "context_json": {"role_codes": ["finance"]}, + }, + ) + + assert response.status_code == 200 + payload = response.json() + assert payload["selected_agent"] == "user_agent" + assert payload["status"] == "succeeded" + assert payload["result"]["draft_payload"]["confirmation_required"] is True + assert payload["result"]["suggested_actions"] + + +def test_orchestrator_treats_expense_narrative_as_draft_instead_of_ar_query() -> None: + client, _ = build_client() + + response = client.post( + "/api/v1/orchestrator/run", + json={ + "source": "user_message", + "user_id": "pytest", + "message": "我今天去客户现场,招待了客户,花销了1000元", + "context_json": {"role_codes": ["finance"]}, + }, + ) + + assert response.status_code == 200 + payload = response.json() + assert payload["selected_agent"] == "user_agent" + assert payload["permission_level"] == "draft_write" + assert payload["status"] == "blocked" + assert payload["route_reason"] == "clarification_required" + assert payload["trace_summary"]["scenario"] == "expense" + assert payload["trace_summary"]["intent"] == "draft" + assert payload["trace_summary"]["tool_count"] == 0 + assert "应收场景数据" not in payload["result"]["message"] + assert "请补充" in payload["result"]["message"] + + +def test_orchestrator_tool_failure_is_logged_and_degraded() -> None: + client, _ = build_client() + + response = client.post( + "/api/v1/orchestrator/run", + json={ + "source": "user_message", + "user_id": "pytest", + "message": "查一下本周报销金额", + "context_json": { + "role_codes": ["finance"], + "simulate_tool_failure": "database", + }, + }, + ) + + assert response.status_code == 200 + payload = response.json() + assert payload["selected_agent"] == "user_agent" + assert payload["status"] == "succeeded" + assert payload["trace_summary"]["failed_tool_count"] == 1 + assert payload["trace_summary"]["degraded"] is True + + run_detail = client.get(f"/api/v1/agent-runs/{payload['run_id']}").json() + assert run_detail["tool_calls"][0]["status"] == "failed" + assert "simulated database failure" in run_detail["tool_calls"][0]["error_message"] + + +def test_orchestrator_exception_is_written_to_agent_run() -> None: + client, _ = build_client() + + response = client.post( + "/api/v1/orchestrator/run", + json={ + "source": "user_message", + "user_id": "pytest", + "message": "查一下本周报销金额", + "context_json": { + "role_codes": ["finance"], + "simulate_orchestrator_exception": True, + }, + }, + ) + + assert response.status_code == 200 + payload = response.json() + assert payload["status"] == "failed" + + run_detail = client.get(f"/api/v1/agent-runs/{payload['run_id']}").json() + assert run_detail["status"] == "failed" + assert "simulated orchestrator exception" in run_detail["error_message"] diff --git a/server/tests/test_user_agent_service.py b/server/tests/test_user_agent_service.py new file mode 100644 index 0000000..3c838ed --- /dev/null +++ b/server/tests/test_user_agent_service.py @@ -0,0 +1,179 @@ +from __future__ import annotations + +from sqlalchemy import create_engine +from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy.pool import StaticPool + +from app.db.base import Base +from app.schemas.ontology import OntologyParseRequest +from app.schemas.user_agent import UserAgentRequest +from app.services.ontology import SemanticOntologyService +from app.services.user_agent import UserAgentService + + +def build_session_factory() -> sessionmaker[Session]: + engine = create_engine( + "sqlite+pysqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + Base.metadata.create_all(bind=engine) + return sessionmaker(bind=engine, autoflush=False, autocommit=False) + + +def test_user_agent_query_returns_readable_answer_and_actions() -> None: + session_factory = build_session_factory() + with session_factory() as db: + ontology = SemanticOntologyService(db).parse( + OntologyParseRequest( + query="张三 4 月差旅报销金额是多少", + user_id="pytest", + ) + ) + response = UserAgentService(db).respond( + UserAgentRequest( + run_id=ontology.run_id, + user_id="pytest", + message="张三 4 月差旅报销金额是多少", + ontology=ontology, + tool_payload={"record_count": 2, "total_amount": 8800.0}, + ) + ) + + assert "8800.00" in response.answer + assert len(response.suggested_actions) >= 1 + + +def test_user_agent_prefers_runtime_model_answer_when_available(monkeypatch) -> None: + session_factory = build_session_factory() + with session_factory() as db: + ontology = SemanticOntologyService(db).parse( + OntologyParseRequest( + query="张三 4 月差旅报销金额是多少", + user_id="pytest", + ) + ) + service = UserAgentService(db) + monkeypatch.setattr( + service, + "_generate_answer_with_model", + lambda *args, **kwargs: "这是模型回答", + ) + + response = service.respond( + UserAgentRequest( + run_id=ontology.run_id, + user_id="pytest", + message="张三 4 月差旅报销金额是多少", + ontology=ontology, + tool_payload={"record_count": 2, "total_amount": 8800.0}, + ) + ) + + assert response.answer == "这是模型回答" + + +def test_user_agent_sanitizes_model_thinking_blocks() -> None: + session_factory = build_session_factory() + with session_factory() as db: + service = UserAgentService(db) + + assert ( + service._sanitize_model_answer("内部推理\n最终答复") + == "最终答复" + ) + + +def test_user_agent_guides_generic_expense_request() -> None: + session_factory = build_session_factory() + with session_factory() as db: + ontology = SemanticOntologyService(db).parse( + OntologyParseRequest( + query="我要报销", + user_id="pytest", + ) + ) + response = UserAgentService(db).respond( + UserAgentRequest( + run_id=ontology.run_id, + user_id="pytest", + message="我要报销", + ontology=ontology, + tool_payload={"record_count": 9, "total_amount": 12345.0}, + ) + ) + + assert "补充费用类型" in response.answer + assert "上传票据" in response.answer + + +def test_user_agent_guides_implicit_expense_draft_request() -> None: + session_factory = build_session_factory() + with session_factory() as db: + ontology = SemanticOntologyService(db).parse( + OntologyParseRequest( + query="我今天去客户现场,招待了客户,花销了1000元", + user_id="pytest", + ) + ) + response = UserAgentService(db).respond( + UserAgentRequest( + run_id=ontology.run_id, + user_id="pytest", + message="我今天去客户现场,招待了客户,花销了1000元", + ontology=ontology, + tool_payload={"draft_only": True}, + ) + ) + + assert "1000元" in response.answer + assert "票据附件" in response.answer + assert "报销草稿" in response.answer + + +def test_user_agent_risk_response_includes_rule_citations() -> None: + session_factory = build_session_factory() + with session_factory() as db: + ontology = SemanticOntologyService(db).parse( + OntologyParseRequest( + query="检查重复报销风险", + user_id="pytest", + ) + ) + response = UserAgentService(db).respond( + UserAgentRequest( + run_id=ontology.run_id, + user_id="pytest", + message="检查重复报销风险", + ontology=ontology, + tool_payload={"risk_flags": ["duplicate_expense"]}, + ) + ) + + assert response.risk_flags == ["duplicate_expense"] + assert any(item.source_type == "rule" for item in response.citations) + assert "duplicate_expense" in response.answer + + +def test_user_agent_draft_returns_structured_payload() -> None: + session_factory = build_session_factory() + with session_factory() as db: + ontology = SemanticOntologyService(db).parse( + OntologyParseRequest( + query="帮我生成张三4月差旅报销草稿", + user_id="pytest", + ) + ) + response = UserAgentService(db).respond( + UserAgentRequest( + run_id=ontology.run_id, + user_id="pytest", + message="帮我生成张三4月差旅报销草稿", + ontology=ontology, + tool_payload={"draft_only": True}, + ) + ) + + assert response.draft_payload is not None + assert response.draft_payload.confirmation_required is True + assert "待人工确认" in response.answer