292 lines
11 KiB
Python
292 lines
11 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
from dataclasses import dataclass, field
|
||
|
|
from datetime import UTC, datetime
|
||
|
|
from threading import Lock
|
||
|
|
from typing import Any, Callable
|
||
|
|
from uuid import uuid4
|
||
|
|
|
||
|
|
from sqlalchemy import select
|
||
|
|
from sqlalchemy.orm import Session, sessionmaker
|
||
|
|
|
||
|
|
from app.api.deps import CurrentUserContext
|
||
|
|
from app.schemas.linked_reimbursement_draft_job import (
|
||
|
|
LinkedReimbursementDraftJobCreate,
|
||
|
|
LinkedReimbursementDraftJobRead,
|
||
|
|
)
|
||
|
|
from app.schemas.ontology import OntologyParseResult, OntologyPermission
|
||
|
|
from app.schemas.orchestrator import OrchestratorRequest
|
||
|
|
from app.models.financial_record import ExpenseClaim
|
||
|
|
from app.services.expense_claims import ExpenseClaimService
|
||
|
|
from app.services.orchestrator import OrchestratorService
|
||
|
|
|
||
|
|
|
||
|
|
TERMINAL_STATUSES = {"succeeded", "failed"}
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass(slots=True)
|
||
|
|
class LinkedReimbursementDraftJobState:
|
||
|
|
job_id: str
|
||
|
|
owner_username: str
|
||
|
|
owner_name: str
|
||
|
|
message: str
|
||
|
|
context_json: dict[str, Any]
|
||
|
|
conversation_id: str = ""
|
||
|
|
status: str = "queued"
|
||
|
|
status_message: str = "已创建报销草稿生成任务,等待后台处理。"
|
||
|
|
error: str = ""
|
||
|
|
run_id: str = ""
|
||
|
|
draft_payload: dict[str, Any] | None = None
|
||
|
|
created_at: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||
|
|
updated_at: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||
|
|
|
||
|
|
def to_read(self) -> LinkedReimbursementDraftJobRead:
|
||
|
|
return LinkedReimbursementDraftJobRead(
|
||
|
|
job_id=self.job_id,
|
||
|
|
status=self.status,
|
||
|
|
message=self.status_message,
|
||
|
|
error=self.error,
|
||
|
|
run_id=self.run_id,
|
||
|
|
conversation_id=self.conversation_id,
|
||
|
|
draft_payload=self.draft_payload,
|
||
|
|
created_at=self.created_at,
|
||
|
|
updated_at=self.updated_at,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
_jobs: dict[str, LinkedReimbursementDraftJobState] = {}
|
||
|
|
_jobs_lock = Lock()
|
||
|
|
|
||
|
|
|
||
|
|
def clear_linked_reimbursement_draft_jobs_for_tests() -> None:
|
||
|
|
with _jobs_lock:
|
||
|
|
_jobs.clear()
|
||
|
|
|
||
|
|
|
||
|
|
def create_linked_reimbursement_draft_job(
|
||
|
|
payload: LinkedReimbursementDraftJobCreate,
|
||
|
|
current_user: CurrentUserContext,
|
||
|
|
) -> LinkedReimbursementDraftJobRead:
|
||
|
|
context_json = dict(payload.context_json or {})
|
||
|
|
context_json["entry_source"] = context_json.get("entry_source") or "workbench-ai"
|
||
|
|
context_json["session_type"] = context_json.get("session_type") or "expense"
|
||
|
|
job_id = f"linked-reimbursement-draft-{uuid4()}"
|
||
|
|
state = LinkedReimbursementDraftJobState(
|
||
|
|
job_id=job_id,
|
||
|
|
owner_username=str(current_user.username or "").strip(),
|
||
|
|
owner_name=str(current_user.name or "").strip(),
|
||
|
|
message=str(payload.message or "").strip(),
|
||
|
|
context_json=context_json,
|
||
|
|
conversation_id=str(payload.conversation_id or "").strip(),
|
||
|
|
)
|
||
|
|
with _jobs_lock:
|
||
|
|
_jobs[job_id] = state
|
||
|
|
return state.to_read()
|
||
|
|
|
||
|
|
|
||
|
|
def get_linked_reimbursement_draft_job(
|
||
|
|
job_id: str,
|
||
|
|
current_user: CurrentUserContext,
|
||
|
|
) -> LinkedReimbursementDraftJobRead | None:
|
||
|
|
state = _get_authorized_state(job_id, current_user)
|
||
|
|
return state.to_read() if state is not None else None
|
||
|
|
|
||
|
|
|
||
|
|
def run_linked_reimbursement_draft_job(
|
||
|
|
job_id: str,
|
||
|
|
current_user: CurrentUserContext,
|
||
|
|
session_factory: sessionmaker[Session] | Callable[[], Session],
|
||
|
|
) -> None:
|
||
|
|
state = _get_authorized_state(job_id, current_user)
|
||
|
|
if state is None or state.status in TERMINAL_STATUSES:
|
||
|
|
return
|
||
|
|
|
||
|
|
_update_job(job_id, status="running", status_message="正在后台生成报销草稿...")
|
||
|
|
try:
|
||
|
|
with session_factory() as db:
|
||
|
|
if _can_use_direct_save_path(db, state.context_json):
|
||
|
|
run_id, result, draft_payload = _run_direct_save_path(
|
||
|
|
db=db,
|
||
|
|
state=state,
|
||
|
|
current_user=current_user,
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
response = OrchestratorService(db).run(
|
||
|
|
OrchestratorRequest(
|
||
|
|
source="user_message",
|
||
|
|
user_id=_resolve_user_id(current_user),
|
||
|
|
conversation_id=None,
|
||
|
|
message=state.message,
|
||
|
|
context_json=dict(state.context_json),
|
||
|
|
)
|
||
|
|
)
|
||
|
|
run_id = response.run_id
|
||
|
|
result = response.result if isinstance(response.result, dict) else {}
|
||
|
|
draft_payload = result.get("draft_payload") if isinstance(result.get("draft_payload"), dict) else None
|
||
|
|
if response.status != "succeeded":
|
||
|
|
raise ValueError(str(result.get("message") or "报销草稿生成失败,请稍后重试。").strip())
|
||
|
|
|
||
|
|
if draft_payload is None:
|
||
|
|
raise ValueError("报销草稿生成完成,但未返回草稿信息,请刷新单据列表后核对。")
|
||
|
|
|
||
|
|
_update_job(
|
||
|
|
job_id,
|
||
|
|
status="succeeded",
|
||
|
|
status_message=str(result.get("message") or "报销草稿已生成。").strip(),
|
||
|
|
run_id=run_id,
|
||
|
|
draft_payload=draft_payload,
|
||
|
|
error="",
|
||
|
|
)
|
||
|
|
except Exception as exc:
|
||
|
|
message = str(exc).strip() or "报销草稿生成失败,请稍后重试。"
|
||
|
|
_update_job(
|
||
|
|
job_id,
|
||
|
|
status="failed",
|
||
|
|
status_message=message,
|
||
|
|
error=message,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def _get_authorized_state(
|
||
|
|
job_id: str,
|
||
|
|
current_user: CurrentUserContext,
|
||
|
|
) -> LinkedReimbursementDraftJobState | None:
|
||
|
|
normalized_job_id = str(job_id or "").strip()
|
||
|
|
with _jobs_lock:
|
||
|
|
state = _jobs.get(normalized_job_id)
|
||
|
|
if state is None:
|
||
|
|
return None
|
||
|
|
if current_user.is_admin:
|
||
|
|
return state
|
||
|
|
username = str(current_user.username or "").strip()
|
||
|
|
name = str(current_user.name or "").strip()
|
||
|
|
if username and username == state.owner_username:
|
||
|
|
return state
|
||
|
|
if name and name == state.owner_name:
|
||
|
|
return state
|
||
|
|
return None
|
||
|
|
|
||
|
|
|
||
|
|
def _update_job(job_id: str, **updates: Any) -> None:
|
||
|
|
with _jobs_lock:
|
||
|
|
state = _jobs.get(str(job_id or "").strip())
|
||
|
|
if state is None:
|
||
|
|
return
|
||
|
|
for key, value in updates.items():
|
||
|
|
if hasattr(state, key):
|
||
|
|
setattr(state, key, value)
|
||
|
|
state.updated_at = datetime.now(UTC)
|
||
|
|
|
||
|
|
|
||
|
|
def _resolve_user_id(current_user: CurrentUserContext) -> str:
|
||
|
|
return str(current_user.username or current_user.name or "anonymous").strip() or "anonymous"
|
||
|
|
|
||
|
|
|
||
|
|
def _can_use_direct_save_path(db: Session, context_json: dict[str, Any]) -> bool:
|
||
|
|
review_action = str((context_json or {}).get("review_action") or "").strip()
|
||
|
|
if review_action != "save_draft":
|
||
|
|
return False
|
||
|
|
review_values = context_json.get("review_form_values")
|
||
|
|
if not isinstance(review_values, dict):
|
||
|
|
return False
|
||
|
|
application_claim_id = str(review_values.get("application_claim_id") or "").strip()
|
||
|
|
application_claim_no = str(review_values.get("application_claim_no") or "").strip()
|
||
|
|
if not application_claim_no:
|
||
|
|
return False
|
||
|
|
if application_claim_id:
|
||
|
|
return True
|
||
|
|
return _find_application_claim_by_no(db, application_claim_no) is not None
|
||
|
|
|
||
|
|
|
||
|
|
def _find_application_claim_by_no(db: Session, claim_no: str) -> ExpenseClaim | None:
|
||
|
|
normalized_claim_no = str(claim_no or "").strip()
|
||
|
|
if not normalized_claim_no:
|
||
|
|
return None
|
||
|
|
claim = db.scalar(
|
||
|
|
select(ExpenseClaim)
|
||
|
|
.where(ExpenseClaim.claim_no == normalized_claim_no)
|
||
|
|
.limit(1)
|
||
|
|
)
|
||
|
|
if claim is not None and ExpenseClaimService._is_expense_application_claim(claim):
|
||
|
|
return claim
|
||
|
|
return None
|
||
|
|
|
||
|
|
|
||
|
|
def _build_direct_context_json(db: Session, context_json: dict[str, Any]) -> dict[str, Any]:
|
||
|
|
direct_context = dict(context_json or {})
|
||
|
|
review_values = dict(direct_context.get("review_form_values") or {})
|
||
|
|
scene_selection = dict(direct_context.get("expense_scene_selection") or {})
|
||
|
|
application_claim_id = str(review_values.get("application_claim_id") or "").strip()
|
||
|
|
application_claim_no = str(review_values.get("application_claim_no") or "").strip()
|
||
|
|
if not application_claim_id and application_claim_no:
|
||
|
|
application_claim = _find_application_claim_by_no(db, application_claim_no)
|
||
|
|
if application_claim is not None:
|
||
|
|
review_values["application_claim_id"] = application_claim.id
|
||
|
|
scene_selection["application_claim_id"] = application_claim.id
|
||
|
|
scene_selection["application_claim_no"] = str(
|
||
|
|
scene_selection.get("application_claim_no")
|
||
|
|
or application_claim.claim_no
|
||
|
|
or application_claim_no
|
||
|
|
).strip()
|
||
|
|
direct_context["review_form_values"] = review_values
|
||
|
|
if scene_selection:
|
||
|
|
direct_context["expense_scene_selection"] = scene_selection
|
||
|
|
return direct_context
|
||
|
|
|
||
|
|
|
||
|
|
def _run_direct_save_path(
|
||
|
|
*,
|
||
|
|
db: Session,
|
||
|
|
state: LinkedReimbursementDraftJobState,
|
||
|
|
current_user: CurrentUserContext,
|
||
|
|
) -> tuple[str, dict[str, Any], dict[str, Any]]:
|
||
|
|
run_id = state.job_id
|
||
|
|
ontology = OntologyParseResult(
|
||
|
|
scenario="expense",
|
||
|
|
intent="draft",
|
||
|
|
permission=OntologyPermission(
|
||
|
|
level="draft_write",
|
||
|
|
allowed=True,
|
||
|
|
reason="关联申请单生成报销草稿快路径。",
|
||
|
|
),
|
||
|
|
confidence=1.0,
|
||
|
|
run_id=run_id,
|
||
|
|
)
|
||
|
|
result = ExpenseClaimService(db).save_or_submit_from_ontology(
|
||
|
|
run_id=run_id,
|
||
|
|
user_id=_resolve_user_id(current_user),
|
||
|
|
message=state.message,
|
||
|
|
ontology=ontology,
|
||
|
|
context_json=_build_direct_context_json(db, state.context_json),
|
||
|
|
)
|
||
|
|
claim_id = str(result.get("claim_id") or "").strip()
|
||
|
|
claim_no = str(result.get("claim_no") or "").strip()
|
||
|
|
if not claim_id or not claim_no or str(result.get("status") or "").strip() != "draft":
|
||
|
|
raise ValueError(str(result.get("message") or "报销草稿生成失败,请稍后重试。").strip())
|
||
|
|
|
||
|
|
claim = db.get(ExpenseClaim, claim_id)
|
||
|
|
return run_id, result, _build_direct_draft_payload(result, claim)
|
||
|
|
|
||
|
|
|
||
|
|
def _build_direct_draft_payload(
|
||
|
|
result: dict[str, Any],
|
||
|
|
claim: ExpenseClaim | None,
|
||
|
|
) -> dict[str, Any]:
|
||
|
|
claim_id = str(result.get("claim_id") or getattr(claim, "id", "") or "").strip()
|
||
|
|
claim_no = str(result.get("claim_no") or getattr(claim, "claim_no", "") or "").strip()
|
||
|
|
status = str(result.get("status") or getattr(claim, "status", "") or "draft").strip()
|
||
|
|
approval_stage = str(getattr(claim, "approval_stage", "") or "待提交").strip()
|
||
|
|
expense_type = str(getattr(claim, "expense_type", "") or "").strip()
|
||
|
|
message = str(result.get("message") or "报销草稿已生成。").strip()
|
||
|
|
return {
|
||
|
|
"draft_type": "expense",
|
||
|
|
"title": f"费用草稿 {claim_no}" if claim_no else "费用草稿",
|
||
|
|
"body": message,
|
||
|
|
"confirmation_required": True,
|
||
|
|
"claim_id": claim_id,
|
||
|
|
"claim_no": claim_no,
|
||
|
|
"status": status,
|
||
|
|
"approval_stage": approval_stage,
|
||
|
|
"expense_type": expense_type,
|
||
|
|
}
|