from __future__ import annotations from datetime import UTC, date, datetime from decimal import Decimal, InvalidOperation from typing import Any from sqlalchemy import func, select from sqlalchemy.orm import Session from app.models.employee import Employee from app.models.financial_record import ExpenseClaim, ExpenseClaimItem from app.schemas.ontology import OntologyEntity, OntologyParseResult from app.services.audit import AuditLogService from app.services.agent_foundation import AgentFoundationService EXPENSE_TYPE_LABELS = { "travel": "差旅", "hotel": "住宿", "transport": "交通", "meal": "餐费", "meeting": "会务", "entertainment": "招待", } class ExpenseClaimService: def __init__(self, db: Session) -> None: self.db = db self.audit_service = AuditLogService(db) def upsert_draft_from_ontology( self, *, run_id: str, user_id: str | None, message: str, ontology: OntologyParseResult, context_json: dict[str, Any], ) -> dict[str, Any]: self._ensure_ready() claim = self._find_target_claim(ontology=ontology, context_json=context_json) before_json = self._serialize_claim(claim) if claim is not None else None employee = self._resolve_employee(ontology=ontology, context_json=context_json) amount = self._resolve_amount(ontology.entities) occurred_at = self._resolve_occurred_at(ontology) expense_type = self._resolve_expense_type(ontology.entities) location = self._resolve_location(message=message, context_json=context_json) reason = self._resolve_reason(message=message, context_json=context_json) attachment_count = self._resolve_attachment_count(context_json) if claim is None: claim = ExpenseClaim( claim_no=self._generate_claim_no(occurred_at), employee_id=employee.id if employee is not None else None, employee_name=employee.name if employee is not None else self._resolve_employee_name( ontology=ontology, context_json=context_json, user_id=user_id, ), department_id=employee.organization_unit_id if employee is not None else None, department_name=self._resolve_department_name( employee=employee, context_json=context_json, ), project_code=self._resolve_project_code(ontology.entities), expense_type=expense_type, reason=reason, location=location, amount=amount, currency="CNY", invoice_count=attachment_count, occurred_at=occurred_at, status="draft", approval_stage="待补充", risk_flags_json=list(ontology.risk_flags), ) self.db.add(claim) else: claim.employee_id = employee.id if employee is not None else claim.employee_id claim.employee_name = ( employee.name if employee is not None else self._resolve_employee_name( ontology=ontology, context_json=context_json, user_id=user_id, ) ) claim.department_id = employee.organization_unit_id if employee is not None else claim.department_id claim.department_name = self._resolve_department_name( employee=employee, context_json=context_json, fallback=claim.department_name, ) claim.project_code = self._resolve_project_code(ontology.entities) or claim.project_code claim.expense_type = expense_type or claim.expense_type claim.reason = reason claim.location = location claim.amount = amount claim.invoice_count = attachment_count claim.occurred_at = occurred_at claim.status = "draft" claim.approval_stage = "待补充" claim.risk_flags_json = list(ontology.risk_flags) self.db.flush() self._upsert_primary_item( claim=claim, occurred_at=occurred_at, expense_type=expense_type, amount=amount, reason=reason, location=location, attachment_names=self._resolve_attachment_names(context_json), ) self.db.commit() self.db.refresh(claim) self.audit_service.log_action( actor=user_id or claim.employee_name or "anonymous", action="expense_claim.draft_upsert", resource_type="expense_claim", resource_id=claim.id, before_json=before_json, after_json=self._serialize_claim(claim), request_id=run_id, ) return { "message": ( f"已创建报销草稿 {claim.claim_no},当前状态为 draft。" "你可以继续补充费用明细、客户单位和票据附件。" ), "draft_only": True, "claim_id": claim.id, "claim_no": claim.claim_no, "status": claim.status, "amount": float(claim.amount), "invoice_count": int(claim.invoice_count or 0), } def _find_target_claim( self, *, ontology: OntologyParseResult, context_json: dict[str, Any], ) -> ExpenseClaim | None: draft_claim_id = str(context_json.get("draft_claim_id") or "").strip() if draft_claim_id: return self.db.get(ExpenseClaim, draft_claim_id) claim_codes = [ item.normalized_value for item in ontology.entities if item.type == "expense_claim" and item.normalized_value ] if not claim_codes: return None stmt = select(ExpenseClaim).where(ExpenseClaim.claim_no.in_(claim_codes)).limit(1) return self.db.scalar(stmt) def _upsert_primary_item( self, *, claim: ExpenseClaim, occurred_at: datetime, expense_type: str, amount: Decimal, reason: str, location: str, attachment_names: list[str], ) -> None: item = claim.items[0] if claim.items else None if item is None: item = ExpenseClaimItem( claim_id=claim.id, item_date=occurred_at.date(), item_type=expense_type, item_reason=reason, item_location=location, item_amount=amount, invoice_id=attachment_names[0] if attachment_names else None, ) claim.items.append(item) self.db.add(item) return item.item_date = occurred_at.date() item.item_type = expense_type item.item_reason = reason item.item_location = location item.item_amount = amount item.invoice_id = attachment_names[0] if attachment_names else item.invoice_id def _generate_claim_no(self, occurred_at: datetime) -> str: month_code = occurred_at.strftime("%Y%m") prefix = f"EXP-{month_code}-" existing = int( self.db.scalar( select(func.count()).select_from(ExpenseClaim).where(ExpenseClaim.claim_no.like(f"{prefix}%")) ) or 0 ) return f"{prefix}{existing + 1:03d}" def _resolve_employee( self, *, ontology: OntologyParseResult, context_json: dict[str, Any], ) -> Employee | None: employee_name = self._resolve_employee_name( ontology=ontology, context_json=context_json, user_id=None, ) if not employee_name: return None stmt = select(Employee).where(Employee.name == employee_name).limit(1) return self.db.scalar(stmt) @staticmethod def _resolve_employee_name( *, ontology: OntologyParseResult, context_json: dict[str, Any], user_id: str | None, ) -> str: for item in ontology.entities: if item.type == "employee" and item.value.strip(): return item.value.strip() for key in ("name", "user_name", "employee_name"): value = str(context_json.get(key) or "").strip() if value: return value return str(user_id or "待补充").strip() or "待补充" @staticmethod def _resolve_department_name( *, employee: Employee | None, context_json: dict[str, Any], fallback: str = "待补充", ) -> str: if employee is not None and employee.organization_unit is not None: return employee.organization_unit.name request_context = context_json.get("request_context") if isinstance(request_context, dict): for key in ("department", "department_name", "deptName"): value = str(request_context.get(key) or "").strip() if value: return value for key in ("department_name", "department"): value = str(context_json.get(key) or "").strip() if value: return value return fallback @staticmethod def _resolve_project_code(entities: list[OntologyEntity]) -> str | None: for item in entities: if item.type == "project" and item.normalized_value.strip(): return item.normalized_value.strip() return None @staticmethod def _resolve_expense_type(entities: list[OntologyEntity]) -> str: for item in entities: if item.type == "expense_type": normalized = item.normalized_value.strip() if normalized: return normalized return "other" @staticmethod def _resolve_reason(*, message: str, context_json: dict[str, Any]) -> str: request_context = context_json.get("request_context") if isinstance(request_context, dict): for key in ("reason", "title"): value = str(request_context.get(key) or "").strip() if value: return value return str(message or "").strip()[:500] or "待补充" @staticmethod def _resolve_location(*, message: str, context_json: dict[str, Any]) -> str: request_context = context_json.get("request_context") if isinstance(request_context, dict): for key in ("city", "location"): value = str(request_context.get(key) or "").strip() if value: return value compact = str(message or "").replace(" ", "") if "客户现场" in compact: return "客户现场" return "待补充" @staticmethod def _resolve_occurred_at(ontology: OntologyParseResult) -> datetime: start_date = ontology.time_range.start_date if start_date: try: parsed = date.fromisoformat(start_date) return datetime(parsed.year, parsed.month, parsed.day, tzinfo=UTC) except ValueError: pass return datetime.now(UTC) @staticmethod def _resolve_amount(entities: list[OntologyEntity]) -> Decimal: for item in entities: if item.type != "amount" or item.role == "threshold": continue try: return Decimal(item.normalized_value).quantize(Decimal("0.01")) except (InvalidOperation, ValueError): continue return Decimal("0.00") @staticmethod def _resolve_attachment_names(context_json: dict[str, Any]) -> list[str]: names = context_json.get("attachment_names") if not isinstance(names, list): return [] return [str(name).strip() for name in names if str(name).strip()] def _resolve_attachment_count(self, context_json: dict[str, Any]) -> int: names = self._resolve_attachment_names(context_json) if names: return len(names) try: return max(0, int(context_json.get("attachment_count") or 0)) except (TypeError, ValueError): return 0 @staticmethod def _serialize_claim(claim: ExpenseClaim) -> dict[str, Any]: return { "id": claim.id, "claim_no": claim.claim_no, "employee_name": claim.employee_name, "department_name": claim.department_name, "project_code": claim.project_code, "expense_type": claim.expense_type, "reason": claim.reason, "location": claim.location, "amount": float(claim.amount), "invoice_count": int(claim.invoice_count or 0), "status": claim.status, "approval_stage": claim.approval_stage, "risk_flags_json": list(claim.risk_flags_json or []), } def _ensure_ready(self) -> None: AgentFoundationService(self.db).ensure_foundation_ready()