Files
X-Financial/server/src/app/services/linked_reimbursement_draft_jobs.py
caoxiaozhu 332f77389d feat(server): 新增附件关联/关联报销草稿后台任务与申请位置语义
- attachment_association_jobs:从票据夹批量关联附件到报销单,识别城市/日期并创建明细项,内存态 job 跟踪
- linked_reimbursement_draft_jobs:基于申请单异步生成关联报销草稿,调用 Orchestrator 编排,区分 succeeded/failed 终态
- application_location_semantics:抽取差旅出发/到达城市、判断具体地址/业务动作等位置语义,供申请单校验复用
- router 注册两个 job 端点,新增对应 job/语义单元测试
2026-06-24 10:42:05 +08:00

292 lines
11 KiB
Python

from __future__ import annotations
from dataclasses import dataclass, field
from datetime import UTC, datetime
from threading import Lock
from typing import Any, Callable
from uuid import uuid4
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from app.api.deps import CurrentUserContext
from app.schemas.linked_reimbursement_draft_job import (
LinkedReimbursementDraftJobCreate,
LinkedReimbursementDraftJobRead,
)
from app.schemas.ontology import OntologyParseResult, OntologyPermission
from app.schemas.orchestrator import OrchestratorRequest
from app.models.financial_record import ExpenseClaim
from app.services.expense_claims import ExpenseClaimService
from app.services.orchestrator import OrchestratorService
TERMINAL_STATUSES = {"succeeded", "failed"}
@dataclass(slots=True)
class LinkedReimbursementDraftJobState:
job_id: str
owner_username: str
owner_name: str
message: str
context_json: dict[str, Any]
conversation_id: str = ""
status: str = "queued"
status_message: str = "已创建报销草稿生成任务,等待后台处理。"
error: str = ""
run_id: str = ""
draft_payload: dict[str, Any] | None = None
created_at: datetime = field(default_factory=lambda: datetime.now(UTC))
updated_at: datetime = field(default_factory=lambda: datetime.now(UTC))
def to_read(self) -> LinkedReimbursementDraftJobRead:
return LinkedReimbursementDraftJobRead(
job_id=self.job_id,
status=self.status,
message=self.status_message,
error=self.error,
run_id=self.run_id,
conversation_id=self.conversation_id,
draft_payload=self.draft_payload,
created_at=self.created_at,
updated_at=self.updated_at,
)
_jobs: dict[str, LinkedReimbursementDraftJobState] = {}
_jobs_lock = Lock()
def clear_linked_reimbursement_draft_jobs_for_tests() -> None:
with _jobs_lock:
_jobs.clear()
def create_linked_reimbursement_draft_job(
payload: LinkedReimbursementDraftJobCreate,
current_user: CurrentUserContext,
) -> LinkedReimbursementDraftJobRead:
context_json = dict(payload.context_json or {})
context_json["entry_source"] = context_json.get("entry_source") or "workbench-ai"
context_json["session_type"] = context_json.get("session_type") or "expense"
job_id = f"linked-reimbursement-draft-{uuid4()}"
state = LinkedReimbursementDraftJobState(
job_id=job_id,
owner_username=str(current_user.username or "").strip(),
owner_name=str(current_user.name or "").strip(),
message=str(payload.message or "").strip(),
context_json=context_json,
conversation_id=str(payload.conversation_id or "").strip(),
)
with _jobs_lock:
_jobs[job_id] = state
return state.to_read()
def get_linked_reimbursement_draft_job(
job_id: str,
current_user: CurrentUserContext,
) -> LinkedReimbursementDraftJobRead | None:
state = _get_authorized_state(job_id, current_user)
return state.to_read() if state is not None else None
def run_linked_reimbursement_draft_job(
job_id: str,
current_user: CurrentUserContext,
session_factory: sessionmaker[Session] | Callable[[], Session],
) -> None:
state = _get_authorized_state(job_id, current_user)
if state is None or state.status in TERMINAL_STATUSES:
return
_update_job(job_id, status="running", status_message="正在后台生成报销草稿...")
try:
with session_factory() as db:
if _can_use_direct_save_path(db, state.context_json):
run_id, result, draft_payload = _run_direct_save_path(
db=db,
state=state,
current_user=current_user,
)
else:
response = OrchestratorService(db).run(
OrchestratorRequest(
source="user_message",
user_id=_resolve_user_id(current_user),
conversation_id=None,
message=state.message,
context_json=dict(state.context_json),
)
)
run_id = response.run_id
result = response.result if isinstance(response.result, dict) else {}
draft_payload = result.get("draft_payload") if isinstance(result.get("draft_payload"), dict) else None
if response.status != "succeeded":
raise ValueError(str(result.get("message") or "报销草稿生成失败,请稍后重试。").strip())
if draft_payload is None:
raise ValueError("报销草稿生成完成,但未返回草稿信息,请刷新单据列表后核对。")
_update_job(
job_id,
status="succeeded",
status_message=str(result.get("message") or "报销草稿已生成。").strip(),
run_id=run_id,
draft_payload=draft_payload,
error="",
)
except Exception as exc:
message = str(exc).strip() or "报销草稿生成失败,请稍后重试。"
_update_job(
job_id,
status="failed",
status_message=message,
error=message,
)
def _get_authorized_state(
job_id: str,
current_user: CurrentUserContext,
) -> LinkedReimbursementDraftJobState | None:
normalized_job_id = str(job_id or "").strip()
with _jobs_lock:
state = _jobs.get(normalized_job_id)
if state is None:
return None
if current_user.is_admin:
return state
username = str(current_user.username or "").strip()
name = str(current_user.name or "").strip()
if username and username == state.owner_username:
return state
if name and name == state.owner_name:
return state
return None
def _update_job(job_id: str, **updates: Any) -> None:
with _jobs_lock:
state = _jobs.get(str(job_id or "").strip())
if state is None:
return
for key, value in updates.items():
if hasattr(state, key):
setattr(state, key, value)
state.updated_at = datetime.now(UTC)
def _resolve_user_id(current_user: CurrentUserContext) -> str:
return str(current_user.username or current_user.name or "anonymous").strip() or "anonymous"
def _can_use_direct_save_path(db: Session, context_json: dict[str, Any]) -> bool:
review_action = str((context_json or {}).get("review_action") or "").strip()
if review_action != "save_draft":
return False
review_values = context_json.get("review_form_values")
if not isinstance(review_values, dict):
return False
application_claim_id = str(review_values.get("application_claim_id") or "").strip()
application_claim_no = str(review_values.get("application_claim_no") or "").strip()
if not application_claim_no:
return False
if application_claim_id:
return True
return _find_application_claim_by_no(db, application_claim_no) is not None
def _find_application_claim_by_no(db: Session, claim_no: str) -> ExpenseClaim | None:
normalized_claim_no = str(claim_no or "").strip()
if not normalized_claim_no:
return None
claim = db.scalar(
select(ExpenseClaim)
.where(ExpenseClaim.claim_no == normalized_claim_no)
.limit(1)
)
if claim is not None and ExpenseClaimService._is_expense_application_claim(claim):
return claim
return None
def _build_direct_context_json(db: Session, context_json: dict[str, Any]) -> dict[str, Any]:
direct_context = dict(context_json or {})
review_values = dict(direct_context.get("review_form_values") or {})
scene_selection = dict(direct_context.get("expense_scene_selection") or {})
application_claim_id = str(review_values.get("application_claim_id") or "").strip()
application_claim_no = str(review_values.get("application_claim_no") or "").strip()
if not application_claim_id and application_claim_no:
application_claim = _find_application_claim_by_no(db, application_claim_no)
if application_claim is not None:
review_values["application_claim_id"] = application_claim.id
scene_selection["application_claim_id"] = application_claim.id
scene_selection["application_claim_no"] = str(
scene_selection.get("application_claim_no")
or application_claim.claim_no
or application_claim_no
).strip()
direct_context["review_form_values"] = review_values
if scene_selection:
direct_context["expense_scene_selection"] = scene_selection
return direct_context
def _run_direct_save_path(
*,
db: Session,
state: LinkedReimbursementDraftJobState,
current_user: CurrentUserContext,
) -> tuple[str, dict[str, Any], dict[str, Any]]:
run_id = state.job_id
ontology = OntologyParseResult(
scenario="expense",
intent="draft",
permission=OntologyPermission(
level="draft_write",
allowed=True,
reason="关联申请单生成报销草稿快路径。",
),
confidence=1.0,
run_id=run_id,
)
result = ExpenseClaimService(db).save_or_submit_from_ontology(
run_id=run_id,
user_id=_resolve_user_id(current_user),
message=state.message,
ontology=ontology,
context_json=_build_direct_context_json(db, state.context_json),
)
claim_id = str(result.get("claim_id") or "").strip()
claim_no = str(result.get("claim_no") or "").strip()
if not claim_id or not claim_no or str(result.get("status") or "").strip() != "draft":
raise ValueError(str(result.get("message") or "报销草稿生成失败,请稍后重试。").strip())
claim = db.get(ExpenseClaim, claim_id)
return run_id, result, _build_direct_draft_payload(result, claim)
def _build_direct_draft_payload(
result: dict[str, Any],
claim: ExpenseClaim | None,
) -> dict[str, Any]:
claim_id = str(result.get("claim_id") or getattr(claim, "id", "") or "").strip()
claim_no = str(result.get("claim_no") or getattr(claim, "claim_no", "") or "").strip()
status = str(result.get("status") or getattr(claim, "status", "") or "draft").strip()
approval_stage = str(getattr(claim, "approval_stage", "") or "待提交").strip()
expense_type = str(getattr(claim, "expense_type", "") or "").strip()
message = str(result.get("message") or "报销草稿已生成。").strip()
return {
"draft_type": "expense",
"title": f"费用草稿 {claim_no}" if claim_no else "费用草稿",
"body": message,
"confirmation_required": True,
"claim_id": claim_id,
"claim_no": claim_no,
"status": status,
"approval_stage": approval_stage,
"expense_type": expense_type,
}