feat(server): 新增附件关联/关联报销草稿后台任务与申请位置语义

- attachment_association_jobs:从票据夹批量关联附件到报销单,识别城市/日期并创建明细项,内存态 job 跟踪
- linked_reimbursement_draft_jobs:基于申请单异步生成关联报销草稿,调用 Orchestrator 编排,区分 succeeded/failed 终态
- application_location_semantics:抽取差旅出发/到达城市、判断具体地址/业务动作等位置语义,供申请单校验复用
- router 注册两个 job 端点,新增对应 job/语义单元测试
This commit is contained in:
caoxiaozhu
2026-06-24 10:42:05 +08:00
parent d4ff79f326
commit 332f77389d
10 changed files with 1830 additions and 0 deletions

View File

@@ -0,0 +1,75 @@
from __future__ import annotations
from typing import Annotated
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, status
from app.api.deps import CurrentUserContext, get_current_user
from app.db.session import get_session_factory
from app.schemas.attachment_association_job import (
AttachmentAssociationJobCreate,
AttachmentAssociationJobRead,
)
from app.schemas.common import ErrorResponse
from app.services.attachment_association_jobs import (
create_attachment_association_job,
get_attachment_association_job,
run_attachment_association_job,
)
router = APIRouter(prefix="/reimbursements/attachment-association-jobs")
CurrentUser = Annotated[CurrentUserContext, Depends(get_current_user)]
@router.post(
"",
response_model=AttachmentAssociationJobRead,
status_code=status.HTTP_202_ACCEPTED,
summary="创建附件自动关联后台任务",
description="根据已 OCR 入票据夹的 receipt_id在后台自动匹配并归集到报销草稿。",
responses={
status.HTTP_400_BAD_REQUEST: {
"model": ErrorResponse,
"description": "请求缺少可关联票据。",
},
},
)
def create_attachment_association_job_endpoint(
payload: AttachmentAssociationJobCreate,
background_tasks: BackgroundTasks,
current_user: CurrentUser,
) -> AttachmentAssociationJobRead:
try:
job = create_attachment_association_job(payload, current_user)
except ValueError as exc:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
background_tasks.add_task(
run_attachment_association_job,
job.job_id,
current_user,
get_session_factory(),
)
return job
@router.get(
"/{job_id}",
response_model=AttachmentAssociationJobRead,
summary="查询附件自动关联后台任务",
description="用于前端会话恢复后按 job_id 查询任务状态。",
responses={
status.HTTP_404_NOT_FOUND: {
"model": ErrorResponse,
"description": "任务不存在或当前用户无权查看。",
},
},
)
def get_attachment_association_job_endpoint(
job_id: str,
current_user: CurrentUser,
) -> AttachmentAssociationJobRead:
job = get_attachment_association_job(job_id, current_user)
if job is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="附件关联任务不存在或已失效。")
return job

View File

@@ -0,0 +1,74 @@
from __future__ import annotations
from typing import Annotated
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, status
from app.api.deps import CurrentUserContext, get_current_user
from app.db.session import get_session_factory
from app.schemas.common import ErrorResponse
from app.schemas.linked_reimbursement_draft_job import (
LinkedReimbursementDraftJobCreate,
LinkedReimbursementDraftJobRead,
)
from app.services.linked_reimbursement_draft_jobs import (
create_linked_reimbursement_draft_job,
get_linked_reimbursement_draft_job,
run_linked_reimbursement_draft_job,
)
router = APIRouter(prefix="/reimbursements/linked-reimbursement-draft-jobs")
CurrentUser = Annotated[CurrentUserContext, Depends(get_current_user)]
@router.post(
"",
response_model=LinkedReimbursementDraftJobRead,
status_code=status.HTTP_202_ACCEPTED,
summary="创建关联申请单生成报销草稿后台任务",
description="用户选择关联申请单后,后台继续生成报销草稿,避免当前会话长时间同步等待。",
responses={
status.HTTP_400_BAD_REQUEST: {
"model": ErrorResponse,
"description": "请求缺少申请单关联上下文。",
},
},
)
def create_linked_reimbursement_draft_job_endpoint(
payload: LinkedReimbursementDraftJobCreate,
background_tasks: BackgroundTasks,
current_user: CurrentUser,
) -> LinkedReimbursementDraftJobRead:
try:
job = create_linked_reimbursement_draft_job(payload, current_user)
except ValueError as exc:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
background_tasks.add_task(
run_linked_reimbursement_draft_job,
job.job_id,
current_user,
get_session_factory(),
)
return job
@router.get(
"/{job_id}",
response_model=LinkedReimbursementDraftJobRead,
summary="查询关联申请单生成报销草稿后台任务",
description="用于前端按 job_id 查询草稿生成状态。",
responses={
status.HTTP_404_NOT_FOUND: {
"model": ErrorResponse,
"description": "任务不存在或当前用户无权查看。",
},
},
)
def get_linked_reimbursement_draft_job_endpoint(
job_id: str,
current_user: CurrentUser,
) -> LinkedReimbursementDraftJobRead:
job = get_linked_reimbursement_draft_job(job_id, current_user)
if job is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="报销草稿生成任务不存在或已失效。")
return job

View File

@@ -6,6 +6,7 @@ from app.api.v1.endpoints.agent_feedback import router as agent_feedback_router
from app.api.v1.endpoints.agent_runs import router as agent_runs_router from app.api.v1.endpoints.agent_runs import router as agent_runs_router
from app.api.v1.endpoints.agent_traces import router as agent_traces_router from app.api.v1.endpoints.agent_traces import router as agent_traces_router
from app.api.v1.endpoints.analytics import router as analytics_router from app.api.v1.endpoints.analytics import router as analytics_router
from app.api.v1.endpoints.attachment_association_jobs import router as attachment_association_jobs_router
from app.api.v1.endpoints.audit_logs import router as audit_logs_router from app.api.v1.endpoints.audit_logs import router as audit_logs_router
from app.api.v1.endpoints.auth import router as auth_router from app.api.v1.endpoints.auth import router as auth_router
from app.api.v1.endpoints.bootstrap import router as bootstrap_router from app.api.v1.endpoints.bootstrap import router as bootstrap_router
@@ -14,6 +15,7 @@ from app.api.v1.endpoints.employees import router as employees_router
from app.api.v1.endpoints.employee_profiles import router as employee_profiles_router from app.api.v1.endpoints.employee_profiles import router as employee_profiles_router
from app.api.v1.endpoints.health import router as health_router from app.api.v1.endpoints.health import router as health_router
from app.api.v1.endpoints.knowledge import router as knowledge_router from app.api.v1.endpoints.knowledge import router as knowledge_router
from app.api.v1.endpoints.linked_reimbursement_draft_jobs import router as linked_reimbursement_draft_jobs_router
from app.api.v1.endpoints.notification_states import router as notification_states_router from app.api.v1.endpoints.notification_states import router as notification_states_router
from app.api.v1.endpoints.ocr import router as ocr_router from app.api.v1.endpoints.ocr import router as ocr_router
from app.api.v1.endpoints.ontology import router as ontology_router from app.api.v1.endpoints.ontology import router as ontology_router
@@ -36,8 +38,10 @@ router.include_router(agent_feedback_router, tags=["agent-feedback"])
router.include_router(agent_runs_router, tags=["agent-runs"]) router.include_router(agent_runs_router, tags=["agent-runs"])
router.include_router(agent_traces_router, tags=["agent-traces"]) router.include_router(agent_traces_router, tags=["agent-traces"])
router.include_router(analytics_router, tags=["analytics"]) router.include_router(analytics_router, tags=["analytics"])
router.include_router(attachment_association_jobs_router, tags=["attachment-association-jobs"])
router.include_router(audit_logs_router, tags=["audit-logs"]) router.include_router(audit_logs_router, tags=["audit-logs"])
router.include_router(knowledge_router, tags=["knowledge"]) router.include_router(knowledge_router, tags=["knowledge"])
router.include_router(linked_reimbursement_draft_jobs_router, tags=["linked-reimbursement-draft-jobs"])
router.include_router(notification_states_router, tags=["notification-states"]) router.include_router(notification_states_router, tags=["notification-states"])
router.include_router(ocr_router, tags=["ocr"]) router.include_router(ocr_router, tags=["ocr"])
router.include_router(ontology_router, tags=["ontology"]) router.include_router(ontology_router, tags=["ontology"])

