from __future__ import annotations from collections.abc import Generator from fastapi.testclient import TestClient from sqlalchemy import create_engine from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.pool import StaticPool from app.api.deps import get_db from app.core.agent_enums import AgentAssetDomain, AgentAssetStatus from app.db.base import Base from app.main import create_app from app.models.agent_asset import AgentAsset from app.schemas.agent_asset import AgentAssetRiskRuleGenerateRequest from app.services.agent_asset_rule_library import AgentAssetRuleLibraryManager from app.services.agent_asset_risk_rule_regeneration import AgentAssetRiskRuleRegenerationService from app.services.agent_assets import AgentAssetService from app.services.risk_rule_generation import RiskRuleGenerationService class NullRuntimeChatService: def complete(self, *args, **kwargs) -> None: return None def build_client() -> tuple[TestClient, sessionmaker[Session]]: engine = create_engine( "sqlite+pysqlite:///:memory:", connect_args={"check_same_thread": False}, poolclass=StaticPool, ) Base.metadata.create_all(bind=engine) session_factory = sessionmaker(bind=engine, autoflush=False, autocommit=False) app = create_app() def override_db() -> Generator[Session, None, None]: db = session_factory() try: yield db finally: db.close() app.dependency_overrides[get_db] = override_db return TestClient(app), session_factory def test_update_risk_rule_draft_endpoint_updates_unpublished_rule(tmp_path) -> None: client, session_factory = build_client() asset_id = _create_rule(session_factory, tmp_path) response = client.patch( f"/api/v1/agent-assets/{asset_id}/risk-rules/draft", headers=_finance_headers(), json={ "rule_title": "差旅绕行说明校验", "expense_category": "travel", "natural_language": "差旅报销存在绕行但未说明原因时,进入风险复核。", "requires_attachment": True, }, ) assert response.status_code == 200 payload = response.json() assert payload["name"] == "差旅绕行说明校验" assert payload["description"] == "差旅报销存在绕行但未说明原因时,进入风险复核。" assert payload["scenario_json"] == ["差旅费"] assert payload["config_json"]["requires_attachment"] is True assert payload["config_json"]["generation_status"] == "draft_updated" assert payload["config_json"]["last_operation"]["action"] == "update_draft" def test_update_published_risk_rule_draft_endpoint_is_blocked(tmp_path) -> None: client, session_factory = build_client() asset_id = _create_rule(session_factory, tmp_path) _mark_rule_published(session_factory, asset_id) response = client.patch( f"/api/v1/agent-assets/{asset_id}/risk-rules/draft", headers=_finance_headers(), json={"natural_language": "已上线规则不能被草稿接口直接覆盖。"}, ) assert response.status_code == 400 assert "未上线" in response.json()["detail"] def test_create_risk_rule_revision_endpoint_keeps_active_version(tmp_path) -> None: client, session_factory = build_client() asset_id = _create_rule(session_factory, tmp_path) _mark_rule_published(session_factory, asset_id) response = client.post( f"/api/v1/agent-assets/{asset_id}/risk-rules/revisions", headers=_finance_headers(), json={ "rule_title": "票据城市一致性复核", "natural_language": "票据城市与申报目的地不一致时,要求补充说明。", "requires_attachment": True, "change_reason": "补充城市一致性判断。", }, ) assert response.status_code == 201 payload = response.json() revision = payload["config_json"]["revision_draft"] assert payload["status"] == AgentAssetStatus.ACTIVE.value assert payload["published_version"] == "v0.1.0" assert payload["working_version"] == "v0.1.1" assert revision["version"] == "v0.1.1" assert revision["base_version"] == "v0.1.0" assert revision["generation_request"]["natural_language"] == "票据城市与申报目的地不一致时,要求补充说明。" assert payload["config_json"]["last_operation"]["action"] == "create_revision" def test_regenerate_risk_rule_endpoint_returns_updated_detail(tmp_path, monkeypatch) -> None: client, session_factory = build_client() asset_id = _create_rule(session_factory, tmp_path) def fake_regenerate(self, target_asset_id, body, *, actor, request_id=None): del body, request_id asset = self.db.get(AgentAsset, target_asset_id) assert asset is not None config = dict(asset.config_json or {}) config["generation_status"] = "completed" config["last_operation"] = {"action": "regenerate", "actor": actor, "at": "2026-05-30T00:00:00+00:00"} asset.config_json = config self.db.add(asset) self.db.flush() return asset monkeypatch.setattr(AgentAssetRiskRuleRegenerationService, "regenerate", fake_regenerate) response = client.post( f"/api/v1/agent-assets/{asset_id}/risk-rules/regenerate", headers=_finance_headers(), json={"natural_language": "差旅票据城市与申报目的地不一致时要求补充说明。"}, ) assert response.status_code == 200 payload = response.json() assert payload["config_json"]["generation_status"] == "completed" assert payload["config_json"]["last_operation"]["action"] == "regenerate" def test_risk_rule_admin_only_actions_block_non_admin_users(tmp_path) -> None: client, session_factory = build_client() asset_id = _create_rule(session_factory, tmp_path) generate_response = client.post( "/api/v1/agent-assets/risk-rules/generate", headers=_finance_headers(), json={ "business_domain": "expense", "expense_category": "travel", "rule_title": "普通财务新建规则", "natural_language": "差旅票据城市与申报目的地不一致时提示风险。", }, ) assert generate_response.status_code == 403 simulate_response = client.post( f"/api/v1/agent-assets/{asset_id}/risk-rule-tests/simulate", headers=_finance_headers(), json={"message": "测试一张差旅票据。"}, ) assert simulate_response.status_code == 403 delete_response = client.delete( f"/api/v1/agent-assets/{asset_id}", headers=_manager_headers(), ) assert delete_response.status_code == 403 def test_manager_can_toggle_risk_rule_enabled_endpoint(tmp_path, monkeypatch) -> None: client, session_factory = build_client() asset_id = _create_rule(session_factory, tmp_path) def fake_toggle(self, target_asset_id, *, enabled, actor, request_id=None): del request_id asset = self.db.get(AgentAsset, target_asset_id) assert asset is not None config = dict(asset.config_json or {}) config["enabled"] = bool(enabled) config["last_operation"] = {"action": "offline", "actor": actor} asset.config_json = config self.db.add(asset) self.db.flush() return asset monkeypatch.setattr(AgentAssetService, "set_risk_rule_enabled", fake_toggle) response = client.post( f"/api/v1/agent-assets/{asset_id}/risk-rule-enabled", headers=_manager_headers(), json={"enabled": False}, ) assert response.status_code == 200 assert response.json()["config_json"]["enabled"] is False assert response.json()["config_json"]["last_operation"]["actor"] == "manager" def _create_rule(session_factory: sessionmaker[Session], tmp_path) -> str: with session_factory() as db: return RiskRuleGenerationService( db, rule_library_manager=AgentAssetRuleLibraryManager(rule_root=tmp_path / "rules"), runtime_chat_service=NullRuntimeChatService(), ).generate_rule_asset( AgentAssetRiskRuleGenerateRequest( business_domain=AgentAssetDomain.EXPENSE, expense_category="travel", rule_title="差旅规则草稿", natural_language="差旅报销事由缺失时,提示补充说明。", ), actor="pytest", ) def _mark_rule_published(session_factory: sessionmaker[Session], asset_id: str) -> None: with session_factory() as db: asset = db.get(AgentAsset, asset_id) assert asset is not None asset.status = AgentAssetStatus.ACTIVE.value asset.current_version = "v0.1.0" asset.published_version = "v0.1.0" asset.working_version = "v0.1.0" db.add(asset) db.commit() def _finance_headers() -> dict[str, str]: return { "x-auth-username": "finance", "x-auth-name": "finance", "x-auth-role-codes": "finance", "x-actor": "finance", } def _manager_headers() -> dict[str, str]: return { "x-auth-username": "manager", "x-auth-name": "manager", "x-auth-role-codes": "manager", "x-actor": "manager", }