feat(server): 重构报销单服务,优化费用报销流程和数据校验逻辑,包含schema定义和服务实现

This commit is contained in:
caoxiaozhu
2026-05-14 15:42:45 +00:00
parent ad16358e71
commit e21f0d82e9
4 changed files with 1187 additions and 78 deletions

View File

@@ -87,6 +87,8 @@ class ExpenseClaimAttachmentRead(BaseModel):
size_bytes: int size_bytes: int
uploaded_at: datetime | None = None uploaded_at: datetime | None = None
previewable: bool = True previewable: bool = True
preview_kind: str = ""
preview_url: str = ""
analysis: ExpenseClaimAttachmentAnalysisRead | None = None analysis: ExpenseClaimAttachmentAnalysisRead | None = None
document_info: ExpenseClaimAttachmentDocumentInfoRead | None = None document_info: ExpenseClaimAttachmentDocumentInfoRead | None = None
requirement_check: ExpenseClaimAttachmentRequirementRead | None = None requirement_check: ExpenseClaimAttachmentRequirementRead | None = None

View File

@@ -1,5 +1,7 @@
from __future__ import annotations from __future__ import annotations
import base64
import binascii
import json import json
import mimetypes import mimetypes
import re import re
@@ -9,8 +11,10 @@ from decimal import Decimal, InvalidOperation
from pathlib import Path from pathlib import Path
from types import SimpleNamespace from types import SimpleNamespace
from typing import Any from typing import Any
from urllib.parse import quote
from sqlalchemy import func, or_, select from sqlalchemy import func, or_, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session, selectinload from sqlalchemy.orm import Session, selectinload
from app.api.deps import CurrentUserContext from app.api.deps import CurrentUserContext
@@ -102,6 +106,32 @@ DOCUMENT_SCENE_LABELS = {
"other": "其他票据", "other": "其他票据",
} }
DOCUMENT_ASSOCIATION_REVIEW_ACTIONS = {
"link_to_existing_draft",
"create_new_claim_from_documents",
}
MAX_CLAIM_NO_RETRY_ATTEMPTS = 3
DOCUMENT_AMOUNT_PATTERNS = (
re.compile(
r"(?:价税合计|合计金额|费用合计|订单(?:总)?金额|支付(?:金额)?|实付(?:金额)?|实收(?:金额)?|总(?:额|计|价)|票价|金额|车费|消费金额)"
r"[:\s¥¥人民币]*([0-9]+(?:[.,][0-9]{1,2})?)"
),
re.compile(r"[¥¥]\s*([0-9]+(?:[.,][0-9]{1,2})?)"),
re.compile(r"([0-9]+(?:[.,][0-9]{1,2})?)\s*元"),
)
DOCUMENT_DATE_PATTERN = re.compile(r"((?:20\d{2}|19\d{2})[-/年.](?:1[0-2]|0?[1-9])[-/月.](?:3[01]|[12]\d|0?[1-9])日?)")
SYSTEM_GENERATED_REASON_PREFIXES = (
"我上传了",
"请按当前已识别信息",
"请把当前上传的票据",
"请基于当前上传的多张票据",
"我已核对右侧识别结果",
"请同步修正逐票据识别结果",
"我已修改识别信息",
"查看报销草稿",
"请解释一下当前这笔报销的合规风险和待补充项",
)
class ExpenseClaimService: class ExpenseClaimService:
def __init__(self, db: Session) -> None: def __init__(self, db: Session) -> None:
@@ -314,6 +344,10 @@ class ExpenseClaimService:
file_path = attachment_dir / normalized_name file_path = attachment_dir / normalized_name
file_path.write_bytes(content) file_path.write_bytes(content)
resolved_media_type = self._resolve_attachment_media_type(
normalized_name,
fallback=media_type,
)
attachment_analysis = self._build_fallback_attachment_analysis( attachment_analysis = self._build_fallback_attachment_analysis(
media_type=media_type, media_type=media_type,
@@ -353,16 +387,22 @@ class ExpenseClaimService:
) )
item.invoice_id = self._to_attachment_storage_key(file_path) item.invoice_id = self._to_attachment_storage_key(file_path)
preview_meta = self._build_attachment_preview_meta(
file_path=file_path,
media_type=resolved_media_type,
ocr_document=ocr_document,
)
meta = { meta = {
"file_name": normalized_name, "file_name": normalized_name,
"storage_key": item.invoice_id, "storage_key": item.invoice_id,
"media_type": self._resolve_attachment_media_type( "media_type": resolved_media_type,
normalized_name,
fallback=media_type,
),
"size_bytes": len(content), "size_bytes": len(content),
"uploaded_at": datetime.now(UTC).isoformat(), "uploaded_at": datetime.now(UTC).isoformat(),
"previewable": self._is_previewable_media_type(media_type, normalized_name), "previewable": bool(preview_meta["previewable"]),
"preview_kind": str(preview_meta["preview_kind"]),
"preview_storage_key": str(preview_meta["preview_storage_key"]),
"preview_media_type": str(preview_meta["preview_media_type"]),
"preview_file_name": str(preview_meta["preview_file_name"]),
"analysis": attachment_analysis, "analysis": attachment_analysis,
"document_info": document_info, "document_info": document_info,
"requirement_check": requirement_check, "requirement_check": requirement_check,
@@ -438,6 +478,23 @@ class ExpenseClaimService:
return self._resolve_item_attachment_content(item) return self._resolve_item_attachment_content(item)
def get_claim_item_attachment_preview_content(
self,
*,
claim_id: str,
item_id: str,
current_user: CurrentUserContext,
) -> tuple[Path, str, str] | None:
claim, item = self._get_claim_item_or_raise(
claim_id=claim_id,
item_id=item_id,
current_user=current_user,
)
if claim is None:
return None
return self._resolve_item_attachment_preview_content(item)
def delete_claim_item_attachment( def delete_claim_item_attachment(
self, self,
*, *,
@@ -609,10 +666,12 @@ class ExpenseClaimService:
context_json: dict[str, Any], context_json: dict[str, Any],
) -> dict[str, Any]: ) -> dict[str, Any]:
self._ensure_ready() self._ensure_ready()
context_json = dict(context_json or {})
retry_count = self._resolve_claim_no_retry_count(context_json)
claim = self._find_target_claim(ontology=ontology, context_json=context_json) review_action = str(context_json.get("review_action") or "").strip()
is_new_claim = claim is None attachment_names = self._resolve_attachment_names(context_json)
before_json = self._serialize_claim(claim) if claim is not None else None context_documents = self._resolve_context_documents(context_json)
employee = self._resolve_employee( employee = self._resolve_employee(
ontology=ontology, ontology=ontology,
@@ -628,6 +687,40 @@ class ExpenseClaimService:
user_id=user_id, user_id=user_id,
) )
) )
association_candidate = self._find_association_candidate(
ontology=ontology,
context_json=context_json,
user_id=user_id,
employee=employee,
)
if self._should_defer_multi_document_association(
context_json=context_json,
review_action=review_action,
association_candidate=association_candidate,
context_documents=context_documents,
):
document_count = max(len(context_documents), len(attachment_names), self._resolve_attachment_count(context_json))
return {
"message": (
f"检测到你已有草稿 {association_candidate.claim_no}"
f"当前新上传了 {document_count} 张票据,请先选择关联到现有草稿,或单独建立新的报销单。"
),
"draft_only": False,
"status": "pending_association_decision",
"pending_association_decision": True,
"association_candidate_claim_id": association_candidate.id,
"association_candidate_claim_no": association_candidate.claim_no,
}
claim = self._find_target_claim(
ontology=ontology,
context_json=context_json,
review_action=review_action,
association_candidate=association_candidate,
)
is_new_claim = claim is None
before_json = self._serialize_claim(claim) if claim is not None else None
if is_new_claim: if is_new_claim:
existing_draft_count = self._count_draft_claims_for_owner( existing_draft_count = self._count_draft_claims_for_owner(
employee=employee, employee=employee,
@@ -655,7 +748,7 @@ class ExpenseClaimService:
context_json=context_json, context_json=context_json,
allow_message_fallback=is_new_claim, allow_message_fallback=is_new_claim,
) )
attachment_count = self._resolve_attachment_count(context_json) attachment_count = len(attachment_names) or self._resolve_attachment_count(context_json)
final_amount = amount if amount is not None else (claim.amount if claim is not None else Decimal("0.00")) final_amount = amount if amount is not None else (claim.amount if claim is not None else Decimal("0.00"))
final_occurred_at = ( final_occurred_at = (
@@ -671,70 +764,118 @@ class ExpenseClaimService:
list(claim.risk_flags_json or []) if claim is not None else [] list(claim.risk_flags_json or []) if claim is not None else []
) )
if claim is None: try:
claim = ExpenseClaim( if claim is None:
claim_no=self._generate_claim_no(final_occurred_at), claim = ExpenseClaim(
employee_id=employee.id if employee is not None else None, claim_no=self._generate_claim_no(final_occurred_at),
employee_name=draft_owner_name, employee_id=employee.id if employee is not None else None,
department_id=employee.organization_unit_id if employee is not None else None, employee_name=draft_owner_name,
department_name=self._resolve_department_name( department_id=employee.organization_unit_id if employee is not None else None,
department_name=self._resolve_department_name(
employee=employee,
context_json=context_json,
),
project_code=self._resolve_project_code(ontology.entities),
expense_type=final_expense_type,
reason=final_reason,
location=final_location,
amount=final_amount,
currency="CNY",
invoice_count=final_attachment_count,
occurred_at=final_occurred_at,
status="draft",
approval_stage="待提交",
risk_flags_json=final_risk_flags,
)
self.db.add(claim)
else:
claim.employee_id = employee.id if employee is not None else claim.employee_id
claim.employee_name = (
employee.name
if employee is not None
else self._resolve_employee_name(
ontology=ontology,
context_json=context_json,
user_id=user_id,
fallback=claim.employee_name,
)
)
claim.department_id = employee.organization_unit_id if employee is not None else claim.department_id
claim.department_name = self._resolve_department_name(
employee=employee, employee=employee,
context_json=context_json, context_json=context_json,
), fallback=claim.department_name,
project_code=self._resolve_project_code(ontology.entities),
expense_type=final_expense_type,
reason=final_reason,
location=final_location,
amount=final_amount,
currency="CNY",
invoice_count=final_attachment_count,
occurred_at=final_occurred_at,
status="draft",
approval_stage="待提交",
risk_flags_json=final_risk_flags,
)
self.db.add(claim)
else:
claim.employee_id = employee.id if employee is not None else claim.employee_id
claim.employee_name = (
employee.name
if employee is not None
else self._resolve_employee_name(
ontology=ontology,
context_json=context_json,
user_id=user_id,
fallback=claim.employee_name,
) )
) claim.project_code = self._resolve_project_code(ontology.entities) or claim.project_code
claim.department_id = employee.organization_unit_id if employee is not None else claim.department_id claim.expense_type = final_expense_type
claim.department_name = self._resolve_department_name( claim.reason = final_reason
employee=employee, claim.location = final_location
context_json=context_json, claim.amount = final_amount
fallback=claim.department_name, claim.invoice_count = final_attachment_count
) claim.occurred_at = final_occurred_at
claim.project_code = self._resolve_project_code(ontology.entities) or claim.project_code claim.status = "draft"
claim.expense_type = final_expense_type claim.approval_stage = "待提交"
claim.reason = final_reason claim.risk_flags_json = final_risk_flags
claim.location = final_location
claim.amount = final_amount
claim.invoice_count = final_attachment_count
claim.occurred_at = final_occurred_at
claim.status = "draft"
claim.approval_stage = "待提交"
claim.risk_flags_json = final_risk_flags
self.db.flush() self.db.flush()
self._upsert_primary_item( if context_documents or attachment_names:
claim=claim, document_specs = self._build_context_item_specs(
occurred_at=final_occurred_at, context_documents=context_documents,
expense_type=final_expense_type, attachment_names=attachment_names,
amount=final_amount, occurred_at=final_occurred_at,
reason=final_reason, expense_type=final_expense_type,
location=final_location, amount=final_amount,
attachment_names=self._resolve_attachment_names(context_json), reason=final_reason,
) location=final_location,
self.db.commit() )
self.db.refresh(claim) else:
document_specs = []
if document_specs and (is_new_claim or review_action in DOCUMENT_ASSOCIATION_REVIEW_ACTIONS):
if review_action == "link_to_existing_draft" and claim.items:
self._append_document_items(
claim=claim,
item_specs=document_specs,
)
else:
self._replace_claim_items(
claim=claim,
item_specs=document_specs,
)
self._sync_claim_from_items(claim)
else:
self._upsert_primary_item(
claim=claim,
occurred_at=final_occurred_at,
expense_type=final_expense_type,
amount=final_amount,
reason=final_reason,
location=final_location,
attachment_names=attachment_names,
)
self._sync_claim_from_items(claim)
self.db.commit()
self.db.refresh(claim)
except IntegrityError as exc:
self.db.rollback()
if (
is_new_claim
and retry_count < MAX_CLAIM_NO_RETRY_ATTEMPTS
and self._is_claim_no_conflict_error(exc)
):
retry_context = dict(context_json)
retry_context["_claim_no_retry_count"] = retry_count + 1
return self.upsert_draft_from_ontology(
run_id=run_id,
user_id=user_id,
message=message,
ontology=ontology,
context_json=retry_context,
)
raise
except Exception:
self.db.rollback()
raise
self.audit_service.log_action( self.audit_service.log_action(
actor=user_id or claim.employee_name or "anonymous", actor=user_id or claim.employee_name or "anonymous",
@@ -764,10 +905,20 @@ class ExpenseClaimService:
*, *,
ontology: OntologyParseResult, ontology: OntologyParseResult,
context_json: dict[str, Any], context_json: dict[str, Any],
review_action: str = "",
association_candidate: ExpenseClaim | None = None,
) -> ExpenseClaim | None: ) -> ExpenseClaim | None:
if review_action == "create_new_claim_from_documents":
return None
if review_action == "link_to_existing_draft" and association_candidate is not None:
return association_candidate
draft_claim_id = str(context_json.get("draft_claim_id") or "").strip() draft_claim_id = str(context_json.get("draft_claim_id") or "").strip()
if draft_claim_id: if draft_claim_id:
return self.db.get(ExpenseClaim, draft_claim_id) claim = self.db.get(ExpenseClaim, draft_claim_id)
if claim is not None and str(claim.status or "").strip() == "draft":
return claim
return None
claim_codes = [ claim_codes = [
item.normalized_value item.normalized_value
@@ -777,9 +928,386 @@ class ExpenseClaimService:
if not claim_codes: if not claim_codes:
return None return None
stmt = select(ExpenseClaim).where(ExpenseClaim.claim_no.in_(claim_codes)).limit(1) stmt = (
select(ExpenseClaim)
.where(ExpenseClaim.claim_no.in_(claim_codes))
.where(ExpenseClaim.status == "draft")
.limit(1)
)
return self.db.scalar(stmt) return self.db.scalar(stmt)
def _find_association_candidate(
self,
*,
ontology: OntologyParseResult,
context_json: dict[str, Any],
user_id: str | None,
employee: Employee | None,
) -> ExpenseClaim | None:
draft_claim_id = str(context_json.get("draft_claim_id") or "").strip()
if draft_claim_id:
claim = self.db.get(ExpenseClaim, draft_claim_id)
if claim is not None and str(claim.status or "").strip() == "draft":
return claim
owner_filters = self._build_draft_owner_filters(
employee=employee,
user_id=user_id,
)
if not owner_filters:
fallback_name = self._resolve_employee_name(
ontology=ontology,
context_json=context_json,
user_id=user_id,
fallback="",
)
if fallback_name:
owner_filters = [ExpenseClaim.employee_name == fallback_name]
if not owner_filters:
return None
stmt = (
select(ExpenseClaim)
.where(ExpenseClaim.status == "draft")
.where(or_(*owner_filters))
.order_by(ExpenseClaim.updated_at.desc(), ExpenseClaim.created_at.desc())
.limit(1)
)
return self.db.scalar(stmt)
def _should_defer_multi_document_association(
self,
*,
context_json: dict[str, Any],
review_action: str,
association_candidate: ExpenseClaim | None,
context_documents: list[dict[str, Any]],
) -> bool:
if association_candidate is None:
return False
if review_action in DOCUMENT_ASSOCIATION_REVIEW_ACTIONS:
return False
document_count = max(
len(context_documents),
len(self._resolve_attachment_names(context_json)),
self._resolve_attachment_count(context_json),
)
return document_count > 1
def _resolve_context_documents(self, context_json: dict[str, Any]) -> list[dict[str, Any]]:
documents = context_json.get("ocr_documents")
if not isinstance(documents, list):
documents = []
normalized: list[dict[str, Any]] = []
for index, item in enumerate(documents[:10], start=1):
if not isinstance(item, dict):
continue
normalized.append(
{
"index": index,
"filename": str(item.get("filename") or "").strip(),
"summary": str(item.get("summary") or "").strip(),
"text": str(item.get("text") or "").strip(),
"document_type": str(item.get("document_type") or "").strip(),
"scene_code": str(item.get("scene_code") or "").strip(),
"scene_label": str(item.get("scene_label") or "").strip(),
"document_fields": self._normalize_document_fields(item.get("document_fields")),
}
)
overrides = context_json.get("review_document_form_values")
if not isinstance(overrides, list) or not normalized:
return normalized
override_map: dict[tuple[int, str], dict[str, Any]] = {}
for item in overrides:
if not isinstance(item, dict):
continue
filename = str(item.get("filename") or "").strip()
index = int(item.get("index") or 0)
if not filename and index <= 0:
continue
override_map[(index, filename)] = item
for item in normalized:
override = override_map.get((int(item["index"]), str(item["filename"])))
if override is None:
override = override_map.get((int(item["index"]), ""))
if override is None:
continue
summary = str(override.get("summary") or "").strip()
scene_label = str(override.get("scene_label") or "").strip()
fields = override.get("fields")
if summary:
item["summary"] = summary
if scene_label:
item["scene_label"] = scene_label
if isinstance(fields, list):
item["document_fields"] = self._normalize_document_fields(fields)
return normalized
@staticmethod
def _normalize_document_fields(raw_fields: Any) -> list[dict[str, str]]:
if not isinstance(raw_fields, list):
return []
normalized: list[dict[str, str]] = []
for field in raw_fields:
if not isinstance(field, dict):
continue
label = str(field.get("label") or "").strip()
value = str(field.get("value") or "").strip()
key = str(field.get("key") or label or "").strip()
if not label or not value:
continue
normalized.append(
{
"key": key,
"label": label,
"value": value,
}
)
return normalized
def _build_context_item_specs(
self,
*,
context_documents: list[dict[str, Any]],
attachment_names: list[str],
occurred_at: datetime,
expense_type: str,
amount: Decimal,
reason: str,
location: str,
) -> list[dict[str, Any]]:
specs: list[dict[str, Any]] = []
if context_documents:
for document in context_documents:
specs.append(
{
"item_date": self._resolve_document_item_date(document, fallback=occurred_at.date()),
"item_type": self._resolve_document_item_type(document, fallback=expense_type),
"item_reason": reason,
"item_location": location,
"item_amount": self._resolve_document_item_amount(document),
"invoice_id": str(document.get("filename") or "").strip() or None,
}
)
elif attachment_names:
for attachment_name in attachment_names:
specs.append(
{
"item_date": occurred_at.date(),
"item_type": expense_type,
"item_reason": reason,
"item_location": location,
"item_amount": None,
"invoice_id": attachment_name,
}
)
if not specs:
return []
total_recognized = sum(
spec["item_amount"] for spec in specs if isinstance(spec.get("item_amount"), Decimal)
)
missing_specs = [spec for spec in specs if spec.get("item_amount") is None]
if missing_specs:
remaining = (amount - total_recognized).quantize(Decimal("0.01"))
if remaining > Decimal("0.00"):
missing_specs[0]["item_amount"] = remaining
for spec in specs:
if spec.get("item_amount") is None:
spec["item_amount"] = Decimal("0.00")
return specs
def _replace_claim_items(
self,
*,
claim: ExpenseClaim,
item_specs: list[dict[str, Any]],
) -> None:
existing_items = sorted(
list(claim.items),
key=lambda item: (
item.item_date or date.max,
self._normalize_sort_datetime(item.created_at),
),
)
for index, spec in enumerate(item_specs):
item = existing_items[index] if index < len(existing_items) else None
if item is None:
item = ExpenseClaimItem(claim_id=claim.id)
claim.items.append(item)
self.db.add(item)
item.item_date = spec["item_date"]
item.item_type = spec["item_type"]
item.item_reason = spec["item_reason"]
item.item_location = spec["item_location"]
item.item_amount = spec["item_amount"]
item.invoice_id = spec["invoice_id"]
for stale_item in existing_items[len(item_specs) :]:
claim.items.remove(stale_item)
self.db.delete(stale_item)
def _append_document_items(
self,
*,
claim: ExpenseClaim,
item_specs: list[dict[str, Any]],
) -> None:
existing_invoice_ids = {
str(item.invoice_id or "").strip()
for item in claim.items
if str(item.invoice_id or "").strip()
}
for spec in item_specs:
invoice_id = str(spec.get("invoice_id") or "").strip()
if invoice_id and invoice_id in existing_invoice_ids:
continue
claim.items.append(
ExpenseClaimItem(
claim_id=claim.id,
item_date=spec["item_date"],
item_type=spec["item_type"],
item_reason=spec["item_reason"],
item_location=spec["item_location"],
item_amount=spec["item_amount"],
invoice_id=spec["invoice_id"],
)
)
self.db.add(claim.items[-1])
if invoice_id:
existing_invoice_ids.add(invoice_id)
def _resolve_document_item_type(self, document: dict[str, Any], *, fallback: str) -> str:
scene_code = str(document.get("scene_code") or "").strip()
if scene_code in {"travel", "hotel", "transport", "meal", "office", "meeting", "training"}:
return scene_code
document_type = str(document.get("document_type") or "").strip()
if document_type in {"flight_itinerary", "train_ticket"}:
return "travel"
if document_type in {"taxi_receipt", "parking_toll_receipt", "transport_receipt"}:
return "transport"
if document_type == "hotel_invoice":
return "hotel"
if document_type == "meal_receipt":
return "meal"
if document_type == "office_invoice":
return "office"
if document_type == "meeting_invoice":
return "meeting"
if document_type == "training_invoice":
return "training"
scene_label = str(document.get("scene_label") or "").strip()
if "交通" in scene_label:
return "transport"
if "住宿" in scene_label:
return "hotel"
if "" in scene_label:
return "meal"
if "会务" in scene_label or "会议" in scene_label:
return "meeting"
if "培训" in scene_label:
return "training"
return fallback or "other"
def _resolve_document_item_amount(self, document: dict[str, Any]) -> Decimal | None:
for field in list(document.get("document_fields") or []):
if not isinstance(field, dict):
continue
key = str(field.get("key") or "").strip().lower().replace("_", "")
label = str(field.get("label") or "").replace(" ", "")
value = self._parse_document_amount_value(str(field.get("value") or ""))
if value is None:
continue
if key in {
"amount",
"totalamount",
"paymentamount",
"paidamount",
"actualamount",
} or any(
token in label
for token in ("金额", "价税合计", "合计", "总额", "总计", "票价", "支付金额", "实付金额", "实收金额")
):
return value
text = " ".join(
[
str(document.get("summary") or "").strip(),
str(document.get("text") or "").strip(),
]
).strip()
return self._parse_document_amount_value(text)
def _parse_document_amount_value(self, value: str) -> Decimal | None:
raw_value = str(value or "").strip()
if not raw_value:
return None
for pattern in DOCUMENT_AMOUNT_PATTERNS:
match = pattern.search(raw_value)
if not match:
continue
numeric = str(match.group(1) or "").replace(",", ".").strip()
try:
amount = Decimal(numeric).quantize(Decimal("0.01"))
except (InvalidOperation, ValueError):
continue
if amount > Decimal("0.00"):
return amount
return None
def _resolve_document_item_date(self, document: dict[str, Any], *, fallback: date) -> date:
for field in list(document.get("document_fields") or []):
if not isinstance(field, dict):
continue
key = str(field.get("key") or "").strip().lower().replace("_", "")
label = str(field.get("label") or "").replace(" ", "")
value = str(field.get("value") or "").strip()
if not value:
continue
if key in {"date", "time", "issuedat", "invoicedate"} or any(
token in label for token in ("日期", "时间", "开票日期", "发生时间")
):
parsed = self._parse_document_date(value)
if parsed is not None:
return parsed
parsed = self._parse_document_date(
" ".join(
[
str(document.get("summary") or "").strip(),
str(document.get("text") or "").strip(),
]
).strip()
)
return parsed or fallback
@staticmethod
def _parse_document_date(value: str) -> date | None:
match = DOCUMENT_DATE_PATTERN.search(str(value or ""))
if not match:
return None
raw_value = str(match.group(1) or "").strip()
normalized = raw_value.replace("", "-").replace("", "-").replace("", "")
normalized = normalized.replace("/", "-").replace(".", "-")
parts = [part for part in normalized.split("-") if part]
if len(parts) != 3:
return None
try:
return date(int(parts[0]), int(parts[1]), int(parts[2]))
except ValueError:
return None
def _upsert_primary_item( def _upsert_primary_item(
self, self,
*, *,
@@ -816,13 +1344,41 @@ class ExpenseClaimService:
def _generate_claim_no(self, occurred_at: datetime) -> str: def _generate_claim_no(self, occurred_at: datetime) -> str:
month_code = occurred_at.strftime("%Y%m") month_code = occurred_at.strftime("%Y%m")
prefix = f"EXP-{month_code}-" prefix = f"EXP-{month_code}-"
existing = int( existing_claim_nos = list(
self.db.scalar( self.db.scalars(
select(func.count()).select_from(ExpenseClaim).where(ExpenseClaim.claim_no.like(f"{prefix}%")) select(ExpenseClaim.claim_no).where(ExpenseClaim.claim_no.like(f"{prefix}%"))
)
)
max_suffix = 0
for claim_no in existing_claim_nos:
normalized = str(claim_no or "").strip()
if not normalized.startswith(prefix):
continue
suffix = normalized[len(prefix):]
if not suffix.isdigit():
continue
max_suffix = max(max_suffix, int(suffix))
return f"{prefix}{max_suffix + 1:03d}"
@staticmethod
def _resolve_claim_no_retry_count(context_json: dict[str, Any]) -> int:
try:
return max(0, int(context_json.get("_claim_no_retry_count") or 0))
except (TypeError, ValueError):
return 0
@staticmethod
def _is_claim_no_conflict_error(exc: IntegrityError) -> bool:
message = str(exc).lower()
return (
"claim_no" in message
and (
"unique" in message
or "duplicate key" in message
or "ix_expense_claims_claim_no" in message
or "expense_claims.claim_no" in message
) )
or 0
) )
return f"{prefix}{existing + 1:03d}"
def _count_draft_claims_for_owner( def _count_draft_claims_for_owner(
self, self,
@@ -1011,6 +1567,13 @@ class ExpenseClaimService:
if value: if value:
return value return value
explicit_text = context_json.get("user_input_text")
if isinstance(explicit_text, str):
normalized_explicit_text = explicit_text.strip()
if normalized_explicit_text:
return normalized_explicit_text[:500]
return None
request_context = context_json.get("request_context") request_context = context_json.get("request_context")
if ( if (
isinstance(request_context, dict) isinstance(request_context, dict)
@@ -1022,7 +1585,12 @@ class ExpenseClaimService:
return value return value
if not allow_message_fallback: if not allow_message_fallback:
return None return None
return str(message or "").strip()[:500] or None
normalized_message = str(message or "").strip()
compact_message = re.sub(r"\s+", "", normalized_message)
if compact_message.startswith(SYSTEM_GENERATED_REASON_PREFIXES):
return None
return normalized_message[:500] or None
@staticmethod @staticmethod
def _resolve_location(*, message: str, context_json: dict[str, Any]) -> str | None: def _resolve_location(*, message: str, context_json: dict[str, Any]) -> str | None:
@@ -1210,6 +1778,74 @@ class ExpenseClaimService:
return {} return {}
return payload if isinstance(payload, dict) else {} return payload if isinstance(payload, dict) else {}
def _build_attachment_preview_meta(
self,
*,
file_path: Path,
media_type: str,
ocr_document: Any | None,
) -> dict[str, Any]:
filename = file_path.name
storage_key = self._to_attachment_storage_key(file_path)
preview_kind = self._resolve_preview_kind(media_type, filename)
preview_data_url = str(getattr(ocr_document, "preview_data_url", "") or "").strip()
preview_source_kind = str(getattr(ocr_document, "preview_kind", "") or "").strip()
if preview_source_kind == "image" and preview_data_url:
preview_asset = self._write_preview_asset_from_data_url(
attachment_dir=file_path.parent,
original_filename=filename,
preview_data_url=preview_data_url,
)
if preview_asset is not None:
preview_path, preview_media_type, preview_file_name = preview_asset
return {
"previewable": True,
"preview_kind": "image",
"preview_storage_key": self._to_attachment_storage_key(preview_path),
"preview_media_type": preview_media_type,
"preview_file_name": preview_file_name,
}
if preview_kind:
return {
"previewable": True,
"preview_kind": preview_kind,
"preview_storage_key": storage_key,
"preview_media_type": media_type,
"preview_file_name": filename,
}
return {
"previewable": False,
"preview_kind": "",
"preview_storage_key": "",
"preview_media_type": "",
"preview_file_name": "",
}
def _resolve_item_attachment_preview_content(self, item: ExpenseClaimItem) -> tuple[Path, str, str]:
file_path, media_type, filename = self._resolve_item_attachment_content(item)
metadata = self._read_attachment_meta(file_path)
preview_storage_key = str(metadata.get("preview_storage_key") or "").strip()
preview_file_name = str(metadata.get("preview_file_name") or "").strip()
preview_media_type = str(metadata.get("preview_media_type") or "").strip()
if preview_storage_key:
preview_path = self._resolve_attachment_path(preview_storage_key)
if preview_path is not None and preview_path.exists():
resolved_name = preview_file_name or preview_path.name
resolved_media_type = self._resolve_attachment_media_type(
resolved_name,
fallback=preview_media_type,
)
return preview_path, resolved_media_type, resolved_name
if self._is_previewable_media_type(media_type, filename):
return file_path, media_type, filename
raise FileNotFoundError("Attachment preview not found")
def _build_attachment_payload(self, item: ExpenseClaimItem) -> dict[str, Any]: def _build_attachment_payload(self, item: ExpenseClaimItem) -> dict[str, Any]:
file_path, media_type, filename = self._resolve_item_attachment_content(item) file_path, media_type, filename = self._resolve_item_attachment_content(item)
metadata = self._read_attachment_meta(file_path) metadata = self._read_attachment_meta(file_path)
@@ -1233,18 +1869,71 @@ class ExpenseClaimService:
if not isinstance(requirement_check, dict): if not isinstance(requirement_check, dict):
requirement_check = None requirement_check = None
preview_kind = str(metadata.get("preview_kind") or "").strip()
previewable = bool(metadata.get("previewable", self._is_previewable_media_type(media_type, filename)))
preview_url = self._build_attachment_preview_client_path(item.claim_id, item.id) if previewable else ""
return { return {
"file_name": str(metadata.get("file_name") or filename), "file_name": str(metadata.get("file_name") or filename),
"storage_key": str(item.invoice_id or ""), "storage_key": str(item.invoice_id or ""),
"media_type": str(metadata.get("media_type") or media_type), "media_type": str(metadata.get("media_type") or media_type),
"size_bytes": int(metadata.get("size_bytes") or file_path.stat().st_size), "size_bytes": int(metadata.get("size_bytes") or file_path.stat().st_size),
"uploaded_at": uploaded_at, "uploaded_at": uploaded_at,
"previewable": bool(metadata.get("previewable", self._is_previewable_media_type(media_type, filename))), "previewable": previewable,
"preview_kind": preview_kind or self._resolve_preview_kind(media_type, filename),
"preview_url": preview_url,
"analysis": analysis, "analysis": analysis,
"document_info": document_info, "document_info": document_info,
"requirement_check": requirement_check, "requirement_check": requirement_check,
} }
@staticmethod
def _resolve_preview_kind(media_type: str | None, filename: str) -> str:
resolved = str(media_type or "").strip() or (mimetypes.guess_type(filename)[0] or "")
if resolved.startswith("image/"):
return "image"
if resolved == "application/pdf":
return "pdf"
return ""
@staticmethod
def _decode_data_url(payload: str) -> tuple[str, bytes] | None:
normalized = str(payload or "").strip()
matched = re.match(r"^data:(?P<media>[\w.+-]+/[\w.+-]+);base64,(?P<body>.+)$", normalized, flags=re.DOTALL)
if not matched:
return None
try:
content = base64.b64decode(matched.group("body"), validate=True)
except (binascii.Error, ValueError):
return None
return matched.group("media"), content
def _write_preview_asset_from_data_url(
self,
*,
attachment_dir: Path,
original_filename: str,
preview_data_url: str,
) -> tuple[Path, str, str] | None:
decoded = self._decode_data_url(preview_data_url)
if decoded is None:
return None
preview_media_type, preview_content = decoded
suffix = mimetypes.guess_extension(preview_media_type) or ".bin"
preview_name = f"{Path(original_filename).stem}.preview{suffix}"
preview_path = attachment_dir / preview_name
preview_path.write_bytes(preview_content)
return preview_path, preview_media_type, preview_name
@staticmethod
def _build_attachment_preview_client_path(claim_id: str, item_id: str) -> str:
return (
"/reimbursements/claims/"
f"{quote(str(claim_id or '').strip(), safe='')}"
f"/items/{quote(str(item_id or '').strip(), safe='')}/attachment/preview"
)
@staticmethod @staticmethod
def _resolve_attachment_media_type(filename: str, *, fallback: str | None = None) -> str: def _resolve_attachment_media_type(filename: str, *, fallback: str | None = None) -> str:
guessed = mimetypes.guess_type(filename)[0] guessed = mimetypes.guess_type(filename)[0]

View File

@@ -11,9 +11,11 @@ from app.api.deps import CurrentUserContext
from app.db.base import Base from app.db.base import Base
from app.models.employee import Employee from app.models.employee import Employee
from app.models.financial_record import ExpenseClaim, ExpenseClaimItem from app.models.financial_record import ExpenseClaim, ExpenseClaimItem
from app.schemas.ontology import OntologyParseRequest
from app.schemas.ocr import OcrRecognizeBatchRead, OcrRecognizeDocumentRead from app.schemas.ocr import OcrRecognizeBatchRead, OcrRecognizeDocumentRead
from app.schemas.reimbursement import ExpenseClaimItemCreate, ExpenseClaimItemUpdate from app.schemas.reimbursement import ExpenseClaimItemCreate, ExpenseClaimItemUpdate
from app.services.expense_claims import ExpenseClaimService from app.services.expense_claims import ExpenseClaimService
from app.services.ontology import SemanticOntologyService
from app.services.ocr import OcrService from app.services.ocr import OcrService
@@ -97,6 +99,347 @@ def test_resolve_expense_type_maps_office_supplies_review_value_to_office() -> N
assert expense_type == "office" assert expense_type == "office"
def test_upsert_draft_from_ontology_defers_multi_document_association_choice() -> None:
user_id = "zhangsan@example.com"
with build_session() as db:
employee = Employee(
employee_no="E5001",
name="张三",
email=user_id,
)
db.add(employee)
db.flush()
existing_claim = ExpenseClaim(
claim_no="EXP-202605-010",
employee_id=employee.id,
employee_name="张三",
department_name="市场部",
project_code=None,
expense_type="transport",
reason="原有交通报销",
location="深圳",
amount=Decimal("20.00"),
currency="CNY",
invoice_count=1,
occurred_at=datetime(2026, 5, 13, tzinfo=UTC),
status="draft",
approval_stage="待提交",
risk_flags_json=[],
)
existing_claim.items = [
ExpenseClaimItem(
claim_id=existing_claim.id,
item_date=date(2026, 5, 13),
item_type="transport",
item_reason="原有交通报销",
item_location="深圳",
item_amount=Decimal("20.00"),
invoice_id="old-trip.png",
)
]
db.add(existing_claim)
db.commit()
ontology = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="我上传了两张交通票据,帮我生成报销草稿",
user_id=user_id,
)
)
service = ExpenseClaimService(db)
result = service.upsert_draft_from_ontology(
run_id=ontology.run_id,
user_id=user_id,
message="我上传了两张交通票据,帮我生成报销草稿",
ontology=ontology,
context_json={
"name": "张三",
"attachment_names": ["didi-trip.png", "parking-ticket.jpg"],
"attachment_count": 2,
"draft_claim_id": existing_claim.id,
"ocr_documents": [
{
"filename": "didi-trip.png",
"summary": "滴滴出行 支付金额 32 元",
"text": "滴滴出行 支付金额 32 元",
},
{
"filename": "parking-ticket.jpg",
"summary": "停车费 合计 18 元",
"text": "停车费 合计 18 元",
},
],
},
)
db.refresh(existing_claim)
assert result["pending_association_decision"] is True
assert result["association_candidate_claim_id"] == existing_claim.id
assert existing_claim.invoice_count == 1
assert len(existing_claim.items) == 1
assert existing_claim.items[0].invoice_id == "old-trip.png"
def test_upsert_draft_from_ontology_keeps_reason_missing_for_attachment_only_upload() -> None:
user_id = "wangwu@example.com"
with build_session() as db:
employee = Employee(
employee_no="E5003",
name="王五",
email=user_id,
)
db.add(employee)
db.commit()
ontology = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="我上传了 1 份票据,请结合附件名称给出报销建议并尽量生成草稿。",
user_id=user_id,
)
)
service = ExpenseClaimService(db)
result = service.upsert_draft_from_ontology(
run_id=ontology.run_id,
user_id=user_id,
message="我上传了 1 份票据,请结合附件名称给出报销建议并尽量生成草稿。\n附件名称didi-trip.png",
ontology=ontology,
context_json={
"name": "王五",
"user_input_text": "",
"attachment_names": ["didi-trip.png"],
"attachment_count": 1,
"ocr_documents": [
{
"filename": "didi-trip.png",
"summary": "滴滴出行 支付金额 32 元",
"text": "滴滴出行 支付金额 32 元",
"document_type": "taxi_receipt",
"scene_code": "transport",
}
],
},
)
claim = db.get(ExpenseClaim, result["claim_id"])
assert claim is not None
assert claim.reason == "待补充"
def test_upsert_draft_from_ontology_supports_link_or_create_for_multi_documents() -> None:
user_id = "lisi@example.com"
with build_session() as db:
employee = Employee(
employee_no="E5002",
name="李四",
email=user_id,
)
db.add(employee)
db.flush()
existing_claim = ExpenseClaim(
claim_no="EXP-202605-011",
employee_id=employee.id,
employee_name="李四",
department_name="销售部",
project_code=None,
expense_type="transport",
reason="原有交通报销",
location="上海",
amount=Decimal("20.00"),
currency="CNY",
invoice_count=1,
occurred_at=datetime(2026, 5, 13, tzinfo=UTC),
status="draft",
approval_stage="待提交",
risk_flags_json=[],
)
existing_claim.items = [
ExpenseClaimItem(
claim_id=existing_claim.id,
item_date=date(2026, 5, 13),
item_type="transport",
item_reason="原有交通报销",
item_location="上海",
item_amount=Decimal("20.00"),
invoice_id="existing.png",
)
]
db.add(existing_claim)
db.commit()
ontology = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="我上传了两张交通票据,帮我生成报销草稿",
user_id=user_id,
)
)
service = ExpenseClaimService(db)
context_json = {
"name": "李四",
"attachment_names": ["didi-trip.png", "parking-ticket.jpg"],
"attachment_count": 2,
"draft_claim_id": existing_claim.id,
"ocr_documents": [
{
"filename": "didi-trip.png",
"summary": "滴滴出行",
"text": "滴滴出行 支付金额 32.50 元",
"document_type": "taxi_receipt",
"scene_code": "transport",
"document_fields": [{"key": "amount", "label": "支付金额", "value": "32.50"}],
},
{
"filename": "parking-ticket.jpg",
"summary": "停车票",
"text": "停车费 合计 18 元",
"document_type": "parking_toll_receipt",
"scene_code": "transport",
"document_fields": [{"key": "total_amount", "label": "合计金额", "value": "18"}],
},
],
}
link_result = service.upsert_draft_from_ontology(
run_id=ontology.run_id,
user_id=user_id,
message="把这两张票据关联到已有草稿",
ontology=ontology,
context_json={
**context_json,
"review_action": "link_to_existing_draft",
},
)
db.refresh(existing_claim)
assert link_result["claim_id"] == existing_claim.id
assert existing_claim.invoice_count == 3
assert len(existing_claim.items) == 3
assert float(existing_claim.amount) == 70.5
create_result = service.upsert_draft_from_ontology(
run_id=f"{ontology.run_id}-new",
user_id=user_id,
message="单独新建一张报销单",
ontology=ontology,
context_json={
**context_json,
"review_action": "create_new_claim_from_documents",
},
)
assert create_result["claim_id"] != existing_claim.id
new_claim = db.get(ExpenseClaim, create_result["claim_id"])
assert new_claim is not None
assert new_claim.invoice_count == 2
assert len(new_claim.items) == 2
assert float(new_claim.amount) == 50.5
def test_generate_claim_no_uses_max_suffix_instead_of_count() -> None:
with build_session() as db:
db.add_all(
[
ExpenseClaim(
claim_no="EXP-202605-001",
employee_name="张三",
department_name="市场部",
project_code=None,
expense_type="transport",
reason="交通报销",
location="深圳",
amount=Decimal("10.00"),
currency="CNY",
invoice_count=1,
occurred_at=datetime(2026, 5, 10, tzinfo=UTC),
status="draft",
approval_stage="待提交",
risk_flags_json=[],
),
ExpenseClaim(
claim_no="EXP-202605-003",
employee_name="李四",
department_name="销售部",
project_code=None,
expense_type="transport",
reason="交通报销",
location="上海",
amount=Decimal("20.00"),
currency="CNY",
invoice_count=1,
occurred_at=datetime(2026, 5, 11, tzinfo=UTC),
status="submitted",
approval_stage="审批中",
risk_flags_json=[],
),
]
)
db.commit()
service = ExpenseClaimService(db)
assert service._generate_claim_no(datetime(2026, 5, 14, tzinfo=UTC)) == "EXP-202605-004"
def test_upsert_draft_from_ontology_retries_claim_no_conflict() -> None:
user_id = "zhaoliu-claimno@example.com"
with build_session() as db:
employee = Employee(
employee_no="E5006",
name="赵六",
email=user_id,
)
db.add(employee)
db.flush()
db.add(
ExpenseClaim(
claim_no="EXP-202605-004",
employee_name="历史单据",
department_name="财务部",
project_code=None,
expense_type="other",
reason="历史草稿",
location="北京",
amount=Decimal("0.00"),
currency="CNY",
invoice_count=0,
occurred_at=datetime(2026, 5, 12, tzinfo=UTC),
status="submitted",
approval_stage="审批中",
risk_flags_json=[],
)
)
db.commit()
ontology = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="帮我生成报销草稿,我昨天交通费 13.4 元",
user_id=user_id,
)
)
service = ExpenseClaimService(db)
generated_claim_nos = iter(["EXP-202605-004", "EXP-202605-005"])
service._generate_claim_no = lambda occurred_at: next(generated_claim_nos)
result = service.upsert_draft_from_ontology(
run_id=ontology.run_id,
user_id=user_id,
message="帮我生成报销草稿,我昨天交通费 13.4 元",
ontology=ontology,
context_json={
"name": "赵六",
"user_input_text": "帮我生成报销草稿,我昨天交通费 13.4 元",
},
)
created_claim = db.get(ExpenseClaim, result["claim_id"])
assert created_claim is not None
assert created_claim.claim_no == "EXP-202605-005"
assert result["claim_no"] == "EXP-202605-005"
def test_create_claim_item_adds_blank_draft_row_without_forcing_attachment() -> None: def test_create_claim_item_adds_blank_draft_row_without_forcing_attachment() -> None:
current_user = CurrentUserContext( current_user = CurrentUserContext(
username="emp-1", username="emp-1",
@@ -186,6 +529,10 @@ def test_update_claim_item_reanalyzes_existing_attachment(monkeypatch, tmp_path)
current_user=current_user, current_user=current_user,
) )
assert uploaded_meta is not None assert uploaded_meta is not None
assert uploaded_meta["preview_kind"] == "image"
assert uploaded_meta["preview_url"].endswith(
f"/reimbursements/claims/{claim.id}/items/{claim.items[0].id}/attachment/preview"
)
assert uploaded_meta["analysis"]["severity"] == "pass" assert uploaded_meta["analysis"]["severity"] == "pass"
assert uploaded_meta["document_info"]["document_type"] == "office_invoice" assert uploaded_meta["document_info"]["document_type"] == "office_invoice"
assert uploaded_meta["requirement_check"]["matches"] is True assert uploaded_meta["requirement_check"]["matches"] is True

