from __future__ import annotations from datetime import UTC, date, datetime from decimal import Decimal, InvalidOperation from typing import Any from sqlalchemy import func, or_, select from sqlalchemy.orm import Session, selectinload from app.api.deps import CurrentUserContext from app.models.employee import Employee from app.models.financial_record import ExpenseClaim, ExpenseClaimItem from app.schemas.ontology import OntologyEntity, OntologyParseResult from app.schemas.reimbursement import ExpenseClaimItemUpdate from app.services.audit import AuditLogService from app.services.agent_foundation import AgentFoundationService EXPENSE_TYPE_LABELS = { "travel": "差旅", "hotel": "住宿", "transport": "交通", "meal": "餐费", "meeting": "会务", "entertainment": "招待", } PRIVILEGED_CLAIM_ROLE_CODES = {"manager", "finance", "approver", "auditor", "executive"} MAX_DRAFT_CLAIMS_PER_USER = 3 class ExpenseClaimService: def __init__(self, db: Session) -> None: self.db = db self.audit_service = AuditLogService(db) def list_claims(self, current_user: CurrentUserContext) -> list[ExpenseClaim]: stmt = ( select(ExpenseClaim) .options(selectinload(ExpenseClaim.items)) .order_by(ExpenseClaim.created_at.desc(), ExpenseClaim.occurred_at.desc()) ) stmt = self._apply_claim_scope(stmt, current_user) return list(self.db.scalars(stmt).all()) def get_claim(self, claim_id: str, current_user: CurrentUserContext) -> ExpenseClaim | None: stmt = ( select(ExpenseClaim) .options(selectinload(ExpenseClaim.items)) .where(ExpenseClaim.id == claim_id) ) stmt = self._apply_claim_scope(stmt, current_user) return self.db.scalar(stmt) def update_claim_item( self, *, claim_id: str, item_id: str, payload: ExpenseClaimItemUpdate, current_user: CurrentUserContext, ) -> ExpenseClaim | None: claim = self.get_claim(claim_id, current_user) if claim is None: return None self._ensure_draft_claim(claim) item = next((entry for entry in claim.items if entry.id == item_id), None) if item is None: raise LookupError("Item not found") before_json = self._serialize_claim(claim) if payload.item_date is not None: item.item_date = payload.item_date if payload.item_type is not None: item.item_type = self._normalize_optional_text(payload.item_type, fallback=item.item_type) or item.item_type if payload.item_reason is not None: item.item_reason = ( self._normalize_optional_text(payload.item_reason, fallback=item.item_reason) or item.item_reason ) if payload.item_location is not None: item.item_location = ( self._normalize_optional_text(payload.item_location, fallback=item.item_location) or item.item_location ) if payload.item_amount is not None: amount = payload.item_amount.quantize(Decimal("0.01")) if amount <= Decimal("0.00"): raise ValueError("费用金额必须大于 0。") item.item_amount = amount if payload.invoice_id is not None: item.invoice_id = self._normalize_optional_text(payload.invoice_id, allow_empty=True) self._sync_claim_from_items(claim) self.db.commit() self.db.refresh(claim) self.audit_service.log_action( actor=current_user.name or current_user.username, action="expense_claim.item_update", resource_type="expense_claim", resource_id=claim.id, before_json=before_json, after_json=self._serialize_claim(claim), ) return claim def submit_claim(self, claim_id: str, current_user: CurrentUserContext) -> ExpenseClaim | None: claim = self.get_claim(claim_id, current_user) if claim is None: return None self._ensure_draft_claim(claim) self._sync_claim_from_items(claim) missing_fields = self._validate_claim_for_submission(claim) if missing_fields: raise ValueError("提交前请先补全信息:" + ";".join(missing_fields)) before_json = self._serialize_claim(claim) claim.status = "submitted" claim.approval_stage = "AI验审" claim.submitted_at = datetime.now(UTC) self.db.commit() self.db.refresh(claim) self.audit_service.log_action( actor=current_user.name or current_user.username, action="expense_claim.submit", resource_type="expense_claim", resource_id=claim.id, before_json=before_json, after_json=self._serialize_claim(claim), ) return claim def delete_claim(self, claim_id: str, current_user: CurrentUserContext) -> ExpenseClaim | None: claim = self.get_claim(claim_id, current_user) if claim is None: return None self._ensure_draft_claim(claim) before_json = self._serialize_claim(claim) resource_id = claim.id self.db.delete(claim) self.db.commit() self.audit_service.log_action( actor=current_user.name or current_user.username, action="expense_claim.delete", resource_type="expense_claim", resource_id=resource_id, before_json=before_json, after_json=None, ) return claim def upsert_draft_from_ontology( self, *, run_id: str, user_id: str | None, message: str, ontology: OntologyParseResult, context_json: dict[str, Any], ) -> dict[str, Any]: self._ensure_ready() claim = self._find_target_claim(ontology=ontology, context_json=context_json) is_new_claim = claim is None before_json = self._serialize_claim(claim) if claim is not None else None employee = self._resolve_employee( ontology=ontology, context_json=context_json, user_id=user_id, ) draft_owner_name = ( employee.name if employee is not None else self._resolve_employee_name( ontology=ontology, context_json=context_json, user_id=user_id, ) ) if is_new_claim: existing_draft_count = self._count_draft_claims_for_owner( employee=employee, employee_name=draft_owner_name, user_id=user_id, ) if existing_draft_count >= MAX_DRAFT_CLAIMS_PER_USER: return { "message": ( f"你当前已保存 {MAX_DRAFT_CLAIMS_PER_USER} 个草稿,请先完成已保存的草稿," "才能再次新建草稿。" ), "draft_limit_reached": True, "draft_only": False, "status": "blocked", "draft_count": existing_draft_count, "max_draft_count": MAX_DRAFT_CLAIMS_PER_USER, } amount = self._resolve_amount(ontology.entities, context_json=context_json) occurred_at = self._resolve_occurred_at(ontology, context_json=context_json) expense_type = self._resolve_expense_type(ontology.entities, context_json=context_json) location = self._resolve_location(message=message, context_json=context_json) reason = self._resolve_reason( message=message, context_json=context_json, allow_message_fallback=is_new_claim, ) attachment_count = self._resolve_attachment_count(context_json) final_amount = amount if amount is not None else (claim.amount if claim is not None else Decimal("0.00")) final_occurred_at = ( occurred_at if occurred_at is not None else (claim.occurred_at if claim is not None else datetime.now(UTC)) ) final_expense_type = expense_type or (claim.expense_type if claim is not None else "other") final_location = location or (claim.location if claim is not None else "待补充") final_reason = reason or (claim.reason if claim is not None else "待补充") final_attachment_count = ( attachment_count if attachment_count > 0 else int(claim.invoice_count or 0) if claim is not None else 0 ) final_risk_flags = list(ontology.risk_flags) or ( list(claim.risk_flags_json or []) if claim is not None else [] ) if claim is None: claim = ExpenseClaim( claim_no=self._generate_claim_no(final_occurred_at), employee_id=employee.id if employee is not None else None, employee_name=draft_owner_name, department_id=employee.organization_unit_id if employee is not None else None, department_name=self._resolve_department_name( employee=employee, context_json=context_json, ), project_code=self._resolve_project_code(ontology.entities), expense_type=final_expense_type, reason=final_reason, location=final_location, amount=final_amount, currency="CNY", invoice_count=final_attachment_count, occurred_at=final_occurred_at, status="draft", approval_stage="待提交", risk_flags_json=final_risk_flags, ) self.db.add(claim) else: claim.employee_id = employee.id if employee is not None else claim.employee_id claim.employee_name = ( employee.name if employee is not None else self._resolve_employee_name( ontology=ontology, context_json=context_json, user_id=user_id, fallback=claim.employee_name, ) ) claim.department_id = employee.organization_unit_id if employee is not None else claim.department_id claim.department_name = self._resolve_department_name( employee=employee, context_json=context_json, fallback=claim.department_name, ) claim.project_code = self._resolve_project_code(ontology.entities) or claim.project_code claim.expense_type = final_expense_type claim.reason = final_reason claim.location = final_location claim.amount = final_amount claim.invoice_count = final_attachment_count claim.occurred_at = final_occurred_at claim.status = "draft" claim.approval_stage = "待提交" claim.risk_flags_json = final_risk_flags self.db.flush() self._upsert_primary_item( claim=claim, occurred_at=final_occurred_at, expense_type=final_expense_type, amount=final_amount, reason=final_reason, location=final_location, attachment_names=self._resolve_attachment_names(context_json), ) self.db.commit() self.db.refresh(claim) self.audit_service.log_action( actor=user_id or claim.employee_name or "anonymous", action="expense_claim.draft_upsert", resource_type="expense_claim", resource_id=claim.id, before_json=before_json, after_json=self._serialize_claim(claim), request_id=run_id, ) return { "message": ( f"已{'创建' if is_new_claim else '更新'}报销草稿 {claim.claim_no},当前状态为 draft。" "你可以继续补充费用明细、客户单位和票据附件。" ), "draft_only": True, "claim_id": claim.id, "claim_no": claim.claim_no, "status": claim.status, "amount": float(claim.amount), "invoice_count": int(claim.invoice_count or 0), } def _find_target_claim( self, *, ontology: OntologyParseResult, context_json: dict[str, Any], ) -> ExpenseClaim | None: draft_claim_id = str(context_json.get("draft_claim_id") or "").strip() if draft_claim_id: return self.db.get(ExpenseClaim, draft_claim_id) claim_codes = [ item.normalized_value for item in ontology.entities if item.type == "expense_claim" and item.normalized_value ] if not claim_codes: return None stmt = select(ExpenseClaim).where(ExpenseClaim.claim_no.in_(claim_codes)).limit(1) return self.db.scalar(stmt) def _upsert_primary_item( self, *, claim: ExpenseClaim, occurred_at: datetime, expense_type: str, amount: Decimal, reason: str, location: str, attachment_names: list[str], ) -> None: item = claim.items[0] if claim.items else None if item is None: item = ExpenseClaimItem( claim_id=claim.id, item_date=occurred_at.date(), item_type=expense_type, item_reason=reason, item_location=location, item_amount=amount, invoice_id=attachment_names[0] if attachment_names else None, ) claim.items.append(item) self.db.add(item) return item.item_date = occurred_at.date() item.item_type = expense_type item.item_reason = reason item.item_location = location item.item_amount = amount item.invoice_id = attachment_names[0] if attachment_names else item.invoice_id def _generate_claim_no(self, occurred_at: datetime) -> str: month_code = occurred_at.strftime("%Y%m") prefix = f"EXP-{month_code}-" existing = int( self.db.scalar( select(func.count()).select_from(ExpenseClaim).where(ExpenseClaim.claim_no.like(f"{prefix}%")) ) or 0 ) return f"{prefix}{existing + 1:03d}" def _count_draft_claims_for_owner( self, *, employee: Employee | None, employee_name: str, user_id: str | None, ) -> int: owner_filters = self._build_draft_owner_filters( employee=employee, employee_name=employee_name, user_id=user_id, ) if not owner_filters: return 0 stmt = ( select(func.count()) .select_from(ExpenseClaim) .where(ExpenseClaim.status == "draft") .where(or_(*owner_filters)) ) return int(self.db.scalar(stmt) or 0) @staticmethod def _build_draft_owner_filters( *, employee: Employee | None, employee_name: str, user_id: str | None, ) -> list[Any]: conditions: list[Any] = [] seen: set[tuple[str, str]] = set() def add_condition(field_name: str, value: str | None) -> None: normalized = str(value or "").strip() if not normalized or normalized == "待补充": return marker = (field_name, normalized.lower()) if marker in seen: return seen.add(marker) if field_name == "employee_id": conditions.append(ExpenseClaim.employee_id == normalized) return conditions.append(ExpenseClaim.employee_name == normalized) if employee is not None: add_condition("employee_id", employee.id) add_condition("employee_name", employee.name) add_condition("employee_name", employee.email) add_condition("employee_name", employee_name) add_condition("employee_name", user_id) return conditions def _resolve_employee( self, *, ontology: OntologyParseResult, context_json: dict[str, Any], user_id: str | None, ) -> Employee | None: normalized_user_id = str(user_id or "").strip() if normalized_user_id: stmt = select(Employee).where(func.lower(Employee.email) == normalized_user_id.lower()).limit(1) employee = self.db.scalar(stmt) if employee is not None: return employee employee_name = self._resolve_employee_name( ontology=ontology, context_json=context_json, user_id=None, ) if not employee_name: return None stmt = select(Employee).where(Employee.name == employee_name).limit(1) return self.db.scalar(stmt) @staticmethod def _resolve_employee_name( *, ontology: OntologyParseResult, context_json: dict[str, Any], user_id: str | None, fallback: str = "待补充", ) -> str: review_form_values = context_json.get("review_form_values") if isinstance(review_form_values, dict): for key in ("reporter_name", "employee_name", "claimant_name"): value = str(review_form_values.get(key) or "").strip() if value: return value for item in ontology.entities: if item.type == "employee" and item.value.strip(): return item.value.strip() for key in ("name", "user_name", "employee_name"): value = str(context_json.get(key) or "").strip() if value: return value return str(user_id or fallback).strip() or fallback @staticmethod def _resolve_department_name( *, employee: Employee | None, context_json: dict[str, Any], fallback: str = "待补充", ) -> str: if employee is not None and employee.organization_unit is not None: return employee.organization_unit.name request_context = context_json.get("request_context") if isinstance(request_context, dict): for key in ("department", "department_name", "deptName"): value = str(request_context.get(key) or "").strip() if value: return value for key in ("department_name", "department"): value = str(context_json.get(key) or "").strip() if value: return value return fallback @staticmethod def _resolve_project_code(entities: list[OntologyEntity]) -> str | None: for item in entities: if item.type == "project" and item.normalized_value.strip(): return item.normalized_value.strip() return None @staticmethod def _resolve_expense_type( entities: list[OntologyEntity], *, context_json: dict[str, Any], ) -> str | None: review_form_values = context_json.get("review_form_values") if isinstance(review_form_values, dict): compact = str( review_form_values.get("expense_type") or review_form_values.get("reimbursement_type") or "" ).replace(" ", "") if compact: if "招待" in compact or ("客户" in compact and any(word in compact for word in ("吃饭", "宴请", "请客", "用餐"))): return "entertainment" if any(word in compact for word in ("差旅", "出差", "机票", "行程")): return "travel" if any(word in compact for word in ("住宿", "酒店", "宾馆")): return "hotel" if any(word in compact for word in ("交通", "打车", "网约车", "出租车", "停车", "车费")): return "transport" if any(word in compact for word in ("餐费", "用餐", "午餐", "晚餐", "早餐", "伙食")): return "meal" if "会务" in compact: return "meeting" for item in entities: if item.type == "expense_type": normalized = item.normalized_value.strip() if normalized: return normalized return None @staticmethod def _resolve_reason( *, message: str, context_json: dict[str, Any], allow_message_fallback: bool, ) -> str | None: review_form_values = context_json.get("review_form_values") if isinstance(review_form_values, dict): for key in ("reason", "business_reason"): value = str(review_form_values.get(key) or "").strip() if value: return value request_context = context_json.get("request_context") if ( isinstance(request_context, dict) and str(context_json.get("entry_source") or "").strip() == "detail" ): for key in ("reason", "title"): value = str(request_context.get(key) or "").strip() if value: return value if not allow_message_fallback: return None return str(message or "").strip()[:500] or None @staticmethod def _resolve_location(*, message: str, context_json: dict[str, Any]) -> str | None: review_form_values = context_json.get("review_form_values") if isinstance(review_form_values, dict): for key in ("business_location", "location"): value = str(review_form_values.get(key) or "").strip() if value: return value request_context = context_json.get("request_context") if ( isinstance(request_context, dict) and str(context_json.get("entry_source") or "").strip() == "detail" ): for key in ("city", "location"): value = str(request_context.get(key) or "").strip() if value: return value compact = str(message or "").replace(" ", "") if "客户现场" in compact: return "客户现场" return None @staticmethod def _resolve_occurred_at( ontology: OntologyParseResult, *, context_json: dict[str, Any], ) -> datetime | None: review_form_values = context_json.get("review_form_values") if isinstance(review_form_values, dict): for key in ("occurred_date", "time_range", "business_time"): value = str(review_form_values.get(key) or "").strip() if not value: continue try: parsed = date.fromisoformat(value) return datetime(parsed.year, parsed.month, parsed.day, tzinfo=UTC) except ValueError: continue start_date = ontology.time_range.start_date if start_date: try: parsed = date.fromisoformat(start_date) return datetime(parsed.year, parsed.month, parsed.day, tzinfo=UTC) except ValueError: pass return None @staticmethod def _resolve_amount( entities: list[OntologyEntity], *, context_json: dict[str, Any], ) -> Decimal | None: review_form_values = context_json.get("review_form_values") if isinstance(review_form_values, dict): raw_value = str(review_form_values.get("amount") or "").strip() if raw_value: compact = raw_value.replace("元", "").replace(",", "").strip() try: return Decimal(compact).quantize(Decimal("0.01")) except (InvalidOperation, ValueError): pass for item in entities: if item.type != "amount" or item.role == "threshold": continue try: return Decimal(item.normalized_value).quantize(Decimal("0.01")) except (InvalidOperation, ValueError): continue return None @staticmethod def _resolve_attachment_names(context_json: dict[str, Any]) -> list[str]: names = context_json.get("attachment_names") if not isinstance(names, list): return [] return [str(name).strip() for name in names if str(name).strip()] def _resolve_attachment_count(self, context_json: dict[str, Any]) -> int: names = self._resolve_attachment_names(context_json) if names: return len(names) try: return max(0, int(context_json.get("attachment_count") or 0)) except (TypeError, ValueError): return 0 @staticmethod def _serialize_claim(claim: ExpenseClaim) -> dict[str, Any]: return { "id": claim.id, "claim_no": claim.claim_no, "employee_name": claim.employee_name, "department_name": claim.department_name, "project_code": claim.project_code, "expense_type": claim.expense_type, "reason": claim.reason, "location": claim.location, "amount": float(claim.amount), "invoice_count": int(claim.invoice_count or 0), "status": claim.status, "approval_stage": claim.approval_stage, "risk_flags_json": list(claim.risk_flags_json or []), } @staticmethod def _normalize_optional_text(value: str | None, *, fallback: str = "", allow_empty: bool = False) -> str | None: normalized = str(value or "").strip() if normalized: return normalized if allow_empty: return None return fallback @staticmethod def _is_missing_value(value: Any) -> bool: text = str(value or "").strip() if not text: return True compact = text.replace(" ", "") return compact in {"待补充", "暂无", "无", "未知", "处理中"} def _ensure_draft_claim(self, claim: ExpenseClaim) -> None: if str(claim.status or "").strip().lower() != "draft": raise ValueError("只有草稿状态的报销单才允许执行该操作。") def _sync_claim_from_items(self, claim: ExpenseClaim) -> None: if not claim.items: claim.amount = Decimal("0.00") claim.invoice_count = 0 return ordered_items = sorted( claim.items, key=lambda item: ( item.item_date or date.max, item.created_at or datetime.max.replace(tzinfo=UTC), ), ) primary_item = ordered_items[0] total_amount = sum((item.item_amount for item in ordered_items), Decimal("0.00")) claim.amount = total_amount.quantize(Decimal("0.01")) claim.invoice_count = sum(1 for item in ordered_items if str(item.invoice_id or "").strip()) claim.occurred_at = datetime( primary_item.item_date.year, primary_item.item_date.month, primary_item.item_date.day, tzinfo=UTC, ) claim.expense_type = str(primary_item.item_type or claim.expense_type or "other").strip() or "other" claim.reason = ( self._normalize_optional_text(primary_item.item_reason, fallback=claim.reason or "待补充") or "待补充" ) claim.location = ( self._normalize_optional_text(primary_item.item_location, fallback=claim.location or "待补充") or "待补充" ) if str(claim.status or "").strip().lower() == "draft": claim.approval_stage = "待提交" def _validate_claim_for_submission(self, claim: ExpenseClaim) -> list[str]: issues: list[str] = [] if self._is_missing_value(claim.employee_name): issues.append("申请人未完善") if self._is_missing_value(claim.department_name): issues.append("所属部门未完善") if self._is_missing_value(claim.expense_type): issues.append("报销类型未完善") if self._is_missing_value(claim.reason): issues.append("报销事由未完善") if self._is_missing_value(claim.location): issues.append("业务地点未完善") if claim.amount is None or claim.amount <= Decimal("0.00"): issues.append("报销金额未完善") if claim.occurred_at is None: issues.append("发生时间未完善") if not claim.items: issues.append("费用明细不能为空") for index, item in enumerate(claim.items, start=1): prefix = f"费用明细第 {index} 条" if item.item_date is None: issues.append(f"{prefix}缺少日期") if self._is_missing_value(item.item_type): issues.append(f"{prefix}缺少费用项目") if self._is_missing_value(item.item_reason): issues.append(f"{prefix}缺少说明") if self._is_missing_value(item.item_location): issues.append(f"{prefix}缺少地点") if item.item_amount is None or item.item_amount <= Decimal("0.00"): issues.append(f"{prefix}缺少金额") if self._is_missing_value(item.invoice_id): issues.append(f"{prefix}缺少票据标识") return issues @staticmethod def _has_privileged_claim_access(current_user: CurrentUserContext) -> bool: if current_user.is_admin: return True return bool(set(current_user.role_codes) & PRIVILEGED_CLAIM_ROLE_CODES) def _apply_claim_scope(self, stmt: Any, current_user: CurrentUserContext) -> Any: if self._has_privileged_claim_access(current_user): return stmt conditions = [] username = str(current_user.username or "").strip() name = str(current_user.name or "").strip() if username: conditions.append(ExpenseClaim.employee_id == username) conditions.append(ExpenseClaim.employee_name == username) if name: conditions.append(ExpenseClaim.employee_name == name) if not conditions: return stmt.where(ExpenseClaim.id == "__no_visible_claim__") return stmt.where(or_(*conditions)) def _ensure_ready(self) -> None: AgentFoundationService(self.db).ensure_foundation_ready()