View File

@@ -0,0 +1,40 @@
from __future__ import annotations
from datetime import datetime
from pydantic import BaseModel, Field, field_validator
class AttachmentAssociationJobCreate(BaseModel):
receipt_ids: list[str] = Field(default_factory=list, description="票据夹持久化票据 ID。")
prompt: str = Field(default="", max_length=1000, description="用户发送时的上下文说明。")
conversation_id: str = Field(default="", max_length=120, description="前端会话 ID用于状态恢复。")
@field_validator("receipt_ids")
@classmethod
def validate_receipt_ids(cls, value: list[str]) -> list[str]:
receipt_ids = [
str(item or "").strip()
for item in list(value or [])
if str(item or "").strip()
]
if not receipt_ids:
raise ValueError("请先完成附件 OCR 识别,再发起自动关联。")
return list(dict.fromkeys(receipt_ids))
class AttachmentAssociationJobRead(BaseModel):
job_id: str
status: str
message: str = ""
receipt_ids: list[str] = Field(default_factory=list)
claim_id: str = ""
claim_no: str = ""
uploaded_count: int = 0
skipped_count: int = 0
error: str = ""
prompt: str = ""
conversation_id: str = ""
created_at: datetime
updated_at: datetime

View File

@@ -0,0 +1,32 @@
from __future__ import annotations
from datetime import datetime
from typing import Any
from pydantic import BaseModel, Field, field_validator
class LinkedReimbursementDraftJobCreate(BaseModel):
message: str = Field(min_length=1, max_length=3000, description="生成报销草稿的原始助手请求。")
context_json: dict[str, Any] = Field(default_factory=dict, description="复用 Orchestrator 的上下文。")
conversation_id: str = Field(default="", max_length=120, description="前端会话 ID用于状态恢复。")
@field_validator("message")
@classmethod
def validate_message(cls, value: str) -> str:
normalized = str(value or "").strip()
if not normalized:
raise ValueError("请先选择要关联的申请单。")
return normalized
class LinkedReimbursementDraftJobRead(BaseModel):
job_id: str
status: str
message: str = ""
error: str = ""
run_id: str = ""
conversation_id: str = ""
draft_payload: dict[str, Any] | None = None
created_at: datetime
updated_at: datetime

View File

@@ -0,0 +1,189 @@
from __future__ import annotations
import re
from collections.abc import Iterable
from functools import lru_cache
from typing import Any
from app.services.user_agent_application_locations import (
CITY_TO_PROVINCE,
DIRECT_MUNICIPALITY_DISPLAY,
)
PLACEHOLDER_LOCATION_VALUES = {"", "待补充", "待确认", "未知", "暂无", "", "null", "none"}
BUSINESS_ACTION_PATTERN = re.compile(r"(?:支撑|支持|辅助|部署|上线|实施|验收|项目)")
BUSINESS_OBJECT_PATTERN = re.compile(r"(?:服务器|系统|仿生产|生产环境|测试环境)")
SPECIFIC_ADDRESS_HINT_PATTERN = re.compile(r"""
(?:省|市|区|县|自治州|州|镇|乡|街道|路|街|大道|园区|大厦|中心|基地|机场|车站|高铁站|火车站|酒店|楼|号)$
""", re.VERBOSE)
LOCATION_TAGS = {"LOC", "ns", "s"}
JIEBA_LOCATION_TAGS = {"ns"}
JIEBA_CUSTOM_WORDS = (
"国网",
"仿生产",
"生产环境",
"测试环境",
"服务器",
"部署",
"辅助",
"支撑",
"支持",
"上线",
"实施",
"验收",
)
ROUTE_LOCATION_PREFIX_PATTERN = re.compile(
r"^(?P<prefix>.*?(?:出差|前往|去|到|赴))(?P<body>[\u4e00-\u9fa5].*)$"
)
def compact_application_location_text(value: object) -> str:
text = re.sub(r"\s+", "", str(value or ""))
text = re.sub(r"^(?:地点|业务地点|发生地点|目的地)[:]", "", text)
text = re.sub(r"^(?:去|到|赴|前往)", "", text)
return text.strip(":,。;;、")
def validate_application_location_text(value: object) -> str:
text = compact_application_location_text(value)
if text.lower() in PLACEHOLDER_LOCATION_VALUES:
return ""
if not location_mixes_business_content(text):
return ""
return (
f"地点“{text}”混入了业务事项,请填写真实出差地点,例如“上海”;"
"业务背景请放在申请事由中。"
)
def location_mixes_business_content(value: object) -> bool:
text = compact_application_location_text(value)
if text.lower() in PLACEHOLDER_LOCATION_VALUES:
return False
if _matches_business_location_pattern(text):
return True
return _lac_detects_business_location_mix(text)
def _matches_business_location_pattern(text: str) -> bool:
if BUSINESS_ACTION_PATTERN.search(text):
return True
if BUSINESS_OBJECT_PATTERN.search(text) and not SPECIFIC_ADDRESS_HINT_PATTERN.search(text):
return True
return False
def _lac_detects_business_location_mix(text: str) -> bool:
tokens = list(resolve_lac_tokens(text))
if not tokens:
return False
has_location = any(tag in LOCATION_TAGS for _, tag in tokens)
if not has_location:
return False
non_location_text = "".join(
word
for word, tag in tokens
if tag not in LOCATION_TAGS and tag != "w"
)
return _matches_business_location_pattern(non_location_text)
@lru_cache(maxsize=1)
def _load_lac_analyzer() -> Any:
try:
from LAC import LAC # type: ignore
except Exception:
return None
try:
return LAC(mode="lac")
except Exception:
return None
def resolve_lac_tokens(text: str) -> Iterable[tuple[str, str]]:
analyzer = _load_lac_analyzer()
if analyzer is None:
return []
try:
result = analyzer.run(text)
except Exception:
return []
return _parse_lac_result(result)
@lru_cache(maxsize=1)
def _load_jieba_posseg() -> Any:
try:
import jieba
import jieba.posseg as pseg
except Exception:
return None
for word in _iter_jieba_custom_words():
jieba.add_word(word, freq=100000)
return pseg
def _iter_jieba_custom_words() -> Iterable[str]:
yield from JIEBA_CUSTOM_WORDS
yield from DIRECT_MUNICIPALITY_DISPLAY
yield from CITY_TO_PROVINCE
def resolve_jieba_tokens(text: str) -> list[tuple[str, str]]:
posseg = _load_jieba_posseg()
if posseg is None:
return []
try:
return [
(str(item.word or "").strip(), str(item.flag or "").strip())
for item in posseg.cut(str(text or ""), HMM=True)
if str(item.word or "").strip()
]
except Exception:
return []
def strip_route_location_prefix_with_jieba(value: object) -> str:
text = str(value or "").strip()
match = ROUTE_LOCATION_PREFIX_PATTERN.search(text)
if not match:
return text
body = match.group("body").strip()
tokens = resolve_jieba_tokens(body)
if not tokens:
return text
first_word, first_tag = tokens[0]
if not _is_jieba_location_token(first_word, first_tag):
return text
return body[len(first_word) :].strip(" :,。;;、")
def _is_jieba_location_token(word: str, tag: str) -> bool:
if tag in JIEBA_LOCATION_TAGS:
return True
return word in DIRECT_MUNICIPALITY_DISPLAY or word in CITY_TO_PROVINCE
def _parse_lac_result(result: Any) -> list[tuple[str, str]]:
if (
isinstance(result, (list, tuple))
and len(result) == 2
and isinstance(result[0], list)
and isinstance(result[1], list)
):
return [
(str(word or "").strip(), str(tag or "").strip())
for word, tag in zip(result[0], result[1], strict=False)
if str(word or "").strip()
]
if isinstance(result, list) and all(
isinstance(item, (list, tuple)) and len(item) >= 2
for item in result
):
return [
(str(item[0] or "").strip(), str(item[1] or "").strip())
for item in result
if str(item[0] or "").strip()
]
return []

