diff --git a/server/tests/test_orchestrator_service.py b/server/tests/test_orchestrator_service.py index af227ec..2109a39 100644 --- a/server/tests/test_orchestrator_service.py +++ b/server/tests/test_orchestrator_service.py @@ -1,17 +1,21 @@ from __future__ import annotations from collections.abc import Generator +from datetime import UTC, datetime, timedelta from fastapi.testclient import TestClient -from sqlalchemy import create_engine, select +from sqlalchemy import create_engine, func, select from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.pool import StaticPool from app.api.deps import get_db from app.db.base import Base from app.main import create_app +from app.models.agent_conversation import AgentConversation, AgentConversationMessage from app.models.financial_record import ExpenseClaim +from app.schemas.settings import SettingsWrite from app.services.agent_assets import AgentAssetService +from app.services.settings import SettingsService def build_client() -> tuple[TestClient, sessionmaker[Session]]: @@ -53,6 +57,7 @@ def test_orchestrator_routes_user_query_to_user_agent() -> None: assert payload["selected_agent"] == "user_agent" assert payload["permission_level"] == "read" assert payload["status"] == "succeeded" + assert payload["conversation_id"] assert payload["result"]["answer"] assert payload["result"]["suggested_actions"] assert payload["trace_summary"]["tool_count"] >= 1 @@ -160,6 +165,7 @@ def test_orchestrator_user_agent_draft_returns_structured_payload() -> None: assert payload["selected_agent"] == "user_agent" assert payload["status"] == "succeeded" assert payload["result"]["draft_payload"]["confirmation_required"] is True + assert payload["result"]["review_payload"]["slot_cards"] assert payload["result"]["draft_payload"]["claim_id"] assert payload["result"]["draft_payload"]["claim_no"].startswith("EXP-") assert payload["result"]["draft_payload"]["status"] == "draft" @@ -177,6 +183,182 @@ def test_orchestrator_user_agent_draft_returns_structured_payload() -> None: assert claim.items +def test_orchestrator_persists_conversation_and_reuses_expense_draft_context() -> None: + client, session_factory = build_client() + + first_response = client.post( + "/api/v1/orchestrator/run", + json={ + "source": "user_message", + "user_id": "pytest", + "message": "帮我生成一份差旅报销草稿,我昨天去上海出差,交通费680元", + "context_json": { + "role_codes": ["finance"], + "attachment_names": ["行程单.png"], + "attachment_count": 1, + "ocr_summary": "行程单金额680元", + }, + }, + ) + + assert first_response.status_code == 200 + first_payload = first_response.json() + conversation_id = first_payload["conversation_id"] + claim_id = first_payload["result"]["draft_payload"]["claim_id"] + assert conversation_id + assert claim_id + + second_response = client.post( + "/api/v1/orchestrator/run", + json={ + "source": "user_message", + "user_id": "pytest", + "conversation_id": conversation_id, + "message": "金额改成800元", + "context_json": { + "role_codes": ["finance"], + }, + }, + ) + + assert second_response.status_code == 200 + second_payload = second_response.json() + assert second_payload["conversation_id"] == conversation_id + assert second_payload["trace_summary"]["scenario"] == "expense" + assert second_payload["trace_summary"]["intent"] == "draft" + assert second_payload["result"]["draft_payload"]["claim_id"] == claim_id + + with session_factory() as db: + claim = db.scalar(select(ExpenseClaim).where(ExpenseClaim.id == claim_id)) + assert claim is not None + assert float(claim.amount) == 800.0 + + conversation = db.scalar( + select(AgentConversation).where(AgentConversation.conversation_id == conversation_id) + ) + assert conversation is not None + assert conversation.draft_claim_id == claim_id + assert conversation.last_scenario == "expense" + assert conversation.last_intent == "draft" + + message_count = db.scalar( + select(func.count()) + .select_from(AgentConversationMessage) + .where(AgentConversationMessage.conversation_id == conversation_id) + ) + assert message_count == 4 + + +def test_orchestrator_does_not_reuse_conversation_when_user_changes() -> None: + client, session_factory = build_client() + + first_response = client.post( + "/api/v1/orchestrator/run", + json={ + "source": "user_message", + "user_id": "user_a", + "message": "帮我生成一份差旅报销草稿,我昨天去上海出差,交通费680元", + "context_json": {"role_codes": ["finance"]}, + }, + ) + + assert first_response.status_code == 200 + first_payload = first_response.json() + first_conversation_id = first_payload["conversation_id"] + assert first_conversation_id + + second_response = client.post( + "/api/v1/orchestrator/run", + json={ + "source": "user_message", + "user_id": "user_b", + "conversation_id": first_conversation_id, + "message": "查一下本周报销金额", + "context_json": {"role_codes": ["finance"]}, + }, + ) + + assert second_response.status_code == 200 + second_payload = second_response.json() + assert second_payload["conversation_id"] + assert second_payload["conversation_id"] != first_conversation_id + + with session_factory() as db: + first_conversation = db.scalar( + select(AgentConversation).where( + AgentConversation.conversation_id == first_conversation_id + ) + ) + second_conversation = db.scalar( + select(AgentConversation).where( + AgentConversation.conversation_id == second_payload["conversation_id"] + ) + ) + assert first_conversation is not None + assert second_conversation is not None + assert first_conversation.user_id == "user_a" + assert second_conversation.user_id == "user_b" + + +def test_orchestrator_prunes_conversations_older_than_configured_retention_days() -> None: + client, session_factory = build_client() + expired_conversation_id = "conv_expired" + expired_at = datetime.now(UTC) - timedelta(days=2) + + with session_factory() as db: + settings_service = SettingsService(db) + settings_payload = settings_service.get_settings_snapshot().model_dump() + settings_payload["sessionForm"]["conversationRetentionDays"] = 1 + settings_service.save_settings_snapshot(SettingsWrite(**settings_payload)) + + conversation = AgentConversation( + conversation_id=expired_conversation_id, + user_id="expired_user", + source="user_message", + state_json={}, + message_count=1, + created_at=expired_at, + updated_at=expired_at, + ) + db.add(conversation) + db.flush() + db.add( + AgentConversationMessage( + conversation_id=expired_conversation_id, + role="user", + content="旧会话消息", + created_at=expired_at, + ) + ) + db.commit() + + response = client.post( + "/api/v1/orchestrator/run", + json={ + "source": "user_message", + "user_id": "fresh_user", + "message": "查一下本周报销金额", + "context_json": {"role_codes": ["finance"]}, + }, + ) + + assert response.status_code == 200 + + with session_factory() as db: + conversation = db.scalar( + select(AgentConversation).where( + AgentConversation.conversation_id == expired_conversation_id + ) + ) + message_count = db.scalar( + select(func.count()) + .select_from(AgentConversationMessage) + .where(AgentConversationMessage.conversation_id == expired_conversation_id) + ) + assert conversation is None + assert message_count == 0 + + def test_orchestrator_treats_expense_narrative_as_draft_instead_of_ar_query() -> None: client, _ = build_client() @@ -203,6 +385,106 @@ def test_orchestrator_treats_expense_narrative_as_draft_instead_of_ar_query() -> assert "请补充" in payload["result"]["message"] +def test_orchestrator_can_restore_latest_user_conversation() -> None: + client, _ = build_client() + + first_response = client.post( + "/api/v1/orchestrator/run", + json={ + "source": "user_message", + "user_id": "restore_user", + "message": "帮我生成一份差旅报销草稿,我昨天去上海出差,交通费680元", + "context_json": { + "role_codes": ["finance"], + "attachment_names": ["行程单.png"], + "attachment_count": 1, + "ocr_summary": "行程单金额680元", + }, + }, + ) + + assert first_response.status_code == 200 + first_payload = first_response.json() + + second_response = client.post( + "/api/v1/orchestrator/run", + json={ + "source": "user_message", + "user_id": "restore_user", + "conversation_id": first_payload["conversation_id"], + "message": "金额改成800元", + "context_json": {"role_codes": ["finance"]}, + }, + ) + + assert second_response.status_code == 200 + + restore_response = client.get( + "/api/v1/orchestrator/conversations/latest", + params={"user_id": "restore_user"}, + ) + + assert restore_response.status_code == 200 + restore_payload = restore_response.json() + assert restore_payload["found"] is True + assert restore_payload["conversation"]["conversation_id"] == first_payload["conversation_id"] + assert restore_payload["conversation"]["draft_claim_id"] == first_payload["result"]["draft_payload"]["claim_id"] + assert len(restore_payload["conversation"]["messages"]) == 4 + assert restore_payload["conversation"]["messages"][0]["role"] == "user" + assert restore_payload["conversation"]["messages"][0]["message_json"]["attachment_names"] == ["行程单.png"] + assert restore_payload["conversation"]["messages"][1]["message_json"]["orchestrator_payload"]["run_id"] + + +def test_orchestrator_can_delete_all_user_conversations() -> None: + client, session_factory = build_client() + + for message in ("查一下本周报销金额", "帮我生成差旅报销草稿"): + response = client.post( + "/api/v1/orchestrator/run", + json={ + "source": "user_message", + "user_id": "delete_user", + "message": message, + "context_json": {"role_codes": ["finance"]}, + }, + ) + assert response.status_code == 200 + + other_response = client.post( + "/api/v1/orchestrator/run", + json={ + "source": "user_message", + "user_id": "other_user", + "message": "查一下供应商待付款", + "context_json": {"role_codes": ["finance"]}, + }, + ) + assert other_response.status_code == 200 + + delete_response = client.delete( + "/api/v1/orchestrator/conversations", + params={"user_id": "delete_user"}, + ) + + assert delete_response.status_code == 200 + delete_payload = delete_response.json() + assert delete_payload["deleted_count"] == 2 + + with session_factory() as db: + deleted_count = db.scalar( + select(func.count()) + .select_from(AgentConversation) + .where(AgentConversation.user_id == "delete_user") + ) + other_count = db.scalar( + select(func.count()) + .select_from(AgentConversation) + .where(AgentConversation.user_id == "other_user") + ) + assert deleted_count == 0 + assert other_count == 1 + + def test_orchestrator_tool_failure_is_logged_and_degraded() -> None: client, _ = build_client() diff --git a/server/tests/test_settings_persistence.py b/server/tests/test_settings_persistence.py index 97bdff5..518221a 100644 --- a/server/tests/test_settings_persistence.py +++ b/server/tests/test_settings_persistence.py @@ -50,10 +50,11 @@ def test_settings_service_persists_non_secret_and_secret_fields(monkeypatch) -> payload["companyForm"]["companyName"] = "YGSOFT" payload["companyForm"]["displayName"] = "云广软件" - payload["adminForm"]["adminAccount"] = "admin-root" - payload["adminForm"]["adminEmail"] = "admin@example.com" - payload["adminForm"]["newPassword"] = "54321" - payload["adminForm"]["confirmPassword"] = "54321" + payload["adminForm"]["adminAccount"] = "admin-root" + payload["adminForm"]["adminEmail"] = "admin@example.com" + payload["adminForm"]["newPassword"] = "54321" + payload["adminForm"]["confirmPassword"] = "54321" + payload["sessionForm"]["conversationRetentionDays"] = 7 payload["llmForm"]["mainModel"] = "glm-4.5" payload["llmForm"]["mainApiKey"] = "main-secret" payload["renderForm"]["enabled"] = True @@ -63,8 +64,9 @@ def test_settings_service_persists_non_secret_and_secret_fields(monkeypatch) -> saved_snapshot = service.save_settings_snapshot(SettingsWrite(**payload)) - assert saved_snapshot.companyForm.companyName == "YGSOFT" - assert saved_snapshot.companyForm.displayName == "云广软件" + assert saved_snapshot.companyForm.companyName == "YGSOFT" + assert saved_snapshot.companyForm.displayName == "云广软件" + assert saved_snapshot.sessionForm.conversationRetentionDays == 7 assert saved_snapshot.llmForm.mainModel == "glm-4.5" assert saved_snapshot.llmForm.mainApiKey == "" assert saved_snapshot.llmForm.mainApiKeyConfigured is True @@ -84,6 +86,7 @@ def test_settings_service_persists_non_secret_and_secret_fields(monkeypatch) -> assert model_row.model_name == "glm-4.5" assert model_row.api_key_encrypted assert settings_row is not None + assert settings_row.conversation_retention_days == 7 assert settings_row.onlyoffice_enabled is True assert settings_row.onlyoffice_public_url == "http://10.10.10.122:8082" assert secrets_row is not None diff --git a/server/tests/test_user_agent_service.py b/server/tests/test_user_agent_service.py index 3c838ed..706213e 100644 --- a/server/tests/test_user_agent_service.py +++ b/server/tests/test_user_agent_service.py @@ -177,3 +177,72 @@ def test_user_agent_draft_returns_structured_payload() -> None: assert response.draft_payload is not None assert response.draft_payload.confirmation_required is True assert "待人工确认" in response.answer + + +def test_user_agent_builds_review_payload_for_multi_document_expense_flow() -> None: + session_factory = build_session_factory() + with session_factory() as db: + ontology = SemanticOntologyService(db).parse( + OntologyParseRequest( + query="我昨天去上海出差,还请客户A吃饭,帮我生成报销草稿", + user_id="pytest", + context_json={ + "attachment_names": ["机票行程单.png", "餐饮发票.jpg"], + "attachment_count": 2, + "ocr_documents": [ + { + "filename": "机票行程单.png", + "summary": "机票行程单 上海-北京 金额 680 元", + "text": "机票行程单 上海-北京 金额 680 元", + "avg_score": 0.93, + "warnings": [], + }, + { + "filename": "餐饮发票.jpg", + "summary": "餐饮发票 客户招待 金额 320 元", + "text": "餐饮发票 客户招待 金额 320 元", + "avg_score": 0.91, + "warnings": [], + }, + ], + }, + ) + ) + response = UserAgentService(db).respond( + UserAgentRequest( + run_id=ontology.run_id, + user_id="pytest", + message="我昨天去上海出差,还请客户A吃饭,帮我生成报销草稿", + ontology=ontology, + context_json={ + "name": "张三", + "attachment_names": ["机票行程单.png", "餐饮发票.jpg"], + "attachment_count": 2, + "ocr_documents": [ + { + "filename": "机票行程单.png", + "summary": "机票行程单 上海-北京 金额 680 元", + "text": "机票行程单 上海-北京 金额 680 元", + "avg_score": 0.93, + "warnings": [], + }, + { + "filename": "餐饮发票.jpg", + "summary": "餐饮发票 客户招待 金额 320 元", + "text": "餐饮发票 客户招待 金额 320 元", + "avg_score": 0.91, + "warnings": [], + }, + ], + }, + tool_payload={"draft_only": True, "claim_no": "EXP-202605-009", "status": "draft"}, + ) + ) + + assert response.review_payload is not None + assert len(response.review_payload.document_cards) == 2 + assert len(response.review_payload.claim_groups) == 2 + assert any( + item.action_type == "split_claims" + for item in response.review_payload.confirmation_actions + )