feat: 报销预审会话状态管理与工作台交互增强

- 新增差旅报销会话状态管理与对话模型重构
- 增强风险观测服务与运行时聊天上下文作用域
- 优化工作台图标资源、助理意图识别与摘要工具
- 完善报销创建视图样式与差旅详情页标准调整交互
- 补充风险观测、运行时聊天与报销端点测试覆盖
This commit is contained in:
caoxiaozhu
2026-06-04 11:03:29 +08:00
parent 87da5df91b
commit 1cbf3fee44
60 changed files with 4156 additions and 393 deletions

View File

@@ -0,0 +1,78 @@
from __future__ import annotations
import asyncio
import json
from collections.abc import AsyncIterator
from typing import Annotated, Any
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from app.api.deps import get_db
from app.schemas.common import ErrorResponse
from app.schemas.steward import StewardPlanRequest, StewardPlanResponse
from app.services.runtime_chat import RuntimeChatService
from app.services.steward_intent_agent import StewardIntentAgent
from app.services.steward_planner import StewardPlannerService
router = APIRouter(prefix="/steward")
DbSession = Annotated[Session, Depends(get_db)]
@router.post(
"/plans",
response_model=StewardPlanResponse,
summary="生成小财管家任务计划",
description="把首页自然语言和附件元信息拆解为可确认、可追踪、可分派的财务任务计划。",
responses={
status.HTTP_400_BAD_REQUEST: {
"model": ErrorResponse,
"description": "请求缺少任务描述,无法生成小财管家计划。",
}
},
)
def create_steward_plan(payload: StewardPlanRequest, db: DbSession) -> StewardPlanResponse:
try:
return _build_steward_planner(db).build_plan(payload)
except ValueError as exc:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
@router.post(
"/plans/stream",
summary="流式生成小财管家任务计划",
description="以 NDJSON 逐条返回小财管家的过程摘要事件,最后返回完整任务计划。",
)
async def stream_steward_plan(payload: StewardPlanRequest, db: DbSession) -> StreamingResponse:
return StreamingResponse(
_iter_steward_plan_events(payload, _build_steward_planner(db)),
media_type="application/x-ndjson",
)
async def _iter_steward_plan_events(
payload: StewardPlanRequest,
planner: StewardPlannerService,
) -> AsyncIterator[str]:
try:
plan = planner.build_plan(payload)
except ValueError as exc:
yield _encode_stream_event("error", {"message": str(exc)})
return
for event in plan.thinking_events:
yield _encode_stream_event("thinking", event.model_dump(mode="json"))
await asyncio.sleep(0.18)
yield _encode_stream_event("plan", plan.model_dump(mode="json"))
def _encode_stream_event(event: str, data: dict[str, Any]) -> str:
return json.dumps({"event": event, "data": data}, ensure_ascii=False) + "\n"
def _build_steward_planner(db: Session) -> StewardPlannerService:
return StewardPlannerService(
intent_agent=StewardIntentAgent(RuntimeChatService(db)),
)

View File

@@ -22,6 +22,7 @@ from app.api.v1.endpoints.receipt_folder import router as receipt_folder_router
from app.api.v1.endpoints.reimbursements import router as reimbursements_router
from app.api.v1.endpoints.risk_observations import router as risk_observations_router
from app.api.v1.endpoints.settings import router as settings_router
from app.api.v1.endpoints.steward import router as steward_router
from app.api.v1.endpoints.system_logs import router as system_logs_router
router = APIRouter()
@@ -47,4 +48,5 @@ router.include_router(employee_profiles_router, tags=["employee-profiles"])
router.include_router(reimbursements_router, prefix="/reimbursements", tags=["reimbursements"])
router.include_router(risk_observations_router, tags=["risk-observations"])
router.include_router(settings_router, tags=["settings"])
router.include_router(steward_router, tags=["steward"])
router.include_router(system_logs_router, tags=["system-logs"])

View File

@@ -126,6 +126,7 @@ class ExpenseClaimStandardAdjustmentRisk(BaseModel):
item_id: str | None = Field(default=None, max_length=120)
title: str | None = Field(default=None, max_length=120)
risk: str | None = Field(default=None, max_length=500)
application_days: int | None = Field(default=None, ge=1, le=365)
original_amount: Decimal | None = None
reimbursable_amount: Decimal | None = None

View File

@@ -0,0 +1,90 @@
from __future__ import annotations
from typing import Any, Literal
from pydantic import BaseModel, Field
StewardTaskType = Literal["expense_application", "reimbursement"]
StewardAssignedAgent = Literal["application_assistant", "reimbursement_assistant"]
StewardPlanningSource = Literal["llm_function_call", "rule_fallback"]
StewardTaskStatus = Literal[
"planned",
"needs_confirmation",
"ready_to_delegate",
"delegated",
"completed",
"blocked",
]
StewardConfirmationStatus = Literal["pending", "confirmed", "rejected"]
class StewardAttachmentInput(BaseModel):
name: str = Field(description="附件原始文件名。")
media_type: str = Field(default="", description="附件 MIME 类型。")
ocr_summary: str = Field(default="", description="可选 OCR 摘要。")
ocr_fields: dict[str, Any] = Field(default_factory=dict, description="可选 OCR 结构化字段。")
class StewardPlanRequest(BaseModel):
message: str = Field(description="用户在首页输入的自然语言任务。")
user_id: str | None = Field(default=None, description="当前用户 ID。")
client_now_iso: str | None = Field(default=None, description="客户端当前时间 ISO 字符串。")
attachments: list[StewardAttachmentInput] = Field(default_factory=list, description="随本次输入上传的附件。")
context_json: dict[str, Any] = Field(default_factory=dict, description="调用方上下文。")
class StewardThinkingEvent(BaseModel):
event_id: str = Field(description="过程摘要事件 ID。")
stage: str = Field(description="阶段编码。")
title: str = Field(description="面向用户展示的阶段标题。")
content: str = Field(description="面向用户展示的过程摘要。")
status: str = Field(default="completed", description="事件状态。")
class StewardTask(BaseModel):
task_id: str = Field(description="小财管家任务 ID。")
task_type: StewardTaskType = Field(description="任务类型。")
assigned_agent: StewardAssignedAgent = Field(description="建议分派的下游助手。")
title: str = Field(description="任务标题。")
summary: str = Field(description="任务摘要。")
status: StewardTaskStatus = Field(default="needs_confirmation", description="任务状态。")
confidence: float = Field(default=0.0, ge=0.0, le=1.0, description="识别置信度。")
ontology_fields: dict[str, str] = Field(default_factory=dict, description="归一化后的业务本体字段。")
missing_fields: list[str] = Field(default_factory=list, description="仍缺失的本体字段。")
confirmation_required: bool = Field(default=True, description="执行前是否需要用户确认。")
class StewardAttachmentGroup(BaseModel):
group_id: str = Field(description="附件归集组 ID。")
target_task_id: str | None = Field(default=None, description="建议归属的任务 ID。")
scene: str = Field(description="归集场景编码。")
scene_label: str = Field(description="归集场景展示名。")
attachment_names: list[str] = Field(default_factory=list, description="建议纳入的附件名称。")
excluded_attachment_names: list[str] = Field(default_factory=list, description="建议排除或单独处理的附件名称。")
confidence: float = Field(default=0.0, ge=0.0, le=1.0, description="归集置信度。")
rationale: str = Field(default="", description="归集依据。")
confirmation_required: bool = Field(default=True, description="归集前是否需要用户确认。")
class StewardConfirmationAction(BaseModel):
confirmation_id: str = Field(description="确认动作 ID。")
action_type: str = Field(description="确认动作类型。")
label: str = Field(description="确认按钮文案。")
description: str = Field(default="", description="确认动作说明。")
target_task_id: str | None = Field(default=None, description="关联任务 ID。")
attachment_group_id: str | None = Field(default=None, description="关联附件归集组 ID。")
status: StewardConfirmationStatus = Field(default="pending", description="确认状态。")
payload: dict[str, Any] = Field(default_factory=dict, description="确认后继续执行所需载荷。")
class StewardPlanResponse(BaseModel):
plan_id: str = Field(description="小财管家计划 ID。")
plan_status: str = Field(default="needs_confirmation", description="计划状态。")
planning_source: StewardPlanningSource = Field(default="rule_fallback", description="计划生成来源。")
summary: str = Field(description="计划摘要。")
thinking_events: list[StewardThinkingEvent] = Field(default_factory=list, description="过程摘要事件。")
tasks: list[StewardTask] = Field(default_factory=list, description="拆解后的任务。")
attachment_groups: list[StewardAttachmentGroup] = Field(default_factory=list, description="附件归集建议。")
confirmation_groups: list[StewardConfirmationAction] = Field(default_factory=list, description="等待用户确认的动作。")
model_call_traces: list[dict[str, Any]] = Field(default_factory=list, description="模型工具调用轨迹。")

