feat: 新增风险规则生成引擎与知识图谱可视化

后端新增风险规则自动生成和模板执行服务,支持从规则资产
批量生成并持久化风险规则文件;知识库入库日志增强图谱
查询和本地 RAG 回退,前端审计页面增加风险规则模型和流
程图组件,知识入库面板拆分为图谱可视化子组件,报销创
建页面增加引导式流程模型,更新知识库索引数据。
This commit is contained in:
caoxiaozhu
2026-05-23 19:54:42 +08:00
parent 5b388d08c0
commit 575f093c74
63 changed files with 35497 additions and 1517 deletions

View File

@@ -42,6 +42,7 @@ class AgentAssetJsonRuleMixin:
description=str(payload.get("description") or asset.description or "").strip(),
evaluator=str(payload.get("evaluator") or ""),
ontology_signal=str(payload.get("ontology_signal") or "") or None,
flow_diagram_svg=str(payload.get("flow_diagram_svg") or "") or None,
inputs=payload.get("inputs") if isinstance(payload.get("inputs"), dict) else {},
outcomes=payload.get("outcomes") if isinstance(payload.get("outcomes"), dict) else {},
payload=payload,
@@ -95,4 +96,3 @@ class AgentAssetJsonRuleMixin:
)
self.db.commit()
return self.read_rule_json(asset_id)

View File

