Files
X-Financial/server/tests/test_risk_rule_semantic_plan_generation.py

196 lines
8.1 KiB
Python
Raw Permalink Normal View History

from __future__ import annotations
import json
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import StaticPool
from app.core.agent_enums import AgentAssetDomain
from app.db.base import Base
from app.models.agent_asset import AgentAsset
from app.schemas.agent_asset import AgentAssetRiskRuleGenerateRequest
from app.services.agent_asset_rule_library import AgentAssetRuleLibraryManager
from app.services.agent_asset_spreadsheet import RISK_RULES_LIBRARY
from app.services.risk_rule_generation import RiskRuleGenerationService
from app.services.risk_rule_generation_prompt import build_risk_rule_compiler_messages
from app.services.risk_rule_generation_semantic_plan import unwrap_semantic_plan_payload
class SemanticPlanEnvelopeRuntimeChatService:
def complete(self, *args, **kwargs) -> str:
return json.dumps(
{
"semantic_plan": {
"rule_intent": "费用申请金额不得超过可用预算",
"judgment_steps": [
"读取申请金额",
"读取可用预算余额",
"比较申请金额是否大于可用预算",
],
},
"dsl": {
"name": "预算余额超额校验",
"description": "申请金额超过当前可用预算时提示风险。",
"template_key": "composite_rule_v1",
"semantic_type": "budget_available_balance_check",
"field_keys": ["claim.amount", "budget.remaining_amount"],
"condition_summary": "claim.amount > budget.remaining_amount",
"conditions": [
{
"id": "amount_exceeds_budget",
"operator": "numeric_compare",
"left_fields": ["claim.amount"],
"right_fields": ["budget.remaining_amount"],
"compare": "gt",
}
],
"hit_logic": {"all": ["amount_exceeds_budget"]},
"message_template": "申请金额超过当前可用预算余额。",
"keywords": [],
},
},
ensure_ascii=False,
)
class SemanticPlanOnlyRuntimeChatService:
def complete(self, *args, **kwargs) -> str:
return json.dumps(
{
"semantic_plan": {
"rule_intent": "费用申请金额超过可用预算余额时提示风险",
"required_fields": [
{"field": "claim.amount", "label": "申请金额"},
{"field": "budget.remaining_amount", "label": "可用预算余额"},
],
"judgment_steps": [
"读取申请金额 claim.amount",
"读取可用预算余额 budget.remaining_amount",
"若申请金额超过可用预算余额则命中预算风险",
],
"risk_action": {"message": "要求补充预算审批说明"},
}
},
ensure_ascii=False,
)
def build_session() -> Session:
engine = create_engine(
"sqlite+pysqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
Base.metadata.create_all(bind=engine)
return sessionmaker(bind=engine, autoflush=False, autocommit=False)()
def test_prompt_requires_semantic_plan_then_dsl() -> None:
messages = build_risk_rule_compiler_messages(
domain="expense",
domain_label="报销",
business_stage="expense_application",
business_stage_label="费用申请",
expense_category="travel",
expense_category_label="差旅费",
natural_language="申请金额超过预算余额时提示风险。",
available_fields=[{"key": "claim.amount", "label": "申请金额", "type": "number", "source": "claim"}],
)
request_payload = json.loads(messages[1]["content"])
required_shape = request_payload["required_json_shape"]
assert "semantic_plan" in required_shape
assert "dsl" in required_shape
assert "semantic_plan 和 dsl" in messages[0]["content"]
def test_semantic_plan_envelope_is_unwrapped_and_persisted(tmp_path) -> None:
with build_session() as db:
manager = AgentAssetRuleLibraryManager(rule_root=tmp_path / "rules")
service = RiskRuleGenerationService(
db,
rule_library_manager=manager,
runtime_chat_service=SemanticPlanEnvelopeRuntimeChatService(),
)
asset_id = service.generate_rule_asset(
AgentAssetRiskRuleGenerateRequest(
business_domain=AgentAssetDomain.EXPENSE,
business_stage="expense_application",
expense_category="travel",
rule_title="预算余额超额校验",
natural_language="费用申请时,如果申请金额超过当前可用预算余额,则提示预算风险。",
),
actor="pytest",
)
asset = db.get(AgentAsset, asset_id)
assert asset is not None
payload = manager.read_rule_library_json(
library=RISK_RULES_LIBRARY,
file_name=asset.config_json["rule_document"]["file_name"],
)
assert payload["template_key"] == "composite_rule_v1"
assert payload["params"]["conditions"][0]["operator"] == "numeric_compare"
assert payload["metadata"]["model_semantic_plan"]["rule_intent"] == "费用申请金额不得超过可用预算"
assert payload["semantic_plan"]["judgment_steps"]
def test_semantic_plan_only_response_can_generate_standard_dsl(tmp_path) -> None:
with build_session() as db:
manager = AgentAssetRuleLibraryManager(rule_root=tmp_path / "rules")
service = RiskRuleGenerationService(
db,
rule_library_manager=manager,
runtime_chat_service=SemanticPlanOnlyRuntimeChatService(),
)
asset_id = service.generate_rule_asset(
AgentAssetRiskRuleGenerateRequest(
business_domain=AgentAssetDomain.EXPENSE,
business_stage="expense_application",
expense_category="travel",
rule_title="预算余额语义计划校验",
natural_language="费用申请金额超过可用预算余额时提示风险,并要求补充审批说明。",
),
actor="pytest",
)
asset = db.get(AgentAsset, asset_id)
assert asset is not None
payload = manager.read_rule_library_json(
library=RISK_RULES_LIBRARY,
file_name=asset.config_json["rule_document"]["file_name"],
)
assert payload["params"]["conditions"][0]["operator"] == "numeric_compare"
assert payload["params"]["conditions"][0]["left_fields"] == ["claim.amount"]
assert payload["params"]["conditions"][0]["right_fields"] == ["budget.remaining_amount"]
assert payload["metadata"]["model_semantic_plan"]["required_fields"]
def test_unwrap_semantic_plan_payload_keeps_legacy_payload_compatible() -> None:
legacy = {"template_key": "field_required_v1", "field_keys": ["claim.reason"]}
assert unwrap_semantic_plan_payload(legacy) == legacy
wrapped = unwrap_semantic_plan_payload(
{
"semantic_plan": {"rule_intent": "预算校验"},
"dsl": {"template_key": "composite_rule_v1", "field_keys": ["claim.amount"]},
}
)
assert wrapped["template_key"] == "composite_rule_v1"
assert wrapped["model_semantic_plan"]["rule_intent"] == "预算校验"
plan_only = unwrap_semantic_plan_payload(
{
"semantic_plan": {
"rule_intent": "预算校验",
"required_fields": [{"field": "claim.amount"}],
"judgment_steps": ["申请金额超过预算余额"],
}
}
)
assert plan_only["template_key"] == "composite_rule_v1"
assert plan_only["field_keys"] == ["claim.amount"]