feat(server): 重构报销单服务,优化费用报销流程和数据校验逻辑,包含schema定义和服务实现

This commit is contained in:
caoxiaozhu
2026-05-14 15:42:45 +00:00
parent ad16358e71
commit e21f0d82e9
4 changed files with 1187 additions and 78 deletions

View File

@@ -87,6 +87,8 @@ class ExpenseClaimAttachmentRead(BaseModel):
size_bytes: int
uploaded_at: datetime | None = None
previewable: bool = True
preview_kind: str = ""
preview_url: str = ""
analysis: ExpenseClaimAttachmentAnalysisRead | None = None
document_info: ExpenseClaimAttachmentDocumentInfoRead | None = None
requirement_check: ExpenseClaimAttachmentRequirementRead | None = None

View File

@@ -1,5 +1,7 @@
from __future__ import annotations
import base64
import binascii
import json
import mimetypes
import re
@@ -9,8 +11,10 @@ from decimal import Decimal, InvalidOperation
from pathlib import Path
from types import SimpleNamespace
from typing import Any
from urllib.parse import quote
from sqlalchemy import func, or_, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session, selectinload
from app.api.deps import CurrentUserContext
@@ -102,6 +106,32 @@ DOCUMENT_SCENE_LABELS = {
"other": "其他票据",
}
DOCUMENT_ASSOCIATION_REVIEW_ACTIONS = {
"link_to_existing_draft",
"create_new_claim_from_documents",
}
MAX_CLAIM_NO_RETRY_ATTEMPTS = 3
DOCUMENT_AMOUNT_PATTERNS = (
re.compile(
r"(?:价税合计|合计金额|费用合计|订单(?:总)?金额|支付(?:金额)?|实付(?:金额)?|实收(?:金额)?|总(?:额|计|价)|票价|金额|车费|消费金额)"
r"[:\s¥¥人民币]*([0-9]+(?:[.,][0-9]{1,2})?)"
),
re.compile(r"[¥¥]\s*([0-9]+(?:[.,][0-9]{1,2})?)"),
re.compile(r"([0-9]+(?:[.,][0-9]{1,2})?)\s*元"),
)
DOCUMENT_DATE_PATTERN = re.compile(r"((?:20\d{2}|19\d{2})[-/年.](?:1[0-2]|0?[1-9])[-/月.](?:3[01]|[12]\d|0?[1-9])日?)")
SYSTEM_GENERATED_REASON_PREFIXES = (
"我上传了",
"请按当前已识别信息",
"请把当前上传的票据",
"请基于当前上传的多张票据",
"我已核对右侧识别结果",
"请同步修正逐票据识别结果",
"我已修改识别信息",
"查看报销草稿",
"请解释一下当前这笔报销的合规风险和待补充项",
)
class ExpenseClaimService:
def __init__(self, db: Session) -> None:
@@ -314,6 +344,10 @@ class ExpenseClaimService:
file_path = attachment_dir / normalized_name
file_path.write_bytes(content)
resolved_media_type = self._resolve_attachment_media_type(
normalized_name,
fallback=media_type,
)
attachment_analysis = self._build_fallback_attachment_analysis(
media_type=media_type,
@@ -353,16 +387,22 @@ class ExpenseClaimService:
)
item.invoice_id = self._to_attachment_storage_key(file_path)
preview_meta = self._build_attachment_preview_meta(
file_path=file_path,
media_type=resolved_media_type,
ocr_document=ocr_document,
)
meta = {
"file_name": normalized_name,
"storage_key": item.invoice_id,
"media_type": self._resolve_attachment_media_type(
normalized_name,
fallback=media_type,
),
"media_type": resolved_media_type,
"size_bytes": len(content),
"uploaded_at": datetime.now(UTC).isoformat(),
"previewable": self._is_previewable_media_type(media_type, normalized_name),
"previewable": bool(preview_meta["previewable"]),
"preview_kind": str(preview_meta["preview_kind"]),
"preview_storage_key": str(preview_meta["preview_storage_key"]),
"preview_media_type": str(preview_meta["preview_media_type"]),
"preview_file_name": str(preview_meta["preview_file_name"]),
"analysis": attachment_analysis,
"document_info": document_info,
"requirement_check": requirement_check,
@@ -438,6 +478,23 @@ class ExpenseClaimService:
return self._resolve_item_attachment_content(item)
def get_claim_item_attachment_preview_content(
self,
*,
claim_id: str,
item_id: str,
current_user: CurrentUserContext,
) -> tuple[Path, str, str] | None:
claim, item = self._get_claim_item_or_raise(
claim_id=claim_id,
item_id=item_id,
current_user=current_user,
)
if claim is None:
return None
return self._resolve_item_attachment_preview_content(item)
def delete_claim_item_attachment(
self,
*,
@@ -609,10 +666,12 @@ class ExpenseClaimService:
context_json: dict[str, Any],
) -> dict[str, Any]:
self._ensure_ready()
context_json = dict(context_json or {})
retry_count = self._resolve_claim_no_retry_count(context_json)
claim = self._find_target_claim(ontology=ontology, context_json=context_json)
is_new_claim = claim is None
before_json = self._serialize_claim(claim) if claim is not None else None
review_action = str(context_json.get("review_action") or "").strip()
attachment_names = self._resolve_attachment_names(context_json)
context_documents = self._resolve_context_documents(context_json)
employee = self._resolve_employee(
ontology=ontology,
@@ -628,6 +687,40 @@ class ExpenseClaimService:
user_id=user_id,
)
)
association_candidate = self._find_association_candidate(
ontology=ontology,
context_json=context_json,
user_id=user_id,
employee=employee,
)
if self._should_defer_multi_document_association(
context_json=context_json,
review_action=review_action,
association_candidate=association_candidate,
context_documents=context_documents,
):
document_count = max(len(context_documents), len(attachment_names), self._resolve_attachment_count(context_json))
return {
"message": (
f"检测到你已有草稿 {association_candidate.claim_no}"
f"当前新上传了 {document_count} 张票据,请先选择关联到现有草稿,或单独建立新的报销单。"
),
"draft_only": False,
"status": "pending_association_decision",
"pending_association_decision": True,
"association_candidate_claim_id": association_candidate.id,
"association_candidate_claim_no": association_candidate.claim_no,
}
claim = self._find_target_claim(
ontology=ontology,
context_json=context_json,
review_action=review_action,
association_candidate=association_candidate,
)
is_new_claim = claim is None
before_json = self._serialize_claim(claim) if claim is not None else None
if is_new_claim:
existing_draft_count = self._count_draft_claims_for_owner(
employee=employee,
@@ -655,7 +748,7 @@ class ExpenseClaimService:
context_json=context_json,
allow_message_fallback=is_new_claim,
)
attachment_count = self._resolve_attachment_count(context_json)
attachment_count = len(attachment_names) or self._resolve_attachment_count(context_json)
final_amount = amount if amount is not None else (claim.amount if claim is not None else Decimal("0.00"))
final_occurred_at = (
@@ -671,6 +764,7 @@ class ExpenseClaimService:
list(claim.risk_flags_json or []) if claim is not None else []
)
try:
if claim is None:
claim = ExpenseClaim(
claim_no=self._generate_claim_no(final_occurred_at),
@@ -724,6 +818,32 @@ class ExpenseClaimService:
claim.risk_flags_json = final_risk_flags
self.db.flush()
if context_documents or attachment_names:
document_specs = self._build_context_item_specs(
context_documents=context_documents,
attachment_names=attachment_names,
occurred_at=final_occurred_at,
expense_type=final_expense_type,
amount=final_amount,
reason=final_reason,
location=final_location,
)
else:
document_specs = []
if document_specs and (is_new_claim or review_action in DOCUMENT_ASSOCIATION_REVIEW_ACTIONS):
if review_action == "link_to_existing_draft" and claim.items:
self._append_document_items(
claim=claim,
item_specs=document_specs,
)
else:
self._replace_claim_items(
claim=claim,
item_specs=document_specs,
)
self._sync_claim_from_items(claim)
else:
self._upsert_primary_item(
claim=claim,
occurred_at=final_occurred_at,
@@ -731,10 +851,31 @@ class ExpenseClaimService:
amount=final_amount,
reason=final_reason,
location=final_location,
attachment_names=self._resolve_attachment_names(context_json),
attachment_names=attachment_names,
)
self._sync_claim_from_items(claim)
self.db.commit()
self.db.refresh(claim)
except IntegrityError as exc:
self.db.rollback()
if (
is_new_claim
and retry_count < MAX_CLAIM_NO_RETRY_ATTEMPTS
and self._is_claim_no_conflict_error(exc)
):
retry_context = dict(context_json)
retry_context["_claim_no_retry_count"] = retry_count + 1
return self.upsert_draft_from_ontology(
run_id=run_id,
user_id=user_id,
message=message,
ontology=ontology,
context_json=retry_context,
)
raise
except Exception:
self.db.rollback()
raise
self.audit_service.log_action(
actor=user_id or claim.employee_name or "anonymous",
@@ -764,10 +905,20 @@ class ExpenseClaimService:
*,
ontology: OntologyParseResult,
context_json: dict[str, Any],
review_action: str = "",
association_candidate: ExpenseClaim | None = None,
) -> ExpenseClaim | None:
if review_action == "create_new_claim_from_documents":
return None
if review_action == "link_to_existing_draft" and association_candidate is not None:
return association_candidate
draft_claim_id = str(context_json.get("draft_claim_id") or "").strip()
if draft_claim_id:
return self.db.get(ExpenseClaim, draft_claim_id)
claim = self.db.get(ExpenseClaim, draft_claim_id)
if claim is not None and str(claim.status or "").strip() == "draft":
return claim
return None
claim_codes = [
item.normalized_value
@@ -777,9 +928,386 @@ class ExpenseClaimService:
if not claim_codes:
return None
stmt = select(ExpenseClaim).where(ExpenseClaim.claim_no.in_(claim_codes)).limit(1)
stmt = (
select(ExpenseClaim)
.where(ExpenseClaim.claim_no.in_(claim_codes))
.where(ExpenseClaim.status == "draft")
.limit(1)
)
return self.db.scalar(stmt)
def _find_association_candidate(
self,
*,
ontology: OntologyParseResult,
context_json: dict[str, Any],
user_id: str | None,
employee: Employee | None,
) -> ExpenseClaim | None:
draft_claim_id = str(context_json.get("draft_claim_id") or "").strip()
if draft_claim_id:
claim = self.db.get(ExpenseClaim, draft_claim_id)
if claim is not None and str(claim.status or "").strip() == "draft":
return claim
owner_filters = self._build_draft_owner_filters(
employee=employee,
user_id=user_id,
)
if not owner_filters:
fallback_name = self._resolve_employee_name(
ontology=ontology,
context_json=context_json,
user_id=user_id,
fallback="",
)
if fallback_name:
owner_filters = [ExpenseClaim.employee_name == fallback_name]
if not owner_filters:
return None
stmt = (
select(ExpenseClaim)
.where(ExpenseClaim.status == "draft")
.where(or_(*owner_filters))
.order_by(ExpenseClaim.updated_at.desc(), ExpenseClaim.created_at.desc())
.limit(1)
)
return self.db.scalar(stmt)
def _should_defer_multi_document_association(
self,
*,
context_json: dict[str, Any],
review_action: str,
association_candidate: ExpenseClaim | None,
context_documents: list[dict[str, Any]],
) -> bool:
if association_candidate is None:
return False
if review_action in DOCUMENT_ASSOCIATION_REVIEW_ACTIONS:
return False
document_count = max(
len(context_documents),
len(self._resolve_attachment_names(context_json)),
self._resolve_attachment_count(context_json),
)
return document_count > 1
def _resolve_context_documents(self, context_json: dict[str, Any]) -> list[dict[str, Any]]:
documents = context_json.get("ocr_documents")
if not isinstance(documents, list):
documents = []
normalized: list[dict[str, Any]] = []
for index, item in enumerate(documents[:10], start=1):
if not isinstance(item, dict):
continue
normalized.append(
{
"index": index,
"filename": str(item.get("filename") or "").strip(),
"summary": str(item.get("summary") or "").strip(),
"text": str(item.get("text") or "").strip(),
"document_type": str(item.get("document_type") or "").strip(),
"scene_code": str(item.get("scene_code") or "").strip(),
"scene_label": str(item.get("scene_label") or "").strip(),
"document_fields": self._normalize_document_fields(item.get("document_fields")),
}
)
overrides = context_json.get("review_document_form_values")
if not isinstance(overrides, list) or not normalized:
return normalized
override_map: dict[tuple[int, str], dict[str, Any]] = {}
for item in overrides:
if not isinstance(item, dict):
continue
filename = str(item.get("filename") or "").strip()
index = int(item.get("index") or 0)
if not filename and index <= 0:
continue
override_map[(index, filename)] = item
for item in normalized:
override = override_map.get((int(item["index"]), str(item["filename"])))
if override is None:
override = override_map.get((int(item["index"]), ""))
if override is None:
continue
summary = str(override.get("summary") or "").strip()
scene_label = str(override.get("scene_label") or "").strip()
fields = override.get("fields")
if summary:
item["summary"] = summary
if scene_label:
item["scene_label"] = scene_label
if isinstance(fields, list):
item["document_fields"] = self._normalize_document_fields(fields)
return normalized
@staticmethod
def _normalize_document_fields(raw_fields: Any) -> list[dict[str, str]]:
if not isinstance(raw_fields, list):
return []
normalized: list[dict[str, str]] = []
for field in raw_fields:
if not isinstance(field, dict):
continue
label = str(field.get("label") or "").strip()
value = str(field.get("value") or "").strip()
key = str(field.get("key") or label or "").strip()
if not label or not value:
continue
normalized.append(
{
"key": key,
"label": label,
"value": value,
}
)
return normalized
def _build_context_item_specs(
self,
*,
context_documents: list[dict[str, Any]],
attachment_names: list[str],
occurred_at: datetime,
expense_type: str,
amount: Decimal,
reason: str,
location: str,
) -> list[dict[str, Any]]:
specs: list[dict[str, Any]] = []
if context_documents:
for document in context_documents:
specs.append(
{
"item_date": self._resolve_document_item_date(document, fallback=occurred_at.date()),
"item_type": self._resolve_document_item_type(document, fallback=expense_type),
"item_reason": reason,
"item_location": location,
"item_amount": self._resolve_document_item_amount(document),
"invoice_id": str(document.get("filename") or "").strip() or None,
}
)
elif attachment_names:
for attachment_name in attachment_names:
specs.append(
{
"item_date": occurred_at.date(),
"item_type": expense_type,
"item_reason": reason,
"item_location": location,
"item_amount": None,
"invoice_id": attachment_name,
}
)
if not specs:
return []
total_recognized = sum(
spec["item_amount"] for spec in specs if isinstance(spec.get("item_amount"), Decimal)
)
missing_specs = [spec for spec in specs if spec.get("item_amount") is None]
if missing_specs:
remaining = (amount - total_recognized).quantize(Decimal("0.01"))
if remaining > Decimal("0.00"):
missing_specs[0]["item_amount"] = remaining
for spec in specs:
if spec.get("item_amount") is None:
spec["item_amount"] = Decimal("0.00")
return specs
def _replace_claim_items(
self,
*,
claim: ExpenseClaim,
item_specs: list[dict[str, Any]],
) -> None:
existing_items = sorted(
list(claim.items),
key=lambda item: (
item.item_date or date.max,
self._normalize_sort_datetime(item.created_at),
),
)
for index, spec in enumerate(item_specs):
item = existing_items[index] if index < len(existing_items) else None
if item is None:
item = ExpenseClaimItem(claim_id=claim.id)
claim.items.append(item)
self.db.add(item)
item.item_date = spec["item_date"]
item.item_type = spec["item_type"]
item.item_reason = spec["item_reason"]
item.item_location = spec["item_location"]
item.item_amount = spec["item_amount"]
item.invoice_id = spec["invoice_id"]
for stale_item in existing_items[len(item_specs) :]:
claim.items.remove(stale_item)
self.db.delete(stale_item)
def _append_document_items(
self,
*,
claim: ExpenseClaim,
item_specs: list[dict[str, Any]],
) -> None:
existing_invoice_ids = {
str(item.invoice_id or "").strip()
for item in claim.items
if str(item.invoice_id or "").strip()
}
for spec in item_specs:
invoice_id = str(spec.get("invoice_id") or "").strip()
if invoice_id and invoice_id in existing_invoice_ids:
continue
claim.items.append(
ExpenseClaimItem(
claim_id=claim.id,
item_date=spec["item_date"],
item_type=spec["item_type"],
item_reason=spec["item_reason"],
item_location=spec["item_location"],
item_amount=spec["item_amount"],
invoice_id=spec["invoice_id"],
)
)
self.db.add(claim.items[-1])
if invoice_id:
existing_invoice_ids.add(invoice_id)
def _resolve_document_item_type(self, document: dict[str, Any], *, fallback: str) -> str:
scene_code = str(document.get("scene_code") or "").strip()
if scene_code in {"travel", "hotel", "transport", "meal", "office", "meeting", "training"}:
return scene_code
document_type = str(document.get("document_type") or "").strip()
if document_type in {"flight_itinerary", "train_ticket"}:
return "travel"
if document_type in {"taxi_receipt", "parking_toll_receipt", "transport_receipt"}:
return "transport"
if document_type == "hotel_invoice":
return "hotel"
if document_type == "meal_receipt":
return "meal"
if document_type == "office_invoice":
return "office"
if document_type == "meeting_invoice":
return "meeting"
if document_type == "training_invoice":
return "training"
scene_label = str(document.get("scene_label") or "").strip()
if "交通" in scene_label:
return "transport"
if "住宿" in scene_label:
return "hotel"
if "" in scene_label:
return "meal"
if "会务" in scene_label or "会议" in scene_label:
return "meeting"
if "培训" in scene_label:
return "training"
return fallback or "other"
def _resolve_document_item_amount(self, document: dict[str, Any]) -> Decimal | None:
for field in list(document.get("document_fields") or []):
if not isinstance(field, dict):
continue
key = str(field.get("key") or "").strip().lower().replace("_", "")
label = str(field.get("label") or "").replace(" ", "")
value = self._parse_document_amount_value(str(field.get("value") or ""))
if value is None:
continue
if key in {
"amount",
"totalamount",
"paymentamount",
"paidamount",
"actualamount",
} or any(
token in label
for token in ("金额", "价税合计", "合计", "总额", "总计", "票价", "支付金额", "实付金额", "实收金额")
):
return value
text = " ".join(
[
str(document.get("summary") or "").strip(),
str(document.get("text") or "").strip(),
]
).strip()
return self._parse_document_amount_value(text)
def _parse_document_amount_value(self, value: str) -> Decimal | None:
raw_value = str(value or "").strip()
if not raw_value:
return None
for pattern in DOCUMENT_AMOUNT_PATTERNS:
match = pattern.search(raw_value)
if not match:
continue
numeric = str(match.group(1) or "").replace(",", ".").strip()
try:
amount = Decimal(numeric).quantize(Decimal("0.01"))
except (InvalidOperation, ValueError):
continue
if amount > Decimal("0.00"):
return amount
return None
def _resolve_document_item_date(self, document: dict[str, Any], *, fallback: date) -> date:
for field in list(document.get("document_fields") or []):
if not isinstance(field, dict):
continue
key = str(field.get("key") or "").strip().lower().replace("_", "")
label = str(field.get("label") or "").replace(" ", "")
value = str(field.get("value") or "").strip()
if not value:
continue
if key in {"date", "time", "issuedat", "invoicedate"} or any(
token in label for token in ("日期", "时间", "开票日期", "发生时间")
):
parsed = self._parse_document_date(value)
if parsed is not None:
return parsed
parsed = self._parse_document_date(
" ".join(
[
str(document.get("summary") or "").strip(),
str(document.get("text") or "").strip(),
]
).strip()
)
return parsed or fallback
@staticmethod
def _parse_document_date(value: str) -> date | None:
match = DOCUMENT_DATE_PATTERN.search(str(value or ""))
if not match:
return None
raw_value = str(match.group(1) or "").strip()
normalized = raw_value.replace("", "-").replace("", "-").replace("", "")
normalized = normalized.replace("/", "-").replace(".", "-")
parts = [part for part in normalized.split("-") if part]
if len(parts) != 3:
return None
try:
return date(int(parts[0]), int(parts[1]), int(parts[2]))
except ValueError:
return None
def _upsert_primary_item(
self,
*,
@@ -816,13 +1344,41 @@ class ExpenseClaimService:
def _generate_claim_no(self, occurred_at: datetime) -> str:
month_code = occurred_at.strftime("%Y%m")
prefix = f"EXP-{month_code}-"
existing = int(
self.db.scalar(
select(func.count()).select_from(ExpenseClaim).where(ExpenseClaim.claim_no.like(f"{prefix}%"))
existing_claim_nos = list(
self.db.scalars(
select(ExpenseClaim.claim_no).where(ExpenseClaim.claim_no.like(f"{prefix}%"))
)
)
max_suffix = 0
for claim_no in existing_claim_nos:
normalized = str(claim_no or "").strip()
if not normalized.startswith(prefix):
continue
suffix = normalized[len(prefix):]
if not suffix.isdigit():
continue
max_suffix = max(max_suffix, int(suffix))
return f"{prefix}{max_suffix + 1:03d}"
@staticmethod
def _resolve_claim_no_retry_count(context_json: dict[str, Any]) -> int:
try:
return max(0, int(context_json.get("_claim_no_retry_count") or 0))
except (TypeError, ValueError):
return 0
@staticmethod
def _is_claim_no_conflict_error(exc: IntegrityError) -> bool:
message = str(exc).lower()
return (
"claim_no" in message
and (
"unique" in message
or "duplicate key" in message
or "ix_expense_claims_claim_no" in message
or "expense_claims.claim_no" in message
)
or 0
)
return f"{prefix}{existing + 1:03d}"
def _count_draft_claims_for_owner(
self,
@@ -1011,6 +1567,13 @@ class ExpenseClaimService:
if value:
return value
explicit_text = context_json.get("user_input_text")
if isinstance(explicit_text, str):
normalized_explicit_text = explicit_text.strip()
if normalized_explicit_text:
return normalized_explicit_text[:500]
return None
request_context = context_json.get("request_context")
if (
isinstance(request_context, dict)
@@ -1022,7 +1585,12 @@ class ExpenseClaimService:
return value
if not allow_message_fallback:
return None
return str(message or "").strip()[:500] or None
normalized_message = str(message or "").strip()
compact_message = re.sub(r"\s+", "", normalized_message)
if compact_message.startswith(SYSTEM_GENERATED_REASON_PREFIXES):
return None
return normalized_message[:500] or None
@staticmethod
def _resolve_location(*, message: str, context_json: dict[str, Any]) -> str | None:
@@ -1210,6 +1778,74 @@ class ExpenseClaimService:
return {}
return payload if isinstance(payload, dict) else {}
def _build_attachment_preview_meta(
self,
*,
file_path: Path,
media_type: str,
ocr_document: Any | None,
) -> dict[str, Any]:
filename = file_path.name
storage_key = self._to_attachment_storage_key(file_path)
preview_kind = self._resolve_preview_kind(media_type, filename)
preview_data_url = str(getattr(ocr_document, "preview_data_url", "") or "").strip()
preview_source_kind = str(getattr(ocr_document, "preview_kind", "") or "").strip()
if preview_source_kind == "image" and preview_data_url:
preview_asset = self._write_preview_asset_from_data_url(
attachment_dir=file_path.parent,
original_filename=filename,
preview_data_url=preview_data_url,
)
if preview_asset is not None:
preview_path, preview_media_type, preview_file_name = preview_asset
return {
"previewable": True,
"preview_kind": "image",
"preview_storage_key": self._to_attachment_storage_key(preview_path),
"preview_media_type": preview_media_type,
"preview_file_name": preview_file_name,
}
if preview_kind:
return {
"previewable": True,
"preview_kind": preview_kind,
"preview_storage_key": storage_key,
"preview_media_type": media_type,
"preview_file_name": filename,
}
return {
"previewable": False,
"preview_kind": "",
"preview_storage_key": "",
"preview_media_type": "",
"preview_file_name": "",
}
def _resolve_item_attachment_preview_content(self, item: ExpenseClaimItem) -> tuple[Path, str, str]:
file_path, media_type, filename = self._resolve_item_attachment_content(item)
metadata = self._read_attachment_meta(file_path)
preview_storage_key = str(metadata.get("preview_storage_key") or "").strip()
preview_file_name = str(metadata.get("preview_file_name") or "").strip()
preview_media_type = str(metadata.get("preview_media_type") or "").strip()
if preview_storage_key:
preview_path = self._resolve_attachment_path(preview_storage_key)
if preview_path is not None and preview_path.exists():
resolved_name = preview_file_name or preview_path.name
resolved_media_type = self._resolve_attachment_media_type(
resolved_name,
fallback=preview_media_type,
)
return preview_path, resolved_media_type, resolved_name
if self._is_previewable_media_type(media_type, filename):
return file_path, media_type, filename
raise FileNotFoundError("Attachment preview not found")
def _build_attachment_payload(self, item: ExpenseClaimItem) -> dict[str, Any]:
file_path, media_type, filename = self._resolve_item_attachment_content(item)
metadata = self._read_attachment_meta(file_path)
@@ -1233,18 +1869,71 @@ class ExpenseClaimService:
if not isinstance(requirement_check, dict):
requirement_check = None
preview_kind = str(metadata.get("preview_kind") or "").strip()
previewable = bool(metadata.get("previewable", self._is_previewable_media_type(media_type, filename)))
preview_url = self._build_attachment_preview_client_path(item.claim_id, item.id) if previewable else ""
return {
"file_name": str(metadata.get("file_name") or filename),
"storage_key": str(item.invoice_id or ""),
"media_type": str(metadata.get("media_type") or media_type),
"size_bytes": int(metadata.get("size_bytes") or file_path.stat().st_size),
"uploaded_at": uploaded_at,
"previewable": bool(metadata.get("previewable", self._is_previewable_media_type(media_type, filename))),
"previewable": previewable,
"preview_kind": preview_kind or self._resolve_preview_kind(media_type, filename),
"preview_url": preview_url,
"analysis": analysis,
"document_info": document_info,
"requirement_check": requirement_check,
}
@staticmethod
def _resolve_preview_kind(media_type: str | None, filename: str) -> str:
resolved = str(media_type or "").strip() or (mimetypes.guess_type(filename)[0] or "")
if resolved.startswith("image/"):
return "image"
if resolved == "application/pdf":
return "pdf"
return ""
@staticmethod
def _decode_data_url(payload: str) -> tuple[str, bytes] | None:
normalized = str(payload or "").strip()
matched = re.match(r"^data:(?P<media>[\w.+-]+/[\w.+-]+);base64,(?P<body>.+)$", normalized, flags=re.DOTALL)
if not matched:
return None
try:
content = base64.b64decode(matched.group("body"), validate=True)
except (binascii.Error, ValueError):
return None
return matched.group("media"), content
def _write_preview_asset_from_data_url(
self,
*,
attachment_dir: Path,
original_filename: str,
preview_data_url: str,
) -> tuple[Path, str, str] | None:
decoded = self._decode_data_url(preview_data_url)
if decoded is None:
return None
preview_media_type, preview_content = decoded
suffix = mimetypes.guess_extension(preview_media_type) or ".bin"
preview_name = f"{Path(original_filename).stem}.preview{suffix}"
preview_path = attachment_dir / preview_name
preview_path.write_bytes(preview_content)
return preview_path, preview_media_type, preview_name
@staticmethod
def _build_attachment_preview_client_path(claim_id: str, item_id: str) -> str:
return (
"/reimbursements/claims/"
f"{quote(str(claim_id or '').strip(), safe='')}"
f"/items/{quote(str(item_id or '').strip(), safe='')}/attachment/preview"
)
@staticmethod
def _resolve_attachment_media_type(filename: str, *, fallback: str | None = None) -> str:
guessed = mimetypes.guess_type(filename)[0]

View File

@@ -11,9 +11,11 @@ from app.api.deps import CurrentUserContext
from app.db.base import Base
from app.models.employee import Employee
from app.models.financial_record import ExpenseClaim, ExpenseClaimItem
from app.schemas.ontology import OntologyParseRequest
from app.schemas.ocr import OcrRecognizeBatchRead, OcrRecognizeDocumentRead
from app.schemas.reimbursement import ExpenseClaimItemCreate, ExpenseClaimItemUpdate
from app.services.expense_claims import ExpenseClaimService
from app.services.ontology import SemanticOntologyService
from app.services.ocr import OcrService
@@ -97,6 +99,347 @@ def test_resolve_expense_type_maps_office_supplies_review_value_to_office() -> N
assert expense_type == "office"
def test_upsert_draft_from_ontology_defers_multi_document_association_choice() -> None:
user_id = "zhangsan@example.com"
with build_session() as db:
employee = Employee(
employee_no="E5001",
name="张三",
email=user_id,
)
db.add(employee)
db.flush()
existing_claim = ExpenseClaim(
claim_no="EXP-202605-010",
employee_id=employee.id,
employee_name="张三",
department_name="市场部",
project_code=None,
expense_type="transport",
reason="原有交通报销",
location="深圳",
amount=Decimal("20.00"),
currency="CNY",
invoice_count=1,
occurred_at=datetime(2026, 5, 13, tzinfo=UTC),
status="draft",
approval_stage="待提交",
risk_flags_json=[],
)
existing_claim.items = [
ExpenseClaimItem(
claim_id=existing_claim.id,
item_date=date(2026, 5, 13),
item_type="transport",
item_reason="原有交通报销",
item_location="深圳",
item_amount=Decimal("20.00"),
invoice_id="old-trip.png",
)
]
db.add(existing_claim)
db.commit()
ontology = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="我上传了两张交通票据,帮我生成报销草稿",
user_id=user_id,
)
)
service = ExpenseClaimService(db)
result = service.upsert_draft_from_ontology(
run_id=ontology.run_id,
user_id=user_id,
message="我上传了两张交通票据,帮我生成报销草稿",
ontology=ontology,
context_json={
"name": "张三",
"attachment_names": ["didi-trip.png", "parking-ticket.jpg"],
"attachment_count": 2,
"draft_claim_id": existing_claim.id,
"ocr_documents": [
{
"filename": "didi-trip.png",
"summary": "滴滴出行 支付金额 32 元",
"text": "滴滴出行 支付金额 32 元",
},
{
"filename": "parking-ticket.jpg",
"summary": "停车费 合计 18 元",
"text": "停车费 合计 18 元",
},
],
},
)
db.refresh(existing_claim)
assert result["pending_association_decision"] is True
assert result["association_candidate_claim_id"] == existing_claim.id
assert existing_claim.invoice_count == 1
assert len(existing_claim.items) == 1
assert existing_claim.items[0].invoice_id == "old-trip.png"
def test_upsert_draft_from_ontology_keeps_reason_missing_for_attachment_only_upload() -> None:
user_id = "wangwu@example.com"
with build_session() as db:
employee = Employee(
employee_no="E5003",
name="王五",
email=user_id,
)
db.add(employee)
db.commit()
ontology = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="我上传了 1 份票据,请结合附件名称给出报销建议并尽量生成草稿。",
user_id=user_id,
)
)
service = ExpenseClaimService(db)
result = service.upsert_draft_from_ontology(
run_id=ontology.run_id,
user_id=user_id,
message="我上传了 1 份票据,请结合附件名称给出报销建议并尽量生成草稿。\n附件名称didi-trip.png",
ontology=ontology,
context_json={
"name": "王五",
"user_input_text": "",
"attachment_names": ["didi-trip.png"],
"attachment_count": 1,
"ocr_documents": [
{
"filename": "didi-trip.png",
"summary": "滴滴出行 支付金额 32 元",
"text": "滴滴出行 支付金额 32 元",
"document_type": "taxi_receipt",
"scene_code": "transport",
}
],
},
)
claim = db.get(ExpenseClaim, result["claim_id"])
assert claim is not None
assert claim.reason == "待补充"
def test_upsert_draft_from_ontology_supports_link_or_create_for_multi_documents() -> None:
user_id = "lisi@example.com"
with build_session() as db:
employee = Employee(
employee_no="E5002",
name="李四",
email=user_id,
)
db.add(employee)
db.flush()
existing_claim = ExpenseClaim(
claim_no="EXP-202605-011",
employee_id=employee.id,
employee_name="李四",
department_name="销售部",
project_code=None,
expense_type="transport",
reason="原有交通报销",
location="上海",
amount=Decimal("20.00"),
currency="CNY",
invoice_count=1,
occurred_at=datetime(2026, 5, 13, tzinfo=UTC),
status="draft",
approval_stage="待提交",
risk_flags_json=[],
)
existing_claim.items = [
ExpenseClaimItem(
claim_id=existing_claim.id,
item_date=date(2026, 5, 13),
item_type="transport",
item_reason="原有交通报销",
item_location="上海",
item_amount=Decimal("20.00"),
invoice_id="existing.png",
)
]
db.add(existing_claim)
db.commit()
ontology = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="我上传了两张交通票据,帮我生成报销草稿",
user_id=user_id,
)
)
service = ExpenseClaimService(db)
context_json = {
"name": "李四",
"attachment_names": ["didi-trip.png", "parking-ticket.jpg"],
"attachment_count": 2,
"draft_claim_id": existing_claim.id,
"ocr_documents": [
{
"filename": "didi-trip.png",
"summary": "滴滴出行",
"text": "滴滴出行 支付金额 32.50 元",
"document_type": "taxi_receipt",
"scene_code": "transport",
"document_fields": [{"key": "amount", "label": "支付金额", "value": "32.50"}],
},
{
"filename": "parking-ticket.jpg",
"summary": "停车票",
"text": "停车费 合计 18 元",
"document_type": "parking_toll_receipt",
"scene_code": "transport",
"document_fields": [{"key": "total_amount", "label": "合计金额", "value": "18"}],
},
],
}
link_result = service.upsert_draft_from_ontology(
run_id=ontology.run_id,
user_id=user_id,
message="把这两张票据关联到已有草稿",
ontology=ontology,
context_json={
**context_json,
"review_action": "link_to_existing_draft",
},
)
db.refresh(existing_claim)
assert link_result["claim_id"] == existing_claim.id
assert existing_claim.invoice_count == 3
assert len(existing_claim.items) == 3
assert float(existing_claim.amount) == 70.5
create_result = service.upsert_draft_from_ontology(
run_id=f"{ontology.run_id}-new",
user_id=user_id,
message="单独新建一张报销单",
ontology=ontology,
context_json={
**context_json,
"review_action": "create_new_claim_from_documents",
},
)
assert create_result["claim_id"] != existing_claim.id
new_claim = db.get(ExpenseClaim, create_result["claim_id"])
assert new_claim is not None
assert new_claim.invoice_count == 2
assert len(new_claim.items) == 2
assert float(new_claim.amount) == 50.5
def test_generate_claim_no_uses_max_suffix_instead_of_count() -> None:
with build_session() as db:
db.add_all(
[
ExpenseClaim(
claim_no="EXP-202605-001",
employee_name="张三",
department_name="市场部",
project_code=None,
expense_type="transport",
reason="交通报销",
location="深圳",
amount=Decimal("10.00"),
currency="CNY",
invoice_count=1,
occurred_at=datetime(2026, 5, 10, tzinfo=UTC),
status="draft",
approval_stage="待提交",
risk_flags_json=[],
),
ExpenseClaim(
claim_no="EXP-202605-003",
employee_name="李四",
department_name="销售部",
project_code=None,
expense_type="transport",
reason="交通报销",
location="上海",
amount=Decimal("20.00"),
currency="CNY",
invoice_count=1,
occurred_at=datetime(2026, 5, 11, tzinfo=UTC),
status="submitted",
approval_stage="审批中",
risk_flags_json=[],
),
]
)
db.commit()
service = ExpenseClaimService(db)
assert service._generate_claim_no(datetime(2026, 5, 14, tzinfo=UTC)) == "EXP-202605-004"
def test_upsert_draft_from_ontology_retries_claim_no_conflict() -> None:
user_id = "zhaoliu-claimno@example.com"
with build_session() as db:
employee = Employee(
employee_no="E5006",
name="赵六",
email=user_id,
)
db.add(employee)
db.flush()
db.add(
ExpenseClaim(
claim_no="EXP-202605-004",
employee_name="历史单据",
department_name="财务部",
project_code=None,
expense_type="other",
reason="历史草稿",
location="北京",
amount=Decimal("0.00"),
currency="CNY",
invoice_count=0,
occurred_at=datetime(2026, 5, 12, tzinfo=UTC),
status="submitted",
approval_stage="审批中",
risk_flags_json=[],
)
)
db.commit()
ontology = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="帮我生成报销草稿,我昨天交通费 13.4 元",
user_id=user_id,
)
)
service = ExpenseClaimService(db)
generated_claim_nos = iter(["EXP-202605-004", "EXP-202605-005"])
service._generate_claim_no = lambda occurred_at: next(generated_claim_nos)
result = service.upsert_draft_from_ontology(
run_id=ontology.run_id,
user_id=user_id,
message="帮我生成报销草稿,我昨天交通费 13.4 元",
ontology=ontology,
context_json={
"name": "赵六",
"user_input_text": "帮我生成报销草稿,我昨天交通费 13.4 元",
},
)
created_claim = db.get(ExpenseClaim, result["claim_id"])
assert created_claim is not None
assert created_claim.claim_no == "EXP-202605-005"
assert result["claim_no"] == "EXP-202605-005"
def test_create_claim_item_adds_blank_draft_row_without_forcing_attachment() -> None:
current_user = CurrentUserContext(
username="emp-1",
@@ -186,6 +529,10 @@ def test_update_claim_item_reanalyzes_existing_attachment(monkeypatch, tmp_path)
current_user=current_user,
)
assert uploaded_meta is not None
assert uploaded_meta["preview_kind"] == "image"
assert uploaded_meta["preview_url"].endswith(
f"/reimbursements/claims/{claim.id}/items/{claim.items[0].id}/attachment/preview"
)
assert uploaded_meta["analysis"]["severity"] == "pass"
assert uploaded_meta["document_info"]["document_type"] == "office_invoice"
assert uploaded_meta["requirement_check"]["matches"] is True

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import base64
from collections.abc import Generator
from datetime import UTC, date, datetime
from decimal import Decimal
@@ -165,9 +166,18 @@ def test_claim_item_attachment_upload_preview_and_delete(monkeypatch, tmp_path)
assert meta_response.status_code == 200
meta_payload = meta_response.json()
assert meta_payload["media_type"] == "image/png"
assert meta_payload["preview_kind"] == "image"
assert meta_payload["preview_url"].endswith(f"/reimbursements/claims/{claim_id}/items/{item_id}/attachment/preview")
assert meta_payload["analysis"]["headline"]
assert meta_payload["document_info"]["fields"][0]["label"] == "金额"
preview_response = client.get(
f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment/preview",
headers=headers,
)
assert preview_response.status_code == 200
assert preview_response.content == file_bytes
content_response = client.get(
f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment",
headers=headers,
@@ -279,6 +289,67 @@ def test_claim_item_attachment_upload_flags_non_invoice_image_as_high_risk(monke
assert any("附件内容" in point for point in analysis["points"])
def test_claim_item_pdf_attachment_preview_returns_generated_image(monkeypatch, tmp_path) -> None:
preview_bytes = b"fake-preview-png"
preview_data_url = f"data:image/png;base64,{base64.b64encode(preview_bytes).decode('ascii')}"
def fake_recognize(
self,
files: list[tuple[str, bytes, str | None]],
) -> OcrRecognizeBatchRead:
return OcrRecognizeBatchRead(
total_file_count=1,
success_count=1,
documents=[
OcrRecognizeDocumentRead(
filename="invoice.pdf",
media_type="application/pdf",
text="滴滴出行电子发票 金额13.4元",
summary="识别到交通票据,金额 13.4 元。",
avg_score=0.96,
line_count=1,
page_count=1,
document_type="taxi_receipt",
document_type_label="出租车/网约车票据",
scene_code="transport",
scene_label="交通票据",
preview_kind="image",
preview_data_url=preview_data_url,
warnings=[],
)
],
)
monkeypatch.setattr(OcrService, "recognize_files", fake_recognize)
monkeypatch.setattr(ExpenseClaimService, "_get_attachment_storage_root", lambda self: tmp_path)
client, session_factory = build_client()
with session_factory() as db:
claim, item = seed_claim(db)
claim_id = claim.id
item_id = item.id
headers = {"x-auth-username": "emp-1", "x-auth-name": "Zhang San"}
upload_response = client.post(
f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment",
headers=headers,
files=[("file", ("invoice.pdf", b"%PDF-1.4 fake", "application/pdf"))],
)
assert upload_response.status_code == 200
meta_payload = upload_response.json()["attachment"]
assert meta_payload["preview_kind"] == "image"
assert meta_payload["preview_url"].endswith(f"/reimbursements/claims/{claim_id}/items/{item_id}/attachment/preview")
preview_response = client.get(
f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment/preview",
headers=headers,
)
assert preview_response.status_code == 200
assert preview_response.headers["content-type"].startswith("image/png")
assert preview_response.content == preview_bytes
def test_claim_item_delete_removes_item_and_attachment(monkeypatch, tmp_path) -> None:
def fake_recognize(
self,