diff --git a/server/src/app/schemas/reimbursement.py b/server/src/app/schemas/reimbursement.py index 7d9179a..a71efe4 100644 --- a/server/src/app/schemas/reimbursement.py +++ b/server/src/app/schemas/reimbursement.py @@ -87,6 +87,8 @@ class ExpenseClaimAttachmentRead(BaseModel): size_bytes: int uploaded_at: datetime | None = None previewable: bool = True + preview_kind: str = "" + preview_url: str = "" analysis: ExpenseClaimAttachmentAnalysisRead | None = None document_info: ExpenseClaimAttachmentDocumentInfoRead | None = None requirement_check: ExpenseClaimAttachmentRequirementRead | None = None diff --git a/server/src/app/services/expense_claims.py b/server/src/app/services/expense_claims.py index 81c4b10..2be6061 100644 --- a/server/src/app/services/expense_claims.py +++ b/server/src/app/services/expense_claims.py @@ -1,5 +1,7 @@ from __future__ import annotations +import base64 +import binascii import json import mimetypes import re @@ -9,8 +11,10 @@ from decimal import Decimal, InvalidOperation from pathlib import Path from types import SimpleNamespace from typing import Any +from urllib.parse import quote from sqlalchemy import func, or_, select +from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session, selectinload from app.api.deps import CurrentUserContext @@ -102,6 +106,32 @@ DOCUMENT_SCENE_LABELS = { "other": "其他票据", } +DOCUMENT_ASSOCIATION_REVIEW_ACTIONS = { + "link_to_existing_draft", + "create_new_claim_from_documents", +} +MAX_CLAIM_NO_RETRY_ATTEMPTS = 3 +DOCUMENT_AMOUNT_PATTERNS = ( + re.compile( + r"(?:价税合计|合计金额|费用合计|订单(?:总)?金额|支付(?:金额)?|实付(?:金额)?|实收(?:金额)?|总(?:额|计|价)|票价|金额|车费|消费金额)" + r"[::\s¥¥人民币]*([0-9]+(?:[.,][0-9]{1,2})?)" + ), + re.compile(r"[¥¥]\s*([0-9]+(?:[.,][0-9]{1,2})?)"), + re.compile(r"([0-9]+(?:[.,][0-9]{1,2})?)\s*元"), +) +DOCUMENT_DATE_PATTERN = re.compile(r"((?:20\d{2}|19\d{2})[-/年.](?:1[0-2]|0?[1-9])[-/月.](?:3[01]|[12]\d|0?[1-9])日?)") +SYSTEM_GENERATED_REASON_PREFIXES = ( + "我上传了", + "请按当前已识别信息", + "请把当前上传的票据", + "请基于当前上传的多张票据", + "我已核对右侧识别结果", + "请同步修正逐票据识别结果", + "我已修改识别信息", + "查看报销草稿", + "请解释一下当前这笔报销的合规风险和待补充项", +) + class ExpenseClaimService: def __init__(self, db: Session) -> None: @@ -314,6 +344,10 @@ class ExpenseClaimService: file_path = attachment_dir / normalized_name file_path.write_bytes(content) + resolved_media_type = self._resolve_attachment_media_type( + normalized_name, + fallback=media_type, + ) attachment_analysis = self._build_fallback_attachment_analysis( media_type=media_type, @@ -353,16 +387,22 @@ class ExpenseClaimService: ) item.invoice_id = self._to_attachment_storage_key(file_path) + preview_meta = self._build_attachment_preview_meta( + file_path=file_path, + media_type=resolved_media_type, + ocr_document=ocr_document, + ) meta = { "file_name": normalized_name, "storage_key": item.invoice_id, - "media_type": self._resolve_attachment_media_type( - normalized_name, - fallback=media_type, - ), + "media_type": resolved_media_type, "size_bytes": len(content), "uploaded_at": datetime.now(UTC).isoformat(), - "previewable": self._is_previewable_media_type(media_type, normalized_name), + "previewable": bool(preview_meta["previewable"]), + "preview_kind": str(preview_meta["preview_kind"]), + "preview_storage_key": str(preview_meta["preview_storage_key"]), + "preview_media_type": str(preview_meta["preview_media_type"]), + "preview_file_name": str(preview_meta["preview_file_name"]), "analysis": attachment_analysis, "document_info": document_info, "requirement_check": requirement_check, @@ -438,6 +478,23 @@ class ExpenseClaimService: return self._resolve_item_attachment_content(item) + def get_claim_item_attachment_preview_content( + self, + *, + claim_id: str, + item_id: str, + current_user: CurrentUserContext, + ) -> tuple[Path, str, str] | None: + claim, item = self._get_claim_item_or_raise( + claim_id=claim_id, + item_id=item_id, + current_user=current_user, + ) + if claim is None: + return None + + return self._resolve_item_attachment_preview_content(item) + def delete_claim_item_attachment( self, *, @@ -609,10 +666,12 @@ class ExpenseClaimService: context_json: dict[str, Any], ) -> dict[str, Any]: self._ensure_ready() + context_json = dict(context_json or {}) + retry_count = self._resolve_claim_no_retry_count(context_json) - claim = self._find_target_claim(ontology=ontology, context_json=context_json) - is_new_claim = claim is None - before_json = self._serialize_claim(claim) if claim is not None else None + review_action = str(context_json.get("review_action") or "").strip() + attachment_names = self._resolve_attachment_names(context_json) + context_documents = self._resolve_context_documents(context_json) employee = self._resolve_employee( ontology=ontology, @@ -628,6 +687,40 @@ class ExpenseClaimService: user_id=user_id, ) ) + + association_candidate = self._find_association_candidate( + ontology=ontology, + context_json=context_json, + user_id=user_id, + employee=employee, + ) + if self._should_defer_multi_document_association( + context_json=context_json, + review_action=review_action, + association_candidate=association_candidate, + context_documents=context_documents, + ): + document_count = max(len(context_documents), len(attachment_names), self._resolve_attachment_count(context_json)) + return { + "message": ( + f"检测到你已有草稿 {association_candidate.claim_no}," + f"当前新上传了 {document_count} 张票据,请先选择关联到现有草稿,或单独建立新的报销单。" + ), + "draft_only": False, + "status": "pending_association_decision", + "pending_association_decision": True, + "association_candidate_claim_id": association_candidate.id, + "association_candidate_claim_no": association_candidate.claim_no, + } + + claim = self._find_target_claim( + ontology=ontology, + context_json=context_json, + review_action=review_action, + association_candidate=association_candidate, + ) + is_new_claim = claim is None + before_json = self._serialize_claim(claim) if claim is not None else None if is_new_claim: existing_draft_count = self._count_draft_claims_for_owner( employee=employee, @@ -655,7 +748,7 @@ class ExpenseClaimService: context_json=context_json, allow_message_fallback=is_new_claim, ) - attachment_count = self._resolve_attachment_count(context_json) + attachment_count = len(attachment_names) or self._resolve_attachment_count(context_json) final_amount = amount if amount is not None else (claim.amount if claim is not None else Decimal("0.00")) final_occurred_at = ( @@ -671,70 +764,118 @@ class ExpenseClaimService: list(claim.risk_flags_json or []) if claim is not None else [] ) - if claim is None: - claim = ExpenseClaim( - claim_no=self._generate_claim_no(final_occurred_at), - employee_id=employee.id if employee is not None else None, - employee_name=draft_owner_name, - department_id=employee.organization_unit_id if employee is not None else None, - department_name=self._resolve_department_name( + try: + if claim is None: + claim = ExpenseClaim( + claim_no=self._generate_claim_no(final_occurred_at), + employee_id=employee.id if employee is not None else None, + employee_name=draft_owner_name, + department_id=employee.organization_unit_id if employee is not None else None, + department_name=self._resolve_department_name( + employee=employee, + context_json=context_json, + ), + project_code=self._resolve_project_code(ontology.entities), + expense_type=final_expense_type, + reason=final_reason, + location=final_location, + amount=final_amount, + currency="CNY", + invoice_count=final_attachment_count, + occurred_at=final_occurred_at, + status="draft", + approval_stage="待提交", + risk_flags_json=final_risk_flags, + ) + self.db.add(claim) + else: + claim.employee_id = employee.id if employee is not None else claim.employee_id + claim.employee_name = ( + employee.name + if employee is not None + else self._resolve_employee_name( + ontology=ontology, + context_json=context_json, + user_id=user_id, + fallback=claim.employee_name, + ) + ) + claim.department_id = employee.organization_unit_id if employee is not None else claim.department_id + claim.department_name = self._resolve_department_name( employee=employee, context_json=context_json, - ), - project_code=self._resolve_project_code(ontology.entities), - expense_type=final_expense_type, - reason=final_reason, - location=final_location, - amount=final_amount, - currency="CNY", - invoice_count=final_attachment_count, - occurred_at=final_occurred_at, - status="draft", - approval_stage="待提交", - risk_flags_json=final_risk_flags, - ) - self.db.add(claim) - else: - claim.employee_id = employee.id if employee is not None else claim.employee_id - claim.employee_name = ( - employee.name - if employee is not None - else self._resolve_employee_name( - ontology=ontology, - context_json=context_json, - user_id=user_id, - fallback=claim.employee_name, + fallback=claim.department_name, ) - ) - claim.department_id = employee.organization_unit_id if employee is not None else claim.department_id - claim.department_name = self._resolve_department_name( - employee=employee, - context_json=context_json, - fallback=claim.department_name, - ) - claim.project_code = self._resolve_project_code(ontology.entities) or claim.project_code - claim.expense_type = final_expense_type - claim.reason = final_reason - claim.location = final_location - claim.amount = final_amount - claim.invoice_count = final_attachment_count - claim.occurred_at = final_occurred_at - claim.status = "draft" - claim.approval_stage = "待提交" - claim.risk_flags_json = final_risk_flags + claim.project_code = self._resolve_project_code(ontology.entities) or claim.project_code + claim.expense_type = final_expense_type + claim.reason = final_reason + claim.location = final_location + claim.amount = final_amount + claim.invoice_count = final_attachment_count + claim.occurred_at = final_occurred_at + claim.status = "draft" + claim.approval_stage = "待提交" + claim.risk_flags_json = final_risk_flags - self.db.flush() - self._upsert_primary_item( - claim=claim, - occurred_at=final_occurred_at, - expense_type=final_expense_type, - amount=final_amount, - reason=final_reason, - location=final_location, - attachment_names=self._resolve_attachment_names(context_json), - ) - self.db.commit() - self.db.refresh(claim) + self.db.flush() + if context_documents or attachment_names: + document_specs = self._build_context_item_specs( + context_documents=context_documents, + attachment_names=attachment_names, + occurred_at=final_occurred_at, + expense_type=final_expense_type, + amount=final_amount, + reason=final_reason, + location=final_location, + ) + else: + document_specs = [] + + if document_specs and (is_new_claim or review_action in DOCUMENT_ASSOCIATION_REVIEW_ACTIONS): + if review_action == "link_to_existing_draft" and claim.items: + self._append_document_items( + claim=claim, + item_specs=document_specs, + ) + else: + self._replace_claim_items( + claim=claim, + item_specs=document_specs, + ) + self._sync_claim_from_items(claim) + else: + self._upsert_primary_item( + claim=claim, + occurred_at=final_occurred_at, + expense_type=final_expense_type, + amount=final_amount, + reason=final_reason, + location=final_location, + attachment_names=attachment_names, + ) + self._sync_claim_from_items(claim) + self.db.commit() + self.db.refresh(claim) + except IntegrityError as exc: + self.db.rollback() + if ( + is_new_claim + and retry_count < MAX_CLAIM_NO_RETRY_ATTEMPTS + and self._is_claim_no_conflict_error(exc) + ): + retry_context = dict(context_json) + retry_context["_claim_no_retry_count"] = retry_count + 1 + return self.upsert_draft_from_ontology( + run_id=run_id, + user_id=user_id, + message=message, + ontology=ontology, + context_json=retry_context, + ) + raise + except Exception: + self.db.rollback() + raise self.audit_service.log_action( actor=user_id or claim.employee_name or "anonymous", @@ -764,10 +905,20 @@ class ExpenseClaimService: *, ontology: OntologyParseResult, context_json: dict[str, Any], + review_action: str = "", + association_candidate: ExpenseClaim | None = None, ) -> ExpenseClaim | None: + if review_action == "create_new_claim_from_documents": + return None + if review_action == "link_to_existing_draft" and association_candidate is not None: + return association_candidate + draft_claim_id = str(context_json.get("draft_claim_id") or "").strip() if draft_claim_id: - return self.db.get(ExpenseClaim, draft_claim_id) + claim = self.db.get(ExpenseClaim, draft_claim_id) + if claim is not None and str(claim.status or "").strip() == "draft": + return claim + return None claim_codes = [ item.normalized_value @@ -777,9 +928,386 @@ class ExpenseClaimService: if not claim_codes: return None - stmt = select(ExpenseClaim).where(ExpenseClaim.claim_no.in_(claim_codes)).limit(1) + stmt = ( + select(ExpenseClaim) + .where(ExpenseClaim.claim_no.in_(claim_codes)) + .where(ExpenseClaim.status == "draft") + .limit(1) + ) return self.db.scalar(stmt) + def _find_association_candidate( + self, + *, + ontology: OntologyParseResult, + context_json: dict[str, Any], + user_id: str | None, + employee: Employee | None, + ) -> ExpenseClaim | None: + draft_claim_id = str(context_json.get("draft_claim_id") or "").strip() + if draft_claim_id: + claim = self.db.get(ExpenseClaim, draft_claim_id) + if claim is not None and str(claim.status or "").strip() == "draft": + return claim + + owner_filters = self._build_draft_owner_filters( + employee=employee, + user_id=user_id, + ) + if not owner_filters: + fallback_name = self._resolve_employee_name( + ontology=ontology, + context_json=context_json, + user_id=user_id, + fallback="", + ) + if fallback_name: + owner_filters = [ExpenseClaim.employee_name == fallback_name] + + if not owner_filters: + return None + + stmt = ( + select(ExpenseClaim) + .where(ExpenseClaim.status == "draft") + .where(or_(*owner_filters)) + .order_by(ExpenseClaim.updated_at.desc(), ExpenseClaim.created_at.desc()) + .limit(1) + ) + return self.db.scalar(stmt) + + def _should_defer_multi_document_association( + self, + *, + context_json: dict[str, Any], + review_action: str, + association_candidate: ExpenseClaim | None, + context_documents: list[dict[str, Any]], + ) -> bool: + if association_candidate is None: + return False + if review_action in DOCUMENT_ASSOCIATION_REVIEW_ACTIONS: + return False + document_count = max( + len(context_documents), + len(self._resolve_attachment_names(context_json)), + self._resolve_attachment_count(context_json), + ) + return document_count > 1 + + def _resolve_context_documents(self, context_json: dict[str, Any]) -> list[dict[str, Any]]: + documents = context_json.get("ocr_documents") + if not isinstance(documents, list): + documents = [] + + normalized: list[dict[str, Any]] = [] + for index, item in enumerate(documents[:10], start=1): + if not isinstance(item, dict): + continue + normalized.append( + { + "index": index, + "filename": str(item.get("filename") or "").strip(), + "summary": str(item.get("summary") or "").strip(), + "text": str(item.get("text") or "").strip(), + "document_type": str(item.get("document_type") or "").strip(), + "scene_code": str(item.get("scene_code") or "").strip(), + "scene_label": str(item.get("scene_label") or "").strip(), + "document_fields": self._normalize_document_fields(item.get("document_fields")), + } + ) + + overrides = context_json.get("review_document_form_values") + if not isinstance(overrides, list) or not normalized: + return normalized + + override_map: dict[tuple[int, str], dict[str, Any]] = {} + for item in overrides: + if not isinstance(item, dict): + continue + filename = str(item.get("filename") or "").strip() + index = int(item.get("index") or 0) + if not filename and index <= 0: + continue + override_map[(index, filename)] = item + + for item in normalized: + override = override_map.get((int(item["index"]), str(item["filename"]))) + if override is None: + override = override_map.get((int(item["index"]), "")) + if override is None: + continue + summary = str(override.get("summary") or "").strip() + scene_label = str(override.get("scene_label") or "").strip() + fields = override.get("fields") + if summary: + item["summary"] = summary + if scene_label: + item["scene_label"] = scene_label + if isinstance(fields, list): + item["document_fields"] = self._normalize_document_fields(fields) + + return normalized + + @staticmethod + def _normalize_document_fields(raw_fields: Any) -> list[dict[str, str]]: + if not isinstance(raw_fields, list): + return [] + normalized: list[dict[str, str]] = [] + for field in raw_fields: + if not isinstance(field, dict): + continue + label = str(field.get("label") or "").strip() + value = str(field.get("value") or "").strip() + key = str(field.get("key") or label or "").strip() + if not label or not value: + continue + normalized.append( + { + "key": key, + "label": label, + "value": value, + } + ) + return normalized + + def _build_context_item_specs( + self, + *, + context_documents: list[dict[str, Any]], + attachment_names: list[str], + occurred_at: datetime, + expense_type: str, + amount: Decimal, + reason: str, + location: str, + ) -> list[dict[str, Any]]: + specs: list[dict[str, Any]] = [] + if context_documents: + for document in context_documents: + specs.append( + { + "item_date": self._resolve_document_item_date(document, fallback=occurred_at.date()), + "item_type": self._resolve_document_item_type(document, fallback=expense_type), + "item_reason": reason, + "item_location": location, + "item_amount": self._resolve_document_item_amount(document), + "invoice_id": str(document.get("filename") or "").strip() or None, + } + ) + elif attachment_names: + for attachment_name in attachment_names: + specs.append( + { + "item_date": occurred_at.date(), + "item_type": expense_type, + "item_reason": reason, + "item_location": location, + "item_amount": None, + "invoice_id": attachment_name, + } + ) + + if not specs: + return [] + + total_recognized = sum( + spec["item_amount"] for spec in specs if isinstance(spec.get("item_amount"), Decimal) + ) + missing_specs = [spec for spec in specs if spec.get("item_amount") is None] + if missing_specs: + remaining = (amount - total_recognized).quantize(Decimal("0.01")) + if remaining > Decimal("0.00"): + missing_specs[0]["item_amount"] = remaining + + for spec in specs: + if spec.get("item_amount") is None: + spec["item_amount"] = Decimal("0.00") + + return specs + + def _replace_claim_items( + self, + *, + claim: ExpenseClaim, + item_specs: list[dict[str, Any]], + ) -> None: + existing_items = sorted( + list(claim.items), + key=lambda item: ( + item.item_date or date.max, + self._normalize_sort_datetime(item.created_at), + ), + ) + for index, spec in enumerate(item_specs): + item = existing_items[index] if index < len(existing_items) else None + if item is None: + item = ExpenseClaimItem(claim_id=claim.id) + claim.items.append(item) + self.db.add(item) + item.item_date = spec["item_date"] + item.item_type = spec["item_type"] + item.item_reason = spec["item_reason"] + item.item_location = spec["item_location"] + item.item_amount = spec["item_amount"] + item.invoice_id = spec["invoice_id"] + + for stale_item in existing_items[len(item_specs) :]: + claim.items.remove(stale_item) + self.db.delete(stale_item) + + def _append_document_items( + self, + *, + claim: ExpenseClaim, + item_specs: list[dict[str, Any]], + ) -> None: + existing_invoice_ids = { + str(item.invoice_id or "").strip() + for item in claim.items + if str(item.invoice_id or "").strip() + } + for spec in item_specs: + invoice_id = str(spec.get("invoice_id") or "").strip() + if invoice_id and invoice_id in existing_invoice_ids: + continue + claim.items.append( + ExpenseClaimItem( + claim_id=claim.id, + item_date=spec["item_date"], + item_type=spec["item_type"], + item_reason=spec["item_reason"], + item_location=spec["item_location"], + item_amount=spec["item_amount"], + invoice_id=spec["invoice_id"], + ) + ) + self.db.add(claim.items[-1]) + if invoice_id: + existing_invoice_ids.add(invoice_id) + + def _resolve_document_item_type(self, document: dict[str, Any], *, fallback: str) -> str: + scene_code = str(document.get("scene_code") or "").strip() + if scene_code in {"travel", "hotel", "transport", "meal", "office", "meeting", "training"}: + return scene_code + + document_type = str(document.get("document_type") or "").strip() + if document_type in {"flight_itinerary", "train_ticket"}: + return "travel" + if document_type in {"taxi_receipt", "parking_toll_receipt", "transport_receipt"}: + return "transport" + if document_type == "hotel_invoice": + return "hotel" + if document_type == "meal_receipt": + return "meal" + if document_type == "office_invoice": + return "office" + if document_type == "meeting_invoice": + return "meeting" + if document_type == "training_invoice": + return "training" + + scene_label = str(document.get("scene_label") or "").strip() + if "交通" in scene_label: + return "transport" + if "住宿" in scene_label: + return "hotel" + if "餐" in scene_label: + return "meal" + if "会务" in scene_label or "会议" in scene_label: + return "meeting" + if "培训" in scene_label: + return "training" + return fallback or "other" + + def _resolve_document_item_amount(self, document: dict[str, Any]) -> Decimal | None: + for field in list(document.get("document_fields") or []): + if not isinstance(field, dict): + continue + key = str(field.get("key") or "").strip().lower().replace("_", "") + label = str(field.get("label") or "").replace(" ", "") + value = self._parse_document_amount_value(str(field.get("value") or "")) + if value is None: + continue + if key in { + "amount", + "totalamount", + "paymentamount", + "paidamount", + "actualamount", + } or any( + token in label + for token in ("金额", "价税合计", "合计", "总额", "总计", "票价", "支付金额", "实付金额", "实收金额") + ): + return value + + text = " ".join( + [ + str(document.get("summary") or "").strip(), + str(document.get("text") or "").strip(), + ] + ).strip() + return self._parse_document_amount_value(text) + + def _parse_document_amount_value(self, value: str) -> Decimal | None: + raw_value = str(value or "").strip() + if not raw_value: + return None + for pattern in DOCUMENT_AMOUNT_PATTERNS: + match = pattern.search(raw_value) + if not match: + continue + numeric = str(match.group(1) or "").replace(",", ".").strip() + try: + amount = Decimal(numeric).quantize(Decimal("0.01")) + except (InvalidOperation, ValueError): + continue + if amount > Decimal("0.00"): + return amount + return None + + def _resolve_document_item_date(self, document: dict[str, Any], *, fallback: date) -> date: + for field in list(document.get("document_fields") or []): + if not isinstance(field, dict): + continue + key = str(field.get("key") or "").strip().lower().replace("_", "") + label = str(field.get("label") or "").replace(" ", "") + value = str(field.get("value") or "").strip() + if not value: + continue + if key in {"date", "time", "issuedat", "invoicedate"} or any( + token in label for token in ("日期", "时间", "开票日期", "发生时间") + ): + parsed = self._parse_document_date(value) + if parsed is not None: + return parsed + + parsed = self._parse_document_date( + " ".join( + [ + str(document.get("summary") or "").strip(), + str(document.get("text") or "").strip(), + ] + ).strip() + ) + return parsed or fallback + + @staticmethod + def _parse_document_date(value: str) -> date | None: + match = DOCUMENT_DATE_PATTERN.search(str(value or "")) + if not match: + return None + raw_value = str(match.group(1) or "").strip() + normalized = raw_value.replace("年", "-").replace("月", "-").replace("日", "") + normalized = normalized.replace("/", "-").replace(".", "-") + parts = [part for part in normalized.split("-") if part] + if len(parts) != 3: + return None + try: + return date(int(parts[0]), int(parts[1]), int(parts[2])) + except ValueError: + return None + def _upsert_primary_item( self, *, @@ -816,13 +1344,41 @@ class ExpenseClaimService: def _generate_claim_no(self, occurred_at: datetime) -> str: month_code = occurred_at.strftime("%Y%m") prefix = f"EXP-{month_code}-" - existing = int( - self.db.scalar( - select(func.count()).select_from(ExpenseClaim).where(ExpenseClaim.claim_no.like(f"{prefix}%")) + existing_claim_nos = list( + self.db.scalars( + select(ExpenseClaim.claim_no).where(ExpenseClaim.claim_no.like(f"{prefix}%")) + ) + ) + max_suffix = 0 + for claim_no in existing_claim_nos: + normalized = str(claim_no or "").strip() + if not normalized.startswith(prefix): + continue + suffix = normalized[len(prefix):] + if not suffix.isdigit(): + continue + max_suffix = max(max_suffix, int(suffix)) + return f"{prefix}{max_suffix + 1:03d}" + + @staticmethod + def _resolve_claim_no_retry_count(context_json: dict[str, Any]) -> int: + try: + return max(0, int(context_json.get("_claim_no_retry_count") or 0)) + except (TypeError, ValueError): + return 0 + + @staticmethod + def _is_claim_no_conflict_error(exc: IntegrityError) -> bool: + message = str(exc).lower() + return ( + "claim_no" in message + and ( + "unique" in message + or "duplicate key" in message + or "ix_expense_claims_claim_no" in message + or "expense_claims.claim_no" in message ) - or 0 ) - return f"{prefix}{existing + 1:03d}" def _count_draft_claims_for_owner( self, @@ -1011,6 +1567,13 @@ class ExpenseClaimService: if value: return value + explicit_text = context_json.get("user_input_text") + if isinstance(explicit_text, str): + normalized_explicit_text = explicit_text.strip() + if normalized_explicit_text: + return normalized_explicit_text[:500] + return None + request_context = context_json.get("request_context") if ( isinstance(request_context, dict) @@ -1022,7 +1585,12 @@ class ExpenseClaimService: return value if not allow_message_fallback: return None - return str(message or "").strip()[:500] or None + + normalized_message = str(message or "").strip() + compact_message = re.sub(r"\s+", "", normalized_message) + if compact_message.startswith(SYSTEM_GENERATED_REASON_PREFIXES): + return None + return normalized_message[:500] or None @staticmethod def _resolve_location(*, message: str, context_json: dict[str, Any]) -> str | None: @@ -1210,6 +1778,74 @@ class ExpenseClaimService: return {} return payload if isinstance(payload, dict) else {} + def _build_attachment_preview_meta( + self, + *, + file_path: Path, + media_type: str, + ocr_document: Any | None, + ) -> dict[str, Any]: + filename = file_path.name + storage_key = self._to_attachment_storage_key(file_path) + preview_kind = self._resolve_preview_kind(media_type, filename) + + preview_data_url = str(getattr(ocr_document, "preview_data_url", "") or "").strip() + preview_source_kind = str(getattr(ocr_document, "preview_kind", "") or "").strip() + if preview_source_kind == "image" and preview_data_url: + preview_asset = self._write_preview_asset_from_data_url( + attachment_dir=file_path.parent, + original_filename=filename, + preview_data_url=preview_data_url, + ) + if preview_asset is not None: + preview_path, preview_media_type, preview_file_name = preview_asset + return { + "previewable": True, + "preview_kind": "image", + "preview_storage_key": self._to_attachment_storage_key(preview_path), + "preview_media_type": preview_media_type, + "preview_file_name": preview_file_name, + } + + if preview_kind: + return { + "previewable": True, + "preview_kind": preview_kind, + "preview_storage_key": storage_key, + "preview_media_type": media_type, + "preview_file_name": filename, + } + + return { + "previewable": False, + "preview_kind": "", + "preview_storage_key": "", + "preview_media_type": "", + "preview_file_name": "", + } + + def _resolve_item_attachment_preview_content(self, item: ExpenseClaimItem) -> tuple[Path, str, str]: + file_path, media_type, filename = self._resolve_item_attachment_content(item) + metadata = self._read_attachment_meta(file_path) + preview_storage_key = str(metadata.get("preview_storage_key") or "").strip() + preview_file_name = str(metadata.get("preview_file_name") or "").strip() + preview_media_type = str(metadata.get("preview_media_type") or "").strip() + + if preview_storage_key: + preview_path = self._resolve_attachment_path(preview_storage_key) + if preview_path is not None and preview_path.exists(): + resolved_name = preview_file_name or preview_path.name + resolved_media_type = self._resolve_attachment_media_type( + resolved_name, + fallback=preview_media_type, + ) + return preview_path, resolved_media_type, resolved_name + + if self._is_previewable_media_type(media_type, filename): + return file_path, media_type, filename + + raise FileNotFoundError("Attachment preview not found") + def _build_attachment_payload(self, item: ExpenseClaimItem) -> dict[str, Any]: file_path, media_type, filename = self._resolve_item_attachment_content(item) metadata = self._read_attachment_meta(file_path) @@ -1233,18 +1869,71 @@ class ExpenseClaimService: if not isinstance(requirement_check, dict): requirement_check = None + preview_kind = str(metadata.get("preview_kind") or "").strip() + previewable = bool(metadata.get("previewable", self._is_previewable_media_type(media_type, filename))) + preview_url = self._build_attachment_preview_client_path(item.claim_id, item.id) if previewable else "" + return { "file_name": str(metadata.get("file_name") or filename), "storage_key": str(item.invoice_id or ""), "media_type": str(metadata.get("media_type") or media_type), "size_bytes": int(metadata.get("size_bytes") or file_path.stat().st_size), "uploaded_at": uploaded_at, - "previewable": bool(metadata.get("previewable", self._is_previewable_media_type(media_type, filename))), + "previewable": previewable, + "preview_kind": preview_kind or self._resolve_preview_kind(media_type, filename), + "preview_url": preview_url, "analysis": analysis, "document_info": document_info, "requirement_check": requirement_check, } + @staticmethod + def _resolve_preview_kind(media_type: str | None, filename: str) -> str: + resolved = str(media_type or "").strip() or (mimetypes.guess_type(filename)[0] or "") + if resolved.startswith("image/"): + return "image" + if resolved == "application/pdf": + return "pdf" + return "" + + @staticmethod + def _decode_data_url(payload: str) -> tuple[str, bytes] | None: + normalized = str(payload or "").strip() + matched = re.match(r"^data:(?P[\w.+-]+/[\w.+-]+);base64,(?P.+)$", normalized, flags=re.DOTALL) + if not matched: + return None + try: + content = base64.b64decode(matched.group("body"), validate=True) + except (binascii.Error, ValueError): + return None + return matched.group("media"), content + + def _write_preview_asset_from_data_url( + self, + *, + attachment_dir: Path, + original_filename: str, + preview_data_url: str, + ) -> tuple[Path, str, str] | None: + decoded = self._decode_data_url(preview_data_url) + if decoded is None: + return None + + preview_media_type, preview_content = decoded + suffix = mimetypes.guess_extension(preview_media_type) or ".bin" + preview_name = f"{Path(original_filename).stem}.preview{suffix}" + preview_path = attachment_dir / preview_name + preview_path.write_bytes(preview_content) + return preview_path, preview_media_type, preview_name + + @staticmethod + def _build_attachment_preview_client_path(claim_id: str, item_id: str) -> str: + return ( + "/reimbursements/claims/" + f"{quote(str(claim_id or '').strip(), safe='')}" + f"/items/{quote(str(item_id or '').strip(), safe='')}/attachment/preview" + ) + @staticmethod def _resolve_attachment_media_type(filename: str, *, fallback: str | None = None) -> str: guessed = mimetypes.guess_type(filename)[0] diff --git a/server/tests/test_expense_claim_service.py b/server/tests/test_expense_claim_service.py index b0a5f27..93053f5 100644 --- a/server/tests/test_expense_claim_service.py +++ b/server/tests/test_expense_claim_service.py @@ -11,9 +11,11 @@ from app.api.deps import CurrentUserContext from app.db.base import Base from app.models.employee import Employee from app.models.financial_record import ExpenseClaim, ExpenseClaimItem +from app.schemas.ontology import OntologyParseRequest from app.schemas.ocr import OcrRecognizeBatchRead, OcrRecognizeDocumentRead from app.schemas.reimbursement import ExpenseClaimItemCreate, ExpenseClaimItemUpdate from app.services.expense_claims import ExpenseClaimService +from app.services.ontology import SemanticOntologyService from app.services.ocr import OcrService @@ -97,6 +99,347 @@ def test_resolve_expense_type_maps_office_supplies_review_value_to_office() -> N assert expense_type == "office" +def test_upsert_draft_from_ontology_defers_multi_document_association_choice() -> None: + user_id = "zhangsan@example.com" + + with build_session() as db: + employee = Employee( + employee_no="E5001", + name="张三", + email=user_id, + ) + db.add(employee) + db.flush() + existing_claim = ExpenseClaim( + claim_no="EXP-202605-010", + employee_id=employee.id, + employee_name="张三", + department_name="市场部", + project_code=None, + expense_type="transport", + reason="原有交通报销", + location="深圳", + amount=Decimal("20.00"), + currency="CNY", + invoice_count=1, + occurred_at=datetime(2026, 5, 13, tzinfo=UTC), + status="draft", + approval_stage="待提交", + risk_flags_json=[], + ) + existing_claim.items = [ + ExpenseClaimItem( + claim_id=existing_claim.id, + item_date=date(2026, 5, 13), + item_type="transport", + item_reason="原有交通报销", + item_location="深圳", + item_amount=Decimal("20.00"), + invoice_id="old-trip.png", + ) + ] + db.add(existing_claim) + db.commit() + + ontology = SemanticOntologyService(db).parse( + OntologyParseRequest( + query="我上传了两张交通票据,帮我生成报销草稿", + user_id=user_id, + ) + ) + service = ExpenseClaimService(db) + result = service.upsert_draft_from_ontology( + run_id=ontology.run_id, + user_id=user_id, + message="我上传了两张交通票据,帮我生成报销草稿", + ontology=ontology, + context_json={ + "name": "张三", + "attachment_names": ["didi-trip.png", "parking-ticket.jpg"], + "attachment_count": 2, + "draft_claim_id": existing_claim.id, + "ocr_documents": [ + { + "filename": "didi-trip.png", + "summary": "滴滴出行 支付金额 32 元", + "text": "滴滴出行 支付金额 32 元", + }, + { + "filename": "parking-ticket.jpg", + "summary": "停车费 合计 18 元", + "text": "停车费 合计 18 元", + }, + ], + }, + ) + + db.refresh(existing_claim) + assert result["pending_association_decision"] is True + assert result["association_candidate_claim_id"] == existing_claim.id + assert existing_claim.invoice_count == 1 + assert len(existing_claim.items) == 1 + assert existing_claim.items[0].invoice_id == "old-trip.png" + + +def test_upsert_draft_from_ontology_keeps_reason_missing_for_attachment_only_upload() -> None: + user_id = "wangwu@example.com" + + with build_session() as db: + employee = Employee( + employee_no="E5003", + name="王五", + email=user_id, + ) + db.add(employee) + db.commit() + + ontology = SemanticOntologyService(db).parse( + OntologyParseRequest( + query="我上传了 1 份票据,请结合附件名称给出报销建议并尽量生成草稿。", + user_id=user_id, + ) + ) + service = ExpenseClaimService(db) + result = service.upsert_draft_from_ontology( + run_id=ontology.run_id, + user_id=user_id, + message="我上传了 1 份票据,请结合附件名称给出报销建议并尽量生成草稿。\n附件名称:didi-trip.png", + ontology=ontology, + context_json={ + "name": "王五", + "user_input_text": "", + "attachment_names": ["didi-trip.png"], + "attachment_count": 1, + "ocr_documents": [ + { + "filename": "didi-trip.png", + "summary": "滴滴出行 支付金额 32 元", + "text": "滴滴出行 支付金额 32 元", + "document_type": "taxi_receipt", + "scene_code": "transport", + } + ], + }, + ) + + claim = db.get(ExpenseClaim, result["claim_id"]) + assert claim is not None + assert claim.reason == "待补充" + + +def test_upsert_draft_from_ontology_supports_link_or_create_for_multi_documents() -> None: + user_id = "lisi@example.com" + + with build_session() as db: + employee = Employee( + employee_no="E5002", + name="李四", + email=user_id, + ) + db.add(employee) + db.flush() + existing_claim = ExpenseClaim( + claim_no="EXP-202605-011", + employee_id=employee.id, + employee_name="李四", + department_name="销售部", + project_code=None, + expense_type="transport", + reason="原有交通报销", + location="上海", + amount=Decimal("20.00"), + currency="CNY", + invoice_count=1, + occurred_at=datetime(2026, 5, 13, tzinfo=UTC), + status="draft", + approval_stage="待提交", + risk_flags_json=[], + ) + existing_claim.items = [ + ExpenseClaimItem( + claim_id=existing_claim.id, + item_date=date(2026, 5, 13), + item_type="transport", + item_reason="原有交通报销", + item_location="上海", + item_amount=Decimal("20.00"), + invoice_id="existing.png", + ) + ] + db.add(existing_claim) + db.commit() + + ontology = SemanticOntologyService(db).parse( + OntologyParseRequest( + query="我上传了两张交通票据,帮我生成报销草稿", + user_id=user_id, + ) + ) + service = ExpenseClaimService(db) + context_json = { + "name": "李四", + "attachment_names": ["didi-trip.png", "parking-ticket.jpg"], + "attachment_count": 2, + "draft_claim_id": existing_claim.id, + "ocr_documents": [ + { + "filename": "didi-trip.png", + "summary": "滴滴出行", + "text": "滴滴出行 支付金额 32.50 元", + "document_type": "taxi_receipt", + "scene_code": "transport", + "document_fields": [{"key": "amount", "label": "支付金额", "value": "32.50"}], + }, + { + "filename": "parking-ticket.jpg", + "summary": "停车票", + "text": "停车费 合计 18 元", + "document_type": "parking_toll_receipt", + "scene_code": "transport", + "document_fields": [{"key": "total_amount", "label": "合计金额", "value": "18"}], + }, + ], + } + + link_result = service.upsert_draft_from_ontology( + run_id=ontology.run_id, + user_id=user_id, + message="把这两张票据关联到已有草稿", + ontology=ontology, + context_json={ + **context_json, + "review_action": "link_to_existing_draft", + }, + ) + + db.refresh(existing_claim) + assert link_result["claim_id"] == existing_claim.id + assert existing_claim.invoice_count == 3 + assert len(existing_claim.items) == 3 + assert float(existing_claim.amount) == 70.5 + + create_result = service.upsert_draft_from_ontology( + run_id=f"{ontology.run_id}-new", + user_id=user_id, + message="单独新建一张报销单", + ontology=ontology, + context_json={ + **context_json, + "review_action": "create_new_claim_from_documents", + }, + ) + + assert create_result["claim_id"] != existing_claim.id + new_claim = db.get(ExpenseClaim, create_result["claim_id"]) + assert new_claim is not None + assert new_claim.invoice_count == 2 + assert len(new_claim.items) == 2 + assert float(new_claim.amount) == 50.5 + + +def test_generate_claim_no_uses_max_suffix_instead_of_count() -> None: + with build_session() as db: + db.add_all( + [ + ExpenseClaim( + claim_no="EXP-202605-001", + employee_name="张三", + department_name="市场部", + project_code=None, + expense_type="transport", + reason="交通报销", + location="深圳", + amount=Decimal("10.00"), + currency="CNY", + invoice_count=1, + occurred_at=datetime(2026, 5, 10, tzinfo=UTC), + status="draft", + approval_stage="待提交", + risk_flags_json=[], + ), + ExpenseClaim( + claim_no="EXP-202605-003", + employee_name="李四", + department_name="销售部", + project_code=None, + expense_type="transport", + reason="交通报销", + location="上海", + amount=Decimal("20.00"), + currency="CNY", + invoice_count=1, + occurred_at=datetime(2026, 5, 11, tzinfo=UTC), + status="submitted", + approval_stage="审批中", + risk_flags_json=[], + ), + ] + ) + db.commit() + + service = ExpenseClaimService(db) + + assert service._generate_claim_no(datetime(2026, 5, 14, tzinfo=UTC)) == "EXP-202605-004" + + +def test_upsert_draft_from_ontology_retries_claim_no_conflict() -> None: + user_id = "zhaoliu-claimno@example.com" + + with build_session() as db: + employee = Employee( + employee_no="E5006", + name="赵六", + email=user_id, + ) + db.add(employee) + db.flush() + db.add( + ExpenseClaim( + claim_no="EXP-202605-004", + employee_name="历史单据", + department_name="财务部", + project_code=None, + expense_type="other", + reason="历史草稿", + location="北京", + amount=Decimal("0.00"), + currency="CNY", + invoice_count=0, + occurred_at=datetime(2026, 5, 12, tzinfo=UTC), + status="submitted", + approval_stage="审批中", + risk_flags_json=[], + ) + ) + db.commit() + + ontology = SemanticOntologyService(db).parse( + OntologyParseRequest( + query="帮我生成报销草稿,我昨天交通费 13.4 元", + user_id=user_id, + ) + ) + service = ExpenseClaimService(db) + generated_claim_nos = iter(["EXP-202605-004", "EXP-202605-005"]) + service._generate_claim_no = lambda occurred_at: next(generated_claim_nos) + + result = service.upsert_draft_from_ontology( + run_id=ontology.run_id, + user_id=user_id, + message="帮我生成报销草稿,我昨天交通费 13.4 元", + ontology=ontology, + context_json={ + "name": "赵六", + "user_input_text": "帮我生成报销草稿,我昨天交通费 13.4 元", + }, + ) + + created_claim = db.get(ExpenseClaim, result["claim_id"]) + assert created_claim is not None + assert created_claim.claim_no == "EXP-202605-005" + assert result["claim_no"] == "EXP-202605-005" + + def test_create_claim_item_adds_blank_draft_row_without_forcing_attachment() -> None: current_user = CurrentUserContext( username="emp-1", @@ -186,6 +529,10 @@ def test_update_claim_item_reanalyzes_existing_attachment(monkeypatch, tmp_path) current_user=current_user, ) assert uploaded_meta is not None + assert uploaded_meta["preview_kind"] == "image" + assert uploaded_meta["preview_url"].endswith( + f"/reimbursements/claims/{claim.id}/items/{claim.items[0].id}/attachment/preview" + ) assert uploaded_meta["analysis"]["severity"] == "pass" assert uploaded_meta["document_info"]["document_type"] == "office_invoice" assert uploaded_meta["requirement_check"]["matches"] is True diff --git a/server/tests/test_reimbursement_endpoints.py b/server/tests/test_reimbursement_endpoints.py index be05879..ac9144b 100644 --- a/server/tests/test_reimbursement_endpoints.py +++ b/server/tests/test_reimbursement_endpoints.py @@ -1,5 +1,6 @@ from __future__ import annotations +import base64 from collections.abc import Generator from datetime import UTC, date, datetime from decimal import Decimal @@ -165,9 +166,18 @@ def test_claim_item_attachment_upload_preview_and_delete(monkeypatch, tmp_path) assert meta_response.status_code == 200 meta_payload = meta_response.json() assert meta_payload["media_type"] == "image/png" + assert meta_payload["preview_kind"] == "image" + assert meta_payload["preview_url"].endswith(f"/reimbursements/claims/{claim_id}/items/{item_id}/attachment/preview") assert meta_payload["analysis"]["headline"] assert meta_payload["document_info"]["fields"][0]["label"] == "金额" + preview_response = client.get( + f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment/preview", + headers=headers, + ) + assert preview_response.status_code == 200 + assert preview_response.content == file_bytes + content_response = client.get( f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment", headers=headers, @@ -279,6 +289,67 @@ def test_claim_item_attachment_upload_flags_non_invoice_image_as_high_risk(monke assert any("附件内容" in point for point in analysis["points"]) +def test_claim_item_pdf_attachment_preview_returns_generated_image(monkeypatch, tmp_path) -> None: + preview_bytes = b"fake-preview-png" + preview_data_url = f"data:image/png;base64,{base64.b64encode(preview_bytes).decode('ascii')}" + + def fake_recognize( + self, + files: list[tuple[str, bytes, str | None]], + ) -> OcrRecognizeBatchRead: + return OcrRecognizeBatchRead( + total_file_count=1, + success_count=1, + documents=[ + OcrRecognizeDocumentRead( + filename="invoice.pdf", + media_type="application/pdf", + text="滴滴出行电子发票 金额13.4元", + summary="识别到交通票据,金额 13.4 元。", + avg_score=0.96, + line_count=1, + page_count=1, + document_type="taxi_receipt", + document_type_label="出租车/网约车票据", + scene_code="transport", + scene_label="交通票据", + preview_kind="image", + preview_data_url=preview_data_url, + warnings=[], + ) + ], + ) + + monkeypatch.setattr(OcrService, "recognize_files", fake_recognize) + monkeypatch.setattr(ExpenseClaimService, "_get_attachment_storage_root", lambda self: tmp_path) + + client, session_factory = build_client() + with session_factory() as db: + claim, item = seed_claim(db) + claim_id = claim.id + item_id = item.id + + headers = {"x-auth-username": "emp-1", "x-auth-name": "Zhang San"} + upload_response = client.post( + f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment", + headers=headers, + files=[("file", ("invoice.pdf", b"%PDF-1.4 fake", "application/pdf"))], + ) + + assert upload_response.status_code == 200 + meta_payload = upload_response.json()["attachment"] + assert meta_payload["preview_kind"] == "image" + assert meta_payload["preview_url"].endswith(f"/reimbursements/claims/{claim_id}/items/{item_id}/attachment/preview") + + preview_response = client.get( + f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment/preview", + headers=headers, + ) + assert preview_response.status_code == 200 + assert preview_response.headers["content-type"].startswith("image/png") + assert preview_response.content == preview_bytes + + def test_claim_item_delete_removes_item_and_attachment(monkeypatch, tmp_path) -> None: def fake_recognize( self,