from __future__ import annotations import json from sqlalchemy import create_engine from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.pool import StaticPool from app.core.agent_enums import AgentAssetDomain, AgentAssetStatus from app.db.base import Base from app.models.agent_asset import AgentAsset from app.schemas.agent_asset import AgentAssetRiskRuleGenerateRequest from app.services.agent_asset_rule_library import AgentAssetRuleLibraryManager from app.services.agent_asset_spreadsheet import RISK_RULES_LIBRARY from app.services.risk_rule_flow_diagram import RiskRuleFlowDiagramRenderer, RiskRuleFlowDiagramSpec from app.services.risk_rule_generation import RiskRuleGenerationService class NullRuntimeChatService: def complete(self, *args, **kwargs) -> None: return None def build_session() -> Session: engine = create_engine( "sqlite+pysqlite:///:memory:", connect_args={"check_same_thread": False}, poolclass=StaticPool, ) Base.metadata.create_all(bind=engine) session_factory = sessionmaker(bind=engine, autoflush=False, autocommit=False) return session_factory() def test_generate_risk_rule_asset_creates_draft_json_rule(tmp_path) -> None: with build_session() as db: service = RiskRuleGenerationService( db, rule_library_manager=AgentAssetRuleLibraryManager(rule_root=tmp_path / "rules"), runtime_chat_service=NullRuntimeChatService(), ) asset_id = service.generate_rule_asset( AgentAssetRiskRuleGenerateRequest( business_domain=AgentAssetDomain.EXPENSE, risk_level="high", natural_language="住宿城市必须出现在本次差旅行程城市中,否则提示高风险。", ), actor="pytest", ) asset = db.get(AgentAsset, asset_id) assert asset is not None assert asset.status == AgentAssetStatus.DRAFT.value assert asset.config_json["detail_mode"] == "json_risk" assert asset.config_json["evaluator"] == "template_rule" assert asset.current_version == "v0.1.0" file_name = asset.config_json["rule_document"]["file_name"] rule_path = tmp_path / "rules" / RISK_RULES_LIBRARY / file_name payload = json.loads(rule_path.read_text(encoding="utf-8")) assert payload["rule_code"] == asset.code assert payload["outcomes"]["fail"]["severity"] == "high" assert payload["template_key"] == "field_compare_v1" assert payload["metadata"]["natural_language"].startswith("住宿城市") assert payload["inputs"]["fields"] assert payload["flow_diagram_svg"].startswith(" None: renderer = RiskRuleFlowDiagramRenderer() def render(severity: str, label: str) -> str: return renderer.render( RiskRuleFlowDiagramSpec( title="测试规则", domain_label="报销", severity=severity, severity_label=label, fields=(), start="业务单据提交", evidence="读取规则字段", decision="判断是否命中风险", basis="根据规则字段判断", pass_text="未命中风险,继续流转", fail_text=f"命中{label},进入复核", ) ) assert "#2563eb" in render("low", "低风险") assert "#f97316" in render("medium", "中风险") high_svg = render("high", "高风险") assert "#dc2626" in high_svg assert high_svg.count("#dc2626") == 1 assert "#10a37f" not in high_svg