feat(server): 新增附件关联/关联报销草稿后台任务与申请位置语义
- attachment_association_jobs:从票据夹批量关联附件到报销单,识别城市/日期并创建明细项,内存态 job 跟踪 - linked_reimbursement_draft_jobs:基于申请单异步生成关联报销草稿,调用 Orchestrator 编排,区分 succeeded/failed 终态 - application_location_semantics:抽取差旅出发/到达城市、判断具体地址/业务动作等位置语义,供申请单校验复用 - router 注册两个 job 端点,新增对应 job/语义单元测试
This commit is contained in:
@@ -0,0 +1,75 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, status
|
||||
|
||||
from app.api.deps import CurrentUserContext, get_current_user
|
||||
from app.db.session import get_session_factory
|
||||
from app.schemas.attachment_association_job import (
|
||||
AttachmentAssociationJobCreate,
|
||||
AttachmentAssociationJobRead,
|
||||
)
|
||||
from app.schemas.common import ErrorResponse
|
||||
from app.services.attachment_association_jobs import (
|
||||
create_attachment_association_job,
|
||||
get_attachment_association_job,
|
||||
run_attachment_association_job,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/reimbursements/attachment-association-jobs")
|
||||
CurrentUser = Annotated[CurrentUserContext, Depends(get_current_user)]
|
||||
|
||||
|
||||
@router.post(
|
||||
"",
|
||||
response_model=AttachmentAssociationJobRead,
|
||||
status_code=status.HTTP_202_ACCEPTED,
|
||||
summary="创建附件自动关联后台任务",
|
||||
description="根据已 OCR 入票据夹的 receipt_id,在后台自动匹配并归集到报销草稿。",
|
||||
responses={
|
||||
status.HTTP_400_BAD_REQUEST: {
|
||||
"model": ErrorResponse,
|
||||
"description": "请求缺少可关联票据。",
|
||||
},
|
||||
},
|
||||
)
|
||||
def create_attachment_association_job_endpoint(
|
||||
payload: AttachmentAssociationJobCreate,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: CurrentUser,
|
||||
) -> AttachmentAssociationJobRead:
|
||||
try:
|
||||
job = create_attachment_association_job(payload, current_user)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
|
||||
background_tasks.add_task(
|
||||
run_attachment_association_job,
|
||||
job.job_id,
|
||||
current_user,
|
||||
get_session_factory(),
|
||||
)
|
||||
return job
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{job_id}",
|
||||
response_model=AttachmentAssociationJobRead,
|
||||
summary="查询附件自动关联后台任务",
|
||||
description="用于前端会话恢复后按 job_id 查询任务状态。",
|
||||
responses={
|
||||
status.HTTP_404_NOT_FOUND: {
|
||||
"model": ErrorResponse,
|
||||
"description": "任务不存在或当前用户无权查看。",
|
||||
},
|
||||
},
|
||||
)
|
||||
def get_attachment_association_job_endpoint(
|
||||
job_id: str,
|
||||
current_user: CurrentUser,
|
||||
) -> AttachmentAssociationJobRead:
|
||||
job = get_attachment_association_job(job_id, current_user)
|
||||
if job is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="附件关联任务不存在或已失效。")
|
||||
return job
|
||||
|
||||
@@ -0,0 +1,74 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, status
|
||||
|
||||
from app.api.deps import CurrentUserContext, get_current_user
|
||||
from app.db.session import get_session_factory
|
||||
from app.schemas.common import ErrorResponse
|
||||
from app.schemas.linked_reimbursement_draft_job import (
|
||||
LinkedReimbursementDraftJobCreate,
|
||||
LinkedReimbursementDraftJobRead,
|
||||
)
|
||||
from app.services.linked_reimbursement_draft_jobs import (
|
||||
create_linked_reimbursement_draft_job,
|
||||
get_linked_reimbursement_draft_job,
|
||||
run_linked_reimbursement_draft_job,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/reimbursements/linked-reimbursement-draft-jobs")
|
||||
CurrentUser = Annotated[CurrentUserContext, Depends(get_current_user)]
|
||||
|
||||
|
||||
@router.post(
|
||||
"",
|
||||
response_model=LinkedReimbursementDraftJobRead,
|
||||
status_code=status.HTTP_202_ACCEPTED,
|
||||
summary="创建关联申请单生成报销草稿后台任务",
|
||||
description="用户选择关联申请单后,后台继续生成报销草稿,避免当前会话长时间同步等待。",
|
||||
responses={
|
||||
status.HTTP_400_BAD_REQUEST: {
|
||||
"model": ErrorResponse,
|
||||
"description": "请求缺少申请单关联上下文。",
|
||||
},
|
||||
},
|
||||
)
|
||||
def create_linked_reimbursement_draft_job_endpoint(
|
||||
payload: LinkedReimbursementDraftJobCreate,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: CurrentUser,
|
||||
) -> LinkedReimbursementDraftJobRead:
|
||||
try:
|
||||
job = create_linked_reimbursement_draft_job(payload, current_user)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
|
||||
background_tasks.add_task(
|
||||
run_linked_reimbursement_draft_job,
|
||||
job.job_id,
|
||||
current_user,
|
||||
get_session_factory(),
|
||||
)
|
||||
return job
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{job_id}",
|
||||
response_model=LinkedReimbursementDraftJobRead,
|
||||
summary="查询关联申请单生成报销草稿后台任务",
|
||||
description="用于前端按 job_id 查询草稿生成状态。",
|
||||
responses={
|
||||
status.HTTP_404_NOT_FOUND: {
|
||||
"model": ErrorResponse,
|
||||
"description": "任务不存在或当前用户无权查看。",
|
||||
},
|
||||
},
|
||||
)
|
||||
def get_linked_reimbursement_draft_job_endpoint(
|
||||
job_id: str,
|
||||
current_user: CurrentUser,
|
||||
) -> LinkedReimbursementDraftJobRead:
|
||||
job = get_linked_reimbursement_draft_job(job_id, current_user)
|
||||
if job is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="报销草稿生成任务不存在或已失效。")
|
||||
return job
|
||||
@@ -6,6 +6,7 @@ from app.api.v1.endpoints.agent_feedback import router as agent_feedback_router
|
||||
from app.api.v1.endpoints.agent_runs import router as agent_runs_router
|
||||
from app.api.v1.endpoints.agent_traces import router as agent_traces_router
|
||||
from app.api.v1.endpoints.analytics import router as analytics_router
|
||||
from app.api.v1.endpoints.attachment_association_jobs import router as attachment_association_jobs_router
|
||||
from app.api.v1.endpoints.audit_logs import router as audit_logs_router
|
||||
from app.api.v1.endpoints.auth import router as auth_router
|
||||
from app.api.v1.endpoints.bootstrap import router as bootstrap_router
|
||||
@@ -14,6 +15,7 @@ from app.api.v1.endpoints.employees import router as employees_router
|
||||
from app.api.v1.endpoints.employee_profiles import router as employee_profiles_router
|
||||
from app.api.v1.endpoints.health import router as health_router
|
||||
from app.api.v1.endpoints.knowledge import router as knowledge_router
|
||||
from app.api.v1.endpoints.linked_reimbursement_draft_jobs import router as linked_reimbursement_draft_jobs_router
|
||||
from app.api.v1.endpoints.notification_states import router as notification_states_router
|
||||
from app.api.v1.endpoints.ocr import router as ocr_router
|
||||
from app.api.v1.endpoints.ontology import router as ontology_router
|
||||
@@ -36,8 +38,10 @@ router.include_router(agent_feedback_router, tags=["agent-feedback"])
|
||||
router.include_router(agent_runs_router, tags=["agent-runs"])
|
||||
router.include_router(agent_traces_router, tags=["agent-traces"])
|
||||
router.include_router(analytics_router, tags=["analytics"])
|
||||
router.include_router(attachment_association_jobs_router, tags=["attachment-association-jobs"])
|
||||
router.include_router(audit_logs_router, tags=["audit-logs"])
|
||||
router.include_router(knowledge_router, tags=["knowledge"])
|
||||
router.include_router(linked_reimbursement_draft_jobs_router, tags=["linked-reimbursement-draft-jobs"])
|
||||
router.include_router(notification_states_router, tags=["notification-states"])
|
||||
router.include_router(ocr_router, tags=["ocr"])
|
||||
router.include_router(ontology_router, tags=["ontology"])
|
||||
|
||||
40
server/src/app/schemas/attachment_association_job.py
Normal file
40
server/src/app/schemas/attachment_association_job.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class AttachmentAssociationJobCreate(BaseModel):
|
||||
receipt_ids: list[str] = Field(default_factory=list, description="票据夹持久化票据 ID。")
|
||||
prompt: str = Field(default="", max_length=1000, description="用户发送时的上下文说明。")
|
||||
conversation_id: str = Field(default="", max_length=120, description="前端会话 ID,用于状态恢复。")
|
||||
|
||||
@field_validator("receipt_ids")
|
||||
@classmethod
|
||||
def validate_receipt_ids(cls, value: list[str]) -> list[str]:
|
||||
receipt_ids = [
|
||||
str(item or "").strip()
|
||||
for item in list(value or [])
|
||||
if str(item or "").strip()
|
||||
]
|
||||
if not receipt_ids:
|
||||
raise ValueError("请先完成附件 OCR 识别,再发起自动关联。")
|
||||
return list(dict.fromkeys(receipt_ids))
|
||||
|
||||
|
||||
class AttachmentAssociationJobRead(BaseModel):
|
||||
job_id: str
|
||||
status: str
|
||||
message: str = ""
|
||||
receipt_ids: list[str] = Field(default_factory=list)
|
||||
claim_id: str = ""
|
||||
claim_no: str = ""
|
||||
uploaded_count: int = 0
|
||||
skipped_count: int = 0
|
||||
error: str = ""
|
||||
prompt: str = ""
|
||||
conversation_id: str = ""
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
32
server/src/app/schemas/linked_reimbursement_draft_job.py
Normal file
32
server/src/app/schemas/linked_reimbursement_draft_job.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class LinkedReimbursementDraftJobCreate(BaseModel):
|
||||
message: str = Field(min_length=1, max_length=3000, description="生成报销草稿的原始助手请求。")
|
||||
context_json: dict[str, Any] = Field(default_factory=dict, description="复用 Orchestrator 的上下文。")
|
||||
conversation_id: str = Field(default="", max_length=120, description="前端会话 ID,用于状态恢复。")
|
||||
|
||||
@field_validator("message")
|
||||
@classmethod
|
||||
def validate_message(cls, value: str) -> str:
|
||||
normalized = str(value or "").strip()
|
||||
if not normalized:
|
||||
raise ValueError("请先选择要关联的申请单。")
|
||||
return normalized
|
||||
|
||||
|
||||
class LinkedReimbursementDraftJobRead(BaseModel):
|
||||
job_id: str
|
||||
status: str
|
||||
message: str = ""
|
||||
error: str = ""
|
||||
run_id: str = ""
|
||||
conversation_id: str = ""
|
||||
draft_payload: dict[str, Any] | None = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
189
server/src/app/services/application_location_semantics.py
Normal file
189
server/src/app/services/application_location_semantics.py
Normal file
@@ -0,0 +1,189 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from collections.abc import Iterable
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
|
||||
from app.services.user_agent_application_locations import (
|
||||
CITY_TO_PROVINCE,
|
||||
DIRECT_MUNICIPALITY_DISPLAY,
|
||||
)
|
||||
|
||||
PLACEHOLDER_LOCATION_VALUES = {"", "待补充", "待确认", "未知", "暂无", "无", "null", "none"}
|
||||
BUSINESS_ACTION_PATTERN = re.compile(r"(?:支撑|支持|辅助|部署|上线|实施|验收|项目)")
|
||||
BUSINESS_OBJECT_PATTERN = re.compile(r"(?:服务器|系统|仿生产|生产环境|测试环境)")
|
||||
SPECIFIC_ADDRESS_HINT_PATTERN = re.compile(r"""
|
||||
(?:省|市|区|县|自治州|州|镇|乡|街道|路|街|大道|园区|大厦|中心|基地|机场|车站|高铁站|火车站|酒店|楼|号)$
|
||||
""", re.VERBOSE)
|
||||
LOCATION_TAGS = {"LOC", "ns", "s"}
|
||||
JIEBA_LOCATION_TAGS = {"ns"}
|
||||
JIEBA_CUSTOM_WORDS = (
|
||||
"国网",
|
||||
"仿生产",
|
||||
"生产环境",
|
||||
"测试环境",
|
||||
"服务器",
|
||||
"部署",
|
||||
"辅助",
|
||||
"支撑",
|
||||
"支持",
|
||||
"上线",
|
||||
"实施",
|
||||
"验收",
|
||||
)
|
||||
ROUTE_LOCATION_PREFIX_PATTERN = re.compile(
|
||||
r"^(?P<prefix>.*?(?:出差|前往|去|到|赴))(?P<body>[\u4e00-\u9fa5].*)$"
|
||||
)
|
||||
|
||||
|
||||
def compact_application_location_text(value: object) -> str:
|
||||
text = re.sub(r"\s+", "", str(value or ""))
|
||||
text = re.sub(r"^(?:地点|业务地点|发生地点|目的地)[::]", "", text)
|
||||
text = re.sub(r"^(?:去|到|赴|前往)", "", text)
|
||||
return text.strip("::,,。;;、")
|
||||
|
||||
|
||||
def validate_application_location_text(value: object) -> str:
|
||||
text = compact_application_location_text(value)
|
||||
if text.lower() in PLACEHOLDER_LOCATION_VALUES:
|
||||
return ""
|
||||
if not location_mixes_business_content(text):
|
||||
return ""
|
||||
return (
|
||||
f"地点“{text}”混入了业务事项,请填写真实出差地点,例如“上海”;"
|
||||
"业务背景请放在申请事由中。"
|
||||
)
|
||||
|
||||
|
||||
def location_mixes_business_content(value: object) -> bool:
|
||||
text = compact_application_location_text(value)
|
||||
if text.lower() in PLACEHOLDER_LOCATION_VALUES:
|
||||
return False
|
||||
if _matches_business_location_pattern(text):
|
||||
return True
|
||||
return _lac_detects_business_location_mix(text)
|
||||
|
||||
|
||||
def _matches_business_location_pattern(text: str) -> bool:
|
||||
if BUSINESS_ACTION_PATTERN.search(text):
|
||||
return True
|
||||
if BUSINESS_OBJECT_PATTERN.search(text) and not SPECIFIC_ADDRESS_HINT_PATTERN.search(text):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _lac_detects_business_location_mix(text: str) -> bool:
|
||||
tokens = list(resolve_lac_tokens(text))
|
||||
if not tokens:
|
||||
return False
|
||||
has_location = any(tag in LOCATION_TAGS for _, tag in tokens)
|
||||
if not has_location:
|
||||
return False
|
||||
non_location_text = "".join(
|
||||
word
|
||||
for word, tag in tokens
|
||||
if tag not in LOCATION_TAGS and tag != "w"
|
||||
)
|
||||
return _matches_business_location_pattern(non_location_text)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _load_lac_analyzer() -> Any:
|
||||
try:
|
||||
from LAC import LAC # type: ignore
|
||||
except Exception:
|
||||
return None
|
||||
try:
|
||||
return LAC(mode="lac")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def resolve_lac_tokens(text: str) -> Iterable[tuple[str, str]]:
|
||||
analyzer = _load_lac_analyzer()
|
||||
if analyzer is None:
|
||||
return []
|
||||
try:
|
||||
result = analyzer.run(text)
|
||||
except Exception:
|
||||
return []
|
||||
return _parse_lac_result(result)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _load_jieba_posseg() -> Any:
|
||||
try:
|
||||
import jieba
|
||||
import jieba.posseg as pseg
|
||||
except Exception:
|
||||
return None
|
||||
for word in _iter_jieba_custom_words():
|
||||
jieba.add_word(word, freq=100000)
|
||||
return pseg
|
||||
|
||||
|
||||
def _iter_jieba_custom_words() -> Iterable[str]:
|
||||
yield from JIEBA_CUSTOM_WORDS
|
||||
yield from DIRECT_MUNICIPALITY_DISPLAY
|
||||
yield from CITY_TO_PROVINCE
|
||||
|
||||
|
||||
def resolve_jieba_tokens(text: str) -> list[tuple[str, str]]:
|
||||
posseg = _load_jieba_posseg()
|
||||
if posseg is None:
|
||||
return []
|
||||
try:
|
||||
return [
|
||||
(str(item.word or "").strip(), str(item.flag or "").strip())
|
||||
for item in posseg.cut(str(text or ""), HMM=True)
|
||||
if str(item.word or "").strip()
|
||||
]
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
||||
def strip_route_location_prefix_with_jieba(value: object) -> str:
|
||||
text = str(value or "").strip()
|
||||
match = ROUTE_LOCATION_PREFIX_PATTERN.search(text)
|
||||
if not match:
|
||||
return text
|
||||
|
||||
body = match.group("body").strip()
|
||||
tokens = resolve_jieba_tokens(body)
|
||||
if not tokens:
|
||||
return text
|
||||
first_word, first_tag = tokens[0]
|
||||
if not _is_jieba_location_token(first_word, first_tag):
|
||||
return text
|
||||
return body[len(first_word) :].strip(" ::,,。;;、")
|
||||
|
||||
|
||||
def _is_jieba_location_token(word: str, tag: str) -> bool:
|
||||
if tag in JIEBA_LOCATION_TAGS:
|
||||
return True
|
||||
return word in DIRECT_MUNICIPALITY_DISPLAY or word in CITY_TO_PROVINCE
|
||||
|
||||
|
||||
def _parse_lac_result(result: Any) -> list[tuple[str, str]]:
|
||||
if (
|
||||
isinstance(result, (list, tuple))
|
||||
and len(result) == 2
|
||||
and isinstance(result[0], list)
|
||||
and isinstance(result[1], list)
|
||||
):
|
||||
return [
|
||||
(str(word or "").strip(), str(tag or "").strip())
|
||||
for word, tag in zip(result[0], result[1], strict=False)
|
||||
if str(word or "").strip()
|
||||
]
|
||||
if isinstance(result, list) and all(
|
||||
isinstance(item, (list, tuple)) and len(item) >= 2
|
||||
for item in result
|
||||
):
|
||||
return [
|
||||
(str(item[0] or "").strip(), str(item[1] or "").strip())
|
||||
for item in result
|
||||
if str(item[0] or "").strip()
|
||||
]
|
||||
return []
|
||||
549
server/src/app/services/attachment_association_jobs.py
Normal file
549
server/src/app/services/attachment_association_jobs.py
Normal file
@@ -0,0 +1,549 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, date, datetime
|
||||
from decimal import Decimal
|
||||
from threading import Lock
|
||||
from typing import Any, Callable
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from app.api.deps import CurrentUserContext
|
||||
from app.models.financial_record import ExpenseClaim, ExpenseClaimItem
|
||||
from app.schemas.attachment_association_job import (
|
||||
AttachmentAssociationJobCreate,
|
||||
AttachmentAssociationJobRead,
|
||||
)
|
||||
from app.schemas.receipt_folder import ReceiptFolderDetailRead
|
||||
from app.schemas.reimbursement import ExpenseClaimItemCreate
|
||||
from app.services.expense_claim_constants import (
|
||||
DOCUMENT_TYPE_ITEM_TYPE_MAP,
|
||||
EDITABLE_CLAIM_STATUSES,
|
||||
)
|
||||
from app.services.expense_claims import ExpenseClaimService
|
||||
from app.services.receipt_folder import ReceiptFolderService
|
||||
|
||||
|
||||
CITY_NAMES = (
|
||||
"北京",
|
||||
"上海",
|
||||
"广州",
|
||||
"深圳",
|
||||
"武汉",
|
||||
"南京",
|
||||
"杭州",
|
||||
"成都",
|
||||
"重庆",
|
||||
"西安",
|
||||
"天津",
|
||||
"苏州",
|
||||
"长沙",
|
||||
"郑州",
|
||||
"青岛",
|
||||
"厦门",
|
||||
"宁波",
|
||||
"无锡",
|
||||
"合肥",
|
||||
"福州",
|
||||
"昆明",
|
||||
"大连",
|
||||
"沈阳",
|
||||
"济南",
|
||||
"哈尔滨",
|
||||
"长春",
|
||||
"南昌",
|
||||
"太原",
|
||||
"贵阳",
|
||||
"南宁",
|
||||
"石家庄",
|
||||
"兰州",
|
||||
"银川",
|
||||
"西宁",
|
||||
"海口",
|
||||
"拉萨",
|
||||
)
|
||||
|
||||
TERMINAL_STATUSES = {"succeeded", "failed"}
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class AttachmentAssociationJobState:
|
||||
job_id: str
|
||||
owner_username: str
|
||||
owner_name: str
|
||||
receipt_ids: list[str]
|
||||
prompt: str = ""
|
||||
conversation_id: str = ""
|
||||
status: str = "queued"
|
||||
message: str = "已创建附件关联任务,等待后台处理。"
|
||||
claim_id: str = ""
|
||||
claim_no: str = ""
|
||||
uploaded_count: int = 0
|
||||
skipped_count: int = 0
|
||||
error: str = ""
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||
updated_at: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||
|
||||
def to_read(self) -> AttachmentAssociationJobRead:
|
||||
return AttachmentAssociationJobRead(
|
||||
job_id=self.job_id,
|
||||
status=self.status,
|
||||
message=self.message,
|
||||
receipt_ids=list(self.receipt_ids),
|
||||
claim_id=self.claim_id,
|
||||
claim_no=self.claim_no,
|
||||
uploaded_count=self.uploaded_count,
|
||||
skipped_count=self.skipped_count,
|
||||
error=self.error,
|
||||
prompt=self.prompt,
|
||||
conversation_id=self.conversation_id,
|
||||
created_at=self.created_at,
|
||||
updated_at=self.updated_at,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class AttachmentAssociationCandidate:
|
||||
claim: ExpenseClaim
|
||||
score: int
|
||||
reasons: list[str]
|
||||
|
||||
|
||||
_jobs: dict[str, AttachmentAssociationJobState] = {}
|
||||
_jobs_lock = Lock()
|
||||
|
||||
|
||||
def clear_attachment_association_jobs_for_tests() -> None:
|
||||
with _jobs_lock:
|
||||
_jobs.clear()
|
||||
|
||||
|
||||
def create_attachment_association_job(
|
||||
payload: AttachmentAssociationJobCreate,
|
||||
current_user: CurrentUserContext,
|
||||
) -> AttachmentAssociationJobRead:
|
||||
job_id = f"attachment-association-{uuid4()}"
|
||||
state = AttachmentAssociationJobState(
|
||||
job_id=job_id,
|
||||
owner_username=str(current_user.username or "").strip(),
|
||||
owner_name=str(current_user.name or "").strip(),
|
||||
receipt_ids=list(payload.receipt_ids),
|
||||
prompt=str(payload.prompt or "").strip(),
|
||||
conversation_id=str(payload.conversation_id or "").strip(),
|
||||
)
|
||||
with _jobs_lock:
|
||||
_jobs[job_id] = state
|
||||
return state.to_read()
|
||||
|
||||
|
||||
def get_attachment_association_job(
|
||||
job_id: str,
|
||||
current_user: CurrentUserContext,
|
||||
) -> AttachmentAssociationJobRead | None:
|
||||
state = _get_authorized_state(job_id, current_user)
|
||||
return state.to_read() if state is not None else None
|
||||
|
||||
|
||||
def run_attachment_association_job(
|
||||
job_id: str,
|
||||
current_user: CurrentUserContext,
|
||||
session_factory: sessionmaker[Session] | Callable[[], Session],
|
||||
) -> None:
|
||||
state = _get_authorized_state(job_id, current_user)
|
||||
if state is None or state.status in TERMINAL_STATUSES:
|
||||
return
|
||||
|
||||
_update_job(job_id, status="running", message="正在匹配可关联的报销草稿...")
|
||||
try:
|
||||
with session_factory() as db:
|
||||
result = AttachmentAssociationJobRunner(db).run(
|
||||
receipt_ids=state.receipt_ids,
|
||||
current_user=current_user,
|
||||
)
|
||||
_update_job(
|
||||
job_id,
|
||||
status="succeeded",
|
||||
message=f"已自动关联到 {result['claim_no']},成功归集 {result['uploaded_count']} 份附件。",
|
||||
claim_id=str(result["claim_id"]),
|
||||
claim_no=str(result["claim_no"]),
|
||||
uploaded_count=int(result["uploaded_count"]),
|
||||
skipped_count=int(result["skipped_count"]),
|
||||
error="",
|
||||
)
|
||||
except Exception as exc:
|
||||
message = str(exc).strip() or "自动关联任务执行失败,请稍后重试。"
|
||||
_update_job(
|
||||
job_id,
|
||||
status="failed",
|
||||
message=message,
|
||||
error=message,
|
||||
)
|
||||
|
||||
|
||||
def _get_authorized_state(
|
||||
job_id: str,
|
||||
current_user: CurrentUserContext,
|
||||
) -> AttachmentAssociationJobState | None:
|
||||
normalized_job_id = str(job_id or "").strip()
|
||||
with _jobs_lock:
|
||||
state = _jobs.get(normalized_job_id)
|
||||
if state is None:
|
||||
return None
|
||||
if current_user.is_admin:
|
||||
return state
|
||||
username = str(current_user.username or "").strip()
|
||||
name = str(current_user.name or "").strip()
|
||||
if username and username == state.owner_username:
|
||||
return state
|
||||
if name and name == state.owner_name:
|
||||
return state
|
||||
return None
|
||||
|
||||
|
||||
def _update_job(job_id: str, **updates: Any) -> None:
|
||||
with _jobs_lock:
|
||||
state = _jobs.get(str(job_id or "").strip())
|
||||
if state is None:
|
||||
return
|
||||
for key, value in updates.items():
|
||||
if hasattr(state, key):
|
||||
setattr(state, key, value)
|
||||
state.updated_at = datetime.now(UTC)
|
||||
|
||||
|
||||
class AttachmentAssociationJobRunner:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
self.claim_service = ExpenseClaimService(db)
|
||||
self.receipt_service = ReceiptFolderService()
|
||||
|
||||
def run(
|
||||
self,
|
||||
*,
|
||||
receipt_ids: list[str],
|
||||
current_user: CurrentUserContext,
|
||||
) -> dict[str, Any]:
|
||||
receipts = self._load_receipts(receipt_ids, current_user)
|
||||
candidates = self._rank_claims(receipts, current_user)
|
||||
if not candidates:
|
||||
raise ValueError("没有找到可自动关联的报销草稿,请先新建草稿或补充说明。")
|
||||
|
||||
recommended = candidates[0]
|
||||
runner_up = candidates[1] if len(candidates) > 1 else None
|
||||
if recommended.score < 5 or (runner_up is not None and recommended.score - runner_up.score < 2):
|
||||
raise ValueError("找到多个可能关联的报销草稿,请补充说明或手动选择后再归集。")
|
||||
|
||||
uploaded_count = 0
|
||||
skipped_count = 0
|
||||
for receipt in receipts:
|
||||
if self._is_linked_to_other_claim(receipt, recommended.claim.id):
|
||||
skipped_count += 1
|
||||
continue
|
||||
target_item = self._resolve_target_item(
|
||||
claim_id=recommended.claim.id,
|
||||
receipt=receipt,
|
||||
current_user=current_user,
|
||||
)
|
||||
source_path, media_type, file_name = self.receipt_service.resolve_source(receipt.id, current_user)
|
||||
result = self.claim_service.upload_claim_item_attachment(
|
||||
claim_id=recommended.claim.id,
|
||||
item_id=target_item.id,
|
||||
filename=file_name,
|
||||
content=source_path.read_bytes(),
|
||||
media_type=media_type,
|
||||
current_user=current_user,
|
||||
source_receipt_id=receipt.id,
|
||||
)
|
||||
if result is None:
|
||||
skipped_count += 1
|
||||
else:
|
||||
uploaded_count += 1
|
||||
|
||||
if uploaded_count <= 0:
|
||||
raise ValueError("未能归集任何附件,请进入报销单详情手动核对。")
|
||||
return {
|
||||
"claim_id": recommended.claim.id,
|
||||
"claim_no": recommended.claim.claim_no,
|
||||
"uploaded_count": uploaded_count,
|
||||
"skipped_count": skipped_count,
|
||||
}
|
||||
|
||||
def _load_receipts(
|
||||
self,
|
||||
receipt_ids: list[str],
|
||||
current_user: CurrentUserContext,
|
||||
) -> list[ReceiptFolderDetailRead]:
|
||||
receipts = []
|
||||
for receipt_id in list(dict.fromkeys(str(item or "").strip() for item in receipt_ids if str(item or "").strip())):
|
||||
try:
|
||||
receipts.append(self.receipt_service.get_receipt(receipt_id, current_user))
|
||||
except FileNotFoundError as exc:
|
||||
raise ValueError("当前附件没有持久化票据记录,请重新上传后再试。") from exc
|
||||
if not receipts:
|
||||
raise ValueError("当前附件没有持久化票据记录,请重新上传后再试。")
|
||||
return receipts
|
||||
|
||||
def _rank_claims(
|
||||
self,
|
||||
receipts: list[ReceiptFolderDetailRead],
|
||||
current_user: CurrentUserContext,
|
||||
) -> list[AttachmentAssociationCandidate]:
|
||||
signals = _collect_receipt_signals(receipts)
|
||||
claims = [
|
||||
claim
|
||||
for claim in self.claim_service.list_claims(current_user)
|
||||
if self._is_auto_association_candidate(claim)
|
||||
]
|
||||
ranked = [
|
||||
candidate
|
||||
for candidate in (
|
||||
self._score_claim(claim, signals)
|
||||
for claim in claims
|
||||
)
|
||||
if candidate.score > 0
|
||||
]
|
||||
return sorted(ranked, key=lambda item: item.score, reverse=True)
|
||||
|
||||
def _is_auto_association_candidate(self, claim: ExpenseClaim) -> bool:
|
||||
status = str(claim.status or "").strip().lower()
|
||||
if status not in EDITABLE_CLAIM_STATUSES:
|
||||
return False
|
||||
return not self.claim_service._is_expense_application_claim(claim)
|
||||
|
||||
def _score_claim(
|
||||
self,
|
||||
claim: ExpenseClaim,
|
||||
signals: dict[str, Any],
|
||||
) -> AttachmentAssociationCandidate:
|
||||
claim_text = _build_claim_text(claim)
|
||||
compact_claim_text = _normalize_text(claim_text)
|
||||
claim_dates = _extract_date_tokens(claim_text)
|
||||
claim_cities = _unique([*_extract_city_tokens(claim_text), *_extract_city_tokens(claim.location)])
|
||||
reasons: list[str] = []
|
||||
score = 0
|
||||
|
||||
if _dates_overlap(signals["dates"], claim_dates):
|
||||
score += 4
|
||||
reasons.append("票据日期与报销单日期一致")
|
||||
|
||||
matched_cities = [city for city in signals["cities"] if city in compact_claim_text]
|
||||
if matched_cities:
|
||||
score += min(4, len(matched_cities) * 2)
|
||||
reasons.append(f"地点或行程包含 {'、'.join(matched_cities)}")
|
||||
|
||||
if len(claim_cities) >= 2 and len(matched_cities) >= 2:
|
||||
score += 2
|
||||
reasons.append("票据往返城市与报销事由吻合")
|
||||
|
||||
if str(claim.status or "").strip().lower() == "draft":
|
||||
score += 1
|
||||
reasons.append("当前单据仍是可归集草稿")
|
||||
|
||||
return AttachmentAssociationCandidate(claim=claim, score=score, reasons=reasons)
|
||||
|
||||
@staticmethod
|
||||
def _is_linked_to_other_claim(receipt: ReceiptFolderDetailRead, claim_id: str) -> bool:
|
||||
linked_claim_id = str(receipt.linked_claim_id or "").strip()
|
||||
return bool(str(receipt.status or "").strip() == "linked" and linked_claim_id and linked_claim_id != claim_id)
|
||||
|
||||
def _resolve_target_item(
|
||||
self,
|
||||
*,
|
||||
claim_id: str,
|
||||
receipt: ReceiptFolderDetailRead,
|
||||
current_user: CurrentUserContext,
|
||||
) -> ExpenseClaimItem:
|
||||
claim = self.claim_service.get_claim(claim_id, current_user)
|
||||
if claim is None:
|
||||
raise ValueError("匹配到的报销草稿不存在,请刷新后再试。")
|
||||
|
||||
preferred_type = _resolve_receipt_item_type(receipt)
|
||||
empty_items = [
|
||||
item
|
||||
for item in list(claim.items or [])
|
||||
if not str(item.invoice_id or "").strip() and not item.is_system_generated
|
||||
]
|
||||
for item in empty_items:
|
||||
if preferred_type and str(item.item_type or "").strip() == preferred_type:
|
||||
return item
|
||||
if empty_items:
|
||||
return empty_items[0]
|
||||
|
||||
before_ids = {str(item.id) for item in list(claim.items or [])}
|
||||
created_claim = self.claim_service.create_claim_item(
|
||||
claim_id=claim.id,
|
||||
payload=_build_item_payload_from_receipt(claim, receipt, preferred_type),
|
||||
current_user=current_user,
|
||||
)
|
||||
if created_claim is None:
|
||||
raise ValueError("无法创建票据归集明细,请进入详情页手动处理。")
|
||||
for item in list(created_claim.items or []):
|
||||
if str(item.id) not in before_ids and not str(item.invoice_id or "").strip():
|
||||
return item
|
||||
raise ValueError("无法找到可归集的费用明细,请进入详情页手动处理。")
|
||||
|
||||
|
||||
def _normalize_text(value: Any) -> str:
|
||||
return re.sub(r"\s+", "", str(value or "").strip())
|
||||
|
||||
|
||||
def _unique(values: list[str] | tuple[str, ...]) -> list[str]:
|
||||
return list(dict.fromkeys(str(item or "").strip() for item in values if str(item or "").strip()))
|
||||
|
||||
|
||||
def _extract_date_tokens(text: Any) -> list[str]:
|
||||
source = str(text or "")
|
||||
matches = [
|
||||
*re.finditer(r"20\d{2}[-/.年]\d{1,2}[-/.月]\d{1,2}", source),
|
||||
*re.finditer(r"\d{1,2}月\d{1,2}", source),
|
||||
]
|
||||
return _unique([_normalize_date_token(match.group(0)) for match in matches])
|
||||
|
||||
|
||||
def _normalize_date_token(value: Any) -> str:
|
||||
if isinstance(value, (date, datetime)):
|
||||
return value.isoformat()[:10]
|
||||
text = str(value or "").strip()
|
||||
full_match = re.search(r"(20\d{2})[-/.年](\d{1,2})[-/.月](\d{1,2})", text)
|
||||
if full_match:
|
||||
year, month, day = full_match.groups()
|
||||
return f"{year}-{month.zfill(2)}-{day.zfill(2)}"
|
||||
short_match = re.search(r"(\d{1,2})月(\d{1,2})", text)
|
||||
if short_match:
|
||||
month, day = short_match.groups()
|
||||
return f"{month.zfill(2)}-{day.zfill(2)}"
|
||||
return ""
|
||||
|
||||
|
||||
def _extract_city_tokens(text: Any) -> list[str]:
|
||||
compact = _normalize_text(text)
|
||||
if not compact:
|
||||
return []
|
||||
return [city for city in CITY_NAMES if city in compact]
|
||||
|
||||
|
||||
def _dates_overlap(left: list[str], right: list[str]) -> bool:
|
||||
for left_date in left:
|
||||
if not left_date:
|
||||
continue
|
||||
for right_date in right:
|
||||
if right_date and (left_date == right_date or left_date.endswith(right_date) or right_date.endswith(left_date)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _collect_receipt_signals(receipts: list[ReceiptFolderDetailRead]) -> dict[str, Any]:
|
||||
text = "\n".join(_build_receipt_text(receipt) for receipt in receipts)
|
||||
dates = _unique([
|
||||
*_extract_date_tokens(text),
|
||||
*[str(receipt.document_date or "").strip() for receipt in receipts],
|
||||
])
|
||||
return {
|
||||
"text": text,
|
||||
"dates": dates,
|
||||
"cities": _unique(_extract_city_tokens(text)),
|
||||
}
|
||||
|
||||
|
||||
def _build_receipt_text(receipt: ReceiptFolderDetailRead) -> str:
|
||||
fields_text = "\n".join(
|
||||
f"{field.label} {field.value}"
|
||||
for field in list(receipt.fields or [])
|
||||
if str(field.label or field.value or "").strip()
|
||||
)
|
||||
return "\n".join(
|
||||
value
|
||||
for value in (
|
||||
receipt.file_name,
|
||||
receipt.summary,
|
||||
receipt.ocr_text,
|
||||
receipt.document_date,
|
||||
receipt.merchant_name,
|
||||
fields_text,
|
||||
)
|
||||
if str(value or "").strip()
|
||||
)
|
||||
|
||||
|
||||
def _build_claim_text(claim: ExpenseClaim) -> str:
|
||||
item_text = "\n".join(
|
||||
" ".join(
|
||||
str(value or "").strip()
|
||||
for value in (
|
||||
item.item_date.isoformat() if item.item_date else "",
|
||||
item.item_type,
|
||||
item.item_reason,
|
||||
item.item_location,
|
||||
item.item_note,
|
||||
)
|
||||
if str(value or "").strip()
|
||||
)
|
||||
for item in list(claim.items or [])
|
||||
)
|
||||
occurred_at = claim.occurred_at.isoformat()[:10] if claim.occurred_at else ""
|
||||
return "\n".join(
|
||||
value
|
||||
for value in (
|
||||
claim.claim_no,
|
||||
claim.expense_type,
|
||||
claim.status,
|
||||
claim.reason,
|
||||
claim.location,
|
||||
occurred_at,
|
||||
item_text,
|
||||
)
|
||||
if str(value or "").strip()
|
||||
)
|
||||
|
||||
|
||||
def _resolve_receipt_item_type(receipt: ReceiptFolderDetailRead) -> str:
|
||||
document_type = str(receipt.document_type or "").strip()
|
||||
if document_type in DOCUMENT_TYPE_ITEM_TYPE_MAP:
|
||||
return DOCUMENT_TYPE_ITEM_TYPE_MAP[document_type]
|
||||
scene_code = str(receipt.scene_code or "").strip()
|
||||
if scene_code == "travel":
|
||||
return "travel"
|
||||
return scene_code or "other"
|
||||
|
||||
|
||||
def _build_item_payload_from_receipt(
|
||||
claim: ExpenseClaim,
|
||||
receipt: ReceiptFolderDetailRead,
|
||||
preferred_type: str,
|
||||
) -> ExpenseClaimItemCreate:
|
||||
item_date = _resolve_receipt_item_date(receipt) or (claim.occurred_at.date() if claim.occurred_at else None)
|
||||
return ExpenseClaimItemCreate(
|
||||
item_date=item_date,
|
||||
item_type=preferred_type or str(claim.expense_type or "").strip() or "other",
|
||||
item_reason=str(receipt.summary or receipt.file_name or "").strip(),
|
||||
item_location=_resolve_receipt_item_location(receipt) or str(claim.location or "").strip(),
|
||||
item_amount=Decimal("0.00"),
|
||||
)
|
||||
|
||||
|
||||
def _resolve_receipt_item_date(receipt: ReceiptFolderDetailRead) -> date | None:
|
||||
for value in [
|
||||
*[field.value for field in list(receipt.fields or []) if "日期" in str(field.label or "") or "时间" in str(field.label or "")],
|
||||
receipt.document_date,
|
||||
]:
|
||||
token = _normalize_date_token(value)
|
||||
if len(token) == 10:
|
||||
try:
|
||||
return date.fromisoformat(token)
|
||||
except ValueError:
|
||||
continue
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_receipt_item_location(receipt: ReceiptFolderDetailRead) -> str:
|
||||
for field in list(receipt.fields or []):
|
||||
label = str(field.label or "")
|
||||
value = str(field.value or "").strip()
|
||||
if value and ("行程" in label or "到达" in label or "地点" in label or "城市" in label):
|
||||
cities = _extract_city_tokens(value)
|
||||
return cities[-1] if cities else value[:40]
|
||||
cities = _extract_city_tokens(_build_receipt_text(receipt))
|
||||
return cities[-1] if cities else ""
|
||||
|
||||
291
server/src/app/services/linked_reimbursement_draft_jobs.py
Normal file
291
server/src/app/services/linked_reimbursement_draft_jobs.py
Normal file
@@ -0,0 +1,291 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from threading import Lock
|
||||
from typing import Any, Callable
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from app.api.deps import CurrentUserContext
|
||||
from app.schemas.linked_reimbursement_draft_job import (
|
||||
LinkedReimbursementDraftJobCreate,
|
||||
LinkedReimbursementDraftJobRead,
|
||||
)
|
||||
from app.schemas.ontology import OntologyParseResult, OntologyPermission
|
||||
from app.schemas.orchestrator import OrchestratorRequest
|
||||
from app.models.financial_record import ExpenseClaim
|
||||
from app.services.expense_claims import ExpenseClaimService
|
||||
from app.services.orchestrator import OrchestratorService
|
||||
|
||||
|
||||
TERMINAL_STATUSES = {"succeeded", "failed"}
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class LinkedReimbursementDraftJobState:
|
||||
job_id: str
|
||||
owner_username: str
|
||||
owner_name: str
|
||||
message: str
|
||||
context_json: dict[str, Any]
|
||||
conversation_id: str = ""
|
||||
status: str = "queued"
|
||||
status_message: str = "已创建报销草稿生成任务,等待后台处理。"
|
||||
error: str = ""
|
||||
run_id: str = ""
|
||||
draft_payload: dict[str, Any] | None = None
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||
updated_at: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||
|
||||
def to_read(self) -> LinkedReimbursementDraftJobRead:
|
||||
return LinkedReimbursementDraftJobRead(
|
||||
job_id=self.job_id,
|
||||
status=self.status,
|
||||
message=self.status_message,
|
||||
error=self.error,
|
||||
run_id=self.run_id,
|
||||
conversation_id=self.conversation_id,
|
||||
draft_payload=self.draft_payload,
|
||||
created_at=self.created_at,
|
||||
updated_at=self.updated_at,
|
||||
)
|
||||
|
||||
|
||||
_jobs: dict[str, LinkedReimbursementDraftJobState] = {}
|
||||
_jobs_lock = Lock()
|
||||
|
||||
|
||||
def clear_linked_reimbursement_draft_jobs_for_tests() -> None:
|
||||
with _jobs_lock:
|
||||
_jobs.clear()
|
||||
|
||||
|
||||
def create_linked_reimbursement_draft_job(
|
||||
payload: LinkedReimbursementDraftJobCreate,
|
||||
current_user: CurrentUserContext,
|
||||
) -> LinkedReimbursementDraftJobRead:
|
||||
context_json = dict(payload.context_json or {})
|
||||
context_json["entry_source"] = context_json.get("entry_source") or "workbench-ai"
|
||||
context_json["session_type"] = context_json.get("session_type") or "expense"
|
||||
job_id = f"linked-reimbursement-draft-{uuid4()}"
|
||||
state = LinkedReimbursementDraftJobState(
|
||||
job_id=job_id,
|
||||
owner_username=str(current_user.username or "").strip(),
|
||||
owner_name=str(current_user.name or "").strip(),
|
||||
message=str(payload.message or "").strip(),
|
||||
context_json=context_json,
|
||||
conversation_id=str(payload.conversation_id or "").strip(),
|
||||
)
|
||||
with _jobs_lock:
|
||||
_jobs[job_id] = state
|
||||
return state.to_read()
|
||||
|
||||
|
||||
def get_linked_reimbursement_draft_job(
|
||||
job_id: str,
|
||||
current_user: CurrentUserContext,
|
||||
) -> LinkedReimbursementDraftJobRead | None:
|
||||
state = _get_authorized_state(job_id, current_user)
|
||||
return state.to_read() if state is not None else None
|
||||
|
||||
|
||||
def run_linked_reimbursement_draft_job(
|
||||
job_id: str,
|
||||
current_user: CurrentUserContext,
|
||||
session_factory: sessionmaker[Session] | Callable[[], Session],
|
||||
) -> None:
|
||||
state = _get_authorized_state(job_id, current_user)
|
||||
if state is None or state.status in TERMINAL_STATUSES:
|
||||
return
|
||||
|
||||
_update_job(job_id, status="running", status_message="正在后台生成报销草稿...")
|
||||
try:
|
||||
with session_factory() as db:
|
||||
if _can_use_direct_save_path(db, state.context_json):
|
||||
run_id, result, draft_payload = _run_direct_save_path(
|
||||
db=db,
|
||||
state=state,
|
||||
current_user=current_user,
|
||||
)
|
||||
else:
|
||||
response = OrchestratorService(db).run(
|
||||
OrchestratorRequest(
|
||||
source="user_message",
|
||||
user_id=_resolve_user_id(current_user),
|
||||
conversation_id=None,
|
||||
message=state.message,
|
||||
context_json=dict(state.context_json),
|
||||
)
|
||||
)
|
||||
run_id = response.run_id
|
||||
result = response.result if isinstance(response.result, dict) else {}
|
||||
draft_payload = result.get("draft_payload") if isinstance(result.get("draft_payload"), dict) else None
|
||||
if response.status != "succeeded":
|
||||
raise ValueError(str(result.get("message") or "报销草稿生成失败,请稍后重试。").strip())
|
||||
|
||||
if draft_payload is None:
|
||||
raise ValueError("报销草稿生成完成,但未返回草稿信息,请刷新单据列表后核对。")
|
||||
|
||||
_update_job(
|
||||
job_id,
|
||||
status="succeeded",
|
||||
status_message=str(result.get("message") or "报销草稿已生成。").strip(),
|
||||
run_id=run_id,
|
||||
draft_payload=draft_payload,
|
||||
error="",
|
||||
)
|
||||
except Exception as exc:
|
||||
message = str(exc).strip() or "报销草稿生成失败,请稍后重试。"
|
||||
_update_job(
|
||||
job_id,
|
||||
status="failed",
|
||||
status_message=message,
|
||||
error=message,
|
||||
)
|
||||
|
||||
|
||||
def _get_authorized_state(
|
||||
job_id: str,
|
||||
current_user: CurrentUserContext,
|
||||
) -> LinkedReimbursementDraftJobState | None:
|
||||
normalized_job_id = str(job_id or "").strip()
|
||||
with _jobs_lock:
|
||||
state = _jobs.get(normalized_job_id)
|
||||
if state is None:
|
||||
return None
|
||||
if current_user.is_admin:
|
||||
return state
|
||||
username = str(current_user.username or "").strip()
|
||||
name = str(current_user.name or "").strip()
|
||||
if username and username == state.owner_username:
|
||||
return state
|
||||
if name and name == state.owner_name:
|
||||
return state
|
||||
return None
|
||||
|
||||
|
||||
def _update_job(job_id: str, **updates: Any) -> None:
|
||||
with _jobs_lock:
|
||||
state = _jobs.get(str(job_id or "").strip())
|
||||
if state is None:
|
||||
return
|
||||
for key, value in updates.items():
|
||||
if hasattr(state, key):
|
||||
setattr(state, key, value)
|
||||
state.updated_at = datetime.now(UTC)
|
||||
|
||||
|
||||
def _resolve_user_id(current_user: CurrentUserContext) -> str:
|
||||
return str(current_user.username or current_user.name or "anonymous").strip() or "anonymous"
|
||||
|
||||
|
||||
def _can_use_direct_save_path(db: Session, context_json: dict[str, Any]) -> bool:
|
||||
review_action = str((context_json or {}).get("review_action") or "").strip()
|
||||
if review_action != "save_draft":
|
||||
return False
|
||||
review_values = context_json.get("review_form_values")
|
||||
if not isinstance(review_values, dict):
|
||||
return False
|
||||
application_claim_id = str(review_values.get("application_claim_id") or "").strip()
|
||||
application_claim_no = str(review_values.get("application_claim_no") or "").strip()
|
||||
if not application_claim_no:
|
||||
return False
|
||||
if application_claim_id:
|
||||
return True
|
||||
return _find_application_claim_by_no(db, application_claim_no) is not None
|
||||
|
||||
|
||||
def _find_application_claim_by_no(db: Session, claim_no: str) -> ExpenseClaim | None:
|
||||
normalized_claim_no = str(claim_no or "").strip()
|
||||
if not normalized_claim_no:
|
||||
return None
|
||||
claim = db.scalar(
|
||||
select(ExpenseClaim)
|
||||
.where(ExpenseClaim.claim_no == normalized_claim_no)
|
||||
.limit(1)
|
||||
)
|
||||
if claim is not None and ExpenseClaimService._is_expense_application_claim(claim):
|
||||
return claim
|
||||
return None
|
||||
|
||||
|
||||
def _build_direct_context_json(db: Session, context_json: dict[str, Any]) -> dict[str, Any]:
|
||||
direct_context = dict(context_json or {})
|
||||
review_values = dict(direct_context.get("review_form_values") or {})
|
||||
scene_selection = dict(direct_context.get("expense_scene_selection") or {})
|
||||
application_claim_id = str(review_values.get("application_claim_id") or "").strip()
|
||||
application_claim_no = str(review_values.get("application_claim_no") or "").strip()
|
||||
if not application_claim_id and application_claim_no:
|
||||
application_claim = _find_application_claim_by_no(db, application_claim_no)
|
||||
if application_claim is not None:
|
||||
review_values["application_claim_id"] = application_claim.id
|
||||
scene_selection["application_claim_id"] = application_claim.id
|
||||
scene_selection["application_claim_no"] = str(
|
||||
scene_selection.get("application_claim_no")
|
||||
or application_claim.claim_no
|
||||
or application_claim_no
|
||||
).strip()
|
||||
direct_context["review_form_values"] = review_values
|
||||
if scene_selection:
|
||||
direct_context["expense_scene_selection"] = scene_selection
|
||||
return direct_context
|
||||
|
||||
|
||||
def _run_direct_save_path(
|
||||
*,
|
||||
db: Session,
|
||||
state: LinkedReimbursementDraftJobState,
|
||||
current_user: CurrentUserContext,
|
||||
) -> tuple[str, dict[str, Any], dict[str, Any]]:
|
||||
run_id = state.job_id
|
||||
ontology = OntologyParseResult(
|
||||
scenario="expense",
|
||||
intent="draft",
|
||||
permission=OntologyPermission(
|
||||
level="draft_write",
|
||||
allowed=True,
|
||||
reason="关联申请单生成报销草稿快路径。",
|
||||
),
|
||||
confidence=1.0,
|
||||
run_id=run_id,
|
||||
)
|
||||
result = ExpenseClaimService(db).save_or_submit_from_ontology(
|
||||
run_id=run_id,
|
||||
user_id=_resolve_user_id(current_user),
|
||||
message=state.message,
|
||||
ontology=ontology,
|
||||
context_json=_build_direct_context_json(db, state.context_json),
|
||||
)
|
||||
claim_id = str(result.get("claim_id") or "").strip()
|
||||
claim_no = str(result.get("claim_no") or "").strip()
|
||||
if not claim_id or not claim_no or str(result.get("status") or "").strip() != "draft":
|
||||
raise ValueError(str(result.get("message") or "报销草稿生成失败,请稍后重试。").strip())
|
||||
|
||||
claim = db.get(ExpenseClaim, claim_id)
|
||||
return run_id, result, _build_direct_draft_payload(result, claim)
|
||||
|
||||
|
||||
def _build_direct_draft_payload(
|
||||
result: dict[str, Any],
|
||||
claim: ExpenseClaim | None,
|
||||
) -> dict[str, Any]:
|
||||
claim_id = str(result.get("claim_id") or getattr(claim, "id", "") or "").strip()
|
||||
claim_no = str(result.get("claim_no") or getattr(claim, "claim_no", "") or "").strip()
|
||||
status = str(result.get("status") or getattr(claim, "status", "") or "draft").strip()
|
||||
approval_stage = str(getattr(claim, "approval_stage", "") or "待提交").strip()
|
||||
expense_type = str(getattr(claim, "expense_type", "") or "").strip()
|
||||
message = str(result.get("message") or "报销草稿已生成。").strip()
|
||||
return {
|
||||
"draft_type": "expense",
|
||||
"title": f"费用草稿 {claim_no}" if claim_no else "费用草稿",
|
||||
"body": message,
|
||||
"confirmation_required": True,
|
||||
"claim_id": claim_id,
|
||||
"claim_no": claim_no,
|
||||
"status": status,
|
||||
"approval_stage": approval_stage,
|
||||
"expense_type": expense_type,
|
||||
}
|
||||
280
server/tests/test_attachment_association_jobs.py
Normal file
280
server/tests/test_attachment_association_jobs.py
Normal file
@@ -0,0 +1,280 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from datetime import UTC, date, datetime
|
||||
from decimal import Decimal
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
from app.api.deps import CurrentUserContext, get_db
|
||||
from app.api.v1.endpoints import attachment_association_jobs as attachment_jobs_endpoint
|
||||
from app.core.config import get_settings
|
||||
from app.main import create_app
|
||||
from app.models.employee import Employee
|
||||
from app.models.financial_record import ExpenseClaim, ExpenseClaimItem
|
||||
from app.schemas.ocr import OcrRecognizeBatchRead, OcrRecognizeDocumentRead, OcrRecognizeFieldRead
|
||||
from app.services.attachment_association_jobs import clear_attachment_association_jobs_for_tests
|
||||
from app.services.expense_claim_attachment_storage import ExpenseClaimAttachmentStorage
|
||||
from app.services.ocr import OcrService
|
||||
from app.services.receipt_folder import ReceiptFolderService
|
||||
from app.test_helpers.db import build_in_memory_session_factory
|
||||
|
||||
|
||||
def build_client(monkeypatch) -> tuple[TestClient, object]:
|
||||
session_factory = build_in_memory_session_factory()
|
||||
app = create_app()
|
||||
|
||||
def override_db() -> Generator[Session, None, None]:
|
||||
db = session_factory()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
app.dependency_overrides[get_db] = override_db
|
||||
monkeypatch.setattr(attachment_jobs_endpoint, "get_session_factory", lambda: session_factory)
|
||||
return TestClient(app), session_factory
|
||||
|
||||
|
||||
def seed_travel_claim(db: Session) -> ExpenseClaim:
|
||||
employee = Employee(
|
||||
id="emp-bg-association",
|
||||
employee_no="E10001",
|
||||
name="张三",
|
||||
email="zhangsan@example.com",
|
||||
position="实施顾问",
|
||||
grade="P4",
|
||||
)
|
||||
claim = ExpenseClaim(
|
||||
id="claim-bg-association",
|
||||
claim_no="BX-20260220-001",
|
||||
employee_id=employee.id,
|
||||
employee_name=employee.name,
|
||||
department_id="dept-delivery",
|
||||
department_name="交付部",
|
||||
project_code=None,
|
||||
expense_type="travel",
|
||||
reason="辅助国网仿生产服务器部署,武汉往返上海",
|
||||
location="上海",
|
||||
amount=Decimal("0.00"),
|
||||
currency="CNY",
|
||||
invoice_count=0,
|
||||
occurred_at=datetime(2026, 2, 20, tzinfo=UTC),
|
||||
submitted_at=None,
|
||||
status="draft",
|
||||
approval_stage="待提交",
|
||||
risk_flags_json=[],
|
||||
)
|
||||
item = ExpenseClaimItem(
|
||||
id="item-bg-association-1",
|
||||
claim_id=claim.id,
|
||||
item_date=date(2026, 2, 20),
|
||||
item_type="train_ticket",
|
||||
item_reason="武汉至上海高铁",
|
||||
item_location="上海",
|
||||
item_amount=Decimal("0.00"),
|
||||
invoice_id=None,
|
||||
)
|
||||
claim.items = [item]
|
||||
db.add_all([employee, claim])
|
||||
db.commit()
|
||||
return claim
|
||||
|
||||
|
||||
def save_train_receipt(
|
||||
*,
|
||||
service: ReceiptFolderService,
|
||||
current_user: CurrentUserContext,
|
||||
filename: str,
|
||||
route: str,
|
||||
trip_date: str,
|
||||
) -> str:
|
||||
receipt = service.save_receipt(
|
||||
filename=filename,
|
||||
content=f"fake-pdf-{filename}".encode("utf-8"),
|
||||
media_type="application/pdf",
|
||||
current_user=current_user,
|
||||
document=OcrRecognizeDocumentRead(
|
||||
filename=filename,
|
||||
media_type="application/pdf",
|
||||
text=f"电子发票(铁路电子客票) {route} {trip_date} 票价 354 元",
|
||||
summary=f"铁路电子客票,{route},票价 354 元。",
|
||||
avg_score=0.96,
|
||||
line_count=1,
|
||||
page_count=1,
|
||||
document_type="train_ticket",
|
||||
document_type_label="火车/高铁票",
|
||||
scene_code="travel",
|
||||
scene_label="差旅票据",
|
||||
document_fields=[
|
||||
OcrRecognizeFieldRead(key="date", label="列车出发时间", value=trip_date),
|
||||
OcrRecognizeFieldRead(key="route", label="行程", value=route),
|
||||
OcrRecognizeFieldRead(key="amount", label="金额", value="354元"),
|
||||
],
|
||||
),
|
||||
)
|
||||
return receipt.id
|
||||
|
||||
|
||||
def fake_ocr_recognize(
|
||||
self,
|
||||
files: list[tuple[str, bytes, str | None]],
|
||||
) -> OcrRecognizeBatchRead:
|
||||
filename = files[0][0]
|
||||
return OcrRecognizeBatchRead(
|
||||
total_file_count=1,
|
||||
success_count=1,
|
||||
documents=[
|
||||
OcrRecognizeDocumentRead(
|
||||
filename=filename,
|
||||
media_type=files[0][2] or "application/pdf",
|
||||
text="电子发票(铁路电子客票) 武汉 上海 2026-02-20 票价 354 元",
|
||||
summary="铁路电子客票,武汉至上海,票价 354 元。",
|
||||
avg_score=0.96,
|
||||
line_count=1,
|
||||
page_count=1,
|
||||
document_type="train_ticket",
|
||||
document_type_label="火车/高铁票",
|
||||
scene_code="travel",
|
||||
scene_label="差旅票据",
|
||||
document_fields=[
|
||||
OcrRecognizeFieldRead(key="date", label="列车出发时间", value="2026-02-20"),
|
||||
OcrRecognizeFieldRead(key="route", label="行程", value="武汉-上海"),
|
||||
OcrRecognizeFieldRead(key="amount", label="金额", value="354元"),
|
||||
],
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_attachment_association_job_links_receipts_after_conversation_exit(monkeypatch, tmp_path) -> None:
|
||||
monkeypatch.setenv("STORAGE_ROOT_DIR", str(tmp_path / "storage"))
|
||||
get_settings.cache_clear()
|
||||
clear_attachment_association_jobs_for_tests()
|
||||
monkeypatch.setattr(OcrService, "recognize_files", fake_ocr_recognize)
|
||||
monkeypatch.setattr(ExpenseClaimAttachmentStorage, "root", lambda self: tmp_path / "attachments")
|
||||
try:
|
||||
client, session_factory = build_client(monkeypatch)
|
||||
current_user = CurrentUserContext(
|
||||
username="zhangsan@example.com",
|
||||
name="张三",
|
||||
role_codes=["user"],
|
||||
is_admin=False,
|
||||
employee_no="E10001",
|
||||
)
|
||||
with session_factory() as db:
|
||||
seed_travel_claim(db)
|
||||
|
||||
receipt_service = ReceiptFolderService()
|
||||
receipt_ids = [
|
||||
save_train_receipt(
|
||||
service=receipt_service,
|
||||
current_user=current_user,
|
||||
filename="2月20 武汉-上海.pdf",
|
||||
route="武汉-上海",
|
||||
trip_date="2026-02-20",
|
||||
),
|
||||
save_train_receipt(
|
||||
service=receipt_service,
|
||||
current_user=current_user,
|
||||
filename="2月23 上海-武汉.pdf",
|
||||
route="上海-武汉",
|
||||
trip_date="2026-02-23",
|
||||
),
|
||||
]
|
||||
|
||||
headers = {
|
||||
"x-auth-username": "zhangsan@example.com",
|
||||
"x-auth-name": "Zhang San",
|
||||
"x-auth-employee-no": "E10001",
|
||||
"x-auth-role-codes": "user",
|
||||
}
|
||||
response = client.post(
|
||||
"/api/v1/reimbursements/attachment-association-jobs",
|
||||
headers=headers,
|
||||
json={
|
||||
"receipt_ids": receipt_ids,
|
||||
"prompt": "请帮我处理已上传的附件。",
|
||||
"conversation_id": "inline-test",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 202
|
||||
job_id = response.json()["job_id"]
|
||||
|
||||
status_response = client.get(
|
||||
f"/api/v1/reimbursements/attachment-association-jobs/{job_id}",
|
||||
headers=headers,
|
||||
)
|
||||
assert status_response.status_code == 200
|
||||
payload = status_response.json()
|
||||
assert payload["status"] == "succeeded"
|
||||
assert payload["claim_id"] == "claim-bg-association"
|
||||
assert payload["claim_no"] == "BX-20260220-001"
|
||||
assert payload["uploaded_count"] == 2
|
||||
|
||||
with session_factory() as db:
|
||||
claim = db.scalar(
|
||||
select(ExpenseClaim)
|
||||
.options(selectinload(ExpenseClaim.items))
|
||||
.where(ExpenseClaim.id == "claim-bg-association")
|
||||
)
|
||||
assert claim is not None
|
||||
attached_items = [item for item in claim.items if item.invoice_id]
|
||||
assert len(attached_items) == 2
|
||||
|
||||
linked_receipts = receipt_service.list_receipts(current_user=current_user, status_filter="linked")
|
||||
assert {item.id for item in linked_receipts} == set(receipt_ids)
|
||||
assert {item.linked_claim_id for item in linked_receipts} == {"claim-bg-association"}
|
||||
finally:
|
||||
clear_attachment_association_jobs_for_tests()
|
||||
get_settings.cache_clear()
|
||||
|
||||
|
||||
def test_attachment_association_job_fails_without_editable_claim(monkeypatch, tmp_path) -> None:
|
||||
monkeypatch.setenv("STORAGE_ROOT_DIR", str(tmp_path / "storage"))
|
||||
get_settings.cache_clear()
|
||||
clear_attachment_association_jobs_for_tests()
|
||||
try:
|
||||
client, _session_factory = build_client(monkeypatch)
|
||||
current_user = CurrentUserContext(
|
||||
username="zhangsan@example.com",
|
||||
name="张三",
|
||||
role_codes=["user"],
|
||||
is_admin=False,
|
||||
employee_no="E10001",
|
||||
)
|
||||
receipt_id = save_train_receipt(
|
||||
service=ReceiptFolderService(),
|
||||
current_user=current_user,
|
||||
filename="2月20 武汉-上海.pdf",
|
||||
route="武汉-上海",
|
||||
trip_date="2026-02-20",
|
||||
)
|
||||
|
||||
headers = {
|
||||
"x-auth-username": "zhangsan@example.com",
|
||||
"x-auth-name": "Zhang San",
|
||||
"x-auth-employee-no": "E10001",
|
||||
"x-auth-role-codes": "user",
|
||||
}
|
||||
response = client.post(
|
||||
"/api/v1/reimbursements/attachment-association-jobs",
|
||||
headers=headers,
|
||||
json={"receipt_ids": [receipt_id], "conversation_id": "inline-empty"},
|
||||
)
|
||||
assert response.status_code == 202
|
||||
|
||||
status_response = client.get(
|
||||
f"/api/v1/reimbursements/attachment-association-jobs/{response.json()['job_id']}",
|
||||
headers=headers,
|
||||
)
|
||||
assert status_response.status_code == 200
|
||||
payload = status_response.json()
|
||||
assert payload["status"] == "failed"
|
||||
assert "没有找到可自动关联的报销草稿" in payload["message"]
|
||||
finally:
|
||||
clear_attachment_association_jobs_for_tests()
|
||||
get_settings.cache_clear()
|
||||
296
server/tests/test_linked_reimbursement_draft_jobs.py
Normal file
296
server/tests/test_linked_reimbursement_draft_jobs.py
Normal file
@@ -0,0 +1,296 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.deps import get_db
|
||||
from app.api.v1.endpoints import linked_reimbursement_draft_jobs as draft_jobs_endpoint
|
||||
from app.main import create_app
|
||||
from app.models.employee import Employee
|
||||
from app.models.financial_record import ExpenseClaim
|
||||
from app.schemas.orchestrator import OrchestratorResponse, OrchestratorTraceSummary
|
||||
from app.services.linked_reimbursement_draft_jobs import clear_linked_reimbursement_draft_jobs_for_tests
|
||||
from app.services.orchestrator import OrchestratorService
|
||||
from app.test_helpers.db import build_in_memory_session_factory
|
||||
|
||||
|
||||
def seed_employee_and_application(db: Session) -> None:
|
||||
employee = Employee(
|
||||
id="emp-linked-draft-fast",
|
||||
employee_no="E10001",
|
||||
name="张三",
|
||||
email="zhangsan@example.com",
|
||||
position="实施顾问",
|
||||
grade="P5",
|
||||
)
|
||||
application = ExpenseClaim(
|
||||
id="application-linked-draft-fast",
|
||||
claim_no="AP-202606-FAST",
|
||||
employee_id=employee.id,
|
||||
employee_name=employee.name,
|
||||
department_id="dept-delivery",
|
||||
department_name="交付部",
|
||||
project_code=None,
|
||||
expense_type="travel_application",
|
||||
reason="支撑国网仿生产服务器部署",
|
||||
location="上海",
|
||||
amount=3000,
|
||||
currency="CNY",
|
||||
invoice_count=0,
|
||||
occurred_at=datetime(2026, 2, 20, tzinfo=UTC),
|
||||
submitted_at=None,
|
||||
status="approved",
|
||||
approval_stage="已完成",
|
||||
risk_flags_json=[],
|
||||
)
|
||||
db.add_all([employee, application])
|
||||
db.commit()
|
||||
|
||||
|
||||
def build_client(monkeypatch) -> tuple[TestClient, object]:
|
||||
session_factory = build_in_memory_session_factory()
|
||||
app = create_app()
|
||||
|
||||
def override_db() -> Generator[Session, None, None]:
|
||||
db = session_factory()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
app.dependency_overrides[get_db] = override_db
|
||||
monkeypatch.setattr(draft_jobs_endpoint, "get_session_factory", lambda: session_factory)
|
||||
return TestClient(app), session_factory
|
||||
|
||||
|
||||
def test_linked_reimbursement_draft_job_runs_after_conversation_leaves(monkeypatch) -> None:
|
||||
clear_linked_reimbursement_draft_jobs_for_tests()
|
||||
captured_messages = []
|
||||
|
||||
def fake_run(self, payload):
|
||||
captured_messages.append(payload.message)
|
||||
return OrchestratorResponse(
|
||||
run_id="run-linked-draft-job",
|
||||
conversation_id=None,
|
||||
selected_agent="user_agent",
|
||||
route_reason="测试后台生成报销草稿。",
|
||||
permission_level="draft_write",
|
||||
status="succeeded",
|
||||
result={
|
||||
"message": "报销草稿已生成。",
|
||||
"draft_payload": {
|
||||
"claim_id": "draft-linked-1",
|
||||
"claim_no": "RE-202606-009",
|
||||
"status": "draft",
|
||||
"expense_type": "travel",
|
||||
},
|
||||
},
|
||||
requires_confirmation=False,
|
||||
trace_summary=OrchestratorTraceSummary(
|
||||
scenario="expense",
|
||||
intent="draft",
|
||||
tool_count=1,
|
||||
failed_tool_count=0,
|
||||
selected_capability_codes=[],
|
||||
degraded=False,
|
||||
),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(OrchestratorService, "run", fake_run)
|
||||
try:
|
||||
client, _session_factory = build_client(monkeypatch)
|
||||
headers = {
|
||||
"x-auth-username": "zhangsan@example.com",
|
||||
"x-auth-name": "Zhang San",
|
||||
"x-auth-employee-no": "E10001",
|
||||
"x-auth-role-codes": "user",
|
||||
}
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/reimbursements/linked-reimbursement-draft-jobs",
|
||||
headers=headers,
|
||||
json={
|
||||
"message": "我要报销\n用户选择报销场景:差旅费\n关联申请单:AP-202606-001",
|
||||
"conversation_id": "inline-test",
|
||||
"context_json": {
|
||||
"review_action": "save_draft",
|
||||
"expense_scene_selection": {
|
||||
"expense_type": "travel",
|
||||
"expense_type_label": "差旅费",
|
||||
"application_claim_no": "AP-202606-001",
|
||||
},
|
||||
"review_form_values": {
|
||||
"application_claim_no": "AP-202606-001",
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 202
|
||||
job_id = response.json()["job_id"]
|
||||
|
||||
status_response = client.get(
|
||||
f"/api/v1/reimbursements/linked-reimbursement-draft-jobs/{job_id}",
|
||||
headers=headers,
|
||||
)
|
||||
assert status_response.status_code == 200
|
||||
payload = status_response.json()
|
||||
assert payload["status"] == "succeeded"
|
||||
assert payload["draft_payload"]["claim_no"] == "RE-202606-009"
|
||||
assert payload["run_id"] == "run-linked-draft-job"
|
||||
assert captured_messages == ["我要报销\n用户选择报销场景:差旅费\n关联申请单:AP-202606-001"]
|
||||
finally:
|
||||
clear_linked_reimbursement_draft_jobs_for_tests()
|
||||
|
||||
|
||||
def test_linked_reimbursement_draft_job_uses_direct_save_path(monkeypatch) -> None:
|
||||
clear_linked_reimbursement_draft_jobs_for_tests()
|
||||
|
||||
def fail_if_orchestrator_runs(self, payload):
|
||||
raise AssertionError("linked draft job should not run full orchestrator")
|
||||
|
||||
monkeypatch.setattr(OrchestratorService, "run", fail_if_orchestrator_runs)
|
||||
try:
|
||||
client, session_factory = build_client(monkeypatch)
|
||||
with session_factory() as db:
|
||||
seed_employee_and_application(db)
|
||||
|
||||
headers = {
|
||||
"x-auth-username": "zhangsan@example.com",
|
||||
"x-auth-name": "Zhang San",
|
||||
"x-auth-employee-no": "E10001",
|
||||
"x-auth-role-codes": "user",
|
||||
}
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/reimbursements/linked-reimbursement-draft-jobs",
|
||||
headers=headers,
|
||||
json={
|
||||
"message": "我要报销\n用户选择报销场景:差旅费\n关联申请单:AP-202606-FAST",
|
||||
"conversation_id": "inline-fast-test",
|
||||
"context_json": {
|
||||
"name": "张三",
|
||||
"review_action": "save_draft",
|
||||
"expense_scene_selection": {
|
||||
"expense_type": "travel",
|
||||
"expense_type_label": "差旅费",
|
||||
"application_claim_id": "application-linked-draft-fast",
|
||||
"application_claim_no": "AP-202606-FAST",
|
||||
},
|
||||
"review_form_values": {
|
||||
"expense_type": "差旅费",
|
||||
"reason": "支撑国网仿生产服务器部署",
|
||||
"location": "上海",
|
||||
"time_range": "2026-02-20 至 2026-02-23",
|
||||
"application_claim_id": "application-linked-draft-fast",
|
||||
"application_claim_no": "AP-202606-FAST",
|
||||
"application_reason": "支撑国网仿生产服务器部署",
|
||||
"application_location": "上海",
|
||||
"application_amount": "3000",
|
||||
"application_amount_label": "¥3,000",
|
||||
"application_business_time": "2026-02-20 至 2026-02-23",
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 202
|
||||
job_id = response.json()["job_id"]
|
||||
|
||||
status_response = client.get(
|
||||
f"/api/v1/reimbursements/linked-reimbursement-draft-jobs/{job_id}",
|
||||
headers=headers,
|
||||
)
|
||||
assert status_response.status_code == 200
|
||||
payload = status_response.json()
|
||||
assert payload["status"] == "succeeded"
|
||||
assert payload["draft_payload"]["claim_no"]
|
||||
assert payload["draft_payload"]["claim_id"]
|
||||
assert payload["run_id"].startswith("linked-reimbursement-draft-")
|
||||
|
||||
with session_factory() as db:
|
||||
draft = db.get(ExpenseClaim, payload["draft_payload"]["claim_id"])
|
||||
assert draft is not None
|
||||
assert draft.status == "draft"
|
||||
assert draft.expense_type == "travel"
|
||||
assert draft.reason == "支撑国网仿生产服务器部署"
|
||||
assert draft.items == []
|
||||
finally:
|
||||
clear_linked_reimbursement_draft_jobs_for_tests()
|
||||
|
||||
|
||||
def test_linked_reimbursement_draft_job_uses_direct_save_path_with_application_no_only(monkeypatch) -> None:
|
||||
clear_linked_reimbursement_draft_jobs_for_tests()
|
||||
|
||||
def fail_if_orchestrator_runs(self, payload):
|
||||
raise AssertionError("linked draft job should resolve application no without full orchestrator")
|
||||
|
||||
monkeypatch.setattr(OrchestratorService, "run", fail_if_orchestrator_runs)
|
||||
try:
|
||||
client, session_factory = build_client(monkeypatch)
|
||||
with session_factory() as db:
|
||||
seed_employee_and_application(db)
|
||||
|
||||
headers = {
|
||||
"x-auth-username": "zhangsan@example.com",
|
||||
"x-auth-name": "Zhang San",
|
||||
"x-auth-employee-no": "E10001",
|
||||
"x-auth-role-codes": "user",
|
||||
}
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/reimbursements/linked-reimbursement-draft-jobs",
|
||||
headers=headers,
|
||||
json={
|
||||
"message": "我要报销\n用户选择报销场景:差旅费\n关联申请单:AP-202606-FAST",
|
||||
"conversation_id": "inline-fast-no-id-test",
|
||||
"context_json": {
|
||||
"name": "张三",
|
||||
"review_action": "save_draft",
|
||||
"expense_scene_selection": {
|
||||
"expense_type": "travel",
|
||||
"expense_type_label": "差旅费",
|
||||
"application_claim_no": "AP-202606-FAST",
|
||||
},
|
||||
"review_form_values": {
|
||||
"expense_type": "差旅费",
|
||||
"reason": "支撑国网仿生产服务器部署",
|
||||
"location": "上海",
|
||||
"time_range": "2026-02-20 至 2026-02-23",
|
||||
"application_claim_no": "AP-202606-FAST",
|
||||
"application_reason": "支撑国网仿生产服务器部署",
|
||||
"application_location": "上海",
|
||||
"application_amount": "3000",
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 202
|
||||
job_id = response.json()["job_id"]
|
||||
|
||||
status_response = client.get(
|
||||
f"/api/v1/reimbursements/linked-reimbursement-draft-jobs/{job_id}",
|
||||
headers=headers,
|
||||
)
|
||||
assert status_response.status_code == 200
|
||||
payload = status_response.json()
|
||||
assert payload["status"] == "succeeded"
|
||||
assert payload["draft_payload"]["claim_no"]
|
||||
assert payload["draft_payload"]["claim_id"]
|
||||
|
||||
with session_factory() as db:
|
||||
draft = db.get(ExpenseClaim, payload["draft_payload"]["claim_id"])
|
||||
assert draft is not None
|
||||
link_flag = next(
|
||||
flag
|
||||
for flag in draft.risk_flags_json
|
||||
if flag.get("source") == "application_link"
|
||||
)
|
||||
assert link_flag["application_claim_no"] == "AP-202606-FAST"
|
||||
assert link_flag["application_claim_id"] == "application-linked-draft-fast"
|
||||
finally:
|
||||
clear_linked_reimbursement_draft_jobs_for_tests()
|
||||
Reference in New Issue
Block a user