feat: 报销预审会话状态管理与工作台交互增强
- 新增差旅报销会话状态管理与对话模型重构 - 增强风险观测服务与运行时聊天上下文作用域 - 优化工作台图标资源、助理意图识别与摘要工具 - 完善报销创建视图样式与差旅详情页标准调整交互 - 补充风险观测、运行时聊天与报销端点测试覆盖
This commit is contained in:
78
server/src/app/api/v1/endpoints/steward.py
Normal file
78
server/src/app/api/v1/endpoints/steward.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Annotated, Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.deps import get_db
|
||||
from app.schemas.common import ErrorResponse
|
||||
from app.schemas.steward import StewardPlanRequest, StewardPlanResponse
|
||||
from app.services.runtime_chat import RuntimeChatService
|
||||
from app.services.steward_intent_agent import StewardIntentAgent
|
||||
from app.services.steward_planner import StewardPlannerService
|
||||
|
||||
router = APIRouter(prefix="/steward")
|
||||
DbSession = Annotated[Session, Depends(get_db)]
|
||||
|
||||
|
||||
@router.post(
|
||||
"/plans",
|
||||
response_model=StewardPlanResponse,
|
||||
summary="生成小财管家任务计划",
|
||||
description="把首页自然语言和附件元信息拆解为可确认、可追踪、可分派的财务任务计划。",
|
||||
responses={
|
||||
status.HTTP_400_BAD_REQUEST: {
|
||||
"model": ErrorResponse,
|
||||
"description": "请求缺少任务描述,无法生成小财管家计划。",
|
||||
}
|
||||
},
|
||||
)
|
||||
def create_steward_plan(payload: StewardPlanRequest, db: DbSession) -> StewardPlanResponse:
|
||||
try:
|
||||
return _build_steward_planner(db).build_plan(payload)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
|
||||
|
||||
|
||||
@router.post(
|
||||
"/plans/stream",
|
||||
summary="流式生成小财管家任务计划",
|
||||
description="以 NDJSON 逐条返回小财管家的过程摘要事件,最后返回完整任务计划。",
|
||||
)
|
||||
async def stream_steward_plan(payload: StewardPlanRequest, db: DbSession) -> StreamingResponse:
|
||||
return StreamingResponse(
|
||||
_iter_steward_plan_events(payload, _build_steward_planner(db)),
|
||||
media_type="application/x-ndjson",
|
||||
)
|
||||
|
||||
|
||||
async def _iter_steward_plan_events(
|
||||
payload: StewardPlanRequest,
|
||||
planner: StewardPlannerService,
|
||||
) -> AsyncIterator[str]:
|
||||
try:
|
||||
plan = planner.build_plan(payload)
|
||||
except ValueError as exc:
|
||||
yield _encode_stream_event("error", {"message": str(exc)})
|
||||
return
|
||||
|
||||
for event in plan.thinking_events:
|
||||
yield _encode_stream_event("thinking", event.model_dump(mode="json"))
|
||||
await asyncio.sleep(0.18)
|
||||
|
||||
yield _encode_stream_event("plan", plan.model_dump(mode="json"))
|
||||
|
||||
|
||||
def _encode_stream_event(event: str, data: dict[str, Any]) -> str:
|
||||
return json.dumps({"event": event, "data": data}, ensure_ascii=False) + "\n"
|
||||
|
||||
|
||||
def _build_steward_planner(db: Session) -> StewardPlannerService:
|
||||
return StewardPlannerService(
|
||||
intent_agent=StewardIntentAgent(RuntimeChatService(db)),
|
||||
)
|
||||
@@ -22,6 +22,7 @@ from app.api.v1.endpoints.receipt_folder import router as receipt_folder_router
|
||||
from app.api.v1.endpoints.reimbursements import router as reimbursements_router
|
||||
from app.api.v1.endpoints.risk_observations import router as risk_observations_router
|
||||
from app.api.v1.endpoints.settings import router as settings_router
|
||||
from app.api.v1.endpoints.steward import router as steward_router
|
||||
from app.api.v1.endpoints.system_logs import router as system_logs_router
|
||||
|
||||
router = APIRouter()
|
||||
@@ -47,4 +48,5 @@ router.include_router(employee_profiles_router, tags=["employee-profiles"])
|
||||
router.include_router(reimbursements_router, prefix="/reimbursements", tags=["reimbursements"])
|
||||
router.include_router(risk_observations_router, tags=["risk-observations"])
|
||||
router.include_router(settings_router, tags=["settings"])
|
||||
router.include_router(steward_router, tags=["steward"])
|
||||
router.include_router(system_logs_router, tags=["system-logs"])
|
||||
|
||||
@@ -126,6 +126,7 @@ class ExpenseClaimStandardAdjustmentRisk(BaseModel):
|
||||
item_id: str | None = Field(default=None, max_length=120)
|
||||
title: str | None = Field(default=None, max_length=120)
|
||||
risk: str | None = Field(default=None, max_length=500)
|
||||
application_days: int | None = Field(default=None, ge=1, le=365)
|
||||
original_amount: Decimal | None = None
|
||||
reimbursable_amount: Decimal | None = None
|
||||
|
||||
|
||||
90
server/src/app/schemas/steward.py
Normal file
90
server/src/app/schemas/steward.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
StewardTaskType = Literal["expense_application", "reimbursement"]
|
||||
StewardAssignedAgent = Literal["application_assistant", "reimbursement_assistant"]
|
||||
StewardPlanningSource = Literal["llm_function_call", "rule_fallback"]
|
||||
StewardTaskStatus = Literal[
|
||||
"planned",
|
||||
"needs_confirmation",
|
||||
"ready_to_delegate",
|
||||
"delegated",
|
||||
"completed",
|
||||
"blocked",
|
||||
]
|
||||
StewardConfirmationStatus = Literal["pending", "confirmed", "rejected"]
|
||||
|
||||
|
||||
class StewardAttachmentInput(BaseModel):
|
||||
name: str = Field(description="附件原始文件名。")
|
||||
media_type: str = Field(default="", description="附件 MIME 类型。")
|
||||
ocr_summary: str = Field(default="", description="可选 OCR 摘要。")
|
||||
ocr_fields: dict[str, Any] = Field(default_factory=dict, description="可选 OCR 结构化字段。")
|
||||
|
||||
|
||||
class StewardPlanRequest(BaseModel):
|
||||
message: str = Field(description="用户在首页输入的自然语言任务。")
|
||||
user_id: str | None = Field(default=None, description="当前用户 ID。")
|
||||
client_now_iso: str | None = Field(default=None, description="客户端当前时间 ISO 字符串。")
|
||||
attachments: list[StewardAttachmentInput] = Field(default_factory=list, description="随本次输入上传的附件。")
|
||||
context_json: dict[str, Any] = Field(default_factory=dict, description="调用方上下文。")
|
||||
|
||||
|
||||
class StewardThinkingEvent(BaseModel):
|
||||
event_id: str = Field(description="过程摘要事件 ID。")
|
||||
stage: str = Field(description="阶段编码。")
|
||||
title: str = Field(description="面向用户展示的阶段标题。")
|
||||
content: str = Field(description="面向用户展示的过程摘要。")
|
||||
status: str = Field(default="completed", description="事件状态。")
|
||||
|
||||
|
||||
class StewardTask(BaseModel):
|
||||
task_id: str = Field(description="小财管家任务 ID。")
|
||||
task_type: StewardTaskType = Field(description="任务类型。")
|
||||
assigned_agent: StewardAssignedAgent = Field(description="建议分派的下游助手。")
|
||||
title: str = Field(description="任务标题。")
|
||||
summary: str = Field(description="任务摘要。")
|
||||
status: StewardTaskStatus = Field(default="needs_confirmation", description="任务状态。")
|
||||
confidence: float = Field(default=0.0, ge=0.0, le=1.0, description="识别置信度。")
|
||||
ontology_fields: dict[str, str] = Field(default_factory=dict, description="归一化后的业务本体字段。")
|
||||
missing_fields: list[str] = Field(default_factory=list, description="仍缺失的本体字段。")
|
||||
confirmation_required: bool = Field(default=True, description="执行前是否需要用户确认。")
|
||||
|
||||
|
||||
class StewardAttachmentGroup(BaseModel):
|
||||
group_id: str = Field(description="附件归集组 ID。")
|
||||
target_task_id: str | None = Field(default=None, description="建议归属的任务 ID。")
|
||||
scene: str = Field(description="归集场景编码。")
|
||||
scene_label: str = Field(description="归集场景展示名。")
|
||||
attachment_names: list[str] = Field(default_factory=list, description="建议纳入的附件名称。")
|
||||
excluded_attachment_names: list[str] = Field(default_factory=list, description="建议排除或单独处理的附件名称。")
|
||||
confidence: float = Field(default=0.0, ge=0.0, le=1.0, description="归集置信度。")
|
||||
rationale: str = Field(default="", description="归集依据。")
|
||||
confirmation_required: bool = Field(default=True, description="归集前是否需要用户确认。")
|
||||
|
||||
|
||||
class StewardConfirmationAction(BaseModel):
|
||||
confirmation_id: str = Field(description="确认动作 ID。")
|
||||
action_type: str = Field(description="确认动作类型。")
|
||||
label: str = Field(description="确认按钮文案。")
|
||||
description: str = Field(default="", description="确认动作说明。")
|
||||
target_task_id: str | None = Field(default=None, description="关联任务 ID。")
|
||||
attachment_group_id: str | None = Field(default=None, description="关联附件归集组 ID。")
|
||||
status: StewardConfirmationStatus = Field(default="pending", description="确认状态。")
|
||||
payload: dict[str, Any] = Field(default_factory=dict, description="确认后继续执行所需载荷。")
|
||||
|
||||
|
||||
class StewardPlanResponse(BaseModel):
|
||||
plan_id: str = Field(description="小财管家计划 ID。")
|
||||
plan_status: str = Field(default="needs_confirmation", description="计划状态。")
|
||||
planning_source: StewardPlanningSource = Field(default="rule_fallback", description="计划生成来源。")
|
||||
summary: str = Field(description="计划摘要。")
|
||||
thinking_events: list[StewardThinkingEvent] = Field(default_factory=list, description="过程摘要事件。")
|
||||
tasks: list[StewardTask] = Field(default_factory=list, description="拆解后的任务。")
|
||||
attachment_groups: list[StewardAttachmentGroup] = Field(default_factory=list, description="附件归集建议。")
|
||||
confirmation_groups: list[StewardConfirmationAction] = Field(default_factory=list, description="等待用户确认的动作。")
|
||||
model_call_traces: list[dict[str, Any]] = Field(default_factory=list, description="模型工具调用轨迹。")
|
||||
@@ -254,6 +254,7 @@ class ExpenseClaimAttachmentOperationsMixin:
|
||||
)
|
||||
|
||||
self._sync_claim_from_items(claim)
|
||||
self._refresh_claim_pre_review_flags(claim, is_application_claim=False)
|
||||
self.db.commit()
|
||||
self.db.refresh(claim)
|
||||
|
||||
@@ -356,6 +357,7 @@ class ExpenseClaimAttachmentOperationsMixin:
|
||||
item.invoice_id = None
|
||||
|
||||
self._sync_claim_from_items(claim)
|
||||
self._refresh_claim_pre_review_flags(claim, is_application_claim=False)
|
||||
self.db.commit()
|
||||
self.db.refresh(claim)
|
||||
|
||||
|
||||
@@ -139,3 +139,61 @@ class ExpenseClaimPreReviewMixin:
|
||||
)
|
||||
]
|
||||
return [*preserved_flags, next_flag]
|
||||
|
||||
def _refresh_claim_pre_review_flags(
|
||||
self,
|
||||
claim: ExpenseClaim,
|
||||
*,
|
||||
is_application_claim: bool | None = None,
|
||||
reviewed_at: datetime | None = None,
|
||||
) -> bool:
|
||||
if claim is None:
|
||||
return False
|
||||
|
||||
if is_application_claim is None:
|
||||
is_application_claim = self._is_expense_application_claim(claim)
|
||||
reviewed_at = reviewed_at or datetime.now(UTC)
|
||||
|
||||
if is_application_claim:
|
||||
preserved_flags = [
|
||||
flag
|
||||
for flag in list(claim.risk_flags_json or [])
|
||||
if not (
|
||||
isinstance(flag, dict)
|
||||
and str(flag.get("source") or "").strip() == "submission_review"
|
||||
and str(flag.get("hit_source") or "").strip() == "rule_center"
|
||||
)
|
||||
]
|
||||
application_review = self.evaluate_platform_risk_rules(
|
||||
claim,
|
||||
business_stage="expense_application",
|
||||
)
|
||||
review_flags = [*preserved_flags, *list(application_review.get("flags") or [])]
|
||||
else:
|
||||
review_result = self._run_ai_submission_review(claim)
|
||||
review_flags = list(review_result.get("risk_flags") or [])
|
||||
|
||||
blocking_count = self._count_ai_pre_review_blocking_risks(review_flags)
|
||||
claim.risk_flags_json = self._replace_ai_pre_review_flag(
|
||||
review_flags,
|
||||
self._build_ai_pre_review_flag(
|
||||
passed=blocking_count <= 0,
|
||||
blocking_count=blocking_count,
|
||||
reviewed_at=reviewed_at,
|
||||
business_stage=risk_business_stage_for_claim(
|
||||
is_application_claim=is_application_claim,
|
||||
),
|
||||
),
|
||||
)
|
||||
if not is_application_claim:
|
||||
claim.approval_stage = "\u5f85\u63d0\u4ea4"
|
||||
claim.submitted_at = None
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _has_ai_pre_review_flag(claim: ExpenseClaim) -> bool:
|
||||
return any(
|
||||
isinstance(flag, dict)
|
||||
and str(flag.get("source") or "").strip() == "ai_pre_review"
|
||||
for flag in list(claim.risk_flags_json or [])
|
||||
)
|
||||
|
||||
@@ -48,6 +48,7 @@ from app.services.expense_claim_attachment_analysis import ExpenseClaimAttachmen
|
||||
from app.services.expense_claim_attachment_document import ExpenseClaimAttachmentDocumentMixin
|
||||
from app.services.expense_claim_attachment_operations import ExpenseClaimAttachmentOperationsMixin
|
||||
from app.services.expense_claim_budget_flow import ExpenseClaimBudgetFlowMixin
|
||||
from app.services.expense_claim_workflow_constants import DIRECT_MANAGER_APPROVAL_STAGE
|
||||
from app.services.expense_claim_document_item_builder import ExpenseClaimDocumentItemBuilderMixin
|
||||
from app.services.expense_claim_document_parsing import ExpenseClaimDocumentParsingMixin
|
||||
from app.services.expense_claim_draft_flow import ExpenseClaimDraftFlowMixin
|
||||
@@ -278,6 +279,9 @@ class ExpenseClaimService(
|
||||
if payload.reason is not None:
|
||||
claim.reason = self._normalize_optional_text(payload.reason, allow_empty=True) or "待补充"
|
||||
|
||||
if not self._is_expense_application_claim(claim):
|
||||
self._refresh_claim_pre_review_flags(claim, is_application_claim=False)
|
||||
|
||||
self.db.commit()
|
||||
self.db.refresh(claim)
|
||||
|
||||
@@ -306,6 +310,146 @@ class ExpenseClaimService(
|
||||
normalized = Decimal(value or Decimal("0.00")).quantize(Decimal("0.01"))
|
||||
return f"{normalized:.2f}"
|
||||
|
||||
@staticmethod
|
||||
def _normalize_standard_adjustment_days(value: Any) -> int | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, int):
|
||||
return value if 1 <= value <= 365 else None
|
||||
text = str(value or "").strip()
|
||||
if not text:
|
||||
return None
|
||||
match = re.search(r"\d{1,3}", text)
|
||||
if not match:
|
||||
return None
|
||||
days = int(match.group(0))
|
||||
return days if 1 <= days <= 365 else None
|
||||
|
||||
@staticmethod
|
||||
def _normalize_standard_adjustment_text(value: Any) -> str:
|
||||
text = str(value or "").strip()
|
||||
if not text or text in {"-", "N/A", "n/a"}:
|
||||
return ""
|
||||
if text in {"待补充", "未知", "暂无", "非必填"}:
|
||||
return ""
|
||||
return text
|
||||
|
||||
def _iter_standard_adjustment_application_details(self, claim: ExpenseClaim) -> list[dict[str, Any]]:
|
||||
details: list[dict[str, Any]] = []
|
||||
for flag in list(claim.risk_flags_json or []):
|
||||
if not isinstance(flag, dict):
|
||||
continue
|
||||
detail = flag.get("application_detail") or flag.get("applicationDetail")
|
||||
if isinstance(detail, dict):
|
||||
details.append(detail)
|
||||
related = flag.get("related_application") or flag.get("relatedApplication")
|
||||
if isinstance(related, dict):
|
||||
details.append(related)
|
||||
return details
|
||||
|
||||
def _resolve_standard_adjustment_days(
|
||||
self,
|
||||
claim: ExpenseClaim,
|
||||
item: ExpenseClaimItem,
|
||||
entry: Any,
|
||||
) -> int:
|
||||
direct_days = self._normalize_standard_adjustment_days(getattr(entry, "application_days", None))
|
||||
if direct_days is not None:
|
||||
return direct_days
|
||||
|
||||
for detail in self._iter_standard_adjustment_application_details(claim):
|
||||
for key in ("application_days", "applicationDays", "days"):
|
||||
detail_days = self._normalize_standard_adjustment_days(detail.get(key))
|
||||
if detail_days is not None:
|
||||
return detail_days
|
||||
|
||||
candidates = [
|
||||
getattr(entry, "risk", None),
|
||||
getattr(entry, "title", None),
|
||||
item.item_reason,
|
||||
claim.reason,
|
||||
]
|
||||
for text in candidates:
|
||||
match = re.search(r"(\d{1,3})\s*(?:天|晚|夜)", str(text or ""))
|
||||
if match:
|
||||
days = self._normalize_standard_adjustment_days(match.group(1))
|
||||
if days is not None:
|
||||
return days
|
||||
return 1
|
||||
|
||||
def _resolve_standard_adjustment_location(
|
||||
self,
|
||||
claim: ExpenseClaim,
|
||||
item: ExpenseClaimItem,
|
||||
) -> str:
|
||||
for value in (item.item_location, claim.location):
|
||||
text = self._normalize_standard_adjustment_text(value)
|
||||
if text:
|
||||
return text
|
||||
|
||||
for detail in self._iter_standard_adjustment_application_details(claim):
|
||||
for key in ("application_location", "applicationLocation", "location", "city"):
|
||||
text = self._normalize_standard_adjustment_text(detail.get(key))
|
||||
if text:
|
||||
return text
|
||||
return ""
|
||||
|
||||
def _resolve_policy_standard_reimbursable_amount(
|
||||
self,
|
||||
*,
|
||||
claim: ExpenseClaim,
|
||||
item: ExpenseClaimItem,
|
||||
entry: Any,
|
||||
current_user: CurrentUserContext,
|
||||
) -> Decimal | None:
|
||||
item_type = str(item.item_type or "").strip().lower()
|
||||
if item_type not in {"hotel", "hotel_ticket"}:
|
||||
return None
|
||||
|
||||
location = self._resolve_standard_adjustment_location(claim, item)
|
||||
grade = str(claim.employee_grade or current_user.grade or "").strip()
|
||||
if not location or not grade:
|
||||
return None
|
||||
|
||||
try:
|
||||
from app.services.travel_reimbursement_calculator import TravelReimbursementCalculatorService
|
||||
|
||||
result = TravelReimbursementCalculatorService(self.db).calculate(
|
||||
TravelReimbursementCalculatorRequest(
|
||||
days=self._resolve_standard_adjustment_days(claim, item, entry),
|
||||
location=location,
|
||||
grade=grade,
|
||||
),
|
||||
current_user,
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
return self._normalize_standard_adjustment_amount(result.hotel_amount)
|
||||
|
||||
def _resolve_standard_adjustment_reimbursable_amount(
|
||||
self,
|
||||
*,
|
||||
claim: ExpenseClaim,
|
||||
item: ExpenseClaimItem,
|
||||
entry: Any,
|
||||
original_amount: Decimal,
|
||||
current_user: CurrentUserContext,
|
||||
) -> Decimal:
|
||||
policy_amount = self._resolve_policy_standard_reimbursable_amount(
|
||||
claim=claim,
|
||||
item=item,
|
||||
entry=entry,
|
||||
current_user=current_user,
|
||||
)
|
||||
if policy_amount is not None:
|
||||
return min(max(policy_amount, Decimal("0.00")), original_amount)
|
||||
|
||||
entry_amount = self._normalize_standard_adjustment_amount(entry.reimbursable_amount)
|
||||
if entry_amount is not None:
|
||||
return min(max(entry_amount, Decimal("0.00")), original_amount)
|
||||
return original_amount
|
||||
|
||||
def accept_standard_adjustment(
|
||||
self,
|
||||
*,
|
||||
@@ -340,11 +484,13 @@ class ExpenseClaimService(
|
||||
self._normalize_standard_adjustment_amount(entry.original_amount)
|
||||
or Decimal(item.item_amount or Decimal("0.00")).quantize(Decimal("0.01"))
|
||||
)
|
||||
reimbursable_amount = (
|
||||
self._normalize_standard_adjustment_amount(entry.reimbursable_amount)
|
||||
or original_amount
|
||||
reimbursable_amount = self._resolve_standard_adjustment_reimbursable_amount(
|
||||
claim=claim,
|
||||
item=item,
|
||||
entry=entry,
|
||||
original_amount=original_amount,
|
||||
current_user=current_user,
|
||||
)
|
||||
reimbursable_amount = min(max(reimbursable_amount, Decimal("0.00")), original_amount)
|
||||
employee_absorbed_amount = (original_amount - reimbursable_amount).quantize(Decimal("0.01"))
|
||||
item_label = (
|
||||
str(item.item_reason or "").strip()
|
||||
@@ -456,6 +602,7 @@ class ExpenseClaimService(
|
||||
|
||||
self._refresh_item_attachment_analysis(item)
|
||||
self._sync_claim_from_items(claim)
|
||||
self._refresh_claim_pre_review_flags(claim, is_application_claim=False)
|
||||
self.db.commit()
|
||||
self.db.refresh(claim)
|
||||
|
||||
@@ -510,6 +657,7 @@ class ExpenseClaimService(
|
||||
self.db.add(item)
|
||||
|
||||
self._sync_claim_from_items(claim)
|
||||
self._refresh_claim_pre_review_flags(claim, is_application_claim=False)
|
||||
self.db.commit()
|
||||
self.db.refresh(claim)
|
||||
|
||||
@@ -548,6 +696,7 @@ class ExpenseClaimService(
|
||||
self.db.delete(item)
|
||||
|
||||
self._sync_claim_from_items(claim)
|
||||
self._refresh_claim_pre_review_flags(claim, is_application_claim=False)
|
||||
self.db.commit()
|
||||
self.db.refresh(claim)
|
||||
|
||||
@@ -645,12 +794,13 @@ class ExpenseClaimService(
|
||||
budget_flags,
|
||||
business_stage="reimbursement",
|
||||
)
|
||||
review_result = self._run_ai_submission_review(claim)
|
||||
if not self._has_ai_pre_review_flag(claim):
|
||||
self._refresh_claim_pre_review_flags(claim, is_application_claim=False)
|
||||
|
||||
claim.status = "submitted"
|
||||
claim.approval_stage = DIRECT_MANAGER_APPROVAL_STAGE
|
||||
claim.submitted_at = datetime.now(UTC)
|
||||
|
||||
claim.status = str(review_result.get("status") or "supplement")
|
||||
claim.approval_stage = str(review_result.get("approval_stage") or "待补充")
|
||||
claim.risk_flags_json = list(review_result.get("risk_flags") or [])
|
||||
claim.submitted_at = datetime.now(UTC) if claim.status == "submitted" else None
|
||||
|
||||
self.db.commit()
|
||||
self.db.refresh(claim)
|
||||
@@ -872,11 +1022,3 @@ class ExpenseClaimService(
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -33,17 +33,25 @@ FEEDBACK_STATUS_MAP = {
|
||||
|
||||
|
||||
class RiskObservationService:
|
||||
_storage_ready_cache: set[str] = set()
|
||||
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
|
||||
def ensure_storage_ready(self) -> None:
|
||||
bind = self.db.get_bind()
|
||||
cache_key = str(getattr(bind, "url", "") or id(bind))
|
||||
if cache_key in self._storage_ready_cache:
|
||||
return
|
||||
|
||||
Base.metadata.create_all(
|
||||
bind=self.db.get_bind(),
|
||||
bind=bind,
|
||||
tables=[
|
||||
RiskObservation.__table__,
|
||||
RiskObservationFeedback.__table__,
|
||||
],
|
||||
)
|
||||
self._storage_ready_cache.add(cache_key)
|
||||
|
||||
def upsert_observation(
|
||||
self,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from http import HTTPStatus
|
||||
from time import monotonic, sleep
|
||||
@@ -61,6 +62,23 @@ class RuntimeChatResult:
|
||||
return [item.model_dump() for item in self.calls]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class RuntimeChatToolCall:
|
||||
name: str
|
||||
arguments: dict[str, Any]
|
||||
call_id: str | None = None
|
||||
raw_arguments: str = ""
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class RuntimeToolCallResult:
|
||||
tool_call: RuntimeChatToolCall | None
|
||||
calls: list[RuntimeChatCallTrace]
|
||||
|
||||
def calls_as_dicts(self) -> list[dict[str, Any]]:
|
||||
return [item.model_dump() for item in self.calls]
|
||||
|
||||
|
||||
class RuntimeChatService:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
@@ -208,6 +226,131 @@ class RuntimeChatService:
|
||||
|
||||
return RuntimeChatResult(None, calls)
|
||||
|
||||
def complete_with_tool_call(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
*,
|
||||
tools: list[dict[str, Any]],
|
||||
tool_choice: dict[str, Any] | str | None = None,
|
||||
slot_priority: tuple[str, ...] = ("main", "backup"),
|
||||
max_tokens: int = 1200,
|
||||
temperature: float = 0.1,
|
||||
timeout_seconds: int | None = None,
|
||||
slot_timeouts: dict[str, int] | None = None,
|
||||
max_attempts: int | None = None,
|
||||
) -> RuntimeToolCallResult:
|
||||
configs: list[dict[str, str]] = []
|
||||
calls: list[RuntimeChatCallTrace] = []
|
||||
for slot in slot_priority:
|
||||
config = self._load_chat_slot(slot)
|
||||
if config is None:
|
||||
calls.append(
|
||||
RuntimeChatCallTrace(
|
||||
slot=slot,
|
||||
provider="",
|
||||
model="",
|
||||
attempt=0,
|
||||
status="skipped",
|
||||
skipped_reason="not_configured",
|
||||
)
|
||||
)
|
||||
continue
|
||||
configs.append(config)
|
||||
if not configs:
|
||||
return RuntimeToolCallResult(None, calls)
|
||||
|
||||
resolved_timeout_seconds = timeout_seconds or DEFAULT_RUNTIME_CHAT_TIMEOUT_SECONDS
|
||||
resolved_slot_timeouts = dict(slot_timeouts or {})
|
||||
resolved_max_attempts = max_attempts or DEFAULT_RUNTIME_CHAT_RETRY_ATTEMPTS
|
||||
|
||||
for attempt in range(1, resolved_max_attempts + 1):
|
||||
for config in configs:
|
||||
cache_key = self._build_slot_cache_key(config)
|
||||
if _slot_failure_until.get(cache_key, 0.0) > monotonic():
|
||||
logger.info(
|
||||
"Skip runtime chat tool slot=%s provider=%s because it is in cooldown",
|
||||
config["slot"],
|
||||
config["provider"],
|
||||
)
|
||||
calls.append(
|
||||
RuntimeChatCallTrace(
|
||||
slot=config["slot"],
|
||||
provider=config["provider"],
|
||||
model=config["model"],
|
||||
attempt=attempt,
|
||||
status="skipped",
|
||||
skipped_reason="cooldown",
|
||||
)
|
||||
)
|
||||
continue
|
||||
started = monotonic()
|
||||
try:
|
||||
tool_call = self._request_chat_tool_call(
|
||||
config,
|
||||
messages,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
timeout_seconds=resolved_slot_timeouts.get(
|
||||
config["slot"],
|
||||
resolved_timeout_seconds,
|
||||
),
|
||||
)
|
||||
duration_ms = int((monotonic() - started) * 1000)
|
||||
if tool_call is not None:
|
||||
_slot_failure_until.pop(cache_key, None)
|
||||
calls.append(
|
||||
RuntimeChatCallTrace(
|
||||
slot=config["slot"],
|
||||
provider=config["provider"],
|
||||
model=config["model"],
|
||||
attempt=attempt,
|
||||
status="succeeded",
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
)
|
||||
return RuntimeToolCallResult(tool_call, calls)
|
||||
calls.append(
|
||||
RuntimeChatCallTrace(
|
||||
slot=config["slot"],
|
||||
provider=config["provider"],
|
||||
model=config["model"],
|
||||
attempt=attempt,
|
||||
status="empty",
|
||||
duration_ms=duration_ms,
|
||||
error_message="模型未返回工具调用。",
|
||||
)
|
||||
)
|
||||
except Exception as exc:
|
||||
duration_ms = int((monotonic() - started) * 1000)
|
||||
_slot_failure_until[cache_key] = (
|
||||
monotonic() + DEFAULT_RUNTIME_CHAT_FAILURE_COOLDOWN_SECONDS
|
||||
)
|
||||
calls.append(
|
||||
RuntimeChatCallTrace(
|
||||
slot=config["slot"],
|
||||
provider=config["provider"],
|
||||
model=config["model"],
|
||||
attempt=attempt,
|
||||
status="failed",
|
||||
duration_ms=duration_ms,
|
||||
error_message=str(exc),
|
||||
)
|
||||
)
|
||||
logger.warning(
|
||||
"Runtime chat tool request failed slot=%s provider=%s attempt=%s/%s: %s",
|
||||
config["slot"],
|
||||
config["provider"],
|
||||
attempt,
|
||||
resolved_max_attempts,
|
||||
exc,
|
||||
)
|
||||
if attempt < resolved_max_attempts:
|
||||
sleep(DEFAULT_RUNTIME_CHAT_RETRY_DELAY_SECONDS)
|
||||
|
||||
return RuntimeToolCallResult(None, calls)
|
||||
|
||||
@staticmethod
|
||||
def _build_slot_cache_key(config: dict[str, str]) -> str:
|
||||
return "|".join(
|
||||
@@ -295,6 +438,51 @@ class RuntimeChatService:
|
||||
timeout_seconds=timeout_seconds,
|
||||
)
|
||||
|
||||
def _request_chat_tool_call(
|
||||
self,
|
||||
config: dict[str, str],
|
||||
messages: list[dict[str, Any]],
|
||||
*,
|
||||
tools: list[dict[str, Any]],
|
||||
tool_choice: dict[str, Any] | str | None,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
timeout_seconds: int,
|
||||
) -> RuntimeChatToolCall | None:
|
||||
provider = config["provider"]
|
||||
endpoint = config["endpoint"]
|
||||
model = config["model"]
|
||||
api_key = config["apiKey"]
|
||||
|
||||
if provider == "Azure OpenAI":
|
||||
return self._request_azure_openai_tool_call(
|
||||
endpoint=endpoint,
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
timeout_seconds=timeout_seconds,
|
||||
)
|
||||
|
||||
if provider == "Ollama":
|
||||
raise ConnectivityCheckError("Ollama 暂不支持小财管家 function calling。")
|
||||
|
||||
return self._request_openai_compatible_tool_call(
|
||||
provider=provider,
|
||||
endpoint=endpoint,
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
timeout_seconds=timeout_seconds,
|
||||
)
|
||||
|
||||
def _request_openai_compatible(
|
||||
self,
|
||||
*,
|
||||
@@ -331,6 +519,46 @@ class RuntimeChatService:
|
||||
)
|
||||
return self._extract_openai_text(payload)
|
||||
|
||||
def _request_openai_compatible_tool_call(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
endpoint: str,
|
||||
model: str,
|
||||
api_key: str,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]],
|
||||
tool_choice: dict[str, Any] | str | None,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
timeout_seconds: int,
|
||||
) -> RuntimeChatToolCall | None:
|
||||
url = _ensure_path(_normalize_endpoint(endpoint), "chat/completions")
|
||||
request_payload: dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"tools": tools,
|
||||
"tool_choice": tool_choice or "auto",
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
}
|
||||
if provider == "GLM":
|
||||
request_payload["thinking"] = {"type": "disabled"}
|
||||
|
||||
status_code, payload = _send_json_request(
|
||||
"POST",
|
||||
url,
|
||||
headers=_build_headers(api_key=api_key, use_bearer=True),
|
||||
payload=request_payload,
|
||||
timeout_seconds=timeout_seconds,
|
||||
)
|
||||
if status_code >= HTTPStatus.BAD_REQUEST:
|
||||
raise ConnectivityCheckError(
|
||||
f"模型接口返回异常状态 {status_code}。",
|
||||
status_code=status_code,
|
||||
)
|
||||
return self._extract_openai_tool_call(payload)
|
||||
|
||||
def _request_ollama(
|
||||
self,
|
||||
*,
|
||||
@@ -396,6 +624,41 @@ class RuntimeChatService:
|
||||
)
|
||||
return self._extract_openai_text(payload)
|
||||
|
||||
def _request_azure_openai_tool_call(
|
||||
self,
|
||||
*,
|
||||
endpoint: str,
|
||||
model: str,
|
||||
api_key: str,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]],
|
||||
tool_choice: dict[str, Any] | str | None,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
timeout_seconds: int,
|
||||
) -> RuntimeChatToolCall | None:
|
||||
deployment_base = _build_azure_deployment_base(endpoint, model)
|
||||
url = f"{deployment_base}/chat/completions?api-version={AZURE_API_VERSION}"
|
||||
status_code, payload = _send_json_request(
|
||||
"POST",
|
||||
url,
|
||||
headers=_build_headers(api_key=api_key, use_bearer=False, use_api_key=True),
|
||||
payload={
|
||||
"messages": messages,
|
||||
"tools": tools,
|
||||
"tool_choice": tool_choice or "auto",
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
},
|
||||
timeout_seconds=timeout_seconds,
|
||||
)
|
||||
if status_code >= HTTPStatus.BAD_REQUEST:
|
||||
raise ConnectivityCheckError(
|
||||
f"Azure OpenAI 返回异常状态 {status_code}。",
|
||||
status_code=status_code,
|
||||
)
|
||||
return self._extract_openai_tool_call(payload)
|
||||
|
||||
@staticmethod
|
||||
def _extract_openai_text(payload: Any) -> str:
|
||||
if not isinstance(payload, dict):
|
||||
@@ -426,3 +689,74 @@ class RuntimeChatService:
|
||||
return text.strip()
|
||||
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _extract_openai_tool_call(payload: Any) -> RuntimeChatToolCall | None:
|
||||
if not isinstance(payload, dict):
|
||||
return None
|
||||
|
||||
choices = payload.get("choices")
|
||||
if not isinstance(choices, list) or not choices:
|
||||
return None
|
||||
|
||||
first_choice = choices[0]
|
||||
if not isinstance(first_choice, dict):
|
||||
return None
|
||||
|
||||
message = first_choice.get("message")
|
||||
if not isinstance(message, dict):
|
||||
return None
|
||||
|
||||
tool_calls = message.get("tool_calls")
|
||||
if isinstance(tool_calls, list) and tool_calls:
|
||||
first_tool = tool_calls[0]
|
||||
if isinstance(first_tool, dict):
|
||||
function_payload = first_tool.get("function")
|
||||
if isinstance(function_payload, dict):
|
||||
return RuntimeChatService._build_runtime_tool_call(
|
||||
name=function_payload.get("name"),
|
||||
arguments=function_payload.get("arguments"),
|
||||
call_id=first_tool.get("id"),
|
||||
)
|
||||
|
||||
function_call = message.get("function_call")
|
||||
if isinstance(function_call, dict):
|
||||
return RuntimeChatService._build_runtime_tool_call(
|
||||
name=function_call.get("name"),
|
||||
arguments=function_call.get("arguments"),
|
||||
call_id=None,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _build_runtime_tool_call(
|
||||
*,
|
||||
name: Any,
|
||||
arguments: Any,
|
||||
call_id: Any,
|
||||
) -> RuntimeChatToolCall | None:
|
||||
tool_name = str(name or "").strip()
|
||||
if not tool_name:
|
||||
return None
|
||||
|
||||
raw_arguments = ""
|
||||
if isinstance(arguments, dict):
|
||||
parsed_arguments = arguments
|
||||
raw_arguments = json.dumps(arguments, ensure_ascii=False)
|
||||
else:
|
||||
raw_arguments = str(arguments or "").strip()
|
||||
if not raw_arguments:
|
||||
parsed_arguments = {}
|
||||
else:
|
||||
parsed = json.loads(raw_arguments)
|
||||
if not isinstance(parsed, dict):
|
||||
raise ValueError("工具调用参数必须是 JSON object。")
|
||||
parsed_arguments = parsed
|
||||
|
||||
return RuntimeChatToolCall(
|
||||
name=tool_name,
|
||||
arguments=parsed_arguments,
|
||||
call_id=str(call_id).strip() if call_id else None,
|
||||
raw_arguments=raw_arguments,
|
||||
)
|
||||
|
||||
18
server/src/app/services/steward_constants.py
Normal file
18
server/src/app/services/steward_constants.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
BUSINESS_CANONICAL_FIELD_ORDER = (
|
||||
"expense_type",
|
||||
"time_range",
|
||||
"location",
|
||||
"reason",
|
||||
"amount",
|
||||
"transport_mode",
|
||||
"attachments",
|
||||
"customer_name",
|
||||
"merchant_name",
|
||||
"department_name",
|
||||
"employee_name",
|
||||
"employee_no",
|
||||
)
|
||||
BUSINESS_CANONICAL_FIELDS = frozenset(BUSINESS_CANONICAL_FIELD_ORDER)
|
||||
220
server/src/app/services/steward_intent_agent.py
Normal file
220
server/src/app/services/steward_intent_agent.py
Normal file
@@ -0,0 +1,220 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from datetime import date
|
||||
from typing import Any
|
||||
|
||||
from app.schemas.steward import StewardPlanRequest
|
||||
from app.services.ontology_field_registry import normalize_ontology_form_values
|
||||
from app.services.runtime_chat import RuntimeChatService
|
||||
|
||||
|
||||
STEWARD_INTENT_FUNCTION_NAME = "submit_steward_intent_plan"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class StewardIntentAgentResult:
|
||||
payload: dict[str, Any]
|
||||
model_call_traces: list[dict[str, Any]]
|
||||
|
||||
|
||||
class StewardIntentAgent:
|
||||
"""使用大模型 function calling 识别小财管家的复合财务意图。"""
|
||||
|
||||
def __init__(self, runtime_chat_service: RuntimeChatService) -> None:
|
||||
self.runtime_chat_service = runtime_chat_service
|
||||
self.last_call_traces: list[dict[str, Any]] = []
|
||||
|
||||
def detect(
|
||||
self,
|
||||
request: StewardPlanRequest,
|
||||
*,
|
||||
base_date: date,
|
||||
canonical_fields: list[str],
|
||||
) -> StewardIntentAgentResult | None:
|
||||
result = self.runtime_chat_service.complete_with_tool_call(
|
||||
self._build_messages(request, base_date=base_date, canonical_fields=canonical_fields),
|
||||
tools=[self._build_intent_tool_schema(canonical_fields)],
|
||||
tool_choice={
|
||||
"type": "function",
|
||||
"function": {"name": STEWARD_INTENT_FUNCTION_NAME},
|
||||
},
|
||||
max_tokens=1800,
|
||||
temperature=0.1,
|
||||
timeout_seconds=18,
|
||||
max_attempts=1,
|
||||
)
|
||||
self.last_call_traces = result.calls_as_dicts()
|
||||
if result.tool_call is None or result.tool_call.name != STEWARD_INTENT_FUNCTION_NAME:
|
||||
return None
|
||||
return StewardIntentAgentResult(
|
||||
payload=result.tool_call.arguments,
|
||||
model_call_traces=self.last_call_traces,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_messages(
|
||||
request: StewardPlanRequest,
|
||||
*,
|
||||
base_date: date,
|
||||
canonical_fields: list[str],
|
||||
) -> list[dict[str, Any]]:
|
||||
context_payload = {
|
||||
"message": request.message,
|
||||
"base_date": base_date.isoformat(),
|
||||
"client_now_iso": request.client_now_iso,
|
||||
"user_id": request.user_id,
|
||||
"canonical_ontology_fields": canonical_fields,
|
||||
"review_form_values": normalize_ontology_form_values(
|
||||
request.context_json.get("review_form_values")
|
||||
),
|
||||
"context_json": {
|
||||
key: value
|
||||
for key, value in request.context_json.items()
|
||||
if key
|
||||
in {
|
||||
"entry_source",
|
||||
"session_type",
|
||||
"role_codes",
|
||||
"username",
|
||||
"name",
|
||||
"department_name",
|
||||
"employee_grade",
|
||||
"employee_no",
|
||||
"client_timezone_offset_minutes",
|
||||
}
|
||||
},
|
||||
"attachments": [
|
||||
{
|
||||
"index": index + 1,
|
||||
"name": item.name,
|
||||
"media_type": item.media_type,
|
||||
"ocr_summary": item.ocr_summary,
|
||||
"ocr_fields": item.ocr_fields,
|
||||
}
|
||||
for index, item in enumerate(request.attachments)
|
||||
if item.name
|
||||
],
|
||||
}
|
||||
return [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"你是 X-Financial 的小财管家意图识别智能体。"
|
||||
"你必须通过 function calling 输出结构化计划,不能只返回普通文本。"
|
||||
"当前版本只支持 expense_application 和 reimbursement 两类任务;"
|
||||
"你只做识别、拆解、归集和确认点规划,不能执行入库、绑定附件或提交审批。"
|
||||
"所有 ontology_fields 只能使用调用方给出的 canonical_ontology_fields;"
|
||||
"如果输入里出现 occurred_date、transport_type、reason_value 等别名,必须映射为 canonical 字段。"
|
||||
"相对日期必须以 base_date 为准转换为明确日期。"
|
||||
"thinking_events 只能是面向用户的过程摘要,不能暴露内部推理链。"
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": json.dumps(context_payload, ensure_ascii=False),
|
||||
},
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _build_intent_tool_schema(canonical_fields: list[str]) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": STEWARD_INTENT_FUNCTION_NAME,
|
||||
"description": "提交小财管家的复合财务意图识别结果。",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"thinking_events": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"stage": {"type": "string"},
|
||||
"title": {"type": "string"},
|
||||
"content": {"type": "string"},
|
||||
},
|
||||
"required": ["stage", "title", "content"],
|
||||
},
|
||||
},
|
||||
"tasks": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"task_type": {
|
||||
"type": "string",
|
||||
"enum": ["expense_application", "reimbursement"],
|
||||
},
|
||||
"title": {"type": "string"},
|
||||
"summary": {"type": "string"},
|
||||
"confidence": {
|
||||
"type": "number",
|
||||
"minimum": 0,
|
||||
"maximum": 1,
|
||||
},
|
||||
"ontology_fields": {
|
||||
"type": "object",
|
||||
"additionalProperties": {"type": "string"},
|
||||
},
|
||||
"missing_fields": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"enum": canonical_fields,
|
||||
},
|
||||
},
|
||||
},
|
||||
"required": [
|
||||
"task_type",
|
||||
"title",
|
||||
"summary",
|
||||
"confidence",
|
||||
"ontology_fields",
|
||||
"missing_fields",
|
||||
],
|
||||
},
|
||||
},
|
||||
"attachment_groups": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"target_task_index": {
|
||||
"type": "integer",
|
||||
"minimum": 1,
|
||||
},
|
||||
"scene": {"type": "string"},
|
||||
"scene_label": {"type": "string"},
|
||||
"attachment_names": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
"excluded_attachment_names": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
"confidence": {
|
||||
"type": "number",
|
||||
"minimum": 0,
|
||||
"maximum": 1,
|
||||
},
|
||||
"rationale": {"type": "string"},
|
||||
},
|
||||
"required": [
|
||||
"scene",
|
||||
"scene_label",
|
||||
"attachment_names",
|
||||
"excluded_attachment_names",
|
||||
"confidence",
|
||||
"rationale",
|
||||
],
|
||||
},
|
||||
},
|
||||
},
|
||||
"required": ["thinking_events", "tasks", "attachment_groups"],
|
||||
},
|
||||
},
|
||||
}
|
||||
365
server/src/app/services/steward_model_plan_builder.py
Normal file
365
server/src/app/services/steward_model_plan_builder.py
Normal file
@@ -0,0 +1,365 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import uuid
|
||||
from datetime import date
|
||||
from typing import Any
|
||||
|
||||
from app.schemas.steward import (
|
||||
StewardAttachmentGroup,
|
||||
StewardAttachmentInput,
|
||||
StewardPlanRequest,
|
||||
StewardPlanResponse,
|
||||
StewardTask,
|
||||
StewardThinkingEvent,
|
||||
)
|
||||
from app.services.ontology_field_registry import normalize_ontology_form_values
|
||||
from app.services.steward_constants import BUSINESS_CANONICAL_FIELDS
|
||||
from app.services.steward_intent_agent import StewardIntentAgentResult
|
||||
|
||||
|
||||
class StewardModelPlanBuilder:
|
||||
"""把模型 function calling 返回值转换为小财管家的服务端计划。"""
|
||||
|
||||
def __init__(self, planner: Any) -> None:
|
||||
self.planner = planner
|
||||
|
||||
def build(
|
||||
self,
|
||||
intent_result: StewardIntentAgentResult,
|
||||
*,
|
||||
request: StewardPlanRequest,
|
||||
base_date: date,
|
||||
) -> StewardPlanResponse | None:
|
||||
tasks = self._build_tasks_from_model_payload(intent_result.payload, request, base_date)
|
||||
if not tasks:
|
||||
return None
|
||||
|
||||
attachment_groups = self._build_attachment_groups_from_model_payload(
|
||||
intent_result.payload,
|
||||
request.attachments,
|
||||
tasks,
|
||||
)
|
||||
if request.attachments and not attachment_groups:
|
||||
attachment_groups = self.planner._build_attachment_groups(request.attachments, tasks)
|
||||
confirmation_groups = self.planner._build_confirmation_actions(tasks, attachment_groups)
|
||||
thinking_events = self._build_llm_thinking_events(
|
||||
intent_result.payload,
|
||||
tasks=tasks,
|
||||
attachment_groups=attachment_groups,
|
||||
attachments=request.attachments,
|
||||
)
|
||||
|
||||
return StewardPlanResponse(
|
||||
plan_id=f"steward_plan_{uuid.uuid4().hex[:12]}",
|
||||
plan_status="needs_confirmation" if confirmation_groups else "ready_to_delegate",
|
||||
planning_source="llm_function_call",
|
||||
summary=self.planner._build_summary(tasks, attachment_groups),
|
||||
thinking_events=thinking_events,
|
||||
tasks=tasks,
|
||||
attachment_groups=attachment_groups,
|
||||
confirmation_groups=confirmation_groups,
|
||||
model_call_traces=intent_result.model_call_traces,
|
||||
)
|
||||
|
||||
def _build_tasks_from_model_payload(
|
||||
self,
|
||||
payload: dict[str, Any],
|
||||
request: StewardPlanRequest,
|
||||
base_date: date,
|
||||
) -> list[StewardTask]:
|
||||
raw_tasks = payload.get("tasks")
|
||||
if not isinstance(raw_tasks, list):
|
||||
return []
|
||||
|
||||
tasks: list[StewardTask] = []
|
||||
for raw_task in raw_tasks:
|
||||
if not isinstance(raw_task, dict):
|
||||
continue
|
||||
task_type = str(raw_task.get("task_type") or "").strip()
|
||||
if task_type not in {"expense_application", "reimbursement"}:
|
||||
continue
|
||||
|
||||
task_index = len(tasks) + 1
|
||||
fields = self._sanitize_model_ontology_fields(
|
||||
raw_task.get("ontology_fields"),
|
||||
request=request,
|
||||
base_date=base_date,
|
||||
)
|
||||
supplement_segment = " ".join(
|
||||
[
|
||||
str(raw_task.get("title") or ""),
|
||||
str(raw_task.get("summary") or ""),
|
||||
]
|
||||
)
|
||||
supplement_fields = self.planner._extract_ontology_fields(
|
||||
supplement_segment,
|
||||
task_type,
|
||||
base_date,
|
||||
request,
|
||||
)
|
||||
for key, value in supplement_fields.items():
|
||||
fields.setdefault(key, value)
|
||||
|
||||
assigned_agent = (
|
||||
"application_assistant"
|
||||
if task_type == "expense_application"
|
||||
else "reimbursement_assistant"
|
||||
)
|
||||
task_id = f"task_{'app' if task_type == 'expense_application' else 'reim'}_{task_index:03d}"
|
||||
title_prefix = "费用申请" if task_type == "expense_application" else "费用报销"
|
||||
title = self.planner._clean_text(raw_task.get("title")) or self.planner._build_task_title(
|
||||
title_prefix,
|
||||
fields,
|
||||
task_index,
|
||||
)
|
||||
summary = self.planner._clean_text(raw_task.get("summary")) or self.planner._build_task_summary(
|
||||
supplement_segment,
|
||||
fields,
|
||||
)
|
||||
missing_fields = self._sanitize_model_missing_fields(
|
||||
raw_task.get("missing_fields"),
|
||||
task_type=task_type,
|
||||
fields=fields,
|
||||
)
|
||||
tasks.append(
|
||||
StewardTask(
|
||||
task_id=task_id,
|
||||
task_type=task_type, # type: ignore[arg-type]
|
||||
assigned_agent=assigned_agent, # type: ignore[arg-type]
|
||||
title=title,
|
||||
summary=summary,
|
||||
status="needs_confirmation",
|
||||
confidence=self._resolve_model_confidence(
|
||||
raw_task.get("confidence"),
|
||||
segment=supplement_segment,
|
||||
fields=fields,
|
||||
task_type=task_type,
|
||||
),
|
||||
ontology_fields=fields,
|
||||
missing_fields=missing_fields,
|
||||
confirmation_required=True,
|
||||
)
|
||||
)
|
||||
|
||||
return tasks
|
||||
|
||||
def _sanitize_model_ontology_fields(
|
||||
self,
|
||||
raw_fields: Any,
|
||||
*,
|
||||
request: StewardPlanRequest,
|
||||
base_date: date,
|
||||
) -> dict[str, str]:
|
||||
normalized_context = normalize_ontology_form_values(request.context_json.get("review_form_values"))
|
||||
fields: dict[str, str] = {
|
||||
key: value
|
||||
for key, value in normalized_context.items()
|
||||
if key in BUSINESS_CANONICAL_FIELDS and str(value or "").strip()
|
||||
}
|
||||
if not isinstance(raw_fields, dict):
|
||||
return fields
|
||||
|
||||
normalized_model_fields = normalize_ontology_form_values(raw_fields)
|
||||
for key, value in normalized_model_fields.items():
|
||||
if key not in BUSINESS_CANONICAL_FIELDS:
|
||||
continue
|
||||
normalized_value = self._normalize_model_field_value(key, value, base_date)
|
||||
if normalized_value:
|
||||
fields[key] = normalized_value
|
||||
if request.attachments and not fields.get("attachments"):
|
||||
fields["attachments"] = "、".join(item.name for item in request.attachments if item.name)
|
||||
return {key: value for key, value in fields.items() if key in BUSINESS_CANONICAL_FIELDS and value}
|
||||
|
||||
def _build_attachment_groups_from_model_payload(
|
||||
self,
|
||||
payload: dict[str, Any],
|
||||
attachments: list[StewardAttachmentInput],
|
||||
tasks: list[StewardTask],
|
||||
) -> list[StewardAttachmentGroup]:
|
||||
raw_groups = payload.get("attachment_groups")
|
||||
if not isinstance(raw_groups, list) or not attachments:
|
||||
return []
|
||||
|
||||
uploaded_names = {item.name for item in attachments if item.name}
|
||||
groups: list[StewardAttachmentGroup] = []
|
||||
for raw_group in raw_groups:
|
||||
if not isinstance(raw_group, dict):
|
||||
continue
|
||||
attachment_names = self._filter_uploaded_attachment_names(
|
||||
raw_group.get("attachment_names"),
|
||||
uploaded_names,
|
||||
)
|
||||
excluded_names = self._filter_uploaded_attachment_names(
|
||||
raw_group.get("excluded_attachment_names"),
|
||||
uploaded_names,
|
||||
)
|
||||
if not attachment_names and not excluded_names:
|
||||
continue
|
||||
|
||||
scene = self.planner._clean_text(raw_group.get("scene")) or "other"
|
||||
groups.append(
|
||||
StewardAttachmentGroup(
|
||||
group_id=f"ag_{self._slug_scene(scene)}_{len(groups) + 1:03d}",
|
||||
target_task_id=self._resolve_model_group_target_task_id(raw_group, tasks),
|
||||
scene=scene,
|
||||
scene_label=self.planner._clean_text(raw_group.get("scene_label")) or "待确认费用",
|
||||
attachment_names=attachment_names,
|
||||
excluded_attachment_names=excluded_names,
|
||||
confidence=self._clamp_confidence(raw_group.get("confidence"), default=0.68),
|
||||
rationale=(
|
||||
self.planner._clean_text(raw_group.get("rationale"))
|
||||
or "模型根据附件线索生成归集建议。"
|
||||
),
|
||||
confirmation_required=True,
|
||||
)
|
||||
)
|
||||
|
||||
return groups
|
||||
|
||||
def _build_llm_thinking_events(
|
||||
self,
|
||||
payload: dict[str, Any],
|
||||
*,
|
||||
tasks: list[StewardTask],
|
||||
attachment_groups: list[StewardAttachmentGroup],
|
||||
attachments: list[StewardAttachmentInput],
|
||||
) -> list[StewardThinkingEvent]:
|
||||
events = [
|
||||
StewardThinkingEvent(
|
||||
event_id="intent_agent_function_call",
|
||||
stage="llm_function_call",
|
||||
title="意图识别智能体接管",
|
||||
content=(
|
||||
"已调用系统主模型的 submit_steward_intent_plan 工具,"
|
||||
"把用户话术转换为可校验的结构化财务任务计划。"
|
||||
),
|
||||
)
|
||||
]
|
||||
raw_events = payload.get("thinking_events")
|
||||
if isinstance(raw_events, list):
|
||||
for raw_event in raw_events[:4]:
|
||||
if not isinstance(raw_event, dict):
|
||||
continue
|
||||
title = self.planner._clean_text(raw_event.get("title"))
|
||||
content = self.planner._clean_text(raw_event.get("content"))
|
||||
if not title or not content:
|
||||
continue
|
||||
events.append(
|
||||
StewardThinkingEvent(
|
||||
event_id=f"intent_agent_model_{len(events):03d}",
|
||||
stage=self.planner._clean_text(raw_event.get("stage")) or "model_summary",
|
||||
title=title,
|
||||
content=content,
|
||||
)
|
||||
)
|
||||
if len(events) == 1:
|
||||
events.extend(self.planner._build_thinking_events(tasks, attachment_groups, attachments)[1:])
|
||||
return events
|
||||
|
||||
def _sanitize_model_missing_fields(
|
||||
self,
|
||||
raw_missing_fields: Any,
|
||||
*,
|
||||
task_type: str,
|
||||
fields: dict[str, str],
|
||||
) -> list[str]:
|
||||
missing_fields: list[str] = []
|
||||
if isinstance(raw_missing_fields, list):
|
||||
for item in raw_missing_fields:
|
||||
key = str(item or "").strip()
|
||||
if key in BUSINESS_CANONICAL_FIELDS and key not in missing_fields and not fields.get(key):
|
||||
missing_fields.append(key)
|
||||
for key in self.planner._resolve_missing_fields(task_type, fields):
|
||||
if key not in missing_fields:
|
||||
missing_fields.append(key)
|
||||
return missing_fields
|
||||
|
||||
def _resolve_model_confidence(
|
||||
self,
|
||||
value: Any,
|
||||
*,
|
||||
segment: str,
|
||||
fields: dict[str, str],
|
||||
task_type: str,
|
||||
) -> float:
|
||||
return self._clamp_confidence(
|
||||
value,
|
||||
default=self.planner._resolve_task_confidence(segment, fields, task_type),
|
||||
)
|
||||
|
||||
def _normalize_model_field_value(self, key: str, value: Any, base_date: date) -> str:
|
||||
cleaned = self.planner._clean_text(value)
|
||||
if not cleaned:
|
||||
return ""
|
||||
if key == "time_range":
|
||||
return self.planner._extract_time_range(cleaned, base_date) or cleaned
|
||||
if key == "expense_type":
|
||||
return self._normalize_expense_type_value(cleaned)
|
||||
if key == "transport_mode":
|
||||
return self._normalize_transport_mode_value(cleaned)
|
||||
return cleaned
|
||||
|
||||
@staticmethod
|
||||
def _normalize_expense_type_value(value: str) -> str:
|
||||
normalized = str(value or "").strip().lower()
|
||||
if normalized in {"travel", "travel_application", "差旅", "差旅费", "出差"}:
|
||||
return "travel"
|
||||
if normalized in {"transport", "traffic", "交通", "交通费", "打车", "出租车"}:
|
||||
return "transport"
|
||||
if normalized in {"entertainment", "meal", "招待", "接待", "餐饮", "业务招待"}:
|
||||
return "entertainment"
|
||||
if normalized in {"office", "办公", "办公用品"}:
|
||||
return "office"
|
||||
return normalized
|
||||
|
||||
@staticmethod
|
||||
def _normalize_transport_mode_value(value: str) -> str:
|
||||
normalized = str(value or "").strip().lower()
|
||||
if normalized in {"train", "高铁", "动车", "火车"}:
|
||||
return "train"
|
||||
if normalized in {"flight", "air", "飞机", "机票", "航班"}:
|
||||
return "flight"
|
||||
if normalized in {"taxi", "出租车", "的士", "网约车", "打车"}:
|
||||
return "taxi"
|
||||
if normalized in {"subway", "地铁"}:
|
||||
return "subway"
|
||||
return normalized
|
||||
|
||||
@staticmethod
|
||||
def _filter_uploaded_attachment_names(raw_names: Any, uploaded_names: set[str]) -> list[str]:
|
||||
if not isinstance(raw_names, list):
|
||||
return []
|
||||
names: list[str] = []
|
||||
for raw_name in raw_names:
|
||||
name = str(raw_name or "").strip()
|
||||
if name in uploaded_names and name not in names:
|
||||
names.append(name)
|
||||
return names
|
||||
|
||||
@staticmethod
|
||||
def _resolve_model_group_target_task_id(raw_group: dict[str, Any], tasks: list[StewardTask]) -> str | None:
|
||||
try:
|
||||
target_index = int(raw_group.get("target_task_index") or 0)
|
||||
except (TypeError, ValueError):
|
||||
target_index = 0
|
||||
if target_index > 0 and target_index <= len(tasks):
|
||||
return tasks[target_index - 1].task_id
|
||||
|
||||
target_task_id = str(raw_group.get("target_task_id") or "").strip()
|
||||
if target_task_id and any(task.task_id == target_task_id for task in tasks):
|
||||
return target_task_id
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _slug_scene(value: str) -> str:
|
||||
normalized = re.sub(r"[^a-zA-Z0-9_]+", "_", str(value or "").strip().lower()).strip("_")
|
||||
return normalized or "other"
|
||||
|
||||
@staticmethod
|
||||
def _clamp_confidence(value: Any, *, default: float) -> float:
|
||||
try:
|
||||
parsed = float(value)
|
||||
except (TypeError, ValueError):
|
||||
parsed = default
|
||||
return round(min(1.0, max(0.0, parsed)), 2)
|
||||
645
server/src/app/services/steward_planner.py
Normal file
645
server/src/app/services/steward_planner.py
Normal file
@@ -0,0 +1,645 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, date, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from app.schemas.steward import (
|
||||
StewardAttachmentGroup,
|
||||
StewardAttachmentInput,
|
||||
StewardConfirmationAction,
|
||||
StewardPlanRequest,
|
||||
StewardPlanResponse,
|
||||
StewardTask,
|
||||
StewardThinkingEvent,
|
||||
)
|
||||
from app.services.steward_constants import BUSINESS_CANONICAL_FIELD_ORDER, BUSINESS_CANONICAL_FIELDS
|
||||
from app.services.ontology_field_registry import normalize_ontology_form_values
|
||||
from app.services.steward_intent_agent import StewardIntentAgent
|
||||
from app.services.steward_model_plan_builder import StewardModelPlanBuilder
|
||||
|
||||
|
||||
CITY_NAMES = (
|
||||
"北京",
|
||||
"上海",
|
||||
"广州",
|
||||
"深圳",
|
||||
"杭州",
|
||||
"南京",
|
||||
"苏州",
|
||||
"成都",
|
||||
"重庆",
|
||||
"天津",
|
||||
"武汉",
|
||||
"西安",
|
||||
"长沙",
|
||||
"郑州",
|
||||
"青岛",
|
||||
"厦门",
|
||||
"福州",
|
||||
"合肥",
|
||||
"济南",
|
||||
"沈阳",
|
||||
"大连",
|
||||
"宁波",
|
||||
"无锡",
|
||||
)
|
||||
|
||||
APPLICATION_SPLIT_PATTERN = re.compile(r"(?:^|[,,。;;])[^,,。;;]*?(?:申请|出差申请|差旅申请)[^,,。;;]*")
|
||||
REIMBURSEMENT_PATTERN = re.compile(r"(?:我要报销|还需要报销|需要报销|报销)([^,,。;;!??!\n]+)")
|
||||
MONTH_DAY_PATTERN = re.compile(r"(?P<month>\d{1,2})\s*月\s*(?P<day>\d{1,2})\s*(?:日|号)?")
|
||||
ISO_DATE_PATTERN = re.compile(r"(?P<year>\d{4})[-/年](?P<month>\d{1,2})[-/月](?P<day>\d{1,2})(?:日)?")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PlannedTaskDraft:
|
||||
task_type: str
|
||||
segment: str
|
||||
index: int
|
||||
|
||||
|
||||
class StewardPlannerService:
|
||||
"""小财管家第一版规划服务:只生成计划,不执行入库类动作。"""
|
||||
|
||||
def __init__(self, intent_agent: StewardIntentAgent | None = None) -> None:
|
||||
self.intent_agent = intent_agent
|
||||
|
||||
def build_plan(self, request: StewardPlanRequest) -> StewardPlanResponse:
|
||||
message = self._clean_text(request.message)
|
||||
if not message:
|
||||
raise ValueError("小财管家需要一段任务描述。")
|
||||
|
||||
base_date = self._resolve_base_date(request.client_now_iso, request.context_json)
|
||||
model_call_traces: list[dict[str, Any]] = []
|
||||
fallback_reason = ""
|
||||
if self.intent_agent is not None:
|
||||
try:
|
||||
intent_result = self.intent_agent.detect(
|
||||
request,
|
||||
base_date=base_date,
|
||||
canonical_fields=list(BUSINESS_CANONICAL_FIELD_ORDER),
|
||||
)
|
||||
if intent_result is not None:
|
||||
model_call_traces = intent_result.model_call_traces
|
||||
llm_plan = StewardModelPlanBuilder(self).build(
|
||||
intent_result,
|
||||
request=request,
|
||||
base_date=base_date,
|
||||
)
|
||||
if llm_plan is not None:
|
||||
return llm_plan
|
||||
model_call_traces = getattr(self.intent_agent, "last_call_traces", []) or model_call_traces
|
||||
fallback_reason = "主模型未返回可用的 function calling 计划,已切换到规则兜底。"
|
||||
except Exception as exc:
|
||||
model_call_traces = getattr(self.intent_agent, "last_call_traces", []) or model_call_traces
|
||||
fallback_reason = f"主模型 function calling 调用失败,已切换到规则兜底:{exc}"
|
||||
|
||||
return self._build_rule_fallback_plan(
|
||||
request,
|
||||
base_date=base_date,
|
||||
model_call_traces=model_call_traces,
|
||||
fallback_reason=fallback_reason,
|
||||
)
|
||||
|
||||
def _build_rule_fallback_plan(
|
||||
self,
|
||||
request: StewardPlanRequest,
|
||||
*,
|
||||
base_date: date,
|
||||
model_call_traces: list[dict[str, Any]] | None = None,
|
||||
fallback_reason: str = "",
|
||||
) -> StewardPlanResponse:
|
||||
message = self._clean_text(request.message)
|
||||
task_drafts = self._extract_task_drafts(message)
|
||||
tasks = [self._build_task(draft, base_date, request) for draft in task_drafts]
|
||||
if not tasks:
|
||||
tasks = [self._build_fallback_task(message, base_date, request)]
|
||||
|
||||
attachment_groups = self._build_attachment_groups(request.attachments, tasks)
|
||||
confirmation_groups = self._build_confirmation_actions(tasks, attachment_groups)
|
||||
thinking_events = self._build_thinking_events(tasks, attachment_groups, request.attachments)
|
||||
if fallback_reason:
|
||||
thinking_events.insert(
|
||||
0,
|
||||
StewardThinkingEvent(
|
||||
event_id="intent_agent_rule_fallback",
|
||||
stage="rule_fallback",
|
||||
title="意图识别智能体进入兜底模式",
|
||||
content=fallback_reason,
|
||||
),
|
||||
)
|
||||
plan_id = f"steward_plan_{uuid.uuid4().hex[:12]}"
|
||||
|
||||
return StewardPlanResponse(
|
||||
plan_id=plan_id,
|
||||
plan_status="needs_confirmation" if confirmation_groups else "ready_to_delegate",
|
||||
planning_source="rule_fallback",
|
||||
summary=self._build_summary(tasks, attachment_groups),
|
||||
thinking_events=thinking_events,
|
||||
tasks=tasks,
|
||||
attachment_groups=attachment_groups,
|
||||
confirmation_groups=confirmation_groups,
|
||||
model_call_traces=model_call_traces or [],
|
||||
)
|
||||
|
||||
def _extract_task_drafts(self, message: str) -> list[PlannedTaskDraft]:
|
||||
drafts: list[PlannedTaskDraft] = []
|
||||
first_reimbursement = self._find_first_reimbursement_index(message)
|
||||
application_source = message[:first_reimbursement] if first_reimbursement >= 0 else message
|
||||
if self._looks_like_application(application_source):
|
||||
drafts.append(
|
||||
PlannedTaskDraft(
|
||||
task_type="expense_application",
|
||||
segment=application_source.strip(",,。;; "),
|
||||
index=len(drafts) + 1,
|
||||
)
|
||||
)
|
||||
|
||||
for match in REIMBURSEMENT_PATTERN.finditer(message):
|
||||
segment = f"报销{match.group(1)}"
|
||||
drafts.append(
|
||||
PlannedTaskDraft(
|
||||
task_type="reimbursement",
|
||||
segment=segment.strip(",,。;; "),
|
||||
index=len(drafts) + 1,
|
||||
)
|
||||
)
|
||||
|
||||
return drafts
|
||||
|
||||
@staticmethod
|
||||
def _find_first_reimbursement_index(message: str) -> int:
|
||||
candidates = [message.find(item) for item in ("我要报销", "还需要报销", "需要报销", "报销")]
|
||||
positives = [item for item in candidates if item >= 0]
|
||||
return min(positives) if positives else -1
|
||||
|
||||
@staticmethod
|
||||
def _looks_like_application(text: str) -> bool:
|
||||
compact = re.sub(r"\s+", "", text)
|
||||
return bool(compact) and "申请" in compact and bool(re.search(r"出差|差旅|费用|交通|住宿|采购|会务|会议", compact))
|
||||
|
||||
def _build_task(
|
||||
self,
|
||||
draft: PlannedTaskDraft,
|
||||
base_date: date,
|
||||
request: StewardPlanRequest,
|
||||
) -> StewardTask:
|
||||
fields = self._extract_ontology_fields(draft.segment, draft.task_type, base_date, request)
|
||||
missing_fields = self._resolve_missing_fields(draft.task_type, fields)
|
||||
task_id = f"task_{'app' if draft.task_type == 'expense_application' else 'reim'}_{draft.index:03d}"
|
||||
assigned_agent = (
|
||||
"application_assistant"
|
||||
if draft.task_type == "expense_application"
|
||||
else "reimbursement_assistant"
|
||||
)
|
||||
title_prefix = "费用申请" if draft.task_type == "expense_application" else "费用报销"
|
||||
title = self._build_task_title(title_prefix, fields, draft.index)
|
||||
return StewardTask(
|
||||
task_id=task_id,
|
||||
task_type=draft.task_type, # type: ignore[arg-type]
|
||||
assigned_agent=assigned_agent, # type: ignore[arg-type]
|
||||
title=title,
|
||||
summary=self._build_task_summary(draft.segment, fields),
|
||||
status="needs_confirmation",
|
||||
confidence=self._resolve_task_confidence(draft.segment, fields, draft.task_type),
|
||||
ontology_fields=fields,
|
||||
missing_fields=missing_fields,
|
||||
confirmation_required=True,
|
||||
)
|
||||
|
||||
def _build_fallback_task(
|
||||
self,
|
||||
message: str,
|
||||
base_date: date,
|
||||
request: StewardPlanRequest,
|
||||
) -> StewardTask:
|
||||
task_type = "reimbursement" if "报销" in message or request.attachments else "expense_application"
|
||||
draft = PlannedTaskDraft(task_type=task_type, segment=message, index=1)
|
||||
task = self._build_task(draft, base_date, request)
|
||||
return task.model_copy(update={"confidence": min(task.confidence, 0.58)})
|
||||
|
||||
def _extract_ontology_fields(
|
||||
self,
|
||||
segment: str,
|
||||
task_type: str,
|
||||
base_date: date,
|
||||
request: StewardPlanRequest,
|
||||
) -> dict[str, str]:
|
||||
normalized_context = normalize_ontology_form_values(request.context_json.get("review_form_values"))
|
||||
fields: dict[str, str] = {
|
||||
key: value
|
||||
for key, value in normalized_context.items()
|
||||
if key in BUSINESS_CANONICAL_FIELDS and str(value or "").strip()
|
||||
}
|
||||
expense_type = self._infer_expense_type(segment, task_type)
|
||||
if expense_type and not fields.get("expense_type"):
|
||||
fields["expense_type"] = expense_type
|
||||
time_range = self._extract_time_range(segment, base_date)
|
||||
if time_range and not fields.get("time_range"):
|
||||
fields["time_range"] = time_range
|
||||
location = self._extract_location(segment)
|
||||
if location and not fields.get("location"):
|
||||
fields["location"] = location
|
||||
reason = self._extract_reason(segment, task_type)
|
||||
if reason and not fields.get("reason"):
|
||||
fields["reason"] = reason
|
||||
transport_mode = self._extract_transport_mode(segment)
|
||||
if transport_mode and not fields.get("transport_mode"):
|
||||
fields["transport_mode"] = transport_mode
|
||||
if request.attachments:
|
||||
fields["attachments"] = "、".join(item.name for item in request.attachments if item.name)
|
||||
|
||||
return {key: value for key, value in fields.items() if key in BUSINESS_CANONICAL_FIELDS and value}
|
||||
|
||||
@staticmethod
|
||||
def _infer_expense_type(segment: str, task_type: str) -> str:
|
||||
compact = re.sub(r"\s+", "", segment)
|
||||
if re.search(r"招待|接待|餐饮|宴请|客户吃饭|业务餐", compact):
|
||||
return "entertainment"
|
||||
if re.search(r"出差|差旅|住宿|酒店|机票|航班|高铁|火车", compact):
|
||||
return "travel"
|
||||
if re.search(r"交通|出租车|的士|网约车|打车|地铁|公交", compact):
|
||||
return "transport" if task_type == "reimbursement" else "travel"
|
||||
return "travel" if task_type == "expense_application" else "other"
|
||||
|
||||
def _extract_time_range(self, segment: str, base_date: date) -> str:
|
||||
compact = re.sub(r"\s+", "", segment)
|
||||
if "昨天" in compact:
|
||||
return (base_date - timedelta(days=1)).isoformat()
|
||||
if "前天" in compact:
|
||||
return (base_date - timedelta(days=2)).isoformat()
|
||||
if "明天" in compact:
|
||||
return (base_date + timedelta(days=1)).isoformat()
|
||||
if "后天" in compact:
|
||||
return (base_date + timedelta(days=2)).isoformat()
|
||||
|
||||
iso_match = ISO_DATE_PATTERN.search(compact)
|
||||
if iso_match:
|
||||
return self._safe_date(
|
||||
int(iso_match.group("year")),
|
||||
int(iso_match.group("month")),
|
||||
int(iso_match.group("day")),
|
||||
)
|
||||
|
||||
month_day = MONTH_DAY_PATTERN.search(compact)
|
||||
if month_day:
|
||||
return self._safe_date(
|
||||
base_date.year,
|
||||
int(month_day.group("month")),
|
||||
int(month_day.group("day")),
|
||||
)
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _safe_date(year: int, month: int, day: int) -> str:
|
||||
try:
|
||||
return date(year, month, day).isoformat()
|
||||
except ValueError:
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _extract_location(segment: str) -> str:
|
||||
compact = re.sub(r"\s+", "", segment)
|
||||
for prefix in ("去", "到", "赴", "前往"):
|
||||
match = re.search(fr"{prefix}({'|'.join(CITY_NAMES)})", compact)
|
||||
if match:
|
||||
return match.group(1)
|
||||
for city in CITY_NAMES:
|
||||
if city in compact:
|
||||
return city
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _extract_reason(segment: str, task_type: str) -> str:
|
||||
cleaned = re.sub(r"\s+", "", segment).strip(",,。;; ")
|
||||
if task_type == "expense_application":
|
||||
match = re.search(r"(辅助|支持|协助|支撑|参加|拜访|调研|实施|部署|审核).+", cleaned)
|
||||
if match:
|
||||
return StewardPlannerService._strip_trailing_connectors(match.group(0))
|
||||
reason = re.sub(r"^.*?(?:出差|差旅)", "", cleaned).strip(",,。;;的费用")
|
||||
return StewardPlannerService._strip_trailing_connectors(reason) or cleaned
|
||||
cleaned = re.sub(r"^报销", "", cleaned)
|
||||
cleaned = re.sub(r"^(?:昨天|前天|明天|后天|\d{1,2}月\d{1,2}(?:日|号)?)的?", "", cleaned)
|
||||
return cleaned.strip(",,。;; ") or segment.strip()
|
||||
|
||||
@staticmethod
|
||||
def _strip_trailing_connectors(value: str) -> str:
|
||||
cleaned = str(value or "").strip(",,。;; ")
|
||||
return re.sub(r"(?:并且|而且|同时|另外|还需要|需要)$", "", cleaned).strip(",,。;; ")
|
||||
|
||||
@staticmethod
|
||||
def _extract_transport_mode(segment: str) -> str:
|
||||
compact = re.sub(r"\s+", "", segment)
|
||||
if re.search(r"高铁|动车|火车", compact):
|
||||
return "train"
|
||||
if re.search(r"飞机|机票|航班", compact):
|
||||
return "flight"
|
||||
if re.search(r"出租车|的士|网约车|打车", compact):
|
||||
return "taxi"
|
||||
if "交通" in compact:
|
||||
return "other"
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _resolve_missing_fields(task_type: str, fields: dict[str, str]) -> list[str]:
|
||||
required = ["expense_type", "time_range", "reason"]
|
||||
if task_type == "expense_application":
|
||||
required.append("location")
|
||||
return [key for key in required if not str(fields.get(key) or "").strip()]
|
||||
|
||||
@staticmethod
|
||||
def _resolve_task_confidence(segment: str, fields: dict[str, str], task_type: str) -> float:
|
||||
compact = re.sub(r"\s+", "", segment)
|
||||
intent_score = 1.0 if ("申请" in compact if task_type == "expense_application" else "报销" in compact) else 0.45
|
||||
time_score = 1.0 if fields.get("time_range") else 0.0
|
||||
location_score = 1.0 if fields.get("location") else 0.2
|
||||
scene_score = 1.0 if fields.get("expense_type") and fields["expense_type"] != "other" else 0.35
|
||||
confidence = min(1.0, 0.35 * intent_score + 0.25 * time_score + 0.2 * location_score + 0.2 * scene_score)
|
||||
return round(max(0.45, confidence), 2)
|
||||
|
||||
def _build_attachment_groups(
|
||||
self,
|
||||
attachments: list[StewardAttachmentInput],
|
||||
tasks: list[StewardTask],
|
||||
) -> list[StewardAttachmentGroup]:
|
||||
if not attachments:
|
||||
return []
|
||||
|
||||
classified = [(item, self._classify_attachment(item)) for item in attachments if item.name]
|
||||
travel_related = [item.name for item, scene in classified if scene in {"travel", "transport"}]
|
||||
excluded = [item.name for item, scene in classified if scene not in {"travel", "transport"}]
|
||||
target_task = self._resolve_attachment_target_task(tasks)
|
||||
|
||||
groups: list[StewardAttachmentGroup] = []
|
||||
if travel_related:
|
||||
confidence = 0.72 + min(0.18, len(travel_related) * 0.04)
|
||||
groups.append(
|
||||
StewardAttachmentGroup(
|
||||
group_id="ag_travel_001",
|
||||
target_task_id=target_task.task_id if target_task else None,
|
||||
scene="travel",
|
||||
scene_label="差旅相关费用",
|
||||
attachment_names=travel_related,
|
||||
excluded_attachment_names=excluded,
|
||||
confidence=round(confidence, 2),
|
||||
rationale="附件名称或 OCR 摘要中包含差旅、交通、住宿、火车、机票等线索。",
|
||||
confirmation_required=True,
|
||||
)
|
||||
)
|
||||
elif excluded:
|
||||
groups.append(
|
||||
StewardAttachmentGroup(
|
||||
group_id="ag_other_001",
|
||||
target_task_id=None,
|
||||
scene="other",
|
||||
scene_label="待人工确认费用",
|
||||
attachment_names=excluded,
|
||||
excluded_attachment_names=[],
|
||||
confidence=0.5,
|
||||
rationale="当前附件缺少可稳定归属到申请或报销任务的差旅线索。",
|
||||
confirmation_required=True,
|
||||
)
|
||||
)
|
||||
return groups
|
||||
|
||||
@staticmethod
|
||||
def _resolve_attachment_target_task(tasks: list[StewardTask]) -> StewardTask | None:
|
||||
reimbursement_tasks = [item for item in tasks if item.task_type == "reimbursement"]
|
||||
for task in reimbursement_tasks:
|
||||
if task.ontology_fields.get("expense_type") == "travel":
|
||||
return task
|
||||
return reimbursement_tasks[0] if reimbursement_tasks else None
|
||||
|
||||
@staticmethod
|
||||
def _classify_attachment(attachment: StewardAttachmentInput) -> str:
|
||||
text = " ".join(
|
||||
[
|
||||
attachment.name,
|
||||
attachment.media_type,
|
||||
attachment.ocr_summary,
|
||||
" ".join(f"{key}:{value}" for key, value in attachment.ocr_fields.items()),
|
||||
]
|
||||
)
|
||||
compact = re.sub(r"\s+", "", text).lower()
|
||||
if re.search(r"招待|接待|餐饮|宴请|客户|meal|entertainment", compact):
|
||||
return "entertainment"
|
||||
if re.search(r"酒店|住宿|差旅|出差|高铁|火车|动车|机票|航班|train|flight|hotel|travel", compact):
|
||||
return "travel"
|
||||
if re.search(r"出租车|的士|网约车|打车|交通|taxi|transport", compact):
|
||||
return "transport"
|
||||
return "other"
|
||||
|
||||
def _build_confirmation_actions(
|
||||
self,
|
||||
tasks: list[StewardTask],
|
||||
attachment_groups: list[StewardAttachmentGroup],
|
||||
) -> list[StewardConfirmationAction]:
|
||||
actions: list[StewardConfirmationAction] = []
|
||||
for task in tasks:
|
||||
if task.task_type == "expense_application":
|
||||
action_type = "confirm_create_application"
|
||||
label = "确认创建申请单"
|
||||
else:
|
||||
action_type = "confirm_create_reimbursement_draft"
|
||||
label = "确认创建报销草稿"
|
||||
actions.append(
|
||||
StewardConfirmationAction(
|
||||
confirmation_id=f"confirm_{task.task_id}",
|
||||
action_type=action_type,
|
||||
label=label,
|
||||
description=f"确认后把“{task.title}”交给{self._agent_label(task.assigned_agent)}继续核对。",
|
||||
target_task_id=task.task_id,
|
||||
payload={
|
||||
"task_id": task.task_id,
|
||||
"task_type": task.task_type,
|
||||
"assigned_agent": task.assigned_agent,
|
||||
"ontology_fields": task.ontology_fields,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
for group in attachment_groups:
|
||||
actions.append(
|
||||
StewardConfirmationAction(
|
||||
confirmation_id=f"confirm_{group.group_id}",
|
||||
action_type="confirm_attachment_group",
|
||||
label="确认附件归集",
|
||||
description=f"确认后将 {len(group.attachment_names)} 份附件按“{group.scene_label}”归集。",
|
||||
target_task_id=group.target_task_id,
|
||||
attachment_group_id=group.group_id,
|
||||
payload={
|
||||
"attachment_group_id": group.group_id,
|
||||
"target_task_id": group.target_task_id,
|
||||
"attachment_names": group.attachment_names,
|
||||
"excluded_attachment_names": group.excluded_attachment_names,
|
||||
},
|
||||
)
|
||||
)
|
||||
return actions
|
||||
|
||||
@staticmethod
|
||||
def _agent_label(assigned_agent: str) -> str:
|
||||
return "申请助手" if assigned_agent == "application_assistant" else "报销助手"
|
||||
|
||||
def _build_thinking_events(
|
||||
self,
|
||||
tasks: list[StewardTask],
|
||||
attachment_groups: list[StewardAttachmentGroup],
|
||||
attachments: list[StewardAttachmentInput],
|
||||
) -> list[StewardThinkingEvent]:
|
||||
application_count = sum(1 for item in tasks if item.task_type == "expense_application")
|
||||
reimbursement_count = sum(1 for item in tasks if item.task_type == "reimbursement")
|
||||
task_intent_summary = self._summarize_task_intents(tasks)
|
||||
ontology_summary = self._summarize_ontology_coverage(tasks)
|
||||
delegation_summary = self._summarize_delegation_targets(tasks)
|
||||
events = [
|
||||
StewardThinkingEvent(
|
||||
event_id="intent_agent_entry",
|
||||
stage="intent_agent",
|
||||
title="意图识别智能体接管",
|
||||
content=(
|
||||
f"检测到复合财务话术,当前不是单一助手会话;"
|
||||
f"已进入小财管家编排模式,候选任务共 {len(tasks)} 个。"
|
||||
),
|
||||
),
|
||||
StewardThinkingEvent(
|
||||
event_id="intent_task_split",
|
||||
stage="task_split",
|
||||
title=f"拆分申请 {application_count} 个、报销 {reimbursement_count} 个",
|
||||
content=task_intent_summary,
|
||||
),
|
||||
StewardThinkingEvent(
|
||||
event_id="intent_ontology_mapping",
|
||||
stage="ontology_mapping",
|
||||
title="映射业务本体字段",
|
||||
content=ontology_summary,
|
||||
),
|
||||
]
|
||||
if attachments:
|
||||
events.append(
|
||||
StewardThinkingEvent(
|
||||
event_id="intent_attachment_correlation",
|
||||
stage="attachment_correlation",
|
||||
title="关联附件与任务线索",
|
||||
content=self._summarize_attachment_correlation(attachment_groups, len(attachments)),
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
StewardThinkingEvent(
|
||||
event_id="intent_delegation_gate",
|
||||
stage="delegation_gate",
|
||||
title="生成确认点并准备分派",
|
||||
content=f"{delegation_summary} 创建单据、生成草稿、绑定附件和提交审批都会等待用户确认。",
|
||||
)
|
||||
)
|
||||
return events
|
||||
|
||||
@staticmethod
|
||||
def _summarize_task_intents(tasks: list[StewardTask]) -> str:
|
||||
if not tasks:
|
||||
return "当前输入尚未形成稳定任务,先保留为待确认财务事项。"
|
||||
parts = []
|
||||
for task in tasks:
|
||||
task_label = "申请" if task.task_type == "expense_application" else "报销"
|
||||
fields = task.ontology_fields
|
||||
anchors = []
|
||||
if fields.get("time_range"):
|
||||
anchors.append(fields["time_range"])
|
||||
if fields.get("location"):
|
||||
anchors.append(fields["location"])
|
||||
if fields.get("expense_type"):
|
||||
anchors.append(fields["expense_type"])
|
||||
anchor_text = "、".join(anchors) if anchors else "待补充关键字段"
|
||||
parts.append(f"{task_label}:{task.title}({anchor_text})")
|
||||
return ";".join(parts)
|
||||
|
||||
@staticmethod
|
||||
def _summarize_ontology_coverage(tasks: list[StewardTask]) -> str:
|
||||
canonical_keys = []
|
||||
missing_keys = []
|
||||
for task in tasks:
|
||||
canonical_keys.extend(task.ontology_fields.keys())
|
||||
missing_keys.extend(task.missing_fields)
|
||||
unique_keys = sorted({item for item in canonical_keys if item})
|
||||
unique_missing = sorted({item for item in missing_keys if item})
|
||||
mapped = "、".join(unique_keys) if unique_keys else "暂无稳定字段"
|
||||
missing = ";缺失字段:" + "、".join(unique_missing) if unique_missing else ""
|
||||
return f"已使用 canonical ontology fields:{mapped}{missing}。兼容字段只作为输入别名,不直接进入业务逻辑。"
|
||||
|
||||
@staticmethod
|
||||
def _summarize_attachment_correlation(
|
||||
attachment_groups: list[StewardAttachmentGroup],
|
||||
total_attachment_count: int,
|
||||
) -> str:
|
||||
grouped_names = []
|
||||
excluded_names = []
|
||||
for group in attachment_groups:
|
||||
grouped_names.extend(group.attachment_names)
|
||||
excluded_names.extend(group.excluded_attachment_names)
|
||||
grouped_text = "、".join(grouped_names) if grouped_names else "暂无可稳定归集附件"
|
||||
excluded_text = ";排除或单独确认:" + "、".join(excluded_names) if excluded_names else ""
|
||||
return f"已核对 {total_attachment_count} 份附件,建议归集:{grouped_text}{excluded_text}。"
|
||||
|
||||
@staticmethod
|
||||
def _summarize_delegation_targets(tasks: list[StewardTask]) -> str:
|
||||
application_count = sum(1 for item in tasks if item.assigned_agent == "application_assistant")
|
||||
reimbursement_count = sum(1 for item in tasks if item.assigned_agent == "reimbursement_assistant")
|
||||
parts = []
|
||||
if application_count:
|
||||
parts.append(f"{application_count} 个申请任务交给申请助手")
|
||||
if reimbursement_count:
|
||||
parts.append(f"{reimbursement_count} 个报销任务交给报销助手")
|
||||
return ";".join(parts) + "。" if parts else "尚无可分派任务。"
|
||||
|
||||
@staticmethod
|
||||
def _build_summary(tasks: list[StewardTask], attachment_groups: list[StewardAttachmentGroup]) -> str:
|
||||
parts = [f"我识别到 {len(tasks)} 个待处理任务"]
|
||||
if attachment_groups:
|
||||
grouped = sum(len(item.attachment_names) for item in attachment_groups)
|
||||
parts.append(f"并形成 {grouped} 份附件的归集建议")
|
||||
parts.append(",请确认后我再分派给对应助手执行。")
|
||||
return "".join(parts)
|
||||
|
||||
@staticmethod
|
||||
def _build_task_title(prefix: str, fields: dict[str, str], index: int) -> str:
|
||||
location = fields.get("location", "")
|
||||
time_range = fields.get("time_range", "")
|
||||
expense_type = fields.get("expense_type", "")
|
||||
subject = location or {"travel": "差旅", "transport": "交通", "entertainment": "招待"}.get(expense_type, "")
|
||||
if subject and time_range:
|
||||
return f"{prefix} {time_range} {subject}"
|
||||
if subject:
|
||||
return f"{prefix} {subject}"
|
||||
return f"{prefix} {index}"
|
||||
|
||||
@staticmethod
|
||||
def _build_task_summary(segment: str, fields: dict[str, str]) -> str:
|
||||
field_parts = []
|
||||
for key, label in (
|
||||
("time_range", "时间"),
|
||||
("location", "地点"),
|
||||
("expense_type", "费用类型"),
|
||||
("reason", "事由"),
|
||||
("transport_mode", "交通方式"),
|
||||
):
|
||||
value = fields.get(key)
|
||||
if value:
|
||||
field_parts.append(f"{label}:{value}")
|
||||
return ";".join(field_parts) or segment
|
||||
|
||||
@staticmethod
|
||||
def _resolve_base_date(client_now_iso: str | None, context_json: dict[str, Any]) -> date:
|
||||
raw_value = client_now_iso or str(context_json.get("client_now_iso") or "").strip()
|
||||
if raw_value:
|
||||
try:
|
||||
parsed = datetime.fromisoformat(raw_value.replace("Z", "+00:00"))
|
||||
return parsed.date()
|
||||
except ValueError:
|
||||
pass
|
||||
return datetime.now(UTC).date()
|
||||
|
||||
@staticmethod
|
||||
def _clean_text(value: Any) -> str:
|
||||
return re.sub(r"\s+", " ", str(value or "")).strip()
|
||||
@@ -1918,6 +1918,77 @@ def test_update_claim_item_reanalyzes_existing_attachment(monkeypatch, tmp_path)
|
||||
assert refreshed_meta["requirement_check"]["matches"] is False
|
||||
assert any("附件类型要求" in point for point in refreshed_meta["analysis"]["points"])
|
||||
|
||||
def test_upload_attachment_refreshes_claim_pre_review(monkeypatch, tmp_path) -> None:
|
||||
current_user = CurrentUserContext(
|
||||
username="emp-1",
|
||||
name="submitter",
|
||||
role_codes=[],
|
||||
is_admin=False,
|
||||
)
|
||||
review_calls: list[str] = []
|
||||
|
||||
def fake_recognize(
|
||||
self,
|
||||
files: list[tuple[str, bytes, str | None]],
|
||||
) -> OcrRecognizeBatchRead:
|
||||
return OcrRecognizeBatchRead(
|
||||
total_file_count=1,
|
||||
success_count=1,
|
||||
documents=[
|
||||
OcrRecognizeDocumentRead(
|
||||
filename="receipt.png",
|
||||
media_type="image/png",
|
||||
text="office receipt amount 88 2026-05-13",
|
||||
summary="recognized office receipt",
|
||||
avg_score=0.98,
|
||||
line_count=1,
|
||||
page_count=1,
|
||||
warnings=[],
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
def fake_review(self, reviewed_claim):
|
||||
review_calls.append(reviewed_claim.id)
|
||||
return {
|
||||
"risk_flags": [
|
||||
*list(reviewed_claim.risk_flags_json or []),
|
||||
{
|
||||
"source": "submission_review",
|
||||
"severity": "high",
|
||||
"label": "upload-time-risk",
|
||||
"message": "risk generated after attachment upload",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
monkeypatch.setattr(OcrService, "recognize_files", fake_recognize)
|
||||
monkeypatch.setattr(ExpenseClaimAttachmentStorage, "root", lambda self: tmp_path)
|
||||
monkeypatch.setattr(ExpenseClaimService, "_run_ai_submission_review", fake_review)
|
||||
|
||||
with build_session() as db:
|
||||
claim = build_claim(expense_type="office", location="Shanghai")
|
||||
claim.invoice_count = 0
|
||||
claim.items[0].invoice_id = None
|
||||
db.add(claim)
|
||||
db.commit()
|
||||
|
||||
payload = ExpenseClaimService(db).upload_claim_item_attachment(
|
||||
claim_id=claim.id,
|
||||
item_id=claim.items[0].id,
|
||||
filename="receipt.png",
|
||||
content=b"fake-image-bytes",
|
||||
media_type="image/png",
|
||||
current_user=current_user,
|
||||
)
|
||||
|
||||
flags = payload["claim_risk_flags"]
|
||||
assert review_calls == [claim.id]
|
||||
assert any(flag.get("label") == "upload-time-risk" for flag in flags)
|
||||
pre_review = next(flag for flag in flags if flag.get("source") == "ai_pre_review")
|
||||
assert pre_review["status"] == "failed"
|
||||
assert pre_review["blocking_risk_count"] >= 1
|
||||
|
||||
|
||||
def test_upload_train_ticket_attachment_backfills_item_amount(monkeypatch, tmp_path) -> None:
|
||||
current_user = CurrentUserContext(
|
||||
@@ -2619,6 +2690,60 @@ def test_submit_claim_runs_ai_review_and_routes_to_direct_manager() -> None:
|
||||
assert submitted.approval_stage == "直属领导审批"
|
||||
assert submitted.submitted_at is not None
|
||||
|
||||
def test_submit_claim_reuses_upload_pre_review_without_rerunning_review(monkeypatch) -> None:
|
||||
current_user = CurrentUserContext(
|
||||
username="emp-submit@example.com",
|
||||
name="submitter",
|
||||
role_codes=[],
|
||||
is_admin=False,
|
||||
)
|
||||
|
||||
def fail_review(self, reviewed_claim):
|
||||
raise AssertionError("submit should reuse upload-time pre-review")
|
||||
|
||||
monkeypatch.setattr(ExpenseClaimService, "_run_ai_submission_review", fail_review)
|
||||
|
||||
with build_session() as db:
|
||||
manager = Employee(
|
||||
employee_no="E7010",
|
||||
name="Manager",
|
||||
email="manager-reuse@example.com",
|
||||
)
|
||||
employee = Employee(
|
||||
employee_no="E7011",
|
||||
name="submitter",
|
||||
email="emp-submit@example.com",
|
||||
manager=manager,
|
||||
)
|
||||
claim = build_claim(expense_type="transport", location="Shanghai")
|
||||
claim.employee = employee
|
||||
claim.employee_id = employee.id
|
||||
claim.items[0].invoice_id = "taxi-ticket.png"
|
||||
claim.risk_flags_json = [
|
||||
{
|
||||
"source": "submission_review",
|
||||
"severity": "medium",
|
||||
"label": "upload-time-warning",
|
||||
"message": "generated before submit",
|
||||
},
|
||||
{
|
||||
"source": "ai_pre_review",
|
||||
"status": "passed",
|
||||
"passed": True,
|
||||
"severity": "info",
|
||||
"blocking_risk_count": 0,
|
||||
},
|
||||
]
|
||||
db.add_all([manager, employee, claim])
|
||||
db.commit()
|
||||
|
||||
submitted = ExpenseClaimService(db).submit_claim(claim.id, current_user)
|
||||
|
||||
assert submitted is not None
|
||||
assert submitted.status == "submitted"
|
||||
assert any(flag.get("label") == "upload-time-warning" for flag in submitted.risk_flags_json)
|
||||
assert any(flag.get("source") == "ai_pre_review" for flag in submitted.risk_flags_json)
|
||||
|
||||
|
||||
def test_accept_standard_adjustment_recalculates_claim_amount_and_preserves_on_submit() -> None:
|
||||
current_user = CurrentUserContext(
|
||||
@@ -2669,28 +2794,92 @@ def test_accept_standard_adjustment_recalculates_claim_amount_and_preserves_on_s
|
||||
)
|
||||
|
||||
assert adjusted is not None
|
||||
assert adjusted.amount == Decimal("600.00")
|
||||
assert adjusted.amount == Decimal("450.00")
|
||||
standard_flag = next(
|
||||
flag
|
||||
for flag in adjusted.risk_flags_json
|
||||
if isinstance(flag, dict) and flag.get("source") == "reimbursement_standard_adjustment"
|
||||
)
|
||||
assert standard_flag["original_amount"] == "880.00"
|
||||
assert standard_flag["reimbursable_amount"] == "600.00"
|
||||
assert standard_flag["employee_absorbed_amount"] == "280.00"
|
||||
assert standard_flag["reimbursable_amount"] == "450.00"
|
||||
assert standard_flag["employee_absorbed_amount"] == "430.00"
|
||||
assert standard_flag["visibility_scope"] == "leader"
|
||||
|
||||
submitted = service.submit_claim(claim.id, current_user)
|
||||
|
||||
assert submitted is not None
|
||||
assert submitted.status == "submitted"
|
||||
assert submitted.amount == Decimal("600.00")
|
||||
assert submitted.amount == Decimal("450.00")
|
||||
assert any(
|
||||
isinstance(flag, dict) and flag.get("source") == "reimbursement_standard_adjustment"
|
||||
for flag in submitted.risk_flags_json
|
||||
)
|
||||
|
||||
|
||||
def test_accept_standard_adjustment_uses_policy_amount_when_payload_has_no_downgrade() -> None:
|
||||
current_user = CurrentUserContext(
|
||||
username="emp-policy-standard@example.com",
|
||||
name="张三",
|
||||
role_codes=[],
|
||||
is_admin=False,
|
||||
grade="P4",
|
||||
)
|
||||
|
||||
with build_session() as db:
|
||||
manager = Employee(
|
||||
employee_no="E7032",
|
||||
name="李经理",
|
||||
email="manager-policy-standard@example.com",
|
||||
)
|
||||
employee = Employee(
|
||||
employee_no="E7033",
|
||||
name="张三",
|
||||
email="emp-policy-standard@example.com",
|
||||
grade="P4",
|
||||
manager=manager,
|
||||
)
|
||||
claim = build_claim(expense_type="hotel", location="北京")
|
||||
claim.employee = employee
|
||||
claim.employee_id = employee.id
|
||||
claim.amount = Decimal("1000.00")
|
||||
claim.items[0].item_type = "hotel_ticket"
|
||||
claim.items[0].item_reason = "北京住宿"
|
||||
claim.items[0].item_location = "北京"
|
||||
claim.items[0].item_amount = Decimal("1000.00")
|
||||
db.add_all([manager, employee, claim])
|
||||
db.commit()
|
||||
|
||||
adjusted = ExpenseClaimService(db).accept_standard_adjustment(
|
||||
claim_id=claim.id,
|
||||
payload=ExpenseClaimStandardAdjustmentPayload(
|
||||
risks=[
|
||||
{
|
||||
"risk_id": "risk-hotel-policy-1",
|
||||
"item_id": claim.items[0].id,
|
||||
"title": "住宿超标待说明",
|
||||
"risk": "住宿票据金额超过职级标准。",
|
||||
"application_days": 2,
|
||||
"original_amount": Decimal("1000.00"),
|
||||
"reimbursable_amount": Decimal("1000.00"),
|
||||
}
|
||||
]
|
||||
),
|
||||
current_user=current_user,
|
||||
)
|
||||
|
||||
assert adjusted is not None
|
||||
assert adjusted.amount == Decimal("900.00")
|
||||
standard_flag = next(
|
||||
flag
|
||||
for flag in adjusted.risk_flags_json
|
||||
if isinstance(flag, dict) and flag.get("source") == "reimbursement_standard_adjustment"
|
||||
)
|
||||
assert standard_flag["original_amount"] == "1000.00"
|
||||
assert standard_flag["reimbursable_amount"] == "900.00"
|
||||
assert standard_flag["employee_absorbed_amount"] == "100.00"
|
||||
assert standard_flag["visibility_scope"] == "leader"
|
||||
|
||||
|
||||
def test_pre_review_claim_records_ai_result_without_submitting() -> None:
|
||||
current_user = CurrentUserContext(
|
||||
username="emp-pre-review@example.com",
|
||||
|
||||
@@ -113,6 +113,53 @@ def seed_claim(db: Session) -> tuple[ExpenseClaim, ExpenseClaimItem]:
|
||||
return claim, item
|
||||
|
||||
|
||||
def test_claim_standard_adjustment_endpoint_recalculates_and_marks_reviewer_notice() -> None:
|
||||
client, session_factory = build_client()
|
||||
with session_factory() as db:
|
||||
claim, item = seed_claim(db)
|
||||
claim.expense_type = "hotel"
|
||||
claim.location = "北京"
|
||||
claim.amount = Decimal("1000.00")
|
||||
item.item_type = "hotel_ticket"
|
||||
item.item_reason = "北京住宿"
|
||||
item.item_location = "北京"
|
||||
item.item_amount = Decimal("1000.00")
|
||||
db.commit()
|
||||
claim_id = claim.id
|
||||
item_id = item.id
|
||||
|
||||
response = client.post(
|
||||
f"/api/v1/reimbursements/claims/{claim_id}/standard-adjustment",
|
||||
json={
|
||||
"risks": [
|
||||
{
|
||||
"risk_id": "risk-hotel-endpoint-1",
|
||||
"item_id": item_id,
|
||||
"title": "住宿超标待说明",
|
||||
"risk": "住宿票据金额超过职级标准。",
|
||||
"application_days": 2,
|
||||
"original_amount": "1000.00",
|
||||
"reimbursable_amount": "1000.00",
|
||||
}
|
||||
]
|
||||
},
|
||||
headers={"x-auth-username": "emp-1", "x-auth-name": "Zhang San", "x-auth-grade": "P4"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["amount"] == "900.00"
|
||||
standard_flag = next(
|
||||
flag
|
||||
for flag in payload["risk_flags_json"]
|
||||
if isinstance(flag, dict) and flag.get("source") == "reimbursement_standard_adjustment"
|
||||
)
|
||||
assert standard_flag["original_amount"] == "1000.00"
|
||||
assert standard_flag["reimbursable_amount"] == "900.00"
|
||||
assert standard_flag["employee_absorbed_amount"] == "100.00"
|
||||
assert standard_flag["visibility_scope"] == "leader"
|
||||
|
||||
|
||||
def test_claim_item_attachment_upload_preview_and_delete(monkeypatch, tmp_path) -> None:
|
||||
def fake_recognize(
|
||||
self,
|
||||
|
||||
@@ -4,6 +4,7 @@ from collections.abc import Generator
|
||||
from datetime import UTC, datetime
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy import create_engine
|
||||
@@ -124,6 +125,27 @@ def test_platform_rule_flags_are_persisted_as_risk_observations() -> None:
|
||||
assert persisted.contribution_scores_json == {"S_rule": 100}
|
||||
|
||||
|
||||
def test_risk_observation_storage_ready_is_cached_per_bind(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
with _build_session() as db:
|
||||
RiskObservationService._storage_ready_cache.clear()
|
||||
create_all_calls = []
|
||||
original_create_all = Base.metadata.create_all
|
||||
|
||||
def spy_create_all(*args, **kwargs):
|
||||
create_all_calls.append(kwargs.get("bind"))
|
||||
return original_create_all(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(Base.metadata, "create_all", spy_create_all)
|
||||
|
||||
service = RiskObservationService(db)
|
||||
service.ensure_storage_ready()
|
||||
service.ensure_storage_ready()
|
||||
RiskObservationService(db).ensure_storage_ready()
|
||||
|
||||
assert len(create_all_calls) == 1
|
||||
RiskObservationService._storage_ready_cache.clear()
|
||||
|
||||
|
||||
def test_risk_observation_endpoints_return_list_detail_dashboard_and_feedback() -> None:
|
||||
client, session_factory = _build_client()
|
||||
with session_factory() as db:
|
||||
|
||||
@@ -150,6 +150,62 @@ def test_runtime_chat_disables_glm_thinking_for_direct_user_answers(monkeypatch)
|
||||
assert captured["timeout_seconds"] == 17
|
||||
|
||||
|
||||
def test_runtime_chat_openai_compatible_tool_call_payload(monkeypatch) -> None:
|
||||
_clear_runtime_chat_cooldown()
|
||||
session_factory = build_session_factory()
|
||||
with session_factory() as db:
|
||||
service = RuntimeChatService(db)
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
def fake_send_json_request(method, url, *, headers, payload, timeout_seconds):
|
||||
captured["method"] = method
|
||||
captured["url"] = url
|
||||
captured["headers"] = headers
|
||||
captured["payload"] = payload
|
||||
captured["timeout_seconds"] = timeout_seconds
|
||||
return 200, {
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_001",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "submit_steward_intent_plan",
|
||||
"arguments": "{\"tasks\": []}",
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
monkeypatch.setattr("app.services.runtime_chat._send_json_request", fake_send_json_request)
|
||||
|
||||
tool_call = service._request_openai_compatible_tool_call(
|
||||
provider="OpenAI Compatible",
|
||||
endpoint="https://api.example.com/v1",
|
||||
model="gpt-test",
|
||||
api_key="secret",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
tools=[{"type": "function", "function": {"name": "submit_steward_intent_plan"}}],
|
||||
tool_choice={"type": "function", "function": {"name": "submit_steward_intent_plan"}},
|
||||
max_tokens=128,
|
||||
temperature=0.1,
|
||||
timeout_seconds=19,
|
||||
)
|
||||
|
||||
assert tool_call is not None
|
||||
assert tool_call.name == "submit_steward_intent_plan"
|
||||
assert tool_call.arguments == {"tasks": []}
|
||||
assert captured["url"] == "https://api.example.com/v1/chat/completions"
|
||||
assert captured["payload"]["tools"][0]["function"]["name"] == "submit_steward_intent_plan"
|
||||
assert captured["payload"]["tool_choice"]["function"]["name"] == "submit_steward_intent_plan"
|
||||
assert captured["headers"]["Authorization"] == "Bearer secret"
|
||||
|
||||
|
||||
def test_runtime_chat_supports_single_pass_fast_failover(monkeypatch) -> None:
|
||||
_clear_runtime_chat_cooldown()
|
||||
session_factory = build_session_factory()
|
||||
|
||||
214
server/tests/test_steward_planner.py
Normal file
214
server/tests/test_steward_planner.py
Normal file
@@ -0,0 +1,214 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.main import create_app
|
||||
from app.schemas.steward import StewardAttachmentInput, StewardPlanRequest
|
||||
from app.services.steward_intent_agent import StewardIntentAgentResult
|
||||
from app.services.steward_planner import StewardPlannerService
|
||||
|
||||
|
||||
class FakeFunctionCallingIntentAgent:
|
||||
def detect(self, request, *, base_date, canonical_fields):
|
||||
assert "expense_type" in canonical_fields
|
||||
assert base_date.isoformat() == "2026-06-04"
|
||||
return StewardIntentAgentResult(
|
||||
payload={
|
||||
"thinking_events": [
|
||||
{
|
||||
"stage": "task_split",
|
||||
"title": "识别复合报销意图",
|
||||
"content": "模型工具调用识别出 1 个报销任务,并关联本次上传的交通附件。",
|
||||
}
|
||||
],
|
||||
"tasks": [
|
||||
{
|
||||
"task_type": "reimbursement",
|
||||
"title": "费用报销 2026-06-03 交通",
|
||||
"summary": "报销昨天客户现场沟通产生的交通费。",
|
||||
"confidence": 0.91,
|
||||
"ontology_fields": {
|
||||
"occurred_date": "昨天",
|
||||
"transport_type": "出租车",
|
||||
"reason_value": "客户现场沟通",
|
||||
"expense_type": "交通费",
|
||||
"unregistered_field": "不能进入业务字段",
|
||||
},
|
||||
"missing_fields": ["amount", "transport_type"],
|
||||
}
|
||||
],
|
||||
"attachment_groups": [
|
||||
{
|
||||
"target_task_index": 1,
|
||||
"scene": "transport",
|
||||
"scene_label": "交通费用",
|
||||
"attachment_names": ["出租车票.png"],
|
||||
"excluded_attachment_names": ["客户招待发票.jpg"],
|
||||
"confidence": 0.86,
|
||||
"rationale": "出租车票与交通报销任务匹配,招待发票不归入该任务。",
|
||||
}
|
||||
],
|
||||
},
|
||||
model_call_traces=[
|
||||
{
|
||||
"slot": "main",
|
||||
"provider": "OpenAI Compatible",
|
||||
"model": "gpt-test",
|
||||
"attempt": 1,
|
||||
"status": "succeeded",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class EmptyFunctionCallingIntentAgent:
|
||||
def detect(self, request, *, base_date, canonical_fields):
|
||||
return None
|
||||
|
||||
|
||||
def test_steward_planner_uses_llm_function_calling_plan_when_available() -> None:
|
||||
payload = StewardPlanRequest(
|
||||
message="我要报销昨天客户现场沟通的交通费",
|
||||
client_now_iso="2026-06-04T09:30:00+08:00",
|
||||
attachments=[
|
||||
StewardAttachmentInput(name="出租车票.png"),
|
||||
StewardAttachmentInput(name="客户招待发票.jpg"),
|
||||
],
|
||||
)
|
||||
|
||||
result = StewardPlannerService(intent_agent=FakeFunctionCallingIntentAgent()).build_plan(payload)
|
||||
|
||||
assert result.planning_source == "llm_function_call"
|
||||
assert result.model_call_traces[0]["status"] == "succeeded"
|
||||
assert len(result.tasks) == 1
|
||||
fields = result.tasks[0].ontology_fields
|
||||
assert fields["time_range"] == "2026-06-03"
|
||||
assert fields["transport_mode"] == "taxi"
|
||||
assert fields["reason"] == "客户现场沟通"
|
||||
assert fields["expense_type"] == "transport"
|
||||
assert "occurred_date" not in fields
|
||||
assert "transport_type" not in fields
|
||||
assert "reason_value" not in fields
|
||||
assert "unregistered_field" not in fields
|
||||
assert result.tasks[0].missing_fields == ["amount"]
|
||||
assert result.attachment_groups[0].attachment_names == ["出租车票.png"]
|
||||
assert result.attachment_groups[0].excluded_attachment_names == ["客户招待发票.jpg"]
|
||||
assert result.thinking_events[0].stage == "llm_function_call"
|
||||
|
||||
|
||||
def test_steward_planner_falls_back_to_rules_when_function_calling_is_unavailable() -> None:
|
||||
payload = StewardPlanRequest(
|
||||
message="我要报销昨天的交通费",
|
||||
client_now_iso="2026-06-04T09:30:00+08:00",
|
||||
)
|
||||
|
||||
result = StewardPlannerService(intent_agent=EmptyFunctionCallingIntentAgent()).build_plan(payload)
|
||||
|
||||
assert result.planning_source == "rule_fallback"
|
||||
assert result.tasks[0].ontology_fields["time_range"] == "2026-06-03"
|
||||
assert result.thinking_events[0].stage == "rule_fallback"
|
||||
|
||||
|
||||
def test_steward_planner_splits_application_and_reimbursement_tasks() -> None:
|
||||
payload = StewardPlanRequest(
|
||||
message=(
|
||||
"我想要申请7月2日去北京出差,辅助北京供电局的税务审核任务,"
|
||||
"并且我要报销昨天的交通费,还需要报销6月3日出差去上海的费用"
|
||||
),
|
||||
user_id="u001",
|
||||
client_now_iso="2026-06-04T09:30:00+08:00",
|
||||
)
|
||||
|
||||
result = StewardPlannerService().build_plan(payload)
|
||||
|
||||
assert len(result.tasks) == 3
|
||||
assert [task.task_type for task in result.tasks] == [
|
||||
"expense_application",
|
||||
"reimbursement",
|
||||
"reimbursement",
|
||||
]
|
||||
assert result.tasks[0].assigned_agent == "application_assistant"
|
||||
assert result.tasks[0].ontology_fields["time_range"] == "2026-07-02"
|
||||
assert result.tasks[0].ontology_fields["location"] == "北京"
|
||||
assert result.tasks[0].ontology_fields["reason"] == "辅助北京供电局的税务审核任务"
|
||||
assert result.tasks[1].ontology_fields["time_range"] == "2026-06-03"
|
||||
assert result.tasks[1].ontology_fields["expense_type"] == "transport"
|
||||
assert result.tasks[2].ontology_fields["time_range"] == "2026-06-03"
|
||||
assert result.tasks[2].ontology_fields["location"] == "上海"
|
||||
assert result.tasks[2].ontology_fields["expense_type"] == "travel"
|
||||
assert all(action.status == "pending" for action in result.confirmation_groups)
|
||||
|
||||
|
||||
def test_steward_planner_outputs_only_canonical_ontology_fields() -> None:
|
||||
payload = StewardPlanRequest(
|
||||
message="我要报销昨天的交通费",
|
||||
client_now_iso="2026-06-04T09:30:00+08:00",
|
||||
context_json={
|
||||
"review_form_values": {
|
||||
"occurred_date": "2026-06-03",
|
||||
"transport_type": "taxi",
|
||||
"reason_value": "客户现场沟通",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
result = StewardPlannerService().build_plan(payload)
|
||||
|
||||
fields = result.tasks[0].ontology_fields
|
||||
assert fields["time_range"] == "2026-06-03"
|
||||
assert fields["transport_mode"] == "taxi"
|
||||
assert fields["reason"] == "客户现场沟通"
|
||||
assert "occurred_date" not in fields
|
||||
assert "transport_type" not in fields
|
||||
assert "reason_value" not in fields
|
||||
|
||||
|
||||
def test_steward_planner_builds_travel_attachment_group_with_exclusions() -> None:
|
||||
payload = StewardPlanRequest(
|
||||
message="还需要报销6月3日出差去上海的费用",
|
||||
client_now_iso="2026-06-04T09:30:00+08:00",
|
||||
attachments=[
|
||||
StewardAttachmentInput(name="上海高铁票.jpg"),
|
||||
StewardAttachmentInput(name="上海酒店发票.pdf"),
|
||||
StewardAttachmentInput(name="出租车票.png"),
|
||||
StewardAttachmentInput(name="客户招待发票.jpg"),
|
||||
],
|
||||
)
|
||||
|
||||
result = StewardPlannerService().build_plan(payload)
|
||||
|
||||
assert len(result.attachment_groups) == 1
|
||||
group = result.attachment_groups[0]
|
||||
assert group.scene == "travel"
|
||||
assert group.attachment_names == ["上海高铁票.jpg", "上海酒店发票.pdf", "出租车票.png"]
|
||||
assert group.excluded_attachment_names == ["客户招待发票.jpg"]
|
||||
assert group.confirmation_required is True
|
||||
attachment_actions = [
|
||||
action for action in result.confirmation_groups if action.action_type == "confirm_attachment_group"
|
||||
]
|
||||
assert len(attachment_actions) == 1
|
||||
|
||||
|
||||
def test_steward_stream_endpoint_emits_thinking_before_plan() -> None:
|
||||
client = TestClient(create_app())
|
||||
|
||||
with client.stream(
|
||||
"POST",
|
||||
"/api/v1/steward/plans/stream",
|
||||
json={
|
||||
"message": "我要报销昨天的交通费",
|
||||
"client_now_iso": "2026-06-04T09:30:00+08:00",
|
||||
},
|
||||
) as response:
|
||||
assert response.status_code == 200
|
||||
events = [
|
||||
json.loads(line.decode("utf-8") if isinstance(line, bytes) else line)
|
||||
for line in response.iter_lines()
|
||||
if line
|
||||
]
|
||||
|
||||
assert [event["event"] for event in events][:2] == ["thinking", "thinking"]
|
||||
assert events[-1]["event"] == "plan"
|
||||
assert events[-1]["data"]["tasks"][0]["ontology_fields"]["time_range"] == "2026-06-03"
|
||||
Reference in New Issue
Block a user