View File

@@ -0,0 +1,549 @@
from __future__ import annotations
import re
from dataclasses import dataclass, field
from datetime import UTC, date, datetime
from decimal import Decimal
from threading import Lock
from typing import Any, Callable
from uuid import uuid4
from sqlalchemy.orm import Session, sessionmaker
from app.api.deps import CurrentUserContext
from app.models.financial_record import ExpenseClaim, ExpenseClaimItem
from app.schemas.attachment_association_job import (
AttachmentAssociationJobCreate,
AttachmentAssociationJobRead,
)
from app.schemas.receipt_folder import ReceiptFolderDetailRead
from app.schemas.reimbursement import ExpenseClaimItemCreate
from app.services.expense_claim_constants import (
DOCUMENT_TYPE_ITEM_TYPE_MAP,
EDITABLE_CLAIM_STATUSES,
)
from app.services.expense_claims import ExpenseClaimService
from app.services.receipt_folder import ReceiptFolderService
CITY_NAMES = (
"北京",
"上海",
"广州",
"深圳",
"武汉",
"南京",
"杭州",
"成都",
"重庆",
"西安",
"天津",
"苏州",
"长沙",
"郑州",
"青岛",
"厦门",
"宁波",
"无锡",
"合肥",
"福州",
"昆明",
"大连",
"沈阳",
"济南",
"哈尔滨",
"长春",
"南昌",
"太原",
"贵阳",
"南宁",
"石家庄",
"兰州",
"银川",
"西宁",
"海口",
"拉萨",
)
TERMINAL_STATUSES = {"succeeded", "failed"}
@dataclass(slots=True)
class AttachmentAssociationJobState:
job_id: str
owner_username: str
owner_name: str
receipt_ids: list[str]
prompt: str = ""
conversation_id: str = ""
status: str = "queued"
message: str = "已创建附件关联任务,等待后台处理。"
claim_id: str = ""
claim_no: str = ""
uploaded_count: int = 0
skipped_count: int = 0
error: str = ""
created_at: datetime = field(default_factory=lambda: datetime.now(UTC))
updated_at: datetime = field(default_factory=lambda: datetime.now(UTC))
def to_read(self) -> AttachmentAssociationJobRead:
return AttachmentAssociationJobRead(
job_id=self.job_id,
status=self.status,
message=self.message,
receipt_ids=list(self.receipt_ids),
claim_id=self.claim_id,
claim_no=self.claim_no,
uploaded_count=self.uploaded_count,
skipped_count=self.skipped_count,
error=self.error,
prompt=self.prompt,
conversation_id=self.conversation_id,
created_at=self.created_at,
updated_at=self.updated_at,
)
@dataclass(slots=True)
class AttachmentAssociationCandidate:
claim: ExpenseClaim
score: int
reasons: list[str]
_jobs: dict[str, AttachmentAssociationJobState] = {}
_jobs_lock = Lock()
def clear_attachment_association_jobs_for_tests() -> None:
with _jobs_lock:
_jobs.clear()
def create_attachment_association_job(
payload: AttachmentAssociationJobCreate,
current_user: CurrentUserContext,
) -> AttachmentAssociationJobRead:
job_id = f"attachment-association-{uuid4()}"
state = AttachmentAssociationJobState(
job_id=job_id,
owner_username=str(current_user.username or "").strip(),
owner_name=str(current_user.name or "").strip(),
receipt_ids=list(payload.receipt_ids),
prompt=str(payload.prompt or "").strip(),
conversation_id=str(payload.conversation_id or "").strip(),
)
with _jobs_lock:
_jobs[job_id] = state
return state.to_read()
def get_attachment_association_job(
job_id: str,
current_user: CurrentUserContext,
) -> AttachmentAssociationJobRead | None:
state = _get_authorized_state(job_id, current_user)
return state.to_read() if state is not None else None
def run_attachment_association_job(
job_id: str,
current_user: CurrentUserContext,
session_factory: sessionmaker[Session] | Callable[[], Session],
) -> None:
state = _get_authorized_state(job_id, current_user)
if state is None or state.status in TERMINAL_STATUSES:
return
_update_job(job_id, status="running", message="正在匹配可关联的报销草稿...")
try:
with session_factory() as db:
result = AttachmentAssociationJobRunner(db).run(
receipt_ids=state.receipt_ids,
current_user=current_user,
)
_update_job(
job_id,
status="succeeded",
message=f"已自动关联到 {result['claim_no']},成功归集 {result['uploaded_count']} 份附件。",
claim_id=str(result["claim_id"]),
claim_no=str(result["claim_no"]),
uploaded_count=int(result["uploaded_count"]),
skipped_count=int(result["skipped_count"]),
error="",
)
except Exception as exc:
message = str(exc).strip() or "自动关联任务执行失败,请稍后重试。"
_update_job(
job_id,
status="failed",
message=message,
error=message,
)
def _get_authorized_state(
job_id: str,
current_user: CurrentUserContext,
) -> AttachmentAssociationJobState | None:
normalized_job_id = str(job_id or "").strip()
with _jobs_lock:
state = _jobs.get(normalized_job_id)
if state is None:
return None
if current_user.is_admin:
return state
username = str(current_user.username or "").strip()
name = str(current_user.name or "").strip()
if username and username == state.owner_username:
return state
if name and name == state.owner_name:
return state
return None
def _update_job(job_id: str, **updates: Any) -> None:
with _jobs_lock:
state = _jobs.get(str(job_id or "").strip())
if state is None:
return
for key, value in updates.items():
if hasattr(state, key):
setattr(state, key, value)
state.updated_at = datetime.now(UTC)
class AttachmentAssociationJobRunner:
def __init__(self, db: Session) -> None:
self.db = db
self.claim_service = ExpenseClaimService(db)
self.receipt_service = ReceiptFolderService()
def run(
self,
*,
receipt_ids: list[str],
current_user: CurrentUserContext,
) -> dict[str, Any]:
receipts = self._load_receipts(receipt_ids, current_user)
candidates = self._rank_claims(receipts, current_user)
if not candidates:
raise ValueError("没有找到可自动关联的报销草稿,请先新建草稿或补充说明。")
recommended = candidates[0]
runner_up = candidates[1] if len(candidates) > 1 else None
if recommended.score < 5 or (runner_up is not None and recommended.score - runner_up.score < 2):
raise ValueError("找到多个可能关联的报销草稿,请补充说明或手动选择后再归集。")
uploaded_count = 0
skipped_count = 0
for receipt in receipts:
if self._is_linked_to_other_claim(receipt, recommended.claim.id):
skipped_count += 1
continue
target_item = self._resolve_target_item(
claim_id=recommended.claim.id,
receipt=receipt,
current_user=current_user,
)
source_path, media_type, file_name = self.receipt_service.resolve_source(receipt.id, current_user)
result = self.claim_service.upload_claim_item_attachment(
claim_id=recommended.claim.id,
item_id=target_item.id,
filename=file_name,
content=source_path.read_bytes(),
media_type=media_type,
current_user=current_user,
source_receipt_id=receipt.id,
)
if result is None:
skipped_count += 1
else:
uploaded_count += 1
if uploaded_count <= 0:
raise ValueError("未能归集任何附件,请进入报销单详情手动核对。")
return {
"claim_id": recommended.claim.id,
"claim_no": recommended.claim.claim_no,
"uploaded_count": uploaded_count,
"skipped_count": skipped_count,
}
def _load_receipts(
self,
receipt_ids: list[str],
current_user: CurrentUserContext,
) -> list[ReceiptFolderDetailRead]:
receipts = []
for receipt_id in list(dict.fromkeys(str(item or "").strip() for item in receipt_ids if str(item or "").strip())):
try:
receipts.append(self.receipt_service.get_receipt(receipt_id, current_user))
except FileNotFoundError as exc:
raise ValueError("当前附件没有持久化票据记录,请重新上传后再试。") from exc
if not receipts:
raise ValueError("当前附件没有持久化票据记录,请重新上传后再试。")
return receipts
def _rank_claims(
self,
receipts: list[ReceiptFolderDetailRead],
current_user: CurrentUserContext,
) -> list[AttachmentAssociationCandidate]:
signals = _collect_receipt_signals(receipts)
claims = [
claim
for claim in self.claim_service.list_claims(current_user)
if self._is_auto_association_candidate(claim)
]
ranked = [
candidate
for candidate in (
self._score_claim(claim, signals)
for claim in claims
)
if candidate.score > 0
]
return sorted(ranked, key=lambda item: item.score, reverse=True)
def _is_auto_association_candidate(self, claim: ExpenseClaim) -> bool:
status = str(claim.status or "").strip().lower()
if status not in EDITABLE_CLAIM_STATUSES:
return False
return not self.claim_service._is_expense_application_claim(claim)
def _score_claim(
self,
claim: ExpenseClaim,
signals: dict[str, Any],
) -> AttachmentAssociationCandidate:
claim_text = _build_claim_text(claim)
compact_claim_text = _normalize_text(claim_text)
claim_dates = _extract_date_tokens(claim_text)
claim_cities = _unique([*_extract_city_tokens(claim_text), *_extract_city_tokens(claim.location)])
reasons: list[str] = []
score = 0
if _dates_overlap(signals["dates"], claim_dates):
score += 4
reasons.append("票据日期与报销单日期一致")
matched_cities = [city for city in signals["cities"] if city in compact_claim_text]
if matched_cities:
score += min(4, len(matched_cities) * 2)
reasons.append(f"地点或行程包含 {''.join(matched_cities)}")
if len(claim_cities) >= 2 and len(matched_cities) >= 2:
score += 2
reasons.append("票据往返城市与报销事由吻合")
if str(claim.status or "").strip().lower() == "draft":
score += 1
reasons.append("当前单据仍是可归集草稿")
return AttachmentAssociationCandidate(claim=claim, score=score, reasons=reasons)
@staticmethod
def _is_linked_to_other_claim(receipt: ReceiptFolderDetailRead, claim_id: str) -> bool:
linked_claim_id = str(receipt.linked_claim_id or "").strip()
return bool(str(receipt.status or "").strip() == "linked" and linked_claim_id and linked_claim_id != claim_id)
def _resolve_target_item(
self,
*,
claim_id: str,
receipt: ReceiptFolderDetailRead,
current_user: CurrentUserContext,
) -> ExpenseClaimItem:
claim = self.claim_service.get_claim(claim_id, current_user)
if claim is None:
raise ValueError("匹配到的报销草稿不存在,请刷新后再试。")
preferred_type = _resolve_receipt_item_type(receipt)
empty_items = [
item
for item in list(claim.items or [])
if not str(item.invoice_id or "").strip() and not item.is_system_generated
]
for item in empty_items:
if preferred_type and str(item.item_type or "").strip() == preferred_type:
return item
if empty_items:
return empty_items[0]
before_ids = {str(item.id) for item in list(claim.items or [])}
created_claim = self.claim_service.create_claim_item(
claim_id=claim.id,
payload=_build_item_payload_from_receipt(claim, receipt, preferred_type),
current_user=current_user,
)
if created_claim is None:
raise ValueError("无法创建票据归集明细,请进入详情页手动处理。")
for item in list(created_claim.items or []):
if str(item.id) not in before_ids and not str(item.invoice_id or "").strip():
return item
raise ValueError("无法找到可归集的费用明细,请进入详情页手动处理。")
def _normalize_text(value: Any) -> str:
return re.sub(r"\s+", "", str(value or "").strip())
def _unique(values: list[str] | tuple[str, ...]) -> list[str]:
return list(dict.fromkeys(str(item or "").strip() for item in values if str(item or "").strip()))
def _extract_date_tokens(text: Any) -> list[str]:
source = str(text or "")
matches = [
*re.finditer(r"20\d{2}[-/.年]\d{1,2}[-/.月]\d{1,2}", source),
*re.finditer(r"\d{1,2}月\d{1,2}", source),
]
return _unique([_normalize_date_token(match.group(0)) for match in matches])
def _normalize_date_token(value: Any) -> str:
if isinstance(value, (date, datetime)):
return value.isoformat()[:10]
text = str(value or "").strip()
full_match = re.search(r"(20\d{2})[-/.年](\d{1,2})[-/.月](\d{1,2})", text)
if full_match:
year, month, day = full_match.groups()
return f"{year}-{month.zfill(2)}-{day.zfill(2)}"
short_match = re.search(r"(\d{1,2})月(\d{1,2})", text)
if short_match:
month, day = short_match.groups()
return f"{month.zfill(2)}-{day.zfill(2)}"
return ""
def _extract_city_tokens(text: Any) -> list[str]:
compact = _normalize_text(text)
if not compact:
return []
return [city for city in CITY_NAMES if city in compact]
def _dates_overlap(left: list[str], right: list[str]) -> bool:
for left_date in left:
if not left_date:
continue
for right_date in right:
if right_date and (left_date == right_date or left_date.endswith(right_date) or right_date.endswith(left_date)):
return True
return False
def _collect_receipt_signals(receipts: list[ReceiptFolderDetailRead]) -> dict[str, Any]:
text = "\n".join(_build_receipt_text(receipt) for receipt in receipts)
dates = _unique([
*_extract_date_tokens(text),
*[str(receipt.document_date or "").strip() for receipt in receipts],
])
return {
"text": text,
"dates": dates,
"cities": _unique(_extract_city_tokens(text)),
}
def _build_receipt_text(receipt: ReceiptFolderDetailRead) -> str:
fields_text = "\n".join(
f"{field.label} {field.value}"
for field in list(receipt.fields or [])
if str(field.label or field.value or "").strip()
)
return "\n".join(
value
for value in (
receipt.file_name,
receipt.summary,
receipt.ocr_text,
receipt.document_date,
receipt.merchant_name,
fields_text,
)
if str(value or "").strip()
)
def _build_claim_text(claim: ExpenseClaim) -> str:
item_text = "\n".join(
" ".join(
str(value or "").strip()
for value in (
item.item_date.isoformat() if item.item_date else "",
item.item_type,
item.item_reason,
item.item_location,
item.item_note,
)
if str(value or "").strip()
)
for item in list(claim.items or [])
)
occurred_at = claim.occurred_at.isoformat()[:10] if claim.occurred_at else ""
return "\n".join(
value
for value in (
claim.claim_no,
claim.expense_type,
claim.status,
claim.reason,
claim.location,
occurred_at,
item_text,
)
if str(value or "").strip()
)
def _resolve_receipt_item_type(receipt: ReceiptFolderDetailRead) -> str:
document_type = str(receipt.document_type or "").strip()
if document_type in DOCUMENT_TYPE_ITEM_TYPE_MAP:
return DOCUMENT_TYPE_ITEM_TYPE_MAP[document_type]
scene_code = str(receipt.scene_code or "").strip()
if scene_code == "travel":
return "travel"
return scene_code or "other"
def _build_item_payload_from_receipt(
claim: ExpenseClaim,
receipt: ReceiptFolderDetailRead,
preferred_type: str,
) -> ExpenseClaimItemCreate:
item_date = _resolve_receipt_item_date(receipt) or (claim.occurred_at.date() if claim.occurred_at else None)
return ExpenseClaimItemCreate(
item_date=item_date,
item_type=preferred_type or str(claim.expense_type or "").strip() or "other",
item_reason=str(receipt.summary or receipt.file_name or "").strip(),
item_location=_resolve_receipt_item_location(receipt) or str(claim.location or "").strip(),
item_amount=Decimal("0.00"),
)
def _resolve_receipt_item_date(receipt: ReceiptFolderDetailRead) -> date | None:
for value in [
*[field.value for field in list(receipt.fields or []) if "日期" in str(field.label or "") or "时间" in str(field.label or "")],
receipt.document_date,
]:
token = _normalize_date_token(value)
if len(token) == 10:
try:
return date.fromisoformat(token)
except ValueError:
continue
return None
def _resolve_receipt_item_location(receipt: ReceiptFolderDetailRead) -> str:
for field in list(receipt.fields or []):
label = str(field.label or "")
value = str(field.value or "").strip()
if value and ("行程" in label or "到达" in label or "地点" in label or "城市" in label):
cities = _extract_city_tokens(value)
return cities[-1] if cities else value[:40]
cities = _extract_city_tokens(_build_receipt_text(receipt))
return cities[-1] if cities else ""

