import json import os import shutil import time import unittest import uuid from unittest.mock import patch from openpyxl import load_workbook from app.utils.rule_generation import RuleGenerationService class RuleGenerationServiceTest(unittest.TestCase): def setUp(self): self.output_dir = os.path.join(os.getcwd(), "output", f"test-rule-generation-{uuid.uuid4().hex[:8]}") os.makedirs(self.output_dir, exist_ok=True) self.domains_file = os.path.join(self.output_dir, "domains.json") self.schema_file = os.path.join(self.output_dir, "schema.json") with open(self.domains_file, "w", encoding="utf-8") as f: json.dump({ "domains": [ { "token": "token-1", "domain": "\u8fc7\u5ea6\u8d1f\u503a", "guidance_files": [{ "file_id": "file-1", "filename": "policy.txt", "guidance_analysis": { "status": "done", "description_patterns": [{ "description_pattern": "\u6838\u5fc3\u6cd5\u89c4\uff1a\u300a\u6d4b\u8bd5\u529e\u6cd5\u300b\uff1b\u6761\u6b3e\u8981\u70b9\uff1a\u52a8\u6001\u76d1\u6d4b\u8d44\u4ea7\u8d1f\u503a\u7387\u3002", "basis_text": "\u52a8\u6001\u76d1\u6d4b\u8d44\u4ea7\u8d1f\u503a\u7387\u3002", "source_sentence": "\u52a8\u6001\u76d1\u6d4b\u8d44\u4ea7\u8d1f\u503a\u7387\u3002", "supervision_dimension": "\u8d44\u91d1\u7a7f\u900f", "keywords": ["\u8d44\u4ea7\u8d1f\u503a\u7387", "\u503a\u52a1"], }], }, }], }, { "token": "token-no-guidance", "domain": "\u865a\u5047\u8d38\u6613", "guidance_files": [], }, { "token": "token-unparsed", "domain": "\u591a\u5143\u6295\u8d44", "guidance_files": [{ "file_id": "file-2", "filename": "pending.txt", }], }, ], }, f, ensure_ascii=False) self._write_schema() self.service = RuleGenerationService( domains_file=self.domains_file, schema_file=self.schema_file, output_dir=self.output_dir, ) def tearDown(self): if os.path.isdir(self.output_dir): shutil.rmtree(self.output_dir) def _write_schema(self): with open(self.schema_file, "w", encoding="utf-8") as f: json.dump({ "modules": [ { "module_name": "\u94f6\u884c\u8d37\u6b3e", "table_name": "bank_loan", "description": "\u8bb0\u5f55\u8d37\u6b3e\u4f59\u989d\u548c\u503a\u52a1\u878d\u8d44\u4fe1\u606f\u3002", "fields": [ {"name": "\u6240\u5c5e\u96c6\u56e2\u7f16\u7801", "marker": "COMCODE"}, {"name": "\u6240\u5c5e\u96c6\u56e2\u540d\u79f0", "marker": "COMNAME"}, {"name": "\u8d37\u6b3e\u5355\u4f4d\u540d\u79f0", "marker": "COMPANY_CLTNAME"}, {"name": "\u8d37\u6b3e\u4f59\u989d", "marker": "STANDARDCURRENCYBALANCE"}, ], }, { "module_name": "\u94f6\u884c\u8d26\u6237", "table_name": "bank_account", "description": "\u8bb0\u5f55\u8d26\u6237\u4f59\u989d\u548c\u8d26\u6237\u72b6\u6001\u3002", "fields": [ {"name": "\u6240\u5c5e\u96c6\u56e2\u7f16\u7801", "marker": "COMCODE"}, {"name": "\u8d26\u6237\u4f59\u989d", "marker": "STANDARDCURRENCYBALANCE"}, ], }, ], }, f, ensure_ascii=False) def _fake_rule(self): return { "rule_name": "\u8d44\u4ea7\u8d1f\u503a\u7387\u8d85\u9650\u9884\u8b66", "risk_description": "\u6d4b\u8bd5\u98ce\u9669\u63cf\u8ff0", "policy_basis": "\u6838\u5fc3\u6cd5\u89c4\uff1a\u300a\u6d4b\u8bd5\u529e\u6cd5\u300b\uff1b\u6761\u6b3e\u8981\u70b9\uff1a\u52a8\u6001\u76d1\u6d4b\u8d44\u4ea7\u8d1f\u503a\u7387\u3002", "business_rule_description": "\u8d44\u4ea7\u8d1f\u503a\u7387 = \u8d37\u6b3e\u4f59\u989d / \u8d26\u6237\u4f59\u989d\n\u98ce\u9669\u7b49\u7ea7\uff1a\n- \u7ea2\u8272\uff1a\u8d44\u4ea7\u8d1f\u503a\u7387 > 85%\n- \u6a59\u8272\uff1a\u8d44\u4ea7\u8d1f\u503a\u7387 > 75%\n- \u9ec4\u8272\uff1a\u8d44\u4ea7\u8d1f\u503a\u7387 > 65%", "data_sources": "\u94f6\u884c\u8d37\u6b3e\uff1a\u6240\u5c5e\u96c6\u56e2\u7f16\u7801\u3001\u8d37\u6b3e\u4f59\u989d\n\u94f6\u884c\u8d26\u6237\uff1a\u8d26\u6237\u4f59\u989d", "system_rule_text": "\u5982\u679c\uff1a\u8d44\u4ea7\u8d1f\u503a\u7387 > 0.85\n\u5219\uff1a\u6821\u9a8c\u901a\u8fc7", "join_logic": "\u94f6\u884c\u8d37\u6b3e\u5de6\u5173\u8054\u94f6\u884c\u8d26\u6237\u3002", "system_logic": "\u503c1\uff1a\u8d37\u6b3e\u4f59\u989d\uff1b\u503c2\uff1a\u8d26\u6237\u4f59\u989d\u3002", "return_fields": "\u6240\u5c5e\u96c6\u56e2\u7f16\u7801,\u8d37\u6b3e\u4f59\u989d,\u98ce\u9669\u7b49\u7ea7", } def test_start_generates_status_and_excel(self): with patch.object(RuleGenerationService, "_generate_rule", return_value=self._fake_rule()): state = self.service.start(limit=1) task_id = state["task_id"] for _ in range(50): current = self.service.get_status(task_id) if current and current["status"] in {"done", "failed"}: break time.sleep(0.05) current = self.service.get_status(task_id) self.assertEqual(current["status"], "done") self.assertEqual(current["generated_count"], 1) self.assertTrue(os.path.exists(current["output_file"])) self.assertTrue(os.path.exists(current["markdown_file"])) self.assertEqual(os.path.dirname(current["output_file"]), os.path.dirname(current["markdown_file"])) self.assertEqual(os.path.basename(os.path.dirname(current["output_file"])), f"rules-{task_id}") self.assertEqual(current["output_dir"], os.path.dirname(current["output_file"])) self.assertEqual(current["files"]["excel"], current["output_file"]) self.assertEqual(current["files"]["markdown"], current["markdown_file"]) workbook = load_workbook(current["output_file"]) try: sheet = workbook.active self.assertEqual(sheet.max_column, 11) self.assertEqual(sheet["A2"].value, 1) self.assertIn("bank_loan(\u94f6\u884c\u8d37\u6b3e)", sheet["G2"].value) self.assertIn("COMCODE(\u6240\u5c5e\u96c6\u56e2\u7f16\u7801)", sheet["G2"].value) self.assertIn("STANDARDCURRENCYBALANCE(\u8d37\u6b3e\u4f59\u989d)", sheet["G2"].value) self.assertEqual(sheet["L1"].value, None) finally: workbook.close() with open(current["markdown_file"], "r", encoding="utf-8") as file: markdown = file.read() self.assertIn("# \u91cd\u70b9\u98ce\u9669\u9886\u57df\u89c4\u5219\u6a21\u578b\u4e0e\u811a\u672c", markdown) self.assertIn("\u4ee5\u5b9e\u9645\u751f\u6210\u7684\u98ce\u9669\u9886\u57df\u4e3a\u51c6", markdown) self.assertIn("## 1. \u8fc7\u5ea6\u8d1f\u503a", markdown) self.assertIn("### \u89c4\u5219\u6a21\u578b1.1\uff1a\u8d44\u4ea7\u8d1f\u503a\u7387\u8d85\u9650\u9884\u8b66", markdown) self.assertIn("bank_loan(\u94f6\u884c\u8d37\u6b3e)", markdown) def test_start_generates_limit_rules_per_policy_point(self): with patch.object(RuleGenerationService, "_generate_rule", return_value=self._fake_rule()) as generate: state = self.service.start(limit=2) task_id = state["task_id"] for _ in range(50): current = self.service.get_status(task_id) if current and current["status"] in {"done", "failed"}: break time.sleep(0.05) current = self.service.get_status(task_id) self.assertEqual(current["status"], "done") self.assertEqual(current["generated_count"], 2) self.assertEqual(current["progress"]["total"], 2) self.assertEqual(generate.call_count, 2) self.assertEqual( {item["domain"] for item in current["skipped_domains"]}, {"\u865a\u5047\u8d38\u6613", "\u591a\u5143\u6295\u8d44"}, ) workbook = load_workbook(current["output_file"]) try: self.assertEqual(workbook.active.max_row, 3) finally: workbook.close() def test_collect_policy_basis_only_uses_analyzed_guidance_files(self): domains = self.service._collect_policy_basis() self.assertEqual([item["token"] for item in domains], ["token-1"]) self.assertEqual(len(domains[0]["patterns"]), 1) def test_write_excel_appends_sql_column_when_enabled(self): service = RuleGenerationService( domains_file=self.domains_file, schema_file=self.schema_file, output_dir=self.output_dir, create_sql=True, ) output_file = os.path.join(self.output_dir, "with-sql.xlsx") service._write_excel([self._fake_rule()], output_file) workbook = load_workbook(output_file) try: sheet = workbook.active sql = sheet["L2"].value self.assertEqual(sheet.max_column, 12) self.assertIn("SQL", sheet["L1"].value) self.assertIn("SELECT", sql) self.assertIn("FROM bank_loan", sql) self.assertIn("STANDARDCURRENCYBALANCE", sql) self.assertIn("CASE", sql) self.assertNotIn("HAVING", sql) self.assertTrue(all(ord(char) < 128 for char in sql)) finally: workbook.close() def test_write_excel_uses_legacy_output_format(self): service = RuleGenerationService( domains_file=self.domains_file, schema_file=self.schema_file, output_dir=self.output_dir, create_sql=True, ) output_file = os.path.join(self.output_dir, "formatted.xlsx") service._write_excel([self._fake_rule()], output_file) workbook = load_workbook(output_file) try: sheet = workbook.active self.assertEqual(sheet.title, "11类重点风险领域") self.assertEqual(sheet.freeze_panes, "A2") self.assertEqual(sheet.auto_filter.ref, "A1:L2") self.assertEqual( [sheet.cell(1, col).value for col in range(1, 13)], [ "序号", "风险领域", "规则名称", "风险描述", "制度依据", "业务规则描述", "数据来源", "系统规则文本", "关联表逻辑", "系统固化逻辑", "返回结果", "SQL查询语句", ], ) self.assertAlmostEqual(sheet.column_dimensions["A"].width, 5.13, places=2) self.assertAlmostEqual(sheet.column_dimensions["H"].width, 107.36, places=2) self.assertTrue(sheet["A1"].font.bold) self.assertEqual(sheet["F1"].fill.fgColor.rgb, "00FFFF00") self.assertEqual(sheet["G1"].fill.fgColor.rgb, "00FFFF00") self.assertEqual(sheet["H1"].fill.fgColor.rgb, "00FFFF00") self.assertEqual(sheet["A1"].alignment.horizontal, "center") self.assertTrue(sheet["F2"].alignment.wrap_text) self.assertEqual(sheet["F2"].alignment.horizontal, "left") self.assertEqual(sheet["A2"].alignment.horizontal, "center") self.assertEqual(sheet["A1"].border.left.style, "thin") self.assertEqual(sheet["L2"].border.bottom.style, "thin") self.assertGreater(sheet.row_dimensions[2].height, 100) finally: workbook.close() def test_write_markdown_includes_sql_when_enabled(self): service = RuleGenerationService( domains_file=self.domains_file, schema_file=self.schema_file, output_dir=self.output_dir, create_sql=True, ) markdown_file = os.path.join(self.output_dir, "with-sql.md") service._write_markdown([self._fake_rule()], markdown_file) with open(markdown_file, "r", encoding="utf-8") as file: markdown = file.read() self.assertIn("**\u5b9e\u73b0\u811a\u672c\uff08SQL\uff09**", markdown) self.assertIn("```sql", markdown) self.assertIn("SELECT", markdown) self.assertIn("FROM bank_loan", markdown) self.assertIn("```", markdown) def test_markdown_failure_does_not_overwrite_successful_excel(self): with patch.object(RuleGenerationService, "_generate_rule", return_value=self._fake_rule()): with patch.object(RuleGenerationService, "_write_markdown", side_effect=RuntimeError("markdown failed")): state = self.service.start(limit=1) task_id = state["task_id"] for _ in range(50): current = self.service.get_status(task_id) if current and current["status"] in {"done", "failed"}: break time.sleep(0.05) current = self.service.get_status(task_id) self.assertEqual(current["status"], "done") self.assertEqual(current["generated_count"], 1) self.assertEqual(current["markdown_error"], "markdown failed") workbook = load_workbook(current["output_file"]) try: self.assertEqual(workbook.active.max_row, 2) self.assertEqual(workbook.active["A2"].value, 1) finally: workbook.close() def test_create_sql_query_uses_schema_markers_only(self): sql = self.service._create_sql_query(self._fake_rule()) self.assertEqual(sql.splitlines()[0], "SELECT") self.assertIn("t1.COMCODE", sql) self.assertIn("t1.STANDARDCURRENCYBALANCE", sql) self.assertIn("FROM bank_loan", sql) self.assertIn("FROM bank_account", sql) self.assertIn("LEFT JOIN (", sql) self.assertIn(") t2 ON t1.COMCODE = t2.COMCODE", sql) self.assertNotIn("WITH", sql) self.assertIn("(t1.STANDARDCURRENCYBALANCE / NULLIF(t2.STANDARDCURRENCYBALANCE, 0))", sql) self.assertIn("t1.STANDARDCURRENCYBALANCE AS VALUE_1", sql) self.assertIn("t2.STANDARDCURRENCYBALANCE AS VALUE_2", sql) self.assertIn("THEN 'RED'", sql) self.assertIn("THEN 'ORANGE'", sql) self.assertIn("THEN 'YELLOW'", sql) self.assertNotIn("HAVING", sql) self.assertNotIn("\u8d37\u6b3e\u4f59\u989d", sql) self.assertTrue(all(ord(char) < 128 for char in sql)) def test_sql_does_not_sum_date_like_balance_markers(self): self._write_schema_with_date_like_marker() sql = self.service._create_sql_query(self._fake_rule()) self.assertNotIn("WITH", sql) self.assertNotIn("SUM(BANKBALANCEDATE)", sql) self.assertIn("BANKBALANCEDATE", sql) def test_generated_sql_validator_rejects_mysql_incompatible_cte(self): with self.assertRaisesRegex(ValueError, "WITH/CTE"): self.service._validate_generated_sql("WITH\nt1 AS (SELECT 1)\nSELECT * FROM t1;") def test_generated_sql_validator_rejects_date_sum(self): with self.assertRaisesRegex(ValueError, "SUM date/time"): self.service._validate_generated_sql("SELECT SUM(BANKBALANCEDATE) FROM bank_account;") def _write_schema_with_date_like_marker(self): with open(self.schema_file, "w", encoding="utf-8") as f: json.dump({ "modules": [{ "module_name": "\u94f6\u884c\u8d37\u6b3e", "table_name": "bank_loan", "fields": [ {"name": "\u6240\u5c5e\u96c6\u56e2\u7f16\u7801", "marker": "COMCODE"}, {"name": "\u8d37\u6b3e\u4f59\u989d", "marker": "STANDARDCURRENCYBALANCE"}, {"name": "\u8d26\u6237\u4f59\u989d\u65f6\u95f4", "marker": "BANKBALANCEDATE", "type": "\u65e5\u671f"}, ], }], }, f, ensure_ascii=False) def test_data_sources_display_marker_with_chinese_name(self): value = self.service._format_data_sources_with_markers( "\u94f6\u884c\u8d37\u6b3e:\u6240\u5c5e\u96c6\u56e2\u7f16\u7801,\u8d37\u6b3e\u4f59\u989d", self.service._read_schema(), ) self.assertEqual( value, "bank_loan(\u94f6\u884c\u8d37\u6b3e): " "COMCODE(\u6240\u5c5e\u96c6\u56e2\u7f16\u7801), " "STANDARDCURRENCYBALANCE(\u8d37\u6b3e\u4f59\u989d)", ) def test_metric_alignment_accepts_close_indicator_names(self): business_rule = ( "\u503a\u52a1\u878d\u8d44\u7ed3\u6784\u5360\u6bd4 = " "\u503a\u52a1\u878d\u8d44\u4f59\u989d / \u603b\u8d44\u4ea7\n" "\u98ce\u9669\u7b49\u7ea7\uff1a\n" "- \u7ea2\u8272\uff1a\u503a\u52a1\u878d\u8d44\u5360\u6bd4 > 80%\n" "- \u6a59\u8272\uff1a\u503a\u52a1\u878d\u8d44\u5360\u6bd4 > 60%\n" "- \u9ec4\u8272\uff1a\u503a\u52a1\u878d\u8d44\u5360\u6bd4 > 40%" ) RuleGenerationService._validate_business_rule_indicator_alignment(business_rule) count_rule = ( "\u5408\u540c\u7b14\u6570 = COUNT(\u5408\u540c\u7f16\u53f7)\n" "\u98ce\u9669\u7b49\u7ea7\uff1a\n" "- \u7ea2\u8272\uff1a\u5408\u540c\u6570\u91cf > 100\n" "- \u6a59\u8272\uff1a\u5408\u540c\u6570\u91cf > 50\n" "- \u9ec4\u8272\uff1a\u5408\u540c\u6570\u91cf > 20" ) RuleGenerationService._validate_business_rule_indicator_alignment(count_rule) def test_metric_alignment_rejects_wrong_metric_category(self): business_rule = ( "\u5883\u5916\u6295\u8d44\u6536\u76ca\u7387 = " "\u6295\u8d44\u6536\u76ca / \u6295\u8d44\u6210\u672c\n" "\u98ce\u9669\u7b49\u7ea7\uff1a\n" "- \u7ea2\u8272\uff1a\u5883\u5916\u6295\u8d44\u91d1\u989d > 10000000\n" "- \u6a59\u8272\uff1a\u5883\u5916\u6295\u8d44\u91d1\u989d > 5000000\n" "- \u9ec4\u8272\uff1a\u5883\u5916\u6295\u8d44\u91d1\u989d > 1000000" ) with self.assertRaisesRegex(ValueError, "\u6307\u6807\u9608\u503c\u672a\u5bf9\u5e94\u516c\u5f0f\u6307\u6807"): RuleGenerationService._validate_business_rule_indicator_alignment(business_rule) def test_rule_shape_allows_single_condition_without_and(self): rule = self._fake_rule() rule["system_rule_text"] = ( "\u8d44\u4ea7\u8d1f\u503a\u7387\u8d85\u9650\u9884\u8b66\uff08\u9ec4\u8272\u9608\u503c\uff09:\n" "\u5982\u679c\uff1a\n" " (\u8d44\u4ea7\u8d1f\u503a\u7387 > 0.65)\n" "\u5c31\uff1a\n" "\u6821\u9a8c\uff1a\u901a\u8fc7\n" "\u8d4b\u503c\uff1a\u3010\u89c4\u5219\u7ed3\u679c-\u98ce\u9669\u7b49\u7ea7\u3011\u7b49\u4e8e \u9ec4\u8272" ) rule["join_logic"] = "\u94f6\u884c\u8d37\u6b3e\u8868\u5de6\u5173\u8054\u94f6\u884c\u8d26\u6237\u8868\uff0c\u6309\u6240\u5c5e\u96c6\u56e2\u7f16\u7801\u6c47\u603b\u3002" RuleGenerationService._validate_rule_shape(rule) def test_normalize_rule_fills_recoverable_columns(self): raw = self._fake_rule() raw["business_rule_description"] = ( "\u503a\u52a1\u878d\u8d44\u7ed3\u6784\u5360\u6bd4 = " "\u8d37\u6b3e\u4f59\u989d / \u8d26\u6237\u4f59\u989d\n" "\u98ce\u9669\u7b49\u7ea7\uff1a\n" "- \u7ea2\u8272\uff1a\u503a\u52a1\u878d\u8d44\u5360\u6bd4 > 80%\n" "- \u6a59\u8272\uff1a\u503a\u52a1\u878d\u8d44\u5360\u6bd4 > 60%\n" "- \u9ec4\u8272\uff1a\u503a\u52a1\u878d\u8d44\u5360\u6bd4 > 40%" ) for key in ("data_sources", "system_rule_text", "join_logic", "system_logic", "return_fields"): raw[key] = "" normalized = self.service._normalize_rule( raw, "\u8fc7\u5ea6\u8d1f\u503a", { "description_pattern": "\u6838\u5fc3\u6cd5\u89c4\uff1a\u300a\u6d4b\u8bd5\u529e\u6cd5\u300b\uff1b\u6761\u6b3e\u8981\u70b9\uff1a\u52a8\u6001\u76d1\u6d4b\u503a\u52a1\u878d\u8d44\u7ed3\u6784\u3002", "basis_text": "\u52a8\u6001\u76d1\u6d4b\u503a\u52a1\u878d\u8d44\u7ed3\u6784\u3002", }, self.service._read_schema(), ) self.assertIn("bank_loan(\u94f6\u884c\u8d37\u6b3e)", normalized["data_sources"]) self.assertIn("\u5982\u679c", normalized["system_rule_text"]) self.assertIn("\u8d4b\u503c", normalized["system_rule_text"]) self.assertIn("\u8868", normalized["join_logic"]) self.assertIn("\u503c1", normalized["system_logic"]) self.assertIn("\u503c2", normalized["system_logic"]) self.assertIn("\u98ce\u9669\u7b49\u7ea7", normalized["return_fields"]) def test_normalize_rule_converts_nested_llm_values_to_excel_text(self): raw = self._fake_rule() raw["policy_basis"] = { "policy_name": "\u300a\u6d4b\u8bd5\u529e\u6cd5\u300b", "clause_summary": "\u52a8\u6001\u76d1\u6d4b\u503a\u52a1\u98ce\u9669\u3002", } raw["data_sources"] = [ {"table": "\u94f6\u884c\u8d37\u6b3e", "fields": ["COMCODE", "STANDARDCURRENCYBALANCE"]}, ] raw["return_fields"] = ["COMCODE", "\u98ce\u9669\u7b49\u7ea7"] normalized = self.service._normalize_rule( raw, "\u8fc7\u5ea6\u8d1f\u503a", { "description_pattern": "\u6838\u5fc3\u6cd5\u89c4\uff1a\u300a\u6d4b\u8bd5\u529e\u6cd5\u300b\uff1b\u6761\u6b3e\u8981\u70b9\uff1a\u52a8\u6001\u76d1\u6d4b\u503a\u52a1\u98ce\u9669\u3002", "basis_text": "\u52a8\u6001\u76d1\u6d4b\u503a\u52a1\u98ce\u9669\u3002", }, self.service._read_schema(), ) output_file = os.path.join(self.output_dir, "nested-values.xlsx") self.service._write_excel([normalized], output_file) workbook = load_workbook(output_file) try: sheet = workbook.active self.assertIn("\u6838\u5fc3\u6cd5\u89c4\uff1a\u300a\u6d4b\u8bd5\u529e\u6cd5\u300b", sheet["E2"].value) self.assertIn("bank_loan(\u94f6\u884c\u8d37\u6b3e)", sheet["G2"].value) self.assertIn("STANDARDCURRENCYBALANCE(\u8d37\u6b3e\u4f59\u989d)", sheet["G2"].value) self.assertEqual(sheet["K2"].value, "COMCODE, \u98ce\u9669\u7b49\u7ea7") finally: workbook.close() def test_indicator_guidance_prefers_amount_count_and_days_for_overdue_rules(self): pattern = { "description_pattern": "\u5efa\u7acb\u503a\u52a1\u5230\u671f\u9884\u8b66\u673a\u5236\uff0c\u5173\u6ce8\u903e\u671f\u91d1\u989d\u548c\u903e\u671f\u7b14\u6570\u3002", "basis_text": "\u4e25\u63a7\u903e\u671f\u503a\u52a1\u98ce\u9669\u3002", } guidance = self.service._build_indicator_guidance( "\u8fc7\u5ea6\u8d1f\u503a", pattern, [{ "module_name": "\u5e94\u4ed8\u7968\u636e", "fields": [ {"name": "\u662f\u5426\u903e\u671f", "marker": "OVERDUE"}, {"name": "\u5230\u671f\u65e5\u671f", "marker": "DUEDATE"}, {"name": "\u7968\u9762\u91d1\u989d", "marker": "AMOUNT"}, ], }], ) self.assertIn("\u91d1\u989d\u7c7b", guidance) self.assertIn("\u6570\u91cf\u7c7b", guidance) self.assertIn("\u671f\u9650\u7c7b", guidance) self.assertIn("不要默认生成 A/B 比率", guidance) def test_rule_prompt_includes_indicator_guidance(self): prompt = self.service._build_rule_prompt( "\u8fc7\u5ea6\u8d1f\u503a", { "description_pattern": "\u5173\u6ce8\u903e\u671f\u91d1\u989d\u548c\u903e\u671f\u7b14\u6570\u3002", "basis_text": "\u503a\u52a1\u5230\u671f\u9884\u8b66\u3002", }, self.service._read_schema(), ) self.assertIn("\u6307\u6807\u53e3\u5f84\u9009\u62e9\u5efa\u8bae", prompt) self.assertIn("\u91d1\u989d\u7c7b", prompt) self.assertIn("\u7b14\u6570", prompt) def test_generate_rule_retries_once_after_transient_failure(self): pattern = { "description_pattern": "\u6838\u5fc3\u6cd5\u89c4\uff1a\u300a\u6d4b\u8bd5\u529e\u6cd5\u300b\uff1b\u6761\u6b3e\u8981\u70b9\uff1a\u52a8\u6001\u76d1\u6d4b\u8d44\u4ea7\u8d1f\u503a\u7387\u3002", "basis_text": "\u52a8\u6001\u76d1\u6d4b\u8d44\u4ea7\u8d1f\u503a\u7387\u3002", } with patch.object( RuleGenerationService, "_generate_rule_with_llm", side_effect=[RuntimeError("timeout"), self._fake_rule()], ) as generate: rule = self.service._generate_rule("\u8fc7\u5ea6\u8d1f\u503a", pattern, self.service._read_schema()) self.assertEqual(generate.call_count, 2) self.assertIn("\u5982\u679c", rule["system_rule_text"]) self.assertIn("\u8d4b\u503c", rule["system_rule_text"]) def test_generate_rule_with_llm_disables_minimax_thinking(self): pattern = { "description_pattern": "\u6838\u5fc3\u6cd5\u89c4\uff1a\u300a\u6d4b\u8bd5\u529e\u6cd5\u300b\uff1b\u6761\u6b3e\u8981\u70b9\uff1a\u52a8\u6001\u76d1\u6d4b\u8d44\u4ea7\u8d1f\u503a\u7387\u3002", "basis_text": "\u52a8\u6001\u76d1\u6d4b\u8d44\u4ea7\u8d1f\u503a\u7387\u3002", } with patch("app.utils.llm.LLMClient") as client_cls: client_cls.return_value.chat.return_value = json.dumps(self._fake_rule(), ensure_ascii=False) self.service._generate_rule_with_llm("\u8fc7\u5ea6\u8d1f\u503a", pattern, self.service._read_schema()) kwargs = client_cls.return_value.chat.call_args.kwargs self.assertEqual(kwargs["thinking"], {"type": "disabled"}) self.assertEqual(kwargs["max_tokens"], 2600) def test_start_records_failure_without_fallback_when_llm_fails(self): with patch.object(RuleGenerationService, "_generate_rule_with_llm", side_effect=RuntimeError("LLM 529")): state = self.service.start(limit=1) task_id = state["task_id"] for _ in range(50): current = self.service.get_status(task_id) if current and current["status"] in {"done", "failed"}: break time.sleep(0.05) current = self.service.get_status(task_id) self.assertEqual(current["status"], "failed") self.assertEqual(current["generated_count"], 0) self.assertTrue(os.path.exists(current["markdown_file"])) self.assertIn("LLM 529", current["skipped_rules"][0]["reason"]) def test_select_policy_points_generates_limit_rules_per_domain(self): domains = [ {"domain": "A", "patterns": [{"id": "a1"}, {"id": "a2"}]}, {"domain": "B", "patterns": [{"id": "b1"}]}, ] with patch("app.utils.rule_generation.random.sample", lambda patterns, limit: patterns[:limit]): with patch("app.utils.rule_generation.random.choice", lambda patterns: patterns[0]): selected = self.service._select_policy_points(domains, 2) self.assertEqual(len(selected), 4) self.assertEqual( [(item["domain"], item["pattern"]["id"], item["variant_index"], item["variant_total"]) for item in selected], [("A", "a1", 1, 2), ("A", "a2", 2, 2), ("B", "b1", 1, 2), ("B", "b1", 2, 2)], ) def test_select_policy_points_does_not_multiply_by_pattern_count(self): domains = [{"domain": "A", "patterns": [{"id": f"a{i}"} for i in range(71)]}] with patch("app.utils.rule_generation.random.sample", lambda patterns, limit: patterns[:limit]): selected = self.service._select_policy_points(domains, 2) self.assertEqual(len(selected), 2) def test_old_pattern_multiplication_would_have_generated_142_rules(self): domains = [{"domain": "A", "patterns": [{"id": f"a{i}"} for i in range(71)]}] selected = self.service._select_policy_points(domains, 2) self.assertNotEqual(len(selected), 142) def test_select_policy_points_keeps_domain_grouped_order(self): domains = [ {"domain": "A", "patterns": [{"id": "a1"}, {"id": "a2"}]}, {"domain": "B", "patterns": [{"id": "b1"}, {"id": "b2"}]}, ] with patch("app.utils.rule_generation.random.sample", lambda patterns, limit: patterns[:limit]): selected = self.service._select_policy_points(domains, 2) self.assertEqual([item["domain"] for item in selected], ["A", "A", "B", "B"]) def test_rule_prompt_includes_variant_guidance_for_each_policy_point(self): pattern = { "description_pattern": "\u5173\u6ce8\u903e\u671f\u91d1\u989d\u548c\u903e\u671f\u7b14\u6570\u3002", "basis_text": "\u503a\u52a1\u5230\u671f\u9884\u8b66\u3002", "_variant_index": 2, "_variant_total": 3, } prompt = self.service._build_rule_prompt("\u8fc7\u5ea6\u8d1f\u503a", pattern, self.service._read_schema()) self.assertIn("\u7b2c 2/3 \u6761\u89c4\u5219", prompt) self.assertIn("\u4e0d\u5f97\u53ea\u6539\u540d\u79f0\u6216\u9608\u503c", prompt) if __name__ == "__main__": unittest.main()