Files
X-Financial/server/tests/test_risk_rule_revision_endpoints.py

150 lines
5.5 KiB
Python
Raw Normal View History

from __future__ import annotations
from collections.abc import Generator
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import StaticPool
from app.api.deps import get_db
from app.core.agent_enums import AgentAssetDomain, AgentAssetStatus
from app.db.base import Base
from app.main import create_app
from app.models.agent_asset import AgentAsset
from app.schemas.agent_asset import AgentAssetRiskRuleGenerateRequest
from app.services.agent_asset_rule_library import AgentAssetRuleLibraryManager
from app.services.risk_rule_generation import RiskRuleGenerationService
class NullRuntimeChatService:
def complete(self, *args, **kwargs) -> None:
return None
def build_client() -> tuple[TestClient, sessionmaker[Session]]:
engine = create_engine(
"sqlite+pysqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
Base.metadata.create_all(bind=engine)
session_factory = sessionmaker(bind=engine, autoflush=False, autocommit=False)
app = create_app()
def override_db() -> Generator[Session, None, None]:
db = session_factory()
try:
yield db
finally:
db.close()
app.dependency_overrides[get_db] = override_db
return TestClient(app), session_factory
def test_update_risk_rule_draft_endpoint_updates_unpublished_rule(tmp_path) -> None:
client, session_factory = build_client()
asset_id = _create_rule(session_factory, tmp_path)
response = client.patch(
f"/api/v1/agent-assets/{asset_id}/risk-rules/draft",
headers=_finance_headers(),
json={
"rule_title": "差旅绕行说明校验",
"expense_category": "travel",
"natural_language": "差旅报销存在绕行但未说明原因时,进入风险复核。",
"requires_attachment": True,
},
)
assert response.status_code == 200
payload = response.json()
assert payload["name"] == "差旅绕行说明校验"
assert payload["description"] == "差旅报销存在绕行但未说明原因时,进入风险复核。"
assert payload["scenario_json"] == ["差旅费"]
assert payload["config_json"]["requires_attachment"] is True
assert payload["config_json"]["generation_status"] == "draft_updated"
assert payload["config_json"]["last_operation"]["action"] == "update_draft"
def test_update_published_risk_rule_draft_endpoint_is_blocked(tmp_path) -> None:
client, session_factory = build_client()
asset_id = _create_rule(session_factory, tmp_path)
_mark_rule_published(session_factory, asset_id)
response = client.patch(
f"/api/v1/agent-assets/{asset_id}/risk-rules/draft",
headers=_finance_headers(),
json={"natural_language": "已上线规则不能被草稿接口直接覆盖。"},
)
assert response.status_code == 400
assert "未上线" in response.json()["detail"]
def test_create_risk_rule_revision_endpoint_keeps_active_version(tmp_path) -> None:
client, session_factory = build_client()
asset_id = _create_rule(session_factory, tmp_path)
_mark_rule_published(session_factory, asset_id)
response = client.post(
f"/api/v1/agent-assets/{asset_id}/risk-rules/revisions",
headers=_finance_headers(),
json={
"rule_title": "票据城市一致性复核",
"natural_language": "票据城市与申报目的地不一致时,要求补充说明。",
"requires_attachment": True,
"change_reason": "补充城市一致性判断。",
},
)
assert response.status_code == 201
payload = response.json()
revision = payload["config_json"]["revision_draft"]
assert payload["status"] == AgentAssetStatus.ACTIVE.value
assert payload["published_version"] == "v0.1.0"
assert payload["working_version"] == "v0.1.1"
assert revision["version"] == "v0.1.1"
assert revision["base_version"] == "v0.1.0"
assert revision["generation_request"]["natural_language"] == "票据城市与申报目的地不一致时,要求补充说明。"
assert payload["config_json"]["last_operation"]["action"] == "create_revision"
def _create_rule(session_factory: sessionmaker[Session], tmp_path) -> str:
with session_factory() as db:
return RiskRuleGenerationService(
db,
rule_library_manager=AgentAssetRuleLibraryManager(rule_root=tmp_path / "rules"),
runtime_chat_service=NullRuntimeChatService(),
).generate_rule_asset(
AgentAssetRiskRuleGenerateRequest(
business_domain=AgentAssetDomain.EXPENSE,
expense_category="travel",
rule_title="差旅规则草稿",
natural_language="差旅报销事由缺失时,提示补充说明。",
),
actor="pytest",
)
def _mark_rule_published(session_factory: sessionmaker[Session], asset_id: str) -> None:
with session_factory() as db:
asset = db.get(AgentAsset, asset_id)
assert asset is not None
asset.status = AgentAssetStatus.ACTIVE.value
asset.current_version = "v0.1.0"
asset.published_version = "v0.1.0"
asset.working_version = "v0.1.0"
db.add(asset)
db.commit()
def _finance_headers() -> dict[str, str]:
return {
"x-auth-username": "finance",
"x-auth-name": "finance",
"x-auth-role-codes": "finance",
"x-actor": "finance",
}