View File

@@ -254,6 +254,7 @@ class ExpenseClaimAttachmentOperationsMixin:
)
self._sync_claim_from_items(claim)
self._refresh_claim_pre_review_flags(claim, is_application_claim=False)
self.db.commit()
self.db.refresh(claim)
@@ -356,6 +357,7 @@ class ExpenseClaimAttachmentOperationsMixin:
item.invoice_id = None
self._sync_claim_from_items(claim)
self._refresh_claim_pre_review_flags(claim, is_application_claim=False)
self.db.commit()
self.db.refresh(claim)

View File

@@ -139,3 +139,61 @@ class ExpenseClaimPreReviewMixin:
)
]
return [*preserved_flags, next_flag]
def _refresh_claim_pre_review_flags(
self,
claim: ExpenseClaim,
*,
is_application_claim: bool | None = None,
reviewed_at: datetime | None = None,
) -> bool:
if claim is None:
return False
if is_application_claim is None:
is_application_claim = self._is_expense_application_claim(claim)
reviewed_at = reviewed_at or datetime.now(UTC)
if is_application_claim:
preserved_flags = [
flag
for flag in list(claim.risk_flags_json or [])
if not (
isinstance(flag, dict)
and str(flag.get("source") or "").strip() == "submission_review"
and str(flag.get("hit_source") or "").strip() == "rule_center"
)
]
application_review = self.evaluate_platform_risk_rules(
claim,
business_stage="expense_application",
)
review_flags = [*preserved_flags, *list(application_review.get("flags") or [])]
else:
review_result = self._run_ai_submission_review(claim)
review_flags = list(review_result.get("risk_flags") or [])
blocking_count = self._count_ai_pre_review_blocking_risks(review_flags)
claim.risk_flags_json = self._replace_ai_pre_review_flag(
review_flags,
self._build_ai_pre_review_flag(
passed=blocking_count <= 0,
blocking_count=blocking_count,
reviewed_at=reviewed_at,
business_stage=risk_business_stage_for_claim(
is_application_claim=is_application_claim,
),
),
)
if not is_application_claim:
claim.approval_stage = "\u5f85\u63d0\u4ea4"
claim.submitted_at = None
return True
@staticmethod
def _has_ai_pre_review_flag(claim: ExpenseClaim) -> bool:
return any(
isinstance(flag, dict)
and str(flag.get("source") or "").strip() == "ai_pre_review"
for flag in list(claim.risk_flags_json or [])
)

View File

