diff --git a/server/src/app/api/v1/endpoints/attachment_association_jobs.py b/server/src/app/api/v1/endpoints/attachment_association_jobs.py new file mode 100644 index 0000000..aed3857 --- /dev/null +++ b/server/src/app/api/v1/endpoints/attachment_association_jobs.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +from typing import Annotated + +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, status + +from app.api.deps import CurrentUserContext, get_current_user +from app.db.session import get_session_factory +from app.schemas.attachment_association_job import ( + AttachmentAssociationJobCreate, + AttachmentAssociationJobRead, +) +from app.schemas.common import ErrorResponse +from app.services.attachment_association_jobs import ( + create_attachment_association_job, + get_attachment_association_job, + run_attachment_association_job, +) + +router = APIRouter(prefix="/reimbursements/attachment-association-jobs") +CurrentUser = Annotated[CurrentUserContext, Depends(get_current_user)] + + +@router.post( + "", + response_model=AttachmentAssociationJobRead, + status_code=status.HTTP_202_ACCEPTED, + summary="创建附件自动关联后台任务", + description="根据已 OCR 入票据夹的 receipt_id,在后台自动匹配并归集到报销草稿。", + responses={ + status.HTTP_400_BAD_REQUEST: { + "model": ErrorResponse, + "description": "请求缺少可关联票据。", + }, + }, +) +def create_attachment_association_job_endpoint( + payload: AttachmentAssociationJobCreate, + background_tasks: BackgroundTasks, + current_user: CurrentUser, +) -> AttachmentAssociationJobRead: + try: + job = create_attachment_association_job(payload, current_user) + except ValueError as exc: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc + background_tasks.add_task( + run_attachment_association_job, + job.job_id, + current_user, + get_session_factory(), + ) + return job + + +@router.get( + "/{job_id}", + response_model=AttachmentAssociationJobRead, + summary="查询附件自动关联后台任务", + description="用于前端会话恢复后按 job_id 查询任务状态。", + responses={ + status.HTTP_404_NOT_FOUND: { + "model": ErrorResponse, + "description": "任务不存在或当前用户无权查看。", + }, + }, +) +def get_attachment_association_job_endpoint( + job_id: str, + current_user: CurrentUser, +) -> AttachmentAssociationJobRead: + job = get_attachment_association_job(job_id, current_user) + if job is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="附件关联任务不存在或已失效。") + return job + diff --git a/server/src/app/api/v1/endpoints/linked_reimbursement_draft_jobs.py b/server/src/app/api/v1/endpoints/linked_reimbursement_draft_jobs.py new file mode 100644 index 0000000..4fcfad2 --- /dev/null +++ b/server/src/app/api/v1/endpoints/linked_reimbursement_draft_jobs.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +from typing import Annotated + +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, status + +from app.api.deps import CurrentUserContext, get_current_user +from app.db.session import get_session_factory +from app.schemas.common import ErrorResponse +from app.schemas.linked_reimbursement_draft_job import ( + LinkedReimbursementDraftJobCreate, + LinkedReimbursementDraftJobRead, +) +from app.services.linked_reimbursement_draft_jobs import ( + create_linked_reimbursement_draft_job, + get_linked_reimbursement_draft_job, + run_linked_reimbursement_draft_job, +) + +router = APIRouter(prefix="/reimbursements/linked-reimbursement-draft-jobs") +CurrentUser = Annotated[CurrentUserContext, Depends(get_current_user)] + + +@router.post( + "", + response_model=LinkedReimbursementDraftJobRead, + status_code=status.HTTP_202_ACCEPTED, + summary="创建关联申请单生成报销草稿后台任务", + description="用户选择关联申请单后,后台继续生成报销草稿,避免当前会话长时间同步等待。", + responses={ + status.HTTP_400_BAD_REQUEST: { + "model": ErrorResponse, + "description": "请求缺少申请单关联上下文。", + }, + }, +) +def create_linked_reimbursement_draft_job_endpoint( + payload: LinkedReimbursementDraftJobCreate, + background_tasks: BackgroundTasks, + current_user: CurrentUser, +) -> LinkedReimbursementDraftJobRead: + try: + job = create_linked_reimbursement_draft_job(payload, current_user) + except ValueError as exc: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc + background_tasks.add_task( + run_linked_reimbursement_draft_job, + job.job_id, + current_user, + get_session_factory(), + ) + return job + + +@router.get( + "/{job_id}", + response_model=LinkedReimbursementDraftJobRead, + summary="查询关联申请单生成报销草稿后台任务", + description="用于前端按 job_id 查询草稿生成状态。", + responses={ + status.HTTP_404_NOT_FOUND: { + "model": ErrorResponse, + "description": "任务不存在或当前用户无权查看。", + }, + }, +) +def get_linked_reimbursement_draft_job_endpoint( + job_id: str, + current_user: CurrentUser, +) -> LinkedReimbursementDraftJobRead: + job = get_linked_reimbursement_draft_job(job_id, current_user) + if job is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="报销草稿生成任务不存在或已失效。") + return job diff --git a/server/src/app/api/v1/router.py b/server/src/app/api/v1/router.py index e808b87..5892bf3 100644 --- a/server/src/app/api/v1/router.py +++ b/server/src/app/api/v1/router.py @@ -6,6 +6,7 @@ from app.api.v1.endpoints.agent_feedback import router as agent_feedback_router from app.api.v1.endpoints.agent_runs import router as agent_runs_router from app.api.v1.endpoints.agent_traces import router as agent_traces_router from app.api.v1.endpoints.analytics import router as analytics_router +from app.api.v1.endpoints.attachment_association_jobs import router as attachment_association_jobs_router from app.api.v1.endpoints.audit_logs import router as audit_logs_router from app.api.v1.endpoints.auth import router as auth_router from app.api.v1.endpoints.bootstrap import router as bootstrap_router @@ -14,6 +15,7 @@ from app.api.v1.endpoints.employees import router as employees_router from app.api.v1.endpoints.employee_profiles import router as employee_profiles_router from app.api.v1.endpoints.health import router as health_router from app.api.v1.endpoints.knowledge import router as knowledge_router +from app.api.v1.endpoints.linked_reimbursement_draft_jobs import router as linked_reimbursement_draft_jobs_router from app.api.v1.endpoints.notification_states import router as notification_states_router from app.api.v1.endpoints.ocr import router as ocr_router from app.api.v1.endpoints.ontology import router as ontology_router @@ -36,8 +38,10 @@ router.include_router(agent_feedback_router, tags=["agent-feedback"]) router.include_router(agent_runs_router, tags=["agent-runs"]) router.include_router(agent_traces_router, tags=["agent-traces"]) router.include_router(analytics_router, tags=["analytics"]) +router.include_router(attachment_association_jobs_router, tags=["attachment-association-jobs"]) router.include_router(audit_logs_router, tags=["audit-logs"]) router.include_router(knowledge_router, tags=["knowledge"]) +router.include_router(linked_reimbursement_draft_jobs_router, tags=["linked-reimbursement-draft-jobs"]) router.include_router(notification_states_router, tags=["notification-states"]) router.include_router(ocr_router, tags=["ocr"]) router.include_router(ontology_router, tags=["ontology"]) diff --git a/server/src/app/schemas/attachment_association_job.py b/server/src/app/schemas/attachment_association_job.py new file mode 100644 index 0000000..7d3b0dc --- /dev/null +++ b/server/src/app/schemas/attachment_association_job.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from datetime import datetime + +from pydantic import BaseModel, Field, field_validator + + +class AttachmentAssociationJobCreate(BaseModel): + receipt_ids: list[str] = Field(default_factory=list, description="票据夹持久化票据 ID。") + prompt: str = Field(default="", max_length=1000, description="用户发送时的上下文说明。") + conversation_id: str = Field(default="", max_length=120, description="前端会话 ID,用于状态恢复。") + + @field_validator("receipt_ids") + @classmethod + def validate_receipt_ids(cls, value: list[str]) -> list[str]: + receipt_ids = [ + str(item or "").strip() + for item in list(value or []) + if str(item or "").strip() + ] + if not receipt_ids: + raise ValueError("请先完成附件 OCR 识别,再发起自动关联。") + return list(dict.fromkeys(receipt_ids)) + + +class AttachmentAssociationJobRead(BaseModel): + job_id: str + status: str + message: str = "" + receipt_ids: list[str] = Field(default_factory=list) + claim_id: str = "" + claim_no: str = "" + uploaded_count: int = 0 + skipped_count: int = 0 + error: str = "" + prompt: str = "" + conversation_id: str = "" + created_at: datetime + updated_at: datetime + diff --git a/server/src/app/schemas/linked_reimbursement_draft_job.py b/server/src/app/schemas/linked_reimbursement_draft_job.py new file mode 100644 index 0000000..9ef6867 --- /dev/null +++ b/server/src/app/schemas/linked_reimbursement_draft_job.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from datetime import datetime +from typing import Any + +from pydantic import BaseModel, Field, field_validator + + +class LinkedReimbursementDraftJobCreate(BaseModel): + message: str = Field(min_length=1, max_length=3000, description="生成报销草稿的原始助手请求。") + context_json: dict[str, Any] = Field(default_factory=dict, description="复用 Orchestrator 的上下文。") + conversation_id: str = Field(default="", max_length=120, description="前端会话 ID,用于状态恢复。") + + @field_validator("message") + @classmethod + def validate_message(cls, value: str) -> str: + normalized = str(value or "").strip() + if not normalized: + raise ValueError("请先选择要关联的申请单。") + return normalized + + +class LinkedReimbursementDraftJobRead(BaseModel): + job_id: str + status: str + message: str = "" + error: str = "" + run_id: str = "" + conversation_id: str = "" + draft_payload: dict[str, Any] | None = None + created_at: datetime + updated_at: datetime diff --git a/server/src/app/services/application_location_semantics.py b/server/src/app/services/application_location_semantics.py new file mode 100644 index 0000000..510deda --- /dev/null +++ b/server/src/app/services/application_location_semantics.py @@ -0,0 +1,189 @@ +from __future__ import annotations + +import re +from collections.abc import Iterable +from functools import lru_cache +from typing import Any + +from app.services.user_agent_application_locations import ( + CITY_TO_PROVINCE, + DIRECT_MUNICIPALITY_DISPLAY, +) + +PLACEHOLDER_LOCATION_VALUES = {"", "待补充", "待确认", "未知", "暂无", "无", "null", "none"} +BUSINESS_ACTION_PATTERN = re.compile(r"(?:支撑|支持|辅助|部署|上线|实施|验收|项目)") +BUSINESS_OBJECT_PATTERN = re.compile(r"(?:服务器|系统|仿生产|生产环境|测试环境)") +SPECIFIC_ADDRESS_HINT_PATTERN = re.compile(r""" + (?:省|市|区|县|自治州|州|镇|乡|街道|路|街|大道|园区|大厦|中心|基地|机场|车站|高铁站|火车站|酒店|楼|号)$ +""", re.VERBOSE) +LOCATION_TAGS = {"LOC", "ns", "s"} +JIEBA_LOCATION_TAGS = {"ns"} +JIEBA_CUSTOM_WORDS = ( + "国网", + "仿生产", + "生产环境", + "测试环境", + "服务器", + "部署", + "辅助", + "支撑", + "支持", + "上线", + "实施", + "验收", +) +ROUTE_LOCATION_PREFIX_PATTERN = re.compile( + r"^(?P.*?(?:出差|前往|去|到|赴))(?P[\u4e00-\u9fa5].*)$" +) + + +def compact_application_location_text(value: object) -> str: + text = re.sub(r"\s+", "", str(value or "")) + text = re.sub(r"^(?:地点|业务地点|发生地点|目的地)[::]", "", text) + text = re.sub(r"^(?:去|到|赴|前往)", "", text) + return text.strip("::,,。;;、") + + +def validate_application_location_text(value: object) -> str: + text = compact_application_location_text(value) + if text.lower() in PLACEHOLDER_LOCATION_VALUES: + return "" + if not location_mixes_business_content(text): + return "" + return ( + f"地点“{text}”混入了业务事项,请填写真实出差地点,例如“上海”;" + "业务背景请放在申请事由中。" + ) + + +def location_mixes_business_content(value: object) -> bool: + text = compact_application_location_text(value) + if text.lower() in PLACEHOLDER_LOCATION_VALUES: + return False + if _matches_business_location_pattern(text): + return True + return _lac_detects_business_location_mix(text) + + +def _matches_business_location_pattern(text: str) -> bool: + if BUSINESS_ACTION_PATTERN.search(text): + return True + if BUSINESS_OBJECT_PATTERN.search(text) and not SPECIFIC_ADDRESS_HINT_PATTERN.search(text): + return True + return False + + +def _lac_detects_business_location_mix(text: str) -> bool: + tokens = list(resolve_lac_tokens(text)) + if not tokens: + return False + has_location = any(tag in LOCATION_TAGS for _, tag in tokens) + if not has_location: + return False + non_location_text = "".join( + word + for word, tag in tokens + if tag not in LOCATION_TAGS and tag != "w" + ) + return _matches_business_location_pattern(non_location_text) + + +@lru_cache(maxsize=1) +def _load_lac_analyzer() -> Any: + try: + from LAC import LAC # type: ignore + except Exception: + return None + try: + return LAC(mode="lac") + except Exception: + return None + + +def resolve_lac_tokens(text: str) -> Iterable[tuple[str, str]]: + analyzer = _load_lac_analyzer() + if analyzer is None: + return [] + try: + result = analyzer.run(text) + except Exception: + return [] + return _parse_lac_result(result) + + +@lru_cache(maxsize=1) +def _load_jieba_posseg() -> Any: + try: + import jieba + import jieba.posseg as pseg + except Exception: + return None + for word in _iter_jieba_custom_words(): + jieba.add_word(word, freq=100000) + return pseg + + +def _iter_jieba_custom_words() -> Iterable[str]: + yield from JIEBA_CUSTOM_WORDS + yield from DIRECT_MUNICIPALITY_DISPLAY + yield from CITY_TO_PROVINCE + + +def resolve_jieba_tokens(text: str) -> list[tuple[str, str]]: + posseg = _load_jieba_posseg() + if posseg is None: + return [] + try: + return [ + (str(item.word or "").strip(), str(item.flag or "").strip()) + for item in posseg.cut(str(text or ""), HMM=True) + if str(item.word or "").strip() + ] + except Exception: + return [] + + +def strip_route_location_prefix_with_jieba(value: object) -> str: + text = str(value or "").strip() + match = ROUTE_LOCATION_PREFIX_PATTERN.search(text) + if not match: + return text + + body = match.group("body").strip() + tokens = resolve_jieba_tokens(body) + if not tokens: + return text + first_word, first_tag = tokens[0] + if not _is_jieba_location_token(first_word, first_tag): + return text + return body[len(first_word) :].strip(" ::,,。;;、") + + +def _is_jieba_location_token(word: str, tag: str) -> bool: + if tag in JIEBA_LOCATION_TAGS: + return True + return word in DIRECT_MUNICIPALITY_DISPLAY or word in CITY_TO_PROVINCE + + +def _parse_lac_result(result: Any) -> list[tuple[str, str]]: + if ( + isinstance(result, (list, tuple)) + and len(result) == 2 + and isinstance(result[0], list) + and isinstance(result[1], list) + ): + return [ + (str(word or "").strip(), str(tag or "").strip()) + for word, tag in zip(result[0], result[1], strict=False) + if str(word or "").strip() + ] + if isinstance(result, list) and all( + isinstance(item, (list, tuple)) and len(item) >= 2 + for item in result + ): + return [ + (str(item[0] or "").strip(), str(item[1] or "").strip()) + for item in result + if str(item[0] or "").strip() + ] + return [] diff --git a/server/src/app/services/attachment_association_jobs.py b/server/src/app/services/attachment_association_jobs.py new file mode 100644 index 0000000..a25b4ea --- /dev/null +++ b/server/src/app/services/attachment_association_jobs.py @@ -0,0 +1,549 @@ +from __future__ import annotations + +import re +from dataclasses import dataclass, field +from datetime import UTC, date, datetime +from decimal import Decimal +from threading import Lock +from typing import Any, Callable +from uuid import uuid4 + +from sqlalchemy.orm import Session, sessionmaker + +from app.api.deps import CurrentUserContext +from app.models.financial_record import ExpenseClaim, ExpenseClaimItem +from app.schemas.attachment_association_job import ( + AttachmentAssociationJobCreate, + AttachmentAssociationJobRead, +) +from app.schemas.receipt_folder import ReceiptFolderDetailRead +from app.schemas.reimbursement import ExpenseClaimItemCreate +from app.services.expense_claim_constants import ( + DOCUMENT_TYPE_ITEM_TYPE_MAP, + EDITABLE_CLAIM_STATUSES, +) +from app.services.expense_claims import ExpenseClaimService +from app.services.receipt_folder import ReceiptFolderService + + +CITY_NAMES = ( + "北京", + "上海", + "广州", + "深圳", + "武汉", + "南京", + "杭州", + "成都", + "重庆", + "西安", + "天津", + "苏州", + "长沙", + "郑州", + "青岛", + "厦门", + "宁波", + "无锡", + "合肥", + "福州", + "昆明", + "大连", + "沈阳", + "济南", + "哈尔滨", + "长春", + "南昌", + "太原", + "贵阳", + "南宁", + "石家庄", + "兰州", + "银川", + "西宁", + "海口", + "拉萨", +) + +TERMINAL_STATUSES = {"succeeded", "failed"} + + +@dataclass(slots=True) +class AttachmentAssociationJobState: + job_id: str + owner_username: str + owner_name: str + receipt_ids: list[str] + prompt: str = "" + conversation_id: str = "" + status: str = "queued" + message: str = "已创建附件关联任务,等待后台处理。" + claim_id: str = "" + claim_no: str = "" + uploaded_count: int = 0 + skipped_count: int = 0 + error: str = "" + created_at: datetime = field(default_factory=lambda: datetime.now(UTC)) + updated_at: datetime = field(default_factory=lambda: datetime.now(UTC)) + + def to_read(self) -> AttachmentAssociationJobRead: + return AttachmentAssociationJobRead( + job_id=self.job_id, + status=self.status, + message=self.message, + receipt_ids=list(self.receipt_ids), + claim_id=self.claim_id, + claim_no=self.claim_no, + uploaded_count=self.uploaded_count, + skipped_count=self.skipped_count, + error=self.error, + prompt=self.prompt, + conversation_id=self.conversation_id, + created_at=self.created_at, + updated_at=self.updated_at, + ) + + +@dataclass(slots=True) +class AttachmentAssociationCandidate: + claim: ExpenseClaim + score: int + reasons: list[str] + + +_jobs: dict[str, AttachmentAssociationJobState] = {} +_jobs_lock = Lock() + + +def clear_attachment_association_jobs_for_tests() -> None: + with _jobs_lock: + _jobs.clear() + + +def create_attachment_association_job( + payload: AttachmentAssociationJobCreate, + current_user: CurrentUserContext, +) -> AttachmentAssociationJobRead: + job_id = f"attachment-association-{uuid4()}" + state = AttachmentAssociationJobState( + job_id=job_id, + owner_username=str(current_user.username or "").strip(), + owner_name=str(current_user.name or "").strip(), + receipt_ids=list(payload.receipt_ids), + prompt=str(payload.prompt or "").strip(), + conversation_id=str(payload.conversation_id or "").strip(), + ) + with _jobs_lock: + _jobs[job_id] = state + return state.to_read() + + +def get_attachment_association_job( + job_id: str, + current_user: CurrentUserContext, +) -> AttachmentAssociationJobRead | None: + state = _get_authorized_state(job_id, current_user) + return state.to_read() if state is not None else None + + +def run_attachment_association_job( + job_id: str, + current_user: CurrentUserContext, + session_factory: sessionmaker[Session] | Callable[[], Session], +) -> None: + state = _get_authorized_state(job_id, current_user) + if state is None or state.status in TERMINAL_STATUSES: + return + + _update_job(job_id, status="running", message="正在匹配可关联的报销草稿...") + try: + with session_factory() as db: + result = AttachmentAssociationJobRunner(db).run( + receipt_ids=state.receipt_ids, + current_user=current_user, + ) + _update_job( + job_id, + status="succeeded", + message=f"已自动关联到 {result['claim_no']},成功归集 {result['uploaded_count']} 份附件。", + claim_id=str(result["claim_id"]), + claim_no=str(result["claim_no"]), + uploaded_count=int(result["uploaded_count"]), + skipped_count=int(result["skipped_count"]), + error="", + ) + except Exception as exc: + message = str(exc).strip() or "自动关联任务执行失败,请稍后重试。" + _update_job( + job_id, + status="failed", + message=message, + error=message, + ) + + +def _get_authorized_state( + job_id: str, + current_user: CurrentUserContext, +) -> AttachmentAssociationJobState | None: + normalized_job_id = str(job_id or "").strip() + with _jobs_lock: + state = _jobs.get(normalized_job_id) + if state is None: + return None + if current_user.is_admin: + return state + username = str(current_user.username or "").strip() + name = str(current_user.name or "").strip() + if username and username == state.owner_username: + return state + if name and name == state.owner_name: + return state + return None + + +def _update_job(job_id: str, **updates: Any) -> None: + with _jobs_lock: + state = _jobs.get(str(job_id or "").strip()) + if state is None: + return + for key, value in updates.items(): + if hasattr(state, key): + setattr(state, key, value) + state.updated_at = datetime.now(UTC) + + +class AttachmentAssociationJobRunner: + def __init__(self, db: Session) -> None: + self.db = db + self.claim_service = ExpenseClaimService(db) + self.receipt_service = ReceiptFolderService() + + def run( + self, + *, + receipt_ids: list[str], + current_user: CurrentUserContext, + ) -> dict[str, Any]: + receipts = self._load_receipts(receipt_ids, current_user) + candidates = self._rank_claims(receipts, current_user) + if not candidates: + raise ValueError("没有找到可自动关联的报销草稿,请先新建草稿或补充说明。") + + recommended = candidates[0] + runner_up = candidates[1] if len(candidates) > 1 else None + if recommended.score < 5 or (runner_up is not None and recommended.score - runner_up.score < 2): + raise ValueError("找到多个可能关联的报销草稿,请补充说明或手动选择后再归集。") + + uploaded_count = 0 + skipped_count = 0 + for receipt in receipts: + if self._is_linked_to_other_claim(receipt, recommended.claim.id): + skipped_count += 1 + continue + target_item = self._resolve_target_item( + claim_id=recommended.claim.id, + receipt=receipt, + current_user=current_user, + ) + source_path, media_type, file_name = self.receipt_service.resolve_source(receipt.id, current_user) + result = self.claim_service.upload_claim_item_attachment( + claim_id=recommended.claim.id, + item_id=target_item.id, + filename=file_name, + content=source_path.read_bytes(), + media_type=media_type, + current_user=current_user, + source_receipt_id=receipt.id, + ) + if result is None: + skipped_count += 1 + else: + uploaded_count += 1 + + if uploaded_count <= 0: + raise ValueError("未能归集任何附件,请进入报销单详情手动核对。") + return { + "claim_id": recommended.claim.id, + "claim_no": recommended.claim.claim_no, + "uploaded_count": uploaded_count, + "skipped_count": skipped_count, + } + + def _load_receipts( + self, + receipt_ids: list[str], + current_user: CurrentUserContext, + ) -> list[ReceiptFolderDetailRead]: + receipts = [] + for receipt_id in list(dict.fromkeys(str(item or "").strip() for item in receipt_ids if str(item or "").strip())): + try: + receipts.append(self.receipt_service.get_receipt(receipt_id, current_user)) + except FileNotFoundError as exc: + raise ValueError("当前附件没有持久化票据记录,请重新上传后再试。") from exc + if not receipts: + raise ValueError("当前附件没有持久化票据记录,请重新上传后再试。") + return receipts + + def _rank_claims( + self, + receipts: list[ReceiptFolderDetailRead], + current_user: CurrentUserContext, + ) -> list[AttachmentAssociationCandidate]: + signals = _collect_receipt_signals(receipts) + claims = [ + claim + for claim in self.claim_service.list_claims(current_user) + if self._is_auto_association_candidate(claim) + ] + ranked = [ + candidate + for candidate in ( + self._score_claim(claim, signals) + for claim in claims + ) + if candidate.score > 0 + ] + return sorted(ranked, key=lambda item: item.score, reverse=True) + + def _is_auto_association_candidate(self, claim: ExpenseClaim) -> bool: + status = str(claim.status or "").strip().lower() + if status not in EDITABLE_CLAIM_STATUSES: + return False + return not self.claim_service._is_expense_application_claim(claim) + + def _score_claim( + self, + claim: ExpenseClaim, + signals: dict[str, Any], + ) -> AttachmentAssociationCandidate: + claim_text = _build_claim_text(claim) + compact_claim_text = _normalize_text(claim_text) + claim_dates = _extract_date_tokens(claim_text) + claim_cities = _unique([*_extract_city_tokens(claim_text), *_extract_city_tokens(claim.location)]) + reasons: list[str] = [] + score = 0 + + if _dates_overlap(signals["dates"], claim_dates): + score += 4 + reasons.append("票据日期与报销单日期一致") + + matched_cities = [city for city in signals["cities"] if city in compact_claim_text] + if matched_cities: + score += min(4, len(matched_cities) * 2) + reasons.append(f"地点或行程包含 {'、'.join(matched_cities)}") + + if len(claim_cities) >= 2 and len(matched_cities) >= 2: + score += 2 + reasons.append("票据往返城市与报销事由吻合") + + if str(claim.status or "").strip().lower() == "draft": + score += 1 + reasons.append("当前单据仍是可归集草稿") + + return AttachmentAssociationCandidate(claim=claim, score=score, reasons=reasons) + + @staticmethod + def _is_linked_to_other_claim(receipt: ReceiptFolderDetailRead, claim_id: str) -> bool: + linked_claim_id = str(receipt.linked_claim_id or "").strip() + return bool(str(receipt.status or "").strip() == "linked" and linked_claim_id and linked_claim_id != claim_id) + + def _resolve_target_item( + self, + *, + claim_id: str, + receipt: ReceiptFolderDetailRead, + current_user: CurrentUserContext, + ) -> ExpenseClaimItem: + claim = self.claim_service.get_claim(claim_id, current_user) + if claim is None: + raise ValueError("匹配到的报销草稿不存在,请刷新后再试。") + + preferred_type = _resolve_receipt_item_type(receipt) + empty_items = [ + item + for item in list(claim.items or []) + if not str(item.invoice_id or "").strip() and not item.is_system_generated + ] + for item in empty_items: + if preferred_type and str(item.item_type or "").strip() == preferred_type: + return item + if empty_items: + return empty_items[0] + + before_ids = {str(item.id) for item in list(claim.items or [])} + created_claim = self.claim_service.create_claim_item( + claim_id=claim.id, + payload=_build_item_payload_from_receipt(claim, receipt, preferred_type), + current_user=current_user, + ) + if created_claim is None: + raise ValueError("无法创建票据归集明细,请进入详情页手动处理。") + for item in list(created_claim.items or []): + if str(item.id) not in before_ids and not str(item.invoice_id or "").strip(): + return item + raise ValueError("无法找到可归集的费用明细,请进入详情页手动处理。") + + +def _normalize_text(value: Any) -> str: + return re.sub(r"\s+", "", str(value or "").strip()) + + +def _unique(values: list[str] | tuple[str, ...]) -> list[str]: + return list(dict.fromkeys(str(item or "").strip() for item in values if str(item or "").strip())) + + +def _extract_date_tokens(text: Any) -> list[str]: + source = str(text or "") + matches = [ + *re.finditer(r"20\d{2}[-/.年]\d{1,2}[-/.月]\d{1,2}", source), + *re.finditer(r"\d{1,2}月\d{1,2}", source), + ] + return _unique([_normalize_date_token(match.group(0)) for match in matches]) + + +def _normalize_date_token(value: Any) -> str: + if isinstance(value, (date, datetime)): + return value.isoformat()[:10] + text = str(value or "").strip() + full_match = re.search(r"(20\d{2})[-/.年](\d{1,2})[-/.月](\d{1,2})", text) + if full_match: + year, month, day = full_match.groups() + return f"{year}-{month.zfill(2)}-{day.zfill(2)}" + short_match = re.search(r"(\d{1,2})月(\d{1,2})", text) + if short_match: + month, day = short_match.groups() + return f"{month.zfill(2)}-{day.zfill(2)}" + return "" + + +def _extract_city_tokens(text: Any) -> list[str]: + compact = _normalize_text(text) + if not compact: + return [] + return [city for city in CITY_NAMES if city in compact] + + +def _dates_overlap(left: list[str], right: list[str]) -> bool: + for left_date in left: + if not left_date: + continue + for right_date in right: + if right_date and (left_date == right_date or left_date.endswith(right_date) or right_date.endswith(left_date)): + return True + return False + + +def _collect_receipt_signals(receipts: list[ReceiptFolderDetailRead]) -> dict[str, Any]: + text = "\n".join(_build_receipt_text(receipt) for receipt in receipts) + dates = _unique([ + *_extract_date_tokens(text), + *[str(receipt.document_date or "").strip() for receipt in receipts], + ]) + return { + "text": text, + "dates": dates, + "cities": _unique(_extract_city_tokens(text)), + } + + +def _build_receipt_text(receipt: ReceiptFolderDetailRead) -> str: + fields_text = "\n".join( + f"{field.label} {field.value}" + for field in list(receipt.fields or []) + if str(field.label or field.value or "").strip() + ) + return "\n".join( + value + for value in ( + receipt.file_name, + receipt.summary, + receipt.ocr_text, + receipt.document_date, + receipt.merchant_name, + fields_text, + ) + if str(value or "").strip() + ) + + +def _build_claim_text(claim: ExpenseClaim) -> str: + item_text = "\n".join( + " ".join( + str(value or "").strip() + for value in ( + item.item_date.isoformat() if item.item_date else "", + item.item_type, + item.item_reason, + item.item_location, + item.item_note, + ) + if str(value or "").strip() + ) + for item in list(claim.items or []) + ) + occurred_at = claim.occurred_at.isoformat()[:10] if claim.occurred_at else "" + return "\n".join( + value + for value in ( + claim.claim_no, + claim.expense_type, + claim.status, + claim.reason, + claim.location, + occurred_at, + item_text, + ) + if str(value or "").strip() + ) + + +def _resolve_receipt_item_type(receipt: ReceiptFolderDetailRead) -> str: + document_type = str(receipt.document_type or "").strip() + if document_type in DOCUMENT_TYPE_ITEM_TYPE_MAP: + return DOCUMENT_TYPE_ITEM_TYPE_MAP[document_type] + scene_code = str(receipt.scene_code or "").strip() + if scene_code == "travel": + return "travel" + return scene_code or "other" + + +def _build_item_payload_from_receipt( + claim: ExpenseClaim, + receipt: ReceiptFolderDetailRead, + preferred_type: str, +) -> ExpenseClaimItemCreate: + item_date = _resolve_receipt_item_date(receipt) or (claim.occurred_at.date() if claim.occurred_at else None) + return ExpenseClaimItemCreate( + item_date=item_date, + item_type=preferred_type or str(claim.expense_type or "").strip() or "other", + item_reason=str(receipt.summary or receipt.file_name or "").strip(), + item_location=_resolve_receipt_item_location(receipt) or str(claim.location or "").strip(), + item_amount=Decimal("0.00"), + ) + + +def _resolve_receipt_item_date(receipt: ReceiptFolderDetailRead) -> date | None: + for value in [ + *[field.value for field in list(receipt.fields or []) if "日期" in str(field.label or "") or "时间" in str(field.label or "")], + receipt.document_date, + ]: + token = _normalize_date_token(value) + if len(token) == 10: + try: + return date.fromisoformat(token) + except ValueError: + continue + return None + + +def _resolve_receipt_item_location(receipt: ReceiptFolderDetailRead) -> str: + for field in list(receipt.fields or []): + label = str(field.label or "") + value = str(field.value or "").strip() + if value and ("行程" in label or "到达" in label or "地点" in label or "城市" in label): + cities = _extract_city_tokens(value) + return cities[-1] if cities else value[:40] + cities = _extract_city_tokens(_build_receipt_text(receipt)) + return cities[-1] if cities else "" + diff --git a/server/src/app/services/linked_reimbursement_draft_jobs.py b/server/src/app/services/linked_reimbursement_draft_jobs.py new file mode 100644 index 0000000..dd68be4 --- /dev/null +++ b/server/src/app/services/linked_reimbursement_draft_jobs.py @@ -0,0 +1,291 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from datetime import UTC, datetime +from threading import Lock +from typing import Any, Callable +from uuid import uuid4 + +from sqlalchemy import select +from sqlalchemy.orm import Session, sessionmaker + +from app.api.deps import CurrentUserContext +from app.schemas.linked_reimbursement_draft_job import ( + LinkedReimbursementDraftJobCreate, + LinkedReimbursementDraftJobRead, +) +from app.schemas.ontology import OntologyParseResult, OntologyPermission +from app.schemas.orchestrator import OrchestratorRequest +from app.models.financial_record import ExpenseClaim +from app.services.expense_claims import ExpenseClaimService +from app.services.orchestrator import OrchestratorService + + +TERMINAL_STATUSES = {"succeeded", "failed"} + + +@dataclass(slots=True) +class LinkedReimbursementDraftJobState: + job_id: str + owner_username: str + owner_name: str + message: str + context_json: dict[str, Any] + conversation_id: str = "" + status: str = "queued" + status_message: str = "已创建报销草稿生成任务,等待后台处理。" + error: str = "" + run_id: str = "" + draft_payload: dict[str, Any] | None = None + created_at: datetime = field(default_factory=lambda: datetime.now(UTC)) + updated_at: datetime = field(default_factory=lambda: datetime.now(UTC)) + + def to_read(self) -> LinkedReimbursementDraftJobRead: + return LinkedReimbursementDraftJobRead( + job_id=self.job_id, + status=self.status, + message=self.status_message, + error=self.error, + run_id=self.run_id, + conversation_id=self.conversation_id, + draft_payload=self.draft_payload, + created_at=self.created_at, + updated_at=self.updated_at, + ) + + +_jobs: dict[str, LinkedReimbursementDraftJobState] = {} +_jobs_lock = Lock() + + +def clear_linked_reimbursement_draft_jobs_for_tests() -> None: + with _jobs_lock: + _jobs.clear() + + +def create_linked_reimbursement_draft_job( + payload: LinkedReimbursementDraftJobCreate, + current_user: CurrentUserContext, +) -> LinkedReimbursementDraftJobRead: + context_json = dict(payload.context_json or {}) + context_json["entry_source"] = context_json.get("entry_source") or "workbench-ai" + context_json["session_type"] = context_json.get("session_type") or "expense" + job_id = f"linked-reimbursement-draft-{uuid4()}" + state = LinkedReimbursementDraftJobState( + job_id=job_id, + owner_username=str(current_user.username or "").strip(), + owner_name=str(current_user.name or "").strip(), + message=str(payload.message or "").strip(), + context_json=context_json, + conversation_id=str(payload.conversation_id or "").strip(), + ) + with _jobs_lock: + _jobs[job_id] = state + return state.to_read() + + +def get_linked_reimbursement_draft_job( + job_id: str, + current_user: CurrentUserContext, +) -> LinkedReimbursementDraftJobRead | None: + state = _get_authorized_state(job_id, current_user) + return state.to_read() if state is not None else None + + +def run_linked_reimbursement_draft_job( + job_id: str, + current_user: CurrentUserContext, + session_factory: sessionmaker[Session] | Callable[[], Session], +) -> None: + state = _get_authorized_state(job_id, current_user) + if state is None or state.status in TERMINAL_STATUSES: + return + + _update_job(job_id, status="running", status_message="正在后台生成报销草稿...") + try: + with session_factory() as db: + if _can_use_direct_save_path(db, state.context_json): + run_id, result, draft_payload = _run_direct_save_path( + db=db, + state=state, + current_user=current_user, + ) + else: + response = OrchestratorService(db).run( + OrchestratorRequest( + source="user_message", + user_id=_resolve_user_id(current_user), + conversation_id=None, + message=state.message, + context_json=dict(state.context_json), + ) + ) + run_id = response.run_id + result = response.result if isinstance(response.result, dict) else {} + draft_payload = result.get("draft_payload") if isinstance(result.get("draft_payload"), dict) else None + if response.status != "succeeded": + raise ValueError(str(result.get("message") or "报销草稿生成失败,请稍后重试。").strip()) + + if draft_payload is None: + raise ValueError("报销草稿生成完成,但未返回草稿信息,请刷新单据列表后核对。") + + _update_job( + job_id, + status="succeeded", + status_message=str(result.get("message") or "报销草稿已生成。").strip(), + run_id=run_id, + draft_payload=draft_payload, + error="", + ) + except Exception as exc: + message = str(exc).strip() or "报销草稿生成失败,请稍后重试。" + _update_job( + job_id, + status="failed", + status_message=message, + error=message, + ) + + +def _get_authorized_state( + job_id: str, + current_user: CurrentUserContext, +) -> LinkedReimbursementDraftJobState | None: + normalized_job_id = str(job_id or "").strip() + with _jobs_lock: + state = _jobs.get(normalized_job_id) + if state is None: + return None + if current_user.is_admin: + return state + username = str(current_user.username or "").strip() + name = str(current_user.name or "").strip() + if username and username == state.owner_username: + return state + if name and name == state.owner_name: + return state + return None + + +def _update_job(job_id: str, **updates: Any) -> None: + with _jobs_lock: + state = _jobs.get(str(job_id or "").strip()) + if state is None: + return + for key, value in updates.items(): + if hasattr(state, key): + setattr(state, key, value) + state.updated_at = datetime.now(UTC) + + +def _resolve_user_id(current_user: CurrentUserContext) -> str: + return str(current_user.username or current_user.name or "anonymous").strip() or "anonymous" + + +def _can_use_direct_save_path(db: Session, context_json: dict[str, Any]) -> bool: + review_action = str((context_json or {}).get("review_action") or "").strip() + if review_action != "save_draft": + return False + review_values = context_json.get("review_form_values") + if not isinstance(review_values, dict): + return False + application_claim_id = str(review_values.get("application_claim_id") or "").strip() + application_claim_no = str(review_values.get("application_claim_no") or "").strip() + if not application_claim_no: + return False + if application_claim_id: + return True + return _find_application_claim_by_no(db, application_claim_no) is not None + + +def _find_application_claim_by_no(db: Session, claim_no: str) -> ExpenseClaim | None: + normalized_claim_no = str(claim_no or "").strip() + if not normalized_claim_no: + return None + claim = db.scalar( + select(ExpenseClaim) + .where(ExpenseClaim.claim_no == normalized_claim_no) + .limit(1) + ) + if claim is not None and ExpenseClaimService._is_expense_application_claim(claim): + return claim + return None + + +def _build_direct_context_json(db: Session, context_json: dict[str, Any]) -> dict[str, Any]: + direct_context = dict(context_json or {}) + review_values = dict(direct_context.get("review_form_values") or {}) + scene_selection = dict(direct_context.get("expense_scene_selection") or {}) + application_claim_id = str(review_values.get("application_claim_id") or "").strip() + application_claim_no = str(review_values.get("application_claim_no") or "").strip() + if not application_claim_id and application_claim_no: + application_claim = _find_application_claim_by_no(db, application_claim_no) + if application_claim is not None: + review_values["application_claim_id"] = application_claim.id + scene_selection["application_claim_id"] = application_claim.id + scene_selection["application_claim_no"] = str( + scene_selection.get("application_claim_no") + or application_claim.claim_no + or application_claim_no + ).strip() + direct_context["review_form_values"] = review_values + if scene_selection: + direct_context["expense_scene_selection"] = scene_selection + return direct_context + + +def _run_direct_save_path( + *, + db: Session, + state: LinkedReimbursementDraftJobState, + current_user: CurrentUserContext, +) -> tuple[str, dict[str, Any], dict[str, Any]]: + run_id = state.job_id + ontology = OntologyParseResult( + scenario="expense", + intent="draft", + permission=OntologyPermission( + level="draft_write", + allowed=True, + reason="关联申请单生成报销草稿快路径。", + ), + confidence=1.0, + run_id=run_id, + ) + result = ExpenseClaimService(db).save_or_submit_from_ontology( + run_id=run_id, + user_id=_resolve_user_id(current_user), + message=state.message, + ontology=ontology, + context_json=_build_direct_context_json(db, state.context_json), + ) + claim_id = str(result.get("claim_id") or "").strip() + claim_no = str(result.get("claim_no") or "").strip() + if not claim_id or not claim_no or str(result.get("status") or "").strip() != "draft": + raise ValueError(str(result.get("message") or "报销草稿生成失败,请稍后重试。").strip()) + + claim = db.get(ExpenseClaim, claim_id) + return run_id, result, _build_direct_draft_payload(result, claim) + + +def _build_direct_draft_payload( + result: dict[str, Any], + claim: ExpenseClaim | None, +) -> dict[str, Any]: + claim_id = str(result.get("claim_id") or getattr(claim, "id", "") or "").strip() + claim_no = str(result.get("claim_no") or getattr(claim, "claim_no", "") or "").strip() + status = str(result.get("status") or getattr(claim, "status", "") or "draft").strip() + approval_stage = str(getattr(claim, "approval_stage", "") or "待提交").strip() + expense_type = str(getattr(claim, "expense_type", "") or "").strip() + message = str(result.get("message") or "报销草稿已生成。").strip() + return { + "draft_type": "expense", + "title": f"费用草稿 {claim_no}" if claim_no else "费用草稿", + "body": message, + "confirmation_required": True, + "claim_id": claim_id, + "claim_no": claim_no, + "status": status, + "approval_stage": approval_stage, + "expense_type": expense_type, + } diff --git a/server/tests/test_attachment_association_jobs.py b/server/tests/test_attachment_association_jobs.py new file mode 100644 index 0000000..a7a3362 --- /dev/null +++ b/server/tests/test_attachment_association_jobs.py @@ -0,0 +1,280 @@ +from __future__ import annotations + +from collections.abc import Generator +from datetime import UTC, date, datetime +from decimal import Decimal + +from fastapi.testclient import TestClient +from sqlalchemy import select +from sqlalchemy.orm import Session, selectinload + +from app.api.deps import CurrentUserContext, get_db +from app.api.v1.endpoints import attachment_association_jobs as attachment_jobs_endpoint +from app.core.config import get_settings +from app.main import create_app +from app.models.employee import Employee +from app.models.financial_record import ExpenseClaim, ExpenseClaimItem +from app.schemas.ocr import OcrRecognizeBatchRead, OcrRecognizeDocumentRead, OcrRecognizeFieldRead +from app.services.attachment_association_jobs import clear_attachment_association_jobs_for_tests +from app.services.expense_claim_attachment_storage import ExpenseClaimAttachmentStorage +from app.services.ocr import OcrService +from app.services.receipt_folder import ReceiptFolderService +from app.test_helpers.db import build_in_memory_session_factory + + +def build_client(monkeypatch) -> tuple[TestClient, object]: + session_factory = build_in_memory_session_factory() + app = create_app() + + def override_db() -> Generator[Session, None, None]: + db = session_factory() + try: + yield db + finally: + db.close() + + app.dependency_overrides[get_db] = override_db + monkeypatch.setattr(attachment_jobs_endpoint, "get_session_factory", lambda: session_factory) + return TestClient(app), session_factory + + +def seed_travel_claim(db: Session) -> ExpenseClaim: + employee = Employee( + id="emp-bg-association", + employee_no="E10001", + name="张三", + email="zhangsan@example.com", + position="实施顾问", + grade="P4", + ) + claim = ExpenseClaim( + id="claim-bg-association", + claim_no="BX-20260220-001", + employee_id=employee.id, + employee_name=employee.name, + department_id="dept-delivery", + department_name="交付部", + project_code=None, + expense_type="travel", + reason="辅助国网仿生产服务器部署,武汉往返上海", + location="上海", + amount=Decimal("0.00"), + currency="CNY", + invoice_count=0, + occurred_at=datetime(2026, 2, 20, tzinfo=UTC), + submitted_at=None, + status="draft", + approval_stage="待提交", + risk_flags_json=[], + ) + item = ExpenseClaimItem( + id="item-bg-association-1", + claim_id=claim.id, + item_date=date(2026, 2, 20), + item_type="train_ticket", + item_reason="武汉至上海高铁", + item_location="上海", + item_amount=Decimal("0.00"), + invoice_id=None, + ) + claim.items = [item] + db.add_all([employee, claim]) + db.commit() + return claim + + +def save_train_receipt( + *, + service: ReceiptFolderService, + current_user: CurrentUserContext, + filename: str, + route: str, + trip_date: str, +) -> str: + receipt = service.save_receipt( + filename=filename, + content=f"fake-pdf-{filename}".encode("utf-8"), + media_type="application/pdf", + current_user=current_user, + document=OcrRecognizeDocumentRead( + filename=filename, + media_type="application/pdf", + text=f"电子发票(铁路电子客票) {route} {trip_date} 票价 354 元", + summary=f"铁路电子客票,{route},票价 354 元。", + avg_score=0.96, + line_count=1, + page_count=1, + document_type="train_ticket", + document_type_label="火车/高铁票", + scene_code="travel", + scene_label="差旅票据", + document_fields=[ + OcrRecognizeFieldRead(key="date", label="列车出发时间", value=trip_date), + OcrRecognizeFieldRead(key="route", label="行程", value=route), + OcrRecognizeFieldRead(key="amount", label="金额", value="354元"), + ], + ), + ) + return receipt.id + + +def fake_ocr_recognize( + self, + files: list[tuple[str, bytes, str | None]], +) -> OcrRecognizeBatchRead: + filename = files[0][0] + return OcrRecognizeBatchRead( + total_file_count=1, + success_count=1, + documents=[ + OcrRecognizeDocumentRead( + filename=filename, + media_type=files[0][2] or "application/pdf", + text="电子发票(铁路电子客票) 武汉 上海 2026-02-20 票价 354 元", + summary="铁路电子客票,武汉至上海,票价 354 元。", + avg_score=0.96, + line_count=1, + page_count=1, + document_type="train_ticket", + document_type_label="火车/高铁票", + scene_code="travel", + scene_label="差旅票据", + document_fields=[ + OcrRecognizeFieldRead(key="date", label="列车出发时间", value="2026-02-20"), + OcrRecognizeFieldRead(key="route", label="行程", value="武汉-上海"), + OcrRecognizeFieldRead(key="amount", label="金额", value="354元"), + ], + ) + ], + ) + + +def test_attachment_association_job_links_receipts_after_conversation_exit(monkeypatch, tmp_path) -> None: + monkeypatch.setenv("STORAGE_ROOT_DIR", str(tmp_path / "storage")) + get_settings.cache_clear() + clear_attachment_association_jobs_for_tests() + monkeypatch.setattr(OcrService, "recognize_files", fake_ocr_recognize) + monkeypatch.setattr(ExpenseClaimAttachmentStorage, "root", lambda self: tmp_path / "attachments") + try: + client, session_factory = build_client(monkeypatch) + current_user = CurrentUserContext( + username="zhangsan@example.com", + name="张三", + role_codes=["user"], + is_admin=False, + employee_no="E10001", + ) + with session_factory() as db: + seed_travel_claim(db) + + receipt_service = ReceiptFolderService() + receipt_ids = [ + save_train_receipt( + service=receipt_service, + current_user=current_user, + filename="2月20 武汉-上海.pdf", + route="武汉-上海", + trip_date="2026-02-20", + ), + save_train_receipt( + service=receipt_service, + current_user=current_user, + filename="2月23 上海-武汉.pdf", + route="上海-武汉", + trip_date="2026-02-23", + ), + ] + + headers = { + "x-auth-username": "zhangsan@example.com", + "x-auth-name": "Zhang San", + "x-auth-employee-no": "E10001", + "x-auth-role-codes": "user", + } + response = client.post( + "/api/v1/reimbursements/attachment-association-jobs", + headers=headers, + json={ + "receipt_ids": receipt_ids, + "prompt": "请帮我处理已上传的附件。", + "conversation_id": "inline-test", + }, + ) + + assert response.status_code == 202 + job_id = response.json()["job_id"] + + status_response = client.get( + f"/api/v1/reimbursements/attachment-association-jobs/{job_id}", + headers=headers, + ) + assert status_response.status_code == 200 + payload = status_response.json() + assert payload["status"] == "succeeded" + assert payload["claim_id"] == "claim-bg-association" + assert payload["claim_no"] == "BX-20260220-001" + assert payload["uploaded_count"] == 2 + + with session_factory() as db: + claim = db.scalar( + select(ExpenseClaim) + .options(selectinload(ExpenseClaim.items)) + .where(ExpenseClaim.id == "claim-bg-association") + ) + assert claim is not None + attached_items = [item for item in claim.items if item.invoice_id] + assert len(attached_items) == 2 + + linked_receipts = receipt_service.list_receipts(current_user=current_user, status_filter="linked") + assert {item.id for item in linked_receipts} == set(receipt_ids) + assert {item.linked_claim_id for item in linked_receipts} == {"claim-bg-association"} + finally: + clear_attachment_association_jobs_for_tests() + get_settings.cache_clear() + + +def test_attachment_association_job_fails_without_editable_claim(monkeypatch, tmp_path) -> None: + monkeypatch.setenv("STORAGE_ROOT_DIR", str(tmp_path / "storage")) + get_settings.cache_clear() + clear_attachment_association_jobs_for_tests() + try: + client, _session_factory = build_client(monkeypatch) + current_user = CurrentUserContext( + username="zhangsan@example.com", + name="张三", + role_codes=["user"], + is_admin=False, + employee_no="E10001", + ) + receipt_id = save_train_receipt( + service=ReceiptFolderService(), + current_user=current_user, + filename="2月20 武汉-上海.pdf", + route="武汉-上海", + trip_date="2026-02-20", + ) + + headers = { + "x-auth-username": "zhangsan@example.com", + "x-auth-name": "Zhang San", + "x-auth-employee-no": "E10001", + "x-auth-role-codes": "user", + } + response = client.post( + "/api/v1/reimbursements/attachment-association-jobs", + headers=headers, + json={"receipt_ids": [receipt_id], "conversation_id": "inline-empty"}, + ) + assert response.status_code == 202 + + status_response = client.get( + f"/api/v1/reimbursements/attachment-association-jobs/{response.json()['job_id']}", + headers=headers, + ) + assert status_response.status_code == 200 + payload = status_response.json() + assert payload["status"] == "failed" + assert "没有找到可自动关联的报销草稿" in payload["message"] + finally: + clear_attachment_association_jobs_for_tests() + get_settings.cache_clear() diff --git a/server/tests/test_linked_reimbursement_draft_jobs.py b/server/tests/test_linked_reimbursement_draft_jobs.py new file mode 100644 index 0000000..7db3942 --- /dev/null +++ b/server/tests/test_linked_reimbursement_draft_jobs.py @@ -0,0 +1,296 @@ +from __future__ import annotations + +from collections.abc import Generator +from datetime import UTC, datetime + +from fastapi.testclient import TestClient +from sqlalchemy.orm import Session + +from app.api.deps import get_db +from app.api.v1.endpoints import linked_reimbursement_draft_jobs as draft_jobs_endpoint +from app.main import create_app +from app.models.employee import Employee +from app.models.financial_record import ExpenseClaim +from app.schemas.orchestrator import OrchestratorResponse, OrchestratorTraceSummary +from app.services.linked_reimbursement_draft_jobs import clear_linked_reimbursement_draft_jobs_for_tests +from app.services.orchestrator import OrchestratorService +from app.test_helpers.db import build_in_memory_session_factory + + +def seed_employee_and_application(db: Session) -> None: + employee = Employee( + id="emp-linked-draft-fast", + employee_no="E10001", + name="张三", + email="zhangsan@example.com", + position="实施顾问", + grade="P5", + ) + application = ExpenseClaim( + id="application-linked-draft-fast", + claim_no="AP-202606-FAST", + employee_id=employee.id, + employee_name=employee.name, + department_id="dept-delivery", + department_name="交付部", + project_code=None, + expense_type="travel_application", + reason="支撑国网仿生产服务器部署", + location="上海", + amount=3000, + currency="CNY", + invoice_count=0, + occurred_at=datetime(2026, 2, 20, tzinfo=UTC), + submitted_at=None, + status="approved", + approval_stage="已完成", + risk_flags_json=[], + ) + db.add_all([employee, application]) + db.commit() + + +def build_client(monkeypatch) -> tuple[TestClient, object]: + session_factory = build_in_memory_session_factory() + app = create_app() + + def override_db() -> Generator[Session, None, None]: + db = session_factory() + try: + yield db + finally: + db.close() + + app.dependency_overrides[get_db] = override_db + monkeypatch.setattr(draft_jobs_endpoint, "get_session_factory", lambda: session_factory) + return TestClient(app), session_factory + + +def test_linked_reimbursement_draft_job_runs_after_conversation_leaves(monkeypatch) -> None: + clear_linked_reimbursement_draft_jobs_for_tests() + captured_messages = [] + + def fake_run(self, payload): + captured_messages.append(payload.message) + return OrchestratorResponse( + run_id="run-linked-draft-job", + conversation_id=None, + selected_agent="user_agent", + route_reason="测试后台生成报销草稿。", + permission_level="draft_write", + status="succeeded", + result={ + "message": "报销草稿已生成。", + "draft_payload": { + "claim_id": "draft-linked-1", + "claim_no": "RE-202606-009", + "status": "draft", + "expense_type": "travel", + }, + }, + requires_confirmation=False, + trace_summary=OrchestratorTraceSummary( + scenario="expense", + intent="draft", + tool_count=1, + failed_tool_count=0, + selected_capability_codes=[], + degraded=False, + ), + ) + + monkeypatch.setattr(OrchestratorService, "run", fake_run) + try: + client, _session_factory = build_client(monkeypatch) + headers = { + "x-auth-username": "zhangsan@example.com", + "x-auth-name": "Zhang San", + "x-auth-employee-no": "E10001", + "x-auth-role-codes": "user", + } + + response = client.post( + "/api/v1/reimbursements/linked-reimbursement-draft-jobs", + headers=headers, + json={ + "message": "我要报销\n用户选择报销场景:差旅费\n关联申请单:AP-202606-001", + "conversation_id": "inline-test", + "context_json": { + "review_action": "save_draft", + "expense_scene_selection": { + "expense_type": "travel", + "expense_type_label": "差旅费", + "application_claim_no": "AP-202606-001", + }, + "review_form_values": { + "application_claim_no": "AP-202606-001", + }, + }, + }, + ) + + assert response.status_code == 202 + job_id = response.json()["job_id"] + + status_response = client.get( + f"/api/v1/reimbursements/linked-reimbursement-draft-jobs/{job_id}", + headers=headers, + ) + assert status_response.status_code == 200 + payload = status_response.json() + assert payload["status"] == "succeeded" + assert payload["draft_payload"]["claim_no"] == "RE-202606-009" + assert payload["run_id"] == "run-linked-draft-job" + assert captured_messages == ["我要报销\n用户选择报销场景:差旅费\n关联申请单:AP-202606-001"] + finally: + clear_linked_reimbursement_draft_jobs_for_tests() + + +def test_linked_reimbursement_draft_job_uses_direct_save_path(monkeypatch) -> None: + clear_linked_reimbursement_draft_jobs_for_tests() + + def fail_if_orchestrator_runs(self, payload): + raise AssertionError("linked draft job should not run full orchestrator") + + monkeypatch.setattr(OrchestratorService, "run", fail_if_orchestrator_runs) + try: + client, session_factory = build_client(monkeypatch) + with session_factory() as db: + seed_employee_and_application(db) + + headers = { + "x-auth-username": "zhangsan@example.com", + "x-auth-name": "Zhang San", + "x-auth-employee-no": "E10001", + "x-auth-role-codes": "user", + } + + response = client.post( + "/api/v1/reimbursements/linked-reimbursement-draft-jobs", + headers=headers, + json={ + "message": "我要报销\n用户选择报销场景:差旅费\n关联申请单:AP-202606-FAST", + "conversation_id": "inline-fast-test", + "context_json": { + "name": "张三", + "review_action": "save_draft", + "expense_scene_selection": { + "expense_type": "travel", + "expense_type_label": "差旅费", + "application_claim_id": "application-linked-draft-fast", + "application_claim_no": "AP-202606-FAST", + }, + "review_form_values": { + "expense_type": "差旅费", + "reason": "支撑国网仿生产服务器部署", + "location": "上海", + "time_range": "2026-02-20 至 2026-02-23", + "application_claim_id": "application-linked-draft-fast", + "application_claim_no": "AP-202606-FAST", + "application_reason": "支撑国网仿生产服务器部署", + "application_location": "上海", + "application_amount": "3000", + "application_amount_label": "¥3,000", + "application_business_time": "2026-02-20 至 2026-02-23", + }, + }, + }, + ) + + assert response.status_code == 202 + job_id = response.json()["job_id"] + + status_response = client.get( + f"/api/v1/reimbursements/linked-reimbursement-draft-jobs/{job_id}", + headers=headers, + ) + assert status_response.status_code == 200 + payload = status_response.json() + assert payload["status"] == "succeeded" + assert payload["draft_payload"]["claim_no"] + assert payload["draft_payload"]["claim_id"] + assert payload["run_id"].startswith("linked-reimbursement-draft-") + + with session_factory() as db: + draft = db.get(ExpenseClaim, payload["draft_payload"]["claim_id"]) + assert draft is not None + assert draft.status == "draft" + assert draft.expense_type == "travel" + assert draft.reason == "支撑国网仿生产服务器部署" + assert draft.items == [] + finally: + clear_linked_reimbursement_draft_jobs_for_tests() + + +def test_linked_reimbursement_draft_job_uses_direct_save_path_with_application_no_only(monkeypatch) -> None: + clear_linked_reimbursement_draft_jobs_for_tests() + + def fail_if_orchestrator_runs(self, payload): + raise AssertionError("linked draft job should resolve application no without full orchestrator") + + monkeypatch.setattr(OrchestratorService, "run", fail_if_orchestrator_runs) + try: + client, session_factory = build_client(monkeypatch) + with session_factory() as db: + seed_employee_and_application(db) + + headers = { + "x-auth-username": "zhangsan@example.com", + "x-auth-name": "Zhang San", + "x-auth-employee-no": "E10001", + "x-auth-role-codes": "user", + } + + response = client.post( + "/api/v1/reimbursements/linked-reimbursement-draft-jobs", + headers=headers, + json={ + "message": "我要报销\n用户选择报销场景:差旅费\n关联申请单:AP-202606-FAST", + "conversation_id": "inline-fast-no-id-test", + "context_json": { + "name": "张三", + "review_action": "save_draft", + "expense_scene_selection": { + "expense_type": "travel", + "expense_type_label": "差旅费", + "application_claim_no": "AP-202606-FAST", + }, + "review_form_values": { + "expense_type": "差旅费", + "reason": "支撑国网仿生产服务器部署", + "location": "上海", + "time_range": "2026-02-20 至 2026-02-23", + "application_claim_no": "AP-202606-FAST", + "application_reason": "支撑国网仿生产服务器部署", + "application_location": "上海", + "application_amount": "3000", + }, + }, + }, + ) + + assert response.status_code == 202 + job_id = response.json()["job_id"] + + status_response = client.get( + f"/api/v1/reimbursements/linked-reimbursement-draft-jobs/{job_id}", + headers=headers, + ) + assert status_response.status_code == 200 + payload = status_response.json() + assert payload["status"] == "succeeded" + assert payload["draft_payload"]["claim_no"] + assert payload["draft_payload"]["claim_id"] + + with session_factory() as db: + draft = db.get(ExpenseClaim, payload["draft_payload"]["claim_id"]) + assert draft is not None + link_flag = next( + flag + for flag in draft.risk_flags_json + if flag.get("source") == "application_link" + ) + assert link_flag["application_claim_no"] == "AP-202606-FAST" + assert link_flag["application_claim_id"] == "application-linked-draft-fast" + finally: + clear_linked_reimbursement_draft_jobs_for_tests()