from __future__ import annotations from dataclasses import dataclass, field from datetime import UTC, datetime from threading import Lock from typing import Any, Callable from uuid import uuid4 from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker from app.api.deps import CurrentUserContext from app.schemas.linked_reimbursement_draft_job import ( LinkedReimbursementDraftJobCreate, LinkedReimbursementDraftJobRead, ) from app.schemas.ontology import OntologyParseResult, OntologyPermission from app.schemas.orchestrator import OrchestratorRequest from app.models.financial_record import ExpenseClaim from app.services.expense_claims import ExpenseClaimService from app.services.orchestrator import OrchestratorService TERMINAL_STATUSES = {"succeeded", "failed"} @dataclass(slots=True) class LinkedReimbursementDraftJobState: job_id: str owner_username: str owner_name: str message: str context_json: dict[str, Any] conversation_id: str = "" status: str = "queued" status_message: str = "已创建报销草稿生成任务,等待后台处理。" error: str = "" run_id: str = "" draft_payload: dict[str, Any] | None = None created_at: datetime = field(default_factory=lambda: datetime.now(UTC)) updated_at: datetime = field(default_factory=lambda: datetime.now(UTC)) def to_read(self) -> LinkedReimbursementDraftJobRead: return LinkedReimbursementDraftJobRead( job_id=self.job_id, status=self.status, message=self.status_message, error=self.error, run_id=self.run_id, conversation_id=self.conversation_id, draft_payload=self.draft_payload, created_at=self.created_at, updated_at=self.updated_at, ) _jobs: dict[str, LinkedReimbursementDraftJobState] = {} _jobs_lock = Lock() def clear_linked_reimbursement_draft_jobs_for_tests() -> None: with _jobs_lock: _jobs.clear() def create_linked_reimbursement_draft_job( payload: LinkedReimbursementDraftJobCreate, current_user: CurrentUserContext, ) -> LinkedReimbursementDraftJobRead: context_json = dict(payload.context_json or {}) context_json["entry_source"] = context_json.get("entry_source") or "workbench-ai" context_json["session_type"] = context_json.get("session_type") or "expense" job_id = f"linked-reimbursement-draft-{uuid4()}" state = LinkedReimbursementDraftJobState( job_id=job_id, owner_username=str(current_user.username or "").strip(), owner_name=str(current_user.name or "").strip(), message=str(payload.message or "").strip(), context_json=context_json, conversation_id=str(payload.conversation_id or "").strip(), ) with _jobs_lock: _jobs[job_id] = state return state.to_read() def get_linked_reimbursement_draft_job( job_id: str, current_user: CurrentUserContext, ) -> LinkedReimbursementDraftJobRead | None: state = _get_authorized_state(job_id, current_user) return state.to_read() if state is not None else None def run_linked_reimbursement_draft_job( job_id: str, current_user: CurrentUserContext, session_factory: sessionmaker[Session] | Callable[[], Session], ) -> None: state = _get_authorized_state(job_id, current_user) if state is None or state.status in TERMINAL_STATUSES: return _update_job(job_id, status="running", status_message="正在后台生成报销草稿...") try: with session_factory() as db: if _can_use_direct_save_path(db, state.context_json): run_id, result, draft_payload = _run_direct_save_path( db=db, state=state, current_user=current_user, ) else: response = OrchestratorService(db).run( OrchestratorRequest( source="user_message", user_id=_resolve_user_id(current_user), conversation_id=None, message=state.message, context_json=dict(state.context_json), ) ) run_id = response.run_id result = response.result if isinstance(response.result, dict) else {} draft_payload = result.get("draft_payload") if isinstance(result.get("draft_payload"), dict) else None if response.status != "succeeded": raise ValueError(str(result.get("message") or "报销草稿生成失败,请稍后重试。").strip()) if draft_payload is None: raise ValueError("报销草稿生成完成,但未返回草稿信息,请刷新单据列表后核对。") _update_job( job_id, status="succeeded", status_message=str(result.get("message") or "报销草稿已生成。").strip(), run_id=run_id, draft_payload=draft_payload, error="", ) except Exception as exc: message = str(exc).strip() or "报销草稿生成失败,请稍后重试。" _update_job( job_id, status="failed", status_message=message, error=message, ) def _get_authorized_state( job_id: str, current_user: CurrentUserContext, ) -> LinkedReimbursementDraftJobState | None: normalized_job_id = str(job_id or "").strip() with _jobs_lock: state = _jobs.get(normalized_job_id) if state is None: return None if current_user.is_admin: return state username = str(current_user.username or "").strip() name = str(current_user.name or "").strip() if username and username == state.owner_username: return state if name and name == state.owner_name: return state return None def _update_job(job_id: str, **updates: Any) -> None: with _jobs_lock: state = _jobs.get(str(job_id or "").strip()) if state is None: return for key, value in updates.items(): if hasattr(state, key): setattr(state, key, value) state.updated_at = datetime.now(UTC) def _resolve_user_id(current_user: CurrentUserContext) -> str: return str(current_user.username or current_user.name or "anonymous").strip() or "anonymous" def _can_use_direct_save_path(db: Session, context_json: dict[str, Any]) -> bool: review_action = str((context_json or {}).get("review_action") or "").strip() if review_action != "save_draft": return False review_values = context_json.get("review_form_values") if not isinstance(review_values, dict): return False application_claim_id = str(review_values.get("application_claim_id") or "").strip() application_claim_no = str(review_values.get("application_claim_no") or "").strip() if not application_claim_no: return False if application_claim_id: return True return _find_application_claim_by_no(db, application_claim_no) is not None def _find_application_claim_by_no(db: Session, claim_no: str) -> ExpenseClaim | None: normalized_claim_no = str(claim_no or "").strip() if not normalized_claim_no: return None claim = db.scalar( select(ExpenseClaim) .where(ExpenseClaim.claim_no == normalized_claim_no) .limit(1) ) if claim is not None and ExpenseClaimService._is_expense_application_claim(claim): return claim return None def _build_direct_context_json(db: Session, context_json: dict[str, Any]) -> dict[str, Any]: direct_context = dict(context_json or {}) review_values = dict(direct_context.get("review_form_values") or {}) scene_selection = dict(direct_context.get("expense_scene_selection") or {}) application_claim_id = str(review_values.get("application_claim_id") or "").strip() application_claim_no = str(review_values.get("application_claim_no") or "").strip() if not application_claim_id and application_claim_no: application_claim = _find_application_claim_by_no(db, application_claim_no) if application_claim is not None: review_values["application_claim_id"] = application_claim.id scene_selection["application_claim_id"] = application_claim.id scene_selection["application_claim_no"] = str( scene_selection.get("application_claim_no") or application_claim.claim_no or application_claim_no ).strip() direct_context["review_form_values"] = review_values if scene_selection: direct_context["expense_scene_selection"] = scene_selection return direct_context def _run_direct_save_path( *, db: Session, state: LinkedReimbursementDraftJobState, current_user: CurrentUserContext, ) -> tuple[str, dict[str, Any], dict[str, Any]]: run_id = state.job_id ontology = OntologyParseResult( scenario="expense", intent="draft", permission=OntologyPermission( level="draft_write", allowed=True, reason="关联申请单生成报销草稿快路径。", ), confidence=1.0, run_id=run_id, ) result = ExpenseClaimService(db).save_or_submit_from_ontology( run_id=run_id, user_id=_resolve_user_id(current_user), message=state.message, ontology=ontology, context_json=_build_direct_context_json(db, state.context_json), ) claim_id = str(result.get("claim_id") or "").strip() claim_no = str(result.get("claim_no") or "").strip() if not claim_id or not claim_no or str(result.get("status") or "").strip() != "draft": raise ValueError(str(result.get("message") or "报销草稿生成失败,请稍后重试。").strip()) claim = db.get(ExpenseClaim, claim_id) return run_id, result, _build_direct_draft_payload(result, claim) def _build_direct_draft_payload( result: dict[str, Any], claim: ExpenseClaim | None, ) -> dict[str, Any]: claim_id = str(result.get("claim_id") or getattr(claim, "id", "") or "").strip() claim_no = str(result.get("claim_no") or getattr(claim, "claim_no", "") or "").strip() status = str(result.get("status") or getattr(claim, "status", "") or "draft").strip() approval_stage = str(getattr(claim, "approval_stage", "") or "待提交").strip() expense_type = str(getattr(claim, "expense_type", "") or "").strip() message = str(result.get("message") or "报销草稿已生成。").strip() return { "draft_type": "expense", "title": f"费用草稿 {claim_no}" if claim_no else "费用草稿", "body": message, "confirmation_required": True, "claim_id": claim_id, "claim_no": claim_no, "status": status, "approval_stage": approval_stage, "expense_type": expense_type, }