Files
X-Financial/server/src/app/services/attachment_association_jobs.py
caoxiaozhu 332f77389d feat(server): 新增附件关联/关联报销草稿后台任务与申请位置语义
- attachment_association_jobs:从票据夹批量关联附件到报销单,识别城市/日期并创建明细项,内存态 job 跟踪
- linked_reimbursement_draft_jobs:基于申请单异步生成关联报销草稿,调用 Orchestrator 编排,区分 succeeded/failed 终态
- application_location_semantics:抽取差旅出发/到达城市、判断具体地址/业务动作等位置语义,供申请单校验复用
- router 注册两个 job 端点,新增对应 job/语义单元测试
2026-06-24 10:42:05 +08:00

550 lines
18 KiB
Python

from __future__ import annotations
import re
from dataclasses import dataclass, field
from datetime import UTC, date, datetime
from decimal import Decimal
from threading import Lock
from typing import Any, Callable
from uuid import uuid4
from sqlalchemy.orm import Session, sessionmaker
from app.api.deps import CurrentUserContext
from app.models.financial_record import ExpenseClaim, ExpenseClaimItem
from app.schemas.attachment_association_job import (
AttachmentAssociationJobCreate,
AttachmentAssociationJobRead,
)
from app.schemas.receipt_folder import ReceiptFolderDetailRead
from app.schemas.reimbursement import ExpenseClaimItemCreate
from app.services.expense_claim_constants import (
DOCUMENT_TYPE_ITEM_TYPE_MAP,
EDITABLE_CLAIM_STATUSES,
)
from app.services.expense_claims import ExpenseClaimService
from app.services.receipt_folder import ReceiptFolderService
CITY_NAMES = (
"北京",
"上海",
"广州",
"深圳",
"武汉",
"南京",
"杭州",
"成都",
"重庆",
"西安",
"天津",
"苏州",
"长沙",
"郑州",
"青岛",
"厦门",
"宁波",
"无锡",
"合肥",
"福州",
"昆明",
"大连",
"沈阳",
"济南",
"哈尔滨",
"长春",
"南昌",
"太原",
"贵阳",
"南宁",
"石家庄",
"兰州",
"银川",
"西宁",
"海口",
"拉萨",
)
TERMINAL_STATUSES = {"succeeded", "failed"}
@dataclass(slots=True)
class AttachmentAssociationJobState:
job_id: str
owner_username: str
owner_name: str
receipt_ids: list[str]
prompt: str = ""
conversation_id: str = ""
status: str = "queued"
message: str = "已创建附件关联任务,等待后台处理。"
claim_id: str = ""
claim_no: str = ""
uploaded_count: int = 0
skipped_count: int = 0
error: str = ""
created_at: datetime = field(default_factory=lambda: datetime.now(UTC))
updated_at: datetime = field(default_factory=lambda: datetime.now(UTC))
def to_read(self) -> AttachmentAssociationJobRead:
return AttachmentAssociationJobRead(
job_id=self.job_id,
status=self.status,
message=self.message,
receipt_ids=list(self.receipt_ids),
claim_id=self.claim_id,
claim_no=self.claim_no,
uploaded_count=self.uploaded_count,
skipped_count=self.skipped_count,
error=self.error,
prompt=self.prompt,
conversation_id=self.conversation_id,
created_at=self.created_at,
updated_at=self.updated_at,
)
@dataclass(slots=True)
class AttachmentAssociationCandidate:
claim: ExpenseClaim
score: int
reasons: list[str]
_jobs: dict[str, AttachmentAssociationJobState] = {}
_jobs_lock = Lock()
def clear_attachment_association_jobs_for_tests() -> None:
with _jobs_lock:
_jobs.clear()
def create_attachment_association_job(
payload: AttachmentAssociationJobCreate,
current_user: CurrentUserContext,
) -> AttachmentAssociationJobRead:
job_id = f"attachment-association-{uuid4()}"
state = AttachmentAssociationJobState(
job_id=job_id,
owner_username=str(current_user.username or "").strip(),
owner_name=str(current_user.name or "").strip(),
receipt_ids=list(payload.receipt_ids),
prompt=str(payload.prompt or "").strip(),
conversation_id=str(payload.conversation_id or "").strip(),
)
with _jobs_lock:
_jobs[job_id] = state
return state.to_read()
def get_attachment_association_job(
job_id: str,
current_user: CurrentUserContext,
) -> AttachmentAssociationJobRead | None:
state = _get_authorized_state(job_id, current_user)
return state.to_read() if state is not None else None
def run_attachment_association_job(
job_id: str,
current_user: CurrentUserContext,
session_factory: sessionmaker[Session] | Callable[[], Session],
) -> None:
state = _get_authorized_state(job_id, current_user)
if state is None or state.status in TERMINAL_STATUSES:
return
_update_job(job_id, status="running", message="正在匹配可关联的报销草稿...")
try:
with session_factory() as db:
result = AttachmentAssociationJobRunner(db).run(
receipt_ids=state.receipt_ids,
current_user=current_user,
)
_update_job(
job_id,
status="succeeded",
message=f"已自动关联到 {result['claim_no']},成功归集 {result['uploaded_count']} 份附件。",
claim_id=str(result["claim_id"]),
claim_no=str(result["claim_no"]),
uploaded_count=int(result["uploaded_count"]),
skipped_count=int(result["skipped_count"]),
error="",
)
except Exception as exc:
message = str(exc).strip() or "自动关联任务执行失败,请稍后重试。"
_update_job(
job_id,
status="failed",
message=message,
error=message,
)
def _get_authorized_state(
job_id: str,
current_user: CurrentUserContext,
) -> AttachmentAssociationJobState | None:
normalized_job_id = str(job_id or "").strip()
with _jobs_lock:
state = _jobs.get(normalized_job_id)
if state is None:
return None
if current_user.is_admin:
return state
username = str(current_user.username or "").strip()
name = str(current_user.name or "").strip()
if username and username == state.owner_username:
return state
if name and name == state.owner_name:
return state
return None
def _update_job(job_id: str, **updates: Any) -> None:
with _jobs_lock:
state = _jobs.get(str(job_id or "").strip())
if state is None:
return
for key, value in updates.items():
if hasattr(state, key):
setattr(state, key, value)
state.updated_at = datetime.now(UTC)
class AttachmentAssociationJobRunner:
def __init__(self, db: Session) -> None:
self.db = db
self.claim_service = ExpenseClaimService(db)
self.receipt_service = ReceiptFolderService()
def run(
self,
*,
receipt_ids: list[str],
current_user: CurrentUserContext,
) -> dict[str, Any]:
receipts = self._load_receipts(receipt_ids, current_user)
candidates = self._rank_claims(receipts, current_user)
if not candidates:
raise ValueError("没有找到可自动关联的报销草稿,请先新建草稿或补充说明。")
recommended = candidates[0]
runner_up = candidates[1] if len(candidates) > 1 else None
if recommended.score < 5 or (runner_up is not None and recommended.score - runner_up.score < 2):
raise ValueError("找到多个可能关联的报销草稿,请补充说明或手动选择后再归集。")
uploaded_count = 0
skipped_count = 0
for receipt in receipts:
if self._is_linked_to_other_claim(receipt, recommended.claim.id):
skipped_count += 1
continue
target_item = self._resolve_target_item(
claim_id=recommended.claim.id,
receipt=receipt,
current_user=current_user,
)
source_path, media_type, file_name = self.receipt_service.resolve_source(receipt.id, current_user)
result = self.claim_service.upload_claim_item_attachment(
claim_id=recommended.claim.id,
item_id=target_item.id,
filename=file_name,
content=source_path.read_bytes(),
media_type=media_type,
current_user=current_user,
source_receipt_id=receipt.id,
)
if result is None:
skipped_count += 1
else:
uploaded_count += 1
if uploaded_count <= 0:
raise ValueError("未能归集任何附件,请进入报销单详情手动核对。")
return {
"claim_id": recommended.claim.id,
"claim_no": recommended.claim.claim_no,
"uploaded_count": uploaded_count,
"skipped_count": skipped_count,
}
def _load_receipts(
self,
receipt_ids: list[str],
current_user: CurrentUserContext,
) -> list[ReceiptFolderDetailRead]:
receipts = []
for receipt_id in list(dict.fromkeys(str(item or "").strip() for item in receipt_ids if str(item or "").strip())):
try:
receipts.append(self.receipt_service.get_receipt(receipt_id, current_user))
except FileNotFoundError as exc:
raise ValueError("当前附件没有持久化票据记录,请重新上传后再试。") from exc
if not receipts:
raise ValueError("当前附件没有持久化票据记录,请重新上传后再试。")
return receipts
def _rank_claims(
self,
receipts: list[ReceiptFolderDetailRead],
current_user: CurrentUserContext,
) -> list[AttachmentAssociationCandidate]:
signals = _collect_receipt_signals(receipts)
claims = [
claim
for claim in self.claim_service.list_claims(current_user)
if self._is_auto_association_candidate(claim)
]
ranked = [
candidate
for candidate in (
self._score_claim(claim, signals)
for claim in claims
)
if candidate.score > 0
]
return sorted(ranked, key=lambda item: item.score, reverse=True)
def _is_auto_association_candidate(self, claim: ExpenseClaim) -> bool:
status = str(claim.status or "").strip().lower()
if status not in EDITABLE_CLAIM_STATUSES:
return False
return not self.claim_service._is_expense_application_claim(claim)
def _score_claim(
self,
claim: ExpenseClaim,
signals: dict[str, Any],
) -> AttachmentAssociationCandidate:
claim_text = _build_claim_text(claim)
compact_claim_text = _normalize_text(claim_text)
claim_dates = _extract_date_tokens(claim_text)
claim_cities = _unique([*_extract_city_tokens(claim_text), *_extract_city_tokens(claim.location)])
reasons: list[str] = []
score = 0
if _dates_overlap(signals["dates"], claim_dates):
score += 4
reasons.append("票据日期与报销单日期一致")
matched_cities = [city for city in signals["cities"] if city in compact_claim_text]
if matched_cities:
score += min(4, len(matched_cities) * 2)
reasons.append(f"地点或行程包含 {''.join(matched_cities)}")
if len(claim_cities) >= 2 and len(matched_cities) >= 2:
score += 2
reasons.append("票据往返城市与报销事由吻合")
if str(claim.status or "").strip().lower() == "draft":
score += 1
reasons.append("当前单据仍是可归集草稿")
return AttachmentAssociationCandidate(claim=claim, score=score, reasons=reasons)
@staticmethod
def _is_linked_to_other_claim(receipt: ReceiptFolderDetailRead, claim_id: str) -> bool:
linked_claim_id = str(receipt.linked_claim_id or "").strip()
return bool(str(receipt.status or "").strip() == "linked" and linked_claim_id and linked_claim_id != claim_id)
def _resolve_target_item(
self,
*,
claim_id: str,
receipt: ReceiptFolderDetailRead,
current_user: CurrentUserContext,
) -> ExpenseClaimItem:
claim = self.claim_service.get_claim(claim_id, current_user)
if claim is None:
raise ValueError("匹配到的报销草稿不存在,请刷新后再试。")
preferred_type = _resolve_receipt_item_type(receipt)
empty_items = [
item
for item in list(claim.items or [])
if not str(item.invoice_id or "").strip() and not item.is_system_generated
]
for item in empty_items:
if preferred_type and str(item.item_type or "").strip() == preferred_type:
return item
if empty_items:
return empty_items[0]
before_ids = {str(item.id) for item in list(claim.items or [])}
created_claim = self.claim_service.create_claim_item(
claim_id=claim.id,
payload=_build_item_payload_from_receipt(claim, receipt, preferred_type),
current_user=current_user,
)
if created_claim is None:
raise ValueError("无法创建票据归集明细,请进入详情页手动处理。")
for item in list(created_claim.items or []):
if str(item.id) not in before_ids and not str(item.invoice_id or "").strip():
return item
raise ValueError("无法找到可归集的费用明细,请进入详情页手动处理。")
def _normalize_text(value: Any) -> str:
return re.sub(r"\s+", "", str(value or "").strip())
def _unique(values: list[str] | tuple[str, ...]) -> list[str]:
return list(dict.fromkeys(str(item or "").strip() for item in values if str(item or "").strip()))
def _extract_date_tokens(text: Any) -> list[str]:
source = str(text or "")
matches = [
*re.finditer(r"20\d{2}[-/.年]\d{1,2}[-/.月]\d{1,2}", source),
*re.finditer(r"\d{1,2}月\d{1,2}", source),
]
return _unique([_normalize_date_token(match.group(0)) for match in matches])
def _normalize_date_token(value: Any) -> str:
if isinstance(value, (date, datetime)):
return value.isoformat()[:10]
text = str(value or "").strip()
full_match = re.search(r"(20\d{2})[-/.年](\d{1,2})[-/.月](\d{1,2})", text)
if full_match:
year, month, day = full_match.groups()
return f"{year}-{month.zfill(2)}-{day.zfill(2)}"
short_match = re.search(r"(\d{1,2})月(\d{1,2})", text)
if short_match:
month, day = short_match.groups()
return f"{month.zfill(2)}-{day.zfill(2)}"
return ""
def _extract_city_tokens(text: Any) -> list[str]:
compact = _normalize_text(text)
if not compact:
return []
return [city for city in CITY_NAMES if city in compact]
def _dates_overlap(left: list[str], right: list[str]) -> bool:
for left_date in left:
if not left_date:
continue
for right_date in right:
if right_date and (left_date == right_date or left_date.endswith(right_date) or right_date.endswith(left_date)):
return True
return False
def _collect_receipt_signals(receipts: list[ReceiptFolderDetailRead]) -> dict[str, Any]:
text = "\n".join(_build_receipt_text(receipt) for receipt in receipts)
dates = _unique([
*_extract_date_tokens(text),
*[str(receipt.document_date or "").strip() for receipt in receipts],
])
return {
"text": text,
"dates": dates,
"cities": _unique(_extract_city_tokens(text)),
}
def _build_receipt_text(receipt: ReceiptFolderDetailRead) -> str:
fields_text = "\n".join(
f"{field.label} {field.value}"
for field in list(receipt.fields or [])
if str(field.label or field.value or "").strip()
)
return "\n".join(
value
for value in (
receipt.file_name,
receipt.summary,
receipt.ocr_text,
receipt.document_date,
receipt.merchant_name,
fields_text,
)
if str(value or "").strip()
)
def _build_claim_text(claim: ExpenseClaim) -> str:
item_text = "\n".join(
" ".join(
str(value or "").strip()
for value in (
item.item_date.isoformat() if item.item_date else "",
item.item_type,
item.item_reason,
item.item_location,
item.item_note,
)
if str(value or "").strip()
)
for item in list(claim.items or [])
)
occurred_at = claim.occurred_at.isoformat()[:10] if claim.occurred_at else ""
return "\n".join(
value
for value in (
claim.claim_no,
claim.expense_type,
claim.status,
claim.reason,
claim.location,
occurred_at,
item_text,
)
if str(value or "").strip()
)
def _resolve_receipt_item_type(receipt: ReceiptFolderDetailRead) -> str:
document_type = str(receipt.document_type or "").strip()
if document_type in DOCUMENT_TYPE_ITEM_TYPE_MAP:
return DOCUMENT_TYPE_ITEM_TYPE_MAP[document_type]
scene_code = str(receipt.scene_code or "").strip()
if scene_code == "travel":
return "travel"
return scene_code or "other"
def _build_item_payload_from_receipt(
claim: ExpenseClaim,
receipt: ReceiptFolderDetailRead,
preferred_type: str,
) -> ExpenseClaimItemCreate:
item_date = _resolve_receipt_item_date(receipt) or (claim.occurred_at.date() if claim.occurred_at else None)
return ExpenseClaimItemCreate(
item_date=item_date,
item_type=preferred_type or str(claim.expense_type or "").strip() or "other",
item_reason=str(receipt.summary or receipt.file_name or "").strip(),
item_location=_resolve_receipt_item_location(receipt) or str(claim.location or "").strip(),
item_amount=Decimal("0.00"),
)
def _resolve_receipt_item_date(receipt: ReceiptFolderDetailRead) -> date | None:
for value in [
*[field.value for field in list(receipt.fields or []) if "日期" in str(field.label or "") or "时间" in str(field.label or "")],
receipt.document_date,
]:
token = _normalize_date_token(value)
if len(token) == 10:
try:
return date.fromisoformat(token)
except ValueError:
continue
return None
def _resolve_receipt_item_location(receipt: ReceiptFolderDetailRead) -> str:
for field in list(receipt.fields or []):
label = str(field.label or "")
value = str(field.value or "").strip()
if value and ("行程" in label or "到达" in label or "地点" in label or "城市" in label):
cities = _extract_city_tokens(value)
return cities[-1] if cities else value[:40]
cities = _extract_city_tokens(_build_receipt_text(receipt))
return cities[-1] if cities else ""