Files
X-Financial/server/src/app/services/expense_claims.py

469 lines
18 KiB
Python
Raw Normal View History

from __future__ import annotations
from datetime import UTC, date, datetime
from decimal import Decimal, InvalidOperation
from typing import Any
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from app.models.employee import Employee
from app.models.financial_record import ExpenseClaim, ExpenseClaimItem
from app.schemas.ontology import OntologyEntity, OntologyParseResult
from app.services.audit import AuditLogService
from app.services.agent_foundation import AgentFoundationService
EXPENSE_TYPE_LABELS = {
"travel": "差旅",
"hotel": "住宿",
"transport": "交通",
"meal": "餐费",
"meeting": "会务",
"entertainment": "招待",
}
class ExpenseClaimService:
def __init__(self, db: Session) -> None:
self.db = db
self.audit_service = AuditLogService(db)
def upsert_draft_from_ontology(
self,
*,
run_id: str,
user_id: str | None,
message: str,
ontology: OntologyParseResult,
context_json: dict[str, Any],
) -> dict[str, Any]:
self._ensure_ready()
claim = self._find_target_claim(ontology=ontology, context_json=context_json)
is_new_claim = claim is None
before_json = self._serialize_claim(claim) if claim is not None else None
employee = self._resolve_employee(ontology=ontology, context_json=context_json)
amount = self._resolve_amount(ontology.entities, context_json=context_json)
occurred_at = self._resolve_occurred_at(ontology, context_json=context_json)
expense_type = self._resolve_expense_type(ontology.entities, context_json=context_json)
location = self._resolve_location(message=message, context_json=context_json)
reason = self._resolve_reason(
message=message,
context_json=context_json,
allow_message_fallback=is_new_claim,
)
attachment_count = self._resolve_attachment_count(context_json)
final_amount = amount if amount is not None else (claim.amount if claim is not None else Decimal("0.00"))
final_occurred_at = (
occurred_at if occurred_at is not None else (claim.occurred_at if claim is not None else datetime.now(UTC))
)
final_expense_type = expense_type or (claim.expense_type if claim is not None else "other")
final_location = location or (claim.location if claim is not None else "待补充")
final_reason = reason or (claim.reason if claim is not None else "待补充")
final_attachment_count = (
attachment_count if attachment_count > 0 else int(claim.invoice_count or 0) if claim is not None else 0
)
final_risk_flags = list(ontology.risk_flags) or (
list(claim.risk_flags_json or []) if claim is not None else []
)
if claim is None:
claim = ExpenseClaim(
claim_no=self._generate_claim_no(final_occurred_at),
employee_id=employee.id if employee is not None else None,
employee_name=employee.name if employee is not None else self._resolve_employee_name(
ontology=ontology,
context_json=context_json,
user_id=user_id,
),
department_id=employee.organization_unit_id if employee is not None else None,
department_name=self._resolve_department_name(
employee=employee,
context_json=context_json,
),
project_code=self._resolve_project_code(ontology.entities),
expense_type=final_expense_type,
reason=final_reason,
location=final_location,
amount=final_amount,
currency="CNY",
invoice_count=final_attachment_count,
occurred_at=final_occurred_at,
status="draft",
approval_stage="待补充",
risk_flags_json=final_risk_flags,
)
self.db.add(claim)
else:
claim.employee_id = employee.id if employee is not None else claim.employee_id
claim.employee_name = (
employee.name
if employee is not None
else self._resolve_employee_name(
ontology=ontology,
context_json=context_json,
user_id=user_id,
fallback=claim.employee_name,
)
)
claim.department_id = employee.organization_unit_id if employee is not None else claim.department_id
claim.department_name = self._resolve_department_name(
employee=employee,
context_json=context_json,
fallback=claim.department_name,
)
claim.project_code = self._resolve_project_code(ontology.entities) or claim.project_code
claim.expense_type = final_expense_type
claim.reason = final_reason
claim.location = final_location
claim.amount = final_amount
claim.invoice_count = final_attachment_count
claim.occurred_at = final_occurred_at
claim.status = "draft"
claim.approval_stage = "待补充"
claim.risk_flags_json = final_risk_flags
self.db.flush()
self._upsert_primary_item(
claim=claim,
occurred_at=final_occurred_at,
expense_type=final_expense_type,
amount=final_amount,
reason=final_reason,
location=final_location,
attachment_names=self._resolve_attachment_names(context_json),
)
self.db.commit()
self.db.refresh(claim)
self.audit_service.log_action(
actor=user_id or claim.employee_name or "anonymous",
action="expense_claim.draft_upsert",
resource_type="expense_claim",
resource_id=claim.id,
before_json=before_json,
after_json=self._serialize_claim(claim),
request_id=run_id,
)
return {
"message": (
f"{'创建' if is_new_claim else '更新'}报销草稿 {claim.claim_no},当前状态为 draft。"
"你可以继续补充费用明细、客户单位和票据附件。"
),
"draft_only": True,
"claim_id": claim.id,
"claim_no": claim.claim_no,
"status": claim.status,
"amount": float(claim.amount),
"invoice_count": int(claim.invoice_count or 0),
}
def _find_target_claim(
self,
*,
ontology: OntologyParseResult,
context_json: dict[str, Any],
) -> ExpenseClaim | None:
draft_claim_id = str(context_json.get("draft_claim_id") or "").strip()
if draft_claim_id:
return self.db.get(ExpenseClaim, draft_claim_id)
claim_codes = [
item.normalized_value
for item in ontology.entities
if item.type == "expense_claim" and item.normalized_value
]
if not claim_codes:
return None
stmt = select(ExpenseClaim).where(ExpenseClaim.claim_no.in_(claim_codes)).limit(1)
return self.db.scalar(stmt)
def _upsert_primary_item(
self,
*,
claim: ExpenseClaim,
occurred_at: datetime,
expense_type: str,
amount: Decimal,
reason: str,
location: str,
attachment_names: list[str],
) -> None:
item = claim.items[0] if claim.items else None
if item is None:
item = ExpenseClaimItem(
claim_id=claim.id,
item_date=occurred_at.date(),
item_type=expense_type,
item_reason=reason,
item_location=location,
item_amount=amount,
invoice_id=attachment_names[0] if attachment_names else None,
)
claim.items.append(item)
self.db.add(item)
return
item.item_date = occurred_at.date()
item.item_type = expense_type
item.item_reason = reason
item.item_location = location
item.item_amount = amount
item.invoice_id = attachment_names[0] if attachment_names else item.invoice_id
def _generate_claim_no(self, occurred_at: datetime) -> str:
month_code = occurred_at.strftime("%Y%m")
prefix = f"EXP-{month_code}-"
existing = int(
self.db.scalar(
select(func.count()).select_from(ExpenseClaim).where(ExpenseClaim.claim_no.like(f"{prefix}%"))
)
or 0
)
return f"{prefix}{existing + 1:03d}"
def _resolve_employee(
self,
*,
ontology: OntologyParseResult,
context_json: dict[str, Any],
) -> Employee | None:
employee_name = self._resolve_employee_name(
ontology=ontology,
context_json=context_json,
user_id=None,
)
if not employee_name:
return None
stmt = select(Employee).where(Employee.name == employee_name).limit(1)
return self.db.scalar(stmt)
@staticmethod
def _resolve_employee_name(
*,
ontology: OntologyParseResult,
context_json: dict[str, Any],
user_id: str | None,
fallback: str = "待补充",
) -> str:
review_form_values = context_json.get("review_form_values")
if isinstance(review_form_values, dict):
for key in ("reporter_name", "employee_name", "claimant_name"):
value = str(review_form_values.get(key) or "").strip()
if value:
return value
for item in ontology.entities:
if item.type == "employee" and item.value.strip():
return item.value.strip()
for key in ("name", "user_name", "employee_name"):
value = str(context_json.get(key) or "").strip()
if value:
return value
return str(user_id or fallback).strip() or fallback
@staticmethod
def _resolve_department_name(
*,
employee: Employee | None,
context_json: dict[str, Any],
fallback: str = "待补充",
) -> str:
if employee is not None and employee.organization_unit is not None:
return employee.organization_unit.name
request_context = context_json.get("request_context")
if isinstance(request_context, dict):
for key in ("department", "department_name", "deptName"):
value = str(request_context.get(key) or "").strip()
if value:
return value
for key in ("department_name", "department"):
value = str(context_json.get(key) or "").strip()
if value:
return value
return fallback
@staticmethod
def _resolve_project_code(entities: list[OntologyEntity]) -> str | None:
for item in entities:
if item.type == "project" and item.normalized_value.strip():
return item.normalized_value.strip()
return None
@staticmethod
def _resolve_expense_type(
entities: list[OntologyEntity],
*,
context_json: dict[str, Any],
) -> str | None:
review_form_values = context_json.get("review_form_values")
if isinstance(review_form_values, dict):
compact = str(
review_form_values.get("expense_type")
or review_form_values.get("reimbursement_type")
or ""
).replace(" ", "")
if compact:
if "招待" in compact or ("客户" in compact and any(word in compact for word in ("吃饭", "宴请", "请客", "用餐"))):
return "entertainment"
if any(word in compact for word in ("差旅", "出差", "机票", "行程")):
return "travel"
if any(word in compact for word in ("住宿", "酒店", "宾馆")):
return "hotel"
if any(word in compact for word in ("交通", "打车", "网约车", "出租车", "停车", "车费")):
return "transport"
if any(word in compact for word in ("餐费", "用餐", "午餐", "晚餐", "早餐", "伙食")):
return "meal"
if "会务" in compact:
return "meeting"
for item in entities:
if item.type == "expense_type":
normalized = item.normalized_value.strip()
if normalized:
return normalized
return None
@staticmethod
def _resolve_reason(
*,
message: str,
context_json: dict[str, Any],
allow_message_fallback: bool,
) -> str | None:
review_form_values = context_json.get("review_form_values")
if isinstance(review_form_values, dict):
for key in ("reason", "business_reason"):
value = str(review_form_values.get(key) or "").strip()
if value:
return value
request_context = context_json.get("request_context")
if (
isinstance(request_context, dict)
and str(context_json.get("entry_source") or "").strip() == "detail"
):
for key in ("reason", "title"):
value = str(request_context.get(key) or "").strip()
if value:
return value
if not allow_message_fallback:
return None
return str(message or "").strip()[:500] or None
@staticmethod
def _resolve_location(*, message: str, context_json: dict[str, Any]) -> str | None:
review_form_values = context_json.get("review_form_values")
if isinstance(review_form_values, dict):
for key in ("business_location", "location"):
value = str(review_form_values.get(key) or "").strip()
if value:
return value
request_context = context_json.get("request_context")
if (
isinstance(request_context, dict)
and str(context_json.get("entry_source") or "").strip() == "detail"
):
for key in ("city", "location"):
value = str(request_context.get(key) or "").strip()
if value:
return value
compact = str(message or "").replace(" ", "")
if "客户现场" in compact:
return "客户现场"
return None
@staticmethod
def _resolve_occurred_at(
ontology: OntologyParseResult,
*,
context_json: dict[str, Any],
) -> datetime | None:
review_form_values = context_json.get("review_form_values")
if isinstance(review_form_values, dict):
for key in ("occurred_date", "time_range", "business_time"):
value = str(review_form_values.get(key) or "").strip()
if not value:
continue
try:
parsed = date.fromisoformat(value)
return datetime(parsed.year, parsed.month, parsed.day, tzinfo=UTC)
except ValueError:
continue
start_date = ontology.time_range.start_date
if start_date:
try:
parsed = date.fromisoformat(start_date)
return datetime(parsed.year, parsed.month, parsed.day, tzinfo=UTC)
except ValueError:
pass
return None
@staticmethod
def _resolve_amount(
entities: list[OntologyEntity],
*,
context_json: dict[str, Any],
) -> Decimal | None:
review_form_values = context_json.get("review_form_values")
if isinstance(review_form_values, dict):
raw_value = str(review_form_values.get("amount") or "").strip()
if raw_value:
compact = raw_value.replace("", "").replace(",", "").strip()
try:
return Decimal(compact).quantize(Decimal("0.01"))
except (InvalidOperation, ValueError):
pass
for item in entities:
if item.type != "amount" or item.role == "threshold":
continue
try:
return Decimal(item.normalized_value).quantize(Decimal("0.01"))
except (InvalidOperation, ValueError):
continue
return None
@staticmethod
def _resolve_attachment_names(context_json: dict[str, Any]) -> list[str]:
names = context_json.get("attachment_names")
if not isinstance(names, list):
return []
return [str(name).strip() for name in names if str(name).strip()]
def _resolve_attachment_count(self, context_json: dict[str, Any]) -> int:
names = self._resolve_attachment_names(context_json)
if names:
return len(names)
try:
return max(0, int(context_json.get("attachment_count") or 0))
except (TypeError, ValueError):
return 0
@staticmethod
def _serialize_claim(claim: ExpenseClaim) -> dict[str, Any]:
return {
"id": claim.id,
"claim_no": claim.claim_no,
"employee_name": claim.employee_name,
"department_name": claim.department_name,
"project_code": claim.project_code,
"expense_type": claim.expense_type,
"reason": claim.reason,
"location": claim.location,
"amount": float(claim.amount),
"invoice_count": int(claim.invoice_count or 0),
"status": claim.status,
"approval_stage": claim.approval_stage,
"risk_flags_json": list(claim.risk_flags_json or []),
}
def _ensure_ready(self) -> None:
AgentFoundationService(self.db).ensure_foundation_ready()