619 lines
30 KiB
Python
619 lines
30 KiB
Python
import json
|
|
import os
|
|
import shutil
|
|
import time
|
|
import unittest
|
|
import uuid
|
|
from unittest.mock import patch
|
|
|
|
from openpyxl import load_workbook
|
|
|
|
from app.utils.rule_generation import RuleGenerationService
|
|
|
|
|
|
class RuleGenerationServiceTest(unittest.TestCase):
|
|
def setUp(self):
|
|
self.output_dir = os.path.join(os.getcwd(), "output", f"test-rule-generation-{uuid.uuid4().hex[:8]}")
|
|
os.makedirs(self.output_dir, exist_ok=True)
|
|
self.domains_file = os.path.join(self.output_dir, "domains.json")
|
|
self.schema_file = os.path.join(self.output_dir, "schema.json")
|
|
|
|
with open(self.domains_file, "w", encoding="utf-8") as f:
|
|
json.dump({
|
|
"domains": [
|
|
{
|
|
"token": "token-1",
|
|
"domain": "\u8fc7\u5ea6\u8d1f\u503a",
|
|
"guidance_files": [{
|
|
"file_id": "file-1",
|
|
"filename": "policy.txt",
|
|
"guidance_analysis": {
|
|
"status": "done",
|
|
"description_patterns": [{
|
|
"description_pattern": "\u6838\u5fc3\u6cd5\u89c4\uff1a\u300a\u6d4b\u8bd5\u529e\u6cd5\u300b\uff1b\u6761\u6b3e\u8981\u70b9\uff1a\u52a8\u6001\u76d1\u6d4b\u8d44\u4ea7\u8d1f\u503a\u7387\u3002",
|
|
"basis_text": "\u52a8\u6001\u76d1\u6d4b\u8d44\u4ea7\u8d1f\u503a\u7387\u3002",
|
|
"source_sentence": "\u52a8\u6001\u76d1\u6d4b\u8d44\u4ea7\u8d1f\u503a\u7387\u3002",
|
|
"supervision_dimension": "\u8d44\u91d1\u7a7f\u900f",
|
|
"keywords": ["\u8d44\u4ea7\u8d1f\u503a\u7387", "\u503a\u52a1"],
|
|
}],
|
|
},
|
|
}],
|
|
},
|
|
{
|
|
"token": "token-no-guidance",
|
|
"domain": "\u865a\u5047\u8d38\u6613",
|
|
"guidance_files": [],
|
|
},
|
|
{
|
|
"token": "token-unparsed",
|
|
"domain": "\u591a\u5143\u6295\u8d44",
|
|
"guidance_files": [{
|
|
"file_id": "file-2",
|
|
"filename": "pending.txt",
|
|
}],
|
|
},
|
|
],
|
|
}, f, ensure_ascii=False)
|
|
|
|
self._write_schema()
|
|
self.service = RuleGenerationService(
|
|
domains_file=self.domains_file,
|
|
schema_file=self.schema_file,
|
|
output_dir=self.output_dir,
|
|
)
|
|
|
|
def tearDown(self):
|
|
if os.path.isdir(self.output_dir):
|
|
shutil.rmtree(self.output_dir)
|
|
|
|
def _write_schema(self):
|
|
with open(self.schema_file, "w", encoding="utf-8") as f:
|
|
json.dump({
|
|
"modules": [
|
|
{
|
|
"module_name": "\u94f6\u884c\u8d37\u6b3e",
|
|
"table_name": "bank_loan",
|
|
"description": "\u8bb0\u5f55\u8d37\u6b3e\u4f59\u989d\u548c\u503a\u52a1\u878d\u8d44\u4fe1\u606f\u3002",
|
|
"fields": [
|
|
{"name": "\u6240\u5c5e\u96c6\u56e2\u7f16\u7801", "marker": "COMCODE"},
|
|
{"name": "\u6240\u5c5e\u96c6\u56e2\u540d\u79f0", "marker": "COMNAME"},
|
|
{"name": "\u8d37\u6b3e\u5355\u4f4d\u540d\u79f0", "marker": "COMPANY_CLTNAME"},
|
|
{"name": "\u8d37\u6b3e\u4f59\u989d", "marker": "STANDARDCURRENCYBALANCE"},
|
|
],
|
|
},
|
|
{
|
|
"module_name": "\u94f6\u884c\u8d26\u6237",
|
|
"table_name": "bank_account",
|
|
"description": "\u8bb0\u5f55\u8d26\u6237\u4f59\u989d\u548c\u8d26\u6237\u72b6\u6001\u3002",
|
|
"fields": [
|
|
{"name": "\u6240\u5c5e\u96c6\u56e2\u7f16\u7801", "marker": "COMCODE"},
|
|
{"name": "\u8d26\u6237\u4f59\u989d", "marker": "STANDARDCURRENCYBALANCE"},
|
|
],
|
|
},
|
|
],
|
|
}, f, ensure_ascii=False)
|
|
|
|
def _fake_rule(self):
|
|
return {
|
|
"rule_name": "\u8d44\u4ea7\u8d1f\u503a\u7387\u8d85\u9650\u9884\u8b66",
|
|
"risk_description": "\u6d4b\u8bd5\u98ce\u9669\u63cf\u8ff0",
|
|
"policy_basis": "\u6838\u5fc3\u6cd5\u89c4\uff1a\u300a\u6d4b\u8bd5\u529e\u6cd5\u300b\uff1b\u6761\u6b3e\u8981\u70b9\uff1a\u52a8\u6001\u76d1\u6d4b\u8d44\u4ea7\u8d1f\u503a\u7387\u3002",
|
|
"business_rule_description": "\u8d44\u4ea7\u8d1f\u503a\u7387 = \u8d37\u6b3e\u4f59\u989d / \u8d26\u6237\u4f59\u989d\n\u98ce\u9669\u7b49\u7ea7\uff1a\n- \u7ea2\u8272\uff1a\u8d44\u4ea7\u8d1f\u503a\u7387 > 85%\n- \u6a59\u8272\uff1a\u8d44\u4ea7\u8d1f\u503a\u7387 > 75%\n- \u9ec4\u8272\uff1a\u8d44\u4ea7\u8d1f\u503a\u7387 > 65%",
|
|
"data_sources": "\u94f6\u884c\u8d37\u6b3e\uff1a\u6240\u5c5e\u96c6\u56e2\u7f16\u7801\u3001\u8d37\u6b3e\u4f59\u989d\n\u94f6\u884c\u8d26\u6237\uff1a\u8d26\u6237\u4f59\u989d",
|
|
"system_rule_text": "\u5982\u679c\uff1a\u8d44\u4ea7\u8d1f\u503a\u7387 > 0.85\n\u5219\uff1a\u6821\u9a8c\u901a\u8fc7",
|
|
"join_logic": "\u94f6\u884c\u8d37\u6b3e\u5de6\u5173\u8054\u94f6\u884c\u8d26\u6237\u3002",
|
|
"system_logic": "\u503c1\uff1a\u8d37\u6b3e\u4f59\u989d\uff1b\u503c2\uff1a\u8d26\u6237\u4f59\u989d\u3002",
|
|
"return_fields": "\u6240\u5c5e\u96c6\u56e2\u7f16\u7801,\u8d37\u6b3e\u4f59\u989d,\u98ce\u9669\u7b49\u7ea7",
|
|
}
|
|
|
|
def test_start_generates_status_and_excel(self):
|
|
with patch.object(RuleGenerationService, "_generate_rule", return_value=self._fake_rule()):
|
|
state = self.service.start(limit=1)
|
|
task_id = state["task_id"]
|
|
|
|
for _ in range(50):
|
|
current = self.service.get_status(task_id)
|
|
if current and current["status"] in {"done", "failed"}:
|
|
break
|
|
time.sleep(0.05)
|
|
|
|
current = self.service.get_status(task_id)
|
|
self.assertEqual(current["status"], "done")
|
|
self.assertEqual(current["generated_count"], 1)
|
|
self.assertTrue(os.path.exists(current["output_file"]))
|
|
self.assertTrue(os.path.exists(current["markdown_file"]))
|
|
self.assertEqual(os.path.dirname(current["output_file"]), os.path.dirname(current["markdown_file"]))
|
|
self.assertEqual(os.path.basename(os.path.dirname(current["output_file"])), f"rules-{task_id}")
|
|
self.assertEqual(current["output_dir"], os.path.dirname(current["output_file"]))
|
|
self.assertEqual(current["files"]["excel"], current["output_file"])
|
|
self.assertEqual(current["files"]["markdown"], current["markdown_file"])
|
|
|
|
workbook = load_workbook(current["output_file"])
|
|
try:
|
|
sheet = workbook.active
|
|
self.assertEqual(sheet.max_column, 11)
|
|
self.assertEqual(sheet["A2"].value, 1)
|
|
self.assertIn("bank_loan(\u94f6\u884c\u8d37\u6b3e)", sheet["G2"].value)
|
|
self.assertIn("COMCODE(\u6240\u5c5e\u96c6\u56e2\u7f16\u7801)", sheet["G2"].value)
|
|
self.assertIn("STANDARDCURRENCYBALANCE(\u8d37\u6b3e\u4f59\u989d)", sheet["G2"].value)
|
|
self.assertEqual(sheet["L1"].value, None)
|
|
finally:
|
|
workbook.close()
|
|
|
|
with open(current["markdown_file"], "r", encoding="utf-8") as file:
|
|
markdown = file.read()
|
|
self.assertIn("# \u91cd\u70b9\u98ce\u9669\u9886\u57df\u89c4\u5219\u6a21\u578b\u4e0e\u811a\u672c", markdown)
|
|
self.assertIn("\u4ee5\u5b9e\u9645\u751f\u6210\u7684\u98ce\u9669\u9886\u57df\u4e3a\u51c6", markdown)
|
|
self.assertIn("## 1. \u8fc7\u5ea6\u8d1f\u503a", markdown)
|
|
self.assertIn("### \u89c4\u5219\u6a21\u578b1.1\uff1a\u8d44\u4ea7\u8d1f\u503a\u7387\u8d85\u9650\u9884\u8b66", markdown)
|
|
self.assertIn("bank_loan(\u94f6\u884c\u8d37\u6b3e)", markdown)
|
|
|
|
def test_start_generates_limit_rules_per_policy_point(self):
|
|
with patch.object(RuleGenerationService, "_generate_rule", return_value=self._fake_rule()) as generate:
|
|
state = self.service.start(limit=2)
|
|
task_id = state["task_id"]
|
|
|
|
for _ in range(50):
|
|
current = self.service.get_status(task_id)
|
|
if current and current["status"] in {"done", "failed"}:
|
|
break
|
|
time.sleep(0.05)
|
|
|
|
current = self.service.get_status(task_id)
|
|
self.assertEqual(current["status"], "done")
|
|
self.assertEqual(current["generated_count"], 2)
|
|
self.assertEqual(current["progress"]["total"], 2)
|
|
self.assertEqual(generate.call_count, 2)
|
|
self.assertEqual(
|
|
{item["domain"] for item in current["skipped_domains"]},
|
|
{"\u865a\u5047\u8d38\u6613", "\u591a\u5143\u6295\u8d44"},
|
|
)
|
|
|
|
workbook = load_workbook(current["output_file"])
|
|
try:
|
|
self.assertEqual(workbook.active.max_row, 3)
|
|
finally:
|
|
workbook.close()
|
|
|
|
def test_collect_policy_basis_only_uses_analyzed_guidance_files(self):
|
|
domains = self.service._collect_policy_basis()
|
|
|
|
self.assertEqual([item["token"] for item in domains], ["token-1"])
|
|
self.assertEqual(len(domains[0]["patterns"]), 1)
|
|
|
|
def test_write_excel_appends_sql_column_when_enabled(self):
|
|
service = RuleGenerationService(
|
|
domains_file=self.domains_file,
|
|
schema_file=self.schema_file,
|
|
output_dir=self.output_dir,
|
|
create_sql=True,
|
|
)
|
|
output_file = os.path.join(self.output_dir, "with-sql.xlsx")
|
|
service._write_excel([self._fake_rule()], output_file)
|
|
|
|
workbook = load_workbook(output_file)
|
|
try:
|
|
sheet = workbook.active
|
|
sql = sheet["L2"].value
|
|
self.assertEqual(sheet.max_column, 12)
|
|
self.assertIn("SQL", sheet["L1"].value)
|
|
self.assertIn("SELECT", sql)
|
|
self.assertIn("FROM bank_loan", sql)
|
|
self.assertIn("STANDARDCURRENCYBALANCE", sql)
|
|
self.assertIn("CASE", sql)
|
|
self.assertNotIn("HAVING", sql)
|
|
self.assertTrue(all(ord(char) < 128 for char in sql))
|
|
finally:
|
|
workbook.close()
|
|
|
|
def test_write_excel_uses_legacy_output_format(self):
|
|
service = RuleGenerationService(
|
|
domains_file=self.domains_file,
|
|
schema_file=self.schema_file,
|
|
output_dir=self.output_dir,
|
|
create_sql=True,
|
|
)
|
|
output_file = os.path.join(self.output_dir, "formatted.xlsx")
|
|
service._write_excel([self._fake_rule()], output_file)
|
|
|
|
workbook = load_workbook(output_file)
|
|
try:
|
|
sheet = workbook.active
|
|
self.assertEqual(sheet.title, "11类重点风险领域")
|
|
self.assertEqual(sheet.freeze_panes, "A2")
|
|
self.assertEqual(sheet.auto_filter.ref, "A1:L2")
|
|
self.assertEqual(
|
|
[sheet.cell(1, col).value for col in range(1, 13)],
|
|
[
|
|
"序号",
|
|
"风险领域",
|
|
"规则名称",
|
|
"风险描述",
|
|
"制度依据",
|
|
"业务规则描述",
|
|
"数据来源",
|
|
"系统规则文本",
|
|
"关联表逻辑",
|
|
"系统固化逻辑",
|
|
"返回结果",
|
|
"SQL查询语句",
|
|
],
|
|
)
|
|
self.assertAlmostEqual(sheet.column_dimensions["A"].width, 5.13, places=2)
|
|
self.assertAlmostEqual(sheet.column_dimensions["H"].width, 107.36, places=2)
|
|
self.assertTrue(sheet["A1"].font.bold)
|
|
self.assertEqual(sheet["F1"].fill.fgColor.rgb, "00FFFF00")
|
|
self.assertEqual(sheet["G1"].fill.fgColor.rgb, "00FFFF00")
|
|
self.assertEqual(sheet["H1"].fill.fgColor.rgb, "00FFFF00")
|
|
self.assertEqual(sheet["A1"].alignment.horizontal, "center")
|
|
self.assertTrue(sheet["F2"].alignment.wrap_text)
|
|
self.assertEqual(sheet["F2"].alignment.horizontal, "left")
|
|
self.assertEqual(sheet["A2"].alignment.horizontal, "center")
|
|
self.assertEqual(sheet["A1"].border.left.style, "thin")
|
|
self.assertEqual(sheet["L2"].border.bottom.style, "thin")
|
|
self.assertGreater(sheet.row_dimensions[2].height, 100)
|
|
finally:
|
|
workbook.close()
|
|
|
|
def test_write_markdown_includes_sql_when_enabled(self):
|
|
service = RuleGenerationService(
|
|
domains_file=self.domains_file,
|
|
schema_file=self.schema_file,
|
|
output_dir=self.output_dir,
|
|
create_sql=True,
|
|
)
|
|
markdown_file = os.path.join(self.output_dir, "with-sql.md")
|
|
service._write_markdown([self._fake_rule()], markdown_file)
|
|
|
|
with open(markdown_file, "r", encoding="utf-8") as file:
|
|
markdown = file.read()
|
|
|
|
self.assertIn("**\u5b9e\u73b0\u811a\u672c\uff08SQL\uff09**", markdown)
|
|
self.assertIn("```sql", markdown)
|
|
self.assertIn("SELECT", markdown)
|
|
self.assertIn("FROM bank_loan", markdown)
|
|
self.assertIn("```", markdown)
|
|
|
|
def test_markdown_failure_does_not_overwrite_successful_excel(self):
|
|
with patch.object(RuleGenerationService, "_generate_rule", return_value=self._fake_rule()):
|
|
with patch.object(RuleGenerationService, "_write_markdown", side_effect=RuntimeError("markdown failed")):
|
|
state = self.service.start(limit=1)
|
|
task_id = state["task_id"]
|
|
|
|
for _ in range(50):
|
|
current = self.service.get_status(task_id)
|
|
if current and current["status"] in {"done", "failed"}:
|
|
break
|
|
time.sleep(0.05)
|
|
|
|
current = self.service.get_status(task_id)
|
|
self.assertEqual(current["status"], "done")
|
|
self.assertEqual(current["generated_count"], 1)
|
|
self.assertEqual(current["markdown_error"], "markdown failed")
|
|
workbook = load_workbook(current["output_file"])
|
|
try:
|
|
self.assertEqual(workbook.active.max_row, 2)
|
|
self.assertEqual(workbook.active["A2"].value, 1)
|
|
finally:
|
|
workbook.close()
|
|
|
|
def test_create_sql_query_uses_schema_markers_only(self):
|
|
sql = self.service._create_sql_query(self._fake_rule())
|
|
|
|
self.assertEqual(sql.splitlines()[0], "SELECT")
|
|
self.assertIn("t1.COMCODE", sql)
|
|
self.assertIn("t1.STANDARDCURRENCYBALANCE", sql)
|
|
self.assertIn("FROM bank_loan", sql)
|
|
self.assertIn("FROM bank_account", sql)
|
|
self.assertIn("LEFT JOIN (", sql)
|
|
self.assertIn(") t2 ON t1.COMCODE = t2.COMCODE", sql)
|
|
self.assertNotIn("WITH", sql)
|
|
self.assertIn("(t1.STANDARDCURRENCYBALANCE / NULLIF(t2.STANDARDCURRENCYBALANCE, 0))", sql)
|
|
self.assertIn("t1.STANDARDCURRENCYBALANCE AS VALUE_1", sql)
|
|
self.assertIn("t2.STANDARDCURRENCYBALANCE AS VALUE_2", sql)
|
|
self.assertIn("THEN 'RED'", sql)
|
|
self.assertIn("THEN 'ORANGE'", sql)
|
|
self.assertIn("THEN 'YELLOW'", sql)
|
|
self.assertNotIn("HAVING", sql)
|
|
self.assertNotIn("\u8d37\u6b3e\u4f59\u989d", sql)
|
|
self.assertTrue(all(ord(char) < 128 for char in sql))
|
|
|
|
def test_sql_does_not_sum_date_like_balance_markers(self):
|
|
self._write_schema_with_date_like_marker()
|
|
sql = self.service._create_sql_query(self._fake_rule())
|
|
|
|
self.assertNotIn("WITH", sql)
|
|
self.assertNotIn("SUM(BANKBALANCEDATE)", sql)
|
|
self.assertIn("BANKBALANCEDATE", sql)
|
|
|
|
def test_generated_sql_validator_rejects_mysql_incompatible_cte(self):
|
|
with self.assertRaisesRegex(ValueError, "WITH/CTE"):
|
|
self.service._validate_generated_sql("WITH\nt1 AS (SELECT 1)\nSELECT * FROM t1;")
|
|
|
|
def test_generated_sql_validator_rejects_date_sum(self):
|
|
with self.assertRaisesRegex(ValueError, "SUM date/time"):
|
|
self.service._validate_generated_sql("SELECT SUM(BANKBALANCEDATE) FROM bank_account;")
|
|
|
|
def _write_schema_with_date_like_marker(self):
|
|
with open(self.schema_file, "w", encoding="utf-8") as f:
|
|
json.dump({
|
|
"modules": [{
|
|
"module_name": "\u94f6\u884c\u8d37\u6b3e",
|
|
"table_name": "bank_loan",
|
|
"fields": [
|
|
{"name": "\u6240\u5c5e\u96c6\u56e2\u7f16\u7801", "marker": "COMCODE"},
|
|
{"name": "\u8d37\u6b3e\u4f59\u989d", "marker": "STANDARDCURRENCYBALANCE"},
|
|
{"name": "\u8d26\u6237\u4f59\u989d\u65f6\u95f4", "marker": "BANKBALANCEDATE", "type": "\u65e5\u671f"},
|
|
],
|
|
}],
|
|
}, f, ensure_ascii=False)
|
|
|
|
def test_data_sources_display_marker_with_chinese_name(self):
|
|
value = self.service._format_data_sources_with_markers(
|
|
"\u94f6\u884c\u8d37\u6b3e:\u6240\u5c5e\u96c6\u56e2\u7f16\u7801,\u8d37\u6b3e\u4f59\u989d",
|
|
self.service._read_schema(),
|
|
)
|
|
|
|
self.assertEqual(
|
|
value,
|
|
"bank_loan(\u94f6\u884c\u8d37\u6b3e): "
|
|
"COMCODE(\u6240\u5c5e\u96c6\u56e2\u7f16\u7801), "
|
|
"STANDARDCURRENCYBALANCE(\u8d37\u6b3e\u4f59\u989d)",
|
|
)
|
|
|
|
def test_metric_alignment_accepts_close_indicator_names(self):
|
|
business_rule = (
|
|
"\u503a\u52a1\u878d\u8d44\u7ed3\u6784\u5360\u6bd4 = "
|
|
"\u503a\u52a1\u878d\u8d44\u4f59\u989d / \u603b\u8d44\u4ea7\n"
|
|
"\u98ce\u9669\u7b49\u7ea7\uff1a\n"
|
|
"- \u7ea2\u8272\uff1a\u503a\u52a1\u878d\u8d44\u5360\u6bd4 > 80%\n"
|
|
"- \u6a59\u8272\uff1a\u503a\u52a1\u878d\u8d44\u5360\u6bd4 > 60%\n"
|
|
"- \u9ec4\u8272\uff1a\u503a\u52a1\u878d\u8d44\u5360\u6bd4 > 40%"
|
|
)
|
|
|
|
RuleGenerationService._validate_business_rule_indicator_alignment(business_rule)
|
|
|
|
count_rule = (
|
|
"\u5408\u540c\u7b14\u6570 = COUNT(\u5408\u540c\u7f16\u53f7)\n"
|
|
"\u98ce\u9669\u7b49\u7ea7\uff1a\n"
|
|
"- \u7ea2\u8272\uff1a\u5408\u540c\u6570\u91cf > 100\n"
|
|
"- \u6a59\u8272\uff1a\u5408\u540c\u6570\u91cf > 50\n"
|
|
"- \u9ec4\u8272\uff1a\u5408\u540c\u6570\u91cf > 20"
|
|
)
|
|
|
|
RuleGenerationService._validate_business_rule_indicator_alignment(count_rule)
|
|
|
|
def test_metric_alignment_rejects_wrong_metric_category(self):
|
|
business_rule = (
|
|
"\u5883\u5916\u6295\u8d44\u6536\u76ca\u7387 = "
|
|
"\u6295\u8d44\u6536\u76ca / \u6295\u8d44\u6210\u672c\n"
|
|
"\u98ce\u9669\u7b49\u7ea7\uff1a\n"
|
|
"- \u7ea2\u8272\uff1a\u5883\u5916\u6295\u8d44\u91d1\u989d > 10000000\n"
|
|
"- \u6a59\u8272\uff1a\u5883\u5916\u6295\u8d44\u91d1\u989d > 5000000\n"
|
|
"- \u9ec4\u8272\uff1a\u5883\u5916\u6295\u8d44\u91d1\u989d > 1000000"
|
|
)
|
|
|
|
with self.assertRaisesRegex(ValueError, "\u6307\u6807\u9608\u503c\u672a\u5bf9\u5e94\u516c\u5f0f\u6307\u6807"):
|
|
RuleGenerationService._validate_business_rule_indicator_alignment(business_rule)
|
|
|
|
def test_rule_shape_allows_single_condition_without_and(self):
|
|
rule = self._fake_rule()
|
|
rule["system_rule_text"] = (
|
|
"\u8d44\u4ea7\u8d1f\u503a\u7387\u8d85\u9650\u9884\u8b66\uff08\u9ec4\u8272\u9608\u503c\uff09:\n"
|
|
"\u5982\u679c\uff1a\n"
|
|
" (\u8d44\u4ea7\u8d1f\u503a\u7387 > 0.65)\n"
|
|
"\u5c31\uff1a\n"
|
|
"\u6821\u9a8c\uff1a\u901a\u8fc7\n"
|
|
"\u8d4b\u503c\uff1a\u3010\u89c4\u5219\u7ed3\u679c-\u98ce\u9669\u7b49\u7ea7\u3011\u7b49\u4e8e \u9ec4\u8272"
|
|
)
|
|
rule["join_logic"] = "\u94f6\u884c\u8d37\u6b3e\u8868\u5de6\u5173\u8054\u94f6\u884c\u8d26\u6237\u8868\uff0c\u6309\u6240\u5c5e\u96c6\u56e2\u7f16\u7801\u6c47\u603b\u3002"
|
|
|
|
RuleGenerationService._validate_rule_shape(rule)
|
|
|
|
def test_normalize_rule_fills_recoverable_columns(self):
|
|
raw = self._fake_rule()
|
|
raw["business_rule_description"] = (
|
|
"\u503a\u52a1\u878d\u8d44\u7ed3\u6784\u5360\u6bd4 = "
|
|
"\u8d37\u6b3e\u4f59\u989d / \u8d26\u6237\u4f59\u989d\n"
|
|
"\u98ce\u9669\u7b49\u7ea7\uff1a\n"
|
|
"- \u7ea2\u8272\uff1a\u503a\u52a1\u878d\u8d44\u5360\u6bd4 > 80%\n"
|
|
"- \u6a59\u8272\uff1a\u503a\u52a1\u878d\u8d44\u5360\u6bd4 > 60%\n"
|
|
"- \u9ec4\u8272\uff1a\u503a\u52a1\u878d\u8d44\u5360\u6bd4 > 40%"
|
|
)
|
|
for key in ("data_sources", "system_rule_text", "join_logic", "system_logic", "return_fields"):
|
|
raw[key] = ""
|
|
|
|
normalized = self.service._normalize_rule(
|
|
raw,
|
|
"\u8fc7\u5ea6\u8d1f\u503a",
|
|
{
|
|
"description_pattern": "\u6838\u5fc3\u6cd5\u89c4\uff1a\u300a\u6d4b\u8bd5\u529e\u6cd5\u300b\uff1b\u6761\u6b3e\u8981\u70b9\uff1a\u52a8\u6001\u76d1\u6d4b\u503a\u52a1\u878d\u8d44\u7ed3\u6784\u3002",
|
|
"basis_text": "\u52a8\u6001\u76d1\u6d4b\u503a\u52a1\u878d\u8d44\u7ed3\u6784\u3002",
|
|
},
|
|
self.service._read_schema(),
|
|
)
|
|
|
|
self.assertIn("bank_loan(\u94f6\u884c\u8d37\u6b3e)", normalized["data_sources"])
|
|
self.assertIn("\u5982\u679c", normalized["system_rule_text"])
|
|
self.assertIn("\u8d4b\u503c", normalized["system_rule_text"])
|
|
self.assertIn("\u8868", normalized["join_logic"])
|
|
self.assertIn("\u503c1", normalized["system_logic"])
|
|
self.assertIn("\u503c2", normalized["system_logic"])
|
|
self.assertIn("\u98ce\u9669\u7b49\u7ea7", normalized["return_fields"])
|
|
|
|
def test_normalize_rule_converts_nested_llm_values_to_excel_text(self):
|
|
raw = self._fake_rule()
|
|
raw["policy_basis"] = {
|
|
"policy_name": "\u300a\u6d4b\u8bd5\u529e\u6cd5\u300b",
|
|
"clause_summary": "\u52a8\u6001\u76d1\u6d4b\u503a\u52a1\u98ce\u9669\u3002",
|
|
}
|
|
raw["data_sources"] = [
|
|
{"table": "\u94f6\u884c\u8d37\u6b3e", "fields": ["COMCODE", "STANDARDCURRENCYBALANCE"]},
|
|
]
|
|
raw["return_fields"] = ["COMCODE", "\u98ce\u9669\u7b49\u7ea7"]
|
|
|
|
normalized = self.service._normalize_rule(
|
|
raw,
|
|
"\u8fc7\u5ea6\u8d1f\u503a",
|
|
{
|
|
"description_pattern": "\u6838\u5fc3\u6cd5\u89c4\uff1a\u300a\u6d4b\u8bd5\u529e\u6cd5\u300b\uff1b\u6761\u6b3e\u8981\u70b9\uff1a\u52a8\u6001\u76d1\u6d4b\u503a\u52a1\u98ce\u9669\u3002",
|
|
"basis_text": "\u52a8\u6001\u76d1\u6d4b\u503a\u52a1\u98ce\u9669\u3002",
|
|
},
|
|
self.service._read_schema(),
|
|
)
|
|
output_file = os.path.join(self.output_dir, "nested-values.xlsx")
|
|
|
|
self.service._write_excel([normalized], output_file)
|
|
|
|
workbook = load_workbook(output_file)
|
|
try:
|
|
sheet = workbook.active
|
|
self.assertIn("\u6838\u5fc3\u6cd5\u89c4\uff1a\u300a\u6d4b\u8bd5\u529e\u6cd5\u300b", sheet["E2"].value)
|
|
self.assertIn("bank_loan(\u94f6\u884c\u8d37\u6b3e)", sheet["G2"].value)
|
|
self.assertIn("STANDARDCURRENCYBALANCE(\u8d37\u6b3e\u4f59\u989d)", sheet["G2"].value)
|
|
self.assertEqual(sheet["K2"].value, "COMCODE, \u98ce\u9669\u7b49\u7ea7")
|
|
finally:
|
|
workbook.close()
|
|
|
|
def test_indicator_guidance_prefers_amount_count_and_days_for_overdue_rules(self):
|
|
pattern = {
|
|
"description_pattern": "\u5efa\u7acb\u503a\u52a1\u5230\u671f\u9884\u8b66\u673a\u5236\uff0c\u5173\u6ce8\u903e\u671f\u91d1\u989d\u548c\u903e\u671f\u7b14\u6570\u3002",
|
|
"basis_text": "\u4e25\u63a7\u903e\u671f\u503a\u52a1\u98ce\u9669\u3002",
|
|
}
|
|
guidance = self.service._build_indicator_guidance(
|
|
"\u8fc7\u5ea6\u8d1f\u503a",
|
|
pattern,
|
|
[{
|
|
"module_name": "\u5e94\u4ed8\u7968\u636e",
|
|
"fields": [
|
|
{"name": "\u662f\u5426\u903e\u671f", "marker": "OVERDUE"},
|
|
{"name": "\u5230\u671f\u65e5\u671f", "marker": "DUEDATE"},
|
|
{"name": "\u7968\u9762\u91d1\u989d", "marker": "AMOUNT"},
|
|
],
|
|
}],
|
|
)
|
|
|
|
self.assertIn("\u91d1\u989d\u7c7b", guidance)
|
|
self.assertIn("\u6570\u91cf\u7c7b", guidance)
|
|
self.assertIn("\u671f\u9650\u7c7b", guidance)
|
|
self.assertIn("不要默认生成 A/B 比率", guidance)
|
|
|
|
def test_rule_prompt_includes_indicator_guidance(self):
|
|
prompt = self.service._build_rule_prompt(
|
|
"\u8fc7\u5ea6\u8d1f\u503a",
|
|
{
|
|
"description_pattern": "\u5173\u6ce8\u903e\u671f\u91d1\u989d\u548c\u903e\u671f\u7b14\u6570\u3002",
|
|
"basis_text": "\u503a\u52a1\u5230\u671f\u9884\u8b66\u3002",
|
|
},
|
|
self.service._read_schema(),
|
|
)
|
|
|
|
self.assertIn("\u6307\u6807\u53e3\u5f84\u9009\u62e9\u5efa\u8bae", prompt)
|
|
self.assertIn("\u91d1\u989d\u7c7b", prompt)
|
|
self.assertIn("\u7b14\u6570", prompt)
|
|
|
|
def test_generate_rule_retries_once_after_transient_failure(self):
|
|
pattern = {
|
|
"description_pattern": "\u6838\u5fc3\u6cd5\u89c4\uff1a\u300a\u6d4b\u8bd5\u529e\u6cd5\u300b\uff1b\u6761\u6b3e\u8981\u70b9\uff1a\u52a8\u6001\u76d1\u6d4b\u8d44\u4ea7\u8d1f\u503a\u7387\u3002",
|
|
"basis_text": "\u52a8\u6001\u76d1\u6d4b\u8d44\u4ea7\u8d1f\u503a\u7387\u3002",
|
|
}
|
|
with patch.object(
|
|
RuleGenerationService,
|
|
"_generate_rule_with_llm",
|
|
side_effect=[RuntimeError("timeout"), self._fake_rule()],
|
|
) as generate:
|
|
rule = self.service._generate_rule("\u8fc7\u5ea6\u8d1f\u503a", pattern, self.service._read_schema())
|
|
|
|
self.assertEqual(generate.call_count, 2)
|
|
self.assertIn("\u5982\u679c", rule["system_rule_text"])
|
|
self.assertIn("\u8d4b\u503c", rule["system_rule_text"])
|
|
|
|
def test_generate_rule_with_llm_disables_minimax_thinking(self):
|
|
pattern = {
|
|
"description_pattern": "\u6838\u5fc3\u6cd5\u89c4\uff1a\u300a\u6d4b\u8bd5\u529e\u6cd5\u300b\uff1b\u6761\u6b3e\u8981\u70b9\uff1a\u52a8\u6001\u76d1\u6d4b\u8d44\u4ea7\u8d1f\u503a\u7387\u3002",
|
|
"basis_text": "\u52a8\u6001\u76d1\u6d4b\u8d44\u4ea7\u8d1f\u503a\u7387\u3002",
|
|
}
|
|
with patch("app.utils.llm.LLMClient") as client_cls:
|
|
client_cls.return_value.chat.return_value = json.dumps(self._fake_rule(), ensure_ascii=False)
|
|
|
|
self.service._generate_rule_with_llm("\u8fc7\u5ea6\u8d1f\u503a", pattern, self.service._read_schema())
|
|
|
|
kwargs = client_cls.return_value.chat.call_args.kwargs
|
|
self.assertEqual(kwargs["thinking"], {"type": "disabled"})
|
|
self.assertEqual(kwargs["max_tokens"], 2600)
|
|
|
|
def test_start_records_failure_without_fallback_when_llm_fails(self):
|
|
with patch.object(RuleGenerationService, "_generate_rule_with_llm", side_effect=RuntimeError("LLM 529")):
|
|
state = self.service.start(limit=1)
|
|
task_id = state["task_id"]
|
|
|
|
for _ in range(50):
|
|
current = self.service.get_status(task_id)
|
|
if current and current["status"] in {"done", "failed"}:
|
|
break
|
|
time.sleep(0.05)
|
|
|
|
current = self.service.get_status(task_id)
|
|
self.assertEqual(current["status"], "failed")
|
|
self.assertEqual(current["generated_count"], 0)
|
|
self.assertTrue(os.path.exists(current["markdown_file"]))
|
|
self.assertIn("LLM 529", current["skipped_rules"][0]["reason"])
|
|
|
|
def test_select_policy_points_generates_limit_rules_per_domain(self):
|
|
domains = [
|
|
{"domain": "A", "patterns": [{"id": "a1"}, {"id": "a2"}]},
|
|
{"domain": "B", "patterns": [{"id": "b1"}]},
|
|
]
|
|
|
|
with patch("app.utils.rule_generation.random.sample", lambda patterns, limit: patterns[:limit]):
|
|
with patch("app.utils.rule_generation.random.choice", lambda patterns: patterns[0]):
|
|
selected = self.service._select_policy_points(domains, 2)
|
|
|
|
self.assertEqual(len(selected), 4)
|
|
self.assertEqual(
|
|
[(item["domain"], item["pattern"]["id"], item["variant_index"], item["variant_total"]) for item in selected],
|
|
[("A", "a1", 1, 2), ("A", "a2", 2, 2), ("B", "b1", 1, 2), ("B", "b1", 2, 2)],
|
|
)
|
|
|
|
def test_select_policy_points_does_not_multiply_by_pattern_count(self):
|
|
domains = [{"domain": "A", "patterns": [{"id": f"a{i}"} for i in range(71)]}]
|
|
|
|
with patch("app.utils.rule_generation.random.sample", lambda patterns, limit: patterns[:limit]):
|
|
selected = self.service._select_policy_points(domains, 2)
|
|
|
|
self.assertEqual(len(selected), 2)
|
|
|
|
def test_old_pattern_multiplication_would_have_generated_142_rules(self):
|
|
domains = [{"domain": "A", "patterns": [{"id": f"a{i}"} for i in range(71)]}]
|
|
|
|
selected = self.service._select_policy_points(domains, 2)
|
|
|
|
self.assertNotEqual(len(selected), 142)
|
|
|
|
def test_select_policy_points_keeps_domain_grouped_order(self):
|
|
domains = [
|
|
{"domain": "A", "patterns": [{"id": "a1"}, {"id": "a2"}]},
|
|
{"domain": "B", "patterns": [{"id": "b1"}, {"id": "b2"}]},
|
|
]
|
|
|
|
with patch("app.utils.rule_generation.random.sample", lambda patterns, limit: patterns[:limit]):
|
|
selected = self.service._select_policy_points(domains, 2)
|
|
|
|
self.assertEqual([item["domain"] for item in selected], ["A", "A", "B", "B"])
|
|
|
|
def test_rule_prompt_includes_variant_guidance_for_each_policy_point(self):
|
|
pattern = {
|
|
"description_pattern": "\u5173\u6ce8\u903e\u671f\u91d1\u989d\u548c\u903e\u671f\u7b14\u6570\u3002",
|
|
"basis_text": "\u503a\u52a1\u5230\u671f\u9884\u8b66\u3002",
|
|
"_variant_index": 2,
|
|
"_variant_total": 3,
|
|
}
|
|
prompt = self.service._build_rule_prompt("\u8fc7\u5ea6\u8d1f\u503a", pattern, self.service._read_schema())
|
|
|
|
self.assertIn("\u7b2c 2/3 \u6761\u89c4\u5219", prompt)
|
|
self.assertIn("\u4e0d\u5f97\u53ea\u6539\u540d\u79f0\u6216\u9608\u503c", prompt)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|