Files
YG-Rules/tests/test_rule_generation.py

619 lines
30 KiB
Python

import json
import os
import shutil
import time
import unittest
import uuid
from unittest.mock import patch
from openpyxl import load_workbook
from app.utils.rule_generation import RuleGenerationService
class RuleGenerationServiceTest(unittest.TestCase):
def setUp(self):
self.output_dir = os.path.join(os.getcwd(), "output", f"test-rule-generation-{uuid.uuid4().hex[:8]}")
os.makedirs(self.output_dir, exist_ok=True)
self.domains_file = os.path.join(self.output_dir, "domains.json")
self.schema_file = os.path.join(self.output_dir, "schema.json")
with open(self.domains_file, "w", encoding="utf-8") as f:
json.dump({
"domains": [
{
"token": "token-1",
"domain": "\u8fc7\u5ea6\u8d1f\u503a",
"guidance_files": [{
"file_id": "file-1",
"filename": "policy.txt",
"guidance_analysis": {
"status": "done",
"description_patterns": [{
"description_pattern": "\u6838\u5fc3\u6cd5\u89c4\uff1a\u300a\u6d4b\u8bd5\u529e\u6cd5\u300b\uff1b\u6761\u6b3e\u8981\u70b9\uff1a\u52a8\u6001\u76d1\u6d4b\u8d44\u4ea7\u8d1f\u503a\u7387\u3002",
"basis_text": "\u52a8\u6001\u76d1\u6d4b\u8d44\u4ea7\u8d1f\u503a\u7387\u3002",
"source_sentence": "\u52a8\u6001\u76d1\u6d4b\u8d44\u4ea7\u8d1f\u503a\u7387\u3002",
"supervision_dimension": "\u8d44\u91d1\u7a7f\u900f",
"keywords": ["\u8d44\u4ea7\u8d1f\u503a\u7387", "\u503a\u52a1"],
}],
},
}],
},
{
"token": "token-no-guidance",
"domain": "\u865a\u5047\u8d38\u6613",
"guidance_files": [],
},
{
"token": "token-unparsed",
"domain": "\u591a\u5143\u6295\u8d44",
"guidance_files": [{
"file_id": "file-2",
"filename": "pending.txt",
}],
},
],
}, f, ensure_ascii=False)
self._write_schema()
self.service = RuleGenerationService(
domains_file=self.domains_file,
schema_file=self.schema_file,
output_dir=self.output_dir,
)
def tearDown(self):
if os.path.isdir(self.output_dir):
shutil.rmtree(self.output_dir)
def _write_schema(self):
with open(self.schema_file, "w", encoding="utf-8") as f:
json.dump({
"modules": [
{
"module_name": "\u94f6\u884c\u8d37\u6b3e",
"table_name": "bank_loan",
"description": "\u8bb0\u5f55\u8d37\u6b3e\u4f59\u989d\u548c\u503a\u52a1\u878d\u8d44\u4fe1\u606f\u3002",
"fields": [
{"name": "\u6240\u5c5e\u96c6\u56e2\u7f16\u7801", "marker": "COMCODE"},
{"name": "\u6240\u5c5e\u96c6\u56e2\u540d\u79f0", "marker": "COMNAME"},
{"name": "\u8d37\u6b3e\u5355\u4f4d\u540d\u79f0", "marker": "COMPANY_CLTNAME"},
{"name": "\u8d37\u6b3e\u4f59\u989d", "marker": "STANDARDCURRENCYBALANCE"},
],
},
{
"module_name": "\u94f6\u884c\u8d26\u6237",
"table_name": "bank_account",
"description": "\u8bb0\u5f55\u8d26\u6237\u4f59\u989d\u548c\u8d26\u6237\u72b6\u6001\u3002",
"fields": [
{"name": "\u6240\u5c5e\u96c6\u56e2\u7f16\u7801", "marker": "COMCODE"},
{"name": "\u8d26\u6237\u4f59\u989d", "marker": "STANDARDCURRENCYBALANCE"},
],
},
],
}, f, ensure_ascii=False)
def _fake_rule(self):
return {
"rule_name": "\u8d44\u4ea7\u8d1f\u503a\u7387\u8d85\u9650\u9884\u8b66",
"risk_description": "\u6d4b\u8bd5\u98ce\u9669\u63cf\u8ff0",
"policy_basis": "\u6838\u5fc3\u6cd5\u89c4\uff1a\u300a\u6d4b\u8bd5\u529e\u6cd5\u300b\uff1b\u6761\u6b3e\u8981\u70b9\uff1a\u52a8\u6001\u76d1\u6d4b\u8d44\u4ea7\u8d1f\u503a\u7387\u3002",
"business_rule_description": "\u8d44\u4ea7\u8d1f\u503a\u7387 = \u8d37\u6b3e\u4f59\u989d / \u8d26\u6237\u4f59\u989d\n\u98ce\u9669\u7b49\u7ea7\uff1a\n- \u7ea2\u8272\uff1a\u8d44\u4ea7\u8d1f\u503a\u7387 > 85%\n- \u6a59\u8272\uff1a\u8d44\u4ea7\u8d1f\u503a\u7387 > 75%\n- \u9ec4\u8272\uff1a\u8d44\u4ea7\u8d1f\u503a\u7387 > 65%",
"data_sources": "\u94f6\u884c\u8d37\u6b3e\uff1a\u6240\u5c5e\u96c6\u56e2\u7f16\u7801\u3001\u8d37\u6b3e\u4f59\u989d\n\u94f6\u884c\u8d26\u6237\uff1a\u8d26\u6237\u4f59\u989d",
"system_rule_text": "\u5982\u679c\uff1a\u8d44\u4ea7\u8d1f\u503a\u7387 > 0.85\n\u5219\uff1a\u6821\u9a8c\u901a\u8fc7",
"join_logic": "\u94f6\u884c\u8d37\u6b3e\u5de6\u5173\u8054\u94f6\u884c\u8d26\u6237\u3002",
"system_logic": "\u503c1\uff1a\u8d37\u6b3e\u4f59\u989d\uff1b\u503c2\uff1a\u8d26\u6237\u4f59\u989d\u3002",
"return_fields": "\u6240\u5c5e\u96c6\u56e2\u7f16\u7801,\u8d37\u6b3e\u4f59\u989d,\u98ce\u9669\u7b49\u7ea7",
}
def test_start_generates_status_and_excel(self):
with patch.object(RuleGenerationService, "_generate_rule", return_value=self._fake_rule()):
state = self.service.start(limit=1)
task_id = state["task_id"]
for _ in range(50):
current = self.service.get_status(task_id)
if current and current["status"] in {"done", "failed"}:
break
time.sleep(0.05)
current = self.service.get_status(task_id)
self.assertEqual(current["status"], "done")
self.assertEqual(current["generated_count"], 1)
self.assertTrue(os.path.exists(current["output_file"]))
self.assertTrue(os.path.exists(current["markdown_file"]))
self.assertEqual(os.path.dirname(current["output_file"]), os.path.dirname(current["markdown_file"]))
self.assertEqual(os.path.basename(os.path.dirname(current["output_file"])), f"rules-{task_id}")
self.assertEqual(current["output_dir"], os.path.dirname(current["output_file"]))
self.assertEqual(current["files"]["excel"], current["output_file"])
self.assertEqual(current["files"]["markdown"], current["markdown_file"])
workbook = load_workbook(current["output_file"])
try:
sheet = workbook.active
self.assertEqual(sheet.max_column, 11)
self.assertEqual(sheet["A2"].value, 1)
self.assertIn("bank_loan(\u94f6\u884c\u8d37\u6b3e)", sheet["G2"].value)
self.assertIn("COMCODE(\u6240\u5c5e\u96c6\u56e2\u7f16\u7801)", sheet["G2"].value)
self.assertIn("STANDARDCURRENCYBALANCE(\u8d37\u6b3e\u4f59\u989d)", sheet["G2"].value)
self.assertEqual(sheet["L1"].value, None)
finally:
workbook.close()
with open(current["markdown_file"], "r", encoding="utf-8") as file:
markdown = file.read()
self.assertIn("# \u91cd\u70b9\u98ce\u9669\u9886\u57df\u89c4\u5219\u6a21\u578b\u4e0e\u811a\u672c", markdown)
self.assertIn("\u4ee5\u5b9e\u9645\u751f\u6210\u7684\u98ce\u9669\u9886\u57df\u4e3a\u51c6", markdown)
self.assertIn("## 1. \u8fc7\u5ea6\u8d1f\u503a", markdown)
self.assertIn("### \u89c4\u5219\u6a21\u578b1.1\uff1a\u8d44\u4ea7\u8d1f\u503a\u7387\u8d85\u9650\u9884\u8b66", markdown)
self.assertIn("bank_loan(\u94f6\u884c\u8d37\u6b3e)", markdown)
def test_start_generates_limit_rules_per_policy_point(self):
with patch.object(RuleGenerationService, "_generate_rule", return_value=self._fake_rule()) as generate:
state = self.service.start(limit=2)
task_id = state["task_id"]
for _ in range(50):
current = self.service.get_status(task_id)
if current and current["status"] in {"done", "failed"}:
break
time.sleep(0.05)
current = self.service.get_status(task_id)
self.assertEqual(current["status"], "done")
self.assertEqual(current["generated_count"], 2)
self.assertEqual(current["progress"]["total"], 2)
self.assertEqual(generate.call_count, 2)
self.assertEqual(
{item["domain"] for item in current["skipped_domains"]},
{"\u865a\u5047\u8d38\u6613", "\u591a\u5143\u6295\u8d44"},
)
workbook = load_workbook(current["output_file"])
try:
self.assertEqual(workbook.active.max_row, 3)
finally:
workbook.close()
def test_collect_policy_basis_only_uses_analyzed_guidance_files(self):
domains = self.service._collect_policy_basis()
self.assertEqual([item["token"] for item in domains], ["token-1"])
self.assertEqual(len(domains[0]["patterns"]), 1)
def test_write_excel_appends_sql_column_when_enabled(self):
service = RuleGenerationService(
domains_file=self.domains_file,
schema_file=self.schema_file,
output_dir=self.output_dir,
create_sql=True,
)
output_file = os.path.join(self.output_dir, "with-sql.xlsx")
service._write_excel([self._fake_rule()], output_file)
workbook = load_workbook(output_file)
try:
sheet = workbook.active
sql = sheet["L2"].value
self.assertEqual(sheet.max_column, 12)
self.assertIn("SQL", sheet["L1"].value)
self.assertIn("SELECT", sql)
self.assertIn("FROM bank_loan", sql)
self.assertIn("STANDARDCURRENCYBALANCE", sql)
self.assertIn("CASE", sql)
self.assertNotIn("HAVING", sql)
self.assertTrue(all(ord(char) < 128 for char in sql))
finally:
workbook.close()
def test_write_excel_uses_legacy_output_format(self):
service = RuleGenerationService(
domains_file=self.domains_file,
schema_file=self.schema_file,
output_dir=self.output_dir,
create_sql=True,
)
output_file = os.path.join(self.output_dir, "formatted.xlsx")
service._write_excel([self._fake_rule()], output_file)
workbook = load_workbook(output_file)
try:
sheet = workbook.active
self.assertEqual(sheet.title, "11类重点风险领域")
self.assertEqual(sheet.freeze_panes, "A2")
self.assertEqual(sheet.auto_filter.ref, "A1:L2")
self.assertEqual(
[sheet.cell(1, col).value for col in range(1, 13)],
[
"序号",
"风险领域",
"规则名称",
"风险描述",
"制度依据",
"业务规则描述",
"数据来源",
"系统规则文本",
"关联表逻辑",
"系统固化逻辑",
"返回结果",
"SQL查询语句",
],
)
self.assertAlmostEqual(sheet.column_dimensions["A"].width, 5.13, places=2)
self.assertAlmostEqual(sheet.column_dimensions["H"].width, 107.36, places=2)
self.assertTrue(sheet["A1"].font.bold)
self.assertEqual(sheet["F1"].fill.fgColor.rgb, "00FFFF00")
self.assertEqual(sheet["G1"].fill.fgColor.rgb, "00FFFF00")
self.assertEqual(sheet["H1"].fill.fgColor.rgb, "00FFFF00")
self.assertEqual(sheet["A1"].alignment.horizontal, "center")
self.assertTrue(sheet["F2"].alignment.wrap_text)
self.assertEqual(sheet["F2"].alignment.horizontal, "left")
self.assertEqual(sheet["A2"].alignment.horizontal, "center")
self.assertEqual(sheet["A1"].border.left.style, "thin")
self.assertEqual(sheet["L2"].border.bottom.style, "thin")
self.assertGreater(sheet.row_dimensions[2].height, 100)
finally:
workbook.close()
def test_write_markdown_includes_sql_when_enabled(self):
service = RuleGenerationService(
domains_file=self.domains_file,
schema_file=self.schema_file,
output_dir=self.output_dir,
create_sql=True,
)
markdown_file = os.path.join(self.output_dir, "with-sql.md")
service._write_markdown([self._fake_rule()], markdown_file)
with open(markdown_file, "r", encoding="utf-8") as file:
markdown = file.read()
self.assertIn("**\u5b9e\u73b0\u811a\u672c\uff08SQL\uff09**", markdown)
self.assertIn("```sql", markdown)
self.assertIn("SELECT", markdown)
self.assertIn("FROM bank_loan", markdown)
self.assertIn("```", markdown)
def test_markdown_failure_does_not_overwrite_successful_excel(self):
with patch.object(RuleGenerationService, "_generate_rule", return_value=self._fake_rule()):
with patch.object(RuleGenerationService, "_write_markdown", side_effect=RuntimeError("markdown failed")):
state = self.service.start(limit=1)
task_id = state["task_id"]
for _ in range(50):
current = self.service.get_status(task_id)
if current and current["status"] in {"done", "failed"}:
break
time.sleep(0.05)
current = self.service.get_status(task_id)
self.assertEqual(current["status"], "done")
self.assertEqual(current["generated_count"], 1)
self.assertEqual(current["markdown_error"], "markdown failed")
workbook = load_workbook(current["output_file"])
try:
self.assertEqual(workbook.active.max_row, 2)
self.assertEqual(workbook.active["A2"].value, 1)
finally:
workbook.close()
def test_create_sql_query_uses_schema_markers_only(self):
sql = self.service._create_sql_query(self._fake_rule())
self.assertEqual(sql.splitlines()[0], "SELECT")
self.assertIn("t1.COMCODE", sql)
self.assertIn("t1.STANDARDCURRENCYBALANCE", sql)
self.assertIn("FROM bank_loan", sql)
self.assertIn("FROM bank_account", sql)
self.assertIn("LEFT JOIN (", sql)
self.assertIn(") t2 ON t1.COMCODE = t2.COMCODE", sql)
self.assertNotIn("WITH", sql)
self.assertIn("(t1.STANDARDCURRENCYBALANCE / NULLIF(t2.STANDARDCURRENCYBALANCE, 0))", sql)
self.assertIn("t1.STANDARDCURRENCYBALANCE AS VALUE_1", sql)
self.assertIn("t2.STANDARDCURRENCYBALANCE AS VALUE_2", sql)
self.assertIn("THEN 'RED'", sql)
self.assertIn("THEN 'ORANGE'", sql)
self.assertIn("THEN 'YELLOW'", sql)
self.assertNotIn("HAVING", sql)
self.assertNotIn("\u8d37\u6b3e\u4f59\u989d", sql)
self.assertTrue(all(ord(char) < 128 for char in sql))
def test_sql_does_not_sum_date_like_balance_markers(self):
self._write_schema_with_date_like_marker()
sql = self.service._create_sql_query(self._fake_rule())
self.assertNotIn("WITH", sql)
self.assertNotIn("SUM(BANKBALANCEDATE)", sql)
self.assertIn("BANKBALANCEDATE", sql)
def test_generated_sql_validator_rejects_mysql_incompatible_cte(self):
with self.assertRaisesRegex(ValueError, "WITH/CTE"):
self.service._validate_generated_sql("WITH\nt1 AS (SELECT 1)\nSELECT * FROM t1;")
def test_generated_sql_validator_rejects_date_sum(self):
with self.assertRaisesRegex(ValueError, "SUM date/time"):
self.service._validate_generated_sql("SELECT SUM(BANKBALANCEDATE) FROM bank_account;")
def _write_schema_with_date_like_marker(self):
with open(self.schema_file, "w", encoding="utf-8") as f:
json.dump({
"modules": [{
"module_name": "\u94f6\u884c\u8d37\u6b3e",
"table_name": "bank_loan",
"fields": [
{"name": "\u6240\u5c5e\u96c6\u56e2\u7f16\u7801", "marker": "COMCODE"},
{"name": "\u8d37\u6b3e\u4f59\u989d", "marker": "STANDARDCURRENCYBALANCE"},
{"name": "\u8d26\u6237\u4f59\u989d\u65f6\u95f4", "marker": "BANKBALANCEDATE", "type": "\u65e5\u671f"},
],
}],
}, f, ensure_ascii=False)
def test_data_sources_display_marker_with_chinese_name(self):
value = self.service._format_data_sources_with_markers(
"\u94f6\u884c\u8d37\u6b3e:\u6240\u5c5e\u96c6\u56e2\u7f16\u7801,\u8d37\u6b3e\u4f59\u989d",
self.service._read_schema(),
)
self.assertEqual(
value,
"bank_loan(\u94f6\u884c\u8d37\u6b3e): "
"COMCODE(\u6240\u5c5e\u96c6\u56e2\u7f16\u7801), "
"STANDARDCURRENCYBALANCE(\u8d37\u6b3e\u4f59\u989d)",
)
def test_metric_alignment_accepts_close_indicator_names(self):
business_rule = (
"\u503a\u52a1\u878d\u8d44\u7ed3\u6784\u5360\u6bd4 = "
"\u503a\u52a1\u878d\u8d44\u4f59\u989d / \u603b\u8d44\u4ea7\n"
"\u98ce\u9669\u7b49\u7ea7\uff1a\n"
"- \u7ea2\u8272\uff1a\u503a\u52a1\u878d\u8d44\u5360\u6bd4 > 80%\n"
"- \u6a59\u8272\uff1a\u503a\u52a1\u878d\u8d44\u5360\u6bd4 > 60%\n"
"- \u9ec4\u8272\uff1a\u503a\u52a1\u878d\u8d44\u5360\u6bd4 > 40%"
)
RuleGenerationService._validate_business_rule_indicator_alignment(business_rule)
count_rule = (
"\u5408\u540c\u7b14\u6570 = COUNT(\u5408\u540c\u7f16\u53f7)\n"
"\u98ce\u9669\u7b49\u7ea7\uff1a\n"
"- \u7ea2\u8272\uff1a\u5408\u540c\u6570\u91cf > 100\n"
"- \u6a59\u8272\uff1a\u5408\u540c\u6570\u91cf > 50\n"
"- \u9ec4\u8272\uff1a\u5408\u540c\u6570\u91cf > 20"
)
RuleGenerationService._validate_business_rule_indicator_alignment(count_rule)
def test_metric_alignment_rejects_wrong_metric_category(self):
business_rule = (
"\u5883\u5916\u6295\u8d44\u6536\u76ca\u7387 = "
"\u6295\u8d44\u6536\u76ca / \u6295\u8d44\u6210\u672c\n"
"\u98ce\u9669\u7b49\u7ea7\uff1a\n"
"- \u7ea2\u8272\uff1a\u5883\u5916\u6295\u8d44\u91d1\u989d > 10000000\n"
"- \u6a59\u8272\uff1a\u5883\u5916\u6295\u8d44\u91d1\u989d > 5000000\n"
"- \u9ec4\u8272\uff1a\u5883\u5916\u6295\u8d44\u91d1\u989d > 1000000"
)
with self.assertRaisesRegex(ValueError, "\u6307\u6807\u9608\u503c\u672a\u5bf9\u5e94\u516c\u5f0f\u6307\u6807"):
RuleGenerationService._validate_business_rule_indicator_alignment(business_rule)
def test_rule_shape_allows_single_condition_without_and(self):
rule = self._fake_rule()
rule["system_rule_text"] = (
"\u8d44\u4ea7\u8d1f\u503a\u7387\u8d85\u9650\u9884\u8b66\uff08\u9ec4\u8272\u9608\u503c\uff09:\n"
"\u5982\u679c\uff1a\n"
" (\u8d44\u4ea7\u8d1f\u503a\u7387 > 0.65)\n"
"\u5c31\uff1a\n"
"\u6821\u9a8c\uff1a\u901a\u8fc7\n"
"\u8d4b\u503c\uff1a\u3010\u89c4\u5219\u7ed3\u679c-\u98ce\u9669\u7b49\u7ea7\u3011\u7b49\u4e8e \u9ec4\u8272"
)
rule["join_logic"] = "\u94f6\u884c\u8d37\u6b3e\u8868\u5de6\u5173\u8054\u94f6\u884c\u8d26\u6237\u8868\uff0c\u6309\u6240\u5c5e\u96c6\u56e2\u7f16\u7801\u6c47\u603b\u3002"
RuleGenerationService._validate_rule_shape(rule)
def test_normalize_rule_fills_recoverable_columns(self):
raw = self._fake_rule()
raw["business_rule_description"] = (
"\u503a\u52a1\u878d\u8d44\u7ed3\u6784\u5360\u6bd4 = "
"\u8d37\u6b3e\u4f59\u989d / \u8d26\u6237\u4f59\u989d\n"
"\u98ce\u9669\u7b49\u7ea7\uff1a\n"
"- \u7ea2\u8272\uff1a\u503a\u52a1\u878d\u8d44\u5360\u6bd4 > 80%\n"
"- \u6a59\u8272\uff1a\u503a\u52a1\u878d\u8d44\u5360\u6bd4 > 60%\n"
"- \u9ec4\u8272\uff1a\u503a\u52a1\u878d\u8d44\u5360\u6bd4 > 40%"
)
for key in ("data_sources", "system_rule_text", "join_logic", "system_logic", "return_fields"):
raw[key] = ""
normalized = self.service._normalize_rule(
raw,
"\u8fc7\u5ea6\u8d1f\u503a",
{
"description_pattern": "\u6838\u5fc3\u6cd5\u89c4\uff1a\u300a\u6d4b\u8bd5\u529e\u6cd5\u300b\uff1b\u6761\u6b3e\u8981\u70b9\uff1a\u52a8\u6001\u76d1\u6d4b\u503a\u52a1\u878d\u8d44\u7ed3\u6784\u3002",
"basis_text": "\u52a8\u6001\u76d1\u6d4b\u503a\u52a1\u878d\u8d44\u7ed3\u6784\u3002",
},
self.service._read_schema(),
)
self.assertIn("bank_loan(\u94f6\u884c\u8d37\u6b3e)", normalized["data_sources"])
self.assertIn("\u5982\u679c", normalized["system_rule_text"])
self.assertIn("\u8d4b\u503c", normalized["system_rule_text"])
self.assertIn("\u8868", normalized["join_logic"])
self.assertIn("\u503c1", normalized["system_logic"])
self.assertIn("\u503c2", normalized["system_logic"])
self.assertIn("\u98ce\u9669\u7b49\u7ea7", normalized["return_fields"])
def test_normalize_rule_converts_nested_llm_values_to_excel_text(self):
raw = self._fake_rule()
raw["policy_basis"] = {
"policy_name": "\u300a\u6d4b\u8bd5\u529e\u6cd5\u300b",
"clause_summary": "\u52a8\u6001\u76d1\u6d4b\u503a\u52a1\u98ce\u9669\u3002",
}
raw["data_sources"] = [
{"table": "\u94f6\u884c\u8d37\u6b3e", "fields": ["COMCODE", "STANDARDCURRENCYBALANCE"]},
]
raw["return_fields"] = ["COMCODE", "\u98ce\u9669\u7b49\u7ea7"]
normalized = self.service._normalize_rule(
raw,
"\u8fc7\u5ea6\u8d1f\u503a",
{
"description_pattern": "\u6838\u5fc3\u6cd5\u89c4\uff1a\u300a\u6d4b\u8bd5\u529e\u6cd5\u300b\uff1b\u6761\u6b3e\u8981\u70b9\uff1a\u52a8\u6001\u76d1\u6d4b\u503a\u52a1\u98ce\u9669\u3002",
"basis_text": "\u52a8\u6001\u76d1\u6d4b\u503a\u52a1\u98ce\u9669\u3002",
},
self.service._read_schema(),
)
output_file = os.path.join(self.output_dir, "nested-values.xlsx")
self.service._write_excel([normalized], output_file)
workbook = load_workbook(output_file)
try:
sheet = workbook.active
self.assertIn("\u6838\u5fc3\u6cd5\u89c4\uff1a\u300a\u6d4b\u8bd5\u529e\u6cd5\u300b", sheet["E2"].value)
self.assertIn("bank_loan(\u94f6\u884c\u8d37\u6b3e)", sheet["G2"].value)
self.assertIn("STANDARDCURRENCYBALANCE(\u8d37\u6b3e\u4f59\u989d)", sheet["G2"].value)
self.assertEqual(sheet["K2"].value, "COMCODE, \u98ce\u9669\u7b49\u7ea7")
finally:
workbook.close()
def test_indicator_guidance_prefers_amount_count_and_days_for_overdue_rules(self):
pattern = {
"description_pattern": "\u5efa\u7acb\u503a\u52a1\u5230\u671f\u9884\u8b66\u673a\u5236\uff0c\u5173\u6ce8\u903e\u671f\u91d1\u989d\u548c\u903e\u671f\u7b14\u6570\u3002",
"basis_text": "\u4e25\u63a7\u903e\u671f\u503a\u52a1\u98ce\u9669\u3002",
}
guidance = self.service._build_indicator_guidance(
"\u8fc7\u5ea6\u8d1f\u503a",
pattern,
[{
"module_name": "\u5e94\u4ed8\u7968\u636e",
"fields": [
{"name": "\u662f\u5426\u903e\u671f", "marker": "OVERDUE"},
{"name": "\u5230\u671f\u65e5\u671f", "marker": "DUEDATE"},
{"name": "\u7968\u9762\u91d1\u989d", "marker": "AMOUNT"},
],
}],
)
self.assertIn("\u91d1\u989d\u7c7b", guidance)
self.assertIn("\u6570\u91cf\u7c7b", guidance)
self.assertIn("\u671f\u9650\u7c7b", guidance)
self.assertIn("不要默认生成 A/B 比率", guidance)
def test_rule_prompt_includes_indicator_guidance(self):
prompt = self.service._build_rule_prompt(
"\u8fc7\u5ea6\u8d1f\u503a",
{
"description_pattern": "\u5173\u6ce8\u903e\u671f\u91d1\u989d\u548c\u903e\u671f\u7b14\u6570\u3002",
"basis_text": "\u503a\u52a1\u5230\u671f\u9884\u8b66\u3002",
},
self.service._read_schema(),
)
self.assertIn("\u6307\u6807\u53e3\u5f84\u9009\u62e9\u5efa\u8bae", prompt)
self.assertIn("\u91d1\u989d\u7c7b", prompt)
self.assertIn("\u7b14\u6570", prompt)
def test_generate_rule_retries_once_after_transient_failure(self):
pattern = {
"description_pattern": "\u6838\u5fc3\u6cd5\u89c4\uff1a\u300a\u6d4b\u8bd5\u529e\u6cd5\u300b\uff1b\u6761\u6b3e\u8981\u70b9\uff1a\u52a8\u6001\u76d1\u6d4b\u8d44\u4ea7\u8d1f\u503a\u7387\u3002",
"basis_text": "\u52a8\u6001\u76d1\u6d4b\u8d44\u4ea7\u8d1f\u503a\u7387\u3002",
}
with patch.object(
RuleGenerationService,
"_generate_rule_with_llm",
side_effect=[RuntimeError("timeout"), self._fake_rule()],
) as generate:
rule = self.service._generate_rule("\u8fc7\u5ea6\u8d1f\u503a", pattern, self.service._read_schema())
self.assertEqual(generate.call_count, 2)
self.assertIn("\u5982\u679c", rule["system_rule_text"])
self.assertIn("\u8d4b\u503c", rule["system_rule_text"])
def test_generate_rule_with_llm_disables_minimax_thinking(self):
pattern = {
"description_pattern": "\u6838\u5fc3\u6cd5\u89c4\uff1a\u300a\u6d4b\u8bd5\u529e\u6cd5\u300b\uff1b\u6761\u6b3e\u8981\u70b9\uff1a\u52a8\u6001\u76d1\u6d4b\u8d44\u4ea7\u8d1f\u503a\u7387\u3002",
"basis_text": "\u52a8\u6001\u76d1\u6d4b\u8d44\u4ea7\u8d1f\u503a\u7387\u3002",
}
with patch("app.utils.llm.LLMClient") as client_cls:
client_cls.return_value.chat.return_value = json.dumps(self._fake_rule(), ensure_ascii=False)
self.service._generate_rule_with_llm("\u8fc7\u5ea6\u8d1f\u503a", pattern, self.service._read_schema())
kwargs = client_cls.return_value.chat.call_args.kwargs
self.assertEqual(kwargs["thinking"], {"type": "disabled"})
self.assertEqual(kwargs["max_tokens"], 2600)
def test_start_records_failure_without_fallback_when_llm_fails(self):
with patch.object(RuleGenerationService, "_generate_rule_with_llm", side_effect=RuntimeError("LLM 529")):
state = self.service.start(limit=1)
task_id = state["task_id"]
for _ in range(50):
current = self.service.get_status(task_id)
if current and current["status"] in {"done", "failed"}:
break
time.sleep(0.05)
current = self.service.get_status(task_id)
self.assertEqual(current["status"], "failed")
self.assertEqual(current["generated_count"], 0)
self.assertTrue(os.path.exists(current["markdown_file"]))
self.assertIn("LLM 529", current["skipped_rules"][0]["reason"])
def test_select_policy_points_generates_limit_rules_per_domain(self):
domains = [
{"domain": "A", "patterns": [{"id": "a1"}, {"id": "a2"}]},
{"domain": "B", "patterns": [{"id": "b1"}]},
]
with patch("app.utils.rule_generation.random.sample", lambda patterns, limit: patterns[:limit]):
with patch("app.utils.rule_generation.random.choice", lambda patterns: patterns[0]):
selected = self.service._select_policy_points(domains, 2)
self.assertEqual(len(selected), 4)
self.assertEqual(
[(item["domain"], item["pattern"]["id"], item["variant_index"], item["variant_total"]) for item in selected],
[("A", "a1", 1, 2), ("A", "a2", 2, 2), ("B", "b1", 1, 2), ("B", "b1", 2, 2)],
)
def test_select_policy_points_does_not_multiply_by_pattern_count(self):
domains = [{"domain": "A", "patterns": [{"id": f"a{i}"} for i in range(71)]}]
with patch("app.utils.rule_generation.random.sample", lambda patterns, limit: patterns[:limit]):
selected = self.service._select_policy_points(domains, 2)
self.assertEqual(len(selected), 2)
def test_old_pattern_multiplication_would_have_generated_142_rules(self):
domains = [{"domain": "A", "patterns": [{"id": f"a{i}"} for i in range(71)]}]
selected = self.service._select_policy_points(domains, 2)
self.assertNotEqual(len(selected), 142)
def test_select_policy_points_keeps_domain_grouped_order(self):
domains = [
{"domain": "A", "patterns": [{"id": "a1"}, {"id": "a2"}]},
{"domain": "B", "patterns": [{"id": "b1"}, {"id": "b2"}]},
]
with patch("app.utils.rule_generation.random.sample", lambda patterns, limit: patterns[:limit]):
selected = self.service._select_policy_points(domains, 2)
self.assertEqual([item["domain"] for item in selected], ["A", "A", "B", "B"])
def test_rule_prompt_includes_variant_guidance_for_each_policy_point(self):
pattern = {
"description_pattern": "\u5173\u6ce8\u903e\u671f\u91d1\u989d\u548c\u903e\u671f\u7b14\u6570\u3002",
"basis_text": "\u503a\u52a1\u5230\u671f\u9884\u8b66\u3002",
"_variant_index": 2,
"_variant_total": 3,
}
prompt = self.service._build_rule_prompt("\u8fc7\u5ea6\u8d1f\u503a", pattern, self.service._read_schema())
self.assertIn("\u7b2c 2/3 \u6761\u89c4\u5219", prompt)
self.assertIn("\u4e0d\u5f97\u53ea\u6539\u540d\u79f0\u6216\u9608\u503c", prompt)
if __name__ == "__main__":
unittest.main()