View File

@@ -0,0 +1,291 @@
from __future__ import annotations
from dataclasses import dataclass, field
from datetime import UTC, datetime
from threading import Lock
from typing import Any, Callable
from uuid import uuid4
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from app.api.deps import CurrentUserContext
from app.schemas.linked_reimbursement_draft_job import (
LinkedReimbursementDraftJobCreate,
LinkedReimbursementDraftJobRead,
)
from app.schemas.ontology import OntologyParseResult, OntologyPermission
from app.schemas.orchestrator import OrchestratorRequest
from app.models.financial_record import ExpenseClaim
from app.services.expense_claims import ExpenseClaimService
from app.services.orchestrator import OrchestratorService
TERMINAL_STATUSES = {"succeeded", "failed"}
@dataclass(slots=True)
class LinkedReimbursementDraftJobState:
job_id: str
owner_username: str
owner_name: str
message: str
context_json: dict[str, Any]
conversation_id: str = ""
status: str = "queued"
status_message: str = "已创建报销草稿生成任务,等待后台处理。"
error: str = ""
run_id: str = ""
draft_payload: dict[str, Any] | None = None
created_at: datetime = field(default_factory=lambda: datetime.now(UTC))
updated_at: datetime = field(default_factory=lambda: datetime.now(UTC))
def to_read(self) -> LinkedReimbursementDraftJobRead:
return LinkedReimbursementDraftJobRead(
job_id=self.job_id,
status=self.status,
message=self.status_message,
error=self.error,
run_id=self.run_id,
conversation_id=self.conversation_id,
draft_payload=self.draft_payload,
created_at=self.created_at,
updated_at=self.updated_at,
)
_jobs: dict[str, LinkedReimbursementDraftJobState] = {}
_jobs_lock = Lock()
def clear_linked_reimbursement_draft_jobs_for_tests() -> None:
with _jobs_lock:
_jobs.clear()
def create_linked_reimbursement_draft_job(
payload: LinkedReimbursementDraftJobCreate,
current_user: CurrentUserContext,
) -> LinkedReimbursementDraftJobRead:
context_json = dict(payload.context_json or {})
context_json["entry_source"] = context_json.get("entry_source") or "workbench-ai"
context_json["session_type"] = context_json.get("session_type") or "expense"
job_id = f"linked-reimbursement-draft-{uuid4()}"
state = LinkedReimbursementDraftJobState(
job_id=job_id,
owner_username=str(current_user.username or "").strip(),
owner_name=str(current_user.name or "").strip(),
message=str(payload.message or "").strip(),
context_json=context_json,
conversation_id=str(payload.conversation_id or "").strip(),
)
with _jobs_lock:
_jobs[job_id] = state
return state.to_read()
def get_linked_reimbursement_draft_job(
job_id: str,
current_user: CurrentUserContext,
) -> LinkedReimbursementDraftJobRead | None:
state = _get_authorized_state(job_id, current_user)
return state.to_read() if state is not None else None
def run_linked_reimbursement_draft_job(
job_id: str,
current_user: CurrentUserContext,
session_factory: sessionmaker[Session] | Callable[[], Session],
) -> None:
state = _get_authorized_state(job_id, current_user)
if state is None or state.status in TERMINAL_STATUSES:
return
_update_job(job_id, status="running", status_message="正在后台生成报销草稿...")
try:
with session_factory() as db:
if _can_use_direct_save_path(db, state.context_json):
run_id, result, draft_payload = _run_direct_save_path(
db=db,
state=state,
current_user=current_user,
)
else:
response = OrchestratorService(db).run(
OrchestratorRequest(
source="user_message",
user_id=_resolve_user_id(current_user),
conversation_id=None,
message=state.message,
context_json=dict(state.context_json),
)
)
run_id = response.run_id
result = response.result if isinstance(response.result, dict) else {}
draft_payload = result.get("draft_payload") if isinstance(result.get("draft_payload"), dict) else None
if response.status != "succeeded":
raise ValueError(str(result.get("message") or "报销草稿生成失败,请稍后重试。").strip())
if draft_payload is None:
raise ValueError("报销草稿生成完成,但未返回草稿信息,请刷新单据列表后核对。")
_update_job(
job_id,
status="succeeded",
status_message=str(result.get("message") or "报销草稿已生成。").strip(),
run_id=run_id,
draft_payload=draft_payload,
error="",
)
except Exception as exc:
message = str(exc).strip() or "报销草稿生成失败,请稍后重试。"
_update_job(
job_id,
status="failed",
status_message=message,
error=message,
)
def _get_authorized_state(
job_id: str,
current_user: CurrentUserContext,
) -> LinkedReimbursementDraftJobState | None:
normalized_job_id = str(job_id or "").strip()
with _jobs_lock:
state = _jobs.get(normalized_job_id)
if state is None:
return None
if current_user.is_admin:
return state
username = str(current_user.username or "").strip()
name = str(current_user.name or "").strip()
if username and username == state.owner_username:
return state
if name and name == state.owner_name:
return state
return None
def _update_job(job_id: str, **updates: Any) -> None:
with _jobs_lock:
state = _jobs.get(str(job_id or "").strip())
if state is None:
return
for key, value in updates.items():
if hasattr(state, key):
setattr(state, key, value)
state.updated_at = datetime.now(UTC)
def _resolve_user_id(current_user: CurrentUserContext) -> str:
return str(current_user.username or current_user.name or "anonymous").strip() or "anonymous"
def _can_use_direct_save_path(db: Session, context_json: dict[str, Any]) -> bool:
review_action = str((context_json or {}).get("review_action") or "").strip()
if review_action != "save_draft":
return False
review_values = context_json.get("review_form_values")
if not isinstance(review_values, dict):
return False
application_claim_id = str(review_values.get("application_claim_id") or "").strip()
application_claim_no = str(review_values.get("application_claim_no") or "").strip()
if not application_claim_no:
return False
if application_claim_id:
return True
return _find_application_claim_by_no(db, application_claim_no) is not None
def _find_application_claim_by_no(db: Session, claim_no: str) -> ExpenseClaim | None:
normalized_claim_no = str(claim_no or "").strip()
if not normalized_claim_no:
return None
claim = db.scalar(
select(ExpenseClaim)
.where(ExpenseClaim.claim_no == normalized_claim_no)
.limit(1)
)
if claim is not None and ExpenseClaimService._is_expense_application_claim(claim):
return claim
return None
def _build_direct_context_json(db: Session, context_json: dict[str, Any]) -> dict[str, Any]:
direct_context = dict(context_json or {})
review_values = dict(direct_context.get("review_form_values") or {})
scene_selection = dict(direct_context.get("expense_scene_selection") or {})
application_claim_id = str(review_values.get("application_claim_id") or "").strip()
application_claim_no = str(review_values.get("application_claim_no") or "").strip()
if not application_claim_id and application_claim_no:
application_claim = _find_application_claim_by_no(db, application_claim_no)
if application_claim is not None:
review_values["application_claim_id"] = application_claim.id
scene_selection["application_claim_id"] = application_claim.id
scene_selection["application_claim_no"] = str(
scene_selection.get("application_claim_no")
or application_claim.claim_no
or application_claim_no
).strip()
direct_context["review_form_values"] = review_values
if scene_selection:
direct_context["expense_scene_selection"] = scene_selection
return direct_context
def _run_direct_save_path(
*,
db: Session,
state: LinkedReimbursementDraftJobState,
current_user: CurrentUserContext,
) -> tuple[str, dict[str, Any], dict[str, Any]]:
run_id = state.job_id
ontology = OntologyParseResult(
scenario="expense",
intent="draft",
permission=OntologyPermission(
level="draft_write",
allowed=True,
reason="关联申请单生成报销草稿快路径。",
),
confidence=1.0,
run_id=run_id,
)
result = ExpenseClaimService(db).save_or_submit_from_ontology(
run_id=run_id,
user_id=_resolve_user_id(current_user),
message=state.message,
ontology=ontology,
context_json=_build_direct_context_json(db, state.context_json),
)
claim_id = str(result.get("claim_id") or "").strip()
claim_no = str(result.get("claim_no") or "").strip()
if not claim_id or not claim_no or str(result.get("status") or "").strip() != "draft":
raise ValueError(str(result.get("message") or "报销草稿生成失败,请稍后重试。").strip())
claim = db.get(ExpenseClaim, claim_id)
return run_id, result, _build_direct_draft_payload(result, claim)
def _build_direct_draft_payload(
result: dict[str, Any],
claim: ExpenseClaim | None,
) -> dict[str, Any]:
claim_id = str(result.get("claim_id") or getattr(claim, "id", "") or "").strip()
claim_no = str(result.get("claim_no") or getattr(claim, "claim_no", "") or "").strip()
status = str(result.get("status") or getattr(claim, "status", "") or "draft").strip()
approval_stage = str(getattr(claim, "approval_stage", "") or "待提交").strip()
expense_type = str(getattr(claim, "expense_type", "") or "").strip()
message = str(result.get("message") or "报销草稿已生成。").strip()
return {
"draft_type": "expense",
"title": f"费用草稿 {claim_no}" if claim_no else "费用草稿",
"body": message,
"confirmation_required": True,
"claim_id": claim_id,
"claim_no": claim_no,
"status": status,
"approval_stage": approval_stage,
"expense_type": expense_type,
}

