Files
X-Financial/server/tests/test_risk_rule_revision_service.py

290 lines
12 KiB
Python
Raw Permalink Normal View History

from __future__ import annotations
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import StaticPool
from app.core.agent_enums import AgentAssetDomain, AgentAssetStatus
from app.db.base import Base
from app.models.agent_asset import AgentAsset, AgentAssetTestRun, AgentAssetVersion
from app.schemas.agent_asset import (
AgentAssetRiskRuleDraftUpdate,
AgentAssetRiskRuleGenerateRequest,
AgentAssetRiskRuleRegenerateRequest,
AgentAssetRiskRuleRevisionCreate,
)
from app.services.agent_asset_rule_library import AgentAssetRuleLibraryManager
from app.services.agent_asset_risk_rule_regeneration import AgentAssetRiskRuleRegenerationService
from app.services.agent_assets import AgentAssetService
from app.services.risk_rule_generation import RiskRuleGenerationService
from app.services.agent_asset_risk_rule_revision import AgentAssetRiskRuleRevisionService
class NullRuntimeChatService:
def complete(self, *args, **kwargs) -> None:
return None
def build_session() -> Session:
engine = create_engine(
"sqlite+pysqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
Base.metadata.create_all(bind=engine)
return sessionmaker(bind=engine, autoflush=False, autocommit=False)()
def test_update_unpublished_risk_rule_draft_updates_business_fields(tmp_path) -> None:
with build_session() as db:
asset_id = _create_rule(db, tmp_path)
updated = AgentAssetRiskRuleRevisionService(db).update_unpublished_draft(
asset_id,
AgentAssetRiskRuleDraftUpdate(
rule_title="差旅绕行说明校验",
expense_category="travel",
natural_language="差旅报销存在绕行但未说明原因时,进入风险复核。",
requires_attachment=True,
),
actor="finance",
)
assert updated.name == "差旅绕行说明校验"
assert updated.description == "差旅报销存在绕行但未说明原因时,进入风险复核。"
assert updated.scenario_json == ["差旅费"]
assert updated.config_json["requires_attachment"] is True
assert updated.config_json["generation_request"]["natural_language"] == updated.description
assert updated.config_json["last_operation"]["action"] == "update_draft"
def test_update_published_rule_requires_revision(tmp_path) -> None:
with build_session() as db:
asset_id = _create_rule(db, tmp_path)
asset = db.get(AgentAsset, asset_id)
assert asset is not None
asset.status = AgentAssetStatus.ACTIVE.value
asset.published_version = asset.current_version
db.add(asset)
db.flush()
with pytest.raises(PermissionError):
AgentAssetRiskRuleRevisionService(db).update_unpublished_draft(
asset_id,
AgentAssetRiskRuleDraftUpdate(natural_language="已上线规则不能直接覆盖。"),
actor="finance",
)
def test_create_revision_draft_for_published_rule_does_not_overwrite_active_version(tmp_path) -> None:
with build_session() as db:
asset_id = _create_rule(db, tmp_path)
asset = db.get(AgentAsset, asset_id)
assert asset is not None
asset.status = AgentAssetStatus.ACTIVE.value
asset.published_version = "v0.1.0"
asset.current_version = "v0.1.0"
asset.working_version = "v0.1.0"
db.add(asset)
db.flush()
updated = AgentAssetRiskRuleRevisionService(db).create_revision_draft(
asset_id,
AgentAssetRiskRuleRevisionCreate(
rule_title="差旅票据城市复核",
natural_language="票据城市与申报目的地不一致时,要求补充说明。",
requires_attachment=True,
change_reason="补充城市一致性判断。",
),
actor="manager",
)
revision = updated.config_json["revision_draft"]
assert updated.status == AgentAssetStatus.ACTIVE.value
assert updated.published_version == "v0.1.0"
assert updated.working_version == "v0.1.1"
assert revision["version"] == "v0.1.1"
assert revision["base_version"] == "v0.1.0"
assert revision["generation_request"]["natural_language"] == "票据城市与申报目的地不一致时,要求补充说明。"
assert updated.config_json["last_operation"]["action"] == "create_revision"
assert db.query(AgentAssetVersion).filter_by(asset_id=asset_id, version="v0.1.1").one()
def test_regenerate_unpublished_draft_updates_dsl_and_score(tmp_path) -> None:
with build_session() as db:
manager = AgentAssetRuleLibraryManager(rule_root=tmp_path / "rules")
asset_id = _create_rule(db, tmp_path, manager=manager)
updated = AgentAssetRiskRuleRegenerationService(
db,
rule_library_manager=manager,
runtime_chat_service=NullRuntimeChatService(),
).regenerate(
asset_id,
AgentAssetRiskRuleRegenerateRequest(
rule_title="差旅城市一致性复核",
natural_language="差旅报销票据城市与申报目的地不一致时,要求补充说明。",
requires_attachment=True,
),
actor="finance",
)
assert updated.status == AgentAssetStatus.DRAFT.value
assert updated.config_json["generation_status"] == "completed"
assert updated.config_json["risk_score"] is not None
assert updated.config_json["last_operation"]["action"] == "regenerate"
payload = manager.read_rule_library_json(
library="risk-rules",
file_name=updated.config_json["rule_document"]["file_name"],
)
assert payload["name"] == "差旅城市一致性复核"
assert payload["flow_diagram_svg"]
assert payload["metadata"]["risk_score"] == updated.config_json["risk_score"]
def test_regenerate_revision_draft_keeps_active_document_unchanged(tmp_path) -> None:
with build_session() as db:
manager = AgentAssetRuleLibraryManager(rule_root=tmp_path / "rules")
asset_id = _create_rule(db, tmp_path, manager=manager)
asset = db.get(AgentAsset, asset_id)
assert asset is not None
active_document = asset.config_json["rule_document"]
active_payload_before = manager.read_rule_library_json(
library="risk-rules",
file_name=active_document["file_name"],
)
asset.status = AgentAssetStatus.ACTIVE.value
asset.published_version = "v0.1.0"
asset.current_version = "v0.1.0"
asset.working_version = "v0.1.0"
db.add(asset)
db.flush()
AgentAssetRiskRuleRevisionService(db).create_revision_draft(
asset_id,
AgentAssetRiskRuleRevisionCreate(
rule_title="票据城市一致性复核",
natural_language="票据城市与申报目的地不一致时,要求补充说明。",
requires_attachment=True,
change_reason="补充城市一致性判断。",
),
actor="manager",
)
updated = AgentAssetRiskRuleRegenerationService(
db,
rule_library_manager=manager,
runtime_chat_service=NullRuntimeChatService(),
).regenerate(
asset_id,
AgentAssetRiskRuleRegenerateRequest(),
actor="manager",
)
revision = updated.config_json["revision_draft"]
assert updated.status == AgentAssetStatus.ACTIVE.value
assert updated.published_version == "v0.1.0"
assert updated.config_json["rule_document"] == active_document
assert revision["generation_status"] == "completed"
assert revision["risk_score"] is not None
assert revision["rule_document"]["file_name"] != active_document["file_name"]
active_payload_after = manager.read_rule_library_json(
library="risk-rules",
file_name=active_document["file_name"],
)
assert active_payload_after == active_payload_before
revision_payload = manager.read_rule_library_json(
library="risk-rules",
file_name=revision["rule_document"]["file_name"],
)
assert revision_payload["rule_code"] == updated.code
assert revision_payload["enabled"] is False
detail_service = AgentAssetService(db)
detail_service.rule_library_manager = manager
displayed = detail_service.read_rule_json(asset_id)
assert displayed.file_name == revision["rule_document"]["file_name"]
assert displayed.payload["rule_code"] == updated.code
def test_publish_regenerated_revision_replaces_online_document(tmp_path) -> None:
with build_session() as db:
manager = AgentAssetRuleLibraryManager(rule_root=tmp_path / "rules")
asset_id = _create_rule(db, tmp_path, manager=manager)
asset = db.get(AgentAsset, asset_id)
assert asset is not None
old_document = asset.config_json["rule_document"]
asset.status = AgentAssetStatus.ACTIVE.value
asset.published_version = "v0.1.0"
asset.current_version = "v0.1.0"
asset.working_version = "v0.1.0"
db.add(asset)
db.flush()
AgentAssetRiskRuleRevisionService(db).create_revision_draft(
asset_id,
AgentAssetRiskRuleRevisionCreate(
rule_title="差旅票据城市复核",
natural_language="票据城市与申报目的地不一致时,要求补充说明。",
requires_attachment=True,
change_reason="补充城市一致性判断。",
),
actor="manager",
)
regenerated = AgentAssetRiskRuleRegenerationService(
db,
rule_library_manager=manager,
runtime_chat_service=NullRuntimeChatService(),
).regenerate(asset_id, AgentAssetRiskRuleRegenerateRequest(), actor="manager")
revision = regenerated.config_json["revision_draft"]
db.add(
AgentAssetTestRun(
asset_id=asset_id,
version="v0.1.1",
test_type="report",
status="passed",
passed=True,
summary="测试报告已确认。",
input_json={},
result_json={},
created_by="manager",
)
)
db.flush()
service = AgentAssetService(db)
service.rule_library_manager = manager
published = service.publish_risk_rule(asset_id, actor="manager")
assert published.status == AgentAssetStatus.ACTIVE.value
assert published.current_version == "v0.1.1"
assert published.published_version == "v0.1.1"
assert "revision_draft" not in published.config_json
assert published.config_json["rule_document"] == revision["rule_document"]
assert published.config_json["revision_history"][0]["previous_rule_document"] == old_document
assert published.config_json["last_operation"]["action"] == "publish_revision"
manifest = manager.read_rule_library_json(
library="risk-rules",
file_name=published.config_json["rule_document"]["file_name"],
)
assert manifest["enabled"] is True
assert manifest["rule_code"] == published.code
def _create_rule(
db: Session,
tmp_path,
*,
manager: AgentAssetRuleLibraryManager | None = None,
) -> str:
return RiskRuleGenerationService(
db,
rule_library_manager=manager or AgentAssetRuleLibraryManager(rule_root=tmp_path / "rules"),
runtime_chat_service=NullRuntimeChatService(),
).generate_rule_asset(
AgentAssetRiskRuleGenerateRequest(
business_domain=AgentAssetDomain.EXPENSE,
expense_category="travel",
rule_title="差旅规则草稿",
natural_language="差旅报销事由缺失时,提示补充说明。",
),
actor="pytest",
)