- attachment_association_jobs:从票据夹批量关联附件到报销单,识别城市/日期并创建明细项,内存态 job 跟踪 - linked_reimbursement_draft_jobs:基于申请单异步生成关联报销草稿,调用 Orchestrator 编排,区分 succeeded/failed 终态 - application_location_semantics:抽取差旅出发/到达城市、判断具体地址/业务动作等位置语义,供申请单校验复用 - router 注册两个 job 端点,新增对应 job/语义单元测试
550 lines
18 KiB
Python
550 lines
18 KiB
Python
from __future__ import annotations
|
|
|
|
import re
|
|
from dataclasses import dataclass, field
|
|
from datetime import UTC, date, datetime
|
|
from decimal import Decimal
|
|
from threading import Lock
|
|
from typing import Any, Callable
|
|
from uuid import uuid4
|
|
|
|
from sqlalchemy.orm import Session, sessionmaker
|
|
|
|
from app.api.deps import CurrentUserContext
|
|
from app.models.financial_record import ExpenseClaim, ExpenseClaimItem
|
|
from app.schemas.attachment_association_job import (
|
|
AttachmentAssociationJobCreate,
|
|
AttachmentAssociationJobRead,
|
|
)
|
|
from app.schemas.receipt_folder import ReceiptFolderDetailRead
|
|
from app.schemas.reimbursement import ExpenseClaimItemCreate
|
|
from app.services.expense_claim_constants import (
|
|
DOCUMENT_TYPE_ITEM_TYPE_MAP,
|
|
EDITABLE_CLAIM_STATUSES,
|
|
)
|
|
from app.services.expense_claims import ExpenseClaimService
|
|
from app.services.receipt_folder import ReceiptFolderService
|
|
|
|
|
|
CITY_NAMES = (
|
|
"北京",
|
|
"上海",
|
|
"广州",
|
|
"深圳",
|
|
"武汉",
|
|
"南京",
|
|
"杭州",
|
|
"成都",
|
|
"重庆",
|
|
"西安",
|
|
"天津",
|
|
"苏州",
|
|
"长沙",
|
|
"郑州",
|
|
"青岛",
|
|
"厦门",
|
|
"宁波",
|
|
"无锡",
|
|
"合肥",
|
|
"福州",
|
|
"昆明",
|
|
"大连",
|
|
"沈阳",
|
|
"济南",
|
|
"哈尔滨",
|
|
"长春",
|
|
"南昌",
|
|
"太原",
|
|
"贵阳",
|
|
"南宁",
|
|
"石家庄",
|
|
"兰州",
|
|
"银川",
|
|
"西宁",
|
|
"海口",
|
|
"拉萨",
|
|
)
|
|
|
|
TERMINAL_STATUSES = {"succeeded", "failed"}
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class AttachmentAssociationJobState:
|
|
job_id: str
|
|
owner_username: str
|
|
owner_name: str
|
|
receipt_ids: list[str]
|
|
prompt: str = ""
|
|
conversation_id: str = ""
|
|
status: str = "queued"
|
|
message: str = "已创建附件关联任务,等待后台处理。"
|
|
claim_id: str = ""
|
|
claim_no: str = ""
|
|
uploaded_count: int = 0
|
|
skipped_count: int = 0
|
|
error: str = ""
|
|
created_at: datetime = field(default_factory=lambda: datetime.now(UTC))
|
|
updated_at: datetime = field(default_factory=lambda: datetime.now(UTC))
|
|
|
|
def to_read(self) -> AttachmentAssociationJobRead:
|
|
return AttachmentAssociationJobRead(
|
|
job_id=self.job_id,
|
|
status=self.status,
|
|
message=self.message,
|
|
receipt_ids=list(self.receipt_ids),
|
|
claim_id=self.claim_id,
|
|
claim_no=self.claim_no,
|
|
uploaded_count=self.uploaded_count,
|
|
skipped_count=self.skipped_count,
|
|
error=self.error,
|
|
prompt=self.prompt,
|
|
conversation_id=self.conversation_id,
|
|
created_at=self.created_at,
|
|
updated_at=self.updated_at,
|
|
)
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class AttachmentAssociationCandidate:
|
|
claim: ExpenseClaim
|
|
score: int
|
|
reasons: list[str]
|
|
|
|
|
|
_jobs: dict[str, AttachmentAssociationJobState] = {}
|
|
_jobs_lock = Lock()
|
|
|
|
|
|
def clear_attachment_association_jobs_for_tests() -> None:
|
|
with _jobs_lock:
|
|
_jobs.clear()
|
|
|
|
|
|
def create_attachment_association_job(
|
|
payload: AttachmentAssociationJobCreate,
|
|
current_user: CurrentUserContext,
|
|
) -> AttachmentAssociationJobRead:
|
|
job_id = f"attachment-association-{uuid4()}"
|
|
state = AttachmentAssociationJobState(
|
|
job_id=job_id,
|
|
owner_username=str(current_user.username or "").strip(),
|
|
owner_name=str(current_user.name or "").strip(),
|
|
receipt_ids=list(payload.receipt_ids),
|
|
prompt=str(payload.prompt or "").strip(),
|
|
conversation_id=str(payload.conversation_id or "").strip(),
|
|
)
|
|
with _jobs_lock:
|
|
_jobs[job_id] = state
|
|
return state.to_read()
|
|
|
|
|
|
def get_attachment_association_job(
|
|
job_id: str,
|
|
current_user: CurrentUserContext,
|
|
) -> AttachmentAssociationJobRead | None:
|
|
state = _get_authorized_state(job_id, current_user)
|
|
return state.to_read() if state is not None else None
|
|
|
|
|
|
def run_attachment_association_job(
|
|
job_id: str,
|
|
current_user: CurrentUserContext,
|
|
session_factory: sessionmaker[Session] | Callable[[], Session],
|
|
) -> None:
|
|
state = _get_authorized_state(job_id, current_user)
|
|
if state is None or state.status in TERMINAL_STATUSES:
|
|
return
|
|
|
|
_update_job(job_id, status="running", message="正在匹配可关联的报销草稿...")
|
|
try:
|
|
with session_factory() as db:
|
|
result = AttachmentAssociationJobRunner(db).run(
|
|
receipt_ids=state.receipt_ids,
|
|
current_user=current_user,
|
|
)
|
|
_update_job(
|
|
job_id,
|
|
status="succeeded",
|
|
message=f"已自动关联到 {result['claim_no']},成功归集 {result['uploaded_count']} 份附件。",
|
|
claim_id=str(result["claim_id"]),
|
|
claim_no=str(result["claim_no"]),
|
|
uploaded_count=int(result["uploaded_count"]),
|
|
skipped_count=int(result["skipped_count"]),
|
|
error="",
|
|
)
|
|
except Exception as exc:
|
|
message = str(exc).strip() or "自动关联任务执行失败,请稍后重试。"
|
|
_update_job(
|
|
job_id,
|
|
status="failed",
|
|
message=message,
|
|
error=message,
|
|
)
|
|
|
|
|
|
def _get_authorized_state(
|
|
job_id: str,
|
|
current_user: CurrentUserContext,
|
|
) -> AttachmentAssociationJobState | None:
|
|
normalized_job_id = str(job_id or "").strip()
|
|
with _jobs_lock:
|
|
state = _jobs.get(normalized_job_id)
|
|
if state is None:
|
|
return None
|
|
if current_user.is_admin:
|
|
return state
|
|
username = str(current_user.username or "").strip()
|
|
name = str(current_user.name or "").strip()
|
|
if username and username == state.owner_username:
|
|
return state
|
|
if name and name == state.owner_name:
|
|
return state
|
|
return None
|
|
|
|
|
|
def _update_job(job_id: str, **updates: Any) -> None:
|
|
with _jobs_lock:
|
|
state = _jobs.get(str(job_id or "").strip())
|
|
if state is None:
|
|
return
|
|
for key, value in updates.items():
|
|
if hasattr(state, key):
|
|
setattr(state, key, value)
|
|
state.updated_at = datetime.now(UTC)
|
|
|
|
|
|
class AttachmentAssociationJobRunner:
|
|
def __init__(self, db: Session) -> None:
|
|
self.db = db
|
|
self.claim_service = ExpenseClaimService(db)
|
|
self.receipt_service = ReceiptFolderService()
|
|
|
|
def run(
|
|
self,
|
|
*,
|
|
receipt_ids: list[str],
|
|
current_user: CurrentUserContext,
|
|
) -> dict[str, Any]:
|
|
receipts = self._load_receipts(receipt_ids, current_user)
|
|
candidates = self._rank_claims(receipts, current_user)
|
|
if not candidates:
|
|
raise ValueError("没有找到可自动关联的报销草稿,请先新建草稿或补充说明。")
|
|
|
|
recommended = candidates[0]
|
|
runner_up = candidates[1] if len(candidates) > 1 else None
|
|
if recommended.score < 5 or (runner_up is not None and recommended.score - runner_up.score < 2):
|
|
raise ValueError("找到多个可能关联的报销草稿,请补充说明或手动选择后再归集。")
|
|
|
|
uploaded_count = 0
|
|
skipped_count = 0
|
|
for receipt in receipts:
|
|
if self._is_linked_to_other_claim(receipt, recommended.claim.id):
|
|
skipped_count += 1
|
|
continue
|
|
target_item = self._resolve_target_item(
|
|
claim_id=recommended.claim.id,
|
|
receipt=receipt,
|
|
current_user=current_user,
|
|
)
|
|
source_path, media_type, file_name = self.receipt_service.resolve_source(receipt.id, current_user)
|
|
result = self.claim_service.upload_claim_item_attachment(
|
|
claim_id=recommended.claim.id,
|
|
item_id=target_item.id,
|
|
filename=file_name,
|
|
content=source_path.read_bytes(),
|
|
media_type=media_type,
|
|
current_user=current_user,
|
|
source_receipt_id=receipt.id,
|
|
)
|
|
if result is None:
|
|
skipped_count += 1
|
|
else:
|
|
uploaded_count += 1
|
|
|
|
if uploaded_count <= 0:
|
|
raise ValueError("未能归集任何附件,请进入报销单详情手动核对。")
|
|
return {
|
|
"claim_id": recommended.claim.id,
|
|
"claim_no": recommended.claim.claim_no,
|
|
"uploaded_count": uploaded_count,
|
|
"skipped_count": skipped_count,
|
|
}
|
|
|
|
def _load_receipts(
|
|
self,
|
|
receipt_ids: list[str],
|
|
current_user: CurrentUserContext,
|
|
) -> list[ReceiptFolderDetailRead]:
|
|
receipts = []
|
|
for receipt_id in list(dict.fromkeys(str(item or "").strip() for item in receipt_ids if str(item or "").strip())):
|
|
try:
|
|
receipts.append(self.receipt_service.get_receipt(receipt_id, current_user))
|
|
except FileNotFoundError as exc:
|
|
raise ValueError("当前附件没有持久化票据记录,请重新上传后再试。") from exc
|
|
if not receipts:
|
|
raise ValueError("当前附件没有持久化票据记录,请重新上传后再试。")
|
|
return receipts
|
|
|
|
def _rank_claims(
|
|
self,
|
|
receipts: list[ReceiptFolderDetailRead],
|
|
current_user: CurrentUserContext,
|
|
) -> list[AttachmentAssociationCandidate]:
|
|
signals = _collect_receipt_signals(receipts)
|
|
claims = [
|
|
claim
|
|
for claim in self.claim_service.list_claims(current_user)
|
|
if self._is_auto_association_candidate(claim)
|
|
]
|
|
ranked = [
|
|
candidate
|
|
for candidate in (
|
|
self._score_claim(claim, signals)
|
|
for claim in claims
|
|
)
|
|
if candidate.score > 0
|
|
]
|
|
return sorted(ranked, key=lambda item: item.score, reverse=True)
|
|
|
|
def _is_auto_association_candidate(self, claim: ExpenseClaim) -> bool:
|
|
status = str(claim.status or "").strip().lower()
|
|
if status not in EDITABLE_CLAIM_STATUSES:
|
|
return False
|
|
return not self.claim_service._is_expense_application_claim(claim)
|
|
|
|
def _score_claim(
|
|
self,
|
|
claim: ExpenseClaim,
|
|
signals: dict[str, Any],
|
|
) -> AttachmentAssociationCandidate:
|
|
claim_text = _build_claim_text(claim)
|
|
compact_claim_text = _normalize_text(claim_text)
|
|
claim_dates = _extract_date_tokens(claim_text)
|
|
claim_cities = _unique([*_extract_city_tokens(claim_text), *_extract_city_tokens(claim.location)])
|
|
reasons: list[str] = []
|
|
score = 0
|
|
|
|
if _dates_overlap(signals["dates"], claim_dates):
|
|
score += 4
|
|
reasons.append("票据日期与报销单日期一致")
|
|
|
|
matched_cities = [city for city in signals["cities"] if city in compact_claim_text]
|
|
if matched_cities:
|
|
score += min(4, len(matched_cities) * 2)
|
|
reasons.append(f"地点或行程包含 {'、'.join(matched_cities)}")
|
|
|
|
if len(claim_cities) >= 2 and len(matched_cities) >= 2:
|
|
score += 2
|
|
reasons.append("票据往返城市与报销事由吻合")
|
|
|
|
if str(claim.status or "").strip().lower() == "draft":
|
|
score += 1
|
|
reasons.append("当前单据仍是可归集草稿")
|
|
|
|
return AttachmentAssociationCandidate(claim=claim, score=score, reasons=reasons)
|
|
|
|
@staticmethod
|
|
def _is_linked_to_other_claim(receipt: ReceiptFolderDetailRead, claim_id: str) -> bool:
|
|
linked_claim_id = str(receipt.linked_claim_id or "").strip()
|
|
return bool(str(receipt.status or "").strip() == "linked" and linked_claim_id and linked_claim_id != claim_id)
|
|
|
|
def _resolve_target_item(
|
|
self,
|
|
*,
|
|
claim_id: str,
|
|
receipt: ReceiptFolderDetailRead,
|
|
current_user: CurrentUserContext,
|
|
) -> ExpenseClaimItem:
|
|
claim = self.claim_service.get_claim(claim_id, current_user)
|
|
if claim is None:
|
|
raise ValueError("匹配到的报销草稿不存在,请刷新后再试。")
|
|
|
|
preferred_type = _resolve_receipt_item_type(receipt)
|
|
empty_items = [
|
|
item
|
|
for item in list(claim.items or [])
|
|
if not str(item.invoice_id or "").strip() and not item.is_system_generated
|
|
]
|
|
for item in empty_items:
|
|
if preferred_type and str(item.item_type or "").strip() == preferred_type:
|
|
return item
|
|
if empty_items:
|
|
return empty_items[0]
|
|
|
|
before_ids = {str(item.id) for item in list(claim.items or [])}
|
|
created_claim = self.claim_service.create_claim_item(
|
|
claim_id=claim.id,
|
|
payload=_build_item_payload_from_receipt(claim, receipt, preferred_type),
|
|
current_user=current_user,
|
|
)
|
|
if created_claim is None:
|
|
raise ValueError("无法创建票据归集明细,请进入详情页手动处理。")
|
|
for item in list(created_claim.items or []):
|
|
if str(item.id) not in before_ids and not str(item.invoice_id or "").strip():
|
|
return item
|
|
raise ValueError("无法找到可归集的费用明细,请进入详情页手动处理。")
|
|
|
|
|
|
def _normalize_text(value: Any) -> str:
|
|
return re.sub(r"\s+", "", str(value or "").strip())
|
|
|
|
|
|
def _unique(values: list[str] | tuple[str, ...]) -> list[str]:
|
|
return list(dict.fromkeys(str(item or "").strip() for item in values if str(item or "").strip()))
|
|
|
|
|
|
def _extract_date_tokens(text: Any) -> list[str]:
|
|
source = str(text or "")
|
|
matches = [
|
|
*re.finditer(r"20\d{2}[-/.年]\d{1,2}[-/.月]\d{1,2}", source),
|
|
*re.finditer(r"\d{1,2}月\d{1,2}", source),
|
|
]
|
|
return _unique([_normalize_date_token(match.group(0)) for match in matches])
|
|
|
|
|
|
def _normalize_date_token(value: Any) -> str:
|
|
if isinstance(value, (date, datetime)):
|
|
return value.isoformat()[:10]
|
|
text = str(value or "").strip()
|
|
full_match = re.search(r"(20\d{2})[-/.年](\d{1,2})[-/.月](\d{1,2})", text)
|
|
if full_match:
|
|
year, month, day = full_match.groups()
|
|
return f"{year}-{month.zfill(2)}-{day.zfill(2)}"
|
|
short_match = re.search(r"(\d{1,2})月(\d{1,2})", text)
|
|
if short_match:
|
|
month, day = short_match.groups()
|
|
return f"{month.zfill(2)}-{day.zfill(2)}"
|
|
return ""
|
|
|
|
|
|
def _extract_city_tokens(text: Any) -> list[str]:
|
|
compact = _normalize_text(text)
|
|
if not compact:
|
|
return []
|
|
return [city for city in CITY_NAMES if city in compact]
|
|
|
|
|
|
def _dates_overlap(left: list[str], right: list[str]) -> bool:
|
|
for left_date in left:
|
|
if not left_date:
|
|
continue
|
|
for right_date in right:
|
|
if right_date and (left_date == right_date or left_date.endswith(right_date) or right_date.endswith(left_date)):
|
|
return True
|
|
return False
|
|
|
|
|
|
def _collect_receipt_signals(receipts: list[ReceiptFolderDetailRead]) -> dict[str, Any]:
|
|
text = "\n".join(_build_receipt_text(receipt) for receipt in receipts)
|
|
dates = _unique([
|
|
*_extract_date_tokens(text),
|
|
*[str(receipt.document_date or "").strip() for receipt in receipts],
|
|
])
|
|
return {
|
|
"text": text,
|
|
"dates": dates,
|
|
"cities": _unique(_extract_city_tokens(text)),
|
|
}
|
|
|
|
|
|
def _build_receipt_text(receipt: ReceiptFolderDetailRead) -> str:
|
|
fields_text = "\n".join(
|
|
f"{field.label} {field.value}"
|
|
for field in list(receipt.fields or [])
|
|
if str(field.label or field.value or "").strip()
|
|
)
|
|
return "\n".join(
|
|
value
|
|
for value in (
|
|
receipt.file_name,
|
|
receipt.summary,
|
|
receipt.ocr_text,
|
|
receipt.document_date,
|
|
receipt.merchant_name,
|
|
fields_text,
|
|
)
|
|
if str(value or "").strip()
|
|
)
|
|
|
|
|
|
def _build_claim_text(claim: ExpenseClaim) -> str:
|
|
item_text = "\n".join(
|
|
" ".join(
|
|
str(value or "").strip()
|
|
for value in (
|
|
item.item_date.isoformat() if item.item_date else "",
|
|
item.item_type,
|
|
item.item_reason,
|
|
item.item_location,
|
|
item.item_note,
|
|
)
|
|
if str(value or "").strip()
|
|
)
|
|
for item in list(claim.items or [])
|
|
)
|
|
occurred_at = claim.occurred_at.isoformat()[:10] if claim.occurred_at else ""
|
|
return "\n".join(
|
|
value
|
|
for value in (
|
|
claim.claim_no,
|
|
claim.expense_type,
|
|
claim.status,
|
|
claim.reason,
|
|
claim.location,
|
|
occurred_at,
|
|
item_text,
|
|
)
|
|
if str(value or "").strip()
|
|
)
|
|
|
|
|
|
def _resolve_receipt_item_type(receipt: ReceiptFolderDetailRead) -> str:
|
|
document_type = str(receipt.document_type or "").strip()
|
|
if document_type in DOCUMENT_TYPE_ITEM_TYPE_MAP:
|
|
return DOCUMENT_TYPE_ITEM_TYPE_MAP[document_type]
|
|
scene_code = str(receipt.scene_code or "").strip()
|
|
if scene_code == "travel":
|
|
return "travel"
|
|
return scene_code or "other"
|
|
|
|
|
|
def _build_item_payload_from_receipt(
|
|
claim: ExpenseClaim,
|
|
receipt: ReceiptFolderDetailRead,
|
|
preferred_type: str,
|
|
) -> ExpenseClaimItemCreate:
|
|
item_date = _resolve_receipt_item_date(receipt) or (claim.occurred_at.date() if claim.occurred_at else None)
|
|
return ExpenseClaimItemCreate(
|
|
item_date=item_date,
|
|
item_type=preferred_type or str(claim.expense_type or "").strip() or "other",
|
|
item_reason=str(receipt.summary or receipt.file_name or "").strip(),
|
|
item_location=_resolve_receipt_item_location(receipt) or str(claim.location or "").strip(),
|
|
item_amount=Decimal("0.00"),
|
|
)
|
|
|
|
|
|
def _resolve_receipt_item_date(receipt: ReceiptFolderDetailRead) -> date | None:
|
|
for value in [
|
|
*[field.value for field in list(receipt.fields or []) if "日期" in str(field.label or "") or "时间" in str(field.label or "")],
|
|
receipt.document_date,
|
|
]:
|
|
token = _normalize_date_token(value)
|
|
if len(token) == 10:
|
|
try:
|
|
return date.fromisoformat(token)
|
|
except ValueError:
|
|
continue
|
|
return None
|
|
|
|
|
|
def _resolve_receipt_item_location(receipt: ReceiptFolderDetailRead) -> str:
|
|
for field in list(receipt.fields or []):
|
|
label = str(field.label or "")
|
|
value = str(field.value or "").strip()
|
|
if value and ("行程" in label or "到达" in label or "地点" in label or "城市" in label):
|
|
cities = _extract_city_tokens(value)
|
|
return cities[-1] if cities else value[:40]
|
|
cities = _extract_city_tokens(_build_receipt_text(receipt))
|
|
return cities[-1] if cities else ""
|
|
|