Files
X-Financial/server/src/app/services/expense_claims.py

810 lines
30 KiB
Python
Raw Normal View History

from __future__ import annotations
from datetime import UTC, date, datetime
from decimal import Decimal, InvalidOperation
from typing import Any
from sqlalchemy import func, or_, select
from sqlalchemy.orm import Session, selectinload
from app.api.deps import CurrentUserContext
from app.models.employee import Employee
from app.models.financial_record import ExpenseClaim, ExpenseClaimItem
from app.schemas.ontology import OntologyEntity, OntologyParseResult
from app.schemas.reimbursement import ExpenseClaimItemUpdate
from app.services.audit import AuditLogService
from app.services.agent_foundation import AgentFoundationService
EXPENSE_TYPE_LABELS = {
"travel": "差旅",
"hotel": "住宿",
"transport": "交通",
"meal": "餐费",
"meeting": "会务",
"entertainment": "招待",
}
PRIVILEGED_CLAIM_ROLE_CODES = {"manager", "finance", "approver", "auditor", "executive"}
MAX_DRAFT_CLAIMS_PER_USER = 3
class ExpenseClaimService:
def __init__(self, db: Session) -> None:
self.db = db
self.audit_service = AuditLogService(db)
def list_claims(self, current_user: CurrentUserContext) -> list[ExpenseClaim]:
stmt = (
select(ExpenseClaim)
.options(selectinload(ExpenseClaim.items))
.order_by(ExpenseClaim.created_at.desc(), ExpenseClaim.occurred_at.desc())
)
stmt = self._apply_claim_scope(stmt, current_user)
return list(self.db.scalars(stmt).all())
def get_claim(self, claim_id: str, current_user: CurrentUserContext) -> ExpenseClaim | None:
stmt = (
select(ExpenseClaim)
.options(selectinload(ExpenseClaim.items))
.where(ExpenseClaim.id == claim_id)
)
stmt = self._apply_claim_scope(stmt, current_user)
return self.db.scalar(stmt)
def update_claim_item(
self,
*,
claim_id: str,
item_id: str,
payload: ExpenseClaimItemUpdate,
current_user: CurrentUserContext,
) -> ExpenseClaim | None:
claim = self.get_claim(claim_id, current_user)
if claim is None:
return None
self._ensure_draft_claim(claim)
item = next((entry for entry in claim.items if entry.id == item_id), None)
if item is None:
raise LookupError("Item not found")
before_json = self._serialize_claim(claim)
if payload.item_date is not None:
item.item_date = payload.item_date
if payload.item_type is not None:
item.item_type = self._normalize_optional_text(payload.item_type, fallback=item.item_type) or item.item_type
if payload.item_reason is not None:
item.item_reason = (
self._normalize_optional_text(payload.item_reason, fallback=item.item_reason) or item.item_reason
)
if payload.item_location is not None:
item.item_location = (
self._normalize_optional_text(payload.item_location, fallback=item.item_location) or item.item_location
)
if payload.item_amount is not None:
amount = payload.item_amount.quantize(Decimal("0.01"))
if amount <= Decimal("0.00"):
raise ValueError("费用金额必须大于 0。")
item.item_amount = amount
if payload.invoice_id is not None:
item.invoice_id = self._normalize_optional_text(payload.invoice_id, allow_empty=True)
self._sync_claim_from_items(claim)
self.db.commit()
self.db.refresh(claim)
self.audit_service.log_action(
actor=current_user.name or current_user.username,
action="expense_claim.item_update",
resource_type="expense_claim",
resource_id=claim.id,
before_json=before_json,
after_json=self._serialize_claim(claim),
)
return claim
def submit_claim(self, claim_id: str, current_user: CurrentUserContext) -> ExpenseClaim | None:
claim = self.get_claim(claim_id, current_user)
if claim is None:
return None
self._ensure_draft_claim(claim)
self._sync_claim_from_items(claim)
missing_fields = self._validate_claim_for_submission(claim)
if missing_fields:
raise ValueError("提交前请先补全信息:" + "".join(missing_fields))
before_json = self._serialize_claim(claim)
claim.status = "submitted"
claim.approval_stage = "AI验审"
claim.submitted_at = datetime.now(UTC)
self.db.commit()
self.db.refresh(claim)
self.audit_service.log_action(
actor=current_user.name or current_user.username,
action="expense_claim.submit",
resource_type="expense_claim",
resource_id=claim.id,
before_json=before_json,
after_json=self._serialize_claim(claim),
)
return claim
def delete_claim(self, claim_id: str, current_user: CurrentUserContext) -> ExpenseClaim | None:
claim = self.get_claim(claim_id, current_user)
if claim is None:
return None
self._ensure_draft_claim(claim)
before_json = self._serialize_claim(claim)
resource_id = claim.id
self.db.delete(claim)
self.db.commit()
self.audit_service.log_action(
actor=current_user.name or current_user.username,
action="expense_claim.delete",
resource_type="expense_claim",
resource_id=resource_id,
before_json=before_json,
after_json=None,
)
return claim
def upsert_draft_from_ontology(
self,
*,
run_id: str,
user_id: str | None,
message: str,
ontology: OntologyParseResult,
context_json: dict[str, Any],
) -> dict[str, Any]:
self._ensure_ready()
claim = self._find_target_claim(ontology=ontology, context_json=context_json)
is_new_claim = claim is None
before_json = self._serialize_claim(claim) if claim is not None else None
employee = self._resolve_employee(
ontology=ontology,
context_json=context_json,
user_id=user_id,
)
draft_owner_name = (
employee.name
if employee is not None
else self._resolve_employee_name(
ontology=ontology,
context_json=context_json,
user_id=user_id,
)
)
if is_new_claim:
existing_draft_count = self._count_draft_claims_for_owner(
employee=employee,
employee_name=draft_owner_name,
user_id=user_id,
)
if existing_draft_count >= MAX_DRAFT_CLAIMS_PER_USER:
return {
"message": (
f"你当前已保存 {MAX_DRAFT_CLAIMS_PER_USER} 个草稿,请先完成已保存的草稿,"
"才能再次新建草稿。"
),
"draft_limit_reached": True,
"draft_only": False,
"status": "blocked",
"draft_count": existing_draft_count,
"max_draft_count": MAX_DRAFT_CLAIMS_PER_USER,
}
amount = self._resolve_amount(ontology.entities, context_json=context_json)
occurred_at = self._resolve_occurred_at(ontology, context_json=context_json)
expense_type = self._resolve_expense_type(ontology.entities, context_json=context_json)
location = self._resolve_location(message=message, context_json=context_json)
reason = self._resolve_reason(
message=message,
context_json=context_json,
allow_message_fallback=is_new_claim,
)
attachment_count = self._resolve_attachment_count(context_json)
final_amount = amount if amount is not None else (claim.amount if claim is not None else Decimal("0.00"))
final_occurred_at = (
occurred_at if occurred_at is not None else (claim.occurred_at if claim is not None else datetime.now(UTC))
)
final_expense_type = expense_type or (claim.expense_type if claim is not None else "other")
final_location = location or (claim.location if claim is not None else "待补充")
final_reason = reason or (claim.reason if claim is not None else "待补充")
final_attachment_count = (
attachment_count if attachment_count > 0 else int(claim.invoice_count or 0) if claim is not None else 0
)
final_risk_flags = list(ontology.risk_flags) or (
list(claim.risk_flags_json or []) if claim is not None else []
)
if claim is None:
claim = ExpenseClaim(
claim_no=self._generate_claim_no(final_occurred_at),
employee_id=employee.id if employee is not None else None,
employee_name=draft_owner_name,
department_id=employee.organization_unit_id if employee is not None else None,
department_name=self._resolve_department_name(
employee=employee,
context_json=context_json,
),
project_code=self._resolve_project_code(ontology.entities),
expense_type=final_expense_type,
reason=final_reason,
location=final_location,
amount=final_amount,
currency="CNY",
invoice_count=final_attachment_count,
occurred_at=final_occurred_at,
status="draft",
approval_stage="待提交",
risk_flags_json=final_risk_flags,
)
self.db.add(claim)
else:
claim.employee_id = employee.id if employee is not None else claim.employee_id
claim.employee_name = (
employee.name
if employee is not None
else self._resolve_employee_name(
ontology=ontology,
context_json=context_json,
user_id=user_id,
fallback=claim.employee_name,
)
)
claim.department_id = employee.organization_unit_id if employee is not None else claim.department_id
claim.department_name = self._resolve_department_name(
employee=employee,
context_json=context_json,
fallback=claim.department_name,
)
claim.project_code = self._resolve_project_code(ontology.entities) or claim.project_code
claim.expense_type = final_expense_type
claim.reason = final_reason
claim.location = final_location
claim.amount = final_amount
claim.invoice_count = final_attachment_count
claim.occurred_at = final_occurred_at
claim.status = "draft"
claim.approval_stage = "待提交"
claim.risk_flags_json = final_risk_flags
self.db.flush()
self._upsert_primary_item(
claim=claim,
occurred_at=final_occurred_at,
expense_type=final_expense_type,
amount=final_amount,
reason=final_reason,
location=final_location,
attachment_names=self._resolve_attachment_names(context_json),
)
self.db.commit()
self.db.refresh(claim)
self.audit_service.log_action(
actor=user_id or claim.employee_name or "anonymous",
action="expense_claim.draft_upsert",
resource_type="expense_claim",
resource_id=claim.id,
before_json=before_json,
after_json=self._serialize_claim(claim),
request_id=run_id,
)
return {
"message": (
f"{'创建' if is_new_claim else '更新'}报销草稿 {claim.claim_no},当前状态为 draft。"
"你可以继续补充费用明细、客户单位和票据附件。"
),
"draft_only": True,
"claim_id": claim.id,
"claim_no": claim.claim_no,
"status": claim.status,
"amount": float(claim.amount),
"invoice_count": int(claim.invoice_count or 0),
}
def _find_target_claim(
self,
*,
ontology: OntologyParseResult,
context_json: dict[str, Any],
) -> ExpenseClaim | None:
draft_claim_id = str(context_json.get("draft_claim_id") or "").strip()
if draft_claim_id:
return self.db.get(ExpenseClaim, draft_claim_id)
claim_codes = [
item.normalized_value
for item in ontology.entities
if item.type == "expense_claim" and item.normalized_value
]
if not claim_codes:
return None
stmt = select(ExpenseClaim).where(ExpenseClaim.claim_no.in_(claim_codes)).limit(1)
return self.db.scalar(stmt)
def _upsert_primary_item(
self,
*,
claim: ExpenseClaim,
occurred_at: datetime,
expense_type: str,
amount: Decimal,
reason: str,
location: str,
attachment_names: list[str],
) -> None:
item = claim.items[0] if claim.items else None
if item is None:
item = ExpenseClaimItem(
claim_id=claim.id,
item_date=occurred_at.date(),
item_type=expense_type,
item_reason=reason,
item_location=location,
item_amount=amount,
invoice_id=attachment_names[0] if attachment_names else None,
)
claim.items.append(item)
self.db.add(item)
return
item.item_date = occurred_at.date()
item.item_type = expense_type
item.item_reason = reason
item.item_location = location
item.item_amount = amount
item.invoice_id = attachment_names[0] if attachment_names else item.invoice_id
def _generate_claim_no(self, occurred_at: datetime) -> str:
month_code = occurred_at.strftime("%Y%m")
prefix = f"EXP-{month_code}-"
existing = int(
self.db.scalar(
select(func.count()).select_from(ExpenseClaim).where(ExpenseClaim.claim_no.like(f"{prefix}%"))
)
or 0
)
return f"{prefix}{existing + 1:03d}"
def _count_draft_claims_for_owner(
self,
*,
employee: Employee | None,
employee_name: str,
user_id: str | None,
) -> int:
owner_filters = self._build_draft_owner_filters(
employee=employee,
employee_name=employee_name,
user_id=user_id,
)
if not owner_filters:
return 0
stmt = (
select(func.count())
.select_from(ExpenseClaim)
.where(ExpenseClaim.status == "draft")
.where(or_(*owner_filters))
)
return int(self.db.scalar(stmt) or 0)
@staticmethod
def _build_draft_owner_filters(
*,
employee: Employee | None,
employee_name: str,
user_id: str | None,
) -> list[Any]:
conditions: list[Any] = []
seen: set[tuple[str, str]] = set()
def add_condition(field_name: str, value: str | None) -> None:
normalized = str(value or "").strip()
if not normalized or normalized == "待补充":
return
marker = (field_name, normalized.lower())
if marker in seen:
return
seen.add(marker)
if field_name == "employee_id":
conditions.append(ExpenseClaim.employee_id == normalized)
return
conditions.append(ExpenseClaim.employee_name == normalized)
if employee is not None:
add_condition("employee_id", employee.id)
add_condition("employee_name", employee.name)
add_condition("employee_name", employee.email)
add_condition("employee_name", employee_name)
add_condition("employee_name", user_id)
return conditions
def _resolve_employee(
self,
*,
ontology: OntologyParseResult,
context_json: dict[str, Any],
user_id: str | None,
) -> Employee | None:
normalized_user_id = str(user_id or "").strip()
if normalized_user_id:
stmt = select(Employee).where(func.lower(Employee.email) == normalized_user_id.lower()).limit(1)
employee = self.db.scalar(stmt)
if employee is not None:
return employee
employee_name = self._resolve_employee_name(
ontology=ontology,
context_json=context_json,
user_id=None,
)
if not employee_name:
return None
stmt = select(Employee).where(Employee.name == employee_name).limit(1)
return self.db.scalar(stmt)
@staticmethod
def _resolve_employee_name(
*,
ontology: OntologyParseResult,
context_json: dict[str, Any],
user_id: str | None,
fallback: str = "待补充",
) -> str:
review_form_values = context_json.get("review_form_values")
if isinstance(review_form_values, dict):
for key in ("reporter_name", "employee_name", "claimant_name"):
value = str(review_form_values.get(key) or "").strip()
if value:
return value
for item in ontology.entities:
if item.type == "employee" and item.value.strip():
return item.value.strip()
for key in ("name", "user_name", "employee_name"):
value = str(context_json.get(key) or "").strip()
if value:
return value
return str(user_id or fallback).strip() or fallback
@staticmethod
def _resolve_department_name(
*,
employee: Employee | None,
context_json: dict[str, Any],
fallback: str = "待补充",
) -> str:
if employee is not None and employee.organization_unit is not None:
return employee.organization_unit.name
request_context = context_json.get("request_context")
if isinstance(request_context, dict):
for key in ("department", "department_name", "deptName"):
value = str(request_context.get(key) or "").strip()
if value:
return value
for key in ("department_name", "department"):
value = str(context_json.get(key) or "").strip()
if value:
return value
return fallback
@staticmethod
def _resolve_project_code(entities: list[OntologyEntity]) -> str | None:
for item in entities:
if item.type == "project" and item.normalized_value.strip():
return item.normalized_value.strip()
return None
@staticmethod
def _resolve_expense_type(
entities: list[OntologyEntity],
*,
context_json: dict[str, Any],
) -> str | None:
review_form_values = context_json.get("review_form_values")
if isinstance(review_form_values, dict):
compact = str(
review_form_values.get("expense_type")
or review_form_values.get("reimbursement_type")
or ""
).replace(" ", "")
if compact:
if "招待" in compact or ("客户" in compact and any(word in compact for word in ("吃饭", "宴请", "请客", "用餐"))):
return "entertainment"
if any(word in compact for word in ("差旅", "出差", "机票", "行程")):
return "travel"
if any(word in compact for word in ("住宿", "酒店", "宾馆")):
return "hotel"
if any(word in compact for word in ("交通", "打车", "网约车", "出租车", "停车", "车费")):
return "transport"
if any(word in compact for word in ("餐费", "用餐", "午餐", "晚餐", "早餐", "伙食")):
return "meal"
if "会务" in compact:
return "meeting"
for item in entities:
if item.type == "expense_type":
normalized = item.normalized_value.strip()
if normalized:
return normalized
return None
@staticmethod
def _resolve_reason(
*,
message: str,
context_json: dict[str, Any],
allow_message_fallback: bool,
) -> str | None:
review_form_values = context_json.get("review_form_values")
if isinstance(review_form_values, dict):
for key in ("reason", "business_reason"):
value = str(review_form_values.get(key) or "").strip()
if value:
return value
request_context = context_json.get("request_context")
if (
isinstance(request_context, dict)
and str(context_json.get("entry_source") or "").strip() == "detail"
):
for key in ("reason", "title"):
value = str(request_context.get(key) or "").strip()
if value:
return value
if not allow_message_fallback:
return None
return str(message or "").strip()[:500] or None
@staticmethod
def _resolve_location(*, message: str, context_json: dict[str, Any]) -> str | None:
review_form_values = context_json.get("review_form_values")
if isinstance(review_form_values, dict):
for key in ("business_location", "location"):
value = str(review_form_values.get(key) or "").strip()
if value:
return value
request_context = context_json.get("request_context")
if (
isinstance(request_context, dict)
and str(context_json.get("entry_source") or "").strip() == "detail"
):
for key in ("city", "location"):
value = str(request_context.get(key) or "").strip()
if value:
return value
compact = str(message or "").replace(" ", "")
if "客户现场" in compact:
return "客户现场"
return None
@staticmethod
def _resolve_occurred_at(
ontology: OntologyParseResult,
*,
context_json: dict[str, Any],
) -> datetime | None:
review_form_values = context_json.get("review_form_values")
if isinstance(review_form_values, dict):
for key in ("occurred_date", "time_range", "business_time"):
value = str(review_form_values.get(key) or "").strip()
if not value:
continue
try:
parsed = date.fromisoformat(value)
return datetime(parsed.year, parsed.month, parsed.day, tzinfo=UTC)
except ValueError:
continue
start_date = ontology.time_range.start_date
if start_date:
try:
parsed = date.fromisoformat(start_date)
return datetime(parsed.year, parsed.month, parsed.day, tzinfo=UTC)
except ValueError:
pass
return None
@staticmethod
def _resolve_amount(
entities: list[OntologyEntity],
*,
context_json: dict[str, Any],
) -> Decimal | None:
review_form_values = context_json.get("review_form_values")
if isinstance(review_form_values, dict):
raw_value = str(review_form_values.get("amount") or "").strip()
if raw_value:
compact = raw_value.replace("", "").replace(",", "").strip()
try:
return Decimal(compact).quantize(Decimal("0.01"))
except (InvalidOperation, ValueError):
pass
for item in entities:
if item.type != "amount" or item.role == "threshold":
continue
try:
return Decimal(item.normalized_value).quantize(Decimal("0.01"))
except (InvalidOperation, ValueError):
continue
return None
@staticmethod
def _resolve_attachment_names(context_json: dict[str, Any]) -> list[str]:
names = context_json.get("attachment_names")
if not isinstance(names, list):
return []
return [str(name).strip() for name in names if str(name).strip()]
def _resolve_attachment_count(self, context_json: dict[str, Any]) -> int:
names = self._resolve_attachment_names(context_json)
if names:
return len(names)
try:
return max(0, int(context_json.get("attachment_count") or 0))
except (TypeError, ValueError):
return 0
@staticmethod
def _serialize_claim(claim: ExpenseClaim) -> dict[str, Any]:
return {
"id": claim.id,
"claim_no": claim.claim_no,
"employee_name": claim.employee_name,
"department_name": claim.department_name,
"project_code": claim.project_code,
"expense_type": claim.expense_type,
"reason": claim.reason,
"location": claim.location,
"amount": float(claim.amount),
"invoice_count": int(claim.invoice_count or 0),
"status": claim.status,
"approval_stage": claim.approval_stage,
"risk_flags_json": list(claim.risk_flags_json or []),
}
@staticmethod
def _normalize_optional_text(value: str | None, *, fallback: str = "", allow_empty: bool = False) -> str | None:
normalized = str(value or "").strip()
if normalized:
return normalized
if allow_empty:
return None
return fallback
@staticmethod
def _is_missing_value(value: Any) -> bool:
text = str(value or "").strip()
if not text:
return True
compact = text.replace(" ", "")
return compact in {"待补充", "暂无", "", "未知", "处理中"}
def _ensure_draft_claim(self, claim: ExpenseClaim) -> None:
if str(claim.status or "").strip().lower() != "draft":
raise ValueError("只有草稿状态的报销单才允许执行该操作。")
def _sync_claim_from_items(self, claim: ExpenseClaim) -> None:
if not claim.items:
claim.amount = Decimal("0.00")
claim.invoice_count = 0
return
ordered_items = sorted(
claim.items,
key=lambda item: (
item.item_date or date.max,
item.created_at or datetime.max.replace(tzinfo=UTC),
),
)
primary_item = ordered_items[0]
total_amount = sum((item.item_amount for item in ordered_items), Decimal("0.00"))
claim.amount = total_amount.quantize(Decimal("0.01"))
claim.invoice_count = sum(1 for item in ordered_items if str(item.invoice_id or "").strip())
claim.occurred_at = datetime(
primary_item.item_date.year,
primary_item.item_date.month,
primary_item.item_date.day,
tzinfo=UTC,
)
claim.expense_type = str(primary_item.item_type or claim.expense_type or "other").strip() or "other"
claim.reason = (
self._normalize_optional_text(primary_item.item_reason, fallback=claim.reason or "待补充") or "待补充"
)
claim.location = (
self._normalize_optional_text(primary_item.item_location, fallback=claim.location or "待补充")
or "待补充"
)
if str(claim.status or "").strip().lower() == "draft":
claim.approval_stage = "待提交"
def _validate_claim_for_submission(self, claim: ExpenseClaim) -> list[str]:
issues: list[str] = []
if self._is_missing_value(claim.employee_name):
issues.append("申请人未完善")
if self._is_missing_value(claim.department_name):
issues.append("所属部门未完善")
if self._is_missing_value(claim.expense_type):
issues.append("报销类型未完善")
if self._is_missing_value(claim.reason):
issues.append("报销事由未完善")
if self._is_missing_value(claim.location):
issues.append("业务地点未完善")
if claim.amount is None or claim.amount <= Decimal("0.00"):
issues.append("报销金额未完善")
if claim.occurred_at is None:
issues.append("发生时间未完善")
if not claim.items:
issues.append("费用明细不能为空")
for index, item in enumerate(claim.items, start=1):
prefix = f"费用明细第 {index}"
if item.item_date is None:
issues.append(f"{prefix}缺少日期")
if self._is_missing_value(item.item_type):
issues.append(f"{prefix}缺少费用项目")
if self._is_missing_value(item.item_reason):
issues.append(f"{prefix}缺少说明")
if self._is_missing_value(item.item_location):
issues.append(f"{prefix}缺少地点")
if item.item_amount is None or item.item_amount <= Decimal("0.00"):
issues.append(f"{prefix}缺少金额")
if self._is_missing_value(item.invoice_id):
issues.append(f"{prefix}缺少票据标识")
return issues
@staticmethod
def _has_privileged_claim_access(current_user: CurrentUserContext) -> bool:
if current_user.is_admin:
return True
return bool(set(current_user.role_codes) & PRIVILEGED_CLAIM_ROLE_CODES)
def _apply_claim_scope(self, stmt: Any, current_user: CurrentUserContext) -> Any:
if self._has_privileged_claim_access(current_user):
return stmt
conditions = []
username = str(current_user.username or "").strip()
name = str(current_user.name or "").strip()
if username:
conditions.append(ExpenseClaim.employee_id == username)
conditions.append(ExpenseClaim.employee_name == username)
if name:
conditions.append(ExpenseClaim.employee_name == name)
if not conditions:
return stmt.where(ExpenseClaim.id == "__no_visible_claim__")
return stmt.where(or_(*conditions))
def _ensure_ready(self) -> None:
AgentFoundationService(self.db).ensure_foundation_ready()