diff --git a/server/src/app/api/v1/endpoints/settings.py b/server/src/app/api/v1/endpoints/settings.py index 6c4526e..069819e 100644 --- a/server/src/app/api/v1/endpoints/settings.py +++ b/server/src/app/api/v1/endpoints/settings.py @@ -5,18 +5,20 @@ from typing import Annotated from fastapi import APIRouter, Depends, Header, HTTPException, status from sqlalchemy.orm import Session -from app.api.deps import get_db +from app.api.deps import CurrentUserContext, get_db, require_admin_user from app.core.config import get_settings as get_runtime_settings from app.schemas.common import ErrorResponse from app.schemas.settings import ( ModelConnectivityTestRead, ModelConnectivityTestRequest, RuntimeModelConfigRead, + SettingsCacheClearRead, SettingsRead, SettingsWrite, ) from app.services.model_connectivity import probe_model_connectivity from app.services.settings import SettingsService +from app.services.system_cache import SystemCacheService router = APIRouter(prefix="/settings") DbSession = Annotated[Session, Depends(get_db)] @@ -93,6 +95,24 @@ def test_model_connectivity( return probe_model_connectivity(resolved_payload) +@router.post( + "/cache/clear", + response_model=SettingsCacheClearRead, + summary="清理系统缓存", + description="清理 OCR、模型失败冷却、知识库索引和运行时配置等进程内缓存,不删除业务文件或数据库记录。", + responses={ + status.HTTP_403_FORBIDDEN: { + "model": ErrorResponse, + "description": "只有管理员可以清理系统缓存。", + } + }, +) +def clear_system_cache( + _: Annotated[CurrentUserContext, Depends(require_admin_user)], +) -> SettingsCacheClearRead: + return SystemCacheService().clear_all() + + @router.get( "/runtime-models/{slot}", response_model=RuntimeModelConfigRead, diff --git a/server/src/app/core/config.py b/server/src/app/core/config.py index c9398c3..7c8dd82 100644 --- a/server/src/app/core/config.py +++ b/server/src/app/core/config.py @@ -169,6 +169,12 @@ def _clear_settings_cache() -> None: _settings_cache_signature = None +def clear_runtime_settings_cache() -> int: + cleared_count = int(_settings_cache is not None) + _clear_settings_cache() + return cleared_count + + def get_settings() -> Settings: global _settings_cache, _settings_cache_signature diff --git a/server/src/app/schemas/settings.py b/server/src/app/schemas/settings.py index 9ec888b..ebfd375 100644 --- a/server/src/app/schemas/settings.py +++ b/server/src/app/schemas/settings.py @@ -222,6 +222,17 @@ class ModelConnectivityTestRead(BaseModel): checked_at: datetime +class SettingsCacheClearItemRead(BaseModel): + cacheKey: str + label: str + clearedCount: int = Field(default=0, ge=0) + + +class SettingsCacheClearRead(BaseModel): + totalCleared: int = Field(default=0, ge=0) + items: list[SettingsCacheClearItemRead] = Field(default_factory=list) + + class RuntimeModelConfigRead(BaseModel): slot: Literal["main", "backup", "embedding", "reranker"] provider: str diff --git a/server/src/app/services/application_location_semantics.py b/server/src/app/services/application_location_semantics.py index 510deda..f7ae60e 100644 --- a/server/src/app/services/application_location_semantics.py +++ b/server/src/app/services/application_location_semantics.py @@ -123,6 +123,14 @@ def _load_jieba_posseg() -> Any: return pseg +def clear_application_location_semantic_caches() -> int: + cleared_count = _load_lac_analyzer.cache_info().currsize + cleared_count += _load_jieba_posseg.cache_info().currsize + _load_lac_analyzer.cache_clear() + _load_jieba_posseg.cache_clear() + return cleared_count + + def _iter_jieba_custom_words() -> Iterable[str]: yield from JIEBA_CUSTOM_WORDS yield from DIRECT_MUNICIPALITY_DISPLAY diff --git a/server/src/app/services/expense_claim_attachment_document.py b/server/src/app/services/expense_claim_attachment_document.py index e26e021..faf2b53 100644 --- a/server/src/app/services/expense_claim_attachment_document.py +++ b/server/src/app/services/expense_claim_attachment_document.py @@ -111,9 +111,20 @@ from app.services.ocr import OcrService class ExpenseClaimAttachmentDocumentMixin: - def _build_attachment_payload(self, item: ExpenseClaimItem) -> dict[str, Any]: + def _build_attachment_payload( + self, + item: ExpenseClaimItem, + *, + current_user: CurrentUserContext | None = None, + ) -> dict[str, Any]: file_path, media_type, filename = self._resolve_item_attachment_content(item) metadata = self._attachment_storage.read_meta(file_path) + metadata = self._repair_attachment_metadata_from_source_receipt_if_needed( + file_path=file_path, + metadata=metadata, + item=item, + current_user=current_user, + ) metadata = self._repair_pdf_text_layer_metadata_if_needed( file_path=file_path, metadata=metadata, @@ -164,6 +175,108 @@ class ExpenseClaimAttachmentDocumentMixin: "requirement_check": requirement_check, } + def _repair_attachment_metadata_from_source_receipt_if_needed( + self, + *, + file_path: Path, + metadata: dict[str, Any], + item: ExpenseClaimItem, + current_user: CurrentUserContext | None, + ) -> dict[str, Any]: + if not metadata or current_user is None: + return metadata + + source_receipt_id = str(metadata.get("source_receipt_id") or "").strip() + if not source_receipt_id: + return metadata + if not self._attachment_metadata_needs_source_receipt_repair(metadata): + return metadata + + source_document = self._resolve_source_receipt_document( + source_receipt_id=source_receipt_id, + current_user=current_user, + fallback_filename=str(metadata.get("file_name") or file_path.name), + fallback_media_type=str(metadata.get("media_type") or ""), + ) + if source_document is None: + return metadata + + document_info = self._build_attachment_document_info(source_document) + requirement_check = self._build_attachment_requirement_check( + item=item, + document_info=document_info, + ) + preview_meta = self._attachment_presentation.build_preview_meta( + file_path=file_path, + media_type=str( + metadata.get("media_type") + or self._attachment_presentation.resolve_media_type(file_path.name) + ), + ocr_document=source_document, + ) + + metadata.update( + { + "previewable": bool(preview_meta["previewable"]), + "preview_kind": str(preview_meta["preview_kind"]), + "preview_storage_key": str(preview_meta["preview_storage_key"]), + "preview_media_type": str(preview_meta["preview_media_type"]), + "preview_file_name": str(preview_meta["preview_file_name"]), + "preview_rendered_with": str(preview_meta.get("preview_rendered_with") or ""), + "analysis": self._build_attachment_analysis( + document=source_document, + item=item, + claim=getattr(item, "claim", None), + document_info=document_info, + requirement_check=requirement_check, + ), + "document_info": document_info, + "requirement_check": requirement_check, + "ocr_status": "recognized", + "ocr_error": "", + "ocr_text": str(getattr(source_document, "text", "") or ""), + "ocr_summary": str(getattr(source_document, "summary", "") or ""), + "ocr_avg_score": float(getattr(source_document, "avg_score", 0.0) or 0.0), + "ocr_line_count": int(getattr(source_document, "line_count", 0) or 0), + "ocr_classification_source": str( + getattr(source_document, "classification_source", "") or "" + ), + "ocr_classification_confidence": float( + getattr(source_document, "classification_confidence", 0.0) or 0.0 + ), + "ocr_classification_evidence": [ + str(value) + for value in list(getattr(source_document, "classification_evidence", []) or []) + if str(value).strip() + ], + "ocr_warnings": [ + str(value) + for value in list(getattr(source_document, "warnings", []) or []) + if str(value).strip() + ], + } + ) + self._attachment_storage.write_meta(file_path, metadata) + return metadata + + @classmethod + def _attachment_metadata_needs_source_receipt_repair(cls, metadata: dict[str, Any]) -> bool: + document_info = metadata.get("document_info") + document_type = "" + fields: list[Any] = [] + if isinstance(document_info, dict): + document_type = str(document_info.get("document_type") or "").strip() + fields = list(document_info.get("fields") or []) + + return ( + str(metadata.get("preview_kind") or "").strip() != "image" + or document_type in {"", "other"} + or not any( + isinstance(field, dict) and str(field.get("value") or "").strip() + for field in fields + ) + ) + @classmethod def _attachment_metadata_needs_analysis_refresh(cls, metadata: dict[str, Any]) -> bool: analysis = metadata.get("analysis") diff --git a/server/src/app/services/expense_claim_attachment_operations.py b/server/src/app/services/expense_claim_attachment_operations.py index 1603dc5..99df4b6 100644 --- a/server/src/app/services/expense_claim_attachment_operations.py +++ b/server/src/app/services/expense_claim_attachment_operations.py @@ -313,8 +313,9 @@ class ExpenseClaimAttachmentOperationsMixin: if not normalized_receipt_id: return None + receipt_service = ReceiptFolderService() try: - receipt = ReceiptFolderService().get_receipt(normalized_receipt_id, current_user) + receipt = receipt_service.get_receipt(normalized_receipt_id, current_user) except FileNotFoundError: return None @@ -325,6 +326,20 @@ class ExpenseClaimAttachmentOperationsMixin: if not fields: fields = self._normalize_receipt_document_fields(raw_meta.get("document_fields")) + preview_source_path = None + preview_media_type = "" + preview_file_name = "" + if str(raw_meta.get("preview_kind") or "").strip() == "image": + try: + preview_source_path, preview_media_type, preview_file_name = receipt_service.resolve_preview( + normalized_receipt_id, + current_user, + ) + except FileNotFoundError: + preview_source_path = None + preview_media_type = "" + preview_file_name = "" + document = SimpleNamespace( filename=str(receipt.file_name or fallback_filename or "").strip(), media_type=str(receipt.media_type or fallback_media_type or "application/octet-stream").strip(), @@ -359,6 +374,9 @@ class ExpenseClaimAttachmentOperationsMixin: document_fields=fields, preview_kind=str(raw_meta.get("preview_kind") or ""), preview_data_url="", + preview_source_path=str(preview_source_path or ""), + preview_media_type=preview_media_type, + preview_file_name=preview_file_name, warnings=[ str(value) for value in list(receipt.warnings or raw_meta.get("ocr_warnings") or []) @@ -399,8 +417,16 @@ class ExpenseClaimAttachmentOperationsMixin: source_type = cls._attachment_document_type(source_receipt_document) upload_type = cls._attachment_document_type(upload_ocr_document) + if source_type in {"", "other"} and upload_type not in {"", "other"}: + return upload_ocr_document if source_type not in {"", "other"} and upload_type in {"", "other"}: return source_receipt_document + if ( + cls._attachment_has_image_preview(source_receipt_document) + and not cls._attachment_has_image_preview(upload_ocr_document) + and source_score >= upload_score + ): + return source_receipt_document if ( source_type == upload_type and cls._attachment_document_field_count(source_receipt_document) @@ -438,6 +464,15 @@ class ExpenseClaimAttachmentOperationsMixin: return 0 return len(list(getattr(document, "document_fields", []) or [])) + @staticmethod + def _attachment_has_image_preview(document: Any | None) -> bool: + if document is None: + return False + return str(getattr(document, "preview_kind", "") or "").strip() == "image" and bool( + str(getattr(document, "preview_data_url", "") or "").strip() + or str(getattr(document, "preview_source_path", "") or "").strip() + ) + def get_claim_item_attachment_meta( self, *, @@ -453,7 +488,7 @@ class ExpenseClaimAttachmentOperationsMixin: if claim is None: return None - return self._build_attachment_payload(item) + return self._build_attachment_payload(item, current_user=current_user) def get_claim_item_attachment_content( self, @@ -487,7 +522,7 @@ class ExpenseClaimAttachmentOperationsMixin: if claim is None: return None - return self._resolve_item_attachment_preview_content(item) + return self._resolve_item_attachment_preview_content(item, current_user=current_user) def delete_claim_item_attachment( self, @@ -740,9 +775,20 @@ class ExpenseClaimAttachmentOperationsMixin: self._attachment_storage.write_meta(file_path, metadata) return metadata - def _resolve_item_attachment_preview_content(self, item: ExpenseClaimItem) -> tuple[Path, str, str]: + def _resolve_item_attachment_preview_content( + self, + item: ExpenseClaimItem, + *, + current_user: CurrentUserContext | None = None, + ) -> tuple[Path, str, str]: file_path, media_type, filename = self._resolve_item_attachment_content(item) metadata = self._attachment_storage.read_meta(file_path) + metadata = self._repair_attachment_metadata_from_source_receipt_if_needed( + file_path=file_path, + metadata=metadata, + item=item, + current_user=current_user, + ) metadata = self._repair_pdf_text_layer_metadata_if_needed( file_path=file_path, metadata=metadata, diff --git a/server/src/app/services/expense_claim_attachment_presentation.py b/server/src/app/services/expense_claim_attachment_presentation.py index 3f88d51..3ae7f08 100644 --- a/server/src/app/services/expense_claim_attachment_presentation.py +++ b/server/src/app/services/expense_claim_attachment_presentation.py @@ -1,6 +1,7 @@ from __future__ import annotations import mimetypes +import shutil from pathlib import Path from typing import Any from urllib.parse import quote @@ -43,6 +44,25 @@ class ExpenseClaimAttachmentPresentation: "preview_rendered_with": DocumentPreviewAssets.renderer_id_for_source(media_type), } + preview_source_path = getattr(ocr_document, "preview_source_path", None) + if preview_source_kind == "image" and preview_source_path: + preview_asset = self._copy_preview_asset_from_source( + attachment_dir=file_path.parent, + original_filename=filename, + preview_source_path=Path(preview_source_path), + preview_media_type=str(getattr(ocr_document, "preview_media_type", "") or ""), + ) + if preview_asset is not None: + preview_path, preview_media_type, preview_file_name = preview_asset + return { + "previewable": True, + "preview_kind": "image", + "preview_storage_key": self.storage.to_storage_key(preview_path), + "preview_media_type": preview_media_type, + "preview_file_name": preview_file_name, + "preview_rendered_with": DocumentPreviewAssets.renderer_id_for_source(media_type), + } + if preview_kind: return { "previewable": True, @@ -88,6 +108,28 @@ class ExpenseClaimAttachmentPresentation: preview_data_url=preview_data_url, ) + def _copy_preview_asset_from_source( + self, + *, + attachment_dir: Path, + original_filename: str, + preview_source_path: Path, + preview_media_type: str, + ) -> tuple[Path, str, str] | None: + if not preview_source_path.exists() or not preview_source_path.is_file(): + return None + + suffix = preview_source_path.suffix or DocumentPreviewAssets.PDF_PREVIEW_SUFFIX + preview_name = f"{Path(original_filename).stem}.preview{suffix}" + preview_path = attachment_dir / preview_name + shutil.copyfile(preview_source_path, preview_path) + resolved_media_type = ( + preview_media_type + or mimetypes.guess_type(preview_source_path.name)[0] + or DocumentPreviewAssets.PDF_PREVIEW_MEDIA_TYPE + ) + return preview_path, resolved_media_type, preview_name + @staticmethod def build_preview_client_path(claim_id: str, item_id: str) -> str: return ( diff --git a/server/src/app/services/knowledge_rag_local.py b/server/src/app/services/knowledge_rag_local.py index 089997a..120b84e 100644 --- a/server/src/app/services/knowledge_rag_local.py +++ b/server/src/app/services/knowledge_rag_local.py @@ -108,6 +108,13 @@ _index_lock = threading.RLock() _index_cache: dict[Path, tuple[tuple[int, int], list[dict[str, Any]]]] = {} +def clear_local_knowledge_index_cache() -> int: + with _index_lock: + cleared_count = len(_index_cache) + _index_cache.clear() + return cleared_count + + @dataclass(frozen=True, slots=True) class LocalKnowledgeSearchResult: hits: list[dict[str, Any]] diff --git a/server/src/app/services/ocr.py b/server/src/app/services/ocr.py index 704e3f6..1367972 100644 --- a/server/src/app/services/ocr.py +++ b/server/src/app/services/ocr.py @@ -148,13 +148,23 @@ class OcrService: for item in pdf_inputs: cache_keys_by_source.setdefault(item.source_key, cache_key) except RuntimeError as exc: - documents.append( - OcrRecognizeDocumentRead( - filename=normalized_name, - media_type=resolved_media_type, - warnings=[str(exc)], - ) + fallback_document = self._build_pdf_text_layer_fallback_document( + filename=normalized_name, + media_type=resolved_media_type, + text_layer=text_layer, + render_warning=str(exc), ) + if fallback_document is not None: + documents.append(fallback_document) + self._write_cached_document(cache_key, fallback_document) + else: + documents.append( + OcrRecognizeDocumentRead( + filename=normalized_name, + media_type=resolved_media_type, + warnings=[str(exc)], + ) + ) continue source_key = uuid4().hex @@ -328,6 +338,13 @@ class OcrService: while len(cls._result_cache) > OCR_RESULT_CACHE_LIMIT: cls._result_cache.popitem(last=False) + @classmethod + def clear_result_cache(cls) -> int: + with cls._cache_lock: + cleared_count = len(cls._result_cache) + cls._result_cache.clear() + return cleared_count + @classmethod def _resolve_worker_semaphore(cls, limit: int) -> threading.Semaphore: normalized_limit = max(1, int(limit or 1)) @@ -425,6 +442,36 @@ class OcrService: ) return descriptors + def _build_pdf_text_layer_fallback_document( + self, + *, + filename: str, + media_type: str, + text_layer: str, + render_warning: str, + ) -> OcrRecognizeDocumentRead | None: + normalized_text = self._normalize_extracted_text(text_layer) + if self._meaningful_char_count(normalized_text) < 8: + return None + + aggregated = AggregatedOcrDocument( + filename=filename, + media_type=media_type, + source_key=uuid4().hex, + page_count=1, + warnings=[ + str(render_warning or "").strip() or "PDF 转图片失败。", + "PDF 转图片失败,已使用 PDF 文本层继续抽取识别信息。", + ], + lines=[ + OcrRecognizeLineRead(text=line, page_index=0) + for line in normalized_text.splitlines() + if line.strip() + ], + ) + aggregated.text_layer_fragments.append(normalized_text) + return self._finalize_document(aggregated) + def _extract_pdf_text_layer(self, pdf_path: Path) -> str: try: completed = subprocess.run( diff --git a/server/src/app/services/receipt_folder.py b/server/src/app/services/receipt_folder.py index b43f9b0..82201f9 100644 --- a/server/src/app/services/receipt_folder.py +++ b/server/src/app/services/receipt_folder.py @@ -889,6 +889,8 @@ class ReceiptFolderTrainTicketMixin: "无效", "二维码", "座席", + "身份", + "身份证号", "证件", ) ): @@ -993,6 +995,11 @@ class ReceiptFolderService(ReceiptFolderStorageMixin, ReceiptFolderItemMixin, Re current_user=current_user, ) if duplicate_receipt is not None: + duplicate_receipt = self._refresh_duplicate_receipt_from_document_if_stronger( + receipt=duplicate_receipt, + document=document, + current_user=current_user, + ) warning = "已上传过同样的单据,请不要重复上传。" existing_warnings = [str(item) for item in list(document.warnings or []) if str(item).strip()] enriched.append( @@ -1061,6 +1068,7 @@ class ReceiptFolderService(ReceiptFolderStorageMixin, ReceiptFolderItemMixin, Re if str(value).strip() ], "document_fields": self._build_ocr_document_fields_from_meta(meta), + "preview_kind": str(meta.get("preview_kind") or document.preview_kind or ""), } ) @@ -1073,6 +1081,62 @@ class ReceiptFolderService(ReceiptFolderStorageMixin, ReceiptFolderItemMixin, Re update["warnings"] = list(dict.fromkeys(warnings)) return document.model_copy(update=update) + def _refresh_duplicate_receipt_from_document_if_stronger( + self, + *, + receipt: ReceiptFolderItemRead, + document: OcrRecognizeDocumentRead, + current_user: CurrentUserContext, + ) -> ReceiptFolderItemRead: + try: + meta = self._read_receipt_meta(receipt.id, current_user) + except FileNotFoundError: + return receipt + + incoming_meta = self._build_document_meta(document) + if not self._is_incoming_document_meta_stronger(meta, incoming_meta): + return receipt + + for key in ( + "engine", + "model", + "ocr_text", + "summary", + "ocr_avg_score", + "ocr_line_count", + "page_count", + "document_type", + "document_type_label", + "scene_code", + "scene_label", + "ocr_classification_source", + "ocr_classification_confidence", + "ocr_classification_evidence", + "document_fields", + "ocr_warnings", + ): + meta[key] = incoming_meta[key] + meta["updated_at"] = datetime.now(UTC).isoformat() + self._write_meta(self._receipt_dir(self._owner_key(current_user), receipt.id), meta) + return self._build_item(meta) + + @staticmethod + def _is_incoming_document_meta_stronger(existing_meta: dict[str, Any], incoming_meta: dict[str, Any]) -> bool: + existing_type = str(existing_meta.get("document_type") or "other").strip() or "other" + incoming_type = str(incoming_meta.get("document_type") or "other").strip() or "other" + existing_fields = [field for field in list(existing_meta.get("document_fields") or []) if isinstance(field, dict)] + incoming_fields = [field for field in list(incoming_meta.get("document_fields") or []) if isinstance(field, dict)] + existing_text = str(existing_meta.get("ocr_text") or "").strip() + incoming_text = str(incoming_meta.get("ocr_text") or "").strip() + + if incoming_type != "other" and existing_type == "other": + return True + if incoming_fields and not existing_fields: + return True + if incoming_text and not existing_text: + return True + return False + def _build_ocr_document_fields_from_meta(self, meta: dict[str, Any]) -> list[OcrRecognizeFieldRead]: return [ OcrRecognizeFieldRead( diff --git a/server/src/app/services/runtime_chat.py b/server/src/app/services/runtime_chat.py index 326ef5c..c4a5cb2 100644 --- a/server/src/app/services/runtime_chat.py +++ b/server/src/app/services/runtime_chat.py @@ -29,6 +29,12 @@ DEFAULT_RUNTIME_CHAT_FAILURE_COOLDOWN_SECONDS = 90 _slot_failure_until: dict[str, float] = {} +def clear_runtime_chat_failure_cache() -> int: + cleared_count = len(_slot_failure_until) + _slot_failure_until.clear() + return cleared_count + + @dataclass(slots=True) class RuntimeChatCallTrace: slot: str diff --git a/server/src/app/services/system_cache.py b/server/src/app/services/system_cache.py new file mode 100644 index 0000000..8432234 --- /dev/null +++ b/server/src/app/services/system_cache.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from app.core.config import clear_runtime_settings_cache +from app.schemas.settings import SettingsCacheClearItemRead, SettingsCacheClearRead +from app.services.application_location_semantics import clear_application_location_semantic_caches +from app.services.knowledge_rag_local import clear_local_knowledge_index_cache +from app.services.ocr import OcrService +from app.services.runtime_chat import clear_runtime_chat_failure_cache + + +class SystemCacheService: + def clear_all(self) -> SettingsCacheClearRead: + items = [ + SettingsCacheClearItemRead( + cacheKey="ocr_result_cache", + label="OCR 识别结果缓存", + clearedCount=OcrService.clear_result_cache(), + ), + SettingsCacheClearItemRead( + cacheKey="runtime_settings_cache", + label="运行时配置缓存", + clearedCount=clear_runtime_settings_cache(), + ), + SettingsCacheClearItemRead( + cacheKey="runtime_chat_failure_cache", + label="模型调用失败冷却缓存", + clearedCount=clear_runtime_chat_failure_cache(), + ), + SettingsCacheClearItemRead( + cacheKey="knowledge_local_index_cache", + label="知识库本地索引缓存", + clearedCount=clear_local_knowledge_index_cache(), + ), + SettingsCacheClearItemRead( + cacheKey="application_location_semantic_cache", + label="地点语义分析缓存", + clearedCount=clear_application_location_semantic_caches(), + ), + ] + total_cleared = sum(item.clearedCount for item in items) + return SettingsCacheClearRead(totalCleared=total_cleared, items=items) diff --git a/server/tests/test_attachment_association_jobs.py b/server/tests/test_attachment_association_jobs.py index a7a3362..cf286db 100644 --- a/server/tests/test_attachment_association_jobs.py +++ b/server/tests/test_attachment_association_jobs.py @@ -1,5 +1,6 @@ from __future__ import annotations +import base64 from collections.abc import Generator from datetime import UTC, date, datetime from decimal import Decimal @@ -16,6 +17,7 @@ from app.models.employee import Employee from app.models.financial_record import ExpenseClaim, ExpenseClaimItem from app.schemas.ocr import OcrRecognizeBatchRead, OcrRecognizeDocumentRead, OcrRecognizeFieldRead from app.services.attachment_association_jobs import clear_attachment_association_jobs_for_tests +from app.services.expense_claims import ExpenseClaimService from app.services.expense_claim_attachment_storage import ExpenseClaimAttachmentStorage from app.services.ocr import OcrService from app.services.receipt_folder import ReceiptFolderService @@ -149,6 +151,13 @@ def fake_ocr_recognize( ) +def fake_ocr_recognize_without_preview( + self, + files: list[tuple[str, bytes, str | None]], +) -> OcrRecognizeBatchRead: + return fake_ocr_recognize(self, files) + + def test_attachment_association_job_links_receipts_after_conversation_exit(monkeypatch, tmp_path) -> None: monkeypatch.setenv("STORAGE_ROOT_DIR", str(tmp_path / "storage")) get_settings.cache_clear() @@ -233,6 +242,233 @@ def test_attachment_association_job_links_receipts_after_conversation_exit(monke get_settings.cache_clear() +def test_attachment_association_keeps_receipt_folder_preview_and_fields_after_cache_clear( + monkeypatch, + tmp_path, +) -> None: + preview_bytes = b"receipt-folder-preview-png" + preview_data_url = f"data:image/png;base64,{base64.b64encode(preview_bytes).decode('ascii')}" + + monkeypatch.setenv("STORAGE_ROOT_DIR", str(tmp_path / "storage")) + get_settings.cache_clear() + clear_attachment_association_jobs_for_tests() + monkeypatch.setattr(OcrService, "recognize_files", fake_ocr_recognize_without_preview) + monkeypatch.setattr(ExpenseClaimAttachmentStorage, "root", lambda self: tmp_path / "attachments") + try: + client, session_factory = build_client(monkeypatch) + current_user = CurrentUserContext( + username="zhangsan@example.com", + name="张三", + role_codes=["user"], + is_admin=False, + employee_no="E10001", + ) + with session_factory() as db: + seed_travel_claim(db) + + receipt = ReceiptFolderService().save_receipt( + filename="2月20 武汉-上海.pdf", + content=b"%PDF-1.7 fake-ticket", + media_type="application/pdf", + current_user=current_user, + document=OcrRecognizeDocumentRead( + filename="2月20 武汉-上海.pdf", + media_type="application/pdf", + text="电子发票(铁路电子客票) 武汉站 G458 上海虹桥站 2026年02月20日 07:55开 二等座 票价 354.00", + summary="铁路电子客票,武汉-上海,票价 354 元。", + avg_score=0.96, + line_count=1, + page_count=1, + document_type="train_ticket", + document_type_label="火车/高铁票", + scene_code="travel", + scene_label="差旅票据", + preview_kind="image", + preview_data_url=preview_data_url, + document_fields=[ + OcrRecognizeFieldRead(key="date", label="列车出发时间", value="2026-02-20 07:55"), + OcrRecognizeFieldRead(key="route", label="行程", value="武汉-上海"), + OcrRecognizeFieldRead(key="amount", label="金额", value="354元"), + ], + ), + ) + OcrService.clear_result_cache() + + headers = { + "x-auth-username": "zhangsan@example.com", + "x-auth-name": "Zhang San", + "x-auth-employee-no": "E10001", + "x-auth-role-codes": "user", + } + response = client.post( + "/api/v1/reimbursements/attachment-association-jobs", + headers=headers, + json={ + "receipt_ids": [receipt.id], + "prompt": "请帮我处理已上传的附件。", + "conversation_id": "inline-test", + }, + ) + assert response.status_code == 202 + job_id = response.json()["job_id"] + + status_response = client.get( + f"/api/v1/reimbursements/attachment-association-jobs/{job_id}", + headers=headers, + ) + assert status_response.status_code == 200 + assert status_response.json()["status"] == "succeeded" + + with session_factory() as db: + claim = db.scalar( + select(ExpenseClaim) + .options(selectinload(ExpenseClaim.items)) + .where(ExpenseClaim.id == "claim-bg-association") + ) + assert claim is not None + attached_item = next(item for item in claim.items if item.invoice_id) + metadata = ExpenseClaimService(db).get_claim_item_attachment_meta( + claim_id=claim.id, + item_id=attached_item.id, + current_user=current_user, + ) + assert metadata is not None + assert metadata["preview_kind"] == "image" + assert metadata["document_info"]["document_type"] == "train_ticket" + assert metadata["document_info"]["document_type_label"] == "火车/高铁票" + assert { + (field["label"], field["value"]) + for field in metadata["document_info"]["fields"] + } >= { + ("列车出发时间", "2026-02-20 07:55"), + ("行程", "武汉-上海"), + ("金额", "354元"), + } + + preview_path, media_type, filename = ExpenseClaimService(db).get_claim_item_attachment_preview_content( + claim_id=claim.id, + item_id=attached_item.id, + current_user=current_user, + ) + assert media_type == "image/png" + assert filename.endswith(".png") + assert preview_path.read_bytes() == preview_bytes + finally: + clear_attachment_association_jobs_for_tests() + get_settings.cache_clear() + + +def test_attachment_meta_repairs_existing_pdf_fallback_from_source_receipt( + monkeypatch, + tmp_path, +) -> None: + preview_bytes = b"legacy-repaired-preview-png" + preview_data_url = f"data:image/png;base64,{base64.b64encode(preview_bytes).decode('ascii')}" + monkeypatch.setenv("STORAGE_ROOT_DIR", str(tmp_path / "storage")) + get_settings.cache_clear() + monkeypatch.setattr(ExpenseClaimAttachmentStorage, "root", lambda self: tmp_path / "attachments") + try: + current_user = CurrentUserContext( + username="zhangsan@example.com", + name="张三", + role_codes=["user"], + is_admin=False, + employee_no="E10001", + ) + client, session_factory = build_client(monkeypatch) + client.close() + + with session_factory() as db: + claim = seed_travel_claim(db) + item = claim.items[0] + receipt = ReceiptFolderService().save_receipt( + filename="2月20 武汉-上海.pdf", + content=b"%PDF-1.7 fake-ticket", + media_type="application/pdf", + current_user=current_user, + document=OcrRecognizeDocumentRead( + filename="2月20 武汉-上海.pdf", + media_type="application/pdf", + text="电子发票(铁路电子客票) 武汉站 G458 上海虹桥站 2026年02月20日 07:55开 二等座 票价 354.00", + summary="铁路电子客票,武汉-上海,票价 354 元。", + avg_score=0.96, + line_count=1, + page_count=1, + document_type="train_ticket", + document_type_label="火车/高铁票", + scene_code="travel", + scene_label="差旅票据", + preview_kind="image", + preview_data_url=preview_data_url, + document_fields=[ + OcrRecognizeFieldRead(key="date", label="列车出发时间", value="2026-02-20 07:55"), + OcrRecognizeFieldRead(key="route", label="行程", value="武汉-上海"), + OcrRecognizeFieldRead(key="amount", label="金额", value="354元"), + ], + ), + ) + + attachment_dir = tmp_path / "attachments" / claim.id / item.id + attachment_dir.mkdir(parents=True) + file_path = attachment_dir / "2月20_武汉-上海.pdf" + file_path.write_bytes(b"%PDF-1.7 persisted-but-bad-meta") + storage = ExpenseClaimAttachmentStorage() + item.invoice_id = storage.to_storage_key(file_path) + storage.write_meta( + file_path, + { + "file_name": file_path.name, + "storage_key": storage.to_storage_key(file_path), + "media_type": "application/pdf", + "size_bytes": file_path.stat().st_size, + "previewable": True, + "preview_kind": "pdf", + "preview_storage_key": storage.to_storage_key(file_path), + "preview_media_type": "application/pdf", + "preview_file_name": file_path.name, + "document_info": { + "document_type": "other", + "document_type_label": "其他单据", + "scene_code": "other", + "scene_label": "其他票据", + "fields": [], + }, + "source_receipt_id": receipt.id, + }, + ) + db.commit() + + service = ExpenseClaimService(db) + metadata = service.get_claim_item_attachment_meta( + claim_id=claim.id, + item_id=item.id, + current_user=current_user, + ) + assert metadata is not None + assert metadata["preview_kind"] == "image" + assert metadata["document_info"]["document_type"] == "train_ticket" + assert metadata["document_info"]["document_type_label"] == "火车/高铁票" + assert { + (field["label"], field["value"]) + for field in metadata["document_info"]["fields"] + } >= { + ("列车出发时间", "2026-02-20 07:55"), + ("行程", "武汉-上海"), + ("金额", "354元"), + } + + preview_path, media_type, filename = service.get_claim_item_attachment_preview_content( + claim_id=claim.id, + item_id=item.id, + current_user=current_user, + ) + assert media_type == "image/png" + assert filename.endswith(".png") + assert preview_path.read_bytes() == preview_bytes + finally: + get_settings.cache_clear() + + def test_attachment_association_job_fails_without_editable_claim(monkeypatch, tmp_path) -> None: monkeypatch.setenv("STORAGE_ROOT_DIR", str(tmp_path / "storage")) get_settings.cache_clear() diff --git a/server/tests/test_ocr_service.py b/server/tests/test_ocr_service.py index a3d2222..c9c0064 100644 --- a/server/tests/test_ocr_service.py +++ b/server/tests/test_ocr_service.py @@ -308,6 +308,7 @@ def test_ocr_service_rejects_pdf_ocr_when_rendered_image_fonts_are_broken( monkeypatch.setattr(OcrService, "_convert_pdf_to_images", fake_convert_pdf_to_images) monkeypatch.setattr(OcrService, "_invoke_worker", fake_invoke_worker) get_settings.cache_clear() + OcrService._result_cache.clear() try: result = OcrService().recognize_files( [ @@ -315,6 +316,7 @@ def test_ocr_service_rejects_pdf_ocr_when_rendered_image_fonts_are_broken( ] ) finally: + OcrService._result_cache.clear() get_settings.cache_clear() failed = result.documents[0] @@ -324,6 +326,63 @@ def test_ocr_service_rejects_pdf_ocr_when_rendered_image_fonts_are_broken( assert failed.warnings == ["PDF 转图片失败:检测到中文字体映射缺失,未生成可 OCR 的图片。"] +def test_ocr_service_uses_pdf_text_layer_when_rendering_fails( + monkeypatch, + tmp_path: Path, +) -> None: + def fake_convert_pdf_to_images(self, *, pdf_path: Path, output_dir: Path) -> tuple[list[Path], bool]: + raise RuntimeError("PDF 转图片失败:Missing language pack for Adobe-GB1") + + def fake_invoke_worker( + self, + *, + python_bin: str, + worker_path: str, + input_paths: list[Path], + ) -> dict: + raise AssertionError("PDF 转图失败但文本层可用时,不应调用 OCR worker。") + + monkeypatch.setenv("STORAGE_ROOT_DIR", str(tmp_path / "storage")) + monkeypatch.setattr(OcrService, "_resolve_python_bin", lambda self: "python") + monkeypatch.setattr(OcrService, "_resolve_worker_path", lambda self: "worker.py") + monkeypatch.setattr(OcrService, "_convert_pdf_to_images", fake_convert_pdf_to_images) + monkeypatch.setattr(OcrService, "_invoke_worker", fake_invoke_worker) + monkeypatch.setattr( + OcrService, + "_extract_pdf_text_layer", + lambda self, pdf_path: ( + "G458\n" + "Wuhan Shanghaihongqiao\n" + "2026 02 20 07:55\n" + "票价: 354.00\n" + "12306 95306" + ), + ) + get_settings.cache_clear() + OcrService._result_cache.clear() + try: + result = OcrService().recognize_files( + [ + ("2月20_武汉-上海.pdf", b"%PDF-1.7 text-layer-fallback", "application/pdf"), + ] + ) + finally: + OcrService._result_cache.clear() + get_settings.cache_clear() + + recovered = result.documents[0] + assert result.success_count == 1 + assert recovered.document_type == "train_ticket" + assert recovered.document_type_label == "火车/高铁票" + assert recovered.preview_kind == "" + assert recovered.preview_data_url == "" + assert any(field.label == "金额" and field.value == "354元" for field in recovered.document_fields) + assert any(field.label == "车次/航班" and field.value == "G458" for field in recovered.document_fields) + assert any(field.label == "行程" and field.value == "武汉-上海" for field in recovered.document_fields) + assert "PDF 转图片失败" in recovered.warnings[0] + assert "已使用 PDF 文本层" in recovered.warnings[1] + + def test_ocr_pdf_conversion_tries_next_renderer_when_poppler_font_mapping_fails( monkeypatch, tmp_path: Path, @@ -339,6 +398,7 @@ def test_ocr_pdf_conversion_tries_next_renderer_when_poppler_font_mapping_fails( text: bool, timeout: int, check: bool, + env: dict[str, str] | None = None, ) -> subprocess.CompletedProcess[str]: calls.append(Path(command[0]).name) if Path(command[0]).name == "pdftoppm": @@ -437,6 +497,7 @@ def test_ocr_service_invokes_worker_even_when_pdf_text_layer_is_usable( ), ) get_settings.cache_clear() + OcrService._result_cache.clear() try: result = OcrService().recognize_files( [ @@ -444,6 +505,7 @@ def test_ocr_service_invokes_worker_even_when_pdf_text_layer_is_usable( ] ) finally: + OcrService._result_cache.clear() get_settings.cache_clear() recognized = result.documents[0] diff --git a/server/tests/test_openapi_schema.py b/server/tests/test_openapi_schema.py index 9bc987e..46b83a9 100644 --- a/server/tests/test_openapi_schema.py +++ b/server/tests/test_openapi_schema.py @@ -49,5 +49,8 @@ def test_openapi_schema_includes_documented_backend_routes() -> None: analytics_get = schema["paths"]["/api/v1/analytics/system-dashboard"]["get"] assert analytics_get["summary"] == "查询系统看板真实指标" + settings_cache_clear_post = schema["paths"]["/api/v1/settings/cache/clear"]["post"] + assert settings_cache_clear_post["summary"] == "清理系统缓存" + root_get = schema["paths"]["/"]["get"] assert root_get["summary"] == "服务根检查" diff --git a/server/tests/test_receipt_folder_service.py b/server/tests/test_receipt_folder_service.py index 62cf2e8..c515fc5 100644 --- a/server/tests/test_receipt_folder_service.py +++ b/server/tests/test_receipt_folder_service.py @@ -4,7 +4,7 @@ import base64 from app.api.deps import CurrentUserContext from app.core.config import get_settings -from app.schemas.ocr import OcrRecognizeDocumentRead, OcrRecognizeFieldRead +from app.schemas.ocr import OcrRecognizeBatchRead, OcrRecognizeDocumentRead, OcrRecognizeFieldRead from app.services.document_preview import DocumentPreviewAssets from app.services.receipt_folder import ReceiptFolderService @@ -121,6 +121,53 @@ def test_receipt_folder_pdf_save_eagerly_renders_image_preview(monkeypatch, tmp_ get_settings.cache_clear() +def test_receipt_folder_persist_enriches_pdf_ocr_document_with_image_preview(monkeypatch, tmp_path) -> None: + monkeypatch.setenv("STORAGE_ROOT_DIR", str(tmp_path / "storage")) + get_settings.cache_clear() + try: + current_user = CurrentUserContext( + username="pytest", + name="Py Test", + role_codes=[], + is_admin=False, + ) + + def fake_render_pdf_first_page(*, pdf_path, preview_path, timeout_seconds): + preview_path.write_bytes(b"rendered-preview") + return preview_path + + monkeypatch.setattr(DocumentPreviewAssets, "render_pdf_first_page", fake_render_pdf_first_page) + + service = ReceiptFolderService() + result = service.persist_ocr_batch( + files=[("2月23_上海-武汉.pdf", b"%PDF-1.4 fake", "application/pdf")], + result=OcrRecognizeBatchRead( + total_file_count=1, + success_count=1, + documents=[ + OcrRecognizeDocumentRead( + filename="2月23_上海-武汉.pdf", + media_type="application/pdf", + text="铁路电子客票 上海虹桥 武汉 G456 354.00", + summary="铁路电子客票,上海虹桥至武汉。", + document_type="train_ticket", + document_type_label="火车/高铁票", + scene_code="travel", + scene_label="差旅票据", + ), + ], + ), + current_user=current_user, + ) + + document = result.documents[0] + assert document.receipt_id + assert document.receipt_preview_url.endswith(f"/receipt-folder/{document.receipt_id}/preview") + assert document.preview_kind == "image" + finally: + get_settings.cache_clear() + + def test_receipt_folder_pdf_preview_regenerates_stale_cached_image(monkeypatch, tmp_path) -> None: monkeypatch.setenv("STORAGE_ROOT_DIR", str(tmp_path / "storage")) get_settings.cache_clear() @@ -433,6 +480,75 @@ def test_receipt_folder_delete_removes_duplicate_marker(monkeypatch, tmp_path) - get_settings.cache_clear() +def test_receipt_folder_duplicate_uses_newer_ocr_when_existing_meta_is_weaker(monkeypatch, tmp_path) -> None: + monkeypatch.setenv("STORAGE_ROOT_DIR", str(tmp_path / "storage")) + get_settings.cache_clear() + try: + current_user = CurrentUserContext( + username="pytest", + name="Py Test", + role_codes=[], + is_admin=False, + ) + service = ReceiptFolderService() + content = b"%PDF-1.7 same train ticket" + stale_receipt = service.save_receipt( + filename="2月20_武汉-上海.pdf", + content=content, + media_type="application/pdf", + current_user=current_user, + document=OcrRecognizeDocumentRead( + filename="2月20_武汉-上海.pdf", + media_type="application/pdf", + document_type="other", + document_type_label="其他单据", + scene_code="other", + scene_label="其他票据", + warnings=["PDF 转图片失败:Missing language pack for Adobe-GB1"], + ), + ) + + result = service.persist_ocr_batch( + files=[("2月20_武汉-上海.pdf", content, "application/pdf")], + result=OcrRecognizeBatchRead( + total_file_count=1, + success_count=1, + documents=[ + OcrRecognizeDocumentRead( + filename="2月20_武汉-上海.pdf", + media_type="application/pdf", + text="G458 Wuhan Shanghaihongqiao 2026 02 20 07:55 票价: 354.00 12306", + summary="Wuhan Shanghaihongqiao G458 354.00", + document_type="train_ticket", + document_type_label="火车/高铁票", + scene_code="travel", + scene_label="差旅票据", + document_fields=[ + OcrRecognizeFieldRead(key="amount", label="金额", value="354元"), + OcrRecognizeFieldRead(key="trip_no", label="车次/航班", value="G458"), + OcrRecognizeFieldRead(key="route", label="行程", value="武汉-上海"), + ], + ), + ], + ), + current_user=current_user, + ) + + document = result.documents[0] + assert document.receipt_id == stale_receipt.id + assert document.document_type == "train_ticket" + assert document.document_type_label == "火车/高铁票" + assert any(field.label == "金额" and field.value == "354元" for field in document.document_fields) + assert any("重复上传" in warning for warning in document.warnings) + + repaired = service.get_receipt(stale_receipt.id, current_user) + assert repaired.document_type == "train_ticket" + assert repaired.document_type_label == "火车/高铁票" + assert {field.label: field.value for field in repaired.fields}["金额"] == "354元" + finally: + get_settings.cache_clear() + + def test_receipt_folder_recovers_train_ticket_detail_from_other_english_ocr(monkeypatch, tmp_path) -> None: monkeypatch.setenv("STORAGE_ROOT_DIR", str(tmp_path / "storage")) get_settings.cache_clear() diff --git a/server/tests/test_system_cache_endpoints.py b/server/tests/test_system_cache_endpoints.py new file mode 100644 index 0000000..be05477 --- /dev/null +++ b/server/tests/test_system_cache_endpoints.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +from collections.abc import Generator + +from fastapi.testclient import TestClient +from sqlalchemy import create_engine +from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy.pool import StaticPool + +from app.api.deps import get_db +from app.db.base import Base +from app.main import create_app +from app.schemas.ocr import OcrRecognizeDocumentRead +from app.services.ocr import OcrService + + +def build_client() -> TestClient: + engine = create_engine( + "sqlite+pysqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + Base.metadata.create_all(bind=engine) + session_factory = sessionmaker(bind=engine, autoflush=False, autocommit=False) + app = create_app() + + def override_db() -> Generator[Session, None, None]: + db = session_factory() + try: + yield db + finally: + db.close() + + app.dependency_overrides[get_db] = override_db + return TestClient(app) + + +def _seed_ocr_cache() -> None: + OcrService._write_cached_document( + "pytest-cache-key", + OcrRecognizeDocumentRead( + filename="receipt.pdf", + media_type="application/pdf", + text="旧 OCR 缓存", + summary="旧 OCR 缓存", + ), + ) + + +def test_clear_settings_cache_endpoint_clears_ocr_result_cache() -> None: + OcrService.clear_result_cache() + _seed_ocr_cache() + assert len(OcrService._result_cache) == 1 + + client = build_client() + response = client.post( + "/api/v1/settings/cache/clear", + headers={ + "x-auth-username": "admin", + "x-auth-name": "Admin", + "x-auth-is-admin": "true", + }, + ) + + assert response.status_code == 200 + payload = response.json() + assert payload["totalCleared"] >= 1 + assert { + "cacheKey": "ocr_result_cache", + "label": "OCR 识别结果缓存", + "clearedCount": 1, + } in payload["items"] + assert len(OcrService._result_cache) == 0 + + +def test_clear_settings_cache_endpoint_requires_admin() -> None: + OcrService.clear_result_cache() + _seed_ocr_cache() + + client = build_client() + response = client.post( + "/api/v1/settings/cache/clear", + headers={ + "x-auth-username": "ordinary-user", + "x-auth-name": "Ordinary User", + }, + ) + + assert response.status_code == 403 + assert len(OcrService._result_cache) == 1 + OcrService.clear_result_cache()