View File

@@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import base64
from collections.abc import Generator from collections.abc import Generator
from datetime import UTC, date, datetime from datetime import UTC, date, datetime
from decimal import Decimal from decimal import Decimal
@@ -165,9 +166,18 @@ def test_claim_item_attachment_upload_preview_and_delete(monkeypatch, tmp_path)
assert meta_response.status_code == 200 assert meta_response.status_code == 200
meta_payload = meta_response.json() meta_payload = meta_response.json()
assert meta_payload["media_type"] == "image/png" assert meta_payload["media_type"] == "image/png"
assert meta_payload["preview_kind"] == "image"
assert meta_payload["preview_url"].endswith(f"/reimbursements/claims/{claim_id}/items/{item_id}/attachment/preview")
assert meta_payload["analysis"]["headline"] assert meta_payload["analysis"]["headline"]
assert meta_payload["document_info"]["fields"][0]["label"] == "金额" assert meta_payload["document_info"]["fields"][0]["label"] == "金额"
preview_response = client.get(
f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment/preview",
headers=headers,
)
assert preview_response.status_code == 200
assert preview_response.content == file_bytes
content_response = client.get( content_response = client.get(
f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment", f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment",
headers=headers, headers=headers,
@@ -279,6 +289,67 @@ def test_claim_item_attachment_upload_flags_non_invoice_image_as_high_risk(monke
assert any("附件内容" in point for point in analysis["points"]) assert any("附件内容" in point for point in analysis["points"])
def test_claim_item_pdf_attachment_preview_returns_generated_image(monkeypatch, tmp_path) -> None:
preview_bytes = b"fake-preview-png"
preview_data_url = f"data:image/png;base64,{base64.b64encode(preview_bytes).decode('ascii')}"
def fake_recognize(
self,
files: list[tuple[str, bytes, str | None]],
) -> OcrRecognizeBatchRead:
return OcrRecognizeBatchRead(
total_file_count=1,
success_count=1,
documents=[
OcrRecognizeDocumentRead(
filename="invoice.pdf",
media_type="application/pdf",
text="滴滴出行电子发票 金额13.4元",
summary="识别到交通票据,金额 13.4 元。",
avg_score=0.96,
line_count=1,
page_count=1,
document_type="taxi_receipt",
document_type_label="出租车/网约车票据",
scene_code="transport",
scene_label="交通票据",
preview_kind="image",
preview_data_url=preview_data_url,
warnings=[],
)
],
)
monkeypatch.setattr(OcrService, "recognize_files", fake_recognize)
monkeypatch.setattr(ExpenseClaimService, "_get_attachment_storage_root", lambda self: tmp_path)
client, session_factory = build_client()
with session_factory() as db:
claim, item = seed_claim(db)
claim_id = claim.id
item_id = item.id
headers = {"x-auth-username": "emp-1", "x-auth-name": "Zhang San"}
upload_response = client.post(
f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment",
headers=headers,
files=[("file", ("invoice.pdf", b"%PDF-1.4 fake", "application/pdf"))],
)
assert upload_response.status_code == 200
meta_payload = upload_response.json()["attachment"]
assert meta_payload["preview_kind"] == "image"
assert meta_payload["preview_url"].endswith(f"/reimbursements/claims/{claim_id}/items/{item_id}/attachment/preview")
preview_response = client.get(
f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment/preview",
headers=headers,
)
assert preview_response.status_code == 200
assert preview_response.headers["content-type"].startswith("image/png")
assert preview_response.content == preview_bytes
def test_claim_item_delete_removes_item_and_attachment(monkeypatch, tmp_path) -> None: def test_claim_item_delete_removes_item_and_attachment(monkeypatch, tmp_path) -> None:
def fake_recognize( def fake_recognize(
self, self,