@@ -1,5 +1,7 @@
from __future__ import annotations
import threading
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -31,6 +33,8 @@ from app.services.agent_foundation_risk_rules import AgentFoundationRiskRuleMixi
from app.services.agent_foundation_spreadsheets import AgentFoundationSpreadsheetMixin
logger = get_logger("app.services.agent_foundation")
_foundation_ready_lock = threading.RLock()
_foundation_ready_keys: set[str] = set()
def prepare_agent_foundation() -> None:
@@ -57,6 +61,17 @@ class AgentFoundationService(
self.db = db
def ensure_foundation_ready(self) -> None:
cache_key = self._foundation_cache_key()
if cache_key in _foundation_ready_keys:
return
with _foundation_ready_lock:
if cache_key in _foundation_ready_keys:
return
self._prepare_foundation()
_foundation_ready_keys.add(cache_key)
def _prepare_foundation(self) -> None:
try:
Base.metadata.create_all(bind=self.db.get_bind())
self._ensure_agent_asset_schema()
@@ -69,6 +84,10 @@ class AgentFoundationService(
logger.exception("Failed to prepare agent foundation")
raise
def _foundation_cache_key(self) -> str:
bind = self.db.get_bind()
return str(getattr(bind, "url", "") or id(bind))
def _sync_demo_financial_records(self) -> None:
if get_settings().seed_demo_financial_records:
self._seed_financial_records()

View File

@@ -6,12 +6,14 @@ from typing import Any
from sqlalchemy.orm import Session
from app.core.config import get_settings
from app.core.agent_enums import AgentName, AgentPermissionLevel, AgentRunStatus
from app.core.logging import get_logger
from app.models.agent_run import AgentRun, AgentToolCall, SemanticParseLog
from app.repositories.agent_run import AgentRunRepository
from app.schemas.agent_run import AgentRunRead, AgentToolCallRead, SemanticParseRead
from app.services.agent_foundation import AgentFoundationService
from app.services.knowledge_ingest_log import enrich_knowledge_ingest_route_json
logger = get_logger("app.services.agent_runs")
@@ -42,7 +44,7 @@ class AgentRunService:
run = self.repository.get_by_run_id(run_id)
if run is None:
return None
return self._serialize_run(run)
return self._serialize_run(run, enrich_knowledge_ingest=True)
def create_run(
self,
@@ -314,9 +316,19 @@ class AgentRunService:
except ValueError:
return None
@staticmethod
def _serialize_run(run: AgentRun) -> AgentRunRead:
def _serialize_run(
self,
run: AgentRun,
*,
enrich_knowledge_ingest: bool = False,
) -> AgentRunRead:
semantic_parse = run.semantic_parse_logs[0] if run.semantic_parse_logs else None
route_json = run.route_json
if enrich_knowledge_ingest:
route_json = enrich_knowledge_ingest_route_json(
dict(run.route_json or {}),
storage_root=get_settings().resolved_storage_root_dir,
)
return AgentRunRead(
id=run.id,
run_id=run.run_id,
@@ -325,7 +337,7 @@ class AgentRunService:
user_id=run.user_id,
task_id=run.task_id,
ontology_json=run.ontology_json,
route_json=run.route_json,
route_json=route_json,
permission_level=run.permission_level,
status=run.status,
result_summary=run.result_summary,

View File

@@ -1,36 +1,19 @@
from __future__ import annotations
import re
from datetime import UTC, date, datetime, timedelta
from decimal import Decimal
from types import SimpleNamespace
from typing import Any
from sqlalchemy import or_, select
from sqlalchemy import inspect as sqlalchemy_inspect
from sqlalchemy import select
from app.api.deps import CurrentUserContext
from app.core.agent_enums import AgentAssetDomain, AgentAssetStatus, AgentAssetType
from app.models.agent_asset import AgentAsset
from app.models.financial_record import ExpenseClaim, ExpenseClaimItem
from app.schemas.reimbursement import TravelReimbursementCalculatorRequest
from app.services.agent_asset_rule_library import AgentAssetRuleLibraryManager
from app.services.agent_asset_spreadsheet import RISK_RULES_LIBRARY
from app.services.expense_claim_constants import (
AI_REVIEW_LOOKBACK_DAYS,
AI_REVIEW_REPEAT_RISK_BLOCK_COUNT,
AI_REVIEW_REPEAT_RISK_WARNING_COUNT,
DOCUMENT_FACT_ITEM_TYPES,
LOCATION_REQUIRED_EXPENSE_TYPES,
SYSTEM_GENERATED_ITEM_TYPES,
TRAVEL_ALLOWANCE_TRIGGER_ITEM_TYPES,
TRAVEL_POLICY_HOTEL_NIGHT_PATTERN,
)
from app.services.expense_rule_runtime import (
ExpenseRuleRuntimeService,
RuntimeTravelPolicy,
build_default_expense_rule_catalog,
)
from app.services.risk_rule_template_executor import RiskRuleTemplateExecutor
class ExpenseClaimPlatformRiskMixin:
@@ -66,9 +49,7 @@ class ExpenseClaimPlatformRiskMixin:
if severity == "high" or action == "block":
blocking_reasons.append(str(flag.get("message") or flag.get("label") or "").strip())
deduplicated_reasons = list(
dict.fromkeys(reason for reason in blocking_reasons if reason)
)
deduplicated_reasons = list(dict.fromkeys(reason for reason in blocking_reasons if reason))
return {"flags": flags, "blocking_reasons": deduplicated_reasons}
def _load_platform_risk_rule_manifests(
@@ -77,9 +58,7 @@ class ExpenseClaimPlatformRiskMixin:
rule_codes: list[str] | None,
) -> list[dict[str, Any]]:
code_filter = {
str(code or "").strip()
for code in list(rule_codes or [])
if str(code or "").strip()
str(code or "").strip() for code in list(rule_codes or []) if str(code or "").strip()
}
manifests_by_code: dict[str, dict[str, Any]] = {}
@@ -224,12 +203,10 @@ class ExpenseClaimPlatformRiskMixin:
normalized_contexts.append(
{
"scene_code": str(document_info.get("scene_code") or "").strip().lower(),
"document_type": str(
document_info.get("document_type") or ""
).strip().lower(),
"item_type": str(
getattr(context.get("item"), "item_type", "") or ""
).strip().lower(),
"document_type": str(document_info.get("document_type") or "").strip().lower(),
"item_type": str(getattr(context.get("item"), "item_type", "") or "")
.strip()
.lower(),
}
)
@@ -312,6 +289,19 @@ class ExpenseClaimPlatformRiskMixin:
claim=claim,
contexts=contexts,
)
if evaluator == "template_rule":
result = RiskRuleTemplateExecutor().evaluate(
manifest,
claim=claim,
contexts=contexts,
)
if result is None:
return None
return self._build_platform_risk_flag(
manifest,
message=str(result.get("message") or "自然语言风险规则命中。"),
evidence=result.get("evidence") if isinstance(result.get("evidence"), dict) else {},
)
return None
def _evaluate_reason_too_brief_risk(
@@ -347,9 +337,7 @@ class ExpenseClaimPlatformRiskMixin:
reason_corpus = self._build_scene_reason_corpus(claim)
compact_reason = re.sub(r"\s+", "", reason_corpus)
looks_like_entertainment = (
"entertainment" in expense_types
or "招待" in compact_reason
or "客户" in compact_reason
"entertainment" in expense_types or "招待" in compact_reason or "客户" in compact_reason
)
if not looks_like_entertainment:
return None
@@ -374,32 +362,28 @@ class ExpenseClaimPlatformRiskMixin:
for context in contexts:
item = context["item"]
item_type = (
str(item.item_type or claim.expense_type or "other").strip().lower()
or "other"
str(item.item_type or claim.expense_type or "other").strip().lower() or "other"
)
policy = self._get_expense_scene_policy(item_type)
if policy is None:
continue
document_info = context.get("document_info") or {}
recognized_scene_code = (
str(document_info.get("scene_code") or "other").strip().lower()
or "other"
str(document_info.get("scene_code") or "other").strip().lower() or "other"
)
recognized_document_type = (
str(document_info.get("document_type") or "other").strip().lower()
or "other"
str(document_info.get("document_type") or "other").strip().lower() or "other"
)
if (
recognized_scene_code in set(policy.allowed_scene_codes)
or recognized_document_type in set(policy.allowed_document_types)
):
if recognized_scene_code in set(
policy.allowed_scene_codes
) or recognized_document_type in set(policy.allowed_document_types):
continue
recognized_label = str(
document_info.get("document_type_label")
or recognized_document_type
or "未知票据"
document_info.get("document_type_label") or recognized_document_type or "未知票据"
)
mismatches.append(
f"{context['index']} 条明细为{policy.label},附件识别为{recognized_label}"
)
mismatches.append(f"{context['index']} 条明细为{policy.label},附件识别为{recognized_label}")
if not mismatches:
return None
@@ -437,7 +421,10 @@ class ExpenseClaimPlatformRiskMixin:
evidence_text = "".join(evidence_cities[:5])
return self._build_platform_risk_flag(
manifest,
message=f"申报地点 {declared_text} 与票据识别地点 {evidence_text} 不一致,建议补充异地说明或更换附件。",
message=(
f"申报地点 {declared_text} 与票据识别地点 {evidence_text} 不一致,"
"建议补充异地说明或更换附件。"
),
evidence={"declared_cities": declared_cities, "evidence_cities": evidence_cities},
)
@@ -450,9 +437,7 @@ class ExpenseClaimPlatformRiskMixin:
) -> dict[str, Any] | None:
invoice_keys = self._collect_invoice_keys_from_contexts(contexts)
duplicate_keys = [
key
for key, count in self._count_values(invoice_keys).items()
if count > 1
key for key, count in self._count_values(invoice_keys).items() if count > 1
]
if duplicate_keys:
return self._build_platform_risk_flag(
@@ -504,9 +489,7 @@ class ExpenseClaimPlatformRiskMixin:
) -> dict[str, Any] | None:
params = manifest.get("params") if isinstance(manifest.get("params"), dict) else {}
allow_keywords = [
str(value)
for value in list(params.get("allow_keywords") or [])
if str(value).strip()
str(value) for value in list(params.get("allow_keywords") or []) if str(value).strip()
]
claimant = str(claim.employee_name or "").strip()
if not claimant:
@@ -564,7 +547,10 @@ class ExpenseClaimPlatformRiskMixin:
return None
return self._build_platform_risk_flag(
manifest,
message=f"票据年份 {mismatch_years[0]} 与费用发生年份 {claim_year} 不一致,建议确认是否跨年报销。",
message=(
f"票据年份 {mismatch_years[0]} 与费用发生年份 {claim_year} 不一致,"
"建议确认是否跨年报销。"
),
evidence={"claim_year": claim_year, "invoice_years": mismatch_years},
)

View File

@@ -480,6 +480,7 @@ def _build_initial_knowledge_ingest_document(
"entity_count": 0,
"relation_count": 0,
"entities": [],
"entity_chunks": [],
"relations": [],
"events": [
{
@@ -677,7 +678,7 @@ def _build_ingest_graph(knowledge_ingest: dict[str, Any]) -> dict[str, Any]:
documents = [
item for item in list(knowledge_ingest.get("documents") or []) if isinstance(item, dict)
]
entities = _dedupe_text_items(
entities = _dedupe_entities(
entity for document in documents for entity in list(document.get("entities") or [])
)
relations = _dedupe_relations(
@@ -692,20 +693,52 @@ def _build_ingest_graph(knowledge_ingest: dict[str, Any]) -> dict[str, Any]:
}
def _dedupe_text_items(items: Any) -> list[str]:
deduped: list[str] = []
def _dedupe_entities(items: Any) -> list[dict[str, Any]]:
deduped: list[dict[str, Any]] = []
seen: set[str] = set()
for item in items:
text = str(item or "").strip()
if not text or text in seen:
if isinstance(item, dict):
name = str(
item.get("name")
or item.get("entity")
or item.get("entity_id")
or item.get("title")
or item.get("id")
or ""
).strip()
entity = dict(item)
else:
name = str(item or "").strip()
entity = {}
if not name or name in seen:
continue
seen.add(text)
deduped.append(text)
seen.add(name)
entity["name"] = name
entity["type"] = str(
entity.get("type")
or entity.get("entity_type")
or entity.get("category")
or entity.get("kind")
or "实体"
).strip()
description = str(entity.get("description") or "").strip()
descriptions = entity.get("descriptions")
if not isinstance(descriptions, list):
descriptions = [description] if description else []
entity["description"] = description
entity["descriptions"] = [
str(description_item or "").strip()
for description_item in descriptions
if str(description_item or "").strip()
][:5]
if not isinstance(entity.get("properties"), dict):
entity["properties"] = {}
deduped.append(entity)
return deduped
def _dedupe_relations(items: Any) -> list[dict[str, str]]:
deduped: list[dict[str, str]] = []
def _dedupe_relations(items: Any) -> list[dict[str, Any]]:
deduped: list[dict[str, Any]] = []
seen: set[tuple[str, str, str]] = set()
for item in items:
if not isinstance(item, dict):
@@ -717,7 +750,7 @@ def _dedupe_relations(items: Any) -> list[dict[str, str]]:
if not source or not target or key in seen:
continue
seen.add(key)
deduped.append({"source": source, "target": target, "type": relation_type})
deduped.append({**item, "source": source, "target": target, "type": relation_type})
return deduped

View File

@@ -1,15 +1,21 @@
from __future__ import annotations
import json
import os
import re
from pathlib import Path
from typing import Any
from xml.etree import ElementTree
MAX_INGEST_LOG_CHUNKS = 24
MAX_INGEST_LOG_ENTITIES = 24
MAX_INGEST_LOG_ENTITY_CHUNKS = 48
MAX_INGEST_LOG_RELATIONS = 24
MAX_INGEST_LOG_SECTIONS = 12
MAX_INGEST_LOG_TEXT_PREVIEW = 180
MAX_INGEST_LOG_ENTITY_DESCRIPTIONS = 5
GRAPHML_NAMESPACE = {"graphml": "http://graphml.graphdrawing.org/xmlns"}
GRAPH_PROPERTY_SEPARATOR = "<SEP>"
INGEST_SECTION_HEADING_PATTERN = re.compile(
r"^(?:#{1,4}\s+.+|第[一二三四五六七八九十百零0-9]+[章节条]\s*.*)$"
@@ -42,6 +48,7 @@ def build_ingest_document_summary(
"entity_count": 0,
"relation_count": 0,
"entities": [],
"entity_chunks": [],
"relations": [],
}
@@ -62,6 +69,33 @@ def build_ingest_status_summary(
}
def enrich_knowledge_ingest_route_json(
route_json: dict[str, Any],
*,
storage_root: Path,
) -> dict[str, Any]:
if not isinstance(route_json, dict):
return route_json
ingest = route_json.get("knowledge_ingest")
if not isinstance(ingest, dict):
return route_json
graph = ingest.get("graph")
if not isinstance(graph, dict):
return route_json
workspace = _resolve_lightrag_workspace(route_json)
graph_snapshot = _load_lightrag_graph_snapshot(storage_root, workspace=workspace)
if not graph_snapshot["entities"] and not graph_snapshot["relations"]:
return route_json
next_route = dict(route_json)
next_ingest = dict(ingest)
next_graph = _enrich_graph_payload(graph, graph_snapshot)
next_ingest["graph"] = next_graph
next_route["knowledge_ingest"] = next_ingest
return next_route
def build_document_graph_summary(
storage_root: Path,
*,
@@ -74,19 +108,264 @@ def build_document_graph_summary(
entities_payload = _load_json_file(workspace_dir / "kv_store_full_entities.json")
relations_payload = _load_json_file(workspace_dir / "kv_store_full_relations.json")
chunks_payload = _load_json_file(workspace_dir / "kv_store_text_chunks.json")
entity_chunks_payload = _load_json_file(workspace_dir / "kv_store_entity_chunks.json")
graph_snapshot = _load_lightrag_graph_snapshot(storage_root, workspace=workspace)
entities = _normalize_document_entities(entities_payload, document_id)
relations = _normalize_document_relations(relations_payload, document_id)
chunks = _normalize_document_chunks(chunks_payload, document_id)
entity_chunks = _normalize_document_entity_chunks(
entity_chunks_payload,
entities,
chunk_ids={str(item.get("id") or "").strip() for item in chunks},
)
return {
"entity_count": len(entities),
"relation_count": len(relations),
"entities": entities[:MAX_INGEST_LOG_ENTITIES],
"relations": relations[:MAX_INGEST_LOG_RELATIONS],
"entities": _enrich_entity_list(entities, graph_snapshot)[:MAX_INGEST_LOG_ENTITIES],
"relations": _enrich_relation_list(relations, graph_snapshot)[:MAX_INGEST_LOG_RELATIONS],
"chunks": chunks[:MAX_INGEST_LOG_CHUNKS],
"entity_chunks": entity_chunks[:MAX_INGEST_LOG_ENTITY_CHUNKS],
}
def _resolve_lightrag_workspace(route_json: dict[str, Any]) -> str:
explicit_workspace = str(
route_json.get("lightrag_workspace") or route_json.get("workspace") or ""
).strip()
if explicit_workspace:
return explicit_workspace
return os.environ.get("LIGHTRAG_WORKSPACE", "x_financial_knowledge").strip() or "x_financial_knowledge"
def _enrich_graph_payload(
graph: dict[str, Any],
graph_snapshot: dict[str, Any],
) -> dict[str, Any]:
next_graph = dict(graph)
relation_items = _extract_relation_items(graph.get("relations"))
relation_entity_names = [
name
for relation in relation_items
for name in (relation.get("source"), relation.get("target"))
]
next_graph["entities"] = _enrich_entity_list(
_dedupe_text_items(
_extract_entity_names(graph.get("entities")) + relation_entity_names
),
graph_snapshot,
)
next_graph["relations"] = _enrich_relation_list(relation_items, graph_snapshot)
return next_graph
def _enrich_entity_list(
entity_names: list[str],
graph_snapshot: dict[str, Any],
) -> list[dict[str, Any]]:
graph_entities = graph_snapshot.get("entities") or {}
return [
graph_entities.get(entity_name)
or {
"name": entity_name,
"type": "实体",
"description": "",
"descriptions": [],
"properties": {},
}
for entity_name in entity_names
]
def _enrich_relation_list(
relations: list[dict[str, Any]],
graph_snapshot: dict[str, Any],
) -> list[dict[str, Any]]:
graph_relations = graph_snapshot.get("relations") or {}
enriched_relations: list[dict[str, Any]] = []
for relation in relations:
source = str(relation.get("source") or "").strip()
target = str(relation.get("target") or "").strip()
relation_type = str(relation.get("type") or "关联").strip()
graph_relation = (
graph_relations.get((source, target))
or graph_relations.get((target, source))
or {}
)
enriched_relations.append(
{
**relation,
"source": source,
"target": target,
"type": relation_type,
"description": graph_relation.get("description", ""),
"keywords": graph_relation.get("keywords", []),
"weight": graph_relation.get("weight", relation.get("weight", 1)),
"properties": graph_relation.get("properties", {}),
}
)
return enriched_relations
def _load_lightrag_graph_snapshot(storage_root: Path, *, workspace: str) -> dict[str, Any]:
graphml_path = (
Path(storage_root)
/ "knowledge"
/ ".lightrag"
/ str(workspace).strip()
/ "graph_chunk_entity_relation.graphml"
)
if not graphml_path.exists():
return {"entities": {}, "relations": {}}
try:
root = ElementTree.parse(graphml_path).getroot()
except (ElementTree.ParseError, OSError):
return {"entities": {}, "relations": {}}
key_names = {
str(key.attrib.get("id") or ""): str(key.attrib.get("attr.name") or "")
for key in root.findall("graphml:key", GRAPHML_NAMESPACE)
}
return {
"entities": _load_graphml_entities(root, key_names),
"relations": _load_graphml_relations(root, key_names),
}
def _load_graphml_entities(
root: ElementTree.Element,
key_names: dict[str, str],
) -> dict[str, dict[str, Any]]:
entities: dict[str, dict[str, Any]] = {}
for node in root.findall(".//graphml:node", GRAPHML_NAMESPACE):
properties = _read_graphml_data(node, key_names)
name = str(properties.get("entity_id") or node.attrib.get("id") or "").strip()
if not name:
continue
descriptions = _split_graph_property(properties.get("description"))
visible_properties = _filter_graph_properties(properties)
entities[name] = {
"name": name,
"type": str(properties.get("entity_type") or "实体").strip(),
"description": descriptions[0] if descriptions else "",
"descriptions": descriptions[:MAX_INGEST_LOG_ENTITY_DESCRIPTIONS],
"properties": visible_properties,
}
return entities
def _load_graphml_relations(
root: ElementTree.Element,
key_names: dict[str, str],
) -> dict[tuple[str, str], dict[str, Any]]:
relations: dict[tuple[str, str], dict[str, Any]] = {}
for edge in root.findall(".//graphml:edge", GRAPHML_NAMESPACE):
source = str(edge.attrib.get("source") or "").strip()
target = str(edge.attrib.get("target") or "").strip()
if not source or not target:
continue
properties = _read_graphml_data(edge, key_names)
description_parts = _split_graph_property(properties.get("description"))
relations[(source, target)] = {
"description": "; ".join(description_parts[:2]),
"keywords": _split_graph_keywords(properties.get("keywords"))[:6],
"weight": _to_float(properties.get("weight"), default=1.0),
"properties": _filter_graph_properties(properties),
}
return relations
def _read_graphml_data(
element: ElementTree.Element,
key_names: dict[str, str],
) -> dict[str, str]:
data: dict[str, str] = {}
for item in element.findall("graphml:data", GRAPHML_NAMESPACE):
key = str(item.attrib.get("key") or "")
name = key_names.get(key) or key
if not name:
continue
data[name] = str(item.text or "").strip()
return data
def _split_graph_property(value: Any) -> list[str]:
return [
_truncate_text(part, max_length=MAX_INGEST_LOG_TEXT_PREVIEW)
for part in str(value or "").split(GRAPH_PROPERTY_SEPARATOR)
if str(part or "").strip()
]
def _split_graph_keywords(value: Any) -> list[str]:
keywords: list[str] = []
for part in str(value or "").split(GRAPH_PROPERTY_SEPARATOR):
keywords.extend(part.split(","))
return [
_truncate_text(keyword, max_length=60)
for keyword in keywords
if str(keyword or "").strip()
]
def _filter_graph_properties(properties: dict[str, Any]) -> dict[str, Any]:
hidden_fields = {
"source_id",
"file_path",
"truncate",
"description",
"keywords",
}
return {
key: value
for key, value in properties.items()
if key not in hidden_fields and str(value or "").strip()
}
def _extract_entity_names(raw_entities: Any) -> list[str]:
if not isinstance(raw_entities, list):
return []
names: list[str] = []
for entity in raw_entities:
if isinstance(entity, dict):
name = str(
entity.get("name")
or entity.get("entity")
or entity.get("entity_id")
or entity.get("id")
or ""
).strip()
else:
name = str(entity or "").strip()
if name:
names.append(name)
return _dedupe_text_items(names)
def _extract_relation_items(raw_relations: Any) -> list[dict[str, Any]]:
if not isinstance(raw_relations, list):
return []
relations: list[dict[str, Any]] = []
for relation in raw_relations:
if not isinstance(relation, dict):
continue
source = str(relation.get("source") or relation.get("from") or "").strip()
target = str(relation.get("target") or relation.get("to") or "").strip()
if not source or not target:
continue
relations.append(
{
**relation,
"source": source,
"target": target,
"type": str(relation.get("type") or "关联").strip(),
}
)
return relations
def _extract_ingest_sections(text: str) -> list[dict[str, str]]:
sections: list[dict[str, str]] = []
lines = [line.strip() for line in str(text or "").splitlines()]
@@ -187,11 +466,46 @@ def _normalize_document_chunks(payload: dict[str, Any], document_id: str) -> lis
"order": _to_int(raw_chunk.get("chunk_order_index")),
"tokens": _to_int(raw_chunk.get("tokens")),
"summary": _build_chunk_summary(content),
"excerpt": _truncate_text(
content,
max_length=MAX_INGEST_LOG_TEXT_PREVIEW,
),
}
)
return sorted(chunks, key=lambda item: (item["order"], item["id"]))
def _normalize_document_entity_chunks(
payload: dict[str, Any],
entities: list[str],
*,
chunk_ids: set[str],
) -> list[dict[str, Any]]:
if not entities or not chunk_ids:
return []
entity_chunks: list[dict[str, Any]] = []
for entity in entities:
raw_entry = payload.get(entity) if isinstance(payload, dict) else {}
raw_chunk_ids = raw_entry.get("chunk_ids") if isinstance(raw_entry, dict) else []
if not isinstance(raw_chunk_ids, list):
continue
matched_chunk_ids = [
str(item or "").strip()
for item in raw_chunk_ids
if str(item or "").strip() in chunk_ids
]
if not matched_chunk_ids:
continue
entity_chunks.append(
{
"entity": entity,
"chunk_ids": _dedupe_text_items(matched_chunk_ids),
}
)
return entity_chunks
def _build_chunk_summary(content: str) -> str:
lines = [line.strip() for line in str(content or "").splitlines() if line.strip()]
text = next((line for line in lines if len(line) >= 12), lines[0] if lines else "")
@@ -217,6 +531,13 @@ def _to_int(value: Any) -> int:
return 0
def _to_float(value: Any, *, default: float = 0.0) -> float:
try:
return float(value)
except (TypeError, ValueError):
return default
def _truncate_text(text: str, *, max_length: int) -> str:
normalized = " ".join(str(text or "").split()).strip()
if len(normalized) <= max_length:

View File

@@ -17,6 +17,7 @@ from app.services.knowledge_ingest_log import (
build_ingest_document_summary,
build_ingest_status_summary,
)
from app.services.knowledge_rag_local import query_local_text_chunks
from app.services.knowledge_rag_runtime import (
KnowledgeRagError,
RuntimeModelConfig,
@@ -95,6 +96,37 @@ class KnowledgeRagService:
"message": "请先输入要检索的知识库问题。",
}
workspace = (
os.environ.get("LIGHTRAG_WORKSPACE", DEFAULT_LIGHTRAG_WORKSPACE).strip()
or DEFAULT_LIGHTRAG_WORKSPACE
)
local_result = query_local_text_chunks(
lightrag_root=(self.storage_root / "knowledge" / ".lightrag").resolve(),
workspace=workspace,
query=normalized_query,
limit=limit,
)
if local_result.confident:
return {
"result_type": "knowledge_search",
"query": normalized_query,
"record_count": len(local_result.hits),
"hits": local_result.hits,
"references": [
str(item.get("code") or "").strip()
for item in local_result.hits
if str(item.get("code") or "").strip()
],
"raw_references": [],
"metadata": {
"retrieval_strategy": "local_text_chunks",
"elapsed_seconds": round(local_result.elapsed_seconds, 4),
"total_chunks": local_result.total_chunks,
"best_score": local_result.best_score,
},
"message": f"已从本地知识块中检索到 {len(local_result.hits)} 条相关内容。",
}
try:
runtime = self._get_runtime()
raw = runtime.query_data(normalized_query, conversation_history=conversation_history)

View File

@@ -0,0 +1,353 @@
from __future__ import annotations
import json
import re
import threading
from dataclasses import dataclass
from pathlib import Path
from time import perf_counter
from typing import Any
MAX_LOCAL_QUERY_TERMS = 14
MAX_LOCAL_HIT_CONTENT_LENGTH = 2200
MAX_LOCAL_HIT_EXCERPT_LENGTH = 240
LOCAL_CONFIDENCE_SCORE = 24
LOCAL_CONFIDENCE_MATCHES = 2
LOCAL_QUERY_STOPWORDS = {
"什么",
"多少",
"哪些",
"怎么",
"如何",
"请问",
"一下",
"关于",
"规定",
"标准",
"可以",
"是否",
"一个",
"根据",
"依据",
"给出",
"说明",
"公司",
"远光",
"软件",
"股份",
"有限",
"员工",
"当前",
"详细",
"问题",
}
LOCAL_TABLE_QUERY_HINTS = (
"标准",
"金额",
"限额",
"补贴",
"住宿",
"餐费",
"交通",
"报销",
"档位",
"额度",
)
LOCAL_DOMAIN_TERMS = (
"报销",
"费用",
"报销时限",
"申请时限",
"三个月",
"逾期",
"住宿费",
"住宿",
"差旅费",
"差旅",
"出差",
"超标",
"超过",
"审批",
"分管领导",
"部门负责人",
"业务招待",
"招待费",
"发票",
"票据",
"预算外",
"预算",
"补贴",
"餐补",
"交通费",
"会议费",
"培训费",
"通信费",
)
_index_lock = threading.RLock()
_index_cache: dict[Path, tuple[tuple[int, int], list[dict[str, Any]]]] = {}
@dataclass(frozen=True, slots=True)
class LocalKnowledgeSearchResult:
hits: list[dict[str, Any]]
confident: bool
elapsed_seconds: float
total_chunks: int
best_score: int
def query_local_text_chunks(
*,
lightrag_root: Path,
workspace: str,
query: str,
limit: int,
) -> LocalKnowledgeSearchResult:
started_at = perf_counter()
chunks = _load_text_chunks(lightrag_root / workspace / "kv_store_text_chunks.json")
query_terms = _extract_local_query_terms(query)
if not chunks or not query_terms:
return LocalKnowledgeSearchResult(
hits=[],
confident=False,
elapsed_seconds=perf_counter() - started_at,
total_chunks=len(chunks),
best_score=0,
)
prefers_tabular_evidence = any(hint in query for hint in LOCAL_TABLE_QUERY_HINTS)
candidates: list[dict[str, Any]] = []
for rank, chunk in enumerate(chunks, start=1):
content = str(chunk.get("content") or "").strip()
if not content:
continue
file_path = str(chunk.get("file_path") or "").strip()
document_id, document_name = _parse_document_identity(
file_path,
fallback_doc_id=str(chunk.get("full_doc_id") or "").strip(),
)
score, matched_terms = _score_local_chunk(
content=content,
title=document_name,
query_terms=query_terms,
rank=rank,
prefers_tabular_evidence=prefers_tabular_evidence,
)
if score <= 0:
continue
chunk_id = str(chunk.get("_id") or chunk.get("chunk_id") or "").strip()
normalized_content = _truncate_text(
content,
max_length=MAX_LOCAL_HIT_CONTENT_LENGTH,
)
candidates.append(
{
"code": f"knowledge.{document_id or 'unknown'}.{chunk_id or rank}",
"candidate_id": chunk_id or f"local-{rank}",
"title": document_name or "知识库文档",
"content": normalized_content,
"excerpt": _build_query_focused_excerpt(
normalized_content,
query_terms=query_terms,
max_length=MAX_LOCAL_HIT_EXCERPT_LENGTH,
),
"document_id": document_id,
"document_name": document_name or Path(file_path).name,
"version": None,
"updated_at": None,
"score": score,
"tags": [],
"evidence": [chunk_id] if chunk_id else [],
"file_path": file_path,
"_matched_terms": matched_terms,
}
)
ranked = sorted(
candidates,
key=lambda item: (
int(item.get("score") or 0),
len(list(item.get("_matched_terms") or [])),
-len(str(item.get("content") or "")),
),
reverse=True,
)
hits: list[dict[str, Any]] = []
for item in ranked[: max(1, limit)]:
normalized = dict(item)
normalized.pop("_matched_terms", None)
hits.append(normalized)
best_score = int(ranked[0].get("score") or 0) if ranked else 0
best_match_count = len(list(ranked[0].get("_matched_terms") or [])) if ranked else 0
confident = bool(
hits
and best_score >= LOCAL_CONFIDENCE_SCORE
and best_match_count >= LOCAL_CONFIDENCE_MATCHES
)
return LocalKnowledgeSearchResult(
hits=hits,
confident=confident,
elapsed_seconds=perf_counter() - started_at,
total_chunks=len(chunks),
best_score=best_score,
)
def _load_text_chunks(path: Path) -> list[dict[str, Any]]:
try:
stat = path.stat()
except OSError:
return []
signature = (int(stat.st_mtime_ns), int(stat.st_size))
resolved_path = path.resolve()
with _index_lock:
cached = _index_cache.get(resolved_path)
if cached is not None and cached[0] == signature:
return cached[1]
try:
payload = json.loads(path.read_text(encoding="utf-8"))
except (OSError, json.JSONDecodeError):
chunks: list[dict[str, Any]] = []
else:
chunks = [
dict(value)
for value in payload.values()
if isinstance(value, dict) and str(value.get("content") or "").strip()
]
chunks.sort(
key=lambda item: (
str(item.get("full_doc_id") or ""),
int(item.get("chunk_order_index") or 0),
str(item.get("_id") or ""),
)
)
_index_cache[resolved_path] = (signature, chunks)
return chunks
def _score_local_chunk(
*,
content: str,
title: str,
query_terms: list[str],
rank: int,
prefers_tabular_evidence: bool,
) -> tuple[int, list[str]]:
lowered_content = content.lower()
lowered_title = title.lower()
matched_terms = [
term for term in query_terms if term in lowered_content or term in lowered_title
]
if not matched_terms:
return 0, []
score = max(1, 32 - min(rank, 20))
for term in matched_terms:
weight = 8 if len(term) >= 4 else 5 if len(term) == 3 else 2
score += weight
if term in lowered_title:
score += max(4, weight)
occurrences = lowered_content.count(term)
if occurrences > 1:
score += min(8, occurrences * 2)
if prefers_tabular_evidence and ("|" in content or "" in content):
score += 12
if "# 结构化表格补充" in content:
score += 10 if prefers_tabular_evidence else 4
if "# 问答线索补充" in content:
score += 8
if "# 章节导航" in content[:260]:
score -= 20
if any(marker in content for marker in ("", "", "", "", "-", "")):
score += 4
return score, matched_terms
def _extract_local_query_terms(query: str) -> list[str]:
normalized_query = str(query or "").strip().lower()
if not normalized_query:
return []
terms: list[str] = []
seen: set[str] = set()
def remember(term: str) -> None:
normalized_term = str(term or "").strip().lower()
if (
not normalized_term
or normalized_term in seen
or normalized_term in LOCAL_QUERY_STOPWORDS
or len(normalized_term) < 2
):
return
seen.add(normalized_term)
terms.append(normalized_term)
for item in re.findall(r"[a-z0-9][a-z0-9_\-]{1,}", normalized_query):
remember(item)
for item in LOCAL_DOMAIN_TERMS:
if item in normalized_query:
remember(item)
for block in re.findall(r"[\u4e00-\u9fff]{2,24}", normalized_query):
if len(block) <= 4:
remember(block)
continue
for size in (4, 3, 2, 5):
for start in range(0, len(block) - size + 1):
remember(block[start : start + size])
if len(terms) >= MAX_LOCAL_QUERY_TERMS:
return terms
return terms[:MAX_LOCAL_QUERY_TERMS]
def _parse_document_identity(file_path: str, *, fallback_doc_id: str = "") -> tuple[str, str]:
path = Path(str(file_path or "").strip())
name = path.name
if "__" not in name:
return fallback_doc_id, name
document_id, document_name = name.split("__", maxsplit=1)
return document_id.strip() or fallback_doc_id, document_name.strip()
def _build_query_focused_excerpt(
text: str,
*,
query_terms: list[str],
max_length: int,
) -> str:
normalized = " ".join(str(text or "").split()).strip()
if not normalized:
return ""
lowered = normalized.lower()
match_positions = [
lowered.find(term) for term in query_terms if term and lowered.find(term) >= 0
]
if not match_positions:
return _truncate_text(normalized, max_length=max_length)
start = max(0, min(match_positions) - max_length // 3)
end = min(len(normalized), start + max_length)
snippet = normalized[start:end].strip()
if start > 0:
snippet = f"...{snippet.lstrip()}"
if end < len(normalized):
snippet = f"{snippet.rstrip()}..."
return snippet
def _truncate_text(text: str, *, max_length: int) -> str:
normalized = str(text or "").strip()
if len(normalized) <= max_length:
return normalized
return f"{normalized[: max_length - 3].rstrip()}..."

View File

@@ -404,6 +404,14 @@ class OrchestratorService:
ontology: OntologyParseResult,
task_asset: AgentAssetRead | None,
) -> dict[str, list[AgentAssetListItem | AgentAssetRead]]:
if ontology.scenario == "knowledge" and payload.source == AgentRunSource.USER_MESSAGE.value:
return {
"rules": [],
"skills": [],
"mcps": [],
"tasks": [],
}
domain_value = SCENARIO_TO_DOMAIN.get(ontology.scenario)
rules = self.execution_engine._rank_assets(
self.asset_service.list_assets(

View File

@@ -0,0 +1,191 @@
# ruff: noqa: E501
from __future__ import annotations
import html
from dataclasses import dataclass
@dataclass(frozen=True)
class RiskRuleFlowDiagramField:
key: str
label: str
@dataclass(frozen=True)
class RiskRuleFlowDiagramSpec:
title: str
domain_label: str
severity: str
severity_label: str
fields: tuple[RiskRuleFlowDiagramField, ...]
start: str
evidence: str
decision: str
basis: str
pass_text: str
fail_text: str
@dataclass(frozen=True)
class RiskRuleFlowDiagramPalette:
accent: str
accent_dark: str
border: str
surface: str
class RiskRuleFlowDiagramRenderer:
"""按 fireworks-tech-graph Style 7 OpenAI Official 生成只读流程 SVG。"""
_FONT = (
"-apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Helvetica Neue, "
"'PingFang SC', 'Microsoft YaHei', 'Microsoft JhengHei', 'SimHei', sans-serif"
)
_TEXT = "#0d0d0d"
_MUTED = "#6e6e80"
_NEUTRAL_LINE = "#cbd5e1"
_NEUTRAL_BORDER = "#e2e8f0"
_NEUTRAL_SURFACE = "#ffffff"
_PALETTES = {
"low": RiskRuleFlowDiagramPalette(
accent="#2563eb",
accent_dark="#1d4ed8",
border="#bfdbfe",
surface="#eff6ff",
),
"medium": RiskRuleFlowDiagramPalette(
accent="#f97316",
accent_dark="#c2410c",
border="#fed7aa",
surface="#fff7ed",
),
"high": RiskRuleFlowDiagramPalette(
accent="#dc2626",
accent_dark="#b91c1c",
border="#fecaca",
surface="#fef2f2",
),
}
def render(self, spec: RiskRuleFlowDiagramSpec) -> str:
title = self._truncate(spec.title, 26)
palette = self._palette(spec.severity)
return f"""<svg xmlns="http://www.w3.org/2000/svg" width="760" height="280" viewBox="0 0 760 280" data-risk-flow-style="review-node-only" role="img" aria-labelledby="risk-flow-title risk-flow-desc">
<title id="risk-flow-title">{self._escape(title)}流程说明</title>
<desc id="risk-flow-desc">风险规则只读流程图,展示从业务单据提交到风险复核的判断路径。</desc>
<defs>
<marker id="arrow-neutral" markerWidth="10" markerHeight="7" refX="9" refY="3.5" orient="auto">
<polygon points="0 0, 10 3.5, 0 7" fill="{self._NEUTRAL_LINE}"/>
</marker>
</defs>
<rect width="760" height="280" fill="#ffffff"/>
<rect x="18" y="18" width="724" height="244" rx="8" ry="8" fill="none" stroke="{self._NEUTRAL_BORDER}" stroke-width="1" stroke-dasharray="4,3"/>
<text x="34" y="43" fill="{self._MUTED}" font-family="{self._FONT}" font-size="11" font-weight="500">RULE FLOW</text>
{self._node("业务输入", spec.start, 48, 118, 124, 60)}
{self._node("字段取数", "读取字段证据", 214, 118, 132, 60)}
{self._diamond("判断依据", spec.decision, 392, 92, 112, 112)}
{self._node("继续流转", spec.pass_text, 562, 74, 126, 60)}
{self._node("进入复核", spec.fail_text, 562, 190, 126, 62, palette=palette)}
{self._note(spec.basis, 214, 218, 290, 36)}
<line x1="172" y1="148" x2="214" y2="148" stroke="{self._NEUTRAL_LINE}" stroke-width="1.45" marker-end="url(#arrow-neutral)"/>
<line x1="346" y1="148" x2="392" y2="148" stroke="{self._NEUTRAL_LINE}" stroke-width="1.45" marker-end="url(#arrow-neutral)"/>
<path d="M 504 127 L 532 127 L 532 104 L 562 104" fill="none" stroke="{self._NEUTRAL_LINE}" stroke-width="1.35" marker-end="url(#arrow-neutral)"/>
<text x="534" y="119" text-anchor="middle" fill="{self._MUTED}" font-family="{self._FONT}" font-size="10.5" font-weight="400">否</text>
<path d="M 504 169 L 532 169 L 532 221 L 562 221" fill="none" stroke="{self._NEUTRAL_LINE}" stroke-width="1.8" marker-end="url(#arrow-neutral)"/>
<text x="534" y="195" text-anchor="middle" fill="{self._MUTED}" font-family="{self._FONT}" font-size="10.5" font-weight="600">是</text>
</svg>"""
def _node(
self,
title: str,
body: str,
x: int,
y: int,
width: int,
height: int,
palette: RiskRuleFlowDiagramPalette | None = None,
) -> str:
body_lines = self._wrap(body, 10 if width <= 126 else 11, 1)
border = palette.border if palette else self._NEUTRAL_BORDER
stripe = palette.accent if palette else self._NEUTRAL_LINE
surface = palette.surface if palette else self._NEUTRAL_SURFACE
return f"""<g>
<rect x="{x}" y="{y}" width="{width}" height="{height}" rx="7" ry="7" fill="{surface}" stroke="{border}" stroke-width="1.2"/>
<rect x="{x}" y="{y}" width="3.5" height="{height}" rx="1.75" ry="1.75" fill="{stripe}"/>
<text x="{x + width / 2:.0f}" y="{y + 24}" text-anchor="middle" fill="{self._TEXT}" font-family="{self._FONT}" font-size="13" font-weight="600">{self._escape(title)}</text>
{self._text_lines(body_lines, x + width / 2, y + 43, "middle", self._MUTED, 11)}
</g>"""
def _diamond(
self,
title: str,
body: str,
x: int,
y: int,
width: int,
height: int,
) -> str:
cx = x + width / 2
cy = y + height / 2
points = f"{cx},{y} {x + width},{cy} {cx},{y + height} {x},{cy}"
body_lines = self._wrap(body, 8, 2)
return f"""<g>
<polygon points="{points}" fill="#ffffff" stroke="{self._NEUTRAL_BORDER}" stroke-width="1.25"/>
<text x="{cx:.0f}" y="{cy - 10:.0f}" text-anchor="middle" fill="{self._TEXT}" font-family="{self._FONT}" font-size="12.5" font-weight="600">{self._escape(title)}</text>
{self._text_lines(body_lines, cx, cy + 11, "middle", self._MUTED, 10.2)}
</g>"""
def _note(
self,
body: str,
x: int,
y: int,
width: int,
height: int,
) -> str:
lines = self._wrap(body, 22, 1)
return f"""<g>
<rect x="{x}" y="{y}" width="{width}" height="{height}" rx="7" ry="7" fill="#ffffff" stroke="{self._NEUTRAL_BORDER}" stroke-width="1" stroke-dasharray="4,3"/>
<text x="{x + 12}" y="{y + 22}" fill="{self._MUTED}" font-family="{self._FONT}" font-size="10" font-weight="500">BASIS</text>
{self._text_lines(lines, x + 54, y + 22, "start", self._TEXT, 10.2)}
</g>"""
def _text_lines(
self,
lines: list[str],
x: float,
y: float,
anchor: str,
color: str,
font_size: float,
) -> str:
return "\n ".join(
f'<text x="{x:.0f}" y="{y + index * (font_size + 5):.0f}" text-anchor="{anchor}" fill="{color}" font-family="{self._FONT}" font-size="{font_size}" font-weight="400">{self._escape(line)}</text>'
for index, line in enumerate(lines)
)
@staticmethod
def _wrap(value: str, width: int, max_lines: int) -> list[str]:
text = str(value or "").strip()
if not text:
return [""]
lines = [text[index : index + width] for index in range(0, len(text), width)]
if len(lines) > max_lines:
lines = lines[:max_lines]
lines[-1] = f"{lines[-1][: max(0, width - 1)]}"
return lines
@staticmethod
def _truncate(value: str, length: int) -> str:
text = str(value or "").strip()
return text if len(text) <= length else f"{text[: length - 1]}"
@staticmethod
def _escape(value: str) -> str:
return html.escape(str(value or ""), quote=True)
@classmethod
def _palette(cls, severity: str) -> RiskRuleFlowDiagramPalette:
return cls._PALETTES.get(str(severity or "").strip().lower(), cls._PALETTES["medium"])

View File

@@ -0,0 +1,751 @@
from __future__ import annotations
import json
import re
from dataclasses import dataclass
from datetime import UTC, datetime
from typing import Any
from sqlalchemy.orm import Session
from app.core.agent_enums import AgentAssetDomain, AgentAssetStatus, AgentAssetType
from app.models.agent_asset import AgentAsset, AgentAssetVersion
from app.schemas.agent_asset import AgentAssetRiskRuleGenerateRequest
from app.services.agent_asset_rule_library import AgentAssetRuleLibraryManager
from app.services.agent_asset_spreadsheet import RISK_RULES_LIBRARY
from app.services.audit import AuditLogService
from app.services.risk_rule_flow_diagram import (
RiskRuleFlowDiagramField,
RiskRuleFlowDiagramRenderer,
RiskRuleFlowDiagramSpec,
)
from app.services.runtime_chat import RuntimeChatService
@dataclass(frozen=True)
class RiskRuleField:
key: str
label: str
field_type: str
source: str
aliases: tuple[str, ...]
BUSINESS_DOMAIN_LABELS: dict[str, str] = {
AgentAssetDomain.EXPENSE.value: "报销",
AgentAssetDomain.AR.value: "应收",
AgentAssetDomain.AP.value: "应付",
}
RISK_LEVEL_LABELS: dict[str, str] = {
"low": "低风险",
"medium": "中风险",
"high": "高风险",
}
FIELD_ONTOLOGY: tuple[RiskRuleField, ...] = (
RiskRuleField("claim.reason", "报销事由", "text", "claim", ("事由", "说明", "理由", "用途")),
RiskRuleField(
"claim.location",
"申报地点",
"text",
"claim",
("地点", "城市", "出差地", "申报地点", "申报目的地", "目的地"),
),
RiskRuleField("claim.amount", "申报金额", "number", "claim", ("金额", "费用", "超额", "额度")),
RiskRuleField("claim.employee_name", "报销人", "text", "claim", ("报销人", "员工", "申请人")),
RiskRuleField("claim.department_name", "部门", "text", "claim", ("部门", "组织")),
RiskRuleField("item.item_type", "费用类型", "enum", "item", ("费用类型", "科目", "类型")),
RiskRuleField("item.item_reason", "明细事由", "text", "item", ("明细事由", "明细说明")),
RiskRuleField("item.item_location", "明细地点", "text", "item", ("明细地点", "发生地点")),
RiskRuleField(
"attachment.invoice_no", "发票号码", "text", "attachment", ("发票号", "发票号码", "票号")
),
RiskRuleField(
"attachment.buyer_name", "购买方名称", "text", "attachment", ("抬头", "购买方", "开票单位")
),
RiskRuleField(
"attachment.goods_name",
"商品服务名称",
"text",
"attachment",
("品名", "商品", "服务名称", "摘要"),
),
RiskRuleField(
"attachment.issue_date",
"开票日期",
"date",
"attachment",
("开票日期", "发票日期", "票据日期"),
),
RiskRuleField(
"attachment.hotel_city",
"住宿城市",
"text",
"attachment",
("住宿城市", "酒店城市", "酒店地点", "酒店发票城市", "酒店票城市", "住宿发票城市"),
),
RiskRuleField(
"attachment.route_cities",
"行程城市",
"list",
"attachment",
("行程", "路线", "途经城市", "出差城市", "交通票行程", "交通票城市"),
),
RiskRuleField(
"attachment.ocr_text",
"票据全文",
"text",
"attachment",
("票据内容", "OCR", "全文", "关键字", "关键词"),
),
RiskRuleField(
"receivable.aging_days", "应收账龄", "number", "receivable", ("账龄", "逾期", "应收逾期")
),
RiskRuleField(
"receivable.amount_outstanding",
"应收未收金额",
"number",
"receivable",
("未收金额", "欠款", "应收余额"),
),
RiskRuleField(
"payable.vendor_name", "供应商名称", "text", "payable", ("供应商", "付款方", "往来单位")
),
RiskRuleField(
"payable.amount_outstanding", "应付未付金额", "number", "payable", ("未付金额", "应付余额")
),
)
DOMAIN_FIELD_PREFIXES: dict[str, tuple[str, ...]] = {
AgentAssetDomain.EXPENSE.value: ("claim.", "item.", "attachment."),
AgentAssetDomain.AR.value: ("receivable.",),
AgentAssetDomain.AP.value: ("payable.",),
}
class RiskRuleGenerationService:
def __init__(
self,
db: Session,
*,
rule_library_manager: AgentAssetRuleLibraryManager | None = None,
runtime_chat_service: RuntimeChatService | None = None,
) -> None:
self.db = db
self.rule_library_manager = rule_library_manager or AgentAssetRuleLibraryManager()
self.runtime_chat_service = runtime_chat_service or RuntimeChatService(db)
self.audit_service = AuditLogService(db)
self.flow_diagram_renderer = RiskRuleFlowDiagramRenderer()
def generate_rule_asset(
self,
body: AgentAssetRiskRuleGenerateRequest,
*,
actor: str,
request_id: str | None = None,
) -> str:
domain = body.business_domain.value
if domain not in BUSINESS_DOMAIN_LABELS:
raise ValueError("当前仅支持报销、应收、应付业务域的新建风险规则。")
natural_language = self._clean_text(body.natural_language)
if len(natural_language) < 8:
raise ValueError("请至少输入 8 个字的风险规则描述。")
risk_level = str(body.risk_level or "medium").strip().lower()
if risk_level not in RISK_LEVEL_LABELS:
raise ValueError("风险等级仅支持 low、medium、high。")
created_at = datetime.now(UTC)
fields = self._resolve_fields(natural_language, domain=domain)
draft = self._compile_with_model(
natural_language=natural_language,
domain=domain,
risk_level=risk_level,
fields=fields,
) or self._build_fallback_draft(
natural_language=natural_language,
domain=domain,
risk_level=risk_level,
fields=fields,
)
draft = self._align_draft_fields(
draft,
natural_language=natural_language,
fields=fields,
)
payload = self._build_rule_payload(
draft,
natural_language=natural_language,
domain=domain,
risk_level=risk_level,
fields=fields,
created_at=created_at,
actor=actor,
)
rule_code = str(payload["rule_code"])
file_name = f"{rule_code}.json"
self.rule_library_manager.write_rule_library_json(
library=RISK_RULES_LIBRARY,
file_name=file_name,
payload=payload,
)
asset = AgentAsset(
asset_type=AgentAssetType.RULE.value,
code=rule_code,
name=str(payload["name"]),
description=str(payload["description"]),
domain=domain,
scenario_json=[str(payload.get("risk_category") or BUSINESS_DOMAIN_LABELS[domain])],
owner=actor,
reviewer=None,
status=AgentAssetStatus.DRAFT.value,
current_version="v0.1.0",
published_version=None,
working_version="v0.1.0",
config_json={
"severity": risk_level,
"enabled": True,
"tag": "风险规则",
"detail_mode": "json_risk",
"risk_category": payload.get("risk_category"),
"rule_library": RISK_RULES_LIBRARY,
"rule_document": {
"file_name": file_name,
"storage_key": f"rules/{RISK_RULES_LIBRARY}/{file_name}",
},
"ontology_signal": payload.get("ontology_signal"),
"evaluator": payload.get("evaluator"),
"generated_by": "natural_language",
"source_ref": "自然语言风险规则",
},
)
self.db.add(asset)
self.db.flush()
self.db.add(
AgentAssetVersion(
asset_id=asset.id,
version="v0.1.0",
content=self._build_version_markdown(payload),
content_type="markdown",
change_note="通过自然语言新建风险规则草稿。",
created_by=actor,
)
)
self.audit_service.log_action(
actor=actor,
action="generate_agent_asset_risk_rule",
resource_type=AgentAssetType.RULE.value,
resource_id=asset.id,
before_json=None,
after_json={"rule_code": rule_code, "risk_level": risk_level, "domain": domain},
request_id=request_id,
)
self.db.refresh(asset)
return asset.id
def _compile_with_model(
self,
*,
natural_language: str,
domain: str,
risk_level: str,
fields: list[RiskRuleField],
) -> dict[str, Any] | None:
field_payload = [
{
"key": item.key,
"label": item.label,
"type": item.field_type,
"source": item.source,
}
for item in fields
]
messages = [
{
"role": "system",
"content": (
"你是 X-Financial 风险规则编译器。只能输出 JSON 对象,不要解释。"
"必须从给定字段本体中选择字段,不允许编造字段。"
"template_key 只能是 field_required_v1、field_compare_v1、keyword_match_v1。"
),
},
{
"role": "user",
"content": json.dumps(
{
"business_domain": domain,
"business_domain_label": BUSINESS_DOMAIN_LABELS[domain],
"risk_level": risk_level,
"risk_level_label": RISK_LEVEL_LABELS[risk_level],
"natural_language": natural_language,
"available_fields": field_payload,
"required_json_shape": {
"name": "规则名称",
"description": "面向业务用户的说明",
"template_key": "field_required_v1",
"field_keys": ["claim.reason"],
"condition_summary": "判断依据",
"keywords": [],
"flow": {
"start": "提交业务单据",
"evidence": "读取字段",
"decision": "判断依据",
"pass": "继续流转",
"fail": "提示风险",
},
},
},
ensure_ascii=False,
),
},
]
answer = self.runtime_chat_service.complete(
messages,
max_tokens=700,
temperature=0.1,
timeout_seconds=12,
max_attempts=1,
)
if not answer:
return None
try:
payload = json.loads(self._extract_json_object(answer))
except (json.JSONDecodeError, ValueError):
return None
if not isinstance(payload, dict):
return None
return self._sanitize_model_draft(payload, fields=fields)
def _sanitize_model_draft(
self,
payload: dict[str, Any],
*,
fields: list[RiskRuleField],
) -> dict[str, Any]:
allowed_fields = {item.key for item in fields}
template_key = str(payload.get("template_key") or "").strip()
if template_key not in {"field_required_v1", "field_compare_v1", "keyword_match_v1"}:
template_key = "field_required_v1"
raw_field_keys = payload.get("field_keys")
field_keys = [
str(item or "").strip()
for item in (raw_field_keys if isinstance(raw_field_keys, list) else [])
if str(item or "").strip() in allowed_fields
]
if not field_keys and fields:
field_keys = [fields[0].key]
keywords = [
str(item or "").strip()
for item in (
payload.get("keywords") if isinstance(payload.get("keywords"), list) else []
)
if str(item or "").strip()
]
flow = payload.get("flow") if isinstance(payload.get("flow"), dict) else {}
return {
"name": self._clean_text(payload.get("name"))[:80],
"description": self._clean_text(payload.get("description")),
"template_key": template_key,
"field_keys": field_keys,
"condition_summary": self._clean_text(payload.get("condition_summary")),
"keywords": keywords[:12],
"flow": {
"start": self._clean_text(flow.get("start")) or "提交业务单据",
"evidence": self._clean_text(flow.get("evidence")) or "读取规则字段",
"decision": self._clean_text(flow.get("decision")) or "判断是否命中风险",
"pass": self._clean_text(flow.get("pass")) or "继续流转",
"fail": self._clean_text(flow.get("fail")) or "提示风险并进入复核",
},
}
def _build_fallback_draft(
self,
*,
natural_language: str,
domain: str,
risk_level: str,
fields: list[RiskRuleField],
) -> dict[str, Any]:
field_keys = [item.key for item in fields[:4]]
template_key = self._infer_template_key(natural_language)
condition_summary = self._build_condition_summary(
natural_language,
template_key=template_key,
fields=fields,
)
name = self._infer_rule_name(natural_language)
description = (
f"{BUSINESS_DOMAIN_LABELS[domain]}业务满足“{natural_language}”时,系统会按"
f"{RISK_LEVEL_LABELS[risk_level]}进行提示,并要求经办人或审核人补充核对依据。"
)
return {
"name": name,
"description": description,
"template_key": template_key,
"field_keys": field_keys,
"condition_summary": condition_summary,
"keywords": self._infer_keywords(natural_language),
"flow": {
"start": f"{BUSINESS_DOMAIN_LABELS[domain]}单据提交",
"evidence": "读取" + "".join(item.label for item in fields[:3]),
"decision": condition_summary,
"pass": "未命中风险,继续业务流转",
"fail": f"命中{RISK_LEVEL_LABELS[risk_level]},提示复核",
},
}
def _build_rule_payload(
self,
draft: dict[str, Any],
*,
natural_language: str,
domain: str,
risk_level: str,
fields: list[RiskRuleField],
created_at: datetime,
actor: str,
) -> dict[str, Any]:
created_stamp = created_at.strftime("%Y%m%d%H%M%S")
domain_slug = {"expense": "expense", "ar": "ar", "ap": "ap"}[domain]
rule_code = f"risk.{domain_slug}.generated_{created_stamp}"
template_key = str(draft.get("template_key") or "field_required_v1").strip()
field_keys = [
str(item or "").strip()
for item in list(draft.get("field_keys") or [])
if str(item or "").strip()
]
condition_summary = (
self._clean_text(draft.get("condition_summary")) or "判断是否符合自然语言规则描述"
)
risk_category = BUSINESS_DOMAIN_LABELS[domain]
keywords = list(draft.get("keywords") or [])
field_by_key = {item.key: item for item in fields}
params: dict[str, Any] = {
"template_key": template_key,
"field_keys": field_keys,
"condition_summary": condition_summary,
"natural_language": natural_language,
}
if template_key == "field_required_v1":
params["required_fields"] = field_keys
if template_key == "field_compare_v1":
params["conditions"] = self._build_compare_conditions(field_keys)
if template_key == "keyword_match_v1":
params["keywords"] = keywords
params["search_fields"] = field_keys
payload = {
"schema_version": "2.0",
"rule_code": rule_code,
"name": self._clean_text(draft.get("name")) or self._infer_rule_name(natural_language),
"description": self._clean_text(draft.get("description")) or natural_language,
"enabled": True,
"risk_dimension": "natural_language_rule",
"risk_category": risk_category,
"ontology_signal": "natural_language_risk",
"evaluator": "template_rule",
"template_key": template_key,
"applies_to": {"domains": [domain]},
"inputs": {
"fields": [
{
"key": item.key,
"label": item.label,
"type": item.field_type,
"source": item.source,
}
for item in [field_by_key[key] for key in field_keys if key in field_by_key]
],
},
"params": params,
"outcomes": {
"pass": {"severity": "none", "action": "continue"},
"fail": {
"severity": risk_level,
"action": "manual_review",
},
},
"metadata": {
"owner": actor,
"stability": "generated_draft",
"source_ref": "自然语言风险规则",
"created_at": created_at.isoformat(),
"created_by": actor,
"natural_language": natural_language,
"business_explanation": self._clean_text(draft.get("description")),
"condition_summary": condition_summary,
"flow": draft.get("flow") if isinstance(draft.get("flow"), dict) else {},
},
}
payload["flow_diagram_svg"] = self._build_flow_diagram_svg(
payload,
fields=[field_by_key[key] for key in field_keys if key in field_by_key],
domain=domain,
risk_level=risk_level,
)
return payload
def _build_flow_diagram_svg(
self,
payload: dict[str, Any],
*,
fields: list[RiskRuleField],
domain: str,
risk_level: str,
) -> str:
metadata = payload.get("metadata") if isinstance(payload.get("metadata"), dict) else {}
flow = metadata.get("flow") if isinstance(metadata.get("flow"), dict) else {}
condition_summary = self._clean_text(metadata.get("condition_summary"))
return self.flow_diagram_renderer.render(
RiskRuleFlowDiagramSpec(
title=self._clean_text(payload.get("name")) or "风险规则判断流程",
domain_label=BUSINESS_DOMAIN_LABELS.get(domain, "业务"),
severity=risk_level,
severity_label=RISK_LEVEL_LABELS.get(risk_level, "中风险"),
fields=tuple(
RiskRuleFlowDiagramField(key=field.key, label=field.label) for field in fields
),
start=self._clean_text(flow.get("start")) or "业务单据提交",
evidence=self._clean_text(flow.get("evidence")) or "读取规则字段",
decision=self._clean_text(flow.get("decision"))
or condition_summary
or "判断是否命中风险",
basis=(
condition_summary
or self._clean_text(flow.get("decision"))
or "根据规则字段判断"
),
pass_text=self._clean_text(flow.get("pass")) or "未命中风险,继续流转",
fail_text=self._clean_text(flow.get("fail"))
or f"命中{RISK_LEVEL_LABELS.get(risk_level, '风险')},进入人工复核",
)
)
def _resolve_fields(self, text: str, *, domain: str) -> list[RiskRuleField]:
prefixes = DOMAIN_FIELD_PREFIXES.get(domain, ())
candidates = [field for field in FIELD_ONTOLOGY if field.key.startswith(prefixes)]
normalized = text.lower()
matched: list[tuple[int, RiskRuleField]] = []
for field in candidates:
score = self._score_field_match(field, text, normalized)
if score > 0:
matched.append((score, field))
if domain == AgentAssetDomain.EXPENSE.value:
if any(keyword in text for keyword in ("住宿", "酒店", "行程", "城市", "出差")):
matched.extend(
(10, field)
for field in candidates
if field.key
in {"claim.location", "attachment.hotel_city", "attachment.route_cities"}
)
if any(keyword in text for keyword in ("发票", "票据", "品名", "抬头", "开票")):
matched.extend(
(6, field)
for field in candidates
if field.key
in {
"attachment.invoice_no",
"attachment.buyer_name",
"attachment.goods_name",
"attachment.ocr_text",
}
)
matched.sort(key=lambda item: item[0], reverse=True)
deduped: list[RiskRuleField] = []
seen: set[str] = set()
for _, field in matched:
if field.key in seen:
continue
seen.add(field.key)
deduped.append(field)
if deduped:
return deduped[:8]
return candidates[:4]
@staticmethod
def _score_field_match(field: RiskRuleField, text: str, normalized: str) -> int:
score = 0
if field.label in text:
score += 8
for alias in field.aliases:
if alias.lower() in normalized:
score += 4 + min(len(alias), 6)
if field.key == "attachment.hotel_city" and any(term in text for term in ("酒店", "住宿")):
score += 12
if field.key == "attachment.route_cities" and any(
term in text for term in ("行程", "交通票", "路线", "途经")
):
score += 10
if field.key == "claim.location" and any(
term in text for term in ("申报目的地", "申报地点", "目的地", "出差地")
):
score += 10
if field.key.startswith("attachment.") and any(term in text for term in ("发票", "票据")):
score += 2
return score
def _align_draft_fields(
self,
draft: dict[str, Any],
*,
natural_language: str,
fields: list[RiskRuleField],
) -> dict[str, Any]:
field_by_key = {field.key: field for field in fields}
original_keys = [
str(item or "").strip()
for item in list(draft.get("field_keys") or [])
if str(item or "").strip() in field_by_key
]
preferred_keys: list[str] = []
def add_preferred(key: str, *terms: str) -> None:
if key in field_by_key and any(term in natural_language for term in terms):
preferred_keys.append(key)
add_preferred("attachment.hotel_city", "酒店", "住宿")
add_preferred("claim.location", "申报目的地", "申报地点", "目的地", "出差地")
add_preferred("attachment.route_cities", "行程", "交通票", "路线", "途经")
merged_keys: list[str] = []
for key in [*preferred_keys, *original_keys, *[field.key for field in fields]]:
if key in field_by_key and key not in merged_keys:
merged_keys.append(key)
if len(merged_keys) >= 4:
break
if draft.get("template_key") == "field_compare_v1" and len(merged_keys) < 2:
for field in fields:
if field.key not in merged_keys:
merged_keys.append(field.key)
if len(merged_keys) >= 2:
break
aligned = {**draft, "field_keys": merged_keys}
selected_fields = [field_by_key[key] for key in merged_keys if key in field_by_key]
if selected_fields:
aligned["condition_summary"] = self._build_condition_summary(
natural_language,
template_key=str(aligned.get("template_key") or "field_required_v1"),
fields=selected_fields,
)
flow = aligned.get("flow") if isinstance(aligned.get("flow"), dict) else {}
aligned["flow"] = {
**flow,
"evidence": "读取" + "".join(field.label for field in selected_fields[:3]),
"decision": aligned["condition_summary"],
}
return aligned
@staticmethod
def _build_compare_conditions(field_keys: list[str]) -> list[dict[str, str]]:
if len(field_keys) >= 2:
return [{"left": field_keys[0], "operator": "overlap", "right": field_keys[1]}]
if field_keys:
return [{"left": field_keys[0], "operator": "is_empty", "right": ""}]
return []
@staticmethod
def _infer_template_key(text: str) -> str:
if any(
keyword in text
for keyword in ("一致", "匹配", "相同", "不一致", "不符", "对应", "出现在")
):
return "field_compare_v1"
if any(
keyword in text
for keyword in ("关键词", "包含", "出现", "品名", "摘要", "服务费", "咨询费")
):
return "keyword_match_v1"
return "field_required_v1"
@staticmethod
def _infer_keywords(text: str) -> list[str]:
quoted = re.findall(r"[“\"']([^“”\"']{2,20})[”\"']", text)
keywords = [item.strip() for item in quoted if item.strip()]
for candidate in ("咨询费", "服务费", "其他", "办公用品", "招待", "红冲", "作废"):
if candidate in text and candidate not in keywords:
keywords.append(candidate)
return keywords[:8]
@staticmethod
def _infer_rule_name(text: str) -> str:
normalized = re.sub(r"\s+", "", str(text or ""))
normalized = re.sub(r"[,。;;:、,.!?]", "", normalized)
if not normalized:
return "自然语言风险规则"
return f"{normalized[:18]}风险规则"
@staticmethod
def _build_condition_summary(
natural_language: str,
*,
template_key: str,
fields: list[RiskRuleField],
) -> str:
field_text = "".join(item.label for item in fields[:3]) or "业务字段"
if template_key == "field_compare_v1":
return f"对比{field_text}之间是否一致或存在交集"
if template_key == "keyword_match_v1":
return f"检查{field_text}是否出现规则描述中的风险关键词"
return f"检查{field_text}是否满足必填和完整性要求"
@staticmethod
def _clean_text(value: Any) -> str:
return re.sub(r"\s+", " ", str(value or "")).strip()
@staticmethod
def _extract_json_object(text: str) -> str:
normalized = re.sub(r"^```(?:json)?|```$", "", str(text or "").strip(), flags=re.IGNORECASE)
start = normalized.find("{")
end = normalized.rfind("}")
if start < 0 or end <= start:
raise ValueError("JSON object not found.")
return normalized[start : end + 1]
@staticmethod
def _build_version_markdown(payload: dict[str, Any]) -> str:
metadata = payload.get("metadata") if isinstance(payload.get("metadata"), dict) else {}
fields = (
payload.get("inputs", {}).get("fields")
if isinstance(payload.get("inputs"), dict)
else []
)
field_labels = [
str(item.get("label") or item.get("key") or "").strip()
for item in fields
if isinstance(item, dict) and str(item.get("label") or item.get("key") or "").strip()
]
return "\n".join(
[
f"# {payload.get('name')}",
"",
"## 业务说明",
"",
str(payload.get("description") or ""),
"",
"## 自然语言原文",
"",
str(metadata.get("natural_language") or ""),
"",
"## 使用字段",
"",
"".join(field_labels) or "未识别字段",
"",
"## 运行时 JSON",
"",
"```json",
json.dumps(payload, ensure_ascii=False, indent=2),
"```",
]
)

View File

@@ -0,0 +1,259 @@
from __future__ import annotations
import re
from typing import Any
from app.models.financial_record import ExpenseClaim
class RiskRuleTemplateExecutor:
def evaluate(
self,
manifest: dict[str, Any],
*,
claim: ExpenseClaim,
contexts: list[dict[str, Any]],
) -> dict[str, Any] | None:
params = manifest.get("params") if isinstance(manifest.get("params"), dict) else {}
template_key = str(manifest.get("template_key") or params.get("template_key") or "").strip()
if template_key == "field_required_v1":
return self._evaluate_required_fields(params, claim=claim, contexts=contexts)
if template_key == "field_compare_v1":
return self._evaluate_compare_conditions(params, claim=claim, contexts=contexts)
if template_key == "keyword_match_v1":
return self._evaluate_keyword_match(params, claim=claim, contexts=contexts)
return None
def _evaluate_required_fields(
self,
params: dict[str, Any],
*,
claim: ExpenseClaim,
contexts: list[dict[str, Any]],
) -> dict[str, Any] | None:
required_fields = self._read_string_list(
params.get("required_fields") or params.get("field_keys")
)
missing = [
field_key
for field_key in required_fields
if not self._has_resolved_value(field_key, claim=claim, contexts=contexts)
]
if not missing:
return None
return {
"message": self._resolve_message(
params,
fallback=f"规则要求的字段未完整提供:{''.join(missing[:4])}",
),
"evidence": {
"missing_fields": missing,
"condition_summary": params.get("condition_summary"),
},
}
def _evaluate_compare_conditions(
self,
params: dict[str, Any],
*,
claim: ExpenseClaim,
contexts: list[dict[str, Any]],
) -> dict[str, Any] | None:
conditions = params.get("conditions") if isinstance(params.get("conditions"), list) else []
failures: list[dict[str, Any]] = []
for condition in conditions:
if not isinstance(condition, dict):
continue
left_key = str(condition.get("left") or "").strip()
right_key = str(condition.get("right") or "").strip()
operator = str(condition.get("operator") or "not_overlap").strip()
left_values = self._resolve_values(left_key, claim=claim, contexts=contexts)
right_values = self._resolve_values(right_key, claim=claim, contexts=contexts)
if self._condition_passes(operator, left_values, right_values):
continue
failures.append(
{
"left": left_key,
"operator": operator,
"right": right_key,
"left_values": left_values[:5],
"right_values": right_values[:5],
}
)
if not failures:
return None
return {
"message": self._resolve_message(
params,
fallback=(
"规则字段对比未通过:"
f"{params.get('condition_summary') or '字段关系不符合要求'}"
),
),
"evidence": {
"failed_conditions": failures[:5],
"condition_summary": params.get("condition_summary"),
},
}
def _evaluate_keyword_match(
self,
params: dict[str, Any],
*,
claim: ExpenseClaim,
contexts: list[dict[str, Any]],
) -> dict[str, Any] | None:
keywords = self._read_string_list(params.get("keywords"))
search_fields = self._read_string_list(
params.get("search_fields") or params.get("field_keys")
)
if not keywords:
return None
corpus_parts: list[str] = []
for field_key in search_fields:
corpus_parts.extend(self._resolve_values(field_key, claim=claim, contexts=contexts))
if not corpus_parts:
corpus_parts.extend(
[
str(claim.reason or ""),
str(claim.location or ""),
*[str(item.item_reason or "") for item in list(claim.items or [])],
*[str(context.get("ocr_text") or "") for context in contexts],
]
)
corpus = "\n".join(corpus_parts)
hits = [keyword for keyword in keywords if keyword and keyword in corpus]
if not hits:
return None
return {
"message": self._resolve_message(
params,
fallback=f"识别到风险关键词:{''.join(hits[:5])}",
),
"evidence": {
"keyword_hits": hits[:8],
"search_fields": search_fields,
"condition_summary": params.get("condition_summary"),
},
}
def _resolve_values(
self,
field_key: str,
*,
claim: ExpenseClaim,
contexts: list[dict[str, Any]],
) -> list[str]:
normalized = str(field_key or "").strip()
if not normalized:
return []
if normalized.startswith("claim."):
return self._normalize_values([getattr(claim, normalized.removeprefix("claim."), "")])
if normalized.startswith("item."):
attr = normalized.removeprefix("item.")
return self._normalize_values(
[getattr(item, attr, "") for item in list(claim.items or [])]
)
if normalized.startswith("attachment."):
return self._resolve_attachment_values(normalized.removeprefix("attachment."), contexts)
return []
def _resolve_attachment_values(
self, field_key: str, contexts: list[dict[str, Any]]
) -> list[str]:
values: list[Any] = []
for context in contexts:
document_info = context.get("document_info") if isinstance(context, dict) else {}
if not isinstance(document_info, dict):
document_info = {}
if field_key == "ocr_text":
values.extend([context.get("ocr_text"), context.get("ocr_summary")])
if field_key in {"hotel_city", "route_cities"}:
values.extend(self._scan_document_values(document_info, field_key))
values.extend(self._scan_document_values(document_info, "city"))
else:
values.extend(self._scan_document_values(document_info, field_key))
return self._normalize_values(values)
def _scan_document_values(self, document_info: dict[str, Any], field_key: str) -> list[Any]:
values: list[Any] = []
for key in {field_key, field_key.replace("_", ""), field_key.replace("_", "-")}:
if key in document_info:
values.append(document_info.get(key))
for field in list(document_info.get("fields") or []):
if not isinstance(field, dict):
continue
key = str(field.get("key") or "").strip().lower()
label = str(field.get("label") or "").strip()
if self._field_matches(key, label, field_key):
values.append(field.get("value"))
return values
@staticmethod
def _field_matches(key: str, label: str, field_key: str) -> bool:
compact_key = key.replace("_", "")
compact_target = field_key.replace("_", "")
if compact_target in compact_key:
return True
label_map = {
"invoice_no": ("发票号", "发票号码", "票号"),
"buyer_name": ("购买方", "抬头", "买方"),
"goods_name": ("品名", "商品", "服务名称"),
"issue_date": ("日期", "开票日期", "发票日期"),
"hotel_city": ("住宿城市", "酒店城市", "酒店地点"),
"route_cities": ("行程", "路线", "城市"),
"city": ("城市", "地点"),
}
return any(item in label for item in label_map.get(field_key, ()))
def _has_resolved_value(
self,
field_key: str,
*,
claim: ExpenseClaim,
contexts: list[dict[str, Any]],
) -> bool:
return bool(self._resolve_values(field_key, claim=claim, contexts=contexts))
@staticmethod
def _condition_passes(operator: str, left_values: list[str], right_values: list[str]) -> bool:
if operator == "is_empty":
return not left_values
if not left_values or not right_values:
return False
left_set = {value.lower() for value in left_values}
right_set = {value.lower() for value in right_values}
if operator in {"equals", "in", "overlap"}:
return bool(left_set & right_set)
if operator in {"not_equals", "not_in", "not_overlap"}:
return not bool(left_set & right_set)
if operator == "contains_any":
return any(any(right in left for right in right_set) for left in left_set)
return bool(left_set & right_set)
@staticmethod
def _normalize_values(values: list[Any]) -> list[str]:
normalized: list[str] = []
for value in values:
if isinstance(value, (list, tuple, set)):
normalized.extend(RiskRuleTemplateExecutor._normalize_values(list(value)))
continue
text = re.sub(r"\s+", " ", str(value or "")).strip()
if text and text not in normalized:
normalized.append(text)
return normalized
@staticmethod
def _read_string_list(value: Any) -> list[str]:
if not isinstance(value, list):
return []
return [str(item or "").strip() for item in value if str(item or "").strip()]
@staticmethod
def _resolve_message(params: dict[str, Any], *, fallback: str) -> str:
template = str(params.get("message_template") or "").strip()
return template or fallback