Files
YG-Rules/app/utils/rule_generation.py

1901 lines
89 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""异步业务规则生成服务。"""
import json
import os
import random
import re
import threading
import uuid
from datetime import datetime
from pathlib import Path
from typing import Any
from openpyxl import Workbook
from openpyxl.styles import Alignment, Border, Font, PatternFill, Side
from openpyxl.utils import get_column_letter
from app.utils.logger import get_logger
logger = get_logger("rule_generation")
_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
_DOMAINS_FILE = os.path.join(_PROJECT_ROOT, "data", "domains.json")
_SCHEMA_FILE = os.path.join(_PROJECT_ROOT, "data", "schema.json")
_OUTPUT_DIR = os.path.join(_PROJECT_ROOT, "output")
_TASK_DIR = os.path.join(_OUTPUT_DIR, "tasks")
_RULE_COLUMNS = [
"序号",
"风险领域",
"规则名称",
"风险描述",
"制度依据",
"业务规则描述",
"数据来源",
"系统规则文本",
"关联表逻辑",
"系统固化逻辑",
"返回结果",
]
_COLUMN_WIDTHS = [5.13, 14.9, 15.84, 27.08, 35.34, 40.55, 37.23, 107.36, 34.58, 39.58, 27.23]
_CENTER_COLUMNS = {1, 2, 3}
_HIGHLIGHT_COLUMNS = {6, 7, 8}
_THIN_SIDE = Side(style="thin", color="000000")
_THIN_BORDER = Border(left=_THIN_SIDE, right=_THIN_SIDE, top=_THIN_SIDE, bottom=_THIN_SIDE)
_HEADER_FILL = PatternFill("solid", fgColor="FFFFFF")
_HIGHLIGHT_FILL = PatternFill("solid", fgColor="FFFF00")
_DOMAIN_TABLE_HINTS = {
"过度负债": ["银行贷款", "应付债券", "银行账户", "担保", "应付票据", "财务公司"],
"无关多元": ["金融投资", "客商信息", "应收款项", "应付款项", "合同"],
"多层架构": ["客商信息", "合同", "应收款项", "应付款项"],
"薪酬乱象": ["资金结算", "银行账户"],
"财务金融风险": ["银行贷款", "应付债券", "担保", "金融投资", "资金结算", "银行账户"],
"控股不控权": ["客商信息", "合同", "应收款项", "应付款项"],
"虚假贸易": ["资金结算", "合同", "增值税发票", "应收款项", "应付款项", "客商信息"],
"捞偏门": ["资金结算", "合同", "客商信息", "增值税发票"],
"靠企吃企": ["客商信息", "资金结算", "合同", "应付款项", "应收款项"],
"资产闲置浪费": ["银行账户", "金融投资", "应收款项", "应付款项"],
"违规境外业务": ["资金结算", "银行账户", "金融投资", "合同"],
}
_KEYWORD_TABLE_HINTS = {
"资产负债率": ["银行贷款", "应付债券", "银行账户", "担保", "财务公司"],
"债务": ["银行贷款", "应付债券", "应付票据", "担保"],
"融资": ["银行贷款", "应付债券", "担保", "供应链金融"],
"担保": ["担保", "银行贷款"],
"票据": ["应付票据", "应收票据"],
"资金流": ["资金结算", "银行账户"],
"账户": ["银行账户", "资金结算"],
"交易": ["资金结算", "合同", "增值税发票", "客商信息"],
"合同": ["合同", "资金结算", "增值税发票"],
"发票": ["增值税发票", "合同"],
"客商": ["客商信息", "应收款项", "应付款项"],
"关联": ["客商信息", "资金结算", "合同"],
"投资": ["金融投资业务", "客商信息"],
"主业": ["金融投资业务", "客商信息"],
}
_FIELD_ROLE_KEYWORDS = {
"subject": ("所属集团编码", "所属集团名称", "开户单位编码", "开户单位名称", "单位编码", "单位名称", "贷款单位名称", "客商名称"),
"amount": ("余额", "金额", "资产", "负债", "市值", "成本", "发生额"),
"status": ("状态", "标识", "类型", "是否", "级别", "性质", "分类"),
"date": ("日期", "时间", "期间", "期限", "到期"),
"counterparty": ("承销商", "放款单位", "交易对手", "客商", "合同", "账户", "银行"),
}
class RuleGenerationService:
"""规则生成任务管理器。"""
def __init__(
self,
domains_file: str = _DOMAINS_FILE,
schema_file: str = _SCHEMA_FILE,
output_dir: str = _OUTPUT_DIR,
create_sql: bool = False,
):
self.domains_file = domains_file
self.schema_file = schema_file
self.output_dir = output_dir
self.task_dir = os.path.join(output_dir, "tasks")
self.create_sql = create_sql
os.makedirs(self.output_dir, exist_ok=True)
os.makedirs(self.task_dir, exist_ok=True)
def start(self, limit: int = 30) -> dict[str, Any]:
limit = self._normalize_limit(limit)
task_id = self._new_task_id()
task_output_dir = os.path.join(self.output_dir, f"rules-{task_id}")
output_file = os.path.join(task_output_dir, f"rules-{task_id}.xlsx")
markdown_file = self._markdown_file_for_excel(output_file)
state = {
"task_id": task_id,
"status": "queued",
"limit": limit,
"rules_per_policy_point": limit,
"create_sql": self.create_sql,
"progress": {"current": 0, "total": 0},
"generated_count": 0,
"output_dir": task_output_dir,
"output_file": output_file,
"markdown_file": markdown_file,
"files": {
"excel": output_file,
"markdown": markdown_file,
},
"started_at": datetime.now().isoformat(),
"finished_at": None,
"error": None,
"markdown_error": None,
"current_domain": "",
"skipped_domains": [],
"skipped_rules": [],
}
self._write_state(task_id, state)
thread = threading.Thread(target=self._run_task, args=(task_id, limit, output_file), daemon=True)
thread.start()
return state
def get_status(self, task_id: str) -> dict[str, Any] | None:
path = self._state_path(task_id)
if not os.path.exists(path):
return None
return self._read_json(path)
def _run_task(self, task_id: str, limit: int, output_file: str) -> None:
state = self.get_status(task_id) or {}
try:
state["status"] = "processing"
self._write_state(task_id, state)
domains = self._collect_policy_basis()
selected_points = self._select_policy_points(domains, limit)
schema = self._read_schema()
state["progress"] = {"current": 0, "total": len(selected_points)}
self._write_state(task_id, state)
rules: list[dict[str, Any]] = []
skipped_domains = self._skipped_domains_without_basis(domains)
skipped_rules: list[dict[str, Any]] = []
for index, point in enumerate(selected_points, start=1):
domain = point["domain"]
pattern = point["pattern"]
pattern = {
**pattern,
"_variant_index": point.get("variant_index", 1),
"_variant_total": point.get("variant_total", limit),
}
state["current_domain"] = domain
state["progress"] = {"current": index, "total": len(selected_points)}
self._write_state(task_id, state)
candidates = self._select_schema_candidates(domain, pattern, schema)
if not candidates:
skipped_rules.append({
"domain": domain,
"basis": self._build_policy_basis_text(pattern)[:120],
"reason": "未匹配到可用 schema 表",
})
continue
try:
rule = self._generate_rule(domain, pattern, candidates)
except Exception as exc:
reason = str(exc)
logger.error(
"LLM 规则生成失败,跳过该规则 | domain=%s | reason=%s",
domain,
reason,
exc_info=True,
)
skipped_rules.append({
"domain": domain,
"basis": self._build_policy_basis_text(pattern)[:120],
"reason": reason,
})
state["skipped_rules"] = skipped_rules
self._write_state(task_id, state)
continue
rule.setdefault("risk_domain", domain)
rules.append(rule)
state["generated_count"] = len(rules)
state["skipped_rules"] = skipped_rules
self._write_state(task_id, state)
self._write_excel(rules, output_file)
try:
self._write_markdown(rules, state.get("markdown_file") or self._markdown_file_for_excel(output_file))
state["markdown_error"] = None
except Exception as exc:
state["markdown_error"] = str(exc)
logger.warning("Markdown 产物写入失败 | task_id=%s", task_id, exc_info=True)
state["status"] = "done"
state["generated_count"] = len(rules)
state["skipped_domains"] = skipped_domains
state["skipped_rules"] = skipped_rules
state["current_domain"] = ""
state["finished_at"] = datetime.now().isoformat()
if not rules:
state["status"] = "failed"
state["error"] = "LLM 未成功生成任何规则,未使用兜底规则。"
self._write_state(task_id, state)
except Exception as exc:
logger.error("规则生成任务失败 | task_id=%s", task_id, exc_info=True)
state["status"] = "failed"
state["error"] = str(exc)
state["finished_at"] = datetime.now().isoformat()
try:
self._write_excel([], output_file)
except Exception:
logger.warning("规则生成失败后写空 Excel 失败 | task_id=%s", task_id, exc_info=True)
try:
self._write_markdown([], state.get("markdown_file") or self._markdown_file_for_excel(output_file))
except Exception as markdown_exc:
state["markdown_error"] = str(markdown_exc)
logger.warning("规则生成失败后写空 Markdown 失败 | task_id=%s", task_id, exc_info=True)
self._write_state(task_id, state)
@staticmethod
def _normalize_limit(limit: int) -> int:
try:
value = int(limit)
except (TypeError, ValueError):
value = 30
return max(1, min(value, 30))
@staticmethod
def _new_task_id() -> str:
return f"{datetime.now().strftime('%Y%m%d-%H%M%S')}-{uuid.uuid4().hex[:6]}"
def _state_path(self, task_id: str) -> str:
safe_task_id = re.sub(r"[^0-9A-Za-z_-]", "", task_id)
return os.path.join(self.task_dir, f"{safe_task_id}.json")
def _write_state(self, task_id: str, state: dict[str, Any]) -> None:
path = self._state_path(task_id)
temp_path = f"{path}.tmp"
with open(temp_path, "w", encoding="utf-8") as file:
json.dump(state, file, ensure_ascii=False, indent=2)
os.replace(temp_path, path)
@staticmethod
def _read_json(path: str) -> dict[str, Any]:
with open(path, "r", encoding="utf-8") as file:
return json.load(file)
def _collect_policy_basis(self) -> list[dict[str, Any]]:
data = self._read_json(self.domains_file)
result = []
for domain in data.get("domains", []):
patterns = []
for file_record in domain.get("guidance_files", []):
analysis = file_record.get("guidance_analysis") or {}
if analysis.get("status") != "done":
continue
for pattern in analysis.get("description_patterns", []):
if pattern.get("description_pattern") or pattern.get("basis_text"):
patterns.append(pattern)
if patterns:
result.append({
"token": domain.get("token", ""),
"domain": domain.get("domain", ""),
"patterns": patterns,
})
return result
def _skipped_domains_without_basis(self, domains_with_basis: list[dict[str, Any]]) -> list[dict[str, str]]:
data = self._read_json(self.domains_file)
tokens_ready = {item["token"] for item in domains_with_basis}
skipped = []
for domain in data.get("domains", []):
domain_name = domain.get("domain", "")
if domain.get("token", "") not in tokens_ready:
skipped.append({"domain": domain_name, "reason": "未上传指导文件或未执行指导文件解析"})
return skipped
@staticmethod
def _select_policy_points(domains: list[dict[str, Any]], limit: int) -> list[dict[str, Any]]:
selected = []
for domain_item in domains:
patterns = list(domain_item.get("patterns", []))
if not patterns:
continue
if len(patterns) >= limit:
chosen_patterns = random.sample(patterns, limit)
else:
chosen_patterns = patterns[:]
chosen_patterns.extend(random.choice(patterns) for _ in range(limit - len(patterns)))
for variant_index, pattern in enumerate(chosen_patterns, start=1):
selected.append({
"domain": domain_item["domain"],
"pattern": pattern,
"variant_index": variant_index,
"variant_total": limit,
})
return selected
def _read_schema(self) -> list[dict[str, Any]]:
data = self._read_json(self.schema_file)
modules = data.get("modules", [])
return [module for module in modules if module.get("module_name") and module.get("fields")]
def _select_schema_candidates(
self,
domain: str,
pattern: dict[str, Any],
schema: list[dict[str, Any]],
) -> list[dict[str, Any]]:
text = " ".join([
domain,
pattern.get("description_pattern", ""),
pattern.get("basis_text", ""),
pattern.get("source_sentence", ""),
" ".join(self._normalize_keywords(pattern)),
])
hint_order: list[str] = []
for hint in _DOMAIN_TABLE_HINTS.get(domain, []):
if hint not in hint_order:
hint_order.append(hint)
for keyword, table_names in _KEYWORD_TABLE_HINTS.items():
if keyword in text:
for table_name in table_names:
if table_name not in hint_order:
hint_order.append(table_name)
scored = []
for module in schema:
module_name = self._clean_module_name(module.get("module_name", ""))
haystack = " ".join([
module_name,
module.get("description", ""),
" ".join(field.get("name", "") for field in module.get("fields", [])),
])
score = 0
rank = self._hint_rank(module_name, hint_order)
if rank <= len(hint_order):
score += 60 - min(rank, 40)
score += sum(1 for keyword in self._keywords_from_text(text) if keyword and keyword in haystack)
roles = self._field_roles(module)
if roles["amount"]:
score += 4
if roles["subject"]:
score += 4
if score > 0:
scored.append((score, module))
scored.sort(key=lambda item: item[0], reverse=True)
return [self._summarize_module(module) for _, module in scored[:4]]
@staticmethod
def _hint_rank(module_name: str, hint_order: list[str]) -> int:
for index, hint in enumerate(hint_order):
if hint in module_name or module_name in hint:
return index
return len(hint_order) + 1
@staticmethod
def _clean_module_name(module_name: str) -> str:
return re.sub(r"^\s*\d+\s*[、.)\-_]*\s*", "", module_name or "").strip()
@staticmethod
def _keywords_from_text(text: str) -> set[str]:
words = set()
for word in re.findall(r"[\u4e00-\u9fa5A-Za-z0-9]{2,}", text or ""):
if len(word) >= 2:
words.add(word)
return words
@staticmethod
def _normalize_keywords(pattern: dict[str, Any]) -> list[str]:
raw_keywords = pattern.get("keywords") or []
if isinstance(raw_keywords, list):
return [str(item).strip() for item in raw_keywords if str(item).strip()]
if isinstance(raw_keywords, str):
return [raw_keywords.strip()] if raw_keywords.strip() else []
return []
def _summarize_module(self, module: dict[str, Any]) -> dict[str, Any]:
roles = self._field_roles(module)
ordered_fields = []
for role in ("subject", "amount", "status", "date", "counterparty"):
ordered_fields.extend(roles[role])
if not ordered_fields:
ordered_fields = module.get("fields", [])[:6]
seen = set()
concise_fields = []
for field in ordered_fields:
name = field.get("name", "")
if not name or name in seen:
continue
seen.add(name)
concise_fields.append({
"name": name,
"marker": field.get("marker", ""),
"type": field.get("type", ""),
"rule": field.get("rule", ""),
})
if len(concise_fields) >= 8:
break
return {
"module_name": self._clean_module_name(module.get("module_name", "")),
"table_name": module.get("table_name", ""),
"description": module.get("description", ""),
"fields": concise_fields,
}
def _field_roles(self, module: dict[str, Any]) -> dict[str, list[dict[str, Any]]]:
roles = {key: [] for key in _FIELD_ROLE_KEYWORDS}
for field in module.get("fields", []):
name = field.get("name", "")
for role, keywords in _FIELD_ROLE_KEYWORDS.items():
if any(keyword in name for keyword in keywords):
roles[role].append(field)
return roles
def _generate_rule(
self,
domain: str,
pattern: dict[str, Any],
schema_candidates: list[dict[str, Any]],
) -> dict[str, Any]:
last_error: Exception | None = None
for attempt in range(2):
try:
llm_rule = self._generate_rule_with_llm(domain, pattern, schema_candidates)
if not llm_rule:
raise ValueError("LLM 未返回可解析的规则 JSON")
return self._normalize_rule(llm_rule, domain, pattern, schema_candidates)
except Exception as exc:
last_error = exc
if attempt == 0:
logger.warning(
"LLM 规则生成失败,准备重试 | domain=%s | reason=%s",
domain,
exc,
)
continue
raise
raise ValueError(str(last_error) if last_error else "LLM 未返回可解析的规则 JSON")
def _generate_rule_with_llm(
self,
domain: str,
pattern: dict[str, Any],
schema_candidates: list[dict[str, Any]],
) -> dict[str, Any] | None:
from app.utils.llm import LLMClient, strip_thinking
prompt = self._build_rule_prompt(domain, pattern, schema_candidates)
response = LLMClient(timeout=120).chat(
messages=[
{
"role": "system",
"content": (
"你是央企监管规则清单整理助手。"
"你的任务是把制度依据整理成适合 Excel 展示的业务梳理稿,"
"不是发明新规则,不要写空话。"
"只输出指定键值块,不输出解释。"
),
},
{"role": "user", "content": prompt},
],
temperature=0.2,
max_tokens=2600,
thinking={"type": "disabled"},
)
return self._parse_rule_response(strip_thinking(response))
def _build_rule_prompt(self, domain: str, pattern: dict[str, Any], schema_candidates: list[dict[str, Any]]) -> str:
policy_basis = self._build_policy_basis_text(pattern)
source_sentence = pattern.get("source_sentence", "")
indicator_guidance = self._build_indicator_guidance(domain, pattern, schema_candidates)
variant_guidance = self._build_variant_guidance(pattern)
return f"""请根据制度依据和可用数据表,整理一条适合 Excel 规则清单展示的业务规则说明,输出 JSON 对象。
风险领域:{domain}
制度依据:{policy_basis}
原文依据:{source_sentence}
监管维度:{pattern.get('supervision_dimension', '')}
可用数据表和字段:
{json.dumps(schema_candidates, ensure_ascii=False, indent=2)}
指标口径选择建议:
{indicator_guidance}
同一监管点多规则生成要求:
{variant_guidance}
写作要求:
1. 只能使用上面列出的表和字段,不得编造表名、字段名。
2. 输出要像人工梳理稿,语言克制、具体、可追溯。
3. 模型名称要短不超过12个字。
4. 模型风险描述只写一句话,先写业务现象,再写风险后果。
5. 制度依据必须原样使用输入中的“制度依据”,不得改写、删减或另写条款要点。
6. 业务规则描述对应《国能poc演示清单》的“业务检查方式”核心是模拟查询指标。必须由两段组成并用换行分隔
第一段写公式/计算口径;如果有多个公式,每个公式必须单独换行。
第二段写指标阈值或红/橙/黄风险等级;阈值必须和上方公式指标一一对应,两个公式就写两个指标的阈值,一个公式就只写一个指标的阈值。
指标类型要根据政策依据和字段动态选择,不限于比例;可以是金额、余额、数量、笔数、次数、天数、期限、集中度、占比、收益率、状态标记等。
示例:逾期票据金额、逾期票据数、合同类型数量、单一类型占比、投资收益率、短期债务比例、偿债能力指标、担保余额、资金回流次数。
7. 数据来源必须动态分析 schema只按“TABLE_01(表名)MARKER(字段中文名)、MARKER(字段中文名)”输出;每张表最多列 2-4 个关键字段,字段英文名必须来自 schema.marker不要写解释。多张表必须换行严禁用分号连在一行。
8. 系统规则文本必须仿照参考文件 H 列,不是自然语言摘要。必须使用“规则名(阈值区间): 如果:...并且...就:校验:通过 赋值:...”的规则伪代码;优先使用“数值求和{{【表名】,【表名-字段名】}}”“大于/小于等于”等表达。
9. 关联表逻辑必须仿照参考文件 I 列,写清楚主表、左关联表、关联键、筛选和分组口径。
10. 系统固化逻辑必须仿照参考文件 J 列写“值1...值2...值3...”中间指标定义和最终公式。
11. 返回结果只列建议输出字段。
12. 不要出现“基于…识别…风险”“多维穿透分析”“结合业务场景综合判断”等套话。
13. 必须完整输出 9 个字段,不能省略任何字段,字段内容未知也要根据已给 schema 和制度依据合理生成:规则名称、风险描述、制度依据、业务规则描述、数据来源、系统规则文本、关联表逻辑、系统固化逻辑、返回结果。
示例只用于仿写格式,不得照抄示例业务内容;实际内容必须来自本次制度依据和可用数据表字段。
参考风格示例:
示例1
- 模型名称:资产负债率超限预警
- 模型风险描述:债务类金额占可用资金口径偏高时,企业存在偿债压力和杠杆失衡风险。
- 制度依据:核心法规:《关于加强国有企业资产负债约束的指导意见》;条款要点:分行业设置资产负债率预警线和重点监管线,持续监测高负债企业债务风险。
- 业务规则描述:资产负债率 = 贷款余额 / 账户余额\n风险等级:\n- 红色:资产负债率 > 85%\n- 橙色:资产负债率 > 75% 且 ≤ 85%\n- 黄色:资产负债率 > 65% 且 ≤ 75%。
- 数据来源:银行贷款:所属集团编码、所属集团名称、贷款余额\n银行账户:所属集团编码、所属集团名称、账户余额。
- 系统规则文本资产负债率超限预警65%<资产负债率<=75%:\n如果:\n (【银行贷款表-贷款单位编码】等于【银行账户表-开户单位编码】)\n 并且 (数值求和{{【银行贷款表】,【银行贷款表-贷款余额】}}/数值求和{{【银行账户表】,【银行账户表-账户余额】}} 大于 0.65)\n 并且 (数值求和{{【银行贷款表】,【银行贷款表-贷款余额】}}/数值求和{{【银行账户表】,【银行账户表-账户余额】}} 小于等于 0.75)\n就:\n校验:通过\n赋值:【规则结果-风险等级】等于 黄色
- 关联表逻辑:银行贷款表左关联银行账户表(贷款单位编码=开户单位编码),按所属集团编码、所属集团名称、贷款单位编码汇总。
- 系统固化逻辑值1贷款余额本币=SUM(银行贷款.贷款余额)值2账户余额本币=SUM(银行账户.账户余额)值3资产负债率=值1/NULLIF(值2,0)按65%、75%、85%阈值输出风险等级。
示例2
- 模型名称:债务集中度过高风险
- 模型风险描述:债务集中于单一金融机构时,融资渠道较为单一,存在集中兑付风险。
- 制度依据:核心法规:《中央企业债券发行管理办法》;条款要点:分散融资渠道,防范债务集中到期和单一机构依赖。
- 业务规则描述:债务集中度 = 单一金融机构债务余额 / 总债务余额\n风险等级:\n- 红色:债务集中度 > 40%\n- 橙色:债务集中度 > 30% 且 ≤ 40%\n- 黄色:债务集中度 > 20% 且 ≤ 30%。
- 数据来源:银行贷款:放款单位名称、贷款余额\n应付债券:承销商、债券余额。
- 系统规则文本债务集中度过高风险20%<集中度<=30%:\n如果:\n (【银行贷款表-放款单位名称】等于【应付债券表-承销商】)\n 并且 ((数值求和{{【银行贷款表】,【银行贷款表-贷款余额】}}+数值求和{{【应付债券表】,【应付债券表-债券余额】}})/数值求和{{【银行贷款表】,【银行贷款表-贷款余额】}} 大于 0.20)\n 并且 (上述比例 小于等于 0.30)\n就:\n校验:通过\n赋值:【规则结果-风险等级】等于 黄色
- 关联表逻辑:从银行贷款表按放款单位名称分组统计贷款余额;从应付债券表按承销商分组统计债券余额;按金融机构名称口径合并计算单一机构债务余额。
- 系统固化逻辑值1单一金融机构债务余额=SUM(银行贷款.贷款余额)+SUM(应付债券.债券余额)值2总债务余额=SUM(全部贷款余额+全部债券余额)值3债务集中度=值1/NULLIF(值2,0)。
示例3
- 模型名称:逾期风险预警
- 模型风险描述:票据或债务出现逾期金额较大、逾期笔数较多时,企业存在偿债困难风险。
- 制度依据:核心法规:《中央企业债券发行管理办法》;条款要点:建立债务到期预警机制,严禁恶意逃废债,逾期债务严肃追责。
- 业务规则描述:逾期票据金额 = 是否逾期为“是”的票据金额合计\n逾期票据数 = 是否逾期为“是”的票据记录数\n风险等级:\n- 红色:逾期票据金额 > 5000万元 或 逾期票据数 > 10笔\n- 橙色:逾期票据金额 > 1000万元 或 逾期票据数 > 5笔\n- 黄色:逾期票据金额 > 500万元 或 逾期票据数 > 3笔。
- 数据来源:应付票据:是否逾期、到期日期、票面金额。
- 系统规则文本:逾期风险预警(逾期票据金额>500万元或逾期票据数>3笔:\n如果:\n (【应付票据表-是否逾期】等于 是)\n 并且 (数值求和{{【应付票据表】,【应付票据表-票面金额】}} 大于 5000000 或 记录计数{{【应付票据表】,【应付票据表-票据编号】}} 大于 3)\n就:\n校验:通过\n赋值:【规则结果-风险等级】等于 黄色
- 关联表逻辑:从应付票据表筛选是否逾期为“是”的记录,按所属集团编码、所属集团名称分组统计逾期票据金额和逾期票据数。
- 系统固化逻辑值1逾期票据金额=SUM(应付票据.票面金额)值2逾期票据数=COUNT(应付票据.票据编号)值3按金额和笔数阈值输出风险等级。
输出格式:
请严格使用以下键名,每个键独占一行;多行内容写在下一行,并保持到下一个键名前结束。
不要输出 JSON不要输出 Markdown 代码块。
也可以使用括号内中文键名,但必须 9 个字段全部输出。
rule_name:
risk_description:
policy_basis:
business_rule_description:
data_sources:
system_rule_text:
join_logic:
system_logic:
return_fields:
中文键名等价写法:
规则名称:
风险描述:
制度依据:
业务规则描述:
数据来源:
系统规则文本:
关联表逻辑:
系统固化逻辑:
返回结果:"""
def _build_indicator_guidance(
self,
domain: str,
pattern: dict[str, Any],
schema_candidates: list[dict[str, Any]],
) -> str:
text = " ".join([
domain,
pattern.get("description_pattern", ""),
pattern.get("basis_text", ""),
pattern.get("source_sentence", ""),
" ".join(self._normalize_keywords(pattern)),
" ".join(field.get("name", "") for table in schema_candidates for field in table.get("fields", [])),
])
roles = self._schema_indicator_roles(schema_candidates)
guidance = [
"先判断本条制度最适合的指标类型,不要默认生成 A/B 比率。",
"参考国能 POC金额/余额、数量/笔数、期限/天数、状态标记、占比/比例、收益率都应按业务场景混合出现。",
]
preferred: list[str] = []
if any(word in text for word in ("逾期", "到期", "兑付", "超期")):
preferred.extend([
"金额类:逾期金额/到期金额/兑付金额 = SUM(相关金额字段)",
"数量类:逾期笔数/到期笔数 = COUNT(满足状态或日期条件的记录)",
"期限类:逾期天数/剩余期限 = 当前日期与到期日期、数据日期的差值",
])
if any(word in text for word in ("分散", "集中", "多元", "单一", "占比", "比例")):
preferred.extend([
"数量类:不同类型数量/机构数量/客商数量 = COUNT(DISTINCT 分类或对象字段)",
"占比类:单一对象金额占比/前N对象金额占比 = 分组金额 / 总金额",
])
if any(word in text for word in ("投资", "收益", "市值", "成本")):
preferred.extend([
"收益类:投资收益金额 = 市值 - 投资成本",
"收益率类:投资收益率 = (市值 - 投资成本) / NULLIF(投资成本, 0)",
"期限类:持有天数 = 当前日期 - 投资开始日期",
])
if any(word in text for word in ("交易", "资金", "薪酬", "回流", "发放")):
preferred.extend([
"金额类:单笔金额/单日总额/资金回流金额 = SUM 或直接取金额字段",
"次数类:交易次数/回流次数/发放笔数 = COUNT(交易流水或记录)",
"状态类:按收支标记、渠道、账户状态、内部标识筛选或分组",
])
if any(word in text for word in ("担保", "授信", "贷款", "债务", "负债")):
preferred.extend([
"余额类:担保余额/贷款余额/债券余额/授信余额 = SUM(余额字段)",
"金额类:超额金额/敞口金额 = 两个金额或余额口径相减",
"比例类只在制度明确要求规模占比、负债率、集中度时使用。",
])
if not preferred:
if roles["amount"]:
preferred.append("金额/余额类优先:围绕金额、余额、总额字段生成 SUM 指标和阈值。")
if roles["status"]:
preferred.append("状态类可用:围绕状态、是否、标识、类型字段生成异常标记或条件计数。")
if roles["date"]:
preferred.append("期限/天数类可用:围绕日期、时间、期限字段生成滞后天数、到期天数、持有天数。")
preferred.append("仅当文本明确出现比例、占比、率、集中度时,才优先使用比率指标。")
guidance.extend(self._dedupe_preserve_order(preferred)[:6])
guidance.append("业务规则描述中可以同时给出 1-2 个指标,例如“金额 + 笔数”“余额 + 天数”“数量 + 占比”,阈值要逐一对应。")
return "\n".join(f"- {item}" for item in guidance)
def _schema_indicator_roles(self, schema_candidates: list[dict[str, Any]]) -> dict[str, bool]:
names = " ".join(field.get("name", "") for table in schema_candidates for field in table.get("fields", []))
markers = " ".join(field.get("marker", "") for table in schema_candidates for field in table.get("fields", []))
return {
"amount": any(word in names for word in ("金额", "余额", "总额", "资产", "负债", "成本", "市值")) or any(
word in markers.upper() for word in ("AMOUNT", "BALANCE", "ASSET", "LIABILITY", "COST", "SZ", "ZZYE")
),
"status": any(word in names for word in ("状态", "是否", "标识", "类型", "性质", "分类")),
"date": any(word in names for word in ("日期", "时间", "期限", "到期", "账龄")),
}
@staticmethod
def _build_variant_guidance(pattern: dict[str, Any]) -> str:
variant_index = int(pattern.get("_variant_index") or 1)
variant_total = int(pattern.get("_variant_total") or 1)
if variant_total <= 1:
return "- 本监管点只生成 1 条规则,选择最贴近制度依据的一个指标口径。"
focus_cycle = [
"金额/余额/规模类指标,例如 SUM(金额或余额字段)、超额金额、风险敞口。",
"数量/笔数/次数类指标,例如 COUNT 记录数、COUNT(DISTINCT 对象)、异常笔数。",
"期限/天数/状态类指标,例如 到期天数、逾期天数、状态标记、是否异常。",
"占比/比例/收益率类指标,仅当制度或字段明确支持比率口径时使用。",
]
focus = focus_cycle[(variant_index - 1) % len(focus_cycle)]
return (
f"- 当前是同一监管点的第 {variant_index}/{variant_total} 条规则。\n"
f"- 本条优先侧重:{focus}\n"
"- 同一监管点的多条规则不得只改名称或阈值;必须使用不同指标、字段组合或筛选口径。"
)
@staticmethod
def _dedupe_preserve_order(items: list[str]) -> list[str]:
result = []
for item in items:
if item not in result:
result.append(item)
return result
@staticmethod
def _parse_json_object(response: str) -> dict[str, Any] | None:
content = re.sub(r"^```(?:json)?|```$", "", (response or "").strip(), flags=re.IGNORECASE)
if not content.startswith("{"):
return None
try:
data = json.loads(content)
return data if isinstance(data, dict) else None
except json.JSONDecodeError:
match = re.search(r"\{.*\}", content, flags=re.DOTALL)
if not match:
return None
try:
data = json.loads(match.group(0))
return data if isinstance(data, dict) else None
except json.JSONDecodeError:
return None
def _parse_rule_response(self, response: str) -> dict[str, Any] | None:
json_result = self._parse_json_object(response)
if json_result:
return self._normalize_response_keys(json_result)
content = re.sub(r"^```(?:yaml|text)?|```$", "", (response or "").strip(), flags=re.IGNORECASE)
key_aliases = self._rule_key_aliases()
key_pattern = "|".join(re.escape(key) for key in key_aliases)
matches = list(re.finditer(rf"(?m)^({key_pattern})\s*[:]\s*", content))
if not matches:
return None
parsed: dict[str, str] = {}
for index, match in enumerate(matches):
key = key_aliases[match.group(1)]
start = match.end()
end = matches[index + 1].start() if index + 1 < len(matches) else len(content)
parsed[key] = content[start:end].strip()
return parsed
@staticmethod
def _rule_key_aliases() -> dict[str, str]:
return {
"rule_name": "rule_name",
"规则名称": "rule_name",
"risk_description": "risk_description",
"风险描述": "risk_description",
"policy_basis": "policy_basis",
"制度依据": "policy_basis",
"business_rule_description": "business_rule_description",
"业务规则描述": "business_rule_description",
"data_sources": "data_sources",
"数据来源": "data_sources",
"system_rule_text": "system_rule_text",
"系统规则文本": "system_rule_text",
"join_logic": "join_logic",
"关联表逻辑": "join_logic",
"system_logic": "system_logic",
"系统固化逻辑": "system_logic",
"return_fields": "return_fields",
"返回结果": "return_fields",
}
@classmethod
def _normalize_response_keys(cls, raw: dict[str, Any]) -> dict[str, Any]:
aliases = cls._rule_key_aliases()
normalized: dict[str, Any] = {}
for key, value in raw.items():
normalized[aliases.get(str(key), str(key))] = value
return normalized
def _normalize_rule(
self,
raw: dict[str, Any],
domain: str,
pattern: dict[str, Any],
schema_candidates: list[dict[str, Any]],
) -> dict[str, Any]:
required_fields = {
"rule_name": "规则名称",
"risk_description": "风险描述",
"policy_basis": "制度依据",
"business_rule_description": "业务规则描述",
"data_sources": "数据来源",
"system_rule_text": "系统规则文本",
"join_logic": "关联表逻辑",
"system_logic": "系统固化逻辑",
"return_fields": "返回结果",
}
core_fields = ("rule_name", "risk_description", "business_rule_description")
missing = [required_fields[key] for key in core_fields if not str(raw.get(key, "")).strip()]
if missing:
raise ValueError(f"LLM 返回规则字段不完整: {', '.join(missing)}")
normalized = {
"risk_domain": domain,
**{key: self._model_value_to_text(raw.get(key, "")).strip() for key in required_fields},
}
normalized["policy_basis"] = self._build_policy_basis_text(pattern)
normalized["business_rule_description"] = self._normalize_business_rule_description(
normalized["business_rule_description"]
)
if not normalized["data_sources"]:
normalized["data_sources"] = self._format_data_sources(schema_candidates)
normalized["data_sources"] = self._format_data_sources_with_markers(
normalized["data_sources"],
schema_candidates,
)
if not self._looks_like_join_logic(normalized["join_logic"]):
normalized["join_logic"] = self._format_join_logic(schema_candidates)
if not self._looks_like_system_logic(normalized["system_logic"]):
normalized["system_logic"] = self._format_system_logic_from_business_rule(
normalized["business_rule_description"]
)
metric_name = self._primary_metric_name(normalized["business_rule_description"])
if not normalized["return_fields"]:
output_fields = self._pick_fields_by_role(schema_candidates, ("subject", "amount", "status", "date"), 6)
normalized["return_fields"] = self._format_return_fields(output_fields, metric_name)
normalized["system_rule_text"] = self._normalize_system_rule_text(normalized["system_rule_text"])
if not self._looks_like_system_rule_text(normalized["system_rule_text"]):
normalized["system_rule_text"] = self._format_system_rule_text_from_business_rule(
normalized["rule_name"],
normalized["business_rule_description"],
)
self._validate_rule_shape(normalized)
return normalized
@classmethod
def _model_value_to_text(cls, value: Any) -> str:
if value is None:
return ""
if isinstance(value, str):
return value
if isinstance(value, (int, float, bool)):
return str(value)
if isinstance(value, list):
if all(not isinstance(item, (dict, list)) for item in value):
return ", ".join(cls._model_value_to_text(item) for item in value)
return "\n".join(cls._model_value_to_text(item) for item in value if cls._model_value_to_text(item))
if isinstance(value, dict):
table = value.get("table") or value.get("table_name") or value.get("module_name")
fields = value.get("fields")
if table and fields:
if isinstance(fields, list):
field_text = ", ".join(cls._model_value_to_text(item) for item in fields)
else:
field_text = cls._model_value_to_text(fields)
return f"{table}: {field_text}"
return "\n".join(f"{key}{cls._model_value_to_text(item)}" for key, item in value.items())
return str(value)
@staticmethod
def _normalize_business_rule_description(text: str) -> str:
raw = (text or "").replace("\\n", "\n").strip()
if "\n" in raw:
lines = [re.sub(r"\s+", " ", line).strip() for line in raw.splitlines() if line.strip()]
content = " ".join(lines)
else:
content = re.sub(r"\s+", " ", raw).strip()
if not content:
return ""
split_patterns = [
r"(风险等级[:])",
r"(指标阈值[:])",
r"(预警标准[:])",
r"((?:^|[。;; ])红色[:])",
r"((?:^|[。;; ])黄色[:])",
r"(-\s*红色[:])",
]
for pattern in split_patterns:
match = re.search(pattern, content)
if match and match.start() > 0:
formula = content[:match.start()].strip("。;; ")
thresholds = content[match.start():].strip()
return (
f"{RuleGenerationService._format_formula_lines(formula)}\n"
f"{RuleGenerationService._format_threshold_lines(thresholds)}"
)
return content
@staticmethod
def _format_formula_lines(text: str) -> str:
content = (text or "").strip("。;; ")
if not content:
return ""
parts = re.split(r"[;]\s*(?=[^;。\n]{1,40}=)", content)
if len(parts) == 1:
parts = re.split(r"(?<=。)\s*(?=[^。;;\n]{1,40}=)", content)
lines = [part.strip("。;; ") for part in parts if part.strip("。;; ")]
return "\n".join(lines)
@staticmethod
def _format_threshold_lines(text: str) -> str:
content = (text or "").strip()
if not content:
return ""
content = re.sub(r"^风险等级[:]\s*", "风险等级:", content)
for color in ("红色", "橙色", "黄色"):
content = re.sub(rf"[;]\s*{color}[:]", f"\n- {color}", content)
content = re.sub(rf"\s+-\s*{color}[:]", f"\n- {color}", content)
content = re.sub(rf"(?<![-\n]){color}[:]", f"\n- {color}", content)
content = re.sub(r"风险等级[:]\s*\n", "风险等级:\n", content)
content = re.sub(r"风险等级[:]\s*-", "风险等级:\n-", content)
content = re.sub(r"\n{2,}", "\n", content)
return content.strip()
@staticmethod
def _normalize_data_sources_text(data_sources: str) -> str:
text = re.sub(r"\s+", " ", data_sources or "").strip()
if not text:
return ""
text = text.replace("", ";")
segments = [segment.strip(" ;") for segment in text.split(";") if segment.strip(" ;")]
if len(segments) <= 1:
return data_sources.strip()
return "\n".join(segments)
def _format_data_sources_with_markers(
self,
data_sources: str,
schema_candidates: list[dict[str, Any]],
) -> str:
normalized = self._normalize_data_sources_text(data_sources)
parsed_sources = self._parse_data_source_lines(normalized)
if not parsed_sources:
return normalized
lines = []
for table_label, field_labels in parsed_sources:
module, index = self._match_schema_module(table_label, schema_candidates)
if not module:
fields = ", ".join(field_labels)
lines.append(f"{table_label}: {fields}" if fields else table_label)
continue
table_display = self._data_source_table_display(module, index)
field_displays = self._data_source_field_displays(field_labels, module)
if not field_displays:
field_displays = [
display
for field in module.get("fields", [])[:4]
if (display := self._data_source_field_display(field))
]
lines.append(f"{table_display}: {', '.join(field_displays)}")
return "\n".join(lines)
def _data_source_table_display(self, module: dict[str, Any], index: int) -> str:
table_name = self._clean_module_name(module.get("module_name", ""))
return f"{self._schema_table_identifier(module, index)}({table_name})"
def _data_source_field_displays(self, field_labels: list[str], module: dict[str, Any]) -> list[str]:
displays = []
for label in field_labels:
field = self._match_schema_field(label, module)
display = self._data_source_field_display(field) if field else label
if display and display not in displays:
displays.append(display)
return displays
def _match_schema_field(self, label: str, module: dict[str, Any]) -> dict[str, Any] | None:
safe_label = self._safe_sql_identifier(label, "")
for field in module.get("fields", []):
marker = self._safe_sql_identifier(field.get("marker", ""), "")
name = str(field.get("name", "") or "")
if safe_label and marker == safe_label:
return field
if label and name and (label == name or label in name or name in label):
return field
return None
def _data_source_field_display(self, field: dict[str, Any]) -> str:
marker = self._safe_sql_identifier(field.get("marker", ""), "")
name = str(field.get("name", "") or "").strip()
if marker and name:
return f"{marker}({name})"
return marker or name
@staticmethod
def _normalize_system_rule_text(text: str) -> str:
raw = (text or "").replace("\\n", "\n").strip()
if not raw:
return ""
return "\n".join(re.sub(r"\s+", " ", line).strip() for line in raw.splitlines() if line.strip())
@staticmethod
def _looks_like_system_rule_text(text: str) -> bool:
content = text or ""
return all(keyword in content for keyword in ("如果", "")) and any(
keyword in content for keyword in ("赋值", "校验")
)
@staticmethod
def _looks_like_join_logic(text: str) -> bool:
content = text or ""
return ("" in content or "TABLE_" in content.upper()) and any(
keyword in content for keyword in ("关联", "筛选", "分组", "汇总", "统计")
)
@staticmethod
def _looks_like_system_logic(text: str) -> bool:
content = text or ""
return "值1" in content and any(keyword in content for keyword in ("值2", "阈值", "风险等级"))
@staticmethod
def _formula_metric_names(business_rule: str) -> list[str]:
formula_part = (business_rule or "").split("风险等级", 1)[0]
metric_names = []
for line in formula_part.splitlines():
if "=" not in line:
continue
metric_name = line.split("=", 1)[0].strip()
metric_name = re.sub(
r"^(?:指标[一二三四五六七八九十\d]+|公式[一二三四五六七八九十\d]+|值\d+)\s*[::、.)-]*\s*",
"",
metric_name,
).strip()
if metric_name and metric_name not in metric_names:
metric_names.append(metric_name)
return metric_names
@staticmethod
def _primary_metric_name(business_rule: str) -> str:
metric_names = RuleGenerationService._formula_metric_names(business_rule)
return metric_names[0] if metric_names else "核心指标"
@staticmethod
def _threshold_part(business_rule: str) -> str:
return business_rule.split("风险等级", 1)[1] if "风险等级" in business_rule else ""
@staticmethod
def _threshold_lines(business_rule: str) -> list[str]:
threshold_part = RuleGenerationService._threshold_part(business_rule)
return [
line.strip()
for line in threshold_part.splitlines()
if line.strip() and any(color in line for color in ("红色", "橙色", "黄色"))
]
@staticmethod
def _format_system_logic_from_business_rule(business_rule: str) -> str:
formula_part = (business_rule or "").split("风险等级", 1)[0]
formula_lines = [line.strip() for line in formula_part.splitlines() if "=" in line]
segments = []
for index, line in enumerate(formula_lines[:3], start=1):
metric_name, expression = [part.strip() for part in line.split("=", 1)]
segments.append(f"{index}{metric_name}={expression}")
if not segments:
segments.append("值1核心指标=按数据来源字段汇总计算")
if len(segments) == 1:
segments.append("值2风险阈值=按红色、橙色、黄色等级阈值判断")
return "".join(segments) + ";按阈值输出风险等级。"
@staticmethod
def _format_system_rule_text_from_business_rule(rule_name: str, business_rule: str) -> str:
threshold_lines = RuleGenerationService._threshold_lines(business_rule)
conditions = []
for line in threshold_lines[:3]:
condition = re.sub(r"^\s*[-*]\s*", "", line).strip()
if condition:
conditions.append(f" ({condition})")
if not conditions:
conditions.append(f" ({RuleGenerationService._primary_metric_name(business_rule)}达到风险阈值)")
condition_block = "\n 或者 ".join(conditions)
return (
f"{rule_name}(风险阈值命中):\n"
"如果:\n"
f"{condition_block}\n"
"就:\n"
"校验:通过\n"
"赋值:【规则结果-风险等级】等于 对应风险等级"
)
@staticmethod
def _validate_rule_shape(rule: dict[str, str]) -> None:
business_rule = rule.get("business_rule_description", "")
system_rule = rule.get("system_rule_text", "")
join_logic = rule.get("join_logic", "")
system_logic = rule.get("system_logic", "")
if " = " not in business_rule and "=" not in business_rule:
raise ValueError("LLM 返回的业务规则描述缺少指标公式")
if "\n" not in business_rule:
raise ValueError("LLM 返回的业务规则描述未按公式和指标阈值分行")
if not all(keyword in business_rule for keyword in ("红色", "橙色", "黄色")):
raise ValueError("LLM 返回的业务规则描述缺少红/橙/黄风险等级")
RuleGenerationService._validate_business_rule_indicator_alignment(business_rule)
if not RuleGenerationService._looks_like_system_rule_text(system_rule):
raise ValueError("LLM 返回的系统规则文本未按 H 列规则伪代码格式生成")
if not RuleGenerationService._looks_like_join_logic(join_logic):
raise ValueError("LLM 返回的关联表逻辑缺少表关联、筛选或分组口径")
if not RuleGenerationService._looks_like_system_logic(system_logic):
raise ValueError("LLM 返回的系统固化逻辑缺少值1/值2中间指标定义")
@staticmethod
def _validate_business_rule_indicator_alignment(business_rule: str) -> None:
metric_names = RuleGenerationService._formula_metric_names(business_rule)
if not metric_names:
raise ValueError("LLM 返回的业务规则描述缺少可识别指标名称")
threshold_part = RuleGenerationService._threshold_part(business_rule)
missing = [
name
for name in metric_names
if not RuleGenerationService._metric_name_matches_threshold(name, threshold_part)
]
if missing:
raise ValueError(f"业务规则描述中的指标阈值未对应公式指标: {', '.join(missing)}")
@staticmethod
def _metric_name_matches_threshold(metric_name: str, threshold_part: str) -> bool:
metric = RuleGenerationService._normalize_metric_text(metric_name)
threshold = RuleGenerationService._normalize_metric_text(threshold_part)
if not metric or metric in threshold:
return True
category_keywords = RuleGenerationService._metric_category_keywords(metric)
if category_keywords and not any(keyword in threshold for keyword in category_keywords):
return False
tokens = RuleGenerationService._metric_match_tokens(metric_name)
long_tokens = [token for token in tokens if len(token) >= 4]
if any(token in threshold for token in long_tokens):
return True
short_tokens = [token for token in tokens if len(token) < 4]
return (bool(category_keywords) or not long_tokens) and any(token in threshold for token in short_tokens)
@staticmethod
def _normalize_metric_text(text: str) -> str:
return re.sub(r"[^0-9A-Za-z\u4e00-\u9fff]+", "", text or "")
@staticmethod
def _metric_category_keywords(metric: str) -> tuple[str, ...]:
if any(keyword in metric for keyword in ("收益", "回报")):
return ("收益", "收益率", "回报")
if any(keyword in metric for keyword in ("金额", "余额", "总额", "规模", "额度")):
return ("金额", "余额", "总额", "规模", "额度", "债务", "投资")
if any(keyword in metric for keyword in ("笔数", "数量", "次数", "个数", "户数")):
return ("笔数", "数量", "次数", "个数", "户数")
if any(keyword in metric for keyword in ("占比", "比例", "比率", "负债率", "收益率", "集中度")):
return ("占比", "比例", "比率", "", "集中度")
return ()
@staticmethod
def _metric_match_tokens(metric_name: str) -> list[str]:
base = RuleGenerationService._normalize_metric_text(metric_name)
variants = {base}
suffixes = (
"收益率",
"集中度",
"占比",
"比例",
"比率",
"金额",
"余额",
"总额",
"数量",
"笔数",
"次数",
"天数",
"规模",
"额度",
"水平",
"指标",
"风险",
"",
)
for _ in range(2):
for candidate in list(variants):
for prefix in ("", "集团"):
if candidate.startswith(prefix) and len(candidate) > len(prefix):
variants.add(candidate[len(prefix):])
for word in ("结构", "类别", "类型", ""):
if word in candidate and len(candidate) > len(word):
variants.add(candidate.replace(word, ""))
for suffix in suffixes:
if candidate.endswith(suffix) and len(candidate) > len(suffix):
variants.add(candidate[:-len(suffix)])
return [
token
for token in sorted(variants, key=len, reverse=True)
if len(token) >= 2 and token not in {"指标", "风险", "金额", "余额", "数量", "比例", "比率", "占比", ""}
]
def _build_policy_basis_text(self, pattern: dict[str, Any]) -> str:
regulations = [str(item).strip() for item in (pattern.get("core_regulations") or []) if str(item).strip()]
description = (pattern.get("description_pattern") or "").strip()
if not regulations:
regulation_match = re.search(r"核心法规:([^;。]+)", description)
if regulation_match:
regulations = [regulation_match.group(1).strip()]
regulation_text = "".join(regulations) or "相关监管制度"
basis_text = (pattern.get("basis_text") or "").strip()
if "条款要点:" in description:
clause = description.split("条款要点:", 1)[1].strip(";。 ")
else:
clause = basis_text.strip(";。 ")
if not clause:
clause = "围绕关键风险指标、主体责任和业务口径持续开展监测。"
clause = re.sub(r"\s+", "", clause)
return f"核心法规:{regulation_text};条款要点:{clause}"
@staticmethod
def _pick_fields_by_role(schema_candidates: list[dict[str, Any]], roles: tuple[str, ...], limit: int) -> list[str]:
selected: list[str] = []
for role in roles:
for table in schema_candidates:
for field in table.get("fields", []):
name = field.get("name", "")
marker = field.get("marker", "")
if not name:
continue
if role == "subject" and any(keyword in name for keyword in _FIELD_ROLE_KEYWORDS["subject"]):
item = f"{table['module_name']}.{name}"
elif role == "amount" and any(keyword in name for keyword in _FIELD_ROLE_KEYWORDS["amount"]):
item = f"{table['module_name']}.{name}"
elif role == "status" and any(keyword in name for keyword in _FIELD_ROLE_KEYWORDS["status"]):
item = f"{table['module_name']}.{name}"
elif role == "date" and any(keyword in name for keyword in _FIELD_ROLE_KEYWORDS["date"]):
item = f"{table['module_name']}.{name}"
elif role == "counterparty" and any(keyword in name for keyword in _FIELD_ROLE_KEYWORDS["counterparty"]):
item = f"{table['module_name']}.{name}"
else:
continue
if item not in selected:
selected.append(item)
if len(selected) >= limit:
return selected
if marker and len(selected) >= limit:
return selected
return selected
def _format_data_sources(self, schema_candidates: list[dict[str, Any]]) -> str:
parts = []
for table in schema_candidates:
important_fields = []
for field in table.get("fields", []):
name = field.get("name", "")
if not name:
continue
if any(keyword in name for keywords in _FIELD_ROLE_KEYWORDS.values() for keyword in keywords):
important_fields.append(field)
if not important_fields:
important_fields = table.get("fields", [])[:3]
concise = []
for field in important_fields:
label = field["name"]
if label not in concise:
concise.append(label)
if len(concise) >= 4:
break
if concise:
parts.append(f"{table['module_name']}{''.join(concise)}")
return "\n".join(parts)
def _format_join_logic(self, schema_candidates: list[dict[str, Any]]) -> str:
if not schema_candidates:
return "未匹配到可关联表。"
primary = schema_candidates[0]["module_name"]
if len(schema_candidates) == 1:
return f"{primary}为主表,按所属集团编码、所属集团名称或单位编码汇总。"
join_texts = []
for table in schema_candidates[1:]:
related_name = table["module_name"]
join_key = self._infer_join_key(schema_candidates[0], table)
join_texts.append(f"{related_name}{join_key}")
return f"{primary}为主表,关联{''.join(join_texts)};按集团或单位维度汇总,必要时结合状态字段做筛选。"
def _infer_join_key(self, left: dict[str, Any], right: dict[str, Any]) -> str:
left_names = {field.get("name", "") for field in left.get("fields", [])}
right_names = {field.get("name", "") for field in right.get("fields", [])}
for key in ("所属集团编码", "开户单位编码", "单位编码", "客商编码", "合同编号", "单位账号"):
if key in left_names and key in right_names:
return key
return "所属集团编码/单位编码"
@staticmethod
def _format_input_params(schema_candidates: list[dict[str, Any]]) -> str:
selected = RuleGenerationService._pick_fields_by_role(schema_candidates, ("subject", "amount", "status", "date"), 6)
return "".join(selected)
@staticmethod
def _format_output_params(fields: list[str], metric_name: str) -> str:
base = [field for field in fields if field][:4]
base.extend([metric_name, "风险等级"])
deduped = []
for item in base:
if item not in deduped:
deduped.append(item)
return "".join(deduped)
@staticmethod
def _format_return_fields(fields: list[str], metric_name: str) -> str:
base = [field for field in fields if field][:6]
base.extend([metric_name, "风险等级"])
deduped = []
for item in base:
if item not in deduped:
deduped.append(item)
return "".join(deduped)
def _build_rule_rows(self, rules: list[dict[str, Any]]) -> list[dict[str, Any]]:
schema = self._read_schema()
rows = []
for index, rule in enumerate(rules, start=1):
row = {
"index": index,
"risk_domain": rule.get("risk_domain", ""),
"rule_name": rule.get("rule_name", ""),
"risk_description": rule.get("risk_description", ""),
"policy_basis": rule.get("policy_basis", ""),
"business_rule_description": self._normalize_business_rule_description(
rule.get("business_rule_description", "")
),
"data_sources": self._format_data_sources_with_markers(rule.get("data_sources", ""), schema),
"system_rule_text": rule.get("system_rule_text", ""),
"join_logic": rule.get("join_logic", ""),
"system_logic": rule.get("system_logic", ""),
"return_fields": rule.get("return_fields", ""),
}
if self.create_sql:
row["sql"] = self._create_sql_query(rule)
rows.append(row)
return rows
def _write_excel(self, rules: list[dict[str, Any]], output_file: str) -> None:
workbook = Workbook()
worksheet = workbook.active
worksheet.title = "11类重点风险领域"
columns = self._excel_columns()
worksheet.append(columns)
for rule_row in self._build_rule_rows(rules):
row = [
rule_row["index"],
rule_row["risk_domain"],
rule_row["rule_name"],
rule_row["risk_description"],
rule_row["policy_basis"],
rule_row["business_rule_description"],
rule_row["data_sources"],
rule_row["system_rule_text"],
rule_row["join_logic"],
rule_row["system_logic"],
rule_row["return_fields"],
]
if self.create_sql:
row.append(rule_row.get("sql", ""))
worksheet.append(row)
worksheet.freeze_panes = "A2"
worksheet.auto_filter.ref = worksheet.dimensions
worksheet.sheet_view.zoomScale = 90
worksheet.row_dimensions[1].height = 13.5
widths = self._excel_column_widths()
for column_index, width in enumerate(widths, start=1):
worksheet.column_dimensions[get_column_letter(column_index)].width = width
for cell in worksheet[1]:
cell.font = Font(name="宋体", size=11, bold=True)
cell.fill = _HIGHLIGHT_FILL if cell.column in _HIGHLIGHT_COLUMNS else _HEADER_FILL
cell.alignment = Alignment(horizontal="center", vertical="center", wrap_text=True)
cell.border = _THIN_BORDER
for row_index, row in enumerate(worksheet.iter_rows(min_row=2), start=2):
worksheet.row_dimensions[row_index].height = self._estimate_row_height(row)
for cell in row:
horizontal = "center" if cell.column in _CENTER_COLUMNS else "left"
cell.font = Font(name="宋体", size=11)
cell.alignment = Alignment(horizontal=horizontal, vertical="center", wrap_text=True)
cell.border = _THIN_BORDER
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
workbook.save(output_file)
@staticmethod
def _markdown_file_for_excel(output_file: str) -> str:
return str(Path(output_file).with_suffix(".md"))
def _write_markdown(self, rules: list[dict[str, Any]], markdown_file: str) -> None:
grouped_rules: dict[str, list[dict[str, Any]]] = {}
for rule_row in self._build_rule_rows(rules):
domain = str(rule_row.get("risk_domain", "") or "未分类")
grouped_rules.setdefault(domain, []).append(rule_row)
lines = [
"# 重点风险领域规则模型与脚本",
"",
"> 根据本次规则生成任务产出的 Excel 内容同步整理,目录和章节以实际生成的风险领域为准。",
"",
"---",
"",
"## 目录",
"",
]
if grouped_rules:
for domain_index, domain in enumerate(grouped_rules, start=1):
lines.append(f"{domain_index}. [{domain}](#{domain_index}-{domain})")
else:
lines.append("- 暂无生成成功的规则")
lines.extend(["", "---", ""])
for domain_index, (domain, domain_rules) in enumerate(grouped_rules.items(), start=1):
first_rule = domain_rules[0]
lines.extend([
f"## {domain_index}. {domain}",
"",
"### 风险描述",
self._markdown_value(first_rule.get("risk_description", "")),
"",
"### 数据来源",
])
data_source_lines = self._domain_data_source_lines(domain_rules)
if data_source_lines:
lines.extend(f"- {source}" for source in data_source_lines)
else:
lines.append("- 无")
lines.extend(["", "---", ""])
for rule_position, rule in enumerate(domain_rules, start=1):
lines.extend([
f"### 规则模型{domain_index}.{rule_position}{self._markdown_value(rule.get('rule_name', ''))}",
"",
f"**序号**{rule['index']}",
"",
f"**风险描述**{self._markdown_value(rule.get('risk_description', ''))}",
"",
f"**制度依据**{self._markdown_value(rule.get('policy_basis', ''))}",
"",
"**业务规则描述**",
"",
self._markdown_value(rule.get("business_rule_description", "")),
"",
"**数据来源**",
"",
self._markdown_value(rule.get("data_sources", "")),
"",
"**系统规则文本**",
"",
self._markdown_value(rule.get("system_rule_text", "")),
"",
f"**关联表逻辑**{self._markdown_value(rule.get('join_logic', ''))}",
"",
f"**系统固化逻辑**{self._markdown_value(rule.get('system_logic', ''))}",
"",
f"**返回结果**{self._markdown_value(rule.get('return_fields', ''))}",
"",
])
if self.create_sql:
lines.extend([
"**实现脚本SQL**",
"",
"```sql",
self._markdown_value(rule.get("sql", "")),
"```",
"",
])
lines.extend(["---", ""])
output_path = Path(markdown_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text("\n".join(lines).rstrip() + "\n", encoding="utf-8")
def _domain_data_source_lines(
self,
domain_rules: list[dict[str, Any]],
) -> list[str]:
sources: list[str] = []
seen: set[str] = set()
for rule in domain_rules:
formatted = rule.get("data_sources", "")
for line in str(formatted or "").splitlines():
source = line.strip()
if source and source not in seen:
sources.append(source)
seen.add(source)
return sources
@staticmethod
def _markdown_value(value: Any) -> str:
text = str(value or "").strip()
return text if text else ""
def _excel_columns(self) -> list[str]:
columns = list(_RULE_COLUMNS)
if self.create_sql:
columns.append("SQL查询语句")
return columns
def _excel_column_widths(self) -> list[float]:
widths = list(_COLUMN_WIDTHS)
if self.create_sql:
widths.append(72)
return widths
def _create_sql_query(self, rule: dict[str, Any]) -> str:
schema = self._read_schema()
tables = self._sql_tables_from_rule(rule, schema)
if not tables and schema:
tables = [self._schema_table_context(schema[0], 1)]
if not tables:
tables = [{"sql_name": "TABLE_1", "alias": "t1", "fields": []}]
primary_table = tables[0]
select_markers = self._sql_field_markers(rule, tables)
if not select_markers:
select_markers = [field["marker"] for field in primary_table["fields"][:4]]
if not select_markers:
select_markers = ["COMCODE", "COMNAME", "DAYIDX"]
metric_alias = self._sql_metric_alias(rule)
metric_expression = self._sql_metric_expression(rule, tables)
metric_components = self._sql_metric_components(metric_expression)
risk_cases = self._sql_risk_cases(rule, metric_expression)
join_key = self._sql_join_key(tables)
select_lines = [f" {self._qualified_marker(marker, tables)}" for marker in select_markers[:8]]
for index, component in enumerate(metric_components, start=1):
select_lines.append(f" {component} AS VALUE_{index}")
if metric_alias not in select_markers:
select_lines.append(f" {metric_expression} AS {metric_alias}")
select_lines.append(f" {risk_cases} AS RISK_LEVEL")
joins = self._sql_join_lines(tables, select_markers, join_key)
sql = "SELECT\n" + ",\n".join(select_lines)
sql += f"\nFROM {self._sql_aggregate_subquery(primary_table, select_markers, join_key)}"
if joins:
sql += "\n" + "\n".join(joins)
return self._validate_generated_sql(sql + ";")
@staticmethod
def _validate_generated_sql(sql: str) -> str:
if re.search(r"^\s*WITH\b", sql, flags=re.IGNORECASE):
raise ValueError("Generated SQL uses WITH/CTE, which is not allowed for MySQL compatibility.")
if re.search(r"[\u4e00-\u9fff]", sql):
raise ValueError("Generated SQL contains Chinese text.")
if re.search(r"SUM\s*\([^)]*(DATE|TIME|DAYIDX)[^)]*\)", sql, flags=re.IGNORECASE):
raise ValueError("Generated SQL tries to SUM date/time fields.")
if re.search(r"^\s*t\d+\s+AS\s*\(", sql, flags=re.IGNORECASE | re.MULTILINE):
raise ValueError("Generated SQL contains CTE-style table aliases.")
return sql
def _sql_tables_from_rule(self, rule: dict[str, Any], schema: list[dict[str, Any]]) -> list[dict[str, Any]]:
contexts = []
seen = set()
parsed_sources = self._parse_data_source_lines(str(rule.get("data_sources", "") or ""))
for table_label, _ in parsed_sources:
module, index = self._match_schema_module(table_label, schema)
if module:
context = self._schema_table_context(module, index)
if context["sql_name"] not in seen:
contexts.append(context)
seen.add(context["sql_name"])
if contexts:
return contexts[:4]
text = "\n".join(str(rule.get(key, "") or "") for key in ("data_sources", "return_fields", "system_logic"))
for index, module in enumerate(schema, start=1):
module_name = self._clean_module_name(module.get("module_name", ""))
if module_name and module_name in text:
context = self._schema_table_context(module, index)
if context["sql_name"] not in seen:
contexts.append(context)
seen.add(context["sql_name"])
return contexts[:4]
@staticmethod
def _parse_data_source_lines(data_sources: str) -> list[tuple[str, list[str]]]:
tables = []
for line in data_sources.splitlines():
raw_line = line.strip()
if not raw_line:
continue
split_at = -1
for separator in (":", "", ""):
split_at = raw_line.find(separator)
if split_at >= 0:
break
if split_at < 0:
continue
table_label = raw_line[:split_at].strip()
field_text = raw_line[split_at + 1:].strip()
field_labels = [item.strip() for item in re.split(r"[,,、;\s]+", field_text) if item.strip()]
if table_label:
tables.append((table_label, field_labels))
return tables
def _match_schema_module(
self,
table_label: str,
schema: list[dict[str, Any]],
) -> tuple[dict[str, Any] | None, int]:
clean_label = self._clean_module_name(table_label)
for index, module in enumerate(schema, start=1):
module_name = self._clean_module_name(module.get("module_name", ""))
if clean_label and module_name and (
clean_label == module_name or clean_label in module_name or module_name in clean_label
):
return module, index
return None, 0
def _schema_table_context(self, module: dict[str, Any], index: int) -> dict[str, Any]:
fields = []
seen = set()
for field in module.get("fields", []):
marker = self._safe_sql_identifier(field.get("marker", ""), "")
if marker and marker not in seen:
fields.append({"name": str(field.get("name", "") or ""), "marker": marker})
seen.add(marker)
return {
"sql_name": self._schema_table_identifier(module, index),
"alias": f"t{max(index, 1)}",
"fields": fields,
}
def _schema_table_identifier(self, module: dict[str, Any], index: int) -> str:
table_name = self._safe_table_identifier(module.get("table_name", ""))
if table_name:
return table_name
raw_name = self._clean_module_name(module.get("module_name", ""))
sql_name = self._safe_sql_identifier(raw_name, "")
return sql_name or f"TABLE_{max(index, 1):02d}"
def _sql_field_markers(self, rule: dict[str, Any], tables: list[dict[str, Any]]) -> list[str]:
markers = []
parsed_sources = self._parse_data_source_lines(str(rule.get("data_sources", "") or ""))
for table_label, field_labels in parsed_sources:
table = self._table_for_label(table_label, tables) or tables[0]
markers.extend(self._markers_for_field_labels(field_labels, table))
return_fields = str(rule.get("return_fields", "") or "")
markers.extend(self._markers_for_field_labels(self._field_tokens(return_fields), tables[0]))
available = {field["marker"] for table in tables for field in table["fields"]}
deduped = []
for marker in markers:
safe_marker = self._safe_sql_identifier(marker, "")
if safe_marker and safe_marker in available and safe_marker not in deduped:
deduped.append(safe_marker)
return deduped
def _table_for_label(self, table_label: str, tables: list[dict[str, Any]]) -> dict[str, Any] | None:
safe_label = self._safe_sql_identifier(table_label, "")
for table in tables:
if safe_label and safe_label == table["sql_name"]:
return table
return None
def _markers_for_field_labels(self, field_labels: list[str], table: dict[str, Any]) -> list[str]:
markers = []
for label in field_labels:
safe_label = self._safe_sql_identifier(label, "")
if safe_label and any(field["marker"] == safe_label for field in table["fields"]):
markers.append(safe_label)
continue
for field in table["fields"]:
field_name = field["name"]
if label and field_name and (label == field_name or label in field_name or field_name in label):
markers.append(field["marker"])
break
return markers
@staticmethod
def _field_tokens(text: str) -> list[str]:
return [item.strip() for item in re.split(r"[,,、;\s]+", text) if item.strip()]
def _qualified_marker(self, marker: str, tables: list[dict[str, Any]]) -> str:
for table in tables:
if any(field["marker"] == marker for field in table["fields"]):
return f"{table['alias']}.{marker}"
return f"{tables[0]['alias']}.{marker}"
def _sql_aggregate_subquery(
self,
table: dict[str, Any],
select_markers: list[str],
join_key: str,
) -> str:
table_markers = {field["marker"] for field in table["fields"]}
selected = [marker for marker in select_markers if marker in table_markers]
aggregate_markers = self._numeric_markers(table)
query_markers = []
for marker in [join_key, *selected, *aggregate_markers]:
if marker and marker in table_markers and marker not in query_markers:
query_markers.append(marker)
if join_key not in query_markers and table["fields"]:
query_markers.insert(0, table["fields"][0]["marker"])
select_lines = []
group_markers = []
for marker in query_markers:
if marker in aggregate_markers:
select_lines.append(f" SUM({marker}) AS {marker}")
else:
select_lines.append(f" {marker}")
group_markers.append(marker)
group_by = ", ".join(group_markers) if group_markers else "1"
joined_select_lines = ",\n".join(select_lines)
return (
"(\n"
" SELECT\n"
f"{joined_select_lines}\n"
f" FROM {table['sql_name']}\n"
f" GROUP BY {group_by}\n"
f") {table['alias']}"
)
@staticmethod
def _numeric_markers(table: dict[str, Any]) -> list[str]:
markers = []
for field in table["fields"]:
marker = field["marker"]
name = str(field.get("name", "") or "")
field_type = str(field.get("type", "") or "")
if any(token in marker for token in ("DATE", "TIME", "DAYIDX")):
continue
if any(token in name for token in ("日期", "时间")):
continue
if any(token in field_type for token in ("日期", "时间")):
continue
if any(token in marker for token in ("AMOUNT", "BALANCE", "ASSET", "LIABILITY", "RATE", "ZZJE", "ZZYE")) or any(
token in name for token in ("金额", "余额", "利率", "市值", "成本", "资产", "负债")
):
markers.append(marker)
return markers
@staticmethod
def _sql_join_key(tables: list[dict[str, Any]]) -> str:
if not tables:
return ""
primary_markers = {field["marker"] for field in tables[0]["fields"]}
for table in tables[1:]:
primary_markers &= {field["marker"] for field in table["fields"]}
return next((marker for marker in ("COMCODE", "COMNAME", "DAYIDX") if marker in primary_markers), None) or (
sorted(primary_markers)[0] if primary_markers else ""
)
def _sql_join_lines(self, tables: list[dict[str, Any]], select_markers: list[str], join_key: str) -> list[str]:
if len(tables) <= 1:
return []
primary = tables[0]
joins = []
for table in tables[1:4]:
table_markers = {field["marker"] for field in table["fields"]}
primary_markers = {field["marker"] for field in primary["fields"]}
if join_key and join_key in primary_markers and join_key in table_markers:
joins.append(
f"LEFT JOIN {self._sql_aggregate_subquery(table, select_markers, join_key)} "
f"ON {primary['alias']}.{join_key} = {table['alias']}.{join_key}"
)
else:
joins.append(f"LEFT JOIN {self._sql_aggregate_subquery(table, select_markers, join_key)} ON 1 = 1")
return joins
def _sql_metric_alias(self, rule: dict[str, Any]) -> str:
metric_match = re.search(r"([^\n=]{2,30})\s*=", str(rule.get("business_rule_description", "")))
metric_name = metric_match.group(1).strip() if metric_match else ""
return self._safe_sql_identifier(metric_name, "RISK_METRIC")
def _sql_metric_expression(self, rule: dict[str, Any], tables: list[dict[str, Any]]) -> str:
formula = self._metric_formula_rhs(str(rule.get("business_rule_description", "") or ""))
expression = self._replace_formula_fields(formula, tables)
if expression:
return expression
for table in tables:
for field in table["fields"]:
marker = field["marker"]
if any(token in marker for token in ("AMOUNT", "BALANCE", "RATE", "ZZJE", "ZZYE")):
return f"{table['alias']}.{marker}"
return "0"
@staticmethod
def _sql_metric_components(metric_expression: str) -> list[str]:
match = re.fullmatch(r"\((.+)\s*/\s*NULLIF\((.+),\s*0\)\)", metric_expression.strip())
if match:
return [match.group(1).strip(), match.group(2).strip()]
return []
@staticmethod
def _metric_formula_rhs(text: str) -> str:
for line in text.splitlines():
if "=" in line:
return line.split("=", 1)[1].strip()
return ""
def _replace_formula_fields(self, formula: str, tables: list[dict[str, Any]]) -> str:
expression = formula.strip()
if not expression:
return ""
replacements = []
for table in tables:
for field in table["fields"]:
name = str(field.get("name", "") or "").strip()
if name and name in expression:
replacements.append((name, f"{table['alias']}.{field['marker']}"))
for name, sql_expression in sorted(replacements, key=lambda item: len(item[0]), reverse=True):
expression = expression.replace(name, sql_expression)
expression = expression.replace("", "(").replace("", ")")
expression = re.sub(r"\s+", " ", expression)
if re.search(r"[\u4e00-\u9fff]", expression):
return ""
if "/" in expression:
left, right = expression.split("/", 1)
left = left.strip()
right = right.strip()
if left and right:
return f"({left} / NULLIF({right}, 0))"
return expression
def _sql_risk_cases(self, rule: dict[str, Any], metric_expression: str) -> str:
thresholds = self._risk_thresholds(rule)
if not thresholds:
return "CASE WHEN 1 = 1 THEN 'HIT' ELSE 'PASS' END"
case_lines = ["CASE"]
for label, operator, value in thresholds:
case_lines.append(f" WHEN {metric_expression} {operator} {value} THEN '{label}'")
case_lines.append(" ELSE 'PASS'")
case_lines.append(" END")
return "\n ".join(case_lines)
def _risk_thresholds(self, rule: dict[str, Any]) -> list[tuple[str, str, str]]:
business_rule = str(rule.get("business_rule_description", "") or "")
colors = [
("RED", "\u7ea2\u8272"),
("ORANGE", "\u6a59\u8272"),
("YELLOW", "\u9ec4\u8272"),
]
thresholds = []
for label, color in colors:
for line in business_rule.splitlines():
if color not in line:
continue
comparison = self._first_sql_comparison(line)
if comparison:
thresholds.append((label, comparison[0], comparison[1]))
break
return thresholds
def _system_rule_conditions(self, system_rule: str, metric_expression: str) -> list[str]:
comparisons = self._sql_comparisons(system_rule)
return [f"{metric_expression} {operator} {value}" for operator, value in comparisons]
def _first_sql_comparison(self, text: str) -> tuple[str, str] | None:
comparisons = self._sql_comparisons(text)
return comparisons[0] if comparisons else None
@staticmethod
def _sql_comparisons(text: str) -> list[tuple[str, str]]:
operator_map = {
">=": ">=",
"<=": "<=",
">": ">",
"<": "<",
"\u5927\u4e8e\u7b49\u4e8e": ">=",
"\u5c0f\u4e8e\u7b49\u4e8e": "<=",
"\u5927\u4e8e": ">",
"\u5c0f\u4e8e": "<",
}
operator_pattern = "|".join(re.escape(operator) for operator in operator_map)
comparisons = []
for match in re.finditer(rf"({operator_pattern})\s*([0-9]+(?:\.[0-9]+)?)\s*(%)?", text):
value = float(match.group(2))
if match.group(3):
value = value / 100
comparisons.append((operator_map[match.group(1)], f"{value:g}"))
return comparisons
@staticmethod
def _safe_sql_identifier(value: Any, fallback: str) -> str:
text = str(value or "").strip().upper()
text = re.sub(r"[^A-Z0-9_]+", "_", text)
text = re.sub(r"_+", "_", text).strip("_")
if not text or not re.search(r"[A-Z]", text):
return fallback
if text[0].isdigit():
text = f"T_{text}"
return text
@staticmethod
def _safe_table_identifier(value: Any) -> str:
text = str(value or "").strip()
parts = []
for part in text.split("."):
safe_part = re.sub(r"[^A-Za-z0-9_]+", "_", part)
safe_part = re.sub(r"_+", "_", safe_part).strip("_")
if not safe_part or not re.search(r"[A-Za-z]", safe_part):
continue
if safe_part[0].isdigit():
safe_part = f"T_{safe_part}"
parts.append(safe_part)
if not parts:
return ""
return ".".join(parts)
@staticmethod
def _estimate_row_height(row) -> float:
max_lines = 1
for cell in row:
value = str(cell.value or "")
explicit_lines = value.count("\n") + 1
width = _COLUMN_WIDTHS[cell.column - 1] if cell.column - 1 < len(_COLUMN_WIDTHS) else 72
wrapped_lines = max(1, len(value) // max(int(width * 1.6), 1) + 1)
max_lines = max(max_lines, explicit_lines, wrapped_lines)
return min(max(36, max_lines * 18), 409)