@@ -48,6 +48,7 @@ from app.services.expense_claim_attachment_analysis import ExpenseClaimAttachmen
from app.services.expense_claim_attachment_document import ExpenseClaimAttachmentDocumentMixin
from app.services.expense_claim_attachment_operations import ExpenseClaimAttachmentOperationsMixin
from app.services.expense_claim_budget_flow import ExpenseClaimBudgetFlowMixin
from app.services.expense_claim_workflow_constants import DIRECT_MANAGER_APPROVAL_STAGE
from app.services.expense_claim_document_item_builder import ExpenseClaimDocumentItemBuilderMixin
from app.services.expense_claim_document_parsing import ExpenseClaimDocumentParsingMixin
from app.services.expense_claim_draft_flow import ExpenseClaimDraftFlowMixin
@@ -278,6 +279,9 @@ class ExpenseClaimService(
if payload.reason is not None:
claim.reason = self._normalize_optional_text(payload.reason, allow_empty=True) or "待补充"
if not self._is_expense_application_claim(claim):
self._refresh_claim_pre_review_flags(claim, is_application_claim=False)
self.db.commit()
self.db.refresh(claim)
@@ -306,6 +310,146 @@ class ExpenseClaimService(
normalized = Decimal(value or Decimal("0.00")).quantize(Decimal("0.01"))
return f"{normalized:.2f}"
@staticmethod
def _normalize_standard_adjustment_days(value: Any) -> int | None:
if value is None:
return None
if isinstance(value, int):
return value if 1 <= value <= 365 else None
text = str(value or "").strip()
if not text:
return None
match = re.search(r"\d{1,3}", text)
if not match:
return None
days = int(match.group(0))
return days if 1 <= days <= 365 else None
@staticmethod
def _normalize_standard_adjustment_text(value: Any) -> str:
text = str(value or "").strip()
if not text or text in {"-", "N/A", "n/a"}:
return ""
if text in {"待补充", "未知", "暂无", "非必填"}:
return ""
return text
def _iter_standard_adjustment_application_details(self, claim: ExpenseClaim) -> list[dict[str, Any]]:
details: list[dict[str, Any]] = []
for flag in list(claim.risk_flags_json or []):
if not isinstance(flag, dict):
continue
detail = flag.get("application_detail") or flag.get("applicationDetail")
if isinstance(detail, dict):
details.append(detail)
related = flag.get("related_application") or flag.get("relatedApplication")
if isinstance(related, dict):
details.append(related)
return details
def _resolve_standard_adjustment_days(
self,
claim: ExpenseClaim,
item: ExpenseClaimItem,
entry: Any,
) -> int:
direct_days = self._normalize_standard_adjustment_days(getattr(entry, "application_days", None))
if direct_days is not None:
return direct_days
for detail in self._iter_standard_adjustment_application_details(claim):
for key in ("application_days", "applicationDays", "days"):
detail_days = self._normalize_standard_adjustment_days(detail.get(key))
if detail_days is not None:
return detail_days
candidates = [
getattr(entry, "risk", None),
getattr(entry, "title", None),
item.item_reason,
claim.reason,
]
for text in candidates:
match = re.search(r"(\d{1,3})\s*(?:天|晚|夜)", str(text or ""))
if match:
days = self._normalize_standard_adjustment_days(match.group(1))
if days is not None:
return days
return 1
def _resolve_standard_adjustment_location(
self,
claim: ExpenseClaim,
item: ExpenseClaimItem,
) -> str:
for value in (item.item_location, claim.location):
text = self._normalize_standard_adjustment_text(value)
if text:
return text
for detail in self._iter_standard_adjustment_application_details(claim):
for key in ("application_location", "applicationLocation", "location", "city"):
text = self._normalize_standard_adjustment_text(detail.get(key))
if text:
return text
return ""
def _resolve_policy_standard_reimbursable_amount(
self,
*,
claim: ExpenseClaim,
item: ExpenseClaimItem,
entry: Any,
current_user: CurrentUserContext,
) -> Decimal | None:
item_type = str(item.item_type or "").strip().lower()
if item_type not in {"hotel", "hotel_ticket"}:
return None
location = self._resolve_standard_adjustment_location(claim, item)
grade = str(claim.employee_grade or current_user.grade or "").strip()
if not location or not grade:
return None
try:
from app.services.travel_reimbursement_calculator import TravelReimbursementCalculatorService
result = TravelReimbursementCalculatorService(self.db).calculate(
TravelReimbursementCalculatorRequest(
days=self._resolve_standard_adjustment_days(claim, item, entry),
location=location,
grade=grade,
),
current_user,
)
except Exception:
return None
return self._normalize_standard_adjustment_amount(result.hotel_amount)
def _resolve_standard_adjustment_reimbursable_amount(
self,
*,
claim: ExpenseClaim,
item: ExpenseClaimItem,
entry: Any,
original_amount: Decimal,
current_user: CurrentUserContext,
) -> Decimal:
policy_amount = self._resolve_policy_standard_reimbursable_amount(
claim=claim,
item=item,
entry=entry,
current_user=current_user,
)
if policy_amount is not None:
return min(max(policy_amount, Decimal("0.00")), original_amount)
entry_amount = self._normalize_standard_adjustment_amount(entry.reimbursable_amount)
if entry_amount is not None:
return min(max(entry_amount, Decimal("0.00")), original_amount)
return original_amount
def accept_standard_adjustment(
self,
*,
@@ -340,11 +484,13 @@ class ExpenseClaimService(
self._normalize_standard_adjustment_amount(entry.original_amount)
or Decimal(item.item_amount or Decimal("0.00")).quantize(Decimal("0.01"))
)
reimbursable_amount = (
self._normalize_standard_adjustment_amount(entry.reimbursable_amount)
or original_amount
reimbursable_amount = self._resolve_standard_adjustment_reimbursable_amount(
claim=claim,
item=item,
entry=entry,
original_amount=original_amount,
current_user=current_user,
)
reimbursable_amount = min(max(reimbursable_amount, Decimal("0.00")), original_amount)
employee_absorbed_amount = (original_amount - reimbursable_amount).quantize(Decimal("0.01"))
item_label = (
str(item.item_reason or "").strip()
@@ -456,6 +602,7 @@ class ExpenseClaimService(
self._refresh_item_attachment_analysis(item)
self._sync_claim_from_items(claim)
self._refresh_claim_pre_review_flags(claim, is_application_claim=False)
self.db.commit()
self.db.refresh(claim)
@@ -510,6 +657,7 @@ class ExpenseClaimService(
self.db.add(item)
self._sync_claim_from_items(claim)
self._refresh_claim_pre_review_flags(claim, is_application_claim=False)
self.db.commit()
self.db.refresh(claim)
@@ -548,6 +696,7 @@ class ExpenseClaimService(
self.db.delete(item)
self._sync_claim_from_items(claim)
self._refresh_claim_pre_review_flags(claim, is_application_claim=False)
self.db.commit()
self.db.refresh(claim)
@@ -645,12 +794,13 @@ class ExpenseClaimService(
budget_flags,
business_stage="reimbursement",
)
review_result = self._run_ai_submission_review(claim)
if not self._has_ai_pre_review_flag(claim):
self._refresh_claim_pre_review_flags(claim, is_application_claim=False)
claim.status = "submitted"
claim.approval_stage = DIRECT_MANAGER_APPROVAL_STAGE
claim.submitted_at = datetime.now(UTC)
claim.status = str(review_result.get("status") or "supplement")
claim.approval_stage = str(review_result.get("approval_stage") or "待补充")
claim.risk_flags_json = list(review_result.get("risk_flags") or [])
claim.submitted_at = datetime.now(UTC) if claim.status == "submitted" else None
self.db.commit()
self.db.refresh(claim)
@@ -872,11 +1022,3 @@ class ExpenseClaimService(

View File

@@ -33,17 +33,25 @@ FEEDBACK_STATUS_MAP = {
class RiskObservationService:
_storage_ready_cache: set[str] = set()
def __init__(self, db: Session) -> None:
self.db = db
def ensure_storage_ready(self) -> None:
bind = self.db.get_bind()
cache_key = str(getattr(bind, "url", "") or id(bind))
if cache_key in self._storage_ready_cache:
return
Base.metadata.create_all(
bind=self.db.get_bind(),
bind=bind,
tables=[
RiskObservation.__table__,
RiskObservationFeedback.__table__,
],
)
self._storage_ready_cache.add(cache_key)
def upsert_observation(
self,

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import json
from dataclasses import dataclass
from http import HTTPStatus
from time import monotonic, sleep
@@ -61,6 +62,23 @@ class RuntimeChatResult:
return [item.model_dump() for item in self.calls]
@dataclass(slots=True)
class RuntimeChatToolCall:
name: str
arguments: dict[str, Any]
call_id: str | None = None
raw_arguments: str = ""
@dataclass(slots=True)
class RuntimeToolCallResult:
tool_call: RuntimeChatToolCall | None
calls: list[RuntimeChatCallTrace]
def calls_as_dicts(self) -> list[dict[str, Any]]:
return [item.model_dump() for item in self.calls]
class RuntimeChatService:
def __init__(self, db: Session) -> None:
self.db = db
@@ -208,6 +226,131 @@ class RuntimeChatService:
return RuntimeChatResult(None, calls)
def complete_with_tool_call(
self,
messages: list[dict[str, Any]],
*,
tools: list[dict[str, Any]],
tool_choice: dict[str, Any] | str | None = None,
slot_priority: tuple[str, ...] = ("main", "backup"),
max_tokens: int = 1200,
temperature: float = 0.1,
timeout_seconds: int | None = None,
slot_timeouts: dict[str, int] | None = None,
max_attempts: int | None = None,
) -> RuntimeToolCallResult:
configs: list[dict[str, str]] = []
calls: list[RuntimeChatCallTrace] = []
for slot in slot_priority:
config = self._load_chat_slot(slot)
if config is None:
calls.append(
RuntimeChatCallTrace(
slot=slot,
provider="",
model="",
attempt=0,
status="skipped",
skipped_reason="not_configured",
)
)
continue
configs.append(config)
if not configs:
return RuntimeToolCallResult(None, calls)
resolved_timeout_seconds = timeout_seconds or DEFAULT_RUNTIME_CHAT_TIMEOUT_SECONDS
resolved_slot_timeouts = dict(slot_timeouts or {})
resolved_max_attempts = max_attempts or DEFAULT_RUNTIME_CHAT_RETRY_ATTEMPTS
for attempt in range(1, resolved_max_attempts + 1):
for config in configs:
cache_key = self._build_slot_cache_key(config)
if _slot_failure_until.get(cache_key, 0.0) > monotonic():
logger.info(
"Skip runtime chat tool slot=%s provider=%s because it is in cooldown",
config["slot"],
config["provider"],
)
calls.append(
RuntimeChatCallTrace(
slot=config["slot"],
provider=config["provider"],
model=config["model"],
attempt=attempt,
status="skipped",
skipped_reason="cooldown",
)
)
continue
started = monotonic()
try:
tool_call = self._request_chat_tool_call(
config,
messages,
tools=tools,
tool_choice=tool_choice,
max_tokens=max_tokens,
temperature=temperature,
timeout_seconds=resolved_slot_timeouts.get(
config["slot"],
resolved_timeout_seconds,
),
)
duration_ms = int((monotonic() - started) * 1000)
if tool_call is not None:
_slot_failure_until.pop(cache_key, None)
calls.append(
RuntimeChatCallTrace(
slot=config["slot"],
provider=config["provider"],
model=config["model"],
attempt=attempt,
status="succeeded",
duration_ms=duration_ms,
)
)
return RuntimeToolCallResult(tool_call, calls)
calls.append(
RuntimeChatCallTrace(
slot=config["slot"],
provider=config["provider"],
model=config["model"],
attempt=attempt,
status="empty",
duration_ms=duration_ms,
error_message="模型未返回工具调用。",
)
)
except Exception as exc:
duration_ms = int((monotonic() - started) * 1000)
_slot_failure_until[cache_key] = (
monotonic() + DEFAULT_RUNTIME_CHAT_FAILURE_COOLDOWN_SECONDS
)
calls.append(
RuntimeChatCallTrace(
slot=config["slot"],
provider=config["provider"],
model=config["model"],
attempt=attempt,
status="failed",
duration_ms=duration_ms,
error_message=str(exc),
)
)
logger.warning(
"Runtime chat tool request failed slot=%s provider=%s attempt=%s/%s: %s",
config["slot"],
config["provider"],
attempt,
resolved_max_attempts,
exc,
)
if attempt < resolved_max_attempts:
sleep(DEFAULT_RUNTIME_CHAT_RETRY_DELAY_SECONDS)
return RuntimeToolCallResult(None, calls)
@staticmethod
def _build_slot_cache_key(config: dict[str, str]) -> str:
return "|".join(
@@ -295,6 +438,51 @@ class RuntimeChatService:
timeout_seconds=timeout_seconds,
)
def _request_chat_tool_call(
self,
config: dict[str, str],
messages: list[dict[str, Any]],
*,
tools: list[dict[str, Any]],
tool_choice: dict[str, Any] | str | None,
max_tokens: int,
temperature: float,
timeout_seconds: int,
) -> RuntimeChatToolCall | None:
provider = config["provider"]
endpoint = config["endpoint"]
model = config["model"]
api_key = config["apiKey"]
if provider == "Azure OpenAI":
return self._request_azure_openai_tool_call(
endpoint=endpoint,
model=model,
api_key=api_key,
messages=messages,
tools=tools,
tool_choice=tool_choice,
max_tokens=max_tokens,
temperature=temperature,
timeout_seconds=timeout_seconds,
)
if provider == "Ollama":
raise ConnectivityCheckError("Ollama 暂不支持小财管家 function calling。")
return self._request_openai_compatible_tool_call(
provider=provider,
endpoint=endpoint,
model=model,
api_key=api_key,
messages=messages,
tools=tools,
tool_choice=tool_choice,
max_tokens=max_tokens,
temperature=temperature,
timeout_seconds=timeout_seconds,
)
def _request_openai_compatible(
self,
*,
@@ -331,6 +519,46 @@ class RuntimeChatService:
)
return self._extract_openai_text(payload)
def _request_openai_compatible_tool_call(
self,
*,
provider: str,
endpoint: str,
model: str,
api_key: str,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]],
tool_choice: dict[str, Any] | str | None,
max_tokens: int,
temperature: float,
timeout_seconds: int,
) -> RuntimeChatToolCall | None:
url = _ensure_path(_normalize_endpoint(endpoint), "chat/completions")
request_payload: dict[str, Any] = {
"model": model,
"messages": messages,
"tools": tools,
"tool_choice": tool_choice or "auto",
"max_tokens": max_tokens,
"temperature": temperature,
}
if provider == "GLM":
request_payload["thinking"] = {"type": "disabled"}
status_code, payload = _send_json_request(
"POST",
url,
headers=_build_headers(api_key=api_key, use_bearer=True),
payload=request_payload,
timeout_seconds=timeout_seconds,
)
if status_code >= HTTPStatus.BAD_REQUEST:
raise ConnectivityCheckError(
f"模型接口返回异常状态 {status_code}",
status_code=status_code,
)
return self._extract_openai_tool_call(payload)
def _request_ollama(
self,
*,
@@ -396,6 +624,41 @@ class RuntimeChatService:
)
return self._extract_openai_text(payload)
def _request_azure_openai_tool_call(
self,
*,
endpoint: str,
model: str,
api_key: str,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]],
tool_choice: dict[str, Any] | str | None,
max_tokens: int,
temperature: float,
timeout_seconds: int,
) -> RuntimeChatToolCall | None:
deployment_base = _build_azure_deployment_base(endpoint, model)
url = f"{deployment_base}/chat/completions?api-version={AZURE_API_VERSION}"
status_code, payload = _send_json_request(
"POST",
url,
headers=_build_headers(api_key=api_key, use_bearer=False, use_api_key=True),
payload={
"messages": messages,
"tools": tools,
"tool_choice": tool_choice or "auto",
"max_tokens": max_tokens,
"temperature": temperature,
},
timeout_seconds=timeout_seconds,
)
if status_code >= HTTPStatus.BAD_REQUEST:
raise ConnectivityCheckError(
f"Azure OpenAI 返回异常状态 {status_code}",
status_code=status_code,
)
return self._extract_openai_tool_call(payload)
@staticmethod
def _extract_openai_text(payload: Any) -> str:
if not isinstance(payload, dict):
@@ -426,3 +689,74 @@ class RuntimeChatService:
return text.strip()
return ""
@staticmethod
def _extract_openai_tool_call(payload: Any) -> RuntimeChatToolCall | None:
if not isinstance(payload, dict):
return None
choices = payload.get("choices")
if not isinstance(choices, list) or not choices:
return None
first_choice = choices[0]
if not isinstance(first_choice, dict):
return None
message = first_choice.get("message")
if not isinstance(message, dict):
return None
tool_calls = message.get("tool_calls")
if isinstance(tool_calls, list) and tool_calls:
first_tool = tool_calls[0]
if isinstance(first_tool, dict):
function_payload = first_tool.get("function")
if isinstance(function_payload, dict):
return RuntimeChatService._build_runtime_tool_call(
name=function_payload.get("name"),
arguments=function_payload.get("arguments"),
call_id=first_tool.get("id"),
)
function_call = message.get("function_call")
if isinstance(function_call, dict):
return RuntimeChatService._build_runtime_tool_call(
name=function_call.get("name"),
arguments=function_call.get("arguments"),
call_id=None,
)
return None
@staticmethod
def _build_runtime_tool_call(
*,
name: Any,
arguments: Any,
call_id: Any,
) -> RuntimeChatToolCall | None:
tool_name = str(name or "").strip()
if not tool_name:
return None
raw_arguments = ""
if isinstance(arguments, dict):
parsed_arguments = arguments
raw_arguments = json.dumps(arguments, ensure_ascii=False)
else:
raw_arguments = str(arguments or "").strip()
if not raw_arguments:
parsed_arguments = {}
else:
parsed = json.loads(raw_arguments)
if not isinstance(parsed, dict):
raise ValueError("工具调用参数必须是 JSON object。")
parsed_arguments = parsed
return RuntimeChatToolCall(
name=tool_name,
arguments=parsed_arguments,
call_id=str(call_id).strip() if call_id else None,
raw_arguments=raw_arguments,
)

View File

@@ -0,0 +1,18 @@
from __future__ import annotations
BUSINESS_CANONICAL_FIELD_ORDER = (
"expense_type",
"time_range",
"location",
"reason",
"amount",
"transport_mode",
"attachments",
"customer_name",
"merchant_name",
"department_name",
"employee_name",
"employee_no",
)
BUSINESS_CANONICAL_FIELDS = frozenset(BUSINESS_CANONICAL_FIELD_ORDER)

View File

@@ -0,0 +1,220 @@
from __future__ import annotations
import json
from dataclasses import dataclass
from datetime import date
from typing import Any
from app.schemas.steward import StewardPlanRequest
from app.services.ontology_field_registry import normalize_ontology_form_values
from app.services.runtime_chat import RuntimeChatService
STEWARD_INTENT_FUNCTION_NAME = "submit_steward_intent_plan"
@dataclass(frozen=True, slots=True)
class StewardIntentAgentResult:
payload: dict[str, Any]
model_call_traces: list[dict[str, Any]]
class StewardIntentAgent:
"""使用大模型 function calling 识别小财管家的复合财务意图。"""
def __init__(self, runtime_chat_service: RuntimeChatService) -> None:
self.runtime_chat_service = runtime_chat_service
self.last_call_traces: list[dict[str, Any]] = []
def detect(
self,
request: StewardPlanRequest,
*,
base_date: date,
canonical_fields: list[str],
) -> StewardIntentAgentResult | None:
result = self.runtime_chat_service.complete_with_tool_call(
self._build_messages(request, base_date=base_date, canonical_fields=canonical_fields),
tools=[self._build_intent_tool_schema(canonical_fields)],
tool_choice={
"type": "function",
"function": {"name": STEWARD_INTENT_FUNCTION_NAME},
},
max_tokens=1800,
temperature=0.1,
timeout_seconds=18,
max_attempts=1,
)
self.last_call_traces = result.calls_as_dicts()
if result.tool_call is None or result.tool_call.name != STEWARD_INTENT_FUNCTION_NAME:
return None
return StewardIntentAgentResult(
payload=result.tool_call.arguments,
model_call_traces=self.last_call_traces,
)
@staticmethod
def _build_messages(
request: StewardPlanRequest,
*,
base_date: date,
canonical_fields: list[str],
) -> list[dict[str, Any]]:
context_payload = {
"message": request.message,
"base_date": base_date.isoformat(),
"client_now_iso": request.client_now_iso,
"user_id": request.user_id,
"canonical_ontology_fields": canonical_fields,
"review_form_values": normalize_ontology_form_values(
request.context_json.get("review_form_values")
),
"context_json": {
key: value
for key, value in request.context_json.items()
if key
in {
"entry_source",
"session_type",
"role_codes",
"username",
"name",
"department_name",
"employee_grade",
"employee_no",
"client_timezone_offset_minutes",
}
},
"attachments": [
{
"index": index + 1,
"name": item.name,
"media_type": item.media_type,
"ocr_summary": item.ocr_summary,
"ocr_fields": item.ocr_fields,
}
for index, item in enumerate(request.attachments)
if item.name
],
}
return [
{
"role": "system",
"content": (
"你是 X-Financial 的小财管家意图识别智能体。"
"你必须通过 function calling 输出结构化计划,不能只返回普通文本。"
"当前版本只支持 expense_application 和 reimbursement 两类任务;"
"你只做识别、拆解、归集和确认点规划,不能执行入库、绑定附件或提交审批。"
"所有 ontology_fields 只能使用调用方给出的 canonical_ontology_fields"
"如果输入里出现 occurred_date、transport_type、reason_value 等别名,必须映射为 canonical 字段。"
"相对日期必须以 base_date 为准转换为明确日期。"
"thinking_events 只能是面向用户的过程摘要,不能暴露内部推理链。"
),
},
{
"role": "user",
"content": json.dumps(context_payload, ensure_ascii=False),
},
]
@staticmethod
def _build_intent_tool_schema(canonical_fields: list[str]) -> dict[str, Any]:
return {
"type": "function",
"function": {
"name": STEWARD_INTENT_FUNCTION_NAME,
"description": "提交小财管家的复合财务意图识别结果。",
"parameters": {
"type": "object",
"properties": {
"thinking_events": {
"type": "array",
"items": {
"type": "object",
"properties": {
"stage": {"type": "string"},
"title": {"type": "string"},
"content": {"type": "string"},
},
"required": ["stage", "title", "content"],
},
},
"tasks": {
"type": "array",
"items": {
"type": "object",
"properties": {
"task_type": {
"type": "string",
"enum": ["expense_application", "reimbursement"],
},
"title": {"type": "string"},
"summary": {"type": "string"},
"confidence": {
"type": "number",
"minimum": 0,
"maximum": 1,
},
"ontology_fields": {
"type": "object",
"additionalProperties": {"type": "string"},
},
"missing_fields": {
"type": "array",
"items": {
"type": "string",
"enum": canonical_fields,
},
},
},
"required": [
"task_type",
"title",
"summary",
"confidence",
"ontology_fields",
"missing_fields",
],
},
},
"attachment_groups": {
"type": "array",
"items": {
"type": "object",
"properties": {
"target_task_index": {
"type": "integer",
"minimum": 1,
},
"scene": {"type": "string"},
"scene_label": {"type": "string"},
"attachment_names": {
"type": "array",
"items": {"type": "string"},
},
"excluded_attachment_names": {
"type": "array",
"items": {"type": "string"},
},
"confidence": {
"type": "number",
"minimum": 0,
"maximum": 1,
},
"rationale": {"type": "string"},
},
"required": [
"scene",
"scene_label",
"attachment_names",
"excluded_attachment_names",
"confidence",
"rationale",
],
},
},
},
"required": ["thinking_events", "tasks", "attachment_groups"],
},
},
}

View File

@@ -0,0 +1,365 @@
from __future__ import annotations
import re
import uuid
from datetime import date
from typing import Any
from app.schemas.steward import (
StewardAttachmentGroup,
StewardAttachmentInput,
StewardPlanRequest,
StewardPlanResponse,
StewardTask,
StewardThinkingEvent,
)
from app.services.ontology_field_registry import normalize_ontology_form_values
from app.services.steward_constants import BUSINESS_CANONICAL_FIELDS
from app.services.steward_intent_agent import StewardIntentAgentResult
class StewardModelPlanBuilder:
"""把模型 function calling 返回值转换为小财管家的服务端计划。"""
def __init__(self, planner: Any) -> None:
self.planner = planner
def build(
self,
intent_result: StewardIntentAgentResult,
*,
request: StewardPlanRequest,
base_date: date,
) -> StewardPlanResponse | None:
tasks = self._build_tasks_from_model_payload(intent_result.payload, request, base_date)
if not tasks:
return None
attachment_groups = self._build_attachment_groups_from_model_payload(
intent_result.payload,
request.attachments,
tasks,
)
if request.attachments and not attachment_groups:
attachment_groups = self.planner._build_attachment_groups(request.attachments, tasks)
confirmation_groups = self.planner._build_confirmation_actions(tasks, attachment_groups)
thinking_events = self._build_llm_thinking_events(
intent_result.payload,
tasks=tasks,
attachment_groups=attachment_groups,
attachments=request.attachments,
)
return StewardPlanResponse(
plan_id=f"steward_plan_{uuid.uuid4().hex[:12]}",
plan_status="needs_confirmation" if confirmation_groups else "ready_to_delegate",
planning_source="llm_function_call",
summary=self.planner._build_summary(tasks, attachment_groups),
thinking_events=thinking_events,
tasks=tasks,
attachment_groups=attachment_groups,
confirmation_groups=confirmation_groups,
model_call_traces=intent_result.model_call_traces,
)
def _build_tasks_from_model_payload(
self,
payload: dict[str, Any],
request: StewardPlanRequest,
base_date: date,
) -> list[StewardTask]:
raw_tasks = payload.get("tasks")
if not isinstance(raw_tasks, list):
return []
tasks: list[StewardTask] = []
for raw_task in raw_tasks:
if not isinstance(raw_task, dict):
continue
task_type = str(raw_task.get("task_type") or "").strip()
if task_type not in {"expense_application", "reimbursement"}:
continue
task_index = len(tasks) + 1
fields = self._sanitize_model_ontology_fields(
raw_task.get("ontology_fields"),
request=request,
base_date=base_date,
)
supplement_segment = " ".join(
[
str(raw_task.get("title") or ""),
str(raw_task.get("summary") or ""),
]
)
supplement_fields = self.planner._extract_ontology_fields(
supplement_segment,
task_type,
base_date,
request,
)
for key, value in supplement_fields.items():
fields.setdefault(key, value)
assigned_agent = (
"application_assistant"
if task_type == "expense_application"
else "reimbursement_assistant"
)
task_id = f"task_{'app' if task_type == 'expense_application' else 'reim'}_{task_index:03d}"
title_prefix = "费用申请" if task_type == "expense_application" else "费用报销"
title = self.planner._clean_text(raw_task.get("title")) or self.planner._build_task_title(
title_prefix,
fields,
task_index,
)
summary = self.planner._clean_text(raw_task.get("summary")) or self.planner._build_task_summary(
supplement_segment,
fields,
)
missing_fields = self._sanitize_model_missing_fields(
raw_task.get("missing_fields"),
task_type=task_type,
fields=fields,
)
tasks.append(
StewardTask(
task_id=task_id,
task_type=task_type, # type: ignore[arg-type]
assigned_agent=assigned_agent, # type: ignore[arg-type]
title=title,
summary=summary,
status="needs_confirmation",
confidence=self._resolve_model_confidence(
raw_task.get("confidence"),
segment=supplement_segment,
fields=fields,
task_type=task_type,
),
ontology_fields=fields,
missing_fields=missing_fields,
confirmation_required=True,
)
)
return tasks
def _sanitize_model_ontology_fields(
self,
raw_fields: Any,
*,
request: StewardPlanRequest,
base_date: date,
) -> dict[str, str]:
normalized_context = normalize_ontology_form_values(request.context_json.get("review_form_values"))
fields: dict[str, str] = {
key: value
for key, value in normalized_context.items()
if key in BUSINESS_CANONICAL_FIELDS and str(value or "").strip()
}
if not isinstance(raw_fields, dict):
return fields
normalized_model_fields = normalize_ontology_form_values(raw_fields)
for key, value in normalized_model_fields.items():
if key not in BUSINESS_CANONICAL_FIELDS:
continue
normalized_value = self._normalize_model_field_value(key, value, base_date)
if normalized_value:
fields[key] = normalized_value
if request.attachments and not fields.get("attachments"):
fields["attachments"] = "".join(item.name for item in request.attachments if item.name)
return {key: value for key, value in fields.items() if key in BUSINESS_CANONICAL_FIELDS and value}
def _build_attachment_groups_from_model_payload(
self,
payload: dict[str, Any],
attachments: list[StewardAttachmentInput],
tasks: list[StewardTask],
) -> list[StewardAttachmentGroup]:
raw_groups = payload.get("attachment_groups")
if not isinstance(raw_groups, list) or not attachments:
return []
uploaded_names = {item.name for item in attachments if item.name}
groups: list[StewardAttachmentGroup] = []
for raw_group in raw_groups:
if not isinstance(raw_group, dict):
continue
attachment_names = self._filter_uploaded_attachment_names(
raw_group.get("attachment_names"),
uploaded_names,
)
excluded_names = self._filter_uploaded_attachment_names(
raw_group.get("excluded_attachment_names"),
uploaded_names,
)
if not attachment_names and not excluded_names:
continue
scene = self.planner._clean_text(raw_group.get("scene")) or "other"
groups.append(
StewardAttachmentGroup(
group_id=f"ag_{self._slug_scene(scene)}_{len(groups) + 1:03d}",
target_task_id=self._resolve_model_group_target_task_id(raw_group, tasks),
scene=scene,
scene_label=self.planner._clean_text(raw_group.get("scene_label")) or "待确认费用",
attachment_names=attachment_names,
excluded_attachment_names=excluded_names,
confidence=self._clamp_confidence(raw_group.get("confidence"), default=0.68),
rationale=(
self.planner._clean_text(raw_group.get("rationale"))
or "模型根据附件线索生成归集建议。"
),
confirmation_required=True,
)
)
return groups
def _build_llm_thinking_events(
self,
payload: dict[str, Any],
*,
tasks: list[StewardTask],
attachment_groups: list[StewardAttachmentGroup],
attachments: list[StewardAttachmentInput],
) -> list[StewardThinkingEvent]:
events = [
StewardThinkingEvent(
event_id="intent_agent_function_call",
stage="llm_function_call",
title="意图识别智能体接管",
content=(
"已调用系统主模型的 submit_steward_intent_plan 工具,"
"把用户话术转换为可校验的结构化财务任务计划。"
),
)
]
raw_events = payload.get("thinking_events")
if isinstance(raw_events, list):
for raw_event in raw_events[:4]:
if not isinstance(raw_event, dict):
continue
title = self.planner._clean_text(raw_event.get("title"))
content = self.planner._clean_text(raw_event.get("content"))
if not title or not content:
continue
events.append(
StewardThinkingEvent(
event_id=f"intent_agent_model_{len(events):03d}",
stage=self.planner._clean_text(raw_event.get("stage")) or "model_summary",
title=title,
content=content,
)
)
if len(events) == 1:
events.extend(self.planner._build_thinking_events(tasks, attachment_groups, attachments)[1:])
return events
def _sanitize_model_missing_fields(
self,
raw_missing_fields: Any,
*,
task_type: str,
fields: dict[str, str],
) -> list[str]:
missing_fields: list[str] = []
if isinstance(raw_missing_fields, list):
for item in raw_missing_fields:
key = str(item or "").strip()
if key in BUSINESS_CANONICAL_FIELDS and key not in missing_fields and not fields.get(key):
missing_fields.append(key)
for key in self.planner._resolve_missing_fields(task_type, fields):
if key not in missing_fields:
missing_fields.append(key)
return missing_fields
def _resolve_model_confidence(
self,
value: Any,
*,
segment: str,
fields: dict[str, str],
task_type: str,
) -> float:
return self._clamp_confidence(
value,
default=self.planner._resolve_task_confidence(segment, fields, task_type),
)
def _normalize_model_field_value(self, key: str, value: Any, base_date: date) -> str:
cleaned = self.planner._clean_text(value)
if not cleaned:
return ""
if key == "time_range":
return self.planner._extract_time_range(cleaned, base_date) or cleaned
if key == "expense_type":
return self._normalize_expense_type_value(cleaned)
if key == "transport_mode":
return self._normalize_transport_mode_value(cleaned)
return cleaned
@staticmethod
def _normalize_expense_type_value(value: str) -> str:
normalized = str(value or "").strip().lower()
if normalized in {"travel", "travel_application", "差旅", "差旅费", "出差"}:
return "travel"
if normalized in {"transport", "traffic", "交通", "交通费", "打车", "出租车"}:
return "transport"
if normalized in {"entertainment", "meal", "招待", "接待", "餐饮", "业务招待"}:
return "entertainment"
if normalized in {"office", "办公", "办公用品"}:
return "office"
return normalized
@staticmethod
def _normalize_transport_mode_value(value: str) -> str:
normalized = str(value or "").strip().lower()
if normalized in {"train", "高铁", "动车", "火车"}:
return "train"
if normalized in {"flight", "air", "飞机", "机票", "航班"}:
return "flight"
if normalized in {"taxi", "出租车", "的士", "网约车", "打车"}:
return "taxi"
if normalized in {"subway", "地铁"}:
return "subway"
return normalized
@staticmethod
def _filter_uploaded_attachment_names(raw_names: Any, uploaded_names: set[str]) -> list[str]:
if not isinstance(raw_names, list):
return []
names: list[str] = []
for raw_name in raw_names:
name = str(raw_name or "").strip()
if name in uploaded_names and name not in names:
names.append(name)
return names
@staticmethod
def _resolve_model_group_target_task_id(raw_group: dict[str, Any], tasks: list[StewardTask]) -> str | None:
try:
target_index = int(raw_group.get("target_task_index") or 0)
except (TypeError, ValueError):
target_index = 0
if target_index > 0 and target_index <= len(tasks):
return tasks[target_index - 1].task_id
target_task_id = str(raw_group.get("target_task_id") or "").strip()
if target_task_id and any(task.task_id == target_task_id for task in tasks):
return target_task_id
return None
@staticmethod
def _slug_scene(value: str) -> str:
normalized = re.sub(r"[^a-zA-Z0-9_]+", "_", str(value or "").strip().lower()).strip("_")
return normalized or "other"
@staticmethod
def _clamp_confidence(value: Any, *, default: float) -> float:
try:
parsed = float(value)
except (TypeError, ValueError):
parsed = default
return round(min(1.0, max(0.0, parsed)), 2)

View File

@@ -0,0 +1,645 @@
from __future__ import annotations
import re
import uuid
from dataclasses import dataclass
from datetime import UTC, date, datetime, timedelta
from typing import Any
from app.schemas.steward import (
StewardAttachmentGroup,
StewardAttachmentInput,
StewardConfirmationAction,
StewardPlanRequest,
StewardPlanResponse,
StewardTask,
StewardThinkingEvent,
)
from app.services.steward_constants import BUSINESS_CANONICAL_FIELD_ORDER, BUSINESS_CANONICAL_FIELDS
from app.services.ontology_field_registry import normalize_ontology_form_values
from app.services.steward_intent_agent import StewardIntentAgent
from app.services.steward_model_plan_builder import StewardModelPlanBuilder
CITY_NAMES = (
"北京",
"上海",
"广州",
"深圳",
"杭州",
"南京",
"苏州",
"成都",
"重庆",
"天津",
"武汉",
"西安",
"长沙",
"郑州",
"青岛",
"厦门",
"福州",
"合肥",
"济南",
"沈阳",
"大连",
"宁波",
"无锡",
)
APPLICATION_SPLIT_PATTERN = re.compile(r"(?:^|[,。;;])[^,。;;]*?(?:申请|出差申请|差旅申请)[^,。;;]*")
REIMBURSEMENT_PATTERN = re.compile(r"(?:我要报销|还需要报销|需要报销|报销)([^,。;;?!\n]+)")
MONTH_DAY_PATTERN = re.compile(r"(?P<month>\d{1,2})\s*月\s*(?P<day>\d{1,2})\s*(?:日|号)?")
ISO_DATE_PATTERN = re.compile(r"(?P<year>\d{4})[-/年](?P<month>\d{1,2})[-/月](?P<day>\d{1,2})(?:日)?")
@dataclass(frozen=True)
class PlannedTaskDraft:
task_type: str
segment: str
index: int
class StewardPlannerService:
"""小财管家第一版规划服务:只生成计划,不执行入库类动作。"""
def __init__(self, intent_agent: StewardIntentAgent | None = None) -> None:
self.intent_agent = intent_agent
def build_plan(self, request: StewardPlanRequest) -> StewardPlanResponse:
message = self._clean_text(request.message)
if not message:
raise ValueError("小财管家需要一段任务描述。")
base_date = self._resolve_base_date(request.client_now_iso, request.context_json)
model_call_traces: list[dict[str, Any]] = []
fallback_reason = ""
if self.intent_agent is not None:
try:
intent_result = self.intent_agent.detect(
request,
base_date=base_date,
canonical_fields=list(BUSINESS_CANONICAL_FIELD_ORDER),
)
if intent_result is not None:
model_call_traces = intent_result.model_call_traces
llm_plan = StewardModelPlanBuilder(self).build(
intent_result,
request=request,
base_date=base_date,
)
if llm_plan is not None:
return llm_plan
model_call_traces = getattr(self.intent_agent, "last_call_traces", []) or model_call_traces
fallback_reason = "主模型未返回可用的 function calling 计划,已切换到规则兜底。"
except Exception as exc:
model_call_traces = getattr(self.intent_agent, "last_call_traces", []) or model_call_traces
fallback_reason = f"主模型 function calling 调用失败,已切换到规则兜底:{exc}"
return self._build_rule_fallback_plan(
request,
base_date=base_date,
model_call_traces=model_call_traces,
fallback_reason=fallback_reason,
)
def _build_rule_fallback_plan(
self,
request: StewardPlanRequest,
*,
base_date: date,
model_call_traces: list[dict[str, Any]] | None = None,
fallback_reason: str = "",
) -> StewardPlanResponse:
message = self._clean_text(request.message)
task_drafts = self._extract_task_drafts(message)
tasks = [self._build_task(draft, base_date, request) for draft in task_drafts]
if not tasks:
tasks = [self._build_fallback_task(message, base_date, request)]
attachment_groups = self._build_attachment_groups(request.attachments, tasks)
confirmation_groups = self._build_confirmation_actions(tasks, attachment_groups)
thinking_events = self._build_thinking_events(tasks, attachment_groups, request.attachments)
if fallback_reason:
thinking_events.insert(
0,
StewardThinkingEvent(
event_id="intent_agent_rule_fallback",
stage="rule_fallback",
title="意图识别智能体进入兜底模式",
content=fallback_reason,
),
)
plan_id = f"steward_plan_{uuid.uuid4().hex[:12]}"
return StewardPlanResponse(
plan_id=plan_id,
plan_status="needs_confirmation" if confirmation_groups else "ready_to_delegate",
planning_source="rule_fallback",
summary=self._build_summary(tasks, attachment_groups),
thinking_events=thinking_events,
tasks=tasks,
attachment_groups=attachment_groups,
confirmation_groups=confirmation_groups,
model_call_traces=model_call_traces or [],
)
def _extract_task_drafts(self, message: str) -> list[PlannedTaskDraft]:
drafts: list[PlannedTaskDraft] = []
first_reimbursement = self._find_first_reimbursement_index(message)
application_source = message[:first_reimbursement] if first_reimbursement >= 0 else message
if self._looks_like_application(application_source):
drafts.append(
PlannedTaskDraft(
task_type="expense_application",
segment=application_source.strip(",。;; "),
index=len(drafts) + 1,
)
)
for match in REIMBURSEMENT_PATTERN.finditer(message):
segment = f"报销{match.group(1)}"
drafts.append(
PlannedTaskDraft(
task_type="reimbursement",
segment=segment.strip(",。;; "),
index=len(drafts) + 1,
)
)
return drafts
@staticmethod
def _find_first_reimbursement_index(message: str) -> int:
candidates = [message.find(item) for item in ("我要报销", "还需要报销", "需要报销", "报销")]
positives = [item for item in candidates if item >= 0]
return min(positives) if positives else -1
@staticmethod
def _looks_like_application(text: str) -> bool:
compact = re.sub(r"\s+", "", text)
return bool(compact) and "申请" in compact and bool(re.search(r"出差|差旅|费用|交通|住宿|采购|会务|会议", compact))
def _build_task(
self,
draft: PlannedTaskDraft,
base_date: date,
request: StewardPlanRequest,
) -> StewardTask:
fields = self._extract_ontology_fields(draft.segment, draft.task_type, base_date, request)
missing_fields = self._resolve_missing_fields(draft.task_type, fields)
task_id = f"task_{'app' if draft.task_type == 'expense_application' else 'reim'}_{draft.index:03d}"
assigned_agent = (
"application_assistant"
if draft.task_type == "expense_application"
else "reimbursement_assistant"
)
title_prefix = "费用申请" if draft.task_type == "expense_application" else "费用报销"
title = self._build_task_title(title_prefix, fields, draft.index)
return StewardTask(
task_id=task_id,
task_type=draft.task_type, # type: ignore[arg-type]
assigned_agent=assigned_agent, # type: ignore[arg-type]
title=title,
summary=self._build_task_summary(draft.segment, fields),
status="needs_confirmation",
confidence=self._resolve_task_confidence(draft.segment, fields, draft.task_type),
ontology_fields=fields,
missing_fields=missing_fields,
confirmation_required=True,
)
def _build_fallback_task(
self,
message: str,
base_date: date,
request: StewardPlanRequest,
) -> StewardTask:
task_type = "reimbursement" if "报销" in message or request.attachments else "expense_application"
draft = PlannedTaskDraft(task_type=task_type, segment=message, index=1)
task = self._build_task(draft, base_date, request)
return task.model_copy(update={"confidence": min(task.confidence, 0.58)})
def _extract_ontology_fields(
self,
segment: str,
task_type: str,
base_date: date,
request: StewardPlanRequest,
) -> dict[str, str]:
normalized_context = normalize_ontology_form_values(request.context_json.get("review_form_values"))
fields: dict[str, str] = {
key: value
for key, value in normalized_context.items()
if key in BUSINESS_CANONICAL_FIELDS and str(value or "").strip()
}
expense_type = self._infer_expense_type(segment, task_type)
if expense_type and not fields.get("expense_type"):
fields["expense_type"] = expense_type
time_range = self._extract_time_range(segment, base_date)
if time_range and not fields.get("time_range"):
fields["time_range"] = time_range
location = self._extract_location(segment)
if location and not fields.get("location"):
fields["location"] = location
reason = self._extract_reason(segment, task_type)
if reason and not fields.get("reason"):
fields["reason"] = reason
transport_mode = self._extract_transport_mode(segment)
if transport_mode and not fields.get("transport_mode"):
fields["transport_mode"] = transport_mode
if request.attachments:
fields["attachments"] = "".join(item.name for item in request.attachments if item.name)
return {key: value for key, value in fields.items() if key in BUSINESS_CANONICAL_FIELDS and value}
@staticmethod
def _infer_expense_type(segment: str, task_type: str) -> str:
compact = re.sub(r"\s+", "", segment)
if re.search(r"招待|接待|餐饮|宴请|客户吃饭|业务餐", compact):
return "entertainment"
if re.search(r"出差|差旅|住宿|酒店|机票|航班|高铁|火车", compact):
return "travel"
if re.search(r"交通|出租车|的士|网约车|打车|地铁|公交", compact):
return "transport" if task_type == "reimbursement" else "travel"
return "travel" if task_type == "expense_application" else "other"
def _extract_time_range(self, segment: str, base_date: date) -> str:
compact = re.sub(r"\s+", "", segment)
if "昨天" in compact:
return (base_date - timedelta(days=1)).isoformat()
if "前天" in compact:
return (base_date - timedelta(days=2)).isoformat()
if "明天" in compact:
return (base_date + timedelta(days=1)).isoformat()
if "后天" in compact:
return (base_date + timedelta(days=2)).isoformat()
iso_match = ISO_DATE_PATTERN.search(compact)
if iso_match:
return self._safe_date(
int(iso_match.group("year")),
int(iso_match.group("month")),
int(iso_match.group("day")),
)
month_day = MONTH_DAY_PATTERN.search(compact)
if month_day:
return self._safe_date(
base_date.year,
int(month_day.group("month")),
int(month_day.group("day")),
)
return ""
@staticmethod
def _safe_date(year: int, month: int, day: int) -> str:
try:
return date(year, month, day).isoformat()
except ValueError:
return ""
@staticmethod
def _extract_location(segment: str) -> str:
compact = re.sub(r"\s+", "", segment)
for prefix in ("", "", "", "前往"):
match = re.search(fr"{prefix}({'|'.join(CITY_NAMES)})", compact)
if match:
return match.group(1)
for city in CITY_NAMES:
if city in compact:
return city
return ""
@staticmethod
def _extract_reason(segment: str, task_type: str) -> str:
cleaned = re.sub(r"\s+", "", segment).strip(",。;; ")
if task_type == "expense_application":
match = re.search(r"(辅助|支持|协助|支撑|参加|拜访|调研|实施|部署|审核).+", cleaned)
if match:
return StewardPlannerService._strip_trailing_connectors(match.group(0))
reason = re.sub(r"^.*?(?:出差|差旅)", "", cleaned).strip(",。;;的费用")
return StewardPlannerService._strip_trailing_connectors(reason) or cleaned
cleaned = re.sub(r"^报销", "", cleaned)
cleaned = re.sub(r"^(?:昨天|前天|明天|后天|\d{1,2}月\d{1,2}(?:日|号)?)的?", "", cleaned)
return cleaned.strip(",。;; ") or segment.strip()
@staticmethod
def _strip_trailing_connectors(value: str) -> str:
cleaned = str(value or "").strip(",。;; ")
return re.sub(r"(?:并且|而且|同时|另外|还需要|需要)$", "", cleaned).strip(",。;; ")
@staticmethod
def _extract_transport_mode(segment: str) -> str:
compact = re.sub(r"\s+", "", segment)
if re.search(r"高铁|动车|火车", compact):
return "train"
if re.search(r"飞机|机票|航班", compact):
return "flight"
if re.search(r"出租车|的士|网约车|打车", compact):
return "taxi"
if "交通" in compact:
return "other"
return ""
@staticmethod
def _resolve_missing_fields(task_type: str, fields: dict[str, str]) -> list[str]:
required = ["expense_type", "time_range", "reason"]
if task_type == "expense_application":
required.append("location")
return [key for key in required if not str(fields.get(key) or "").strip()]
@staticmethod
def _resolve_task_confidence(segment: str, fields: dict[str, str], task_type: str) -> float:
compact = re.sub(r"\s+", "", segment)
intent_score = 1.0 if ("申请" in compact if task_type == "expense_application" else "报销" in compact) else 0.45
time_score = 1.0 if fields.get("time_range") else 0.0
location_score = 1.0 if fields.get("location") else 0.2
scene_score = 1.0 if fields.get("expense_type") and fields["expense_type"] != "other" else 0.35
confidence = min(1.0, 0.35 * intent_score + 0.25 * time_score + 0.2 * location_score + 0.2 * scene_score)
return round(max(0.45, confidence), 2)
def _build_attachment_groups(
self,
attachments: list[StewardAttachmentInput],
tasks: list[StewardTask],
) -> list[StewardAttachmentGroup]:
if not attachments:
return []
classified = [(item, self._classify_attachment(item)) for item in attachments if item.name]
travel_related = [item.name for item, scene in classified if scene in {"travel", "transport"}]
excluded = [item.name for item, scene in classified if scene not in {"travel", "transport"}]
target_task = self._resolve_attachment_target_task(tasks)
groups: list[StewardAttachmentGroup] = []
if travel_related:
confidence = 0.72 + min(0.18, len(travel_related) * 0.04)
groups.append(
StewardAttachmentGroup(
group_id="ag_travel_001",
target_task_id=target_task.task_id if target_task else None,
scene="travel",
scene_label="差旅相关费用",
attachment_names=travel_related,
excluded_attachment_names=excluded,
confidence=round(confidence, 2),
rationale="附件名称或 OCR 摘要中包含差旅、交通、住宿、火车、机票等线索。",
confirmation_required=True,
)
)
elif excluded:
groups.append(
StewardAttachmentGroup(
group_id="ag_other_001",
target_task_id=None,
scene="other",
scene_label="待人工确认费用",
attachment_names=excluded,
excluded_attachment_names=[],
confidence=0.5,
rationale="当前附件缺少可稳定归属到申请或报销任务的差旅线索。",
confirmation_required=True,
)
)
return groups
@staticmethod
def _resolve_attachment_target_task(tasks: list[StewardTask]) -> StewardTask | None:
reimbursement_tasks = [item for item in tasks if item.task_type == "reimbursement"]
for task in reimbursement_tasks:
if task.ontology_fields.get("expense_type") == "travel":
return task
return reimbursement_tasks[0] if reimbursement_tasks else None
@staticmethod
def _classify_attachment(attachment: StewardAttachmentInput) -> str:
text = " ".join(
[
attachment.name,
attachment.media_type,
attachment.ocr_summary,
" ".join(f"{key}:{value}" for key, value in attachment.ocr_fields.items()),
]
)
compact = re.sub(r"\s+", "", text).lower()
if re.search(r"招待|接待|餐饮|宴请|客户|meal|entertainment", compact):
return "entertainment"
if re.search(r"酒店|住宿|差旅|出差|高铁|火车|动车|机票|航班|train|flight|hotel|travel", compact):
return "travel"
if re.search(r"出租车|的士|网约车|打车|交通|taxi|transport", compact):
return "transport"
return "other"
def _build_confirmation_actions(
self,
tasks: list[StewardTask],
attachment_groups: list[StewardAttachmentGroup],
) -> list[StewardConfirmationAction]:
actions: list[StewardConfirmationAction] = []
for task in tasks:
if task.task_type == "expense_application":
action_type = "confirm_create_application"
label = "确认创建申请单"
else:
action_type = "confirm_create_reimbursement_draft"
label = "确认创建报销草稿"
actions.append(
StewardConfirmationAction(
confirmation_id=f"confirm_{task.task_id}",
action_type=action_type,
label=label,
description=f"确认后把“{task.title}”交给{self._agent_label(task.assigned_agent)}继续核对。",
target_task_id=task.task_id,
payload={
"task_id": task.task_id,
"task_type": task.task_type,
"assigned_agent": task.assigned_agent,
"ontology_fields": task.ontology_fields,
},
)
)
for group in attachment_groups:
actions.append(
StewardConfirmationAction(
confirmation_id=f"confirm_{group.group_id}",
action_type="confirm_attachment_group",
label="确认附件归集",
description=f"确认后将 {len(group.attachment_names)} 份附件按“{group.scene_label}”归集。",
target_task_id=group.target_task_id,
attachment_group_id=group.group_id,
payload={
"attachment_group_id": group.group_id,
"target_task_id": group.target_task_id,
"attachment_names": group.attachment_names,
"excluded_attachment_names": group.excluded_attachment_names,
},
)
)
return actions
@staticmethod
def _agent_label(assigned_agent: str) -> str:
return "申请助手" if assigned_agent == "application_assistant" else "报销助手"
def _build_thinking_events(
self,
tasks: list[StewardTask],
attachment_groups: list[StewardAttachmentGroup],
attachments: list[StewardAttachmentInput],
) -> list[StewardThinkingEvent]:
application_count = sum(1 for item in tasks if item.task_type == "expense_application")
reimbursement_count = sum(1 for item in tasks if item.task_type == "reimbursement")
task_intent_summary = self._summarize_task_intents(tasks)
ontology_summary = self._summarize_ontology_coverage(tasks)
delegation_summary = self._summarize_delegation_targets(tasks)
events = [
StewardThinkingEvent(
event_id="intent_agent_entry",
stage="intent_agent",
title="意图识别智能体接管",
content=(
f"检测到复合财务话术,当前不是单一助手会话;"
f"已进入小财管家编排模式,候选任务共 {len(tasks)} 个。"
),
),
StewardThinkingEvent(
event_id="intent_task_split",
stage="task_split",
title=f"拆分申请 {application_count} 个、报销 {reimbursement_count}",
content=task_intent_summary,
),
StewardThinkingEvent(
event_id="intent_ontology_mapping",
stage="ontology_mapping",
title="映射业务本体字段",
content=ontology_summary,
),
]
if attachments:
events.append(
StewardThinkingEvent(
event_id="intent_attachment_correlation",
stage="attachment_correlation",
title="关联附件与任务线索",
content=self._summarize_attachment_correlation(attachment_groups, len(attachments)),
)
)
events.append(
StewardThinkingEvent(
event_id="intent_delegation_gate",
stage="delegation_gate",
title="生成确认点并准备分派",
content=f"{delegation_summary} 创建单据、生成草稿、绑定附件和提交审批都会等待用户确认。",
)
)
return events
@staticmethod
def _summarize_task_intents(tasks: list[StewardTask]) -> str:
if not tasks:
return "当前输入尚未形成稳定任务,先保留为待确认财务事项。"
parts = []
for task in tasks:
task_label = "申请" if task.task_type == "expense_application" else "报销"
fields = task.ontology_fields
anchors = []
if fields.get("time_range"):
anchors.append(fields["time_range"])
if fields.get("location"):
anchors.append(fields["location"])
if fields.get("expense_type"):
anchors.append(fields["expense_type"])
anchor_text = "".join(anchors) if anchors else "待补充关键字段"
parts.append(f"{task_label}{task.title}{anchor_text}")
return "".join(parts)
@staticmethod
def _summarize_ontology_coverage(tasks: list[StewardTask]) -> str:
canonical_keys = []
missing_keys = []
for task in tasks:
canonical_keys.extend(task.ontology_fields.keys())
missing_keys.extend(task.missing_fields)
unique_keys = sorted({item for item in canonical_keys if item})
unique_missing = sorted({item for item in missing_keys if item})
mapped = "".join(unique_keys) if unique_keys else "暂无稳定字段"
missing = ";缺失字段:" + "".join(unique_missing) if unique_missing else ""
return f"已使用 canonical ontology fields{mapped}{missing}。兼容字段只作为输入别名,不直接进入业务逻辑。"
@staticmethod
def _summarize_attachment_correlation(
attachment_groups: list[StewardAttachmentGroup],
total_attachment_count: int,
) -> str:
grouped_names = []
excluded_names = []
for group in attachment_groups:
grouped_names.extend(group.attachment_names)
excluded_names.extend(group.excluded_attachment_names)
grouped_text = "".join(grouped_names) if grouped_names else "暂无可稳定归集附件"
excluded_text = ";排除或单独确认:" + "".join(excluded_names) if excluded_names else ""
return f"已核对 {total_attachment_count} 份附件,建议归集:{grouped_text}{excluded_text}"
@staticmethod
def _summarize_delegation_targets(tasks: list[StewardTask]) -> str:
application_count = sum(1 for item in tasks if item.assigned_agent == "application_assistant")
reimbursement_count = sum(1 for item in tasks if item.assigned_agent == "reimbursement_assistant")
parts = []
if application_count:
parts.append(f"{application_count} 个申请任务交给申请助手")
if reimbursement_count:
parts.append(f"{reimbursement_count} 个报销任务交给报销助手")
return "".join(parts) + "" if parts else "尚无可分派任务。"
@staticmethod
def _build_summary(tasks: list[StewardTask], attachment_groups: list[StewardAttachmentGroup]) -> str:
parts = [f"我识别到 {len(tasks)} 个待处理任务"]
if attachment_groups:
grouped = sum(len(item.attachment_names) for item in attachment_groups)
parts.append(f"并形成 {grouped} 份附件的归集建议")
parts.append(",请确认后我再分派给对应助手执行。")
return "".join(parts)
@staticmethod
def _build_task_title(prefix: str, fields: dict[str, str], index: int) -> str:
location = fields.get("location", "")
time_range = fields.get("time_range", "")
expense_type = fields.get("expense_type", "")
subject = location or {"travel": "差旅", "transport": "交通", "entertainment": "招待"}.get(expense_type, "")
if subject and time_range:
return f"{prefix} {time_range} {subject}"
if subject:
return f"{prefix} {subject}"
return f"{prefix} {index}"
@staticmethod
def _build_task_summary(segment: str, fields: dict[str, str]) -> str:
field_parts = []
for key, label in (
("time_range", "时间"),
("location", "地点"),
("expense_type", "费用类型"),
("reason", "事由"),
("transport_mode", "交通方式"),
):
value = fields.get(key)
if value:
field_parts.append(f"{label}{value}")
return "".join(field_parts) or segment
@staticmethod
def _resolve_base_date(client_now_iso: str | None, context_json: dict[str, Any]) -> date:
raw_value = client_now_iso or str(context_json.get("client_now_iso") or "").strip()
if raw_value:
try:
parsed = datetime.fromisoformat(raw_value.replace("Z", "+00:00"))
return parsed.date()
except ValueError:
pass
return datetime.now(UTC).date()
@staticmethod
def _clean_text(value: Any) -> str:
return re.sub(r"\s+", " ", str(value or "")).strip()