From 68a3907920db175383aa59f920880cd3b0779b2b Mon Sep 17 00:00:00 2001 From: caoxiaozhu Date: Wed, 13 May 2026 15:33:35 +0000 Subject: [PATCH] refactor(backend): update expense claims service and tests - services/expense_claims.py: update expense claims service - tests/test_orchestrator_service.py: update orchestrator service tests --- server/src/app/services/expense_claims.py | 32 +- server/tests/test_orchestrator_service.py | 341 +++++++++++++++++++++- 2 files changed, 360 insertions(+), 13 deletions(-) diff --git a/server/src/app/services/expense_claims.py b/server/src/app/services/expense_claims.py index dfed538..1089552 100644 --- a/server/src/app/services/expense_claims.py +++ b/server/src/app/services/expense_claims.py @@ -36,7 +36,7 @@ EXPENSE_TYPE_LABELS = { "welfare": "福利", } -PRIVILEGED_CLAIM_ROLE_CODES = {"manager", "finance", "approver", "auditor", "executive"} +PRIVILEGED_CLAIM_ROLE_CODES = {"finance"} MAX_DRAFT_CLAIMS_PER_USER = 3 LOCATION_REQUIRED_EXPENSE_TYPES = { "travel", @@ -1607,8 +1607,6 @@ class ExpenseClaimService: @staticmethod def _has_privileged_claim_access(current_user: CurrentUserContext) -> bool: - if current_user.is_admin: - return True return bool(set(current_user.role_codes) & PRIVILEGED_CLAIM_ROLE_CODES) def _apply_claim_scope(self, stmt: Any, current_user: CurrentUserContext) -> Any: @@ -1617,13 +1615,31 @@ class ExpenseClaimService: conditions = [] username = str(current_user.username or "").strip() - name = str(current_user.name or "").strip() + employee = None if username: - conditions.append(ExpenseClaim.employee_id == username) - conditions.append(ExpenseClaim.employee_name == username) - if name: - conditions.append(ExpenseClaim.employee_name == name) + employee = self.db.scalar( + select(Employee) + .where(func.lower(Employee.email) == username.lower()) + .limit(1) + ) + + def add_condition(field_name: str, value: str | None) -> None: + normalized = str(value or "").strip() + if not normalized: + return + if field_name == "employee_id": + conditions.append(ExpenseClaim.employee_id == normalized) + return + conditions.append(ExpenseClaim.employee_name == normalized) + + if employee is not None: + add_condition("employee_id", employee.id) + add_condition("employee_name", employee.name) + add_condition("employee_name", employee.email) + else: + add_condition("employee_id", username) + add_condition("employee_name", username) if not conditions: return stmt.where(ExpenseClaim.id == "__no_visible_claim__") diff --git a/server/tests/test_orchestrator_service.py b/server/tests/test_orchestrator_service.py index 59f8d2d..b75e3a5 100644 --- a/server/tests/test_orchestrator_service.py +++ b/server/tests/test_orchestrator_service.py @@ -14,7 +14,11 @@ from app.db.base import Base from app.main import create_app from app.models.agent_conversation import AgentConversation, AgentConversationMessage from app.models.employee import Employee -from app.models.financial_record import ExpenseClaim +from app.models.financial_record import ( + AccountsPayableRecord, + AccountsReceivableRecord, + ExpenseClaim, +) from app.schemas.settings import SettingsWrite from app.services.agent_assets import AgentAssetService from app.services.settings import SettingsService @@ -71,6 +75,29 @@ def test_orchestrator_routes_user_query_to_user_agent() -> None: assert run_detail["tool_calls"][0]["tool_type"] == "database" +def test_orchestrator_does_not_auto_seed_demo_financial_records() -> None: + client, session_factory = build_client() + + response = client.post( + "/api/v1/orchestrator/run", + json={ + "source": "user_message", + "user_id": "pytest", + "message": "请查询我的报销单", + "context_json": {"role_codes": ["employee"], "name": "测试用户"}, + }, + ) + + assert response.status_code == 200 + payload = response.json() + assert payload["result"]["query_payload"]["record_count"] == 0 + + with session_factory() as db: + assert db.scalar(select(func.count()).select_from(ExpenseClaim)) == 0 + assert db.scalar(select(func.count()).select_from(AccountsReceivableRecord)) == 0 + assert db.scalar(select(func.count()).select_from(AccountsPayableRecord)) == 0 + + def test_orchestrator_scopes_my_expense_query_to_current_user() -> None: client, session_factory = build_client() user_id = "zhaoliu@example.com" @@ -122,6 +149,23 @@ def test_orchestrator_scopes_my_expense_query_to_current_user() -> None: ), ExpenseClaim( claim_no="EXP-TEST-003", + employee_name="赵六", + department_name="测试部", + project_code="PRJ-TEST-03", + expense_type="hotel", + reason="历史住宿报销", + location="南京", + amount=Decimal("888.00"), + currency="CNY", + invoice_count=1, + occurred_at=datetime(2026, 4, 20, 8, 30, tzinfo=UTC), + submitted_at=datetime(2026, 4, 20, 9, 30, tzinfo=UTC), + status="approved", + approval_stage="completed", + risk_flags_json=[], + ), + ExpenseClaim( + claim_no="EXP-TEST-004", employee_name="张三", department_name="财务部", project_code="PRJ-OTHER-01", @@ -150,6 +194,7 @@ def test_orchestrator_scopes_my_expense_query_to_current_user() -> None: "context_json": { "role_codes": ["employee"], "name": "赵六", + "client_now_iso": "2026-05-13T08:00:00+00:00", }, }, ) @@ -158,17 +203,228 @@ def test_orchestrator_scopes_my_expense_query_to_current_user() -> None: payload = response.json() assert payload["selected_agent"] == "user_agent" assert payload["status"] == "succeeded" - assert "查到你的报销单共 2 笔" in payload["result"]["answer"] - assert "EXP-TEST-001" in payload["result"]["answer"] - assert "EXP-TEST-002" in payload["result"]["answer"] - assert "EXP-TEST-003" not in payload["result"]["answer"] + assert "2026-05-04 至 2026-05-13的你的报销单" in payload["result"]["answer"] + assert "共 2 笔" in payload["result"]["answer"] + assert "超过 10 日的单据" in payload["result"]["answer"] + assert payload["result"]["query_payload"]["record_count"] == 2 + assert payload["result"]["query_payload"]["older_record_count"] == 1 + assert payload["result"]["query_payload"]["window_start_date"] == "2026-05-04" + assert payload["result"]["query_payload"]["window_end_date"] == "2026-05-13" + assert [item["claim_no"] for item in payload["result"]["query_payload"]["records"]] == [ + "EXP-TEST-002", + "EXP-TEST-001", + ] run_detail = client.get(f"/api/v1/agent-runs/{payload['run_id']}").json() tool_response = run_detail["tool_calls"][0]["response_json"] assert tool_response["record_count"] == 2 assert tool_response["total_amount"] == 420.0 + assert tool_response["recent_window_applied"] is True + assert tool_response["window_start_date"] == "2026-05-04" + assert tool_response["window_end_date"] == "2026-05-13" + assert tool_response["older_record_count"] == 1 assert tool_response["scoped_to_current_user"] is True assert tool_response["scope_label"] == "你的报销单" + assert [item["claim_no"] for item in tool_response["records"]] == [ + "EXP-TEST-002", + "EXP-TEST-001", + ] + + +def test_orchestrator_non_finance_cannot_query_other_users_expense_claims() -> None: + client, session_factory = build_client() + user_id = "manager1@example.com" + + with session_factory() as db: + owner = Employee( + employee_no="E9101", + name="李经理", + email=user_id, + ) + other = Employee( + employee_no="E9102", + name="王同学", + email="other@example.com", + ) + db.add_all([owner, other]) + db.flush() + db.add_all( + [ + ExpenseClaim( + claim_no="EXP-MGR-001", + employee_id=owner.id, + employee_name="李经理", + department_name="管理部", + project_code="PRJ-MGR-01", + expense_type="travel", + reason="本人出差", + location="上海", + amount=Decimal("100.00"), + currency="CNY", + invoice_count=1, + occurred_at=datetime(2026, 5, 11, 9, 0, tzinfo=UTC), + submitted_at=datetime(2026, 5, 11, 10, 0, tzinfo=UTC), + status="submitted", + approval_stage="finance_review", + risk_flags_json=[], + ), + ExpenseClaim( + claim_no="EXP-MGR-002", + employee_id=other.id, + employee_name="王同学", + department_name="销售部", + project_code="PRJ-SALES-02", + expense_type="meal", + reason="他人报销", + location="杭州", + amount=Decimal("300.00"), + currency="CNY", + invoice_count=1, + occurred_at=datetime(2026, 5, 11, 12, 0, tzinfo=UTC), + submitted_at=datetime(2026, 5, 11, 13, 0, tzinfo=UTC), + status="approved", + approval_stage="completed", + risk_flags_json=[], + ), + ] + ) + db.commit() + + response = client.post( + "/api/v1/orchestrator/run", + json={ + "source": "user_message", + "user_id": user_id, + "message": "请查询王同学的报销单", + "context_json": { + "role_codes": ["manager"], + "name": "李经理", + "client_now_iso": "2026-05-13T08:00:00+00:00", + }, + }, + ) + + assert response.status_code == 200 + payload = response.json() + assert payload["result"]["query_payload"]["record_count"] == 1 + assert [item["claim_no"] for item in payload["result"]["query_payload"]["records"]] == [ + "EXP-MGR-001", + ] + + +def test_orchestrator_finance_can_query_all_expense_claims() -> None: + client, session_factory = build_client() + + with session_factory() as db: + db.add_all( + [ + ExpenseClaim( + claim_no="EXP-FIN-001", + employee_name="甲", + department_name="A部", + project_code="PRJ-A", + expense_type="travel", + reason="A 报销", + location="上海", + amount=Decimal("120.00"), + currency="CNY", + invoice_count=1, + occurred_at=datetime(2026, 5, 11, 9, 0, tzinfo=UTC), + submitted_at=datetime(2026, 5, 11, 10, 0, tzinfo=UTC), + status="submitted", + approval_stage="finance_review", + risk_flags_json=[], + ), + ExpenseClaim( + claim_no="EXP-FIN-002", + employee_name="乙", + department_name="B部", + project_code="PRJ-B", + expense_type="meal", + reason="B 报销", + location="杭州", + amount=Decimal("300.00"), + currency="CNY", + invoice_count=1, + occurred_at=datetime(2026, 5, 11, 12, 0, tzinfo=UTC), + submitted_at=datetime(2026, 5, 11, 13, 0, tzinfo=UTC), + status="approved", + approval_stage="completed", + risk_flags_json=[], + ), + ] + ) + db.commit() + + response = client.post( + "/api/v1/orchestrator/run", + json={ + "source": "user_message", + "user_id": "finance@example.com", + "message": "请查询所有报销单", + "context_json": { + "role_codes": ["finance"], + "name": "财务", + "client_now_iso": "2026-05-13T08:00:00+00:00", + }, + }, + ) + + assert response.status_code == 200 + payload = response.json() + assert payload["result"]["query_payload"]["record_count"] == 2 + assert {item["claim_no"] for item in payload["result"]["query_payload"]["records"]} == { + "EXP-FIN-001", + "EXP-FIN-002", + } + + +def test_orchestrator_expense_query_claim_no_bypasses_recent_window() -> None: + client, session_factory = build_client() + user_id = "zhaoliu@example.com" + + with session_factory() as db: + db.add( + ExpenseClaim( + claim_no="EXP-202604-001", + employee_name="赵六", + department_name="测试部", + project_code="PRJ-OLD-01", + expense_type="travel", + reason="上月差旅", + location="北京", + amount=Decimal("560.00"), + currency="CNY", + invoice_count=1, + occurred_at=datetime(2026, 4, 1, 9, 0, tzinfo=UTC), + submitted_at=datetime(2026, 4, 1, 18, 0, tzinfo=UTC), + status="approved", + approval_stage="completed", + risk_flags_json=[], + ) + ) + db.commit() + + response = client.post( + "/api/v1/orchestrator/run", + json={ + "source": "user_message", + "user_id": user_id, + "message": "请查询报销单 EXP-202604-001", + "context_json": { + "role_codes": ["employee"], + "name": "赵六", + "client_now_iso": "2026-05-13T08:00:00+00:00", + }, + }, + ) + + assert response.status_code == 200 + payload = response.json() + assert payload["result"]["query_payload"]["recent_window_applied"] is False + assert payload["result"]["query_payload"]["record_count"] == 1 + assert payload["result"]["query_payload"]["older_record_count"] == 0 + assert payload["result"]["query_payload"]["records"][0]["claim_no"] == "EXP-202604-001" def test_orchestrator_routes_schedule_to_hermes() -> None: @@ -691,6 +947,81 @@ def test_orchestrator_can_restore_latest_user_conversation() -> None: assert restore_payload["conversation"]["messages"][1]["message_json"]["orchestrator_payload"]["run_id"] +def test_orchestrator_restores_conversation_messages_in_sequence_order() -> None: + client, session_factory = build_client() + conversation_id = "conv_restore_sequence" + created_at = datetime(2026, 5, 13, 13, 20, tzinfo=UTC) + + with session_factory() as db: + conversation = AgentConversation( + conversation_id=conversation_id, + user_id="sequence_user", + source="user_message", + state_json={"session_type": "expense"}, + message_count=4, + created_at=created_at, + updated_at=created_at, + ) + db.add(conversation) + db.flush() + db.add_all( + [ + AgentConversationMessage( + id="msg-z-assistant", + conversation_id=conversation_id, + run_id="run-a", + role="assistant", + content="第二条:助手回复", + message_json={"sequence": 2}, + created_at=created_at, + ), + AgentConversationMessage( + id="msg-b-user", + conversation_id=conversation_id, + run_id="run-b", + role="user", + content="第三条:用户追问", + message_json={"sequence": 3}, + created_at=created_at, + ), + AgentConversationMessage( + id="msg-a-user", + conversation_id=conversation_id, + run_id="run-a", + role="user", + content="第一条:用户发起", + message_json={"sequence": 1}, + created_at=created_at, + ), + AgentConversationMessage( + id="msg-c-assistant", + conversation_id=conversation_id, + run_id="run-b", + role="assistant", + content="第四条:助手总结", + message_json={"sequence": 4}, + created_at=created_at, + ), + ] + ) + db.commit() + + restore_response = client.get( + "/api/v1/orchestrator/conversations/latest", + params={"user_id": "sequence_user", "session_type": "expense"}, + ) + + assert restore_response.status_code == 200 + restore_payload = restore_response.json() + assert restore_payload["found"] is True + assert [item["content"] for item in restore_payload["conversation"]["messages"]] == [ + "第一条:用户发起", + "第二条:助手回复", + "第三条:用户追问", + "第四条:助手总结", + ] + + def test_orchestrator_can_delete_all_user_conversations() -> None: client, session_factory = build_client()