View File

@@ -0,0 +1,280 @@
from __future__ import annotations
from collections.abc import Generator
from datetime import UTC, date, datetime
from decimal import Decimal
from fastapi.testclient import TestClient
from sqlalchemy import select
from sqlalchemy.orm import Session, selectinload
from app.api.deps import CurrentUserContext, get_db
from app.api.v1.endpoints import attachment_association_jobs as attachment_jobs_endpoint
from app.core.config import get_settings
from app.main import create_app
from app.models.employee import Employee
from app.models.financial_record import ExpenseClaim, ExpenseClaimItem
from app.schemas.ocr import OcrRecognizeBatchRead, OcrRecognizeDocumentRead, OcrRecognizeFieldRead
from app.services.attachment_association_jobs import clear_attachment_association_jobs_for_tests
from app.services.expense_claim_attachment_storage import ExpenseClaimAttachmentStorage
from app.services.ocr import OcrService
from app.services.receipt_folder import ReceiptFolderService
from app.test_helpers.db import build_in_memory_session_factory
def build_client(monkeypatch) -> tuple[TestClient, object]:
session_factory = build_in_memory_session_factory()
app = create_app()
def override_db() -> Generator[Session, None, None]:
db = session_factory()
try:
yield db
finally:
db.close()
app.dependency_overrides[get_db] = override_db
monkeypatch.setattr(attachment_jobs_endpoint, "get_session_factory", lambda: session_factory)
return TestClient(app), session_factory
def seed_travel_claim(db: Session) -> ExpenseClaim:
employee = Employee(
id="emp-bg-association",
employee_no="E10001",
name="张三",
email="zhangsan@example.com",
position="实施顾问",
grade="P4",
)
claim = ExpenseClaim(
id="claim-bg-association",
claim_no="BX-20260220-001",
employee_id=employee.id,
employee_name=employee.name,
department_id="dept-delivery",
department_name="交付部",
project_code=None,
expense_type="travel",
reason="辅助国网仿生产服务器部署,武汉往返上海",
location="上海",
amount=Decimal("0.00"),
currency="CNY",
invoice_count=0,
occurred_at=datetime(2026, 2, 20, tzinfo=UTC),
submitted_at=None,
status="draft",
approval_stage="待提交",
risk_flags_json=[],
)
item = ExpenseClaimItem(
id="item-bg-association-1",
claim_id=claim.id,
item_date=date(2026, 2, 20),
item_type="train_ticket",
item_reason="武汉至上海高铁",
item_location="上海",
item_amount=Decimal("0.00"),
invoice_id=None,
)
claim.items = [item]
db.add_all([employee, claim])
db.commit()
return claim
def save_train_receipt(
*,
service: ReceiptFolderService,
current_user: CurrentUserContext,
filename: str,
route: str,
trip_date: str,
) -> str:
receipt = service.save_receipt(
filename=filename,
content=f"fake-pdf-{filename}".encode("utf-8"),
media_type="application/pdf",
current_user=current_user,
document=OcrRecognizeDocumentRead(
filename=filename,
media_type="application/pdf",
text=f"电子发票(铁路电子客票) {route} {trip_date} 票价 354 元",
summary=f"铁路电子客票,{route},票价 354 元。",
avg_score=0.96,
line_count=1,
page_count=1,
document_type="train_ticket",
document_type_label="火车/高铁票",
scene_code="travel",
scene_label="差旅票据",
document_fields=[
OcrRecognizeFieldRead(key="date", label="列车出发时间", value=trip_date),
OcrRecognizeFieldRead(key="route", label="行程", value=route),
OcrRecognizeFieldRead(key="amount", label="金额", value="354元"),
],
),
)
return receipt.id
def fake_ocr_recognize(
self,
files: list[tuple[str, bytes, str | None]],
) -> OcrRecognizeBatchRead:
filename = files[0][0]
return OcrRecognizeBatchRead(
total_file_count=1,
success_count=1,
documents=[
OcrRecognizeDocumentRead(
filename=filename,
media_type=files[0][2] or "application/pdf",
text="电子发票(铁路电子客票) 武汉 上海 2026-02-20 票价 354 元",
summary="铁路电子客票,武汉至上海,票价 354 元。",
avg_score=0.96,
line_count=1,
page_count=1,
document_type="train_ticket",
document_type_label="火车/高铁票",
scene_code="travel",
scene_label="差旅票据",
document_fields=[
OcrRecognizeFieldRead(key="date", label="列车出发时间", value="2026-02-20"),
OcrRecognizeFieldRead(key="route", label="行程", value="武汉-上海"),
OcrRecognizeFieldRead(key="amount", label="金额", value="354元"),
],
)
],
)
def test_attachment_association_job_links_receipts_after_conversation_exit(monkeypatch, tmp_path) -> None:
monkeypatch.setenv("STORAGE_ROOT_DIR", str(tmp_path / "storage"))
get_settings.cache_clear()
clear_attachment_association_jobs_for_tests()
monkeypatch.setattr(OcrService, "recognize_files", fake_ocr_recognize)
monkeypatch.setattr(ExpenseClaimAttachmentStorage, "root", lambda self: tmp_path / "attachments")
try:
client, session_factory = build_client(monkeypatch)
current_user = CurrentUserContext(
username="zhangsan@example.com",
name="张三",
role_codes=["user"],
is_admin=False,
employee_no="E10001",
)
with session_factory() as db:
seed_travel_claim(db)
receipt_service = ReceiptFolderService()
receipt_ids = [
save_train_receipt(
service=receipt_service,
current_user=current_user,
filename="2月20 武汉-上海.pdf",
route="武汉-上海",
trip_date="2026-02-20",
),
save_train_receipt(
service=receipt_service,
current_user=current_user,
filename="2月23 上海-武汉.pdf",
route="上海-武汉",
trip_date="2026-02-23",
),
]
headers = {
"x-auth-username": "zhangsan@example.com",
"x-auth-name": "Zhang San",
"x-auth-employee-no": "E10001",
"x-auth-role-codes": "user",
}
response = client.post(
"/api/v1/reimbursements/attachment-association-jobs",
headers=headers,
json={
"receipt_ids": receipt_ids,
"prompt": "请帮我处理已上传的附件。",
"conversation_id": "inline-test",
},
)
assert response.status_code == 202
job_id = response.json()["job_id"]
status_response = client.get(
f"/api/v1/reimbursements/attachment-association-jobs/{job_id}",
headers=headers,
)
assert status_response.status_code == 200
payload = status_response.json()
assert payload["status"] == "succeeded"
assert payload["claim_id"] == "claim-bg-association"
assert payload["claim_no"] == "BX-20260220-001"
assert payload["uploaded_count"] == 2
with session_factory() as db:
claim = db.scalar(
select(ExpenseClaim)
.options(selectinload(ExpenseClaim.items))
.where(ExpenseClaim.id == "claim-bg-association")
)
assert claim is not None
attached_items = [item for item in claim.items if item.invoice_id]
assert len(attached_items) == 2
linked_receipts = receipt_service.list_receipts(current_user=current_user, status_filter="linked")
assert {item.id for item in linked_receipts} == set(receipt_ids)
assert {item.linked_claim_id for item in linked_receipts} == {"claim-bg-association"}
finally:
clear_attachment_association_jobs_for_tests()
get_settings.cache_clear()
def test_attachment_association_job_fails_without_editable_claim(monkeypatch, tmp_path) -> None:
monkeypatch.setenv("STORAGE_ROOT_DIR", str(tmp_path / "storage"))
get_settings.cache_clear()
clear_attachment_association_jobs_for_tests()
try:
client, _session_factory = build_client(monkeypatch)
current_user = CurrentUserContext(
username="zhangsan@example.com",
name="张三",
role_codes=["user"],
is_admin=False,
employee_no="E10001",
)
receipt_id = save_train_receipt(
service=ReceiptFolderService(),
current_user=current_user,
filename="2月20 武汉-上海.pdf",
route="武汉-上海",
trip_date="2026-02-20",
)
headers = {
"x-auth-username": "zhangsan@example.com",
"x-auth-name": "Zhang San",
"x-auth-employee-no": "E10001",
"x-auth-role-codes": "user",
}
response = client.post(
"/api/v1/reimbursements/attachment-association-jobs",
headers=headers,
json={"receipt_ids": [receipt_id], "conversation_id": "inline-empty"},
)
assert response.status_code == 202
status_response = client.get(
f"/api/v1/reimbursements/attachment-association-jobs/{response.json()['job_id']}",
headers=headers,
)
assert status_response.status_code == 200
payload = status_response.json()
assert payload["status"] == "failed"
assert "没有找到可自动关联的报销草稿" in payload["message"]
finally:
clear_attachment_association_jobs_for_tests()
get_settings.cache_clear()

View File

@@ -0,0 +1,296 @@
from __future__ import annotations
from collections.abc import Generator
from datetime import UTC, datetime
from fastapi.testclient import TestClient
from sqlalchemy.orm import Session
from app.api.deps import get_db
from app.api.v1.endpoints import linked_reimbursement_draft_jobs as draft_jobs_endpoint
from app.main import create_app
from app.models.employee import Employee
from app.models.financial_record import ExpenseClaim
from app.schemas.orchestrator import OrchestratorResponse, OrchestratorTraceSummary
from app.services.linked_reimbursement_draft_jobs import clear_linked_reimbursement_draft_jobs_for_tests
from app.services.orchestrator import OrchestratorService
from app.test_helpers.db import build_in_memory_session_factory
def seed_employee_and_application(db: Session) -> None:
employee = Employee(
id="emp-linked-draft-fast",
employee_no="E10001",
name="张三",
email="zhangsan@example.com",
position="实施顾问",
grade="P5",
)
application = ExpenseClaim(
id="application-linked-draft-fast",
claim_no="AP-202606-FAST",
employee_id=employee.id,
employee_name=employee.name,
department_id="dept-delivery",
department_name="交付部",
project_code=None,
expense_type="travel_application",
reason="支撑国网仿生产服务器部署",
location="上海",
amount=3000,
currency="CNY",
invoice_count=0,
occurred_at=datetime(2026, 2, 20, tzinfo=UTC),
submitted_at=None,
status="approved",
approval_stage="已完成",
risk_flags_json=[],
)
db.add_all([employee, application])
db.commit()
def build_client(monkeypatch) -> tuple[TestClient, object]:
session_factory = build_in_memory_session_factory()
app = create_app()
def override_db() -> Generator[Session, None, None]:
db = session_factory()
try:
yield db
finally:
db.close()
app.dependency_overrides[get_db] = override_db
monkeypatch.setattr(draft_jobs_endpoint, "get_session_factory", lambda: session_factory)
return TestClient(app), session_factory
def test_linked_reimbursement_draft_job_runs_after_conversation_leaves(monkeypatch) -> None:
clear_linked_reimbursement_draft_jobs_for_tests()
captured_messages = []
def fake_run(self, payload):
captured_messages.append(payload.message)
return OrchestratorResponse(
run_id="run-linked-draft-job",
conversation_id=None,
selected_agent="user_agent",
route_reason="测试后台生成报销草稿。",
permission_level="draft_write",
status="succeeded",
result={
"message": "报销草稿已生成。",
"draft_payload": {
"claim_id": "draft-linked-1",
"claim_no": "RE-202606-009",
"status": "draft",
"expense_type": "travel",
},
},
requires_confirmation=False,
trace_summary=OrchestratorTraceSummary(
scenario="expense",
intent="draft",
tool_count=1,
failed_tool_count=0,
selected_capability_codes=[],
degraded=False,
),
)
monkeypatch.setattr(OrchestratorService, "run", fake_run)
try:
client, _session_factory = build_client(monkeypatch)
headers = {
"x-auth-username": "zhangsan@example.com",
"x-auth-name": "Zhang San",
"x-auth-employee-no": "E10001",
"x-auth-role-codes": "user",
}
response = client.post(
"/api/v1/reimbursements/linked-reimbursement-draft-jobs",
headers=headers,
json={
"message": "我要报销\n用户选择报销场景:差旅费\n关联申请单AP-202606-001",
"conversation_id": "inline-test",
"context_json": {
"review_action": "save_draft",
"expense_scene_selection": {
"expense_type": "travel",
"expense_type_label": "差旅费",
"application_claim_no": "AP-202606-001",
},
"review_form_values": {
"application_claim_no": "AP-202606-001",
},
},
},
)
assert response.status_code == 202
job_id = response.json()["job_id"]
status_response = client.get(
f"/api/v1/reimbursements/linked-reimbursement-draft-jobs/{job_id}",
headers=headers,
)
assert status_response.status_code == 200
payload = status_response.json()
assert payload["status"] == "succeeded"
assert payload["draft_payload"]["claim_no"] == "RE-202606-009"
assert payload["run_id"] == "run-linked-draft-job"
assert captured_messages == ["我要报销\n用户选择报销场景:差旅费\n关联申请单AP-202606-001"]
finally:
clear_linked_reimbursement_draft_jobs_for_tests()
def test_linked_reimbursement_draft_job_uses_direct_save_path(monkeypatch) -> None:
clear_linked_reimbursement_draft_jobs_for_tests()
def fail_if_orchestrator_runs(self, payload):
raise AssertionError("linked draft job should not run full orchestrator")
monkeypatch.setattr(OrchestratorService, "run", fail_if_orchestrator_runs)
try:
client, session_factory = build_client(monkeypatch)
with session_factory() as db:
seed_employee_and_application(db)
headers = {
"x-auth-username": "zhangsan@example.com",
"x-auth-name": "Zhang San",
"x-auth-employee-no": "E10001",
"x-auth-role-codes": "user",
}
response = client.post(
"/api/v1/reimbursements/linked-reimbursement-draft-jobs",
headers=headers,
json={
"message": "我要报销\n用户选择报销场景:差旅费\n关联申请单AP-202606-FAST",
"conversation_id": "inline-fast-test",
"context_json": {
"name": "张三",
"review_action": "save_draft",
"expense_scene_selection": {
"expense_type": "travel",
"expense_type_label": "差旅费",
"application_claim_id": "application-linked-draft-fast",
"application_claim_no": "AP-202606-FAST",
},
"review_form_values": {
"expense_type": "差旅费",
"reason": "支撑国网仿生产服务器部署",
"location": "上海",
"time_range": "2026-02-20 至 2026-02-23",
"application_claim_id": "application-linked-draft-fast",
"application_claim_no": "AP-202606-FAST",
"application_reason": "支撑国网仿生产服务器部署",
"application_location": "上海",
"application_amount": "3000",
"application_amount_label": "¥3,000",
"application_business_time": "2026-02-20 至 2026-02-23",
},
},
},
)
assert response.status_code == 202
job_id = response.json()["job_id"]
status_response = client.get(
f"/api/v1/reimbursements/linked-reimbursement-draft-jobs/{job_id}",
headers=headers,
)
assert status_response.status_code == 200
payload = status_response.json()
assert payload["status"] == "succeeded"
assert payload["draft_payload"]["claim_no"]
assert payload["draft_payload"]["claim_id"]
assert payload["run_id"].startswith("linked-reimbursement-draft-")
with session_factory() as db:
draft = db.get(ExpenseClaim, payload["draft_payload"]["claim_id"])
assert draft is not None
assert draft.status == "draft"
assert draft.expense_type == "travel"
assert draft.reason == "支撑国网仿生产服务器部署"
assert draft.items == []
finally:
clear_linked_reimbursement_draft_jobs_for_tests()
def test_linked_reimbursement_draft_job_uses_direct_save_path_with_application_no_only(monkeypatch) -> None:
clear_linked_reimbursement_draft_jobs_for_tests()
def fail_if_orchestrator_runs(self, payload):
raise AssertionError("linked draft job should resolve application no without full orchestrator")
monkeypatch.setattr(OrchestratorService, "run", fail_if_orchestrator_runs)
try:
client, session_factory = build_client(monkeypatch)
with session_factory() as db:
seed_employee_and_application(db)
headers = {
"x-auth-username": "zhangsan@example.com",
"x-auth-name": "Zhang San",
"x-auth-employee-no": "E10001",
"x-auth-role-codes": "user",
}
response = client.post(
"/api/v1/reimbursements/linked-reimbursement-draft-jobs",
headers=headers,
json={
"message": "我要报销\n用户选择报销场景:差旅费\n关联申请单AP-202606-FAST",
"conversation_id": "inline-fast-no-id-test",
"context_json": {
"name": "张三",
"review_action": "save_draft",
"expense_scene_selection": {
"expense_type": "travel",
"expense_type_label": "差旅费",
"application_claim_no": "AP-202606-FAST",
},
"review_form_values": {
"expense_type": "差旅费",
"reason": "支撑国网仿生产服务器部署",
"location": "上海",
"time_range": "2026-02-20 至 2026-02-23",
"application_claim_no": "AP-202606-FAST",
"application_reason": "支撑国网仿生产服务器部署",
"application_location": "上海",
"application_amount": "3000",
},
},
},
)
assert response.status_code == 202
job_id = response.json()["job_id"]
status_response = client.get(
f"/api/v1/reimbursements/linked-reimbursement-draft-jobs/{job_id}",
headers=headers,
)
assert status_response.status_code == 200
payload = status_response.json()
assert payload["status"] == "succeeded"
assert payload["draft_payload"]["claim_no"]
assert payload["draft_payload"]["claim_id"]
with session_factory() as db:
draft = db.get(ExpenseClaim, payload["draft_payload"]["claim_id"])
assert draft is not None
link_flag = next(
flag
for flag in draft.risk_flags_json
if flag.get("source") == "application_link"
)
assert link_flag["application_claim_no"] == "AP-202606-FAST"
assert link_flag["application_claim_id"] == "application-linked-draft-fast"
finally:
clear_linked_reimbursement_draft_jobs_for_tests()