from __future__ import annotations from collections.abc import Generator from fastapi.testclient import TestClient from sqlalchemy import create_engine from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.pool import StaticPool from app.api.deps import get_db from app.core.agent_enums import AgentAssetDomain, AgentAssetStatus from app.db.base import Base from app.main import create_app from app.models.agent_asset import AgentAsset from app.schemas.agent_asset import AgentAssetRiskRuleGenerateRequest from app.services.agent_asset_rule_library import AgentAssetRuleLibraryManager from app.services.risk_rule_generation import RiskRuleGenerationService class NullRuntimeChatService: def complete(self, *args, **kwargs) -> None: return None def build_client() -> tuple[TestClient, sessionmaker[Session]]: engine = create_engine( "sqlite+pysqlite:///:memory:", connect_args={"check_same_thread": False}, poolclass=StaticPool, ) Base.metadata.create_all(bind=engine) session_factory = sessionmaker(bind=engine, autoflush=False, autocommit=False) app = create_app() def override_db() -> Generator[Session, None, None]: db = session_factory() try: yield db finally: db.close() app.dependency_overrides[get_db] = override_db return TestClient(app), session_factory def test_update_risk_rule_draft_endpoint_updates_unpublished_rule(tmp_path) -> None: client, session_factory = build_client() asset_id = _create_rule(session_factory, tmp_path) response = client.patch( f"/api/v1/agent-assets/{asset_id}/risk-rules/draft", headers=_finance_headers(), json={ "rule_title": "差旅绕行说明校验", "expense_category": "travel", "natural_language": "差旅报销存在绕行但未说明原因时,进入风险复核。", "requires_attachment": True, }, ) assert response.status_code == 200 payload = response.json() assert payload["name"] == "差旅绕行说明校验" assert payload["description"] == "差旅报销存在绕行但未说明原因时,进入风险复核。" assert payload["scenario_json"] == ["差旅费"] assert payload["config_json"]["requires_attachment"] is True assert payload["config_json"]["generation_status"] == "draft_updated" assert payload["config_json"]["last_operation"]["action"] == "update_draft" def test_update_published_risk_rule_draft_endpoint_is_blocked(tmp_path) -> None: client, session_factory = build_client() asset_id = _create_rule(session_factory, tmp_path) _mark_rule_published(session_factory, asset_id) response = client.patch( f"/api/v1/agent-assets/{asset_id}/risk-rules/draft", headers=_finance_headers(), json={"natural_language": "已上线规则不能被草稿接口直接覆盖。"}, ) assert response.status_code == 400 assert "未上线" in response.json()["detail"] def test_create_risk_rule_revision_endpoint_keeps_active_version(tmp_path) -> None: client, session_factory = build_client() asset_id = _create_rule(session_factory, tmp_path) _mark_rule_published(session_factory, asset_id) response = client.post( f"/api/v1/agent-assets/{asset_id}/risk-rules/revisions", headers=_finance_headers(), json={ "rule_title": "票据城市一致性复核", "natural_language": "票据城市与申报目的地不一致时,要求补充说明。", "requires_attachment": True, "change_reason": "补充城市一致性判断。", }, ) assert response.status_code == 201 payload = response.json() revision = payload["config_json"]["revision_draft"] assert payload["status"] == AgentAssetStatus.ACTIVE.value assert payload["published_version"] == "v0.1.0" assert payload["working_version"] == "v0.1.1" assert revision["version"] == "v0.1.1" assert revision["base_version"] == "v0.1.0" assert revision["generation_request"]["natural_language"] == "票据城市与申报目的地不一致时,要求补充说明。" assert payload["config_json"]["last_operation"]["action"] == "create_revision" def _create_rule(session_factory: sessionmaker[Session], tmp_path) -> str: with session_factory() as db: return RiskRuleGenerationService( db, rule_library_manager=AgentAssetRuleLibraryManager(rule_root=tmp_path / "rules"), runtime_chat_service=NullRuntimeChatService(), ).generate_rule_asset( AgentAssetRiskRuleGenerateRequest( business_domain=AgentAssetDomain.EXPENSE, expense_category="travel", rule_title="差旅规则草稿", natural_language="差旅报销事由缺失时,提示补充说明。", ), actor="pytest", ) def _mark_rule_published(session_factory: sessionmaker[Session], asset_id: str) -> None: with session_factory() as db: asset = db.get(AgentAsset, asset_id) assert asset is not None asset.status = AgentAssetStatus.ACTIVE.value asset.current_version = "v0.1.0" asset.published_version = "v0.1.0" asset.working_version = "v0.1.0" db.add(asset) db.commit() def _finance_headers() -> dict[str, str]: return { "x-auth-username": "finance", "x-auth-name": "finance", "x-auth-role-codes": "finance", "x-actor": "finance", }