From 4db5e8ec1604b7a10299967da28ed5fc51b91db5 Mon Sep 17 00:00:00 2001 From: caoxiaozhu Date: Wed, 13 May 2026 03:39:41 +0000 Subject: [PATCH] refactor(backend): update services and tests - services/expense_claims.py: update expense claims service - services/user_agent.py: update user agent service - tests/test_orchestrator_service.py: update orchestrator service tests - tests/test_user_agent_service.py: update user agent service tests --- server/src/app/services/expense_claims.py | 112 ++++++++++++++-- server/src/app/services/user_agent.py | 7 + server/tests/test_orchestrator_service.py | 148 ++++++++++++++++++++++ server/tests/test_user_agent_service.py | 30 +++++ 4 files changed, 288 insertions(+), 9 deletions(-) diff --git a/server/src/app/services/expense_claims.py b/server/src/app/services/expense_claims.py index 75e7b30..0a383d4 100644 --- a/server/src/app/services/expense_claims.py +++ b/server/src/app/services/expense_claims.py @@ -25,6 +25,7 @@ EXPENSE_TYPE_LABELS = { } PRIVILEGED_CLAIM_ROLE_CODES = {"manager", "finance", "approver", "auditor", "executive"} +MAX_DRAFT_CLAIMS_PER_USER = 3 class ExpenseClaimService: @@ -117,7 +118,7 @@ class ExpenseClaimService: before_json = self._serialize_claim(claim) claim.status = "submitted" - claim.approval_stage = "审批流转" + claim.approval_stage = "AI验审" claim.submitted_at = datetime.now(UTC) self.db.commit() @@ -172,7 +173,39 @@ class ExpenseClaimService: is_new_claim = claim is None before_json = self._serialize_claim(claim) if claim is not None else None - employee = self._resolve_employee(ontology=ontology, context_json=context_json) + employee = self._resolve_employee( + ontology=ontology, + context_json=context_json, + user_id=user_id, + ) + draft_owner_name = ( + employee.name + if employee is not None + else self._resolve_employee_name( + ontology=ontology, + context_json=context_json, + user_id=user_id, + ) + ) + if is_new_claim: + existing_draft_count = self._count_draft_claims_for_owner( + employee=employee, + employee_name=draft_owner_name, + user_id=user_id, + ) + if existing_draft_count >= MAX_DRAFT_CLAIMS_PER_USER: + return { + "message": ( + f"你当前已保存 {MAX_DRAFT_CLAIMS_PER_USER} 个草稿,请先完成已保存的草稿," + "才能再次新建草稿。" + ), + "draft_limit_reached": True, + "draft_only": False, + "status": "blocked", + "draft_count": existing_draft_count, + "max_draft_count": MAX_DRAFT_CLAIMS_PER_USER, + } + amount = self._resolve_amount(ontology.entities, context_json=context_json) occurred_at = self._resolve_occurred_at(ontology, context_json=context_json) expense_type = self._resolve_expense_type(ontology.entities, context_json=context_json) @@ -202,11 +235,7 @@ class ExpenseClaimService: claim = ExpenseClaim( claim_no=self._generate_claim_no(final_occurred_at), employee_id=employee.id if employee is not None else None, - employee_name=employee.name if employee is not None else self._resolve_employee_name( - ontology=ontology, - context_json=context_json, - user_id=user_id, - ), + employee_name=draft_owner_name, department_id=employee.organization_unit_id if employee is not None else None, department_name=self._resolve_department_name( employee=employee, @@ -221,7 +250,7 @@ class ExpenseClaimService: invoice_count=final_attachment_count, occurred_at=final_occurred_at, status="draft", - approval_stage="待补充", + approval_stage="待提交", risk_flags_json=final_risk_flags, ) self.db.add(claim) @@ -251,7 +280,7 @@ class ExpenseClaimService: claim.invoice_count = final_attachment_count claim.occurred_at = final_occurred_at claim.status = "draft" - claim.approval_stage = "待补充" + claim.approval_stage = "待提交" claim.risk_flags_json = final_risk_flags self.db.flush() @@ -355,12 +384,77 @@ class ExpenseClaimService: ) return f"{prefix}{existing + 1:03d}" + def _count_draft_claims_for_owner( + self, + *, + employee: Employee | None, + employee_name: str, + user_id: str | None, + ) -> int: + owner_filters = self._build_draft_owner_filters( + employee=employee, + employee_name=employee_name, + user_id=user_id, + ) + if not owner_filters: + return 0 + + stmt = ( + select(func.count()) + .select_from(ExpenseClaim) + .where(ExpenseClaim.status == "draft") + .where(or_(*owner_filters)) + ) + return int(self.db.scalar(stmt) or 0) + + @staticmethod + def _build_draft_owner_filters( + *, + employee: Employee | None, + employee_name: str, + user_id: str | None, + ) -> list[Any]: + conditions: list[Any] = [] + seen: set[tuple[str, str]] = set() + + def add_condition(field_name: str, value: str | None) -> None: + normalized = str(value or "").strip() + if not normalized or normalized == "待补充": + return + + marker = (field_name, normalized.lower()) + if marker in seen: + return + seen.add(marker) + + if field_name == "employee_id": + conditions.append(ExpenseClaim.employee_id == normalized) + return + conditions.append(ExpenseClaim.employee_name == normalized) + + if employee is not None: + add_condition("employee_id", employee.id) + add_condition("employee_name", employee.name) + add_condition("employee_name", employee.email) + + add_condition("employee_name", employee_name) + add_condition("employee_name", user_id) + return conditions + def _resolve_employee( self, *, ontology: OntologyParseResult, context_json: dict[str, Any], + user_id: str | None, ) -> Employee | None: + normalized_user_id = str(user_id or "").strip() + if normalized_user_id: + stmt = select(Employee).where(func.lower(Employee.email) == normalized_user_id.lower()).limit(1) + employee = self.db.scalar(stmt) + if employee is not None: + return employee + employee_name = self._resolve_employee_name( ontology=ontology, context_json=context_json, diff --git a/server/src/app/services/user_agent.py b/server/src/app/services/user_agent.py index 4f3a5e2..ae4f303 100644 --- a/server/src/app/services/user_agent.py +++ b/server/src/app/services/user_agent.py @@ -207,6 +207,8 @@ class UserAgentService: if payload.ontology.intent == "draft": tool_message = str(payload.tool_payload.get("message") or "").strip() + if payload.tool_payload.get("draft_limit_reached"): + return tool_message or "你当前已保存 3 个草稿,请先完成已保存的草稿,才能再次新建草稿。" if tool_message and ( str(payload.tool_payload.get("claim_id") or "").strip() or str(payload.tool_payload.get("claim_no") or "").strip() @@ -988,6 +990,11 @@ class UserAgentService: return None if payload.ontology.intent not in {"draft", "operate"}: return None + if payload.tool_payload.get("draft_limit_reached"): + return ( + str(payload.tool_payload.get("message") or "").strip() + or "你当前已保存 3 个草稿,请先完成已保存的草稿,才能再次新建草稿。" + ) review_action = str(payload.context_json.get("review_action") or "").strip() if review_action == "save_draft": diff --git a/server/tests/test_orchestrator_service.py b/server/tests/test_orchestrator_service.py index bb5e9f6..c035798 100644 --- a/server/tests/test_orchestrator_service.py +++ b/server/tests/test_orchestrator_service.py @@ -12,6 +12,7 @@ from app.api.deps import get_db from app.db.base import Base from app.main import create_app from app.models.agent_conversation import AgentConversation, AgentConversationMessage +from app.models.employee import Employee from app.models.financial_record import ExpenseClaim from app.schemas.settings import SettingsWrite from app.services.agent_assets import AgentAssetService @@ -183,6 +184,153 @@ def test_orchestrator_user_agent_draft_returns_structured_payload() -> None: assert claim.items +def test_orchestrator_blocks_fourth_expense_draft_for_same_user() -> None: + client, session_factory = build_client() + user_id = "zhangsan@example.com" + + with session_factory() as db: + db.add( + Employee( + employee_no="E1001", + name="张三", + email=user_id, + ) + ) + db.commit() + + for amount, city in ((120, "上海"), (240, "北京"), (360, "深圳")): + response = client.post( + "/api/v1/orchestrator/run", + json={ + "source": "user_message", + "user_id": user_id, + "message": f"帮我生成报销草稿,我昨天去{city}出差,交通费{amount}元", + "context_json": { + "role_codes": ["finance"], + "name": "张三", + }, + }, + ) + + assert response.status_code == 200 + payload = response.json() + assert payload["result"]["draft_payload"]["claim_no"].startswith("EXP-") + + blocked_response = client.post( + "/api/v1/orchestrator/run", + json={ + "source": "user_message", + "user_id": user_id, + "message": "帮我生成报销草稿,我昨天去杭州出差,交通费480元", + "context_json": { + "role_codes": ["finance"], + "name": "张三", + "review_action": "save_draft", + }, + }, + ) + + assert blocked_response.status_code == 200 + blocked_payload = blocked_response.json() + assert blocked_payload["status"] == "succeeded" + assert "你当前已保存 3 个草稿" in blocked_payload["result"]["answer"] + assert blocked_payload["result"]["draft_payload"]["claim_id"] is None + assert blocked_payload["result"]["draft_payload"]["claim_no"] is None + assert blocked_payload["result"]["draft_payload"]["status"] == "blocked" + + with session_factory() as db: + draft_count = db.scalar( + select(func.count()) + .select_from(ExpenseClaim) + .where(ExpenseClaim.status == "draft") + ) + assert draft_count == 3 + + +def test_orchestrator_allows_existing_draft_update_when_user_already_has_three_drafts() -> None: + client, session_factory = build_client() + user_id = "lisi@example.com" + + with session_factory() as db: + db.add( + Employee( + employee_no="E1002", + name="李四", + email=user_id, + ) + ) + db.commit() + + first_response = client.post( + "/api/v1/orchestrator/run", + json={ + "source": "user_message", + "user_id": user_id, + "message": "帮我生成报销草稿,我昨天去上海出差,交通费120元", + "context_json": { + "role_codes": ["finance"], + "name": "李四", + }, + }, + ) + + assert first_response.status_code == 200 + first_payload = first_response.json() + claim_id = first_payload["result"]["draft_payload"]["claim_id"] + conversation_id = first_payload["conversation_id"] + assert claim_id + assert conversation_id + + for amount, city in ((240, "北京"), (360, "深圳")): + response = client.post( + "/api/v1/orchestrator/run", + json={ + "source": "user_message", + "user_id": user_id, + "message": f"帮我生成报销草稿,我昨天去{city}出差,交通费{amount}元", + "context_json": { + "role_codes": ["finance"], + "name": "李四", + }, + }, + ) + + assert response.status_code == 200 + payload = response.json() + assert payload["result"]["draft_payload"]["claim_no"].startswith("EXP-") + + update_response = client.post( + "/api/v1/orchestrator/run", + json={ + "source": "user_message", + "user_id": user_id, + "conversation_id": conversation_id, + "message": "金额改成888元", + "context_json": { + "role_codes": ["finance"], + "name": "李四", + }, + }, + ) + + assert update_response.status_code == 200 + update_payload = update_response.json() + assert update_payload["result"]["draft_payload"]["claim_id"] == claim_id + + with session_factory() as db: + claim = db.scalar(select(ExpenseClaim).where(ExpenseClaim.id == claim_id)) + assert claim is not None + assert float(claim.amount) == 888.0 + + draft_count = db.scalar( + select(func.count()) + .select_from(ExpenseClaim) + .where(ExpenseClaim.employee_id == claim.employee_id) + .where(ExpenseClaim.status == "draft") + ) + assert draft_count == 3 + + def test_orchestrator_persists_conversation_and_reuses_expense_draft_context() -> None: client, session_factory = build_client() diff --git a/server/tests/test_user_agent_service.py b/server/tests/test_user_agent_service.py index 2a762e0..b398e60 100644 --- a/server/tests/test_user_agent_service.py +++ b/server/tests/test_user_agent_service.py @@ -246,6 +246,36 @@ def test_user_agent_draft_returns_structured_payload() -> None: assert response.answer == response.review_payload.body_message +def test_user_agent_returns_draft_limit_message_when_save_is_blocked() -> None: + session_factory = build_session_factory() + with session_factory() as db: + ontology = SemanticOntologyService(db).parse( + OntologyParseRequest( + query="请按当前识别信息保存报销草稿", + user_id="pytest", + ) + ) + response = UserAgentService(db).respond( + UserAgentRequest( + run_id=ontology.run_id, + user_id="pytest", + message="请按当前识别信息保存报销草稿", + ontology=ontology, + context_json={"review_action": "save_draft"}, + tool_payload={ + "draft_limit_reached": True, + "message": "你当前已保存 3 个草稿,请先完成已保存的草稿,才能再次新建草稿。", + "status": "blocked", + }, + ) + ) + + assert ( + response.answer + == "你当前已保存 3 个草稿,请先完成已保存的草稿,才能再次新建草稿。" + ) + + def test_user_agent_builds_review_payload_for_multi_document_expense_flow() -> None: session_factory = build_session_factory() with session_factory() as db: