from __future__ import annotations import re from dataclasses import dataclass, field from datetime import UTC, date, datetime from decimal import Decimal from threading import Lock from typing import Any, Callable from uuid import uuid4 from sqlalchemy.orm import Session, sessionmaker from app.api.deps import CurrentUserContext from app.models.financial_record import ExpenseClaim, ExpenseClaimItem from app.schemas.attachment_association_job import ( AttachmentAssociationJobCreate, AttachmentAssociationJobRead, ) from app.schemas.receipt_folder import ReceiptFolderDetailRead from app.schemas.reimbursement import ExpenseClaimItemCreate from app.services.expense_claim_constants import ( DOCUMENT_TYPE_ITEM_TYPE_MAP, EDITABLE_CLAIM_STATUSES, ) from app.services.expense_claims import ExpenseClaimService from app.services.receipt_folder import ReceiptFolderService CITY_NAMES = ( "北京", "上海", "广州", "深圳", "武汉", "南京", "杭州", "成都", "重庆", "西安", "天津", "苏州", "长沙", "郑州", "青岛", "厦门", "宁波", "无锡", "合肥", "福州", "昆明", "大连", "沈阳", "济南", "哈尔滨", "长春", "南昌", "太原", "贵阳", "南宁", "石家庄", "兰州", "银川", "西宁", "海口", "拉萨", ) TERMINAL_STATUSES = {"succeeded", "failed"} @dataclass(slots=True) class AttachmentAssociationJobState: job_id: str owner_username: str owner_name: str receipt_ids: list[str] prompt: str = "" conversation_id: str = "" status: str = "queued" message: str = "已创建附件关联任务,等待后台处理。" claim_id: str = "" claim_no: str = "" uploaded_count: int = 0 skipped_count: int = 0 error: str = "" created_at: datetime = field(default_factory=lambda: datetime.now(UTC)) updated_at: datetime = field(default_factory=lambda: datetime.now(UTC)) def to_read(self) -> AttachmentAssociationJobRead: return AttachmentAssociationJobRead( job_id=self.job_id, status=self.status, message=self.message, receipt_ids=list(self.receipt_ids), claim_id=self.claim_id, claim_no=self.claim_no, uploaded_count=self.uploaded_count, skipped_count=self.skipped_count, error=self.error, prompt=self.prompt, conversation_id=self.conversation_id, created_at=self.created_at, updated_at=self.updated_at, ) @dataclass(slots=True) class AttachmentAssociationCandidate: claim: ExpenseClaim score: int reasons: list[str] _jobs: dict[str, AttachmentAssociationJobState] = {} _jobs_lock = Lock() def clear_attachment_association_jobs_for_tests() -> None: with _jobs_lock: _jobs.clear() def create_attachment_association_job( payload: AttachmentAssociationJobCreate, current_user: CurrentUserContext, ) -> AttachmentAssociationJobRead: job_id = f"attachment-association-{uuid4()}" state = AttachmentAssociationJobState( job_id=job_id, owner_username=str(current_user.username or "").strip(), owner_name=str(current_user.name or "").strip(), receipt_ids=list(payload.receipt_ids), prompt=str(payload.prompt or "").strip(), conversation_id=str(payload.conversation_id or "").strip(), ) with _jobs_lock: _jobs[job_id] = state return state.to_read() def get_attachment_association_job( job_id: str, current_user: CurrentUserContext, ) -> AttachmentAssociationJobRead | None: state = _get_authorized_state(job_id, current_user) return state.to_read() if state is not None else None def run_attachment_association_job( job_id: str, current_user: CurrentUserContext, session_factory: sessionmaker[Session] | Callable[[], Session], ) -> None: state = _get_authorized_state(job_id, current_user) if state is None or state.status in TERMINAL_STATUSES: return _update_job(job_id, status="running", message="正在匹配可关联的报销草稿...") try: with session_factory() as db: result = AttachmentAssociationJobRunner(db).run( receipt_ids=state.receipt_ids, current_user=current_user, ) _update_job( job_id, status="succeeded", message=f"已自动关联到 {result['claim_no']},成功归集 {result['uploaded_count']} 份附件。", claim_id=str(result["claim_id"]), claim_no=str(result["claim_no"]), uploaded_count=int(result["uploaded_count"]), skipped_count=int(result["skipped_count"]), error="", ) except Exception as exc: message = str(exc).strip() or "自动关联任务执行失败,请稍后重试。" _update_job( job_id, status="failed", message=message, error=message, ) def _get_authorized_state( job_id: str, current_user: CurrentUserContext, ) -> AttachmentAssociationJobState | None: normalized_job_id = str(job_id or "").strip() with _jobs_lock: state = _jobs.get(normalized_job_id) if state is None: return None if current_user.is_admin: return state username = str(current_user.username or "").strip() name = str(current_user.name or "").strip() if username and username == state.owner_username: return state if name and name == state.owner_name: return state return None def _update_job(job_id: str, **updates: Any) -> None: with _jobs_lock: state = _jobs.get(str(job_id or "").strip()) if state is None: return for key, value in updates.items(): if hasattr(state, key): setattr(state, key, value) state.updated_at = datetime.now(UTC) class AttachmentAssociationJobRunner: def __init__(self, db: Session) -> None: self.db = db self.claim_service = ExpenseClaimService(db) self.receipt_service = ReceiptFolderService() def run( self, *, receipt_ids: list[str], current_user: CurrentUserContext, ) -> dict[str, Any]: receipts = self._load_receipts(receipt_ids, current_user) candidates = self._rank_claims(receipts, current_user) if not candidates: raise ValueError("没有找到可自动关联的报销草稿,请先新建草稿或补充说明。") recommended = candidates[0] runner_up = candidates[1] if len(candidates) > 1 else None if recommended.score < 5 or (runner_up is not None and recommended.score - runner_up.score < 2): raise ValueError("找到多个可能关联的报销草稿,请补充说明或手动选择后再归集。") uploaded_count = 0 skipped_count = 0 for receipt in receipts: if self._is_linked_to_other_claim(receipt, recommended.claim.id): skipped_count += 1 continue target_item = self._resolve_target_item( claim_id=recommended.claim.id, receipt=receipt, current_user=current_user, ) source_path, media_type, file_name = self.receipt_service.resolve_source(receipt.id, current_user) result = self.claim_service.upload_claim_item_attachment( claim_id=recommended.claim.id, item_id=target_item.id, filename=file_name, content=source_path.read_bytes(), media_type=media_type, current_user=current_user, source_receipt_id=receipt.id, ) if result is None: skipped_count += 1 else: uploaded_count += 1 if uploaded_count <= 0: raise ValueError("未能归集任何附件,请进入报销单详情手动核对。") return { "claim_id": recommended.claim.id, "claim_no": recommended.claim.claim_no, "uploaded_count": uploaded_count, "skipped_count": skipped_count, } def _load_receipts( self, receipt_ids: list[str], current_user: CurrentUserContext, ) -> list[ReceiptFolderDetailRead]: receipts = [] for receipt_id in list(dict.fromkeys(str(item or "").strip() for item in receipt_ids if str(item or "").strip())): try: receipts.append(self.receipt_service.get_receipt(receipt_id, current_user)) except FileNotFoundError as exc: raise ValueError("当前附件没有持久化票据记录,请重新上传后再试。") from exc if not receipts: raise ValueError("当前附件没有持久化票据记录,请重新上传后再试。") return receipts def _rank_claims( self, receipts: list[ReceiptFolderDetailRead], current_user: CurrentUserContext, ) -> list[AttachmentAssociationCandidate]: signals = _collect_receipt_signals(receipts) claims = [ claim for claim in self.claim_service.list_claims(current_user) if self._is_auto_association_candidate(claim) ] ranked = [ candidate for candidate in ( self._score_claim(claim, signals) for claim in claims ) if candidate.score > 0 ] return sorted(ranked, key=lambda item: item.score, reverse=True) def _is_auto_association_candidate(self, claim: ExpenseClaim) -> bool: status = str(claim.status or "").strip().lower() if status not in EDITABLE_CLAIM_STATUSES: return False return not self.claim_service._is_expense_application_claim(claim) def _score_claim( self, claim: ExpenseClaim, signals: dict[str, Any], ) -> AttachmentAssociationCandidate: claim_text = _build_claim_text(claim) compact_claim_text = _normalize_text(claim_text) claim_dates = _extract_date_tokens(claim_text) claim_cities = _unique([*_extract_city_tokens(claim_text), *_extract_city_tokens(claim.location)]) reasons: list[str] = [] score = 0 if _dates_overlap(signals["dates"], claim_dates): score += 4 reasons.append("票据日期与报销单日期一致") matched_cities = [city for city in signals["cities"] if city in compact_claim_text] if matched_cities: score += min(4, len(matched_cities) * 2) reasons.append(f"地点或行程包含 {'、'.join(matched_cities)}") if len(claim_cities) >= 2 and len(matched_cities) >= 2: score += 2 reasons.append("票据往返城市与报销事由吻合") if str(claim.status or "").strip().lower() == "draft": score += 1 reasons.append("当前单据仍是可归集草稿") return AttachmentAssociationCandidate(claim=claim, score=score, reasons=reasons) @staticmethod def _is_linked_to_other_claim(receipt: ReceiptFolderDetailRead, claim_id: str) -> bool: linked_claim_id = str(receipt.linked_claim_id or "").strip() return bool(str(receipt.status or "").strip() == "linked" and linked_claim_id and linked_claim_id != claim_id) def _resolve_target_item( self, *, claim_id: str, receipt: ReceiptFolderDetailRead, current_user: CurrentUserContext, ) -> ExpenseClaimItem: claim = self.claim_service.get_claim(claim_id, current_user) if claim is None: raise ValueError("匹配到的报销草稿不存在,请刷新后再试。") preferred_type = _resolve_receipt_item_type(receipt) empty_items = [ item for item in list(claim.items or []) if not str(item.invoice_id or "").strip() and not item.is_system_generated ] for item in empty_items: if preferred_type and str(item.item_type or "").strip() == preferred_type: return item if empty_items: return empty_items[0] before_ids = {str(item.id) for item in list(claim.items or [])} created_claim = self.claim_service.create_claim_item( claim_id=claim.id, payload=_build_item_payload_from_receipt(claim, receipt, preferred_type), current_user=current_user, ) if created_claim is None: raise ValueError("无法创建票据归集明细,请进入详情页手动处理。") for item in list(created_claim.items or []): if str(item.id) not in before_ids and not str(item.invoice_id or "").strip(): return item raise ValueError("无法找到可归集的费用明细,请进入详情页手动处理。") def _normalize_text(value: Any) -> str: return re.sub(r"\s+", "", str(value or "").strip()) def _unique(values: list[str] | tuple[str, ...]) -> list[str]: return list(dict.fromkeys(str(item or "").strip() for item in values if str(item or "").strip())) def _extract_date_tokens(text: Any) -> list[str]: source = str(text or "") matches = [ *re.finditer(r"20\d{2}[-/.年]\d{1,2}[-/.月]\d{1,2}", source), *re.finditer(r"\d{1,2}月\d{1,2}", source), ] return _unique([_normalize_date_token(match.group(0)) for match in matches]) def _normalize_date_token(value: Any) -> str: if isinstance(value, (date, datetime)): return value.isoformat()[:10] text = str(value or "").strip() full_match = re.search(r"(20\d{2})[-/.年](\d{1,2})[-/.月](\d{1,2})", text) if full_match: year, month, day = full_match.groups() return f"{year}-{month.zfill(2)}-{day.zfill(2)}" short_match = re.search(r"(\d{1,2})月(\d{1,2})", text) if short_match: month, day = short_match.groups() return f"{month.zfill(2)}-{day.zfill(2)}" return "" def _extract_city_tokens(text: Any) -> list[str]: compact = _normalize_text(text) if not compact: return [] return [city for city in CITY_NAMES if city in compact] def _dates_overlap(left: list[str], right: list[str]) -> bool: for left_date in left: if not left_date: continue for right_date in right: if right_date and (left_date == right_date or left_date.endswith(right_date) or right_date.endswith(left_date)): return True return False def _collect_receipt_signals(receipts: list[ReceiptFolderDetailRead]) -> dict[str, Any]: text = "\n".join(_build_receipt_text(receipt) for receipt in receipts) dates = _unique([ *_extract_date_tokens(text), *[str(receipt.document_date or "").strip() for receipt in receipts], ]) return { "text": text, "dates": dates, "cities": _unique(_extract_city_tokens(text)), } def _build_receipt_text(receipt: ReceiptFolderDetailRead) -> str: fields_text = "\n".join( f"{field.label} {field.value}" for field in list(receipt.fields or []) if str(field.label or field.value or "").strip() ) return "\n".join( value for value in ( receipt.file_name, receipt.summary, receipt.ocr_text, receipt.document_date, receipt.merchant_name, fields_text, ) if str(value or "").strip() ) def _build_claim_text(claim: ExpenseClaim) -> str: item_text = "\n".join( " ".join( str(value or "").strip() for value in ( item.item_date.isoformat() if item.item_date else "", item.item_type, item.item_reason, item.item_location, item.item_note, ) if str(value or "").strip() ) for item in list(claim.items or []) ) occurred_at = claim.occurred_at.isoformat()[:10] if claim.occurred_at else "" return "\n".join( value for value in ( claim.claim_no, claim.expense_type, claim.status, claim.reason, claim.location, occurred_at, item_text, ) if str(value or "").strip() ) def _resolve_receipt_item_type(receipt: ReceiptFolderDetailRead) -> str: document_type = str(receipt.document_type or "").strip() if document_type in DOCUMENT_TYPE_ITEM_TYPE_MAP: return DOCUMENT_TYPE_ITEM_TYPE_MAP[document_type] scene_code = str(receipt.scene_code or "").strip() if scene_code == "travel": return "travel" return scene_code or "other" def _build_item_payload_from_receipt( claim: ExpenseClaim, receipt: ReceiptFolderDetailRead, preferred_type: str, ) -> ExpenseClaimItemCreate: item_date = _resolve_receipt_item_date(receipt) or (claim.occurred_at.date() if claim.occurred_at else None) return ExpenseClaimItemCreate( item_date=item_date, item_type=preferred_type or str(claim.expense_type or "").strip() or "other", item_reason=str(receipt.summary or receipt.file_name or "").strip(), item_location=_resolve_receipt_item_location(receipt) or str(claim.location or "").strip(), item_amount=Decimal("0.00"), ) def _resolve_receipt_item_date(receipt: ReceiptFolderDetailRead) -> date | None: for value in [ *[field.value for field in list(receipt.fields or []) if "日期" in str(field.label or "") or "时间" in str(field.label or "")], receipt.document_date, ]: token = _normalize_date_token(value) if len(token) == 10: try: return date.fromisoformat(token) except ValueError: continue return None def _resolve_receipt_item_location(receipt: ReceiptFolderDetailRead) -> str: for field in list(receipt.fields or []): label = str(field.label or "") value = str(field.value or "").strip() if value and ("行程" in label or "到达" in label or "地点" in label or "城市" in label): cities = _extract_city_tokens(value) return cities[-1] if cities else value[:40] cities = _extract_city_tokens(_build_receipt_text(receipt)) return cities[-1] if cities else ""