Files
X-Financial/server/src/app/services/steward_slot_decision_agent.py
caoxiaozhu e124e4bbcb feat: 报销审批流重构与管家计划全链路贯通
- 重构报销状态注册表、审批流路由与平台风险标记
- 完善管家意图规划器与模型计划构建器全链路
- 新增 OCR Worker 脚本、数据库会话管理与通知状态
- 优化文档中心、日志视图、预算中心与员工管理交互
- 增强工作台摘要、图标资源与全局主题样式
- 补充审批路由、状态注册、OCR 服务与管家规划器测试覆盖
2026-06-06 17:19:07 +08:00

302 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import json
from dataclasses import dataclass
from typing import Any
from app.schemas.steward import (
StewardSlotDecisionRequest,
StewardSlotDecisionResponse,
StewardSlotOption,
)
from app.services.ontology_field_registry import normalize_ontology_form_values
from app.services.runtime_chat import RuntimeChatService
from app.services.steward_constants import BUSINESS_CANONICAL_FIELD_ORDER, BUSINESS_CANONICAL_FIELDS
STEWARD_SLOT_DECISION_FUNCTION_NAME = "submit_steward_slot_decision"
FIELD_CATALOG: dict[str, dict[str, str]] = {
"expense_type": {"label": "费用类型", "description": "申请或报销所属费用场景,如差旅、交通、住宿、业务招待。"},
"time_range": {"label": "时间", "description": "申请时为出差起止日期,报销时为费用发生日期。"},
"location": {"label": "地点", "description": "出差目的地、费用发生地或业务活动地点。"},
"reason": {"label": "事由", "description": "出差、报销或业务活动的业务原因。"},
"amount": {"label": "金额", "description": "报销时为实际金额;申请时金额可由系统估算,不应默认要求用户填写。"},
"transport_mode": {"label": "出行方式", "description": "差旅申请交通费用测算所需字段,由用户明确选择或表达。"},
"attachments": {"label": "附件/凭证", "description": "发票、行程单、付款截图或其他证明材料。"},
"customer_name": {"label": "客户或项目对象", "description": "业务招待、客户拜访或项目支撑涉及的对象。"},
"merchant_name": {"label": "商户/开票方", "description": "报销票据上的商户或开票方。"},
"department_name": {"label": "所属部门", "description": "申请人或费用归属部门。"},
"employee_name": {"label": "申请人", "description": "发起申请或报销的员工。"},
"employee_no": {"label": "员工编号", "description": "公司内部员工编号。"},
}
APPLICATION_NON_BLOCKING_FIELDS = {"amount", "attachments", "employee_no", "department_name", "employee_name"}
@dataclass(frozen=True, slots=True)
class StewardSlotDecisionAgentResult:
payload: dict[str, Any]
model_call_traces: list[dict[str, Any]]
class StewardSlotDecisionAgent:
"""用大模型 function calling 判断当前任务缺什么,以及下一步是否应先追问。"""
def __init__(self, runtime_chat_service: RuntimeChatService) -> None:
self.runtime_chat_service = runtime_chat_service
def decide(self, request: StewardSlotDecisionRequest) -> StewardSlotDecisionResponse:
normalized_request = self._normalize_request(request)
result = self.runtime_chat_service.complete_with_tool_call(
self._build_messages(normalized_request),
tools=[self._build_tool_schema()],
tool_choice={
"type": "function",
"function": {"name": STEWARD_SLOT_DECISION_FUNCTION_NAME},
},
max_tokens=1200,
temperature=0.05,
timeout_seconds=30,
max_attempts=1,
)
if result.tool_call is not None and result.tool_call.name == STEWARD_SLOT_DECISION_FUNCTION_NAME:
response = self._build_response_from_model_payload(
result.tool_call.arguments,
normalized_request,
result.calls_as_dicts(),
)
if response is not None:
return response
return self._build_rule_fallback(normalized_request, result.calls_as_dicts())
@staticmethod
def _normalize_request(request: StewardSlotDecisionRequest) -> StewardSlotDecisionRequest:
normalized_fields = {
key: value
for key, value in normalize_ontology_form_values(request.ontology_fields).items()
if key in BUSINESS_CANONICAL_FIELDS and str(value or "").strip()
}
missing_fields: list[str] = []
for item in request.missing_fields:
key = str(item or "").strip()
if request.task_type == "expense_application" and key in APPLICATION_NON_BLOCKING_FIELDS:
continue
if key in BUSINESS_CANONICAL_FIELDS and key not in missing_fields and not normalized_fields.get(key):
missing_fields.append(key)
return StewardSlotDecisionRequest(
task_type=request.task_type,
user_message=str(request.user_message or "").strip(),
ontology_fields=normalized_fields,
missing_fields=missing_fields,
task_context=request.task_context if isinstance(request.task_context, dict) else {},
)
@staticmethod
def _build_messages(request: StewardSlotDecisionRequest) -> list[dict[str, Any]]:
context_payload = {
"task_type": request.task_type,
"user_message": request.user_message,
"ontology_fields": request.ontology_fields,
"missing_fields_from_intent_agent": request.missing_fields,
"field_catalog": {
key: FIELD_CATALOG[key]
for key in BUSINESS_CANONICAL_FIELD_ORDER
if key in FIELD_CATALOG
},
"task_context": request.task_context,
}
return [
{
"role": "system",
"content": (
"你是 X-Financial 小财管家的任务字段决策智能体。"
"你必须通过 function calling 返回下一步动作。"
"你的任务不是关键词匹配而是结合用户意图、当前任务类型、canonical ontology 字段、"
"上游意图识别给出的缺失字段和字段目录,判断现在应先追问用户,还是可以展示核对结果。"
"所有 required_fields 和 missing_fields 只能使用 field_catalog 中的 canonical 字段。"
"如果字段是内部提示、示例、系统指令或可选项,不能当作用户已经提供。"
"费用申请场景中 amount 可由系统估算,不应作为用户必须手填字段。"
"费用申请生成核对表阶段attachments 不阻塞生成,可在报销或归档阶段补充;"
"employee_no、department_name、employee_name 属于系统用户档案字段,必须从上下文读取,不能向用户追问。"
"差旅申请通常只有 transport_mode 这类会影响费用测算的字段才需要先追问。"
"如果缺失字段会影响后续测算、入库、附件归集或合规判断,应返回 ask_user"
"如果信息足以生成可核对但未提交的结果,应返回 render_preview。"
"question 和 rationale 必须是面向用户的业务说明,不暴露内部推理链。"
),
},
{
"role": "user",
"content": json.dumps(context_payload, ensure_ascii=False),
},
]
@staticmethod
def _build_tool_schema() -> dict[str, Any]:
canonical_fields = list(BUSINESS_CANONICAL_FIELD_ORDER)
return {
"type": "function",
"function": {
"name": STEWARD_SLOT_DECISION_FUNCTION_NAME,
"description": "提交小财管家当前任务的字段缺口和下一步动作决策。",
"parameters": {
"type": "object",
"properties": {
"next_action": {
"type": "string",
"enum": ["ask_user", "render_preview"],
},
"required_fields": {
"type": "array",
"items": {"type": "string", "enum": canonical_fields},
},
"missing_fields": {
"type": "array",
"items": {"type": "string", "enum": canonical_fields},
},
"question": {"type": "string"},
"options": {
"type": "array",
"items": {
"type": "object",
"properties": {
"label": {"type": "string"},
"value": {"type": "string"},
"field_key": {"type": "string", "enum": canonical_fields},
"description": {"type": "string"},
},
"required": ["label", "value", "field_key"],
},
},
"rationale": {"type": "string"},
},
"required": ["next_action", "required_fields", "missing_fields", "question", "options", "rationale"],
},
},
}
def _build_response_from_model_payload(
self,
payload: dict[str, Any],
request: StewardSlotDecisionRequest,
traces: list[dict[str, Any]],
) -> StewardSlotDecisionResponse | None:
next_action = str(payload.get("next_action") or "").strip()
if next_action not in {"ask_user", "render_preview"}:
return None
required_fields = self._sanitize_fields(payload.get("required_fields"))
missing_fields = self._sanitize_fields(payload.get("missing_fields"))
required_fields = self._filter_blocking_fields(required_fields, request.task_type)
missing_fields = self._filter_blocking_fields(missing_fields, request.task_type)
missing_fields = [
key
for key in missing_fields
if key in required_fields or key in request.missing_fields
]
if next_action == "ask_user" and not missing_fields:
missing_fields = list(request.missing_fields)
if next_action == "ask_user" and not missing_fields:
next_action = "render_preview"
options = []
question = ""
rationale = "当前申请信息足以先生成核对结果;附件和员工编号不应作为用户补填项阻塞申请预览。"
else:
options = self._sanitize_options(payload.get("options"), missing_fields)
question = self._clean_text(payload.get("question"))
rationale = self._clean_text(payload.get("rationale"))
return StewardSlotDecisionResponse(
decision_source="llm_function_call",
next_action=next_action, # type: ignore[arg-type]
required_fields=required_fields,
missing_fields=missing_fields,
question=question,
options=options,
rationale=rationale,
model_call_traces=traces,
)
@staticmethod
def _filter_blocking_fields(fields: list[str], task_type: str) -> list[str]:
if task_type != "expense_application":
return fields
return [field for field in fields if field not in APPLICATION_NON_BLOCKING_FIELDS]
@staticmethod
def _sanitize_fields(raw_fields: Any) -> list[str]:
fields: list[str] = []
if not isinstance(raw_fields, list):
return fields
for item in raw_fields:
key = str(item or "").strip()
if key in BUSINESS_CANONICAL_FIELDS and key not in fields:
fields.append(key)
return fields
def _sanitize_options(self, raw_options: Any, missing_fields: list[str]) -> list[StewardSlotOption]:
options: list[StewardSlotOption] = []
if isinstance(raw_options, list):
for item in raw_options:
if not isinstance(item, dict):
continue
field_key = str(item.get("field_key") or "").strip()
label = self._clean_text(item.get("label"))
value = self._clean_text(item.get("value")) or label
if not field_key or field_key not in BUSINESS_CANONICAL_FIELDS or not label or not value:
continue
options.append(
StewardSlotOption(
field_key=field_key,
label=label,
value=value,
description=self._clean_text(item.get("description")),
)
)
if not options and missing_fields and missing_fields[0] == "transport_mode":
options = [
StewardSlotOption(field_key="transport_mode", label="火车", value="火车", description="选择火车或高铁出行。"),
StewardSlotOption(field_key="transport_mode", label="飞机", value="飞机", description="选择飞机出行。"),
StewardSlotOption(field_key="transport_mode", label="轮船", value="轮船", description="选择轮船出行。"),
]
return options[:6]
def _build_rule_fallback(
self,
request: StewardSlotDecisionRequest,
traces: list[dict[str, Any]],
) -> StewardSlotDecisionResponse:
missing_fields = list(request.missing_fields)
if missing_fields:
field = missing_fields[0]
return StewardSlotDecisionResponse(
decision_source="rule_fallback",
next_action="ask_user",
required_fields=list(dict.fromkeys([*request.ontology_fields.keys(), *missing_fields])),
missing_fields=missing_fields,
question=self._build_fallback_question(field),
options=self._sanitize_options([], [field]),
rationale="模型字段决策暂不可用,我先按上游意图识别给出的本体缺口向你确认。",
model_call_traces=traces,
)
return StewardSlotDecisionResponse(
decision_source="rule_fallback",
next_action="render_preview",
required_fields=list(request.ontology_fields.keys()),
missing_fields=[],
question="",
options=[],
rationale="当前任务没有上游标记的关键字段缺口,可以先生成核对结果供你确认。",
model_call_traces=traces,
)
@staticmethod
def _build_fallback_question(field: str) -> str:
label = FIELD_CATALOG.get(field, {}).get("label") or field
if field == "transport_mode":
return "请问你这次打算怎么出行?可以选择火车、飞机或轮船。"
return f"当前还缺少{label},请先补充后我再继续处理。"
@staticmethod
def _clean_text(value: Any) -> str:
return str(value or "").strip()