1901 lines
89 KiB
Python
1901 lines
89 KiB
Python
"""异步业务规则生成服务。"""
|
||
|
||
import json
|
||
import os
|
||
import random
|
||
import re
|
||
import threading
|
||
import uuid
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
from typing import Any
|
||
|
||
from openpyxl import Workbook
|
||
from openpyxl.styles import Alignment, Border, Font, PatternFill, Side
|
||
from openpyxl.utils import get_column_letter
|
||
|
||
from app.utils.logger import get_logger
|
||
|
||
logger = get_logger("rule_generation")
|
||
|
||
_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
_DOMAINS_FILE = os.path.join(_PROJECT_ROOT, "data", "domains.json")
|
||
_SCHEMA_FILE = os.path.join(_PROJECT_ROOT, "data", "schema.json")
|
||
_OUTPUT_DIR = os.path.join(_PROJECT_ROOT, "output")
|
||
_TASK_DIR = os.path.join(_OUTPUT_DIR, "tasks")
|
||
|
||
_RULE_COLUMNS = [
|
||
"序号",
|
||
"风险领域",
|
||
"规则名称",
|
||
"风险描述",
|
||
"制度依据",
|
||
"业务规则描述",
|
||
"数据来源",
|
||
"系统规则文本",
|
||
"关联表逻辑",
|
||
"系统固化逻辑",
|
||
"返回结果",
|
||
]
|
||
|
||
_COLUMN_WIDTHS = [5.13, 14.9, 15.84, 27.08, 35.34, 40.55, 37.23, 107.36, 34.58, 39.58, 27.23]
|
||
_CENTER_COLUMNS = {1, 2, 3}
|
||
_HIGHLIGHT_COLUMNS = {6, 7, 8}
|
||
|
||
_THIN_SIDE = Side(style="thin", color="000000")
|
||
_THIN_BORDER = Border(left=_THIN_SIDE, right=_THIN_SIDE, top=_THIN_SIDE, bottom=_THIN_SIDE)
|
||
_HEADER_FILL = PatternFill("solid", fgColor="FFFFFF")
|
||
_HIGHLIGHT_FILL = PatternFill("solid", fgColor="FFFF00")
|
||
|
||
_DOMAIN_TABLE_HINTS = {
|
||
"过度负债": ["银行贷款", "应付债券", "银行账户", "担保", "应付票据", "财务公司"],
|
||
"无关多元": ["金融投资", "客商信息", "应收款项", "应付款项", "合同"],
|
||
"多层架构": ["客商信息", "合同", "应收款项", "应付款项"],
|
||
"薪酬乱象": ["资金结算", "银行账户"],
|
||
"财务金融风险": ["银行贷款", "应付债券", "担保", "金融投资", "资金结算", "银行账户"],
|
||
"控股不控权": ["客商信息", "合同", "应收款项", "应付款项"],
|
||
"虚假贸易": ["资金结算", "合同", "增值税发票", "应收款项", "应付款项", "客商信息"],
|
||
"捞偏门": ["资金结算", "合同", "客商信息", "增值税发票"],
|
||
"靠企吃企": ["客商信息", "资金结算", "合同", "应付款项", "应收款项"],
|
||
"资产闲置浪费": ["银行账户", "金融投资", "应收款项", "应付款项"],
|
||
"违规境外业务": ["资金结算", "银行账户", "金融投资", "合同"],
|
||
}
|
||
|
||
_KEYWORD_TABLE_HINTS = {
|
||
"资产负债率": ["银行贷款", "应付债券", "银行账户", "担保", "财务公司"],
|
||
"债务": ["银行贷款", "应付债券", "应付票据", "担保"],
|
||
"融资": ["银行贷款", "应付债券", "担保", "供应链金融"],
|
||
"担保": ["担保", "银行贷款"],
|
||
"票据": ["应付票据", "应收票据"],
|
||
"资金流": ["资金结算", "银行账户"],
|
||
"账户": ["银行账户", "资金结算"],
|
||
"交易": ["资金结算", "合同", "增值税发票", "客商信息"],
|
||
"合同": ["合同", "资金结算", "增值税发票"],
|
||
"发票": ["增值税发票", "合同"],
|
||
"客商": ["客商信息", "应收款项", "应付款项"],
|
||
"关联": ["客商信息", "资金结算", "合同"],
|
||
"投资": ["金融投资业务", "客商信息"],
|
||
"主业": ["金融投资业务", "客商信息"],
|
||
}
|
||
|
||
_FIELD_ROLE_KEYWORDS = {
|
||
"subject": ("所属集团编码", "所属集团名称", "开户单位编码", "开户单位名称", "单位编码", "单位名称", "贷款单位名称", "客商名称"),
|
||
"amount": ("余额", "金额", "资产", "负债", "市值", "成本", "发生额"),
|
||
"status": ("状态", "标识", "类型", "是否", "级别", "性质", "分类"),
|
||
"date": ("日期", "时间", "期间", "期限", "到期"),
|
||
"counterparty": ("承销商", "放款单位", "交易对手", "客商", "合同", "账户", "银行"),
|
||
}
|
||
|
||
|
||
class RuleGenerationService:
|
||
"""规则生成任务管理器。"""
|
||
|
||
def __init__(
|
||
self,
|
||
domains_file: str = _DOMAINS_FILE,
|
||
schema_file: str = _SCHEMA_FILE,
|
||
output_dir: str = _OUTPUT_DIR,
|
||
create_sql: bool = False,
|
||
):
|
||
self.domains_file = domains_file
|
||
self.schema_file = schema_file
|
||
self.output_dir = output_dir
|
||
self.task_dir = os.path.join(output_dir, "tasks")
|
||
self.create_sql = create_sql
|
||
os.makedirs(self.output_dir, exist_ok=True)
|
||
os.makedirs(self.task_dir, exist_ok=True)
|
||
|
||
def start(self, limit: int = 30) -> dict[str, Any]:
|
||
limit = self._normalize_limit(limit)
|
||
task_id = self._new_task_id()
|
||
task_output_dir = os.path.join(self.output_dir, f"rules-{task_id}")
|
||
output_file = os.path.join(task_output_dir, f"rules-{task_id}.xlsx")
|
||
markdown_file = self._markdown_file_for_excel(output_file)
|
||
state = {
|
||
"task_id": task_id,
|
||
"status": "queued",
|
||
"limit": limit,
|
||
"rules_per_policy_point": limit,
|
||
"create_sql": self.create_sql,
|
||
"progress": {"current": 0, "total": 0},
|
||
"generated_count": 0,
|
||
"output_dir": task_output_dir,
|
||
"output_file": output_file,
|
||
"markdown_file": markdown_file,
|
||
"files": {
|
||
"excel": output_file,
|
||
"markdown": markdown_file,
|
||
},
|
||
"started_at": datetime.now().isoformat(),
|
||
"finished_at": None,
|
||
"error": None,
|
||
"markdown_error": None,
|
||
"current_domain": "",
|
||
"skipped_domains": [],
|
||
"skipped_rules": [],
|
||
}
|
||
self._write_state(task_id, state)
|
||
|
||
thread = threading.Thread(target=self._run_task, args=(task_id, limit, output_file), daemon=True)
|
||
thread.start()
|
||
return state
|
||
|
||
def get_status(self, task_id: str) -> dict[str, Any] | None:
|
||
path = self._state_path(task_id)
|
||
if not os.path.exists(path):
|
||
return None
|
||
return self._read_json(path)
|
||
|
||
def _run_task(self, task_id: str, limit: int, output_file: str) -> None:
|
||
state = self.get_status(task_id) or {}
|
||
try:
|
||
state["status"] = "processing"
|
||
self._write_state(task_id, state)
|
||
|
||
domains = self._collect_policy_basis()
|
||
selected_points = self._select_policy_points(domains, limit)
|
||
schema = self._read_schema()
|
||
state["progress"] = {"current": 0, "total": len(selected_points)}
|
||
self._write_state(task_id, state)
|
||
|
||
rules: list[dict[str, Any]] = []
|
||
skipped_domains = self._skipped_domains_without_basis(domains)
|
||
skipped_rules: list[dict[str, Any]] = []
|
||
|
||
for index, point in enumerate(selected_points, start=1):
|
||
domain = point["domain"]
|
||
pattern = point["pattern"]
|
||
pattern = {
|
||
**pattern,
|
||
"_variant_index": point.get("variant_index", 1),
|
||
"_variant_total": point.get("variant_total", limit),
|
||
}
|
||
state["current_domain"] = domain
|
||
state["progress"] = {"current": index, "total": len(selected_points)}
|
||
self._write_state(task_id, state)
|
||
|
||
candidates = self._select_schema_candidates(domain, pattern, schema)
|
||
if not candidates:
|
||
skipped_rules.append({
|
||
"domain": domain,
|
||
"basis": self._build_policy_basis_text(pattern)[:120],
|
||
"reason": "未匹配到可用 schema 表",
|
||
})
|
||
continue
|
||
|
||
try:
|
||
rule = self._generate_rule(domain, pattern, candidates)
|
||
except Exception as exc:
|
||
reason = str(exc)
|
||
logger.error(
|
||
"LLM 规则生成失败,跳过该规则 | domain=%s | reason=%s",
|
||
domain,
|
||
reason,
|
||
exc_info=True,
|
||
)
|
||
skipped_rules.append({
|
||
"domain": domain,
|
||
"basis": self._build_policy_basis_text(pattern)[:120],
|
||
"reason": reason,
|
||
})
|
||
state["skipped_rules"] = skipped_rules
|
||
self._write_state(task_id, state)
|
||
continue
|
||
|
||
rule.setdefault("risk_domain", domain)
|
||
rules.append(rule)
|
||
state["generated_count"] = len(rules)
|
||
state["skipped_rules"] = skipped_rules
|
||
self._write_state(task_id, state)
|
||
|
||
self._write_excel(rules, output_file)
|
||
try:
|
||
self._write_markdown(rules, state.get("markdown_file") or self._markdown_file_for_excel(output_file))
|
||
state["markdown_error"] = None
|
||
except Exception as exc:
|
||
state["markdown_error"] = str(exc)
|
||
logger.warning("Markdown 产物写入失败 | task_id=%s", task_id, exc_info=True)
|
||
state["status"] = "done"
|
||
state["generated_count"] = len(rules)
|
||
state["skipped_domains"] = skipped_domains
|
||
state["skipped_rules"] = skipped_rules
|
||
state["current_domain"] = ""
|
||
state["finished_at"] = datetime.now().isoformat()
|
||
if not rules:
|
||
state["status"] = "failed"
|
||
state["error"] = "LLM 未成功生成任何规则,未使用兜底规则。"
|
||
self._write_state(task_id, state)
|
||
except Exception as exc:
|
||
logger.error("规则生成任务失败 | task_id=%s", task_id, exc_info=True)
|
||
state["status"] = "failed"
|
||
state["error"] = str(exc)
|
||
state["finished_at"] = datetime.now().isoformat()
|
||
try:
|
||
self._write_excel([], output_file)
|
||
except Exception:
|
||
logger.warning("规则生成失败后写空 Excel 失败 | task_id=%s", task_id, exc_info=True)
|
||
try:
|
||
self._write_markdown([], state.get("markdown_file") or self._markdown_file_for_excel(output_file))
|
||
except Exception as markdown_exc:
|
||
state["markdown_error"] = str(markdown_exc)
|
||
logger.warning("规则生成失败后写空 Markdown 失败 | task_id=%s", task_id, exc_info=True)
|
||
self._write_state(task_id, state)
|
||
|
||
@staticmethod
|
||
def _normalize_limit(limit: int) -> int:
|
||
try:
|
||
value = int(limit)
|
||
except (TypeError, ValueError):
|
||
value = 30
|
||
return max(1, min(value, 30))
|
||
|
||
@staticmethod
|
||
def _new_task_id() -> str:
|
||
return f"{datetime.now().strftime('%Y%m%d-%H%M%S')}-{uuid.uuid4().hex[:6]}"
|
||
|
||
def _state_path(self, task_id: str) -> str:
|
||
safe_task_id = re.sub(r"[^0-9A-Za-z_-]", "", task_id)
|
||
return os.path.join(self.task_dir, f"{safe_task_id}.json")
|
||
|
||
def _write_state(self, task_id: str, state: dict[str, Any]) -> None:
|
||
path = self._state_path(task_id)
|
||
temp_path = f"{path}.tmp"
|
||
with open(temp_path, "w", encoding="utf-8") as file:
|
||
json.dump(state, file, ensure_ascii=False, indent=2)
|
||
os.replace(temp_path, path)
|
||
|
||
@staticmethod
|
||
def _read_json(path: str) -> dict[str, Any]:
|
||
with open(path, "r", encoding="utf-8") as file:
|
||
return json.load(file)
|
||
|
||
def _collect_policy_basis(self) -> list[dict[str, Any]]:
|
||
data = self._read_json(self.domains_file)
|
||
result = []
|
||
for domain in data.get("domains", []):
|
||
patterns = []
|
||
for file_record in domain.get("guidance_files", []):
|
||
analysis = file_record.get("guidance_analysis") or {}
|
||
if analysis.get("status") != "done":
|
||
continue
|
||
for pattern in analysis.get("description_patterns", []):
|
||
if pattern.get("description_pattern") or pattern.get("basis_text"):
|
||
patterns.append(pattern)
|
||
if patterns:
|
||
result.append({
|
||
"token": domain.get("token", ""),
|
||
"domain": domain.get("domain", ""),
|
||
"patterns": patterns,
|
||
})
|
||
return result
|
||
|
||
def _skipped_domains_without_basis(self, domains_with_basis: list[dict[str, Any]]) -> list[dict[str, str]]:
|
||
data = self._read_json(self.domains_file)
|
||
tokens_ready = {item["token"] for item in domains_with_basis}
|
||
skipped = []
|
||
for domain in data.get("domains", []):
|
||
domain_name = domain.get("domain", "")
|
||
if domain.get("token", "") not in tokens_ready:
|
||
skipped.append({"domain": domain_name, "reason": "未上传指导文件或未执行指导文件解析"})
|
||
return skipped
|
||
|
||
@staticmethod
|
||
def _select_policy_points(domains: list[dict[str, Any]], limit: int) -> list[dict[str, Any]]:
|
||
selected = []
|
||
for domain_item in domains:
|
||
patterns = list(domain_item.get("patterns", []))
|
||
if not patterns:
|
||
continue
|
||
if len(patterns) >= limit:
|
||
chosen_patterns = random.sample(patterns, limit)
|
||
else:
|
||
chosen_patterns = patterns[:]
|
||
chosen_patterns.extend(random.choice(patterns) for _ in range(limit - len(patterns)))
|
||
for variant_index, pattern in enumerate(chosen_patterns, start=1):
|
||
selected.append({
|
||
"domain": domain_item["domain"],
|
||
"pattern": pattern,
|
||
"variant_index": variant_index,
|
||
"variant_total": limit,
|
||
})
|
||
return selected
|
||
|
||
def _read_schema(self) -> list[dict[str, Any]]:
|
||
data = self._read_json(self.schema_file)
|
||
modules = data.get("modules", [])
|
||
return [module for module in modules if module.get("module_name") and module.get("fields")]
|
||
|
||
def _select_schema_candidates(
|
||
self,
|
||
domain: str,
|
||
pattern: dict[str, Any],
|
||
schema: list[dict[str, Any]],
|
||
) -> list[dict[str, Any]]:
|
||
text = " ".join([
|
||
domain,
|
||
pattern.get("description_pattern", ""),
|
||
pattern.get("basis_text", ""),
|
||
pattern.get("source_sentence", ""),
|
||
" ".join(self._normalize_keywords(pattern)),
|
||
])
|
||
hint_order: list[str] = []
|
||
for hint in _DOMAIN_TABLE_HINTS.get(domain, []):
|
||
if hint not in hint_order:
|
||
hint_order.append(hint)
|
||
for keyword, table_names in _KEYWORD_TABLE_HINTS.items():
|
||
if keyword in text:
|
||
for table_name in table_names:
|
||
if table_name not in hint_order:
|
||
hint_order.append(table_name)
|
||
|
||
scored = []
|
||
for module in schema:
|
||
module_name = self._clean_module_name(module.get("module_name", ""))
|
||
haystack = " ".join([
|
||
module_name,
|
||
module.get("description", ""),
|
||
" ".join(field.get("name", "") for field in module.get("fields", [])),
|
||
])
|
||
score = 0
|
||
rank = self._hint_rank(module_name, hint_order)
|
||
if rank <= len(hint_order):
|
||
score += 60 - min(rank, 40)
|
||
score += sum(1 for keyword in self._keywords_from_text(text) if keyword and keyword in haystack)
|
||
roles = self._field_roles(module)
|
||
if roles["amount"]:
|
||
score += 4
|
||
if roles["subject"]:
|
||
score += 4
|
||
if score > 0:
|
||
scored.append((score, module))
|
||
|
||
scored.sort(key=lambda item: item[0], reverse=True)
|
||
return [self._summarize_module(module) for _, module in scored[:4]]
|
||
|
||
@staticmethod
|
||
def _hint_rank(module_name: str, hint_order: list[str]) -> int:
|
||
for index, hint in enumerate(hint_order):
|
||
if hint in module_name or module_name in hint:
|
||
return index
|
||
return len(hint_order) + 1
|
||
|
||
@staticmethod
|
||
def _clean_module_name(module_name: str) -> str:
|
||
return re.sub(r"^\s*\d+\s*[、..))\-_]*\s*", "", module_name or "").strip()
|
||
|
||
@staticmethod
|
||
def _keywords_from_text(text: str) -> set[str]:
|
||
words = set()
|
||
for word in re.findall(r"[\u4e00-\u9fa5A-Za-z0-9]{2,}", text or ""):
|
||
if len(word) >= 2:
|
||
words.add(word)
|
||
return words
|
||
|
||
@staticmethod
|
||
def _normalize_keywords(pattern: dict[str, Any]) -> list[str]:
|
||
raw_keywords = pattern.get("keywords") or []
|
||
if isinstance(raw_keywords, list):
|
||
return [str(item).strip() for item in raw_keywords if str(item).strip()]
|
||
if isinstance(raw_keywords, str):
|
||
return [raw_keywords.strip()] if raw_keywords.strip() else []
|
||
return []
|
||
|
||
def _summarize_module(self, module: dict[str, Any]) -> dict[str, Any]:
|
||
roles = self._field_roles(module)
|
||
ordered_fields = []
|
||
for role in ("subject", "amount", "status", "date", "counterparty"):
|
||
ordered_fields.extend(roles[role])
|
||
if not ordered_fields:
|
||
ordered_fields = module.get("fields", [])[:6]
|
||
seen = set()
|
||
concise_fields = []
|
||
for field in ordered_fields:
|
||
name = field.get("name", "")
|
||
if not name or name in seen:
|
||
continue
|
||
seen.add(name)
|
||
concise_fields.append({
|
||
"name": name,
|
||
"marker": field.get("marker", ""),
|
||
"type": field.get("type", ""),
|
||
"rule": field.get("rule", ""),
|
||
})
|
||
if len(concise_fields) >= 8:
|
||
break
|
||
return {
|
||
"module_name": self._clean_module_name(module.get("module_name", "")),
|
||
"table_name": module.get("table_name", ""),
|
||
"description": module.get("description", ""),
|
||
"fields": concise_fields,
|
||
}
|
||
|
||
def _field_roles(self, module: dict[str, Any]) -> dict[str, list[dict[str, Any]]]:
|
||
roles = {key: [] for key in _FIELD_ROLE_KEYWORDS}
|
||
for field in module.get("fields", []):
|
||
name = field.get("name", "")
|
||
for role, keywords in _FIELD_ROLE_KEYWORDS.items():
|
||
if any(keyword in name for keyword in keywords):
|
||
roles[role].append(field)
|
||
return roles
|
||
|
||
def _generate_rule(
|
||
self,
|
||
domain: str,
|
||
pattern: dict[str, Any],
|
||
schema_candidates: list[dict[str, Any]],
|
||
) -> dict[str, Any]:
|
||
last_error: Exception | None = None
|
||
for attempt in range(2):
|
||
try:
|
||
llm_rule = self._generate_rule_with_llm(domain, pattern, schema_candidates)
|
||
if not llm_rule:
|
||
raise ValueError("LLM 未返回可解析的规则 JSON")
|
||
return self._normalize_rule(llm_rule, domain, pattern, schema_candidates)
|
||
except Exception as exc:
|
||
last_error = exc
|
||
if attempt == 0:
|
||
logger.warning(
|
||
"LLM 规则生成失败,准备重试 | domain=%s | reason=%s",
|
||
domain,
|
||
exc,
|
||
)
|
||
continue
|
||
raise
|
||
raise ValueError(str(last_error) if last_error else "LLM 未返回可解析的规则 JSON")
|
||
|
||
def _generate_rule_with_llm(
|
||
self,
|
||
domain: str,
|
||
pattern: dict[str, Any],
|
||
schema_candidates: list[dict[str, Any]],
|
||
) -> dict[str, Any] | None:
|
||
from app.utils.llm import LLMClient, strip_thinking
|
||
|
||
prompt = self._build_rule_prompt(domain, pattern, schema_candidates)
|
||
response = LLMClient(timeout=120).chat(
|
||
messages=[
|
||
{
|
||
"role": "system",
|
||
"content": (
|
||
"你是央企监管规则清单整理助手。"
|
||
"你的任务是把制度依据整理成适合 Excel 展示的业务梳理稿,"
|
||
"不是发明新规则,不要写空话。"
|
||
"只输出指定键值块,不输出解释。"
|
||
),
|
||
},
|
||
{"role": "user", "content": prompt},
|
||
],
|
||
temperature=0.2,
|
||
max_tokens=2600,
|
||
thinking={"type": "disabled"},
|
||
)
|
||
return self._parse_rule_response(strip_thinking(response))
|
||
|
||
def _build_rule_prompt(self, domain: str, pattern: dict[str, Any], schema_candidates: list[dict[str, Any]]) -> str:
|
||
policy_basis = self._build_policy_basis_text(pattern)
|
||
source_sentence = pattern.get("source_sentence", "")
|
||
indicator_guidance = self._build_indicator_guidance(domain, pattern, schema_candidates)
|
||
variant_guidance = self._build_variant_guidance(pattern)
|
||
return f"""请根据制度依据和可用数据表,整理一条适合 Excel 规则清单展示的业务规则说明,输出 JSON 对象。
|
||
|
||
风险领域:{domain}
|
||
制度依据:{policy_basis}
|
||
原文依据:{source_sentence}
|
||
监管维度:{pattern.get('supervision_dimension', '')}
|
||
|
||
可用数据表和字段:
|
||
{json.dumps(schema_candidates, ensure_ascii=False, indent=2)}
|
||
|
||
指标口径选择建议:
|
||
{indicator_guidance}
|
||
|
||
同一监管点多规则生成要求:
|
||
{variant_guidance}
|
||
|
||
写作要求:
|
||
1. 只能使用上面列出的表和字段,不得编造表名、字段名。
|
||
2. 输出要像人工梳理稿,语言克制、具体、可追溯。
|
||
3. 模型名称要短,不超过12个字。
|
||
4. 模型风险描述只写一句话,先写业务现象,再写风险后果。
|
||
5. 制度依据必须原样使用输入中的“制度依据”,不得改写、删减或另写条款要点。
|
||
6. 业务规则描述对应《国能poc演示清单》的“业务检查方式”,核心是模拟查询指标。必须由两段组成,并用换行分隔:
|
||
第一段写公式/计算口径;如果有多个公式,每个公式必须单独换行。
|
||
第二段写指标阈值或红/橙/黄风险等级;阈值必须和上方公式指标一一对应,两个公式就写两个指标的阈值,一个公式就只写一个指标的阈值。
|
||
指标类型要根据政策依据和字段动态选择,不限于比例;可以是金额、余额、数量、笔数、次数、天数、期限、集中度、占比、收益率、状态标记等。
|
||
示例:逾期票据金额、逾期票据数、合同类型数量、单一类型占比、投资收益率、短期债务比例、偿债能力指标、担保余额、资金回流次数。
|
||
7. 数据来源必须动态分析 schema,只按“TABLE_01(表名):MARKER(字段中文名)、MARKER(字段中文名)”输出;每张表最多列 2-4 个关键字段,字段英文名必须来自 schema.marker,不要写解释。多张表必须换行,严禁用分号连在一行。
|
||
8. 系统规则文本必须仿照参考文件 H 列,不是自然语言摘要。必须使用“规则名(阈值区间): 如果:...并且...就:校验:通过 赋值:...”的规则伪代码;优先使用“数值求和{{【表名】,【表名-字段名】}}”“大于/小于等于”等表达。
|
||
9. 关联表逻辑必须仿照参考文件 I 列,写清楚主表、左关联表、关联键、筛选和分组口径。
|
||
10. 系统固化逻辑必须仿照参考文件 J 列,写“值1:...;值2:...;值3:...”中间指标定义和最终公式。
|
||
11. 返回结果只列建议输出字段。
|
||
12. 不要出现“基于…识别…风险”“多维穿透分析”“结合业务场景综合判断”等套话。
|
||
13. 必须完整输出 9 个字段,不能省略任何字段,字段内容未知也要根据已给 schema 和制度依据合理生成:规则名称、风险描述、制度依据、业务规则描述、数据来源、系统规则文本、关联表逻辑、系统固化逻辑、返回结果。
|
||
|
||
示例只用于仿写格式,不得照抄示例业务内容;实际内容必须来自本次制度依据和可用数据表字段。
|
||
参考风格示例:
|
||
示例1:
|
||
- 模型名称:资产负债率超限预警
|
||
- 模型风险描述:债务类金额占可用资金口径偏高时,企业存在偿债压力和杠杆失衡风险。
|
||
- 制度依据:核心法规:《关于加强国有企业资产负债约束的指导意见》;条款要点:分行业设置资产负债率预警线和重点监管线,持续监测高负债企业债务风险。
|
||
- 业务规则描述:资产负债率 = 贷款余额 / 账户余额\n风险等级:\n- 红色:资产负债率 > 85%\n- 橙色:资产负债率 > 75% 且 ≤ 85%\n- 黄色:资产负债率 > 65% 且 ≤ 75%。
|
||
- 数据来源:银行贷款:所属集团编码、所属集团名称、贷款余额\n银行账户:所属集团编码、所属集团名称、账户余额。
|
||
- 系统规则文本:资产负债率超限预警(65%<资产负债率<=75%):\n如果:\n (【银行贷款表-贷款单位编码】等于【银行账户表-开户单位编码】)\n 并且 (数值求和{{【银行贷款表】,【银行贷款表-贷款余额】}}/数值求和{{【银行账户表】,【银行账户表-账户余额】}} 大于 0.65)\n 并且 (数值求和{{【银行贷款表】,【银行贷款表-贷款余额】}}/数值求和{{【银行账户表】,【银行账户表-账户余额】}} 小于等于 0.75)\n就:\n校验:通过\n赋值:【规则结果-风险等级】等于 黄色
|
||
- 关联表逻辑:银行贷款表左关联银行账户表(贷款单位编码=开户单位编码),按所属集团编码、所属集团名称、贷款单位编码汇总。
|
||
- 系统固化逻辑:值1:贷款余额本币=SUM(银行贷款.贷款余额);值2:账户余额本币=SUM(银行账户.账户余额);值3:资产负债率=值1/NULLIF(值2,0);按65%、75%、85%阈值输出风险等级。
|
||
|
||
示例2:
|
||
- 模型名称:债务集中度过高风险
|
||
- 模型风险描述:债务集中于单一金融机构时,融资渠道较为单一,存在集中兑付风险。
|
||
- 制度依据:核心法规:《中央企业债券发行管理办法》;条款要点:分散融资渠道,防范债务集中到期和单一机构依赖。
|
||
- 业务规则描述:债务集中度 = 单一金融机构债务余额 / 总债务余额\n风险等级:\n- 红色:债务集中度 > 40%\n- 橙色:债务集中度 > 30% 且 ≤ 40%\n- 黄色:债务集中度 > 20% 且 ≤ 30%。
|
||
- 数据来源:银行贷款:放款单位名称、贷款余额\n应付债券:承销商、债券余额。
|
||
- 系统规则文本:债务集中度过高风险(20%<集中度<=30%):\n如果:\n (【银行贷款表-放款单位名称】等于【应付债券表-承销商】)\n 并且 ((数值求和{{【银行贷款表】,【银行贷款表-贷款余额】}}+数值求和{{【应付债券表】,【应付债券表-债券余额】}})/数值求和{{【银行贷款表】,【银行贷款表-贷款余额】}} 大于 0.20)\n 并且 (上述比例 小于等于 0.30)\n就:\n校验:通过\n赋值:【规则结果-风险等级】等于 黄色
|
||
- 关联表逻辑:从银行贷款表按放款单位名称分组统计贷款余额;从应付债券表按承销商分组统计债券余额;按金融机构名称口径合并计算单一机构债务余额。
|
||
- 系统固化逻辑:值1:单一金融机构债务余额=SUM(银行贷款.贷款余额)+SUM(应付债券.债券余额);值2:总债务余额=SUM(全部贷款余额+全部债券余额);值3:债务集中度=值1/NULLIF(值2,0)。
|
||
|
||
示例3:
|
||
- 模型名称:逾期风险预警
|
||
- 模型风险描述:票据或债务出现逾期金额较大、逾期笔数较多时,企业存在偿债困难风险。
|
||
- 制度依据:核心法规:《中央企业债券发行管理办法》;条款要点:建立债务到期预警机制,严禁恶意逃废债,逾期债务严肃追责。
|
||
- 业务规则描述:逾期票据金额 = 是否逾期为“是”的票据金额合计\n逾期票据数 = 是否逾期为“是”的票据记录数\n风险等级:\n- 红色:逾期票据金额 > 5000万元 或 逾期票据数 > 10笔\n- 橙色:逾期票据金额 > 1000万元 或 逾期票据数 > 5笔\n- 黄色:逾期票据金额 > 500万元 或 逾期票据数 > 3笔。
|
||
- 数据来源:应付票据:是否逾期、到期日期、票面金额。
|
||
- 系统规则文本:逾期风险预警(逾期票据金额>500万元或逾期票据数>3笔):\n如果:\n (【应付票据表-是否逾期】等于 是)\n 并且 (数值求和{{【应付票据表】,【应付票据表-票面金额】}} 大于 5000000 或 记录计数{{【应付票据表】,【应付票据表-票据编号】}} 大于 3)\n就:\n校验:通过\n赋值:【规则结果-风险等级】等于 黄色
|
||
- 关联表逻辑:从应付票据表筛选是否逾期为“是”的记录,按所属集团编码、所属集团名称分组统计逾期票据金额和逾期票据数。
|
||
- 系统固化逻辑:值1:逾期票据金额=SUM(应付票据.票面金额);值2:逾期票据数=COUNT(应付票据.票据编号);值3:按金额和笔数阈值输出风险等级。
|
||
|
||
输出格式:
|
||
请严格使用以下键名,每个键独占一行;多行内容写在下一行,并保持到下一个键名前结束。
|
||
不要输出 JSON,不要输出 Markdown 代码块。
|
||
也可以使用括号内中文键名,但必须 9 个字段全部输出。
|
||
rule_name:
|
||
risk_description:
|
||
policy_basis:
|
||
business_rule_description:
|
||
data_sources:
|
||
system_rule_text:
|
||
join_logic:
|
||
system_logic:
|
||
return_fields:
|
||
|
||
中文键名等价写法:
|
||
规则名称:
|
||
风险描述:
|
||
制度依据:
|
||
业务规则描述:
|
||
数据来源:
|
||
系统规则文本:
|
||
关联表逻辑:
|
||
系统固化逻辑:
|
||
返回结果:"""
|
||
|
||
def _build_indicator_guidance(
|
||
self,
|
||
domain: str,
|
||
pattern: dict[str, Any],
|
||
schema_candidates: list[dict[str, Any]],
|
||
) -> str:
|
||
text = " ".join([
|
||
domain,
|
||
pattern.get("description_pattern", ""),
|
||
pattern.get("basis_text", ""),
|
||
pattern.get("source_sentence", ""),
|
||
" ".join(self._normalize_keywords(pattern)),
|
||
" ".join(field.get("name", "") for table in schema_candidates for field in table.get("fields", [])),
|
||
])
|
||
roles = self._schema_indicator_roles(schema_candidates)
|
||
guidance = [
|
||
"先判断本条制度最适合的指标类型,不要默认生成 A/B 比率。",
|
||
"参考国能 POC:金额/余额、数量/笔数、期限/天数、状态标记、占比/比例、收益率都应按业务场景混合出现。",
|
||
]
|
||
|
||
preferred: list[str] = []
|
||
if any(word in text for word in ("逾期", "到期", "兑付", "超期")):
|
||
preferred.extend([
|
||
"金额类:逾期金额/到期金额/兑付金额 = SUM(相关金额字段)",
|
||
"数量类:逾期笔数/到期笔数 = COUNT(满足状态或日期条件的记录)",
|
||
"期限类:逾期天数/剩余期限 = 当前日期与到期日期、数据日期的差值",
|
||
])
|
||
if any(word in text for word in ("分散", "集中", "多元", "单一", "占比", "比例")):
|
||
preferred.extend([
|
||
"数量类:不同类型数量/机构数量/客商数量 = COUNT(DISTINCT 分类或对象字段)",
|
||
"占比类:单一对象金额占比/前N对象金额占比 = 分组金额 / 总金额",
|
||
])
|
||
if any(word in text for word in ("投资", "收益", "市值", "成本")):
|
||
preferred.extend([
|
||
"收益类:投资收益金额 = 市值 - 投资成本",
|
||
"收益率类:投资收益率 = (市值 - 投资成本) / NULLIF(投资成本, 0)",
|
||
"期限类:持有天数 = 当前日期 - 投资开始日期",
|
||
])
|
||
if any(word in text for word in ("交易", "资金", "薪酬", "回流", "发放")):
|
||
preferred.extend([
|
||
"金额类:单笔金额/单日总额/资金回流金额 = SUM 或直接取金额字段",
|
||
"次数类:交易次数/回流次数/发放笔数 = COUNT(交易流水或记录)",
|
||
"状态类:按收支标记、渠道、账户状态、内部标识筛选或分组",
|
||
])
|
||
if any(word in text for word in ("担保", "授信", "贷款", "债务", "负债")):
|
||
preferred.extend([
|
||
"余额类:担保余额/贷款余额/债券余额/授信余额 = SUM(余额字段)",
|
||
"金额类:超额金额/敞口金额 = 两个金额或余额口径相减",
|
||
"比例类只在制度明确要求规模占比、负债率、集中度时使用。",
|
||
])
|
||
|
||
if not preferred:
|
||
if roles["amount"]:
|
||
preferred.append("金额/余额类优先:围绕金额、余额、总额字段生成 SUM 指标和阈值。")
|
||
if roles["status"]:
|
||
preferred.append("状态类可用:围绕状态、是否、标识、类型字段生成异常标记或条件计数。")
|
||
if roles["date"]:
|
||
preferred.append("期限/天数类可用:围绕日期、时间、期限字段生成滞后天数、到期天数、持有天数。")
|
||
preferred.append("仅当文本明确出现比例、占比、率、集中度时,才优先使用比率指标。")
|
||
|
||
guidance.extend(self._dedupe_preserve_order(preferred)[:6])
|
||
guidance.append("业务规则描述中可以同时给出 1-2 个指标,例如“金额 + 笔数”“余额 + 天数”“数量 + 占比”,阈值要逐一对应。")
|
||
return "\n".join(f"- {item}" for item in guidance)
|
||
|
||
def _schema_indicator_roles(self, schema_candidates: list[dict[str, Any]]) -> dict[str, bool]:
|
||
names = " ".join(field.get("name", "") for table in schema_candidates for field in table.get("fields", []))
|
||
markers = " ".join(field.get("marker", "") for table in schema_candidates for field in table.get("fields", []))
|
||
return {
|
||
"amount": any(word in names for word in ("金额", "余额", "总额", "资产", "负债", "成本", "市值")) or any(
|
||
word in markers.upper() for word in ("AMOUNT", "BALANCE", "ASSET", "LIABILITY", "COST", "SZ", "ZZYE")
|
||
),
|
||
"status": any(word in names for word in ("状态", "是否", "标识", "类型", "性质", "分类")),
|
||
"date": any(word in names for word in ("日期", "时间", "期限", "到期", "账龄")),
|
||
}
|
||
|
||
@staticmethod
|
||
def _build_variant_guidance(pattern: dict[str, Any]) -> str:
|
||
variant_index = int(pattern.get("_variant_index") or 1)
|
||
variant_total = int(pattern.get("_variant_total") or 1)
|
||
if variant_total <= 1:
|
||
return "- 本监管点只生成 1 条规则,选择最贴近制度依据的一个指标口径。"
|
||
|
||
focus_cycle = [
|
||
"金额/余额/规模类指标,例如 SUM(金额或余额字段)、超额金额、风险敞口。",
|
||
"数量/笔数/次数类指标,例如 COUNT 记录数、COUNT(DISTINCT 对象)、异常笔数。",
|
||
"期限/天数/状态类指标,例如 到期天数、逾期天数、状态标记、是否异常。",
|
||
"占比/比例/收益率类指标,仅当制度或字段明确支持比率口径时使用。",
|
||
]
|
||
focus = focus_cycle[(variant_index - 1) % len(focus_cycle)]
|
||
return (
|
||
f"- 当前是同一监管点的第 {variant_index}/{variant_total} 条规则。\n"
|
||
f"- 本条优先侧重:{focus}\n"
|
||
"- 同一监管点的多条规则不得只改名称或阈值;必须使用不同指标、字段组合或筛选口径。"
|
||
)
|
||
|
||
@staticmethod
|
||
def _dedupe_preserve_order(items: list[str]) -> list[str]:
|
||
result = []
|
||
for item in items:
|
||
if item not in result:
|
||
result.append(item)
|
||
return result
|
||
|
||
@staticmethod
|
||
def _parse_json_object(response: str) -> dict[str, Any] | None:
|
||
content = re.sub(r"^```(?:json)?|```$", "", (response or "").strip(), flags=re.IGNORECASE)
|
||
if not content.startswith("{"):
|
||
return None
|
||
try:
|
||
data = json.loads(content)
|
||
return data if isinstance(data, dict) else None
|
||
except json.JSONDecodeError:
|
||
match = re.search(r"\{.*\}", content, flags=re.DOTALL)
|
||
if not match:
|
||
return None
|
||
try:
|
||
data = json.loads(match.group(0))
|
||
return data if isinstance(data, dict) else None
|
||
except json.JSONDecodeError:
|
||
return None
|
||
|
||
def _parse_rule_response(self, response: str) -> dict[str, Any] | None:
|
||
json_result = self._parse_json_object(response)
|
||
if json_result:
|
||
return self._normalize_response_keys(json_result)
|
||
|
||
content = re.sub(r"^```(?:yaml|text)?|```$", "", (response or "").strip(), flags=re.IGNORECASE)
|
||
key_aliases = self._rule_key_aliases()
|
||
key_pattern = "|".join(re.escape(key) for key in key_aliases)
|
||
matches = list(re.finditer(rf"(?m)^({key_pattern})\s*[::]\s*", content))
|
||
if not matches:
|
||
return None
|
||
|
||
parsed: dict[str, str] = {}
|
||
for index, match in enumerate(matches):
|
||
key = key_aliases[match.group(1)]
|
||
start = match.end()
|
||
end = matches[index + 1].start() if index + 1 < len(matches) else len(content)
|
||
parsed[key] = content[start:end].strip()
|
||
return parsed
|
||
|
||
@staticmethod
|
||
def _rule_key_aliases() -> dict[str, str]:
|
||
return {
|
||
"rule_name": "rule_name",
|
||
"规则名称": "rule_name",
|
||
"risk_description": "risk_description",
|
||
"风险描述": "risk_description",
|
||
"policy_basis": "policy_basis",
|
||
"制度依据": "policy_basis",
|
||
"business_rule_description": "business_rule_description",
|
||
"业务规则描述": "business_rule_description",
|
||
"data_sources": "data_sources",
|
||
"数据来源": "data_sources",
|
||
"system_rule_text": "system_rule_text",
|
||
"系统规则文本": "system_rule_text",
|
||
"join_logic": "join_logic",
|
||
"关联表逻辑": "join_logic",
|
||
"system_logic": "system_logic",
|
||
"系统固化逻辑": "system_logic",
|
||
"return_fields": "return_fields",
|
||
"返回结果": "return_fields",
|
||
}
|
||
|
||
@classmethod
|
||
def _normalize_response_keys(cls, raw: dict[str, Any]) -> dict[str, Any]:
|
||
aliases = cls._rule_key_aliases()
|
||
normalized: dict[str, Any] = {}
|
||
for key, value in raw.items():
|
||
normalized[aliases.get(str(key), str(key))] = value
|
||
return normalized
|
||
|
||
def _normalize_rule(
|
||
self,
|
||
raw: dict[str, Any],
|
||
domain: str,
|
||
pattern: dict[str, Any],
|
||
schema_candidates: list[dict[str, Any]],
|
||
) -> dict[str, Any]:
|
||
required_fields = {
|
||
"rule_name": "规则名称",
|
||
"risk_description": "风险描述",
|
||
"policy_basis": "制度依据",
|
||
"business_rule_description": "业务规则描述",
|
||
"data_sources": "数据来源",
|
||
"system_rule_text": "系统规则文本",
|
||
"join_logic": "关联表逻辑",
|
||
"system_logic": "系统固化逻辑",
|
||
"return_fields": "返回结果",
|
||
}
|
||
core_fields = ("rule_name", "risk_description", "business_rule_description")
|
||
missing = [required_fields[key] for key in core_fields if not str(raw.get(key, "")).strip()]
|
||
if missing:
|
||
raise ValueError(f"LLM 返回规则字段不完整: {', '.join(missing)}")
|
||
|
||
normalized = {
|
||
"risk_domain": domain,
|
||
**{key: self._model_value_to_text(raw.get(key, "")).strip() for key in required_fields},
|
||
}
|
||
normalized["policy_basis"] = self._build_policy_basis_text(pattern)
|
||
normalized["business_rule_description"] = self._normalize_business_rule_description(
|
||
normalized["business_rule_description"]
|
||
)
|
||
if not normalized["data_sources"]:
|
||
normalized["data_sources"] = self._format_data_sources(schema_candidates)
|
||
normalized["data_sources"] = self._format_data_sources_with_markers(
|
||
normalized["data_sources"],
|
||
schema_candidates,
|
||
)
|
||
if not self._looks_like_join_logic(normalized["join_logic"]):
|
||
normalized["join_logic"] = self._format_join_logic(schema_candidates)
|
||
if not self._looks_like_system_logic(normalized["system_logic"]):
|
||
normalized["system_logic"] = self._format_system_logic_from_business_rule(
|
||
normalized["business_rule_description"]
|
||
)
|
||
metric_name = self._primary_metric_name(normalized["business_rule_description"])
|
||
if not normalized["return_fields"]:
|
||
output_fields = self._pick_fields_by_role(schema_candidates, ("subject", "amount", "status", "date"), 6)
|
||
normalized["return_fields"] = self._format_return_fields(output_fields, metric_name)
|
||
normalized["system_rule_text"] = self._normalize_system_rule_text(normalized["system_rule_text"])
|
||
if not self._looks_like_system_rule_text(normalized["system_rule_text"]):
|
||
normalized["system_rule_text"] = self._format_system_rule_text_from_business_rule(
|
||
normalized["rule_name"],
|
||
normalized["business_rule_description"],
|
||
)
|
||
self._validate_rule_shape(normalized)
|
||
return normalized
|
||
|
||
@classmethod
|
||
def _model_value_to_text(cls, value: Any) -> str:
|
||
if value is None:
|
||
return ""
|
||
if isinstance(value, str):
|
||
return value
|
||
if isinstance(value, (int, float, bool)):
|
||
return str(value)
|
||
if isinstance(value, list):
|
||
if all(not isinstance(item, (dict, list)) for item in value):
|
||
return ", ".join(cls._model_value_to_text(item) for item in value)
|
||
return "\n".join(cls._model_value_to_text(item) for item in value if cls._model_value_to_text(item))
|
||
if isinstance(value, dict):
|
||
table = value.get("table") or value.get("table_name") or value.get("module_name")
|
||
fields = value.get("fields")
|
||
if table and fields:
|
||
if isinstance(fields, list):
|
||
field_text = ", ".join(cls._model_value_to_text(item) for item in fields)
|
||
else:
|
||
field_text = cls._model_value_to_text(fields)
|
||
return f"{table}: {field_text}"
|
||
return "\n".join(f"{key}:{cls._model_value_to_text(item)}" for key, item in value.items())
|
||
return str(value)
|
||
|
||
@staticmethod
|
||
def _normalize_business_rule_description(text: str) -> str:
|
||
raw = (text or "").replace("\\n", "\n").strip()
|
||
if "\n" in raw:
|
||
lines = [re.sub(r"\s+", " ", line).strip() for line in raw.splitlines() if line.strip()]
|
||
content = " ".join(lines)
|
||
else:
|
||
content = re.sub(r"\s+", " ", raw).strip()
|
||
if not content:
|
||
return ""
|
||
|
||
split_patterns = [
|
||
r"(风险等级[::])",
|
||
r"(指标阈值[::])",
|
||
r"(预警标准[::])",
|
||
r"((?:^|[。;; ])红色[::])",
|
||
r"((?:^|[。;; ])黄色[::])",
|
||
r"(-\s*红色[::])",
|
||
]
|
||
for pattern in split_patterns:
|
||
match = re.search(pattern, content)
|
||
if match and match.start() > 0:
|
||
formula = content[:match.start()].strip("。;; ")
|
||
thresholds = content[match.start():].strip()
|
||
return (
|
||
f"{RuleGenerationService._format_formula_lines(formula)}\n"
|
||
f"{RuleGenerationService._format_threshold_lines(thresholds)}"
|
||
)
|
||
return content
|
||
|
||
@staticmethod
|
||
def _format_formula_lines(text: str) -> str:
|
||
content = (text or "").strip("。;; ")
|
||
if not content:
|
||
return ""
|
||
parts = re.split(r"[;;]\s*(?=[^;;。\n]{1,40}=)", content)
|
||
if len(parts) == 1:
|
||
parts = re.split(r"(?<=。)\s*(?=[^。;;\n]{1,40}=)", content)
|
||
lines = [part.strip("。;; ") for part in parts if part.strip("。;; ")]
|
||
return "\n".join(lines)
|
||
|
||
@staticmethod
|
||
def _format_threshold_lines(text: str) -> str:
|
||
content = (text or "").strip()
|
||
if not content:
|
||
return ""
|
||
content = re.sub(r"^风险等级[::]\s*", "风险等级:", content)
|
||
for color in ("红色", "橙色", "黄色"):
|
||
content = re.sub(rf"[;;]\s*{color}[::]", f"\n- {color}:", content)
|
||
content = re.sub(rf"\s+-\s*{color}[::]", f"\n- {color}:", content)
|
||
content = re.sub(rf"(?<![-\n]){color}[::]", f"\n- {color}:", content)
|
||
content = re.sub(r"风险等级[::]\s*\n", "风险等级:\n", content)
|
||
content = re.sub(r"风险等级[::]\s*-", "风险等级:\n-", content)
|
||
content = re.sub(r"\n{2,}", "\n", content)
|
||
return content.strip()
|
||
|
||
@staticmethod
|
||
def _normalize_data_sources_text(data_sources: str) -> str:
|
||
text = re.sub(r"\s+", " ", data_sources or "").strip()
|
||
if not text:
|
||
return ""
|
||
text = text.replace(";", ";")
|
||
segments = [segment.strip(" ;") for segment in text.split(";") if segment.strip(" ;")]
|
||
if len(segments) <= 1:
|
||
return data_sources.strip()
|
||
return "\n".join(segments)
|
||
|
||
def _format_data_sources_with_markers(
|
||
self,
|
||
data_sources: str,
|
||
schema_candidates: list[dict[str, Any]],
|
||
) -> str:
|
||
normalized = self._normalize_data_sources_text(data_sources)
|
||
parsed_sources = self._parse_data_source_lines(normalized)
|
||
if not parsed_sources:
|
||
return normalized
|
||
|
||
lines = []
|
||
for table_label, field_labels in parsed_sources:
|
||
module, index = self._match_schema_module(table_label, schema_candidates)
|
||
if not module:
|
||
fields = ", ".join(field_labels)
|
||
lines.append(f"{table_label}: {fields}" if fields else table_label)
|
||
continue
|
||
|
||
table_display = self._data_source_table_display(module, index)
|
||
field_displays = self._data_source_field_displays(field_labels, module)
|
||
if not field_displays:
|
||
field_displays = [
|
||
display
|
||
for field in module.get("fields", [])[:4]
|
||
if (display := self._data_source_field_display(field))
|
||
]
|
||
lines.append(f"{table_display}: {', '.join(field_displays)}")
|
||
return "\n".join(lines)
|
||
|
||
def _data_source_table_display(self, module: dict[str, Any], index: int) -> str:
|
||
table_name = self._clean_module_name(module.get("module_name", ""))
|
||
return f"{self._schema_table_identifier(module, index)}({table_name})"
|
||
|
||
def _data_source_field_displays(self, field_labels: list[str], module: dict[str, Any]) -> list[str]:
|
||
displays = []
|
||
for label in field_labels:
|
||
field = self._match_schema_field(label, module)
|
||
display = self._data_source_field_display(field) if field else label
|
||
if display and display not in displays:
|
||
displays.append(display)
|
||
return displays
|
||
|
||
def _match_schema_field(self, label: str, module: dict[str, Any]) -> dict[str, Any] | None:
|
||
safe_label = self._safe_sql_identifier(label, "")
|
||
for field in module.get("fields", []):
|
||
marker = self._safe_sql_identifier(field.get("marker", ""), "")
|
||
name = str(field.get("name", "") or "")
|
||
if safe_label and marker == safe_label:
|
||
return field
|
||
if label and name and (label == name or label in name or name in label):
|
||
return field
|
||
return None
|
||
|
||
def _data_source_field_display(self, field: dict[str, Any]) -> str:
|
||
marker = self._safe_sql_identifier(field.get("marker", ""), "")
|
||
name = str(field.get("name", "") or "").strip()
|
||
if marker and name:
|
||
return f"{marker}({name})"
|
||
return marker or name
|
||
|
||
@staticmethod
|
||
def _normalize_system_rule_text(text: str) -> str:
|
||
raw = (text or "").replace("\\n", "\n").strip()
|
||
if not raw:
|
||
return ""
|
||
return "\n".join(re.sub(r"\s+", " ", line).strip() for line in raw.splitlines() if line.strip())
|
||
|
||
@staticmethod
|
||
def _looks_like_system_rule_text(text: str) -> bool:
|
||
content = text or ""
|
||
return all(keyword in content for keyword in ("如果", "就")) and any(
|
||
keyword in content for keyword in ("赋值", "校验")
|
||
)
|
||
|
||
@staticmethod
|
||
def _looks_like_join_logic(text: str) -> bool:
|
||
content = text or ""
|
||
return ("表" in content or "TABLE_" in content.upper()) and any(
|
||
keyword in content for keyword in ("关联", "筛选", "分组", "汇总", "统计")
|
||
)
|
||
|
||
@staticmethod
|
||
def _looks_like_system_logic(text: str) -> bool:
|
||
content = text or ""
|
||
return "值1" in content and any(keyword in content for keyword in ("值2", "阈值", "风险等级"))
|
||
|
||
@staticmethod
|
||
def _formula_metric_names(business_rule: str) -> list[str]:
|
||
formula_part = (business_rule or "").split("风险等级", 1)[0]
|
||
metric_names = []
|
||
for line in formula_part.splitlines():
|
||
if "=" not in line:
|
||
continue
|
||
metric_name = line.split("=", 1)[0].strip()
|
||
metric_name = re.sub(
|
||
r"^(?:指标[一二三四五六七八九十\d]+|公式[一二三四五六七八九十\d]+|值\d+)\s*[::、.))-]*\s*",
|
||
"",
|
||
metric_name,
|
||
).strip()
|
||
if metric_name and metric_name not in metric_names:
|
||
metric_names.append(metric_name)
|
||
return metric_names
|
||
|
||
@staticmethod
|
||
def _primary_metric_name(business_rule: str) -> str:
|
||
metric_names = RuleGenerationService._formula_metric_names(business_rule)
|
||
return metric_names[0] if metric_names else "核心指标"
|
||
|
||
@staticmethod
|
||
def _threshold_part(business_rule: str) -> str:
|
||
return business_rule.split("风险等级", 1)[1] if "风险等级" in business_rule else ""
|
||
|
||
@staticmethod
|
||
def _threshold_lines(business_rule: str) -> list[str]:
|
||
threshold_part = RuleGenerationService._threshold_part(business_rule)
|
||
return [
|
||
line.strip()
|
||
for line in threshold_part.splitlines()
|
||
if line.strip() and any(color in line for color in ("红色", "橙色", "黄色"))
|
||
]
|
||
|
||
@staticmethod
|
||
def _format_system_logic_from_business_rule(business_rule: str) -> str:
|
||
formula_part = (business_rule or "").split("风险等级", 1)[0]
|
||
formula_lines = [line.strip() for line in formula_part.splitlines() if "=" in line]
|
||
segments = []
|
||
for index, line in enumerate(formula_lines[:3], start=1):
|
||
metric_name, expression = [part.strip() for part in line.split("=", 1)]
|
||
segments.append(f"值{index}:{metric_name}={expression}")
|
||
if not segments:
|
||
segments.append("值1:核心指标=按数据来源字段汇总计算")
|
||
if len(segments) == 1:
|
||
segments.append("值2:风险阈值=按红色、橙色、黄色等级阈值判断")
|
||
return ";".join(segments) + ";按阈值输出风险等级。"
|
||
|
||
@staticmethod
|
||
def _format_system_rule_text_from_business_rule(rule_name: str, business_rule: str) -> str:
|
||
threshold_lines = RuleGenerationService._threshold_lines(business_rule)
|
||
conditions = []
|
||
for line in threshold_lines[:3]:
|
||
condition = re.sub(r"^\s*[-*]\s*", "", line).strip()
|
||
if condition:
|
||
conditions.append(f" ({condition})")
|
||
if not conditions:
|
||
conditions.append(f" ({RuleGenerationService._primary_metric_name(business_rule)}达到风险阈值)")
|
||
condition_block = "\n 或者 ".join(conditions)
|
||
return (
|
||
f"{rule_name}(风险阈值命中):\n"
|
||
"如果:\n"
|
||
f"{condition_block}\n"
|
||
"就:\n"
|
||
"校验:通过\n"
|
||
"赋值:【规则结果-风险等级】等于 对应风险等级"
|
||
)
|
||
|
||
@staticmethod
|
||
def _validate_rule_shape(rule: dict[str, str]) -> None:
|
||
business_rule = rule.get("business_rule_description", "")
|
||
system_rule = rule.get("system_rule_text", "")
|
||
join_logic = rule.get("join_logic", "")
|
||
system_logic = rule.get("system_logic", "")
|
||
|
||
if " = " not in business_rule and "=" not in business_rule:
|
||
raise ValueError("LLM 返回的业务规则描述缺少指标公式")
|
||
if "\n" not in business_rule:
|
||
raise ValueError("LLM 返回的业务规则描述未按公式和指标阈值分行")
|
||
if not all(keyword in business_rule for keyword in ("红色", "橙色", "黄色")):
|
||
raise ValueError("LLM 返回的业务规则描述缺少红/橙/黄风险等级")
|
||
RuleGenerationService._validate_business_rule_indicator_alignment(business_rule)
|
||
if not RuleGenerationService._looks_like_system_rule_text(system_rule):
|
||
raise ValueError("LLM 返回的系统规则文本未按 H 列规则伪代码格式生成")
|
||
if not RuleGenerationService._looks_like_join_logic(join_logic):
|
||
raise ValueError("LLM 返回的关联表逻辑缺少表关联、筛选或分组口径")
|
||
if not RuleGenerationService._looks_like_system_logic(system_logic):
|
||
raise ValueError("LLM 返回的系统固化逻辑缺少值1/值2中间指标定义")
|
||
|
||
@staticmethod
|
||
def _validate_business_rule_indicator_alignment(business_rule: str) -> None:
|
||
metric_names = RuleGenerationService._formula_metric_names(business_rule)
|
||
if not metric_names:
|
||
raise ValueError("LLM 返回的业务规则描述缺少可识别指标名称")
|
||
threshold_part = RuleGenerationService._threshold_part(business_rule)
|
||
missing = [
|
||
name
|
||
for name in metric_names
|
||
if not RuleGenerationService._metric_name_matches_threshold(name, threshold_part)
|
||
]
|
||
if missing:
|
||
raise ValueError(f"业务规则描述中的指标阈值未对应公式指标: {', '.join(missing)}")
|
||
|
||
@staticmethod
|
||
def _metric_name_matches_threshold(metric_name: str, threshold_part: str) -> bool:
|
||
metric = RuleGenerationService._normalize_metric_text(metric_name)
|
||
threshold = RuleGenerationService._normalize_metric_text(threshold_part)
|
||
if not metric or metric in threshold:
|
||
return True
|
||
category_keywords = RuleGenerationService._metric_category_keywords(metric)
|
||
if category_keywords and not any(keyword in threshold for keyword in category_keywords):
|
||
return False
|
||
tokens = RuleGenerationService._metric_match_tokens(metric_name)
|
||
long_tokens = [token for token in tokens if len(token) >= 4]
|
||
if any(token in threshold for token in long_tokens):
|
||
return True
|
||
short_tokens = [token for token in tokens if len(token) < 4]
|
||
return (bool(category_keywords) or not long_tokens) and any(token in threshold for token in short_tokens)
|
||
|
||
@staticmethod
|
||
def _normalize_metric_text(text: str) -> str:
|
||
return re.sub(r"[^0-9A-Za-z\u4e00-\u9fff]+", "", text or "")
|
||
|
||
@staticmethod
|
||
def _metric_category_keywords(metric: str) -> tuple[str, ...]:
|
||
if any(keyword in metric for keyword in ("收益", "回报")):
|
||
return ("收益", "收益率", "回报")
|
||
if any(keyword in metric for keyword in ("金额", "余额", "总额", "规模", "额度")):
|
||
return ("金额", "余额", "总额", "规模", "额度", "债务", "投资")
|
||
if any(keyword in metric for keyword in ("笔数", "数量", "次数", "个数", "户数")):
|
||
return ("笔数", "数量", "次数", "个数", "户数")
|
||
if any(keyword in metric for keyword in ("占比", "比例", "比率", "负债率", "收益率", "集中度")):
|
||
return ("占比", "比例", "比率", "率", "集中度")
|
||
return ()
|
||
|
||
@staticmethod
|
||
def _metric_match_tokens(metric_name: str) -> list[str]:
|
||
base = RuleGenerationService._normalize_metric_text(metric_name)
|
||
variants = {base}
|
||
suffixes = (
|
||
"收益率",
|
||
"集中度",
|
||
"占比",
|
||
"比例",
|
||
"比率",
|
||
"金额",
|
||
"余额",
|
||
"总额",
|
||
"数量",
|
||
"笔数",
|
||
"次数",
|
||
"天数",
|
||
"规模",
|
||
"额度",
|
||
"水平",
|
||
"指标",
|
||
"风险",
|
||
"率",
|
||
)
|
||
for _ in range(2):
|
||
for candidate in list(variants):
|
||
for prefix in ("总", "集团"):
|
||
if candidate.startswith(prefix) and len(candidate) > len(prefix):
|
||
variants.add(candidate[len(prefix):])
|
||
for word in ("结构", "类别", "类型", "类"):
|
||
if word in candidate and len(candidate) > len(word):
|
||
variants.add(candidate.replace(word, ""))
|
||
for suffix in suffixes:
|
||
if candidate.endswith(suffix) and len(candidate) > len(suffix):
|
||
variants.add(candidate[:-len(suffix)])
|
||
return [
|
||
token
|
||
for token in sorted(variants, key=len, reverse=True)
|
||
if len(token) >= 2 and token not in {"指标", "风险", "金额", "余额", "数量", "比例", "比率", "占比", "率"}
|
||
]
|
||
|
||
def _build_policy_basis_text(self, pattern: dict[str, Any]) -> str:
|
||
regulations = [str(item).strip() for item in (pattern.get("core_regulations") or []) if str(item).strip()]
|
||
description = (pattern.get("description_pattern") or "").strip()
|
||
if not regulations:
|
||
regulation_match = re.search(r"核心法规:([^;。]+)", description)
|
||
if regulation_match:
|
||
regulations = [regulation_match.group(1).strip()]
|
||
regulation_text = "、".join(regulations) or "相关监管制度"
|
||
basis_text = (pattern.get("basis_text") or "").strip()
|
||
if "条款要点:" in description:
|
||
clause = description.split("条款要点:", 1)[1].strip(";。 ")
|
||
else:
|
||
clause = basis_text.strip(";。 ")
|
||
if not clause:
|
||
clause = "围绕关键风险指标、主体责任和业务口径持续开展监测。"
|
||
clause = re.sub(r"\s+", "", clause)
|
||
return f"核心法规:{regulation_text};条款要点:{clause}"
|
||
|
||
@staticmethod
|
||
def _pick_fields_by_role(schema_candidates: list[dict[str, Any]], roles: tuple[str, ...], limit: int) -> list[str]:
|
||
selected: list[str] = []
|
||
for role in roles:
|
||
for table in schema_candidates:
|
||
for field in table.get("fields", []):
|
||
name = field.get("name", "")
|
||
marker = field.get("marker", "")
|
||
if not name:
|
||
continue
|
||
if role == "subject" and any(keyword in name for keyword in _FIELD_ROLE_KEYWORDS["subject"]):
|
||
item = f"{table['module_name']}.{name}"
|
||
elif role == "amount" and any(keyword in name for keyword in _FIELD_ROLE_KEYWORDS["amount"]):
|
||
item = f"{table['module_name']}.{name}"
|
||
elif role == "status" and any(keyword in name for keyword in _FIELD_ROLE_KEYWORDS["status"]):
|
||
item = f"{table['module_name']}.{name}"
|
||
elif role == "date" and any(keyword in name for keyword in _FIELD_ROLE_KEYWORDS["date"]):
|
||
item = f"{table['module_name']}.{name}"
|
||
elif role == "counterparty" and any(keyword in name for keyword in _FIELD_ROLE_KEYWORDS["counterparty"]):
|
||
item = f"{table['module_name']}.{name}"
|
||
else:
|
||
continue
|
||
if item not in selected:
|
||
selected.append(item)
|
||
if len(selected) >= limit:
|
||
return selected
|
||
if marker and len(selected) >= limit:
|
||
return selected
|
||
return selected
|
||
|
||
def _format_data_sources(self, schema_candidates: list[dict[str, Any]]) -> str:
|
||
parts = []
|
||
for table in schema_candidates:
|
||
important_fields = []
|
||
for field in table.get("fields", []):
|
||
name = field.get("name", "")
|
||
if not name:
|
||
continue
|
||
if any(keyword in name for keywords in _FIELD_ROLE_KEYWORDS.values() for keyword in keywords):
|
||
important_fields.append(field)
|
||
if not important_fields:
|
||
important_fields = table.get("fields", [])[:3]
|
||
concise = []
|
||
for field in important_fields:
|
||
label = field["name"]
|
||
if label not in concise:
|
||
concise.append(label)
|
||
if len(concise) >= 4:
|
||
break
|
||
if concise:
|
||
parts.append(f"{table['module_name']}:{'、'.join(concise)}")
|
||
return "\n".join(parts)
|
||
|
||
def _format_join_logic(self, schema_candidates: list[dict[str, Any]]) -> str:
|
||
if not schema_candidates:
|
||
return "未匹配到可关联表。"
|
||
primary = schema_candidates[0]["module_name"]
|
||
if len(schema_candidates) == 1:
|
||
return f"以{primary}为主表,按所属集团编码、所属集团名称或单位编码汇总。"
|
||
|
||
join_texts = []
|
||
for table in schema_candidates[1:]:
|
||
related_name = table["module_name"]
|
||
join_key = self._infer_join_key(schema_candidates[0], table)
|
||
join_texts.append(f"{related_name}({join_key})")
|
||
return f"以{primary}为主表,关联{'、'.join(join_texts)};按集团或单位维度汇总,必要时结合状态字段做筛选。"
|
||
|
||
def _infer_join_key(self, left: dict[str, Any], right: dict[str, Any]) -> str:
|
||
left_names = {field.get("name", "") for field in left.get("fields", [])}
|
||
right_names = {field.get("name", "") for field in right.get("fields", [])}
|
||
for key in ("所属集团编码", "开户单位编码", "单位编码", "客商编码", "合同编号", "单位账号"):
|
||
if key in left_names and key in right_names:
|
||
return key
|
||
return "所属集团编码/单位编码"
|
||
|
||
@staticmethod
|
||
def _format_input_params(schema_candidates: list[dict[str, Any]]) -> str:
|
||
selected = RuleGenerationService._pick_fields_by_role(schema_candidates, ("subject", "amount", "status", "date"), 6)
|
||
return ",".join(selected)
|
||
|
||
@staticmethod
|
||
def _format_output_params(fields: list[str], metric_name: str) -> str:
|
||
base = [field for field in fields if field][:4]
|
||
base.extend([metric_name, "风险等级"])
|
||
deduped = []
|
||
for item in base:
|
||
if item not in deduped:
|
||
deduped.append(item)
|
||
return ",".join(deduped)
|
||
|
||
@staticmethod
|
||
def _format_return_fields(fields: list[str], metric_name: str) -> str:
|
||
base = [field for field in fields if field][:6]
|
||
base.extend([metric_name, "风险等级"])
|
||
deduped = []
|
||
for item in base:
|
||
if item not in deduped:
|
||
deduped.append(item)
|
||
return ",".join(deduped)
|
||
|
||
def _build_rule_rows(self, rules: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||
schema = self._read_schema()
|
||
rows = []
|
||
for index, rule in enumerate(rules, start=1):
|
||
row = {
|
||
"index": index,
|
||
"risk_domain": rule.get("risk_domain", ""),
|
||
"rule_name": rule.get("rule_name", ""),
|
||
"risk_description": rule.get("risk_description", ""),
|
||
"policy_basis": rule.get("policy_basis", ""),
|
||
"business_rule_description": self._normalize_business_rule_description(
|
||
rule.get("business_rule_description", "")
|
||
),
|
||
"data_sources": self._format_data_sources_with_markers(rule.get("data_sources", ""), schema),
|
||
"system_rule_text": rule.get("system_rule_text", ""),
|
||
"join_logic": rule.get("join_logic", ""),
|
||
"system_logic": rule.get("system_logic", ""),
|
||
"return_fields": rule.get("return_fields", ""),
|
||
}
|
||
if self.create_sql:
|
||
row["sql"] = self._create_sql_query(rule)
|
||
rows.append(row)
|
||
return rows
|
||
|
||
def _write_excel(self, rules: list[dict[str, Any]], output_file: str) -> None:
|
||
workbook = Workbook()
|
||
worksheet = workbook.active
|
||
worksheet.title = "11类重点风险领域"
|
||
columns = self._excel_columns()
|
||
worksheet.append(columns)
|
||
|
||
for rule_row in self._build_rule_rows(rules):
|
||
row = [
|
||
rule_row["index"],
|
||
rule_row["risk_domain"],
|
||
rule_row["rule_name"],
|
||
rule_row["risk_description"],
|
||
rule_row["policy_basis"],
|
||
rule_row["business_rule_description"],
|
||
rule_row["data_sources"],
|
||
rule_row["system_rule_text"],
|
||
rule_row["join_logic"],
|
||
rule_row["system_logic"],
|
||
rule_row["return_fields"],
|
||
]
|
||
if self.create_sql:
|
||
row.append(rule_row.get("sql", ""))
|
||
worksheet.append(row)
|
||
|
||
worksheet.freeze_panes = "A2"
|
||
worksheet.auto_filter.ref = worksheet.dimensions
|
||
worksheet.sheet_view.zoomScale = 90
|
||
worksheet.row_dimensions[1].height = 13.5
|
||
|
||
widths = self._excel_column_widths()
|
||
for column_index, width in enumerate(widths, start=1):
|
||
worksheet.column_dimensions[get_column_letter(column_index)].width = width
|
||
|
||
for cell in worksheet[1]:
|
||
cell.font = Font(name="宋体", size=11, bold=True)
|
||
cell.fill = _HIGHLIGHT_FILL if cell.column in _HIGHLIGHT_COLUMNS else _HEADER_FILL
|
||
cell.alignment = Alignment(horizontal="center", vertical="center", wrap_text=True)
|
||
cell.border = _THIN_BORDER
|
||
|
||
for row_index, row in enumerate(worksheet.iter_rows(min_row=2), start=2):
|
||
worksheet.row_dimensions[row_index].height = self._estimate_row_height(row)
|
||
for cell in row:
|
||
horizontal = "center" if cell.column in _CENTER_COLUMNS else "left"
|
||
cell.font = Font(name="宋体", size=11)
|
||
cell.alignment = Alignment(horizontal=horizontal, vertical="center", wrap_text=True)
|
||
cell.border = _THIN_BORDER
|
||
|
||
output_path = Path(output_file)
|
||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
workbook.save(output_file)
|
||
|
||
@staticmethod
|
||
def _markdown_file_for_excel(output_file: str) -> str:
|
||
return str(Path(output_file).with_suffix(".md"))
|
||
|
||
def _write_markdown(self, rules: list[dict[str, Any]], markdown_file: str) -> None:
|
||
grouped_rules: dict[str, list[dict[str, Any]]] = {}
|
||
for rule_row in self._build_rule_rows(rules):
|
||
domain = str(rule_row.get("risk_domain", "") or "未分类")
|
||
grouped_rules.setdefault(domain, []).append(rule_row)
|
||
|
||
lines = [
|
||
"# 重点风险领域规则模型与脚本",
|
||
"",
|
||
"> 根据本次规则生成任务产出的 Excel 内容同步整理,目录和章节以实际生成的风险领域为准。",
|
||
"",
|
||
"---",
|
||
"",
|
||
"## 目录",
|
||
"",
|
||
]
|
||
if grouped_rules:
|
||
for domain_index, domain in enumerate(grouped_rules, start=1):
|
||
lines.append(f"{domain_index}. [{domain}](#{domain_index}-{domain})")
|
||
else:
|
||
lines.append("- 暂无生成成功的规则")
|
||
lines.extend(["", "---", ""])
|
||
|
||
for domain_index, (domain, domain_rules) in enumerate(grouped_rules.items(), start=1):
|
||
first_rule = domain_rules[0]
|
||
lines.extend([
|
||
f"## {domain_index}. {domain}",
|
||
"",
|
||
"### 风险描述",
|
||
self._markdown_value(first_rule.get("risk_description", "")),
|
||
"",
|
||
"### 数据来源",
|
||
])
|
||
data_source_lines = self._domain_data_source_lines(domain_rules)
|
||
if data_source_lines:
|
||
lines.extend(f"- {source}" for source in data_source_lines)
|
||
else:
|
||
lines.append("- 无")
|
||
lines.extend(["", "---", ""])
|
||
|
||
for rule_position, rule in enumerate(domain_rules, start=1):
|
||
lines.extend([
|
||
f"### 规则模型{domain_index}.{rule_position}:{self._markdown_value(rule.get('rule_name', ''))}",
|
||
"",
|
||
f"**序号**:{rule['index']}",
|
||
"",
|
||
f"**风险描述**:{self._markdown_value(rule.get('risk_description', ''))}",
|
||
"",
|
||
f"**制度依据**:{self._markdown_value(rule.get('policy_basis', ''))}",
|
||
"",
|
||
"**业务规则描述**:",
|
||
"",
|
||
self._markdown_value(rule.get("business_rule_description", "")),
|
||
"",
|
||
"**数据来源**:",
|
||
"",
|
||
self._markdown_value(rule.get("data_sources", "")),
|
||
"",
|
||
"**系统规则文本**:",
|
||
"",
|
||
self._markdown_value(rule.get("system_rule_text", "")),
|
||
"",
|
||
f"**关联表逻辑**:{self._markdown_value(rule.get('join_logic', ''))}",
|
||
"",
|
||
f"**系统固化逻辑**:{self._markdown_value(rule.get('system_logic', ''))}",
|
||
"",
|
||
f"**返回结果**:{self._markdown_value(rule.get('return_fields', ''))}",
|
||
"",
|
||
])
|
||
if self.create_sql:
|
||
lines.extend([
|
||
"**实现脚本(SQL)**:",
|
||
"",
|
||
"```sql",
|
||
self._markdown_value(rule.get("sql", "")),
|
||
"```",
|
||
"",
|
||
])
|
||
lines.extend(["---", ""])
|
||
|
||
output_path = Path(markdown_file)
|
||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
output_path.write_text("\n".join(lines).rstrip() + "\n", encoding="utf-8")
|
||
|
||
def _domain_data_source_lines(
|
||
self,
|
||
domain_rules: list[dict[str, Any]],
|
||
) -> list[str]:
|
||
sources: list[str] = []
|
||
seen: set[str] = set()
|
||
for rule in domain_rules:
|
||
formatted = rule.get("data_sources", "")
|
||
for line in str(formatted or "").splitlines():
|
||
source = line.strip()
|
||
if source and source not in seen:
|
||
sources.append(source)
|
||
seen.add(source)
|
||
return sources
|
||
|
||
@staticmethod
|
||
def _markdown_value(value: Any) -> str:
|
||
text = str(value or "").strip()
|
||
return text if text else "无"
|
||
|
||
def _excel_columns(self) -> list[str]:
|
||
columns = list(_RULE_COLUMNS)
|
||
if self.create_sql:
|
||
columns.append("SQL查询语句")
|
||
return columns
|
||
|
||
def _excel_column_widths(self) -> list[float]:
|
||
widths = list(_COLUMN_WIDTHS)
|
||
if self.create_sql:
|
||
widths.append(72)
|
||
return widths
|
||
|
||
def _create_sql_query(self, rule: dict[str, Any]) -> str:
|
||
schema = self._read_schema()
|
||
tables = self._sql_tables_from_rule(rule, schema)
|
||
if not tables and schema:
|
||
tables = [self._schema_table_context(schema[0], 1)]
|
||
if not tables:
|
||
tables = [{"sql_name": "TABLE_1", "alias": "t1", "fields": []}]
|
||
|
||
primary_table = tables[0]
|
||
select_markers = self._sql_field_markers(rule, tables)
|
||
if not select_markers:
|
||
select_markers = [field["marker"] for field in primary_table["fields"][:4]]
|
||
if not select_markers:
|
||
select_markers = ["COMCODE", "COMNAME", "DAYIDX"]
|
||
|
||
metric_alias = self._sql_metric_alias(rule)
|
||
metric_expression = self._sql_metric_expression(rule, tables)
|
||
metric_components = self._sql_metric_components(metric_expression)
|
||
risk_cases = self._sql_risk_cases(rule, metric_expression)
|
||
join_key = self._sql_join_key(tables)
|
||
select_lines = [f" {self._qualified_marker(marker, tables)}" for marker in select_markers[:8]]
|
||
for index, component in enumerate(metric_components, start=1):
|
||
select_lines.append(f" {component} AS VALUE_{index}")
|
||
if metric_alias not in select_markers:
|
||
select_lines.append(f" {metric_expression} AS {metric_alias}")
|
||
select_lines.append(f" {risk_cases} AS RISK_LEVEL")
|
||
|
||
joins = self._sql_join_lines(tables, select_markers, join_key)
|
||
sql = "SELECT\n" + ",\n".join(select_lines)
|
||
sql += f"\nFROM {self._sql_aggregate_subquery(primary_table, select_markers, join_key)}"
|
||
if joins:
|
||
sql += "\n" + "\n".join(joins)
|
||
return self._validate_generated_sql(sql + ";")
|
||
|
||
@staticmethod
|
||
def _validate_generated_sql(sql: str) -> str:
|
||
if re.search(r"^\s*WITH\b", sql, flags=re.IGNORECASE):
|
||
raise ValueError("Generated SQL uses WITH/CTE, which is not allowed for MySQL compatibility.")
|
||
if re.search(r"[\u4e00-\u9fff]", sql):
|
||
raise ValueError("Generated SQL contains Chinese text.")
|
||
if re.search(r"SUM\s*\([^)]*(DATE|TIME|DAYIDX)[^)]*\)", sql, flags=re.IGNORECASE):
|
||
raise ValueError("Generated SQL tries to SUM date/time fields.")
|
||
if re.search(r"^\s*t\d+\s+AS\s*\(", sql, flags=re.IGNORECASE | re.MULTILINE):
|
||
raise ValueError("Generated SQL contains CTE-style table aliases.")
|
||
return sql
|
||
|
||
def _sql_tables_from_rule(self, rule: dict[str, Any], schema: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||
contexts = []
|
||
seen = set()
|
||
parsed_sources = self._parse_data_source_lines(str(rule.get("data_sources", "") or ""))
|
||
for table_label, _ in parsed_sources:
|
||
module, index = self._match_schema_module(table_label, schema)
|
||
if module:
|
||
context = self._schema_table_context(module, index)
|
||
if context["sql_name"] not in seen:
|
||
contexts.append(context)
|
||
seen.add(context["sql_name"])
|
||
if contexts:
|
||
return contexts[:4]
|
||
|
||
text = "\n".join(str(rule.get(key, "") or "") for key in ("data_sources", "return_fields", "system_logic"))
|
||
for index, module in enumerate(schema, start=1):
|
||
module_name = self._clean_module_name(module.get("module_name", ""))
|
||
if module_name and module_name in text:
|
||
context = self._schema_table_context(module, index)
|
||
if context["sql_name"] not in seen:
|
||
contexts.append(context)
|
||
seen.add(context["sql_name"])
|
||
return contexts[:4]
|
||
|
||
@staticmethod
|
||
def _parse_data_source_lines(data_sources: str) -> list[tuple[str, list[str]]]:
|
||
tables = []
|
||
for line in data_sources.splitlines():
|
||
raw_line = line.strip()
|
||
if not raw_line:
|
||
continue
|
||
split_at = -1
|
||
for separator in (":", ":", "锛"):
|
||
split_at = raw_line.find(separator)
|
||
if split_at >= 0:
|
||
break
|
||
if split_at < 0:
|
||
continue
|
||
table_label = raw_line[:split_at].strip()
|
||
field_text = raw_line[split_at + 1:].strip()
|
||
field_labels = [item.strip() for item in re.split(r"[,,、;;\s]+", field_text) if item.strip()]
|
||
if table_label:
|
||
tables.append((table_label, field_labels))
|
||
return tables
|
||
|
||
def _match_schema_module(
|
||
self,
|
||
table_label: str,
|
||
schema: list[dict[str, Any]],
|
||
) -> tuple[dict[str, Any] | None, int]:
|
||
clean_label = self._clean_module_name(table_label)
|
||
for index, module in enumerate(schema, start=1):
|
||
module_name = self._clean_module_name(module.get("module_name", ""))
|
||
if clean_label and module_name and (
|
||
clean_label == module_name or clean_label in module_name or module_name in clean_label
|
||
):
|
||
return module, index
|
||
return None, 0
|
||
|
||
def _schema_table_context(self, module: dict[str, Any], index: int) -> dict[str, Any]:
|
||
fields = []
|
||
seen = set()
|
||
for field in module.get("fields", []):
|
||
marker = self._safe_sql_identifier(field.get("marker", ""), "")
|
||
if marker and marker not in seen:
|
||
fields.append({"name": str(field.get("name", "") or ""), "marker": marker})
|
||
seen.add(marker)
|
||
return {
|
||
"sql_name": self._schema_table_identifier(module, index),
|
||
"alias": f"t{max(index, 1)}",
|
||
"fields": fields,
|
||
}
|
||
|
||
def _schema_table_identifier(self, module: dict[str, Any], index: int) -> str:
|
||
table_name = self._safe_table_identifier(module.get("table_name", ""))
|
||
if table_name:
|
||
return table_name
|
||
raw_name = self._clean_module_name(module.get("module_name", ""))
|
||
sql_name = self._safe_sql_identifier(raw_name, "")
|
||
return sql_name or f"TABLE_{max(index, 1):02d}"
|
||
|
||
def _sql_field_markers(self, rule: dict[str, Any], tables: list[dict[str, Any]]) -> list[str]:
|
||
markers = []
|
||
parsed_sources = self._parse_data_source_lines(str(rule.get("data_sources", "") or ""))
|
||
for table_label, field_labels in parsed_sources:
|
||
table = self._table_for_label(table_label, tables) or tables[0]
|
||
markers.extend(self._markers_for_field_labels(field_labels, table))
|
||
|
||
return_fields = str(rule.get("return_fields", "") or "")
|
||
markers.extend(self._markers_for_field_labels(self._field_tokens(return_fields), tables[0]))
|
||
|
||
available = {field["marker"] for table in tables for field in table["fields"]}
|
||
deduped = []
|
||
for marker in markers:
|
||
safe_marker = self._safe_sql_identifier(marker, "")
|
||
if safe_marker and safe_marker in available and safe_marker not in deduped:
|
||
deduped.append(safe_marker)
|
||
return deduped
|
||
|
||
def _table_for_label(self, table_label: str, tables: list[dict[str, Any]]) -> dict[str, Any] | None:
|
||
safe_label = self._safe_sql_identifier(table_label, "")
|
||
for table in tables:
|
||
if safe_label and safe_label == table["sql_name"]:
|
||
return table
|
||
return None
|
||
|
||
def _markers_for_field_labels(self, field_labels: list[str], table: dict[str, Any]) -> list[str]:
|
||
markers = []
|
||
for label in field_labels:
|
||
safe_label = self._safe_sql_identifier(label, "")
|
||
if safe_label and any(field["marker"] == safe_label for field in table["fields"]):
|
||
markers.append(safe_label)
|
||
continue
|
||
for field in table["fields"]:
|
||
field_name = field["name"]
|
||
if label and field_name and (label == field_name or label in field_name or field_name in label):
|
||
markers.append(field["marker"])
|
||
break
|
||
return markers
|
||
|
||
@staticmethod
|
||
def _field_tokens(text: str) -> list[str]:
|
||
return [item.strip() for item in re.split(r"[,,、;;\s]+", text) if item.strip()]
|
||
|
||
def _qualified_marker(self, marker: str, tables: list[dict[str, Any]]) -> str:
|
||
for table in tables:
|
||
if any(field["marker"] == marker for field in table["fields"]):
|
||
return f"{table['alias']}.{marker}"
|
||
return f"{tables[0]['alias']}.{marker}"
|
||
|
||
def _sql_aggregate_subquery(
|
||
self,
|
||
table: dict[str, Any],
|
||
select_markers: list[str],
|
||
join_key: str,
|
||
) -> str:
|
||
table_markers = {field["marker"] for field in table["fields"]}
|
||
selected = [marker for marker in select_markers if marker in table_markers]
|
||
aggregate_markers = self._numeric_markers(table)
|
||
query_markers = []
|
||
for marker in [join_key, *selected, *aggregate_markers]:
|
||
if marker and marker in table_markers and marker not in query_markers:
|
||
query_markers.append(marker)
|
||
if join_key not in query_markers and table["fields"]:
|
||
query_markers.insert(0, table["fields"][0]["marker"])
|
||
|
||
select_lines = []
|
||
group_markers = []
|
||
for marker in query_markers:
|
||
if marker in aggregate_markers:
|
||
select_lines.append(f" SUM({marker}) AS {marker}")
|
||
else:
|
||
select_lines.append(f" {marker}")
|
||
group_markers.append(marker)
|
||
|
||
group_by = ", ".join(group_markers) if group_markers else "1"
|
||
joined_select_lines = ",\n".join(select_lines)
|
||
return (
|
||
"(\n"
|
||
" SELECT\n"
|
||
f"{joined_select_lines}\n"
|
||
f" FROM {table['sql_name']}\n"
|
||
f" GROUP BY {group_by}\n"
|
||
f") {table['alias']}"
|
||
)
|
||
|
||
@staticmethod
|
||
def _numeric_markers(table: dict[str, Any]) -> list[str]:
|
||
markers = []
|
||
for field in table["fields"]:
|
||
marker = field["marker"]
|
||
name = str(field.get("name", "") or "")
|
||
field_type = str(field.get("type", "") or "")
|
||
if any(token in marker for token in ("DATE", "TIME", "DAYIDX")):
|
||
continue
|
||
if any(token in name for token in ("日期", "时间")):
|
||
continue
|
||
if any(token in field_type for token in ("日期", "时间")):
|
||
continue
|
||
if any(token in marker for token in ("AMOUNT", "BALANCE", "ASSET", "LIABILITY", "RATE", "ZZJE", "ZZYE")) or any(
|
||
token in name for token in ("金额", "余额", "利率", "市值", "成本", "资产", "负债")
|
||
):
|
||
markers.append(marker)
|
||
return markers
|
||
|
||
@staticmethod
|
||
def _sql_join_key(tables: list[dict[str, Any]]) -> str:
|
||
if not tables:
|
||
return ""
|
||
primary_markers = {field["marker"] for field in tables[0]["fields"]}
|
||
for table in tables[1:]:
|
||
primary_markers &= {field["marker"] for field in table["fields"]}
|
||
return next((marker for marker in ("COMCODE", "COMNAME", "DAYIDX") if marker in primary_markers), None) or (
|
||
sorted(primary_markers)[0] if primary_markers else ""
|
||
)
|
||
|
||
def _sql_join_lines(self, tables: list[dict[str, Any]], select_markers: list[str], join_key: str) -> list[str]:
|
||
if len(tables) <= 1:
|
||
return []
|
||
primary = tables[0]
|
||
joins = []
|
||
for table in tables[1:4]:
|
||
table_markers = {field["marker"] for field in table["fields"]}
|
||
primary_markers = {field["marker"] for field in primary["fields"]}
|
||
if join_key and join_key in primary_markers and join_key in table_markers:
|
||
joins.append(
|
||
f"LEFT JOIN {self._sql_aggregate_subquery(table, select_markers, join_key)} "
|
||
f"ON {primary['alias']}.{join_key} = {table['alias']}.{join_key}"
|
||
)
|
||
else:
|
||
joins.append(f"LEFT JOIN {self._sql_aggregate_subquery(table, select_markers, join_key)} ON 1 = 1")
|
||
return joins
|
||
|
||
def _sql_metric_alias(self, rule: dict[str, Any]) -> str:
|
||
metric_match = re.search(r"([^\n=]{2,30})\s*=", str(rule.get("business_rule_description", "")))
|
||
metric_name = metric_match.group(1).strip() if metric_match else ""
|
||
return self._safe_sql_identifier(metric_name, "RISK_METRIC")
|
||
|
||
def _sql_metric_expression(self, rule: dict[str, Any], tables: list[dict[str, Any]]) -> str:
|
||
formula = self._metric_formula_rhs(str(rule.get("business_rule_description", "") or ""))
|
||
expression = self._replace_formula_fields(formula, tables)
|
||
if expression:
|
||
return expression
|
||
|
||
for table in tables:
|
||
for field in table["fields"]:
|
||
marker = field["marker"]
|
||
if any(token in marker for token in ("AMOUNT", "BALANCE", "RATE", "ZZJE", "ZZYE")):
|
||
return f"{table['alias']}.{marker}"
|
||
return "0"
|
||
|
||
@staticmethod
|
||
def _sql_metric_components(metric_expression: str) -> list[str]:
|
||
match = re.fullmatch(r"\((.+)\s*/\s*NULLIF\((.+),\s*0\)\)", metric_expression.strip())
|
||
if match:
|
||
return [match.group(1).strip(), match.group(2).strip()]
|
||
return []
|
||
|
||
@staticmethod
|
||
def _metric_formula_rhs(text: str) -> str:
|
||
for line in text.splitlines():
|
||
if "=" in line:
|
||
return line.split("=", 1)[1].strip()
|
||
return ""
|
||
|
||
def _replace_formula_fields(self, formula: str, tables: list[dict[str, Any]]) -> str:
|
||
expression = formula.strip()
|
||
if not expression:
|
||
return ""
|
||
|
||
replacements = []
|
||
for table in tables:
|
||
for field in table["fields"]:
|
||
name = str(field.get("name", "") or "").strip()
|
||
if name and name in expression:
|
||
replacements.append((name, f"{table['alias']}.{field['marker']}"))
|
||
|
||
for name, sql_expression in sorted(replacements, key=lambda item: len(item[0]), reverse=True):
|
||
expression = expression.replace(name, sql_expression)
|
||
|
||
expression = expression.replace("(", "(").replace(")", ")")
|
||
expression = re.sub(r"\s+", " ", expression)
|
||
if re.search(r"[\u4e00-\u9fff]", expression):
|
||
return ""
|
||
if "/" in expression:
|
||
left, right = expression.split("/", 1)
|
||
left = left.strip()
|
||
right = right.strip()
|
||
if left and right:
|
||
return f"({left} / NULLIF({right}, 0))"
|
||
return expression
|
||
|
||
def _sql_risk_cases(self, rule: dict[str, Any], metric_expression: str) -> str:
|
||
thresholds = self._risk_thresholds(rule)
|
||
if not thresholds:
|
||
return "CASE WHEN 1 = 1 THEN 'HIT' ELSE 'PASS' END"
|
||
|
||
case_lines = ["CASE"]
|
||
for label, operator, value in thresholds:
|
||
case_lines.append(f" WHEN {metric_expression} {operator} {value} THEN '{label}'")
|
||
case_lines.append(" ELSE 'PASS'")
|
||
case_lines.append(" END")
|
||
return "\n ".join(case_lines)
|
||
|
||
def _risk_thresholds(self, rule: dict[str, Any]) -> list[tuple[str, str, str]]:
|
||
business_rule = str(rule.get("business_rule_description", "") or "")
|
||
colors = [
|
||
("RED", "\u7ea2\u8272"),
|
||
("ORANGE", "\u6a59\u8272"),
|
||
("YELLOW", "\u9ec4\u8272"),
|
||
]
|
||
thresholds = []
|
||
for label, color in colors:
|
||
for line in business_rule.splitlines():
|
||
if color not in line:
|
||
continue
|
||
comparison = self._first_sql_comparison(line)
|
||
if comparison:
|
||
thresholds.append((label, comparison[0], comparison[1]))
|
||
break
|
||
return thresholds
|
||
|
||
def _system_rule_conditions(self, system_rule: str, metric_expression: str) -> list[str]:
|
||
comparisons = self._sql_comparisons(system_rule)
|
||
return [f"{metric_expression} {operator} {value}" for operator, value in comparisons]
|
||
|
||
def _first_sql_comparison(self, text: str) -> tuple[str, str] | None:
|
||
comparisons = self._sql_comparisons(text)
|
||
return comparisons[0] if comparisons else None
|
||
|
||
@staticmethod
|
||
def _sql_comparisons(text: str) -> list[tuple[str, str]]:
|
||
operator_map = {
|
||
">=": ">=",
|
||
"<=": "<=",
|
||
">": ">",
|
||
"<": "<",
|
||
"\u5927\u4e8e\u7b49\u4e8e": ">=",
|
||
"\u5c0f\u4e8e\u7b49\u4e8e": "<=",
|
||
"\u5927\u4e8e": ">",
|
||
"\u5c0f\u4e8e": "<",
|
||
}
|
||
operator_pattern = "|".join(re.escape(operator) for operator in operator_map)
|
||
comparisons = []
|
||
for match in re.finditer(rf"({operator_pattern})\s*([0-9]+(?:\.[0-9]+)?)\s*(%)?", text):
|
||
value = float(match.group(2))
|
||
if match.group(3):
|
||
value = value / 100
|
||
comparisons.append((operator_map[match.group(1)], f"{value:g}"))
|
||
return comparisons
|
||
|
||
@staticmethod
|
||
def _safe_sql_identifier(value: Any, fallback: str) -> str:
|
||
text = str(value or "").strip().upper()
|
||
text = re.sub(r"[^A-Z0-9_]+", "_", text)
|
||
text = re.sub(r"_+", "_", text).strip("_")
|
||
if not text or not re.search(r"[A-Z]", text):
|
||
return fallback
|
||
if text[0].isdigit():
|
||
text = f"T_{text}"
|
||
return text
|
||
|
||
@staticmethod
|
||
def _safe_table_identifier(value: Any) -> str:
|
||
text = str(value or "").strip()
|
||
parts = []
|
||
for part in text.split("."):
|
||
safe_part = re.sub(r"[^A-Za-z0-9_]+", "_", part)
|
||
safe_part = re.sub(r"_+", "_", safe_part).strip("_")
|
||
if not safe_part or not re.search(r"[A-Za-z]", safe_part):
|
||
continue
|
||
if safe_part[0].isdigit():
|
||
safe_part = f"T_{safe_part}"
|
||
parts.append(safe_part)
|
||
if not parts:
|
||
return ""
|
||
return ".".join(parts)
|
||
|
||
@staticmethod
|
||
def _estimate_row_height(row) -> float:
|
||
max_lines = 1
|
||
for cell in row:
|
||
value = str(cell.value or "")
|
||
explicit_lines = value.count("\n") + 1
|
||
width = _COLUMN_WIDTHS[cell.column - 1] if cell.column - 1 < len(_COLUMN_WIDTHS) else 72
|
||
wrapped_lines = max(1, len(value) // max(int(width * 1.6), 1) + 1)
|
||
max_lines = max(max_lines, explicit_lines, wrapped_lines)
|
||
return min(max(36, max_lines * 18), 409)
|