refactor(backend): update reimbursement and expense claims

- endpoints/reimbursements.py: update reimbursement API endpoint
- schemas/reimbursement.py: update reimbursement data schemas
- services/expense_claims.py: update expense claims service logic
This commit is contained in:
caoxiaozhu
2026-05-13 03:22:52 +00:00
parent 6147b690b2
commit cea8239370
3 changed files with 448 additions and 6 deletions

View File

@@ -5,13 +5,21 @@ from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from app.api.deps import get_db
from app.api.deps import CurrentUserContext, get_current_user, get_db
from app.schemas.common import ErrorResponse
from app.schemas.reimbursement import ReimbursementCreate, ReimbursementRead
from app.schemas.reimbursement import (
ExpenseClaimActionResponse,
ExpenseClaimItemUpdate,
ExpenseClaimRead,
ReimbursementCreate,
ReimbursementRead,
)
from app.services.expense_claims import ExpenseClaimService
from app.services.reimbursement import ReimbursementService
router = APIRouter()
DbSession = Annotated[Session, Depends(get_db)]
CurrentUser = Annotated[CurrentUserContext, Depends(get_current_user)]
@router.get(
@@ -35,6 +43,137 @@ def create_reimbursement(payload: ReimbursementCreate, db: DbSession) -> Reimbur
return ReimbursementService(db).create_reimbursement(payload)
@router.get(
"/claims",
response_model=list[ExpenseClaimRead],
summary="查询个人报销单列表",
description="返回当前登录用户可见的真实个人报销单据列表。",
)
def list_expense_claims(db: DbSession, current_user: CurrentUser) -> list[ExpenseClaimRead]:
return ExpenseClaimService(db).list_claims(current_user)
@router.get(
"/claims/{claim_id}",
response_model=ExpenseClaimRead,
summary="读取个人报销单详情",
description="根据报销单主键读取真实报销详情与费用明细。",
responses={
status.HTTP_404_NOT_FOUND: {
"model": ErrorResponse,
"description": "报销单不存在。",
}
},
)
def get_expense_claim(claim_id: str, db: DbSession, current_user: CurrentUser) -> ExpenseClaimRead:
claim = ExpenseClaimService(db).get_claim(claim_id, current_user)
if claim is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Claim not found")
return claim
@router.patch(
"/claims/{claim_id}/items/{item_id}",
response_model=ExpenseClaimRead,
summary="更新草稿费用明细",
description="更新草稿报销单中的单条费用明细。",
responses={
status.HTTP_404_NOT_FOUND: {
"model": ErrorResponse,
"description": "报销单或费用明细不存在。",
},
status.HTTP_400_BAD_REQUEST: {
"model": ErrorResponse,
"description": "草稿状态校验失败或字段校验失败。",
},
},
)
def update_expense_claim_item(
claim_id: str,
item_id: str,
payload: ExpenseClaimItemUpdate,
db: DbSession,
current_user: CurrentUser,
) -> ExpenseClaimRead:
service = ExpenseClaimService(db)
try:
claim = service.update_claim_item(
claim_id=claim_id,
item_id=item_id,
payload=payload,
current_user=current_user,
)
except LookupError as error:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(error)) from error
except ValueError as error:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(error)) from error
if claim is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Claim not found")
return claim
@router.post(
"/claims/{claim_id}/submit",
response_model=ExpenseClaimRead,
summary="提交个人报销草稿",
description="校验草稿信息完整性后,将报销单提交到审批流程。",
responses={
status.HTTP_404_NOT_FOUND: {
"model": ErrorResponse,
"description": "报销单不存在。",
},
status.HTTP_400_BAD_REQUEST: {
"model": ErrorResponse,
"description": "草稿信息不完整或状态不允许提交。",
},
},
)
def submit_expense_claim(claim_id: str, db: DbSession, current_user: CurrentUser) -> ExpenseClaimRead:
service = ExpenseClaimService(db)
try:
claim = service.submit_claim(claim_id, current_user)
except ValueError as error:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(error)) from error
if claim is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Claim not found")
return claim
@router.delete(
"/claims/{claim_id}",
response_model=ExpenseClaimActionResponse,
summary="删除个人报销草稿",
description="删除当前登录用户可见的草稿报销单。",
responses={
status.HTTP_404_NOT_FOUND: {
"model": ErrorResponse,
"description": "报销单不存在。",
},
status.HTTP_400_BAD_REQUEST: {
"model": ErrorResponse,
"description": "仅草稿状态允许删除。",
},
},
)
def delete_expense_claim(claim_id: str, db: DbSession, current_user: CurrentUser) -> ExpenseClaimActionResponse:
service = ExpenseClaimService(db)
try:
claim = service.delete_claim(claim_id, current_user)
except ValueError as error:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(error)) from error
if claim is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Claim not found")
return ExpenseClaimActionResponse(
message=f"{claim.claim_no} 草稿已删除。",
claim_id=claim.id,
status="deleted",
)
@router.get(
"/{request_id}",
response_model=ReimbursementRead,

View File

@@ -1,9 +1,10 @@
from __future__ import annotations
from datetime import datetime
from datetime import date, datetime
from decimal import Decimal
from typing import Any
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, Field
class ReimbursementCreate(BaseModel):
@@ -28,3 +29,58 @@ class ReimbursementRead(BaseModel):
reason: str | None
created_at: datetime
updated_at: datetime
class ExpenseClaimItemRead(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str
item_date: date
item_type: str
item_reason: str
item_location: str
item_amount: Decimal
invoice_id: str | None
created_at: datetime
updated_at: datetime
class ExpenseClaimItemUpdate(BaseModel):
item_date: date | None = None
item_type: str | None = None
item_reason: str | None = None
item_location: str | None = None
item_amount: Decimal | None = None
invoice_id: str | None = None
class ExpenseClaimRead(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str
claim_no: str
employee_id: str | None
employee_name: str
department_id: str | None
department_name: str
project_code: str | None
expense_type: str
reason: str
location: str
amount: Decimal
currency: str
invoice_count: int
occurred_at: datetime
submitted_at: datetime | None
status: str
approval_stage: str | None
risk_flags_json: list[Any] = Field(default_factory=list)
created_at: datetime
updated_at: datetime
items: list[ExpenseClaimItemRead] = Field(default_factory=list)
class ExpenseClaimActionResponse(BaseModel):
message: str
claim_id: str
status: str | None = None

View File

@@ -4,12 +4,14 @@ from datetime import UTC, date, datetime
from decimal import Decimal, InvalidOperation
from typing import Any
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from sqlalchemy import func, or_, select
from sqlalchemy.orm import Session, selectinload
from app.api.deps import CurrentUserContext
from app.models.employee import Employee
from app.models.financial_record import ExpenseClaim, ExpenseClaimItem
from app.schemas.ontology import OntologyEntity, OntologyParseResult
from app.schemas.reimbursement import ExpenseClaimItemUpdate
from app.services.audit import AuditLogService
from app.services.agent_foundation import AgentFoundationService
@@ -22,12 +24,139 @@ EXPENSE_TYPE_LABELS = {
"entertainment": "招待",
}
PRIVILEGED_CLAIM_ROLE_CODES = {"manager", "finance", "approver", "auditor", "executive"}
class ExpenseClaimService:
def __init__(self, db: Session) -> None:
self.db = db
self.audit_service = AuditLogService(db)
def list_claims(self, current_user: CurrentUserContext) -> list[ExpenseClaim]:
stmt = (
select(ExpenseClaim)
.options(selectinload(ExpenseClaim.items))
.order_by(ExpenseClaim.created_at.desc(), ExpenseClaim.occurred_at.desc())
)
stmt = self._apply_claim_scope(stmt, current_user)
return list(self.db.scalars(stmt).all())
def get_claim(self, claim_id: str, current_user: CurrentUserContext) -> ExpenseClaim | None:
stmt = (
select(ExpenseClaim)
.options(selectinload(ExpenseClaim.items))
.where(ExpenseClaim.id == claim_id)
)
stmt = self._apply_claim_scope(stmt, current_user)
return self.db.scalar(stmt)
def update_claim_item(
self,
*,
claim_id: str,
item_id: str,
payload: ExpenseClaimItemUpdate,
current_user: CurrentUserContext,
) -> ExpenseClaim | None:
claim = self.get_claim(claim_id, current_user)
if claim is None:
return None
self._ensure_draft_claim(claim)
item = next((entry for entry in claim.items if entry.id == item_id), None)
if item is None:
raise LookupError("Item not found")
before_json = self._serialize_claim(claim)
if payload.item_date is not None:
item.item_date = payload.item_date
if payload.item_type is not None:
item.item_type = self._normalize_optional_text(payload.item_type, fallback=item.item_type) or item.item_type
if payload.item_reason is not None:
item.item_reason = (
self._normalize_optional_text(payload.item_reason, fallback=item.item_reason) or item.item_reason
)
if payload.item_location is not None:
item.item_location = (
self._normalize_optional_text(payload.item_location, fallback=item.item_location) or item.item_location
)
if payload.item_amount is not None:
amount = payload.item_amount.quantize(Decimal("0.01"))
if amount <= Decimal("0.00"):
raise ValueError("费用金额必须大于 0。")
item.item_amount = amount
if payload.invoice_id is not None:
item.invoice_id = self._normalize_optional_text(payload.invoice_id, allow_empty=True)
self._sync_claim_from_items(claim)
self.db.commit()
self.db.refresh(claim)
self.audit_service.log_action(
actor=current_user.name or current_user.username,
action="expense_claim.item_update",
resource_type="expense_claim",
resource_id=claim.id,
before_json=before_json,
after_json=self._serialize_claim(claim),
)
return claim
def submit_claim(self, claim_id: str, current_user: CurrentUserContext) -> ExpenseClaim | None:
claim = self.get_claim(claim_id, current_user)
if claim is None:
return None
self._ensure_draft_claim(claim)
self._sync_claim_from_items(claim)
missing_fields = self._validate_claim_for_submission(claim)
if missing_fields:
raise ValueError("提交前请先补全信息:" + "".join(missing_fields))
before_json = self._serialize_claim(claim)
claim.status = "submitted"
claim.approval_stage = "审批流转"
claim.submitted_at = datetime.now(UTC)
self.db.commit()
self.db.refresh(claim)
self.audit_service.log_action(
actor=current_user.name or current_user.username,
action="expense_claim.submit",
resource_type="expense_claim",
resource_id=claim.id,
before_json=before_json,
after_json=self._serialize_claim(claim),
)
return claim
def delete_claim(self, claim_id: str, current_user: CurrentUserContext) -> ExpenseClaim | None:
claim = self.get_claim(claim_id, current_user)
if claim is None:
return None
self._ensure_draft_claim(claim)
before_json = self._serialize_claim(claim)
resource_id = claim.id
self.db.delete(claim)
self.db.commit()
self.audit_service.log_action(
actor=current_user.name or current_user.username,
action="expense_claim.delete",
resource_type="expense_claim",
resource_id=resource_id,
before_json=before_json,
after_json=None,
)
return claim
def upsert_draft_from_ontology(
self,
*,
@@ -464,5 +593,123 @@ class ExpenseClaimService:
"risk_flags_json": list(claim.risk_flags_json or []),
}
@staticmethod
def _normalize_optional_text(value: str | None, *, fallback: str = "", allow_empty: bool = False) -> str | None:
normalized = str(value or "").strip()
if normalized:
return normalized
if allow_empty:
return None
return fallback
@staticmethod
def _is_missing_value(value: Any) -> bool:
text = str(value or "").strip()
if not text:
return True
compact = text.replace(" ", "")
return compact in {"待补充", "暂无", "", "未知", "处理中"}
def _ensure_draft_claim(self, claim: ExpenseClaim) -> None:
if str(claim.status or "").strip().lower() != "draft":
raise ValueError("只有草稿状态的报销单才允许执行该操作。")
def _sync_claim_from_items(self, claim: ExpenseClaim) -> None:
if not claim.items:
claim.amount = Decimal("0.00")
claim.invoice_count = 0
return
ordered_items = sorted(
claim.items,
key=lambda item: (
item.item_date or date.max,
item.created_at or datetime.max.replace(tzinfo=UTC),
),
)
primary_item = ordered_items[0]
total_amount = sum((item.item_amount for item in ordered_items), Decimal("0.00"))
claim.amount = total_amount.quantize(Decimal("0.01"))
claim.invoice_count = sum(1 for item in ordered_items if str(item.invoice_id or "").strip())
claim.occurred_at = datetime(
primary_item.item_date.year,
primary_item.item_date.month,
primary_item.item_date.day,
tzinfo=UTC,
)
claim.expense_type = str(primary_item.item_type or claim.expense_type or "other").strip() or "other"
claim.reason = (
self._normalize_optional_text(primary_item.item_reason, fallback=claim.reason or "待补充") or "待补充"
)
claim.location = (
self._normalize_optional_text(primary_item.item_location, fallback=claim.location or "待补充")
or "待补充"
)
if str(claim.status or "").strip().lower() == "draft":
claim.approval_stage = "待提交"
def _validate_claim_for_submission(self, claim: ExpenseClaim) -> list[str]:
issues: list[str] = []
if self._is_missing_value(claim.employee_name):
issues.append("申请人未完善")
if self._is_missing_value(claim.department_name):
issues.append("所属部门未完善")
if self._is_missing_value(claim.expense_type):
issues.append("报销类型未完善")
if self._is_missing_value(claim.reason):
issues.append("报销事由未完善")
if self._is_missing_value(claim.location):
issues.append("业务地点未完善")
if claim.amount is None or claim.amount <= Decimal("0.00"):
issues.append("报销金额未完善")
if claim.occurred_at is None:
issues.append("发生时间未完善")
if not claim.items:
issues.append("费用明细不能为空")
for index, item in enumerate(claim.items, start=1):
prefix = f"费用明细第 {index}"
if item.item_date is None:
issues.append(f"{prefix}缺少日期")
if self._is_missing_value(item.item_type):
issues.append(f"{prefix}缺少费用项目")
if self._is_missing_value(item.item_reason):
issues.append(f"{prefix}缺少说明")
if self._is_missing_value(item.item_location):
issues.append(f"{prefix}缺少地点")
if item.item_amount is None or item.item_amount <= Decimal("0.00"):
issues.append(f"{prefix}缺少金额")
if self._is_missing_value(item.invoice_id):
issues.append(f"{prefix}缺少票据标识")
return issues
@staticmethod
def _has_privileged_claim_access(current_user: CurrentUserContext) -> bool:
if current_user.is_admin:
return True
return bool(set(current_user.role_codes) & PRIVILEGED_CLAIM_ROLE_CODES)
def _apply_claim_scope(self, stmt: Any, current_user: CurrentUserContext) -> Any:
if self._has_privileged_claim_access(current_user):
return stmt
conditions = []
username = str(current_user.username or "").strip()
name = str(current_user.name or "").strip()
if username:
conditions.append(ExpenseClaim.employee_id == username)
conditions.append(ExpenseClaim.employee_name == username)
if name:
conditions.append(ExpenseClaim.employee_name == name)
if not conditions:
return stmt.where(ExpenseClaim.id == "__no_visible_claim__")
return stmt.where(or_(*conditions))
def _ensure_ready(self) -> None:
AgentFoundationService(self.db).ensure_foundation_ready()