test(backend): update service tests
- test_orchestrator_service.py: update orchestrator service tests - test_settings_persistence.py: update settings persistence tests - test_user_agent_service.py: update user agent service tests
This commit is contained in:
@@ -1,17 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy import create_engine, select
|
||||
from sqlalchemy import create_engine, func, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from app.api.deps import get_db
|
||||
from app.db.base import Base
|
||||
from app.main import create_app
|
||||
from app.models.agent_conversation import AgentConversation, AgentConversationMessage
|
||||
from app.models.financial_record import ExpenseClaim
|
||||
from app.schemas.settings import SettingsWrite
|
||||
from app.services.agent_assets import AgentAssetService
|
||||
from app.services.settings import SettingsService
|
||||
|
||||
|
||||
def build_client() -> tuple[TestClient, sessionmaker[Session]]:
|
||||
@@ -53,6 +57,7 @@ def test_orchestrator_routes_user_query_to_user_agent() -> None:
|
||||
assert payload["selected_agent"] == "user_agent"
|
||||
assert payload["permission_level"] == "read"
|
||||
assert payload["status"] == "succeeded"
|
||||
assert payload["conversation_id"]
|
||||
assert payload["result"]["answer"]
|
||||
assert payload["result"]["suggested_actions"]
|
||||
assert payload["trace_summary"]["tool_count"] >= 1
|
||||
@@ -160,6 +165,7 @@ def test_orchestrator_user_agent_draft_returns_structured_payload() -> None:
|
||||
assert payload["selected_agent"] == "user_agent"
|
||||
assert payload["status"] == "succeeded"
|
||||
assert payload["result"]["draft_payload"]["confirmation_required"] is True
|
||||
assert payload["result"]["review_payload"]["slot_cards"]
|
||||
assert payload["result"]["draft_payload"]["claim_id"]
|
||||
assert payload["result"]["draft_payload"]["claim_no"].startswith("EXP-")
|
||||
assert payload["result"]["draft_payload"]["status"] == "draft"
|
||||
@@ -177,6 +183,182 @@ def test_orchestrator_user_agent_draft_returns_structured_payload() -> None:
|
||||
assert claim.items
|
||||
|
||||
|
||||
def test_orchestrator_persists_conversation_and_reuses_expense_draft_context() -> None:
|
||||
client, session_factory = build_client()
|
||||
|
||||
first_response = client.post(
|
||||
"/api/v1/orchestrator/run",
|
||||
json={
|
||||
"source": "user_message",
|
||||
"user_id": "pytest",
|
||||
"message": "帮我生成一份差旅报销草稿,我昨天去上海出差,交通费680元",
|
||||
"context_json": {
|
||||
"role_codes": ["finance"],
|
||||
"attachment_names": ["行程单.png"],
|
||||
"attachment_count": 1,
|
||||
"ocr_summary": "行程单金额680元",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert first_response.status_code == 200
|
||||
first_payload = first_response.json()
|
||||
conversation_id = first_payload["conversation_id"]
|
||||
claim_id = first_payload["result"]["draft_payload"]["claim_id"]
|
||||
assert conversation_id
|
||||
assert claim_id
|
||||
|
||||
second_response = client.post(
|
||||
"/api/v1/orchestrator/run",
|
||||
json={
|
||||
"source": "user_message",
|
||||
"user_id": "pytest",
|
||||
"conversation_id": conversation_id,
|
||||
"message": "金额改成800元",
|
||||
"context_json": {
|
||||
"role_codes": ["finance"],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert second_response.status_code == 200
|
||||
second_payload = second_response.json()
|
||||
assert second_payload["conversation_id"] == conversation_id
|
||||
assert second_payload["trace_summary"]["scenario"] == "expense"
|
||||
assert second_payload["trace_summary"]["intent"] == "draft"
|
||||
assert second_payload["result"]["draft_payload"]["claim_id"] == claim_id
|
||||
|
||||
with session_factory() as db:
|
||||
claim = db.scalar(select(ExpenseClaim).where(ExpenseClaim.id == claim_id))
|
||||
assert claim is not None
|
||||
assert float(claim.amount) == 800.0
|
||||
|
||||
conversation = db.scalar(
|
||||
select(AgentConversation).where(AgentConversation.conversation_id == conversation_id)
|
||||
)
|
||||
assert conversation is not None
|
||||
assert conversation.draft_claim_id == claim_id
|
||||
assert conversation.last_scenario == "expense"
|
||||
assert conversation.last_intent == "draft"
|
||||
|
||||
message_count = db.scalar(
|
||||
select(func.count())
|
||||
.select_from(AgentConversationMessage)
|
||||
.where(AgentConversationMessage.conversation_id == conversation_id)
|
||||
)
|
||||
assert message_count == 4
|
||||
|
||||
|
||||
def test_orchestrator_does_not_reuse_conversation_when_user_changes() -> None:
|
||||
client, session_factory = build_client()
|
||||
|
||||
first_response = client.post(
|
||||
"/api/v1/orchestrator/run",
|
||||
json={
|
||||
"source": "user_message",
|
||||
"user_id": "user_a",
|
||||
"message": "帮我生成一份差旅报销草稿,我昨天去上海出差,交通费680元",
|
||||
"context_json": {"role_codes": ["finance"]},
|
||||
},
|
||||
)
|
||||
|
||||
assert first_response.status_code == 200
|
||||
first_payload = first_response.json()
|
||||
first_conversation_id = first_payload["conversation_id"]
|
||||
assert first_conversation_id
|
||||
|
||||
second_response = client.post(
|
||||
"/api/v1/orchestrator/run",
|
||||
json={
|
||||
"source": "user_message",
|
||||
"user_id": "user_b",
|
||||
"conversation_id": first_conversation_id,
|
||||
"message": "查一下本周报销金额",
|
||||
"context_json": {"role_codes": ["finance"]},
|
||||
},
|
||||
)
|
||||
|
||||
assert second_response.status_code == 200
|
||||
second_payload = second_response.json()
|
||||
assert second_payload["conversation_id"]
|
||||
assert second_payload["conversation_id"] != first_conversation_id
|
||||
|
||||
with session_factory() as db:
|
||||
first_conversation = db.scalar(
|
||||
select(AgentConversation).where(
|
||||
AgentConversation.conversation_id == first_conversation_id
|
||||
)
|
||||
)
|
||||
second_conversation = db.scalar(
|
||||
select(AgentConversation).where(
|
||||
AgentConversation.conversation_id == second_payload["conversation_id"]
|
||||
)
|
||||
)
|
||||
assert first_conversation is not None
|
||||
assert second_conversation is not None
|
||||
assert first_conversation.user_id == "user_a"
|
||||
assert second_conversation.user_id == "user_b"
|
||||
|
||||
|
||||
def test_orchestrator_prunes_conversations_older_than_configured_retention_days() -> None:
|
||||
client, session_factory = build_client()
|
||||
expired_conversation_id = "conv_expired"
|
||||
expired_at = datetime.now(UTC) - timedelta(days=2)
|
||||
|
||||
with session_factory() as db:
|
||||
settings_service = SettingsService(db)
|
||||
settings_payload = settings_service.get_settings_snapshot().model_dump()
|
||||
settings_payload["sessionForm"]["conversationRetentionDays"] = 1
|
||||
settings_service.save_settings_snapshot(SettingsWrite(**settings_payload))
|
||||
|
||||
conversation = AgentConversation(
|
||||
conversation_id=expired_conversation_id,
|
||||
user_id="expired_user",
|
||||
source="user_message",
|
||||
state_json={},
|
||||
message_count=1,
|
||||
created_at=expired_at,
|
||||
updated_at=expired_at,
|
||||
)
|
||||
db.add(conversation)
|
||||
db.flush()
|
||||
db.add(
|
||||
AgentConversationMessage(
|
||||
conversation_id=expired_conversation_id,
|
||||
role="user",
|
||||
content="旧会话消息",
|
||||
created_at=expired_at,
|
||||
)
|
||||
)
|
||||
db.commit()
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/orchestrator/run",
|
||||
json={
|
||||
"source": "user_message",
|
||||
"user_id": "fresh_user",
|
||||
"message": "查一下本周报销金额",
|
||||
"context_json": {"role_codes": ["finance"]},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
with session_factory() as db:
|
||||
conversation = db.scalar(
|
||||
select(AgentConversation).where(
|
||||
AgentConversation.conversation_id == expired_conversation_id
|
||||
)
|
||||
)
|
||||
message_count = db.scalar(
|
||||
select(func.count())
|
||||
.select_from(AgentConversationMessage)
|
||||
.where(AgentConversationMessage.conversation_id == expired_conversation_id)
|
||||
)
|
||||
assert conversation is None
|
||||
assert message_count == 0
|
||||
|
||||
|
||||
def test_orchestrator_treats_expense_narrative_as_draft_instead_of_ar_query() -> None:
|
||||
client, _ = build_client()
|
||||
|
||||
@@ -203,6 +385,106 @@ def test_orchestrator_treats_expense_narrative_as_draft_instead_of_ar_query() ->
|
||||
assert "请补充" in payload["result"]["message"]
|
||||
|
||||
|
||||
def test_orchestrator_can_restore_latest_user_conversation() -> None:
|
||||
client, _ = build_client()
|
||||
|
||||
first_response = client.post(
|
||||
"/api/v1/orchestrator/run",
|
||||
json={
|
||||
"source": "user_message",
|
||||
"user_id": "restore_user",
|
||||
"message": "帮我生成一份差旅报销草稿,我昨天去上海出差,交通费680元",
|
||||
"context_json": {
|
||||
"role_codes": ["finance"],
|
||||
"attachment_names": ["行程单.png"],
|
||||
"attachment_count": 1,
|
||||
"ocr_summary": "行程单金额680元",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert first_response.status_code == 200
|
||||
first_payload = first_response.json()
|
||||
|
||||
second_response = client.post(
|
||||
"/api/v1/orchestrator/run",
|
||||
json={
|
||||
"source": "user_message",
|
||||
"user_id": "restore_user",
|
||||
"conversation_id": first_payload["conversation_id"],
|
||||
"message": "金额改成800元",
|
||||
"context_json": {"role_codes": ["finance"]},
|
||||
},
|
||||
)
|
||||
|
||||
assert second_response.status_code == 200
|
||||
|
||||
restore_response = client.get(
|
||||
"/api/v1/orchestrator/conversations/latest",
|
||||
params={"user_id": "restore_user"},
|
||||
)
|
||||
|
||||
assert restore_response.status_code == 200
|
||||
restore_payload = restore_response.json()
|
||||
assert restore_payload["found"] is True
|
||||
assert restore_payload["conversation"]["conversation_id"] == first_payload["conversation_id"]
|
||||
assert restore_payload["conversation"]["draft_claim_id"] == first_payload["result"]["draft_payload"]["claim_id"]
|
||||
assert len(restore_payload["conversation"]["messages"]) == 4
|
||||
assert restore_payload["conversation"]["messages"][0]["role"] == "user"
|
||||
assert restore_payload["conversation"]["messages"][0]["message_json"]["attachment_names"] == ["行程单.png"]
|
||||
assert restore_payload["conversation"]["messages"][1]["message_json"]["orchestrator_payload"]["run_id"]
|
||||
|
||||
|
||||
def test_orchestrator_can_delete_all_user_conversations() -> None:
|
||||
client, session_factory = build_client()
|
||||
|
||||
for message in ("查一下本周报销金额", "帮我生成差旅报销草稿"):
|
||||
response = client.post(
|
||||
"/api/v1/orchestrator/run",
|
||||
json={
|
||||
"source": "user_message",
|
||||
"user_id": "delete_user",
|
||||
"message": message,
|
||||
"context_json": {"role_codes": ["finance"]},
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
other_response = client.post(
|
||||
"/api/v1/orchestrator/run",
|
||||
json={
|
||||
"source": "user_message",
|
||||
"user_id": "other_user",
|
||||
"message": "查一下供应商待付款",
|
||||
"context_json": {"role_codes": ["finance"]},
|
||||
},
|
||||
)
|
||||
assert other_response.status_code == 200
|
||||
|
||||
delete_response = client.delete(
|
||||
"/api/v1/orchestrator/conversations",
|
||||
params={"user_id": "delete_user"},
|
||||
)
|
||||
|
||||
assert delete_response.status_code == 200
|
||||
delete_payload = delete_response.json()
|
||||
assert delete_payload["deleted_count"] == 2
|
||||
|
||||
with session_factory() as db:
|
||||
deleted_count = db.scalar(
|
||||
select(func.count())
|
||||
.select_from(AgentConversation)
|
||||
.where(AgentConversation.user_id == "delete_user")
|
||||
)
|
||||
other_count = db.scalar(
|
||||
select(func.count())
|
||||
.select_from(AgentConversation)
|
||||
.where(AgentConversation.user_id == "other_user")
|
||||
)
|
||||
assert deleted_count == 0
|
||||
assert other_count == 1
|
||||
|
||||
|
||||
def test_orchestrator_tool_failure_is_logged_and_degraded() -> None:
|
||||
client, _ = build_client()
|
||||
|
||||
|
||||
@@ -50,10 +50,11 @@ def test_settings_service_persists_non_secret_and_secret_fields(monkeypatch) ->
|
||||
|
||||
payload["companyForm"]["companyName"] = "YGSOFT"
|
||||
payload["companyForm"]["displayName"] = "云广软件"
|
||||
payload["adminForm"]["adminAccount"] = "admin-root"
|
||||
payload["adminForm"]["adminEmail"] = "admin@example.com"
|
||||
payload["adminForm"]["newPassword"] = "54321"
|
||||
payload["adminForm"]["confirmPassword"] = "54321"
|
||||
payload["adminForm"]["adminAccount"] = "admin-root"
|
||||
payload["adminForm"]["adminEmail"] = "admin@example.com"
|
||||
payload["adminForm"]["newPassword"] = "54321"
|
||||
payload["adminForm"]["confirmPassword"] = "54321"
|
||||
payload["sessionForm"]["conversationRetentionDays"] = 7
|
||||
payload["llmForm"]["mainModel"] = "glm-4.5"
|
||||
payload["llmForm"]["mainApiKey"] = "main-secret"
|
||||
payload["renderForm"]["enabled"] = True
|
||||
@@ -63,8 +64,9 @@ def test_settings_service_persists_non_secret_and_secret_fields(monkeypatch) ->
|
||||
|
||||
saved_snapshot = service.save_settings_snapshot(SettingsWrite(**payload))
|
||||
|
||||
assert saved_snapshot.companyForm.companyName == "YGSOFT"
|
||||
assert saved_snapshot.companyForm.displayName == "云广软件"
|
||||
assert saved_snapshot.companyForm.companyName == "YGSOFT"
|
||||
assert saved_snapshot.companyForm.displayName == "云广软件"
|
||||
assert saved_snapshot.sessionForm.conversationRetentionDays == 7
|
||||
assert saved_snapshot.llmForm.mainModel == "glm-4.5"
|
||||
assert saved_snapshot.llmForm.mainApiKey == ""
|
||||
assert saved_snapshot.llmForm.mainApiKeyConfigured is True
|
||||
@@ -84,6 +86,7 @@ def test_settings_service_persists_non_secret_and_secret_fields(monkeypatch) ->
|
||||
assert model_row.model_name == "glm-4.5"
|
||||
assert model_row.api_key_encrypted
|
||||
assert settings_row is not None
|
||||
assert settings_row.conversation_retention_days == 7
|
||||
assert settings_row.onlyoffice_enabled is True
|
||||
assert settings_row.onlyoffice_public_url == "http://10.10.10.122:8082"
|
||||
assert secrets_row is not None
|
||||
|
||||
@@ -177,3 +177,72 @@ def test_user_agent_draft_returns_structured_payload() -> None:
|
||||
assert response.draft_payload is not None
|
||||
assert response.draft_payload.confirmation_required is True
|
||||
assert "待人工确认" in response.answer
|
||||
|
||||
|
||||
def test_user_agent_builds_review_payload_for_multi_document_expense_flow() -> None:
|
||||
session_factory = build_session_factory()
|
||||
with session_factory() as db:
|
||||
ontology = SemanticOntologyService(db).parse(
|
||||
OntologyParseRequest(
|
||||
query="我昨天去上海出差,还请客户A吃饭,帮我生成报销草稿",
|
||||
user_id="pytest",
|
||||
context_json={
|
||||
"attachment_names": ["机票行程单.png", "餐饮发票.jpg"],
|
||||
"attachment_count": 2,
|
||||
"ocr_documents": [
|
||||
{
|
||||
"filename": "机票行程单.png",
|
||||
"summary": "机票行程单 上海-北京 金额 680 元",
|
||||
"text": "机票行程单 上海-北京 金额 680 元",
|
||||
"avg_score": 0.93,
|
||||
"warnings": [],
|
||||
},
|
||||
{
|
||||
"filename": "餐饮发票.jpg",
|
||||
"summary": "餐饮发票 客户招待 金额 320 元",
|
||||
"text": "餐饮发票 客户招待 金额 320 元",
|
||||
"avg_score": 0.91,
|
||||
"warnings": [],
|
||||
},
|
||||
],
|
||||
},
|
||||
)
|
||||
)
|
||||
response = UserAgentService(db).respond(
|
||||
UserAgentRequest(
|
||||
run_id=ontology.run_id,
|
||||
user_id="pytest",
|
||||
message="我昨天去上海出差,还请客户A吃饭,帮我生成报销草稿",
|
||||
ontology=ontology,
|
||||
context_json={
|
||||
"name": "张三",
|
||||
"attachment_names": ["机票行程单.png", "餐饮发票.jpg"],
|
||||
"attachment_count": 2,
|
||||
"ocr_documents": [
|
||||
{
|
||||
"filename": "机票行程单.png",
|
||||
"summary": "机票行程单 上海-北京 金额 680 元",
|
||||
"text": "机票行程单 上海-北京 金额 680 元",
|
||||
"avg_score": 0.93,
|
||||
"warnings": [],
|
||||
},
|
||||
{
|
||||
"filename": "餐饮发票.jpg",
|
||||
"summary": "餐饮发票 客户招待 金额 320 元",
|
||||
"text": "餐饮发票 客户招待 金额 320 元",
|
||||
"avg_score": 0.91,
|
||||
"warnings": [],
|
||||
},
|
||||
],
|
||||
},
|
||||
tool_payload={"draft_only": True, "claim_no": "EXP-202605-009", "status": "draft"},
|
||||
)
|
||||
)
|
||||
|
||||
assert response.review_payload is not None
|
||||
assert len(response.review_payload.document_cards) == 2
|
||||
assert len(response.review_payload.claim_groups) == 2
|
||||
assert any(
|
||||
item.action_type == "split_claims"
|
||||
for item in response.review_payload.confirmation_actions
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user