feat: 新增风险规则生成引擎与知识图谱可视化
后端新增风险规则自动生成和模板执行服务,支持从规则资产 批量生成并持久化风险规则文件;知识库入库日志增强图谱 查询和本地 RAG 回退,前端审计页面增加风险规则模型和流 程图组件,知识入库面板拆分为图谱可视化子组件,报销创 建页面增加引导式流程模型,更新知识库索引数据。
This commit is contained in:
@@ -42,6 +42,7 @@ class AgentAssetJsonRuleMixin:
|
||||
description=str(payload.get("description") or asset.description or "").strip(),
|
||||
evaluator=str(payload.get("evaluator") or ""),
|
||||
ontology_signal=str(payload.get("ontology_signal") or "") or None,
|
||||
flow_diagram_svg=str(payload.get("flow_diagram_svg") or "") or None,
|
||||
inputs=payload.get("inputs") if isinstance(payload.get("inputs"), dict) else {},
|
||||
outcomes=payload.get("outcomes") if isinstance(payload.get("outcomes"), dict) else {},
|
||||
payload=payload,
|
||||
@@ -95,4 +96,3 @@ class AgentAssetJsonRuleMixin:
|
||||
)
|
||||
self.db.commit()
|
||||
return self.read_rule_json(asset_id)
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -31,6 +33,8 @@ from app.services.agent_foundation_risk_rules import AgentFoundationRiskRuleMixi
|
||||
from app.services.agent_foundation_spreadsheets import AgentFoundationSpreadsheetMixin
|
||||
|
||||
logger = get_logger("app.services.agent_foundation")
|
||||
_foundation_ready_lock = threading.RLock()
|
||||
_foundation_ready_keys: set[str] = set()
|
||||
|
||||
|
||||
def prepare_agent_foundation() -> None:
|
||||
@@ -57,6 +61,17 @@ class AgentFoundationService(
|
||||
self.db = db
|
||||
|
||||
def ensure_foundation_ready(self) -> None:
|
||||
cache_key = self._foundation_cache_key()
|
||||
if cache_key in _foundation_ready_keys:
|
||||
return
|
||||
|
||||
with _foundation_ready_lock:
|
||||
if cache_key in _foundation_ready_keys:
|
||||
return
|
||||
self._prepare_foundation()
|
||||
_foundation_ready_keys.add(cache_key)
|
||||
|
||||
def _prepare_foundation(self) -> None:
|
||||
try:
|
||||
Base.metadata.create_all(bind=self.db.get_bind())
|
||||
self._ensure_agent_asset_schema()
|
||||
@@ -69,6 +84,10 @@ class AgentFoundationService(
|
||||
logger.exception("Failed to prepare agent foundation")
|
||||
raise
|
||||
|
||||
def _foundation_cache_key(self) -> str:
|
||||
bind = self.db.get_bind()
|
||||
return str(getattr(bind, "url", "") or id(bind))
|
||||
|
||||
def _sync_demo_financial_records(self) -> None:
|
||||
if get_settings().seed_demo_financial_records:
|
||||
self._seed_financial_records()
|
||||
|
||||
@@ -6,12 +6,14 @@ from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.core.agent_enums import AgentName, AgentPermissionLevel, AgentRunStatus
|
||||
from app.core.logging import get_logger
|
||||
from app.models.agent_run import AgentRun, AgentToolCall, SemanticParseLog
|
||||
from app.repositories.agent_run import AgentRunRepository
|
||||
from app.schemas.agent_run import AgentRunRead, AgentToolCallRead, SemanticParseRead
|
||||
from app.services.agent_foundation import AgentFoundationService
|
||||
from app.services.knowledge_ingest_log import enrich_knowledge_ingest_route_json
|
||||
|
||||
logger = get_logger("app.services.agent_runs")
|
||||
|
||||
@@ -42,7 +44,7 @@ class AgentRunService:
|
||||
run = self.repository.get_by_run_id(run_id)
|
||||
if run is None:
|
||||
return None
|
||||
return self._serialize_run(run)
|
||||
return self._serialize_run(run, enrich_knowledge_ingest=True)
|
||||
|
||||
def create_run(
|
||||
self,
|
||||
@@ -314,9 +316,19 @@ class AgentRunService:
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _serialize_run(run: AgentRun) -> AgentRunRead:
|
||||
def _serialize_run(
|
||||
self,
|
||||
run: AgentRun,
|
||||
*,
|
||||
enrich_knowledge_ingest: bool = False,
|
||||
) -> AgentRunRead:
|
||||
semantic_parse = run.semantic_parse_logs[0] if run.semantic_parse_logs else None
|
||||
route_json = run.route_json
|
||||
if enrich_knowledge_ingest:
|
||||
route_json = enrich_knowledge_ingest_route_json(
|
||||
dict(run.route_json or {}),
|
||||
storage_root=get_settings().resolved_storage_root_dir,
|
||||
)
|
||||
return AgentRunRead(
|
||||
id=run.id,
|
||||
run_id=run.run_id,
|
||||
@@ -325,7 +337,7 @@ class AgentRunService:
|
||||
user_id=run.user_id,
|
||||
task_id=run.task_id,
|
||||
ontology_json=run.ontology_json,
|
||||
route_json=run.route_json,
|
||||
route_json=route_json,
|
||||
permission_level=run.permission_level,
|
||||
status=run.status,
|
||||
result_summary=run.result_summary,
|
||||
|
||||
@@ -1,36 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from datetime import UTC, date, datetime, timedelta
|
||||
from decimal import Decimal
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import or_, select
|
||||
from sqlalchemy import inspect as sqlalchemy_inspect
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.api.deps import CurrentUserContext
|
||||
from app.core.agent_enums import AgentAssetDomain, AgentAssetStatus, AgentAssetType
|
||||
from app.models.agent_asset import AgentAsset
|
||||
from app.models.financial_record import ExpenseClaim, ExpenseClaimItem
|
||||
from app.schemas.reimbursement import TravelReimbursementCalculatorRequest
|
||||
from app.services.agent_asset_rule_library import AgentAssetRuleLibraryManager
|
||||
from app.services.agent_asset_spreadsheet import RISK_RULES_LIBRARY
|
||||
from app.services.expense_claim_constants import (
|
||||
AI_REVIEW_LOOKBACK_DAYS,
|
||||
AI_REVIEW_REPEAT_RISK_BLOCK_COUNT,
|
||||
AI_REVIEW_REPEAT_RISK_WARNING_COUNT,
|
||||
DOCUMENT_FACT_ITEM_TYPES,
|
||||
LOCATION_REQUIRED_EXPENSE_TYPES,
|
||||
SYSTEM_GENERATED_ITEM_TYPES,
|
||||
TRAVEL_ALLOWANCE_TRIGGER_ITEM_TYPES,
|
||||
TRAVEL_POLICY_HOTEL_NIGHT_PATTERN,
|
||||
)
|
||||
from app.services.expense_rule_runtime import (
|
||||
ExpenseRuleRuntimeService,
|
||||
RuntimeTravelPolicy,
|
||||
build_default_expense_rule_catalog,
|
||||
)
|
||||
from app.services.risk_rule_template_executor import RiskRuleTemplateExecutor
|
||||
|
||||
|
||||
class ExpenseClaimPlatformRiskMixin:
|
||||
@@ -66,9 +49,7 @@ class ExpenseClaimPlatformRiskMixin:
|
||||
if severity == "high" or action == "block":
|
||||
blocking_reasons.append(str(flag.get("message") or flag.get("label") or "").strip())
|
||||
|
||||
deduplicated_reasons = list(
|
||||
dict.fromkeys(reason for reason in blocking_reasons if reason)
|
||||
)
|
||||
deduplicated_reasons = list(dict.fromkeys(reason for reason in blocking_reasons if reason))
|
||||
return {"flags": flags, "blocking_reasons": deduplicated_reasons}
|
||||
|
||||
def _load_platform_risk_rule_manifests(
|
||||
@@ -77,9 +58,7 @@ class ExpenseClaimPlatformRiskMixin:
|
||||
rule_codes: list[str] | None,
|
||||
) -> list[dict[str, Any]]:
|
||||
code_filter = {
|
||||
str(code or "").strip()
|
||||
for code in list(rule_codes or [])
|
||||
if str(code or "").strip()
|
||||
str(code or "").strip() for code in list(rule_codes or []) if str(code or "").strip()
|
||||
}
|
||||
manifests_by_code: dict[str, dict[str, Any]] = {}
|
||||
|
||||
@@ -224,12 +203,10 @@ class ExpenseClaimPlatformRiskMixin:
|
||||
normalized_contexts.append(
|
||||
{
|
||||
"scene_code": str(document_info.get("scene_code") or "").strip().lower(),
|
||||
"document_type": str(
|
||||
document_info.get("document_type") or ""
|
||||
).strip().lower(),
|
||||
"item_type": str(
|
||||
getattr(context.get("item"), "item_type", "") or ""
|
||||
).strip().lower(),
|
||||
"document_type": str(document_info.get("document_type") or "").strip().lower(),
|
||||
"item_type": str(getattr(context.get("item"), "item_type", "") or "")
|
||||
.strip()
|
||||
.lower(),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -312,6 +289,19 @@ class ExpenseClaimPlatformRiskMixin:
|
||||
claim=claim,
|
||||
contexts=contexts,
|
||||
)
|
||||
if evaluator == "template_rule":
|
||||
result = RiskRuleTemplateExecutor().evaluate(
|
||||
manifest,
|
||||
claim=claim,
|
||||
contexts=contexts,
|
||||
)
|
||||
if result is None:
|
||||
return None
|
||||
return self._build_platform_risk_flag(
|
||||
manifest,
|
||||
message=str(result.get("message") or "自然语言风险规则命中。"),
|
||||
evidence=result.get("evidence") if isinstance(result.get("evidence"), dict) else {},
|
||||
)
|
||||
return None
|
||||
|
||||
def _evaluate_reason_too_brief_risk(
|
||||
@@ -347,9 +337,7 @@ class ExpenseClaimPlatformRiskMixin:
|
||||
reason_corpus = self._build_scene_reason_corpus(claim)
|
||||
compact_reason = re.sub(r"\s+", "", reason_corpus)
|
||||
looks_like_entertainment = (
|
||||
"entertainment" in expense_types
|
||||
or "招待" in compact_reason
|
||||
or "客户" in compact_reason
|
||||
"entertainment" in expense_types or "招待" in compact_reason or "客户" in compact_reason
|
||||
)
|
||||
if not looks_like_entertainment:
|
||||
return None
|
||||
@@ -374,32 +362,28 @@ class ExpenseClaimPlatformRiskMixin:
|
||||
for context in contexts:
|
||||
item = context["item"]
|
||||
item_type = (
|
||||
str(item.item_type or claim.expense_type or "other").strip().lower()
|
||||
or "other"
|
||||
str(item.item_type or claim.expense_type or "other").strip().lower() or "other"
|
||||
)
|
||||
policy = self._get_expense_scene_policy(item_type)
|
||||
if policy is None:
|
||||
continue
|
||||
document_info = context.get("document_info") or {}
|
||||
recognized_scene_code = (
|
||||
str(document_info.get("scene_code") or "other").strip().lower()
|
||||
or "other"
|
||||
str(document_info.get("scene_code") or "other").strip().lower() or "other"
|
||||
)
|
||||
recognized_document_type = (
|
||||
str(document_info.get("document_type") or "other").strip().lower()
|
||||
or "other"
|
||||
str(document_info.get("document_type") or "other").strip().lower() or "other"
|
||||
)
|
||||
if (
|
||||
recognized_scene_code in set(policy.allowed_scene_codes)
|
||||
or recognized_document_type in set(policy.allowed_document_types)
|
||||
):
|
||||
if recognized_scene_code in set(
|
||||
policy.allowed_scene_codes
|
||||
) or recognized_document_type in set(policy.allowed_document_types):
|
||||
continue
|
||||
recognized_label = str(
|
||||
document_info.get("document_type_label")
|
||||
or recognized_document_type
|
||||
or "未知票据"
|
||||
document_info.get("document_type_label") or recognized_document_type or "未知票据"
|
||||
)
|
||||
mismatches.append(
|
||||
f"第 {context['index']} 条明细为{policy.label},附件识别为{recognized_label}"
|
||||
)
|
||||
mismatches.append(f"第 {context['index']} 条明细为{policy.label},附件识别为{recognized_label}")
|
||||
|
||||
if not mismatches:
|
||||
return None
|
||||
@@ -437,7 +421,10 @@ class ExpenseClaimPlatformRiskMixin:
|
||||
evidence_text = "、".join(evidence_cities[:5])
|
||||
return self._build_platform_risk_flag(
|
||||
manifest,
|
||||
message=f"申报地点 {declared_text} 与票据识别地点 {evidence_text} 不一致,建议补充异地说明或更换附件。",
|
||||
message=(
|
||||
f"申报地点 {declared_text} 与票据识别地点 {evidence_text} 不一致,"
|
||||
"建议补充异地说明或更换附件。"
|
||||
),
|
||||
evidence={"declared_cities": declared_cities, "evidence_cities": evidence_cities},
|
||||
)
|
||||
|
||||
@@ -450,9 +437,7 @@ class ExpenseClaimPlatformRiskMixin:
|
||||
) -> dict[str, Any] | None:
|
||||
invoice_keys = self._collect_invoice_keys_from_contexts(contexts)
|
||||
duplicate_keys = [
|
||||
key
|
||||
for key, count in self._count_values(invoice_keys).items()
|
||||
if count > 1
|
||||
key for key, count in self._count_values(invoice_keys).items() if count > 1
|
||||
]
|
||||
if duplicate_keys:
|
||||
return self._build_platform_risk_flag(
|
||||
@@ -504,9 +489,7 @@ class ExpenseClaimPlatformRiskMixin:
|
||||
) -> dict[str, Any] | None:
|
||||
params = manifest.get("params") if isinstance(manifest.get("params"), dict) else {}
|
||||
allow_keywords = [
|
||||
str(value)
|
||||
for value in list(params.get("allow_keywords") or [])
|
||||
if str(value).strip()
|
||||
str(value) for value in list(params.get("allow_keywords") or []) if str(value).strip()
|
||||
]
|
||||
claimant = str(claim.employee_name or "").strip()
|
||||
if not claimant:
|
||||
@@ -564,7 +547,10 @@ class ExpenseClaimPlatformRiskMixin:
|
||||
return None
|
||||
return self._build_platform_risk_flag(
|
||||
manifest,
|
||||
message=f"票据年份 {mismatch_years[0]} 与费用发生年份 {claim_year} 不一致,建议确认是否跨年报销。",
|
||||
message=(
|
||||
f"票据年份 {mismatch_years[0]} 与费用发生年份 {claim_year} 不一致,"
|
||||
"建议确认是否跨年报销。"
|
||||
),
|
||||
evidence={"claim_year": claim_year, "invoice_years": mismatch_years},
|
||||
)
|
||||
|
||||
|
||||
@@ -480,6 +480,7 @@ def _build_initial_knowledge_ingest_document(
|
||||
"entity_count": 0,
|
||||
"relation_count": 0,
|
||||
"entities": [],
|
||||
"entity_chunks": [],
|
||||
"relations": [],
|
||||
"events": [
|
||||
{
|
||||
@@ -677,7 +678,7 @@ def _build_ingest_graph(knowledge_ingest: dict[str, Any]) -> dict[str, Any]:
|
||||
documents = [
|
||||
item for item in list(knowledge_ingest.get("documents") or []) if isinstance(item, dict)
|
||||
]
|
||||
entities = _dedupe_text_items(
|
||||
entities = _dedupe_entities(
|
||||
entity for document in documents for entity in list(document.get("entities") or [])
|
||||
)
|
||||
relations = _dedupe_relations(
|
||||
@@ -692,20 +693,52 @@ def _build_ingest_graph(knowledge_ingest: dict[str, Any]) -> dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
def _dedupe_text_items(items: Any) -> list[str]:
|
||||
deduped: list[str] = []
|
||||
def _dedupe_entities(items: Any) -> list[dict[str, Any]]:
|
||||
deduped: list[dict[str, Any]] = []
|
||||
seen: set[str] = set()
|
||||
for item in items:
|
||||
text = str(item or "").strip()
|
||||
if not text or text in seen:
|
||||
if isinstance(item, dict):
|
||||
name = str(
|
||||
item.get("name")
|
||||
or item.get("entity")
|
||||
or item.get("entity_id")
|
||||
or item.get("title")
|
||||
or item.get("id")
|
||||
or ""
|
||||
).strip()
|
||||
entity = dict(item)
|
||||
else:
|
||||
name = str(item or "").strip()
|
||||
entity = {}
|
||||
if not name or name in seen:
|
||||
continue
|
||||
seen.add(text)
|
||||
deduped.append(text)
|
||||
seen.add(name)
|
||||
entity["name"] = name
|
||||
entity["type"] = str(
|
||||
entity.get("type")
|
||||
or entity.get("entity_type")
|
||||
or entity.get("category")
|
||||
or entity.get("kind")
|
||||
or "实体"
|
||||
).strip()
|
||||
description = str(entity.get("description") or "").strip()
|
||||
descriptions = entity.get("descriptions")
|
||||
if not isinstance(descriptions, list):
|
||||
descriptions = [description] if description else []
|
||||
entity["description"] = description
|
||||
entity["descriptions"] = [
|
||||
str(description_item or "").strip()
|
||||
for description_item in descriptions
|
||||
if str(description_item or "").strip()
|
||||
][:5]
|
||||
if not isinstance(entity.get("properties"), dict):
|
||||
entity["properties"] = {}
|
||||
deduped.append(entity)
|
||||
return deduped
|
||||
|
||||
|
||||
def _dedupe_relations(items: Any) -> list[dict[str, str]]:
|
||||
deduped: list[dict[str, str]] = []
|
||||
def _dedupe_relations(items: Any) -> list[dict[str, Any]]:
|
||||
deduped: list[dict[str, Any]] = []
|
||||
seen: set[tuple[str, str, str]] = set()
|
||||
for item in items:
|
||||
if not isinstance(item, dict):
|
||||
@@ -717,7 +750,7 @@ def _dedupe_relations(items: Any) -> list[dict[str, str]]:
|
||||
if not source or not target or key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
deduped.append({"source": source, "target": target, "type": relation_type})
|
||||
deduped.append({**item, "source": source, "target": target, "type": relation_type})
|
||||
return deduped
|
||||
|
||||
|
||||
|
||||
@@ -1,15 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from xml.etree import ElementTree
|
||||
|
||||
MAX_INGEST_LOG_CHUNKS = 24
|
||||
MAX_INGEST_LOG_ENTITIES = 24
|
||||
MAX_INGEST_LOG_ENTITY_CHUNKS = 48
|
||||
MAX_INGEST_LOG_RELATIONS = 24
|
||||
MAX_INGEST_LOG_SECTIONS = 12
|
||||
MAX_INGEST_LOG_TEXT_PREVIEW = 180
|
||||
MAX_INGEST_LOG_ENTITY_DESCRIPTIONS = 5
|
||||
GRAPHML_NAMESPACE = {"graphml": "http://graphml.graphdrawing.org/xmlns"}
|
||||
GRAPH_PROPERTY_SEPARATOR = "<SEP>"
|
||||
|
||||
INGEST_SECTION_HEADING_PATTERN = re.compile(
|
||||
r"^(?:#{1,4}\s+.+|第[一二三四五六七八九十百零0-9]+[章节条]\s*.*)$"
|
||||
@@ -42,6 +48,7 @@ def build_ingest_document_summary(
|
||||
"entity_count": 0,
|
||||
"relation_count": 0,
|
||||
"entities": [],
|
||||
"entity_chunks": [],
|
||||
"relations": [],
|
||||
}
|
||||
|
||||
@@ -62,6 +69,33 @@ def build_ingest_status_summary(
|
||||
}
|
||||
|
||||
|
||||
def enrich_knowledge_ingest_route_json(
|
||||
route_json: dict[str, Any],
|
||||
*,
|
||||
storage_root: Path,
|
||||
) -> dict[str, Any]:
|
||||
if not isinstance(route_json, dict):
|
||||
return route_json
|
||||
ingest = route_json.get("knowledge_ingest")
|
||||
if not isinstance(ingest, dict):
|
||||
return route_json
|
||||
graph = ingest.get("graph")
|
||||
if not isinstance(graph, dict):
|
||||
return route_json
|
||||
|
||||
workspace = _resolve_lightrag_workspace(route_json)
|
||||
graph_snapshot = _load_lightrag_graph_snapshot(storage_root, workspace=workspace)
|
||||
if not graph_snapshot["entities"] and not graph_snapshot["relations"]:
|
||||
return route_json
|
||||
|
||||
next_route = dict(route_json)
|
||||
next_ingest = dict(ingest)
|
||||
next_graph = _enrich_graph_payload(graph, graph_snapshot)
|
||||
next_ingest["graph"] = next_graph
|
||||
next_route["knowledge_ingest"] = next_ingest
|
||||
return next_route
|
||||
|
||||
|
||||
def build_document_graph_summary(
|
||||
storage_root: Path,
|
||||
*,
|
||||
@@ -74,19 +108,264 @@ def build_document_graph_summary(
|
||||
entities_payload = _load_json_file(workspace_dir / "kv_store_full_entities.json")
|
||||
relations_payload = _load_json_file(workspace_dir / "kv_store_full_relations.json")
|
||||
chunks_payload = _load_json_file(workspace_dir / "kv_store_text_chunks.json")
|
||||
entity_chunks_payload = _load_json_file(workspace_dir / "kv_store_entity_chunks.json")
|
||||
graph_snapshot = _load_lightrag_graph_snapshot(storage_root, workspace=workspace)
|
||||
|
||||
entities = _normalize_document_entities(entities_payload, document_id)
|
||||
relations = _normalize_document_relations(relations_payload, document_id)
|
||||
chunks = _normalize_document_chunks(chunks_payload, document_id)
|
||||
entity_chunks = _normalize_document_entity_chunks(
|
||||
entity_chunks_payload,
|
||||
entities,
|
||||
chunk_ids={str(item.get("id") or "").strip() for item in chunks},
|
||||
)
|
||||
return {
|
||||
"entity_count": len(entities),
|
||||
"relation_count": len(relations),
|
||||
"entities": entities[:MAX_INGEST_LOG_ENTITIES],
|
||||
"relations": relations[:MAX_INGEST_LOG_RELATIONS],
|
||||
"entities": _enrich_entity_list(entities, graph_snapshot)[:MAX_INGEST_LOG_ENTITIES],
|
||||
"relations": _enrich_relation_list(relations, graph_snapshot)[:MAX_INGEST_LOG_RELATIONS],
|
||||
"chunks": chunks[:MAX_INGEST_LOG_CHUNKS],
|
||||
"entity_chunks": entity_chunks[:MAX_INGEST_LOG_ENTITY_CHUNKS],
|
||||
}
|
||||
|
||||
|
||||
def _resolve_lightrag_workspace(route_json: dict[str, Any]) -> str:
|
||||
explicit_workspace = str(
|
||||
route_json.get("lightrag_workspace") or route_json.get("workspace") or ""
|
||||
).strip()
|
||||
if explicit_workspace:
|
||||
return explicit_workspace
|
||||
return os.environ.get("LIGHTRAG_WORKSPACE", "x_financial_knowledge").strip() or "x_financial_knowledge"
|
||||
|
||||
|
||||
def _enrich_graph_payload(
|
||||
graph: dict[str, Any],
|
||||
graph_snapshot: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
next_graph = dict(graph)
|
||||
relation_items = _extract_relation_items(graph.get("relations"))
|
||||
relation_entity_names = [
|
||||
name
|
||||
for relation in relation_items
|
||||
for name in (relation.get("source"), relation.get("target"))
|
||||
]
|
||||
next_graph["entities"] = _enrich_entity_list(
|
||||
_dedupe_text_items(
|
||||
_extract_entity_names(graph.get("entities")) + relation_entity_names
|
||||
),
|
||||
graph_snapshot,
|
||||
)
|
||||
next_graph["relations"] = _enrich_relation_list(relation_items, graph_snapshot)
|
||||
return next_graph
|
||||
|
||||
|
||||
def _enrich_entity_list(
|
||||
entity_names: list[str],
|
||||
graph_snapshot: dict[str, Any],
|
||||
) -> list[dict[str, Any]]:
|
||||
graph_entities = graph_snapshot.get("entities") or {}
|
||||
return [
|
||||
graph_entities.get(entity_name)
|
||||
or {
|
||||
"name": entity_name,
|
||||
"type": "实体",
|
||||
"description": "",
|
||||
"descriptions": [],
|
||||
"properties": {},
|
||||
}
|
||||
for entity_name in entity_names
|
||||
]
|
||||
|
||||
|
||||
def _enrich_relation_list(
|
||||
relations: list[dict[str, Any]],
|
||||
graph_snapshot: dict[str, Any],
|
||||
) -> list[dict[str, Any]]:
|
||||
graph_relations = graph_snapshot.get("relations") or {}
|
||||
enriched_relations: list[dict[str, Any]] = []
|
||||
for relation in relations:
|
||||
source = str(relation.get("source") or "").strip()
|
||||
target = str(relation.get("target") or "").strip()
|
||||
relation_type = str(relation.get("type") or "关联").strip()
|
||||
graph_relation = (
|
||||
graph_relations.get((source, target))
|
||||
or graph_relations.get((target, source))
|
||||
or {}
|
||||
)
|
||||
enriched_relations.append(
|
||||
{
|
||||
**relation,
|
||||
"source": source,
|
||||
"target": target,
|
||||
"type": relation_type,
|
||||
"description": graph_relation.get("description", ""),
|
||||
"keywords": graph_relation.get("keywords", []),
|
||||
"weight": graph_relation.get("weight", relation.get("weight", 1)),
|
||||
"properties": graph_relation.get("properties", {}),
|
||||
}
|
||||
)
|
||||
return enriched_relations
|
||||
|
||||
|
||||
def _load_lightrag_graph_snapshot(storage_root: Path, *, workspace: str) -> dict[str, Any]:
|
||||
graphml_path = (
|
||||
Path(storage_root)
|
||||
/ "knowledge"
|
||||
/ ".lightrag"
|
||||
/ str(workspace).strip()
|
||||
/ "graph_chunk_entity_relation.graphml"
|
||||
)
|
||||
if not graphml_path.exists():
|
||||
return {"entities": {}, "relations": {}}
|
||||
|
||||
try:
|
||||
root = ElementTree.parse(graphml_path).getroot()
|
||||
except (ElementTree.ParseError, OSError):
|
||||
return {"entities": {}, "relations": {}}
|
||||
|
||||
key_names = {
|
||||
str(key.attrib.get("id") or ""): str(key.attrib.get("attr.name") or "")
|
||||
for key in root.findall("graphml:key", GRAPHML_NAMESPACE)
|
||||
}
|
||||
return {
|
||||
"entities": _load_graphml_entities(root, key_names),
|
||||
"relations": _load_graphml_relations(root, key_names),
|
||||
}
|
||||
|
||||
|
||||
def _load_graphml_entities(
|
||||
root: ElementTree.Element,
|
||||
key_names: dict[str, str],
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
entities: dict[str, dict[str, Any]] = {}
|
||||
for node in root.findall(".//graphml:node", GRAPHML_NAMESPACE):
|
||||
properties = _read_graphml_data(node, key_names)
|
||||
name = str(properties.get("entity_id") or node.attrib.get("id") or "").strip()
|
||||
if not name:
|
||||
continue
|
||||
descriptions = _split_graph_property(properties.get("description"))
|
||||
visible_properties = _filter_graph_properties(properties)
|
||||
entities[name] = {
|
||||
"name": name,
|
||||
"type": str(properties.get("entity_type") or "实体").strip(),
|
||||
"description": descriptions[0] if descriptions else "",
|
||||
"descriptions": descriptions[:MAX_INGEST_LOG_ENTITY_DESCRIPTIONS],
|
||||
"properties": visible_properties,
|
||||
}
|
||||
return entities
|
||||
|
||||
|
||||
def _load_graphml_relations(
|
||||
root: ElementTree.Element,
|
||||
key_names: dict[str, str],
|
||||
) -> dict[tuple[str, str], dict[str, Any]]:
|
||||
relations: dict[tuple[str, str], dict[str, Any]] = {}
|
||||
for edge in root.findall(".//graphml:edge", GRAPHML_NAMESPACE):
|
||||
source = str(edge.attrib.get("source") or "").strip()
|
||||
target = str(edge.attrib.get("target") or "").strip()
|
||||
if not source or not target:
|
||||
continue
|
||||
properties = _read_graphml_data(edge, key_names)
|
||||
description_parts = _split_graph_property(properties.get("description"))
|
||||
relations[(source, target)] = {
|
||||
"description": "; ".join(description_parts[:2]),
|
||||
"keywords": _split_graph_keywords(properties.get("keywords"))[:6],
|
||||
"weight": _to_float(properties.get("weight"), default=1.0),
|
||||
"properties": _filter_graph_properties(properties),
|
||||
}
|
||||
return relations
|
||||
|
||||
|
||||
def _read_graphml_data(
|
||||
element: ElementTree.Element,
|
||||
key_names: dict[str, str],
|
||||
) -> dict[str, str]:
|
||||
data: dict[str, str] = {}
|
||||
for item in element.findall("graphml:data", GRAPHML_NAMESPACE):
|
||||
key = str(item.attrib.get("key") or "")
|
||||
name = key_names.get(key) or key
|
||||
if not name:
|
||||
continue
|
||||
data[name] = str(item.text or "").strip()
|
||||
return data
|
||||
|
||||
|
||||
def _split_graph_property(value: Any) -> list[str]:
|
||||
return [
|
||||
_truncate_text(part, max_length=MAX_INGEST_LOG_TEXT_PREVIEW)
|
||||
for part in str(value or "").split(GRAPH_PROPERTY_SEPARATOR)
|
||||
if str(part or "").strip()
|
||||
]
|
||||
|
||||
|
||||
def _split_graph_keywords(value: Any) -> list[str]:
|
||||
keywords: list[str] = []
|
||||
for part in str(value or "").split(GRAPH_PROPERTY_SEPARATOR):
|
||||
keywords.extend(part.split(","))
|
||||
return [
|
||||
_truncate_text(keyword, max_length=60)
|
||||
for keyword in keywords
|
||||
if str(keyword or "").strip()
|
||||
]
|
||||
|
||||
|
||||
def _filter_graph_properties(properties: dict[str, Any]) -> dict[str, Any]:
|
||||
hidden_fields = {
|
||||
"source_id",
|
||||
"file_path",
|
||||
"truncate",
|
||||
"description",
|
||||
"keywords",
|
||||
}
|
||||
return {
|
||||
key: value
|
||||
for key, value in properties.items()
|
||||
if key not in hidden_fields and str(value or "").strip()
|
||||
}
|
||||
|
||||
|
||||
def _extract_entity_names(raw_entities: Any) -> list[str]:
|
||||
if not isinstance(raw_entities, list):
|
||||
return []
|
||||
names: list[str] = []
|
||||
for entity in raw_entities:
|
||||
if isinstance(entity, dict):
|
||||
name = str(
|
||||
entity.get("name")
|
||||
or entity.get("entity")
|
||||
or entity.get("entity_id")
|
||||
or entity.get("id")
|
||||
or ""
|
||||
).strip()
|
||||
else:
|
||||
name = str(entity or "").strip()
|
||||
if name:
|
||||
names.append(name)
|
||||
return _dedupe_text_items(names)
|
||||
|
||||
|
||||
def _extract_relation_items(raw_relations: Any) -> list[dict[str, Any]]:
|
||||
if not isinstance(raw_relations, list):
|
||||
return []
|
||||
relations: list[dict[str, Any]] = []
|
||||
for relation in raw_relations:
|
||||
if not isinstance(relation, dict):
|
||||
continue
|
||||
source = str(relation.get("source") or relation.get("from") or "").strip()
|
||||
target = str(relation.get("target") or relation.get("to") or "").strip()
|
||||
if not source or not target:
|
||||
continue
|
||||
relations.append(
|
||||
{
|
||||
**relation,
|
||||
"source": source,
|
||||
"target": target,
|
||||
"type": str(relation.get("type") or "关联").strip(),
|
||||
}
|
||||
)
|
||||
return relations
|
||||
|
||||
|
||||
def _extract_ingest_sections(text: str) -> list[dict[str, str]]:
|
||||
sections: list[dict[str, str]] = []
|
||||
lines = [line.strip() for line in str(text or "").splitlines()]
|
||||
@@ -187,11 +466,46 @@ def _normalize_document_chunks(payload: dict[str, Any], document_id: str) -> lis
|
||||
"order": _to_int(raw_chunk.get("chunk_order_index")),
|
||||
"tokens": _to_int(raw_chunk.get("tokens")),
|
||||
"summary": _build_chunk_summary(content),
|
||||
"excerpt": _truncate_text(
|
||||
content,
|
||||
max_length=MAX_INGEST_LOG_TEXT_PREVIEW,
|
||||
),
|
||||
}
|
||||
)
|
||||
return sorted(chunks, key=lambda item: (item["order"], item["id"]))
|
||||
|
||||
|
||||
def _normalize_document_entity_chunks(
|
||||
payload: dict[str, Any],
|
||||
entities: list[str],
|
||||
*,
|
||||
chunk_ids: set[str],
|
||||
) -> list[dict[str, Any]]:
|
||||
if not entities or not chunk_ids:
|
||||
return []
|
||||
|
||||
entity_chunks: list[dict[str, Any]] = []
|
||||
for entity in entities:
|
||||
raw_entry = payload.get(entity) if isinstance(payload, dict) else {}
|
||||
raw_chunk_ids = raw_entry.get("chunk_ids") if isinstance(raw_entry, dict) else []
|
||||
if not isinstance(raw_chunk_ids, list):
|
||||
continue
|
||||
matched_chunk_ids = [
|
||||
str(item or "").strip()
|
||||
for item in raw_chunk_ids
|
||||
if str(item or "").strip() in chunk_ids
|
||||
]
|
||||
if not matched_chunk_ids:
|
||||
continue
|
||||
entity_chunks.append(
|
||||
{
|
||||
"entity": entity,
|
||||
"chunk_ids": _dedupe_text_items(matched_chunk_ids),
|
||||
}
|
||||
)
|
||||
return entity_chunks
|
||||
|
||||
|
||||
def _build_chunk_summary(content: str) -> str:
|
||||
lines = [line.strip() for line in str(content or "").splitlines() if line.strip()]
|
||||
text = next((line for line in lines if len(line) >= 12), lines[0] if lines else "")
|
||||
@@ -217,6 +531,13 @@ def _to_int(value: Any) -> int:
|
||||
return 0
|
||||
|
||||
|
||||
def _to_float(value: Any, *, default: float = 0.0) -> float:
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
|
||||
def _truncate_text(text: str, *, max_length: int) -> str:
|
||||
normalized = " ".join(str(text or "").split()).strip()
|
||||
if len(normalized) <= max_length:
|
||||
|
||||
@@ -17,6 +17,7 @@ from app.services.knowledge_ingest_log import (
|
||||
build_ingest_document_summary,
|
||||
build_ingest_status_summary,
|
||||
)
|
||||
from app.services.knowledge_rag_local import query_local_text_chunks
|
||||
from app.services.knowledge_rag_runtime import (
|
||||
KnowledgeRagError,
|
||||
RuntimeModelConfig,
|
||||
@@ -95,6 +96,37 @@ class KnowledgeRagService:
|
||||
"message": "请先输入要检索的知识库问题。",
|
||||
}
|
||||
|
||||
workspace = (
|
||||
os.environ.get("LIGHTRAG_WORKSPACE", DEFAULT_LIGHTRAG_WORKSPACE).strip()
|
||||
or DEFAULT_LIGHTRAG_WORKSPACE
|
||||
)
|
||||
local_result = query_local_text_chunks(
|
||||
lightrag_root=(self.storage_root / "knowledge" / ".lightrag").resolve(),
|
||||
workspace=workspace,
|
||||
query=normalized_query,
|
||||
limit=limit,
|
||||
)
|
||||
if local_result.confident:
|
||||
return {
|
||||
"result_type": "knowledge_search",
|
||||
"query": normalized_query,
|
||||
"record_count": len(local_result.hits),
|
||||
"hits": local_result.hits,
|
||||
"references": [
|
||||
str(item.get("code") or "").strip()
|
||||
for item in local_result.hits
|
||||
if str(item.get("code") or "").strip()
|
||||
],
|
||||
"raw_references": [],
|
||||
"metadata": {
|
||||
"retrieval_strategy": "local_text_chunks",
|
||||
"elapsed_seconds": round(local_result.elapsed_seconds, 4),
|
||||
"total_chunks": local_result.total_chunks,
|
||||
"best_score": local_result.best_score,
|
||||
},
|
||||
"message": f"已从本地知识块中检索到 {len(local_result.hits)} 条相关内容。",
|
||||
}
|
||||
|
||||
try:
|
||||
runtime = self._get_runtime()
|
||||
raw = runtime.query_data(normalized_query, conversation_history=conversation_history)
|
||||
|
||||
353
server/src/app/services/knowledge_rag_local.py
Normal file
353
server/src/app/services/knowledge_rag_local.py
Normal file
@@ -0,0 +1,353 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from time import perf_counter
|
||||
from typing import Any
|
||||
|
||||
MAX_LOCAL_QUERY_TERMS = 14
|
||||
MAX_LOCAL_HIT_CONTENT_LENGTH = 2200
|
||||
MAX_LOCAL_HIT_EXCERPT_LENGTH = 240
|
||||
LOCAL_CONFIDENCE_SCORE = 24
|
||||
LOCAL_CONFIDENCE_MATCHES = 2
|
||||
LOCAL_QUERY_STOPWORDS = {
|
||||
"什么",
|
||||
"多少",
|
||||
"哪些",
|
||||
"怎么",
|
||||
"如何",
|
||||
"请问",
|
||||
"一下",
|
||||
"关于",
|
||||
"规定",
|
||||
"标准",
|
||||
"可以",
|
||||
"是否",
|
||||
"一个",
|
||||
"根据",
|
||||
"依据",
|
||||
"给出",
|
||||
"说明",
|
||||
"公司",
|
||||
"远光",
|
||||
"软件",
|
||||
"股份",
|
||||
"有限",
|
||||
"员工",
|
||||
"当前",
|
||||
"详细",
|
||||
"问题",
|
||||
}
|
||||
LOCAL_TABLE_QUERY_HINTS = (
|
||||
"标准",
|
||||
"金额",
|
||||
"限额",
|
||||
"补贴",
|
||||
"住宿",
|
||||
"餐费",
|
||||
"交通",
|
||||
"报销",
|
||||
"档位",
|
||||
"额度",
|
||||
)
|
||||
LOCAL_DOMAIN_TERMS = (
|
||||
"报销",
|
||||
"费用",
|
||||
"报销时限",
|
||||
"申请时限",
|
||||
"三个月",
|
||||
"逾期",
|
||||
"住宿费",
|
||||
"住宿",
|
||||
"差旅费",
|
||||
"差旅",
|
||||
"出差",
|
||||
"超标",
|
||||
"超过",
|
||||
"审批",
|
||||
"分管领导",
|
||||
"部门负责人",
|
||||
"业务招待",
|
||||
"招待费",
|
||||
"发票",
|
||||
"票据",
|
||||
"预算外",
|
||||
"预算",
|
||||
"补贴",
|
||||
"餐补",
|
||||
"交通费",
|
||||
"会议费",
|
||||
"培训费",
|
||||
"通信费",
|
||||
)
|
||||
|
||||
_index_lock = threading.RLock()
|
||||
_index_cache: dict[Path, tuple[tuple[int, int], list[dict[str, Any]]]] = {}
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class LocalKnowledgeSearchResult:
|
||||
hits: list[dict[str, Any]]
|
||||
confident: bool
|
||||
elapsed_seconds: float
|
||||
total_chunks: int
|
||||
best_score: int
|
||||
|
||||
|
||||
def query_local_text_chunks(
|
||||
*,
|
||||
lightrag_root: Path,
|
||||
workspace: str,
|
||||
query: str,
|
||||
limit: int,
|
||||
) -> LocalKnowledgeSearchResult:
|
||||
started_at = perf_counter()
|
||||
chunks = _load_text_chunks(lightrag_root / workspace / "kv_store_text_chunks.json")
|
||||
query_terms = _extract_local_query_terms(query)
|
||||
if not chunks or not query_terms:
|
||||
return LocalKnowledgeSearchResult(
|
||||
hits=[],
|
||||
confident=False,
|
||||
elapsed_seconds=perf_counter() - started_at,
|
||||
total_chunks=len(chunks),
|
||||
best_score=0,
|
||||
)
|
||||
|
||||
prefers_tabular_evidence = any(hint in query for hint in LOCAL_TABLE_QUERY_HINTS)
|
||||
candidates: list[dict[str, Any]] = []
|
||||
for rank, chunk in enumerate(chunks, start=1):
|
||||
content = str(chunk.get("content") or "").strip()
|
||||
if not content:
|
||||
continue
|
||||
file_path = str(chunk.get("file_path") or "").strip()
|
||||
document_id, document_name = _parse_document_identity(
|
||||
file_path,
|
||||
fallback_doc_id=str(chunk.get("full_doc_id") or "").strip(),
|
||||
)
|
||||
score, matched_terms = _score_local_chunk(
|
||||
content=content,
|
||||
title=document_name,
|
||||
query_terms=query_terms,
|
||||
rank=rank,
|
||||
prefers_tabular_evidence=prefers_tabular_evidence,
|
||||
)
|
||||
if score <= 0:
|
||||
continue
|
||||
|
||||
chunk_id = str(chunk.get("_id") or chunk.get("chunk_id") or "").strip()
|
||||
normalized_content = _truncate_text(
|
||||
content,
|
||||
max_length=MAX_LOCAL_HIT_CONTENT_LENGTH,
|
||||
)
|
||||
candidates.append(
|
||||
{
|
||||
"code": f"knowledge.{document_id or 'unknown'}.{chunk_id or rank}",
|
||||
"candidate_id": chunk_id or f"local-{rank}",
|
||||
"title": document_name or "知识库文档",
|
||||
"content": normalized_content,
|
||||
"excerpt": _build_query_focused_excerpt(
|
||||
normalized_content,
|
||||
query_terms=query_terms,
|
||||
max_length=MAX_LOCAL_HIT_EXCERPT_LENGTH,
|
||||
),
|
||||
"document_id": document_id,
|
||||
"document_name": document_name or Path(file_path).name,
|
||||
"version": None,
|
||||
"updated_at": None,
|
||||
"score": score,
|
||||
"tags": [],
|
||||
"evidence": [chunk_id] if chunk_id else [],
|
||||
"file_path": file_path,
|
||||
"_matched_terms": matched_terms,
|
||||
}
|
||||
)
|
||||
|
||||
ranked = sorted(
|
||||
candidates,
|
||||
key=lambda item: (
|
||||
int(item.get("score") or 0),
|
||||
len(list(item.get("_matched_terms") or [])),
|
||||
-len(str(item.get("content") or "")),
|
||||
),
|
||||
reverse=True,
|
||||
)
|
||||
hits: list[dict[str, Any]] = []
|
||||
for item in ranked[: max(1, limit)]:
|
||||
normalized = dict(item)
|
||||
normalized.pop("_matched_terms", None)
|
||||
hits.append(normalized)
|
||||
|
||||
best_score = int(ranked[0].get("score") or 0) if ranked else 0
|
||||
best_match_count = len(list(ranked[0].get("_matched_terms") or [])) if ranked else 0
|
||||
confident = bool(
|
||||
hits
|
||||
and best_score >= LOCAL_CONFIDENCE_SCORE
|
||||
and best_match_count >= LOCAL_CONFIDENCE_MATCHES
|
||||
)
|
||||
return LocalKnowledgeSearchResult(
|
||||
hits=hits,
|
||||
confident=confident,
|
||||
elapsed_seconds=perf_counter() - started_at,
|
||||
total_chunks=len(chunks),
|
||||
best_score=best_score,
|
||||
)
|
||||
|
||||
|
||||
def _load_text_chunks(path: Path) -> list[dict[str, Any]]:
|
||||
try:
|
||||
stat = path.stat()
|
||||
except OSError:
|
||||
return []
|
||||
|
||||
signature = (int(stat.st_mtime_ns), int(stat.st_size))
|
||||
resolved_path = path.resolve()
|
||||
with _index_lock:
|
||||
cached = _index_cache.get(resolved_path)
|
||||
if cached is not None and cached[0] == signature:
|
||||
return cached[1]
|
||||
|
||||
try:
|
||||
payload = json.loads(path.read_text(encoding="utf-8"))
|
||||
except (OSError, json.JSONDecodeError):
|
||||
chunks: list[dict[str, Any]] = []
|
||||
else:
|
||||
chunks = [
|
||||
dict(value)
|
||||
for value in payload.values()
|
||||
if isinstance(value, dict) and str(value.get("content") or "").strip()
|
||||
]
|
||||
chunks.sort(
|
||||
key=lambda item: (
|
||||
str(item.get("full_doc_id") or ""),
|
||||
int(item.get("chunk_order_index") or 0),
|
||||
str(item.get("_id") or ""),
|
||||
)
|
||||
)
|
||||
|
||||
_index_cache[resolved_path] = (signature, chunks)
|
||||
return chunks
|
||||
|
||||
|
||||
def _score_local_chunk(
|
||||
*,
|
||||
content: str,
|
||||
title: str,
|
||||
query_terms: list[str],
|
||||
rank: int,
|
||||
prefers_tabular_evidence: bool,
|
||||
) -> tuple[int, list[str]]:
|
||||
lowered_content = content.lower()
|
||||
lowered_title = title.lower()
|
||||
matched_terms = [
|
||||
term for term in query_terms if term in lowered_content or term in lowered_title
|
||||
]
|
||||
if not matched_terms:
|
||||
return 0, []
|
||||
|
||||
score = max(1, 32 - min(rank, 20))
|
||||
for term in matched_terms:
|
||||
weight = 8 if len(term) >= 4 else 5 if len(term) == 3 else 2
|
||||
score += weight
|
||||
if term in lowered_title:
|
||||
score += max(4, weight)
|
||||
occurrences = lowered_content.count(term)
|
||||
if occurrences > 1:
|
||||
score += min(8, occurrences * 2)
|
||||
|
||||
if prefers_tabular_evidence and ("|" in content or "表" in content):
|
||||
score += 12
|
||||
if "# 结构化表格补充" in content:
|
||||
score += 10 if prefers_tabular_evidence else 4
|
||||
if "# 问答线索补充" in content:
|
||||
score += 8
|
||||
if "# 章节导航" in content[:260]:
|
||||
score -= 20
|
||||
if any(marker in content for marker in ("第", "条", ":", ";", "-", "•")):
|
||||
score += 4
|
||||
|
||||
return score, matched_terms
|
||||
|
||||
|
||||
def _extract_local_query_terms(query: str) -> list[str]:
|
||||
normalized_query = str(query or "").strip().lower()
|
||||
if not normalized_query:
|
||||
return []
|
||||
|
||||
terms: list[str] = []
|
||||
seen: set[str] = set()
|
||||
|
||||
def remember(term: str) -> None:
|
||||
normalized_term = str(term or "").strip().lower()
|
||||
if (
|
||||
not normalized_term
|
||||
or normalized_term in seen
|
||||
or normalized_term in LOCAL_QUERY_STOPWORDS
|
||||
or len(normalized_term) < 2
|
||||
):
|
||||
return
|
||||
seen.add(normalized_term)
|
||||
terms.append(normalized_term)
|
||||
|
||||
for item in re.findall(r"[a-z0-9][a-z0-9_\-]{1,}", normalized_query):
|
||||
remember(item)
|
||||
|
||||
for item in LOCAL_DOMAIN_TERMS:
|
||||
if item in normalized_query:
|
||||
remember(item)
|
||||
|
||||
for block in re.findall(r"[\u4e00-\u9fff]{2,24}", normalized_query):
|
||||
if len(block) <= 4:
|
||||
remember(block)
|
||||
continue
|
||||
for size in (4, 3, 2, 5):
|
||||
for start in range(0, len(block) - size + 1):
|
||||
remember(block[start : start + size])
|
||||
if len(terms) >= MAX_LOCAL_QUERY_TERMS:
|
||||
return terms
|
||||
|
||||
return terms[:MAX_LOCAL_QUERY_TERMS]
|
||||
|
||||
|
||||
def _parse_document_identity(file_path: str, *, fallback_doc_id: str = "") -> tuple[str, str]:
|
||||
path = Path(str(file_path or "").strip())
|
||||
name = path.name
|
||||
if "__" not in name:
|
||||
return fallback_doc_id, name
|
||||
document_id, document_name = name.split("__", maxsplit=1)
|
||||
return document_id.strip() or fallback_doc_id, document_name.strip()
|
||||
|
||||
|
||||
def _build_query_focused_excerpt(
|
||||
text: str,
|
||||
*,
|
||||
query_terms: list[str],
|
||||
max_length: int,
|
||||
) -> str:
|
||||
normalized = " ".join(str(text or "").split()).strip()
|
||||
if not normalized:
|
||||
return ""
|
||||
lowered = normalized.lower()
|
||||
match_positions = [
|
||||
lowered.find(term) for term in query_terms if term and lowered.find(term) >= 0
|
||||
]
|
||||
if not match_positions:
|
||||
return _truncate_text(normalized, max_length=max_length)
|
||||
|
||||
start = max(0, min(match_positions) - max_length // 3)
|
||||
end = min(len(normalized), start + max_length)
|
||||
snippet = normalized[start:end].strip()
|
||||
if start > 0:
|
||||
snippet = f"...{snippet.lstrip()}"
|
||||
if end < len(normalized):
|
||||
snippet = f"{snippet.rstrip()}..."
|
||||
return snippet
|
||||
|
||||
|
||||
def _truncate_text(text: str, *, max_length: int) -> str:
|
||||
normalized = str(text or "").strip()
|
||||
if len(normalized) <= max_length:
|
||||
return normalized
|
||||
return f"{normalized[: max_length - 3].rstrip()}..."
|
||||
@@ -404,6 +404,14 @@ class OrchestratorService:
|
||||
ontology: OntologyParseResult,
|
||||
task_asset: AgentAssetRead | None,
|
||||
) -> dict[str, list[AgentAssetListItem | AgentAssetRead]]:
|
||||
if ontology.scenario == "knowledge" and payload.source == AgentRunSource.USER_MESSAGE.value:
|
||||
return {
|
||||
"rules": [],
|
||||
"skills": [],
|
||||
"mcps": [],
|
||||
"tasks": [],
|
||||
}
|
||||
|
||||
domain_value = SCENARIO_TO_DOMAIN.get(ontology.scenario)
|
||||
rules = self.execution_engine._rank_assets(
|
||||
self.asset_service.list_assets(
|
||||
|
||||
191
server/src/app/services/risk_rule_flow_diagram.py
Normal file
191
server/src/app/services/risk_rule_flow_diagram.py
Normal file
@@ -0,0 +1,191 @@
|
||||
# ruff: noqa: E501
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import html
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RiskRuleFlowDiagramField:
|
||||
key: str
|
||||
label: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RiskRuleFlowDiagramSpec:
|
||||
title: str
|
||||
domain_label: str
|
||||
severity: str
|
||||
severity_label: str
|
||||
fields: tuple[RiskRuleFlowDiagramField, ...]
|
||||
start: str
|
||||
evidence: str
|
||||
decision: str
|
||||
basis: str
|
||||
pass_text: str
|
||||
fail_text: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RiskRuleFlowDiagramPalette:
|
||||
accent: str
|
||||
accent_dark: str
|
||||
border: str
|
||||
surface: str
|
||||
|
||||
|
||||
class RiskRuleFlowDiagramRenderer:
|
||||
"""按 fireworks-tech-graph Style 7 OpenAI Official 生成只读流程 SVG。"""
|
||||
|
||||
_FONT = (
|
||||
"-apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Helvetica Neue, "
|
||||
"'PingFang SC', 'Microsoft YaHei', 'Microsoft JhengHei', 'SimHei', sans-serif"
|
||||
)
|
||||
_TEXT = "#0d0d0d"
|
||||
_MUTED = "#6e6e80"
|
||||
_NEUTRAL_LINE = "#cbd5e1"
|
||||
_NEUTRAL_BORDER = "#e2e8f0"
|
||||
_NEUTRAL_SURFACE = "#ffffff"
|
||||
_PALETTES = {
|
||||
"low": RiskRuleFlowDiagramPalette(
|
||||
accent="#2563eb",
|
||||
accent_dark="#1d4ed8",
|
||||
border="#bfdbfe",
|
||||
surface="#eff6ff",
|
||||
),
|
||||
"medium": RiskRuleFlowDiagramPalette(
|
||||
accent="#f97316",
|
||||
accent_dark="#c2410c",
|
||||
border="#fed7aa",
|
||||
surface="#fff7ed",
|
||||
),
|
||||
"high": RiskRuleFlowDiagramPalette(
|
||||
accent="#dc2626",
|
||||
accent_dark="#b91c1c",
|
||||
border="#fecaca",
|
||||
surface="#fef2f2",
|
||||
),
|
||||
}
|
||||
|
||||
def render(self, spec: RiskRuleFlowDiagramSpec) -> str:
|
||||
title = self._truncate(spec.title, 26)
|
||||
palette = self._palette(spec.severity)
|
||||
|
||||
return f"""<svg xmlns="http://www.w3.org/2000/svg" width="760" height="280" viewBox="0 0 760 280" data-risk-flow-style="review-node-only" role="img" aria-labelledby="risk-flow-title risk-flow-desc">
|
||||
<title id="risk-flow-title">{self._escape(title)}流程说明</title>
|
||||
<desc id="risk-flow-desc">风险规则只读流程图,展示从业务单据提交到风险复核的判断路径。</desc>
|
||||
<defs>
|
||||
<marker id="arrow-neutral" markerWidth="10" markerHeight="7" refX="9" refY="3.5" orient="auto">
|
||||
<polygon points="0 0, 10 3.5, 0 7" fill="{self._NEUTRAL_LINE}"/>
|
||||
</marker>
|
||||
</defs>
|
||||
<rect width="760" height="280" fill="#ffffff"/>
|
||||
<rect x="18" y="18" width="724" height="244" rx="8" ry="8" fill="none" stroke="{self._NEUTRAL_BORDER}" stroke-width="1" stroke-dasharray="4,3"/>
|
||||
<text x="34" y="43" fill="{self._MUTED}" font-family="{self._FONT}" font-size="11" font-weight="500">RULE FLOW</text>
|
||||
{self._node("业务输入", spec.start, 48, 118, 124, 60)}
|
||||
{self._node("字段取数", "读取字段证据", 214, 118, 132, 60)}
|
||||
{self._diamond("判断依据", spec.decision, 392, 92, 112, 112)}
|
||||
{self._node("继续流转", spec.pass_text, 562, 74, 126, 60)}
|
||||
{self._node("进入复核", spec.fail_text, 562, 190, 126, 62, palette=palette)}
|
||||
{self._note(spec.basis, 214, 218, 290, 36)}
|
||||
<line x1="172" y1="148" x2="214" y2="148" stroke="{self._NEUTRAL_LINE}" stroke-width="1.45" marker-end="url(#arrow-neutral)"/>
|
||||
<line x1="346" y1="148" x2="392" y2="148" stroke="{self._NEUTRAL_LINE}" stroke-width="1.45" marker-end="url(#arrow-neutral)"/>
|
||||
<path d="M 504 127 L 532 127 L 532 104 L 562 104" fill="none" stroke="{self._NEUTRAL_LINE}" stroke-width="1.35" marker-end="url(#arrow-neutral)"/>
|
||||
<text x="534" y="119" text-anchor="middle" fill="{self._MUTED}" font-family="{self._FONT}" font-size="10.5" font-weight="400">否</text>
|
||||
<path d="M 504 169 L 532 169 L 532 221 L 562 221" fill="none" stroke="{self._NEUTRAL_LINE}" stroke-width="1.8" marker-end="url(#arrow-neutral)"/>
|
||||
<text x="534" y="195" text-anchor="middle" fill="{self._MUTED}" font-family="{self._FONT}" font-size="10.5" font-weight="600">是</text>
|
||||
</svg>"""
|
||||
|
||||
def _node(
|
||||
self,
|
||||
title: str,
|
||||
body: str,
|
||||
x: int,
|
||||
y: int,
|
||||
width: int,
|
||||
height: int,
|
||||
palette: RiskRuleFlowDiagramPalette | None = None,
|
||||
) -> str:
|
||||
body_lines = self._wrap(body, 10 if width <= 126 else 11, 1)
|
||||
border = palette.border if palette else self._NEUTRAL_BORDER
|
||||
stripe = palette.accent if palette else self._NEUTRAL_LINE
|
||||
surface = palette.surface if palette else self._NEUTRAL_SURFACE
|
||||
return f"""<g>
|
||||
<rect x="{x}" y="{y}" width="{width}" height="{height}" rx="7" ry="7" fill="{surface}" stroke="{border}" stroke-width="1.2"/>
|
||||
<rect x="{x}" y="{y}" width="3.5" height="{height}" rx="1.75" ry="1.75" fill="{stripe}"/>
|
||||
<text x="{x + width / 2:.0f}" y="{y + 24}" text-anchor="middle" fill="{self._TEXT}" font-family="{self._FONT}" font-size="13" font-weight="600">{self._escape(title)}</text>
|
||||
{self._text_lines(body_lines, x + width / 2, y + 43, "middle", self._MUTED, 11)}
|
||||
</g>"""
|
||||
|
||||
def _diamond(
|
||||
self,
|
||||
title: str,
|
||||
body: str,
|
||||
x: int,
|
||||
y: int,
|
||||
width: int,
|
||||
height: int,
|
||||
) -> str:
|
||||
cx = x + width / 2
|
||||
cy = y + height / 2
|
||||
points = f"{cx},{y} {x + width},{cy} {cx},{y + height} {x},{cy}"
|
||||
body_lines = self._wrap(body, 8, 2)
|
||||
return f"""<g>
|
||||
<polygon points="{points}" fill="#ffffff" stroke="{self._NEUTRAL_BORDER}" stroke-width="1.25"/>
|
||||
<text x="{cx:.0f}" y="{cy - 10:.0f}" text-anchor="middle" fill="{self._TEXT}" font-family="{self._FONT}" font-size="12.5" font-weight="600">{self._escape(title)}</text>
|
||||
{self._text_lines(body_lines, cx, cy + 11, "middle", self._MUTED, 10.2)}
|
||||
</g>"""
|
||||
|
||||
def _note(
|
||||
self,
|
||||
body: str,
|
||||
x: int,
|
||||
y: int,
|
||||
width: int,
|
||||
height: int,
|
||||
) -> str:
|
||||
lines = self._wrap(body, 22, 1)
|
||||
return f"""<g>
|
||||
<rect x="{x}" y="{y}" width="{width}" height="{height}" rx="7" ry="7" fill="#ffffff" stroke="{self._NEUTRAL_BORDER}" stroke-width="1" stroke-dasharray="4,3"/>
|
||||
<text x="{x + 12}" y="{y + 22}" fill="{self._MUTED}" font-family="{self._FONT}" font-size="10" font-weight="500">BASIS</text>
|
||||
{self._text_lines(lines, x + 54, y + 22, "start", self._TEXT, 10.2)}
|
||||
</g>"""
|
||||
|
||||
def _text_lines(
|
||||
self,
|
||||
lines: list[str],
|
||||
x: float,
|
||||
y: float,
|
||||
anchor: str,
|
||||
color: str,
|
||||
font_size: float,
|
||||
) -> str:
|
||||
return "\n ".join(
|
||||
f'<text x="{x:.0f}" y="{y + index * (font_size + 5):.0f}" text-anchor="{anchor}" fill="{color}" font-family="{self._FONT}" font-size="{font_size}" font-weight="400">{self._escape(line)}</text>'
|
||||
for index, line in enumerate(lines)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _wrap(value: str, width: int, max_lines: int) -> list[str]:
|
||||
text = str(value or "").strip()
|
||||
if not text:
|
||||
return [""]
|
||||
lines = [text[index : index + width] for index in range(0, len(text), width)]
|
||||
if len(lines) > max_lines:
|
||||
lines = lines[:max_lines]
|
||||
lines[-1] = f"{lines[-1][: max(0, width - 1)]}…"
|
||||
return lines
|
||||
|
||||
@staticmethod
|
||||
def _truncate(value: str, length: int) -> str:
|
||||
text = str(value or "").strip()
|
||||
return text if len(text) <= length else f"{text[: length - 1]}…"
|
||||
|
||||
@staticmethod
|
||||
def _escape(value: str) -> str:
|
||||
return html.escape(str(value or ""), quote=True)
|
||||
|
||||
@classmethod
|
||||
def _palette(cls, severity: str) -> RiskRuleFlowDiagramPalette:
|
||||
return cls._PALETTES.get(str(severity or "").strip().lower(), cls._PALETTES["medium"])
|
||||
751
server/src/app/services/risk_rule_generation.py
Normal file
751
server/src/app/services/risk_rule_generation.py
Normal file
@@ -0,0 +1,751 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.agent_enums import AgentAssetDomain, AgentAssetStatus, AgentAssetType
|
||||
from app.models.agent_asset import AgentAsset, AgentAssetVersion
|
||||
from app.schemas.agent_asset import AgentAssetRiskRuleGenerateRequest
|
||||
from app.services.agent_asset_rule_library import AgentAssetRuleLibraryManager
|
||||
from app.services.agent_asset_spreadsheet import RISK_RULES_LIBRARY
|
||||
from app.services.audit import AuditLogService
|
||||
from app.services.risk_rule_flow_diagram import (
|
||||
RiskRuleFlowDiagramField,
|
||||
RiskRuleFlowDiagramRenderer,
|
||||
RiskRuleFlowDiagramSpec,
|
||||
)
|
||||
from app.services.runtime_chat import RuntimeChatService
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RiskRuleField:
|
||||
key: str
|
||||
label: str
|
||||
field_type: str
|
||||
source: str
|
||||
aliases: tuple[str, ...]
|
||||
|
||||
|
||||
BUSINESS_DOMAIN_LABELS: dict[str, str] = {
|
||||
AgentAssetDomain.EXPENSE.value: "报销",
|
||||
AgentAssetDomain.AR.value: "应收",
|
||||
AgentAssetDomain.AP.value: "应付",
|
||||
}
|
||||
|
||||
RISK_LEVEL_LABELS: dict[str, str] = {
|
||||
"low": "低风险",
|
||||
"medium": "中风险",
|
||||
"high": "高风险",
|
||||
}
|
||||
|
||||
FIELD_ONTOLOGY: tuple[RiskRuleField, ...] = (
|
||||
RiskRuleField("claim.reason", "报销事由", "text", "claim", ("事由", "说明", "理由", "用途")),
|
||||
RiskRuleField(
|
||||
"claim.location",
|
||||
"申报地点",
|
||||
"text",
|
||||
"claim",
|
||||
("地点", "城市", "出差地", "申报地点", "申报目的地", "目的地"),
|
||||
),
|
||||
RiskRuleField("claim.amount", "申报金额", "number", "claim", ("金额", "费用", "超额", "额度")),
|
||||
RiskRuleField("claim.employee_name", "报销人", "text", "claim", ("报销人", "员工", "申请人")),
|
||||
RiskRuleField("claim.department_name", "部门", "text", "claim", ("部门", "组织")),
|
||||
RiskRuleField("item.item_type", "费用类型", "enum", "item", ("费用类型", "科目", "类型")),
|
||||
RiskRuleField("item.item_reason", "明细事由", "text", "item", ("明细事由", "明细说明")),
|
||||
RiskRuleField("item.item_location", "明细地点", "text", "item", ("明细地点", "发生地点")),
|
||||
RiskRuleField(
|
||||
"attachment.invoice_no", "发票号码", "text", "attachment", ("发票号", "发票号码", "票号")
|
||||
),
|
||||
RiskRuleField(
|
||||
"attachment.buyer_name", "购买方名称", "text", "attachment", ("抬头", "购买方", "开票单位")
|
||||
),
|
||||
RiskRuleField(
|
||||
"attachment.goods_name",
|
||||
"商品服务名称",
|
||||
"text",
|
||||
"attachment",
|
||||
("品名", "商品", "服务名称", "摘要"),
|
||||
),
|
||||
RiskRuleField(
|
||||
"attachment.issue_date",
|
||||
"开票日期",
|
||||
"date",
|
||||
"attachment",
|
||||
("开票日期", "发票日期", "票据日期"),
|
||||
),
|
||||
RiskRuleField(
|
||||
"attachment.hotel_city",
|
||||
"住宿城市",
|
||||
"text",
|
||||
"attachment",
|
||||
("住宿城市", "酒店城市", "酒店地点", "酒店发票城市", "酒店票城市", "住宿发票城市"),
|
||||
),
|
||||
RiskRuleField(
|
||||
"attachment.route_cities",
|
||||
"行程城市",
|
||||
"list",
|
||||
"attachment",
|
||||
("行程", "路线", "途经城市", "出差城市", "交通票行程", "交通票城市"),
|
||||
),
|
||||
RiskRuleField(
|
||||
"attachment.ocr_text",
|
||||
"票据全文",
|
||||
"text",
|
||||
"attachment",
|
||||
("票据内容", "OCR", "全文", "关键字", "关键词"),
|
||||
),
|
||||
RiskRuleField(
|
||||
"receivable.aging_days", "应收账龄", "number", "receivable", ("账龄", "逾期", "应收逾期")
|
||||
),
|
||||
RiskRuleField(
|
||||
"receivable.amount_outstanding",
|
||||
"应收未收金额",
|
||||
"number",
|
||||
"receivable",
|
||||
("未收金额", "欠款", "应收余额"),
|
||||
),
|
||||
RiskRuleField(
|
||||
"payable.vendor_name", "供应商名称", "text", "payable", ("供应商", "付款方", "往来单位")
|
||||
),
|
||||
RiskRuleField(
|
||||
"payable.amount_outstanding", "应付未付金额", "number", "payable", ("未付金额", "应付余额")
|
||||
),
|
||||
)
|
||||
|
||||
DOMAIN_FIELD_PREFIXES: dict[str, tuple[str, ...]] = {
|
||||
AgentAssetDomain.EXPENSE.value: ("claim.", "item.", "attachment."),
|
||||
AgentAssetDomain.AR.value: ("receivable.",),
|
||||
AgentAssetDomain.AP.value: ("payable.",),
|
||||
}
|
||||
|
||||
|
||||
class RiskRuleGenerationService:
|
||||
def __init__(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
rule_library_manager: AgentAssetRuleLibraryManager | None = None,
|
||||
runtime_chat_service: RuntimeChatService | None = None,
|
||||
) -> None:
|
||||
self.db = db
|
||||
self.rule_library_manager = rule_library_manager or AgentAssetRuleLibraryManager()
|
||||
self.runtime_chat_service = runtime_chat_service or RuntimeChatService(db)
|
||||
self.audit_service = AuditLogService(db)
|
||||
self.flow_diagram_renderer = RiskRuleFlowDiagramRenderer()
|
||||
|
||||
def generate_rule_asset(
|
||||
self,
|
||||
body: AgentAssetRiskRuleGenerateRequest,
|
||||
*,
|
||||
actor: str,
|
||||
request_id: str | None = None,
|
||||
) -> str:
|
||||
domain = body.business_domain.value
|
||||
if domain not in BUSINESS_DOMAIN_LABELS:
|
||||
raise ValueError("当前仅支持报销、应收、应付业务域的新建风险规则。")
|
||||
|
||||
natural_language = self._clean_text(body.natural_language)
|
||||
if len(natural_language) < 8:
|
||||
raise ValueError("请至少输入 8 个字的风险规则描述。")
|
||||
|
||||
risk_level = str(body.risk_level or "medium").strip().lower()
|
||||
if risk_level not in RISK_LEVEL_LABELS:
|
||||
raise ValueError("风险等级仅支持 low、medium、high。")
|
||||
|
||||
created_at = datetime.now(UTC)
|
||||
fields = self._resolve_fields(natural_language, domain=domain)
|
||||
draft = self._compile_with_model(
|
||||
natural_language=natural_language,
|
||||
domain=domain,
|
||||
risk_level=risk_level,
|
||||
fields=fields,
|
||||
) or self._build_fallback_draft(
|
||||
natural_language=natural_language,
|
||||
domain=domain,
|
||||
risk_level=risk_level,
|
||||
fields=fields,
|
||||
)
|
||||
draft = self._align_draft_fields(
|
||||
draft,
|
||||
natural_language=natural_language,
|
||||
fields=fields,
|
||||
)
|
||||
payload = self._build_rule_payload(
|
||||
draft,
|
||||
natural_language=natural_language,
|
||||
domain=domain,
|
||||
risk_level=risk_level,
|
||||
fields=fields,
|
||||
created_at=created_at,
|
||||
actor=actor,
|
||||
)
|
||||
rule_code = str(payload["rule_code"])
|
||||
file_name = f"{rule_code}.json"
|
||||
|
||||
self.rule_library_manager.write_rule_library_json(
|
||||
library=RISK_RULES_LIBRARY,
|
||||
file_name=file_name,
|
||||
payload=payload,
|
||||
)
|
||||
|
||||
asset = AgentAsset(
|
||||
asset_type=AgentAssetType.RULE.value,
|
||||
code=rule_code,
|
||||
name=str(payload["name"]),
|
||||
description=str(payload["description"]),
|
||||
domain=domain,
|
||||
scenario_json=[str(payload.get("risk_category") or BUSINESS_DOMAIN_LABELS[domain])],
|
||||
owner=actor,
|
||||
reviewer=None,
|
||||
status=AgentAssetStatus.DRAFT.value,
|
||||
current_version="v0.1.0",
|
||||
published_version=None,
|
||||
working_version="v0.1.0",
|
||||
config_json={
|
||||
"severity": risk_level,
|
||||
"enabled": True,
|
||||
"tag": "风险规则",
|
||||
"detail_mode": "json_risk",
|
||||
"risk_category": payload.get("risk_category"),
|
||||
"rule_library": RISK_RULES_LIBRARY,
|
||||
"rule_document": {
|
||||
"file_name": file_name,
|
||||
"storage_key": f"rules/{RISK_RULES_LIBRARY}/{file_name}",
|
||||
},
|
||||
"ontology_signal": payload.get("ontology_signal"),
|
||||
"evaluator": payload.get("evaluator"),
|
||||
"generated_by": "natural_language",
|
||||
"source_ref": "自然语言风险规则",
|
||||
},
|
||||
)
|
||||
self.db.add(asset)
|
||||
self.db.flush()
|
||||
self.db.add(
|
||||
AgentAssetVersion(
|
||||
asset_id=asset.id,
|
||||
version="v0.1.0",
|
||||
content=self._build_version_markdown(payload),
|
||||
content_type="markdown",
|
||||
change_note="通过自然语言新建风险规则草稿。",
|
||||
created_by=actor,
|
||||
)
|
||||
)
|
||||
self.audit_service.log_action(
|
||||
actor=actor,
|
||||
action="generate_agent_asset_risk_rule",
|
||||
resource_type=AgentAssetType.RULE.value,
|
||||
resource_id=asset.id,
|
||||
before_json=None,
|
||||
after_json={"rule_code": rule_code, "risk_level": risk_level, "domain": domain},
|
||||
request_id=request_id,
|
||||
)
|
||||
self.db.refresh(asset)
|
||||
return asset.id
|
||||
|
||||
def _compile_with_model(
|
||||
self,
|
||||
*,
|
||||
natural_language: str,
|
||||
domain: str,
|
||||
risk_level: str,
|
||||
fields: list[RiskRuleField],
|
||||
) -> dict[str, Any] | None:
|
||||
field_payload = [
|
||||
{
|
||||
"key": item.key,
|
||||
"label": item.label,
|
||||
"type": item.field_type,
|
||||
"source": item.source,
|
||||
}
|
||||
for item in fields
|
||||
]
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"你是 X-Financial 风险规则编译器。只能输出 JSON 对象,不要解释。"
|
||||
"必须从给定字段本体中选择字段,不允许编造字段。"
|
||||
"template_key 只能是 field_required_v1、field_compare_v1、keyword_match_v1。"
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": json.dumps(
|
||||
{
|
||||
"business_domain": domain,
|
||||
"business_domain_label": BUSINESS_DOMAIN_LABELS[domain],
|
||||
"risk_level": risk_level,
|
||||
"risk_level_label": RISK_LEVEL_LABELS[risk_level],
|
||||
"natural_language": natural_language,
|
||||
"available_fields": field_payload,
|
||||
"required_json_shape": {
|
||||
"name": "规则名称",
|
||||
"description": "面向业务用户的说明",
|
||||
"template_key": "field_required_v1",
|
||||
"field_keys": ["claim.reason"],
|
||||
"condition_summary": "判断依据",
|
||||
"keywords": [],
|
||||
"flow": {
|
||||
"start": "提交业务单据",
|
||||
"evidence": "读取字段",
|
||||
"decision": "判断依据",
|
||||
"pass": "继续流转",
|
||||
"fail": "提示风险",
|
||||
},
|
||||
},
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
},
|
||||
]
|
||||
answer = self.runtime_chat_service.complete(
|
||||
messages,
|
||||
max_tokens=700,
|
||||
temperature=0.1,
|
||||
timeout_seconds=12,
|
||||
max_attempts=1,
|
||||
)
|
||||
if not answer:
|
||||
return None
|
||||
|
||||
try:
|
||||
payload = json.loads(self._extract_json_object(answer))
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
return None
|
||||
if not isinstance(payload, dict):
|
||||
return None
|
||||
return self._sanitize_model_draft(payload, fields=fields)
|
||||
|
||||
def _sanitize_model_draft(
|
||||
self,
|
||||
payload: dict[str, Any],
|
||||
*,
|
||||
fields: list[RiskRuleField],
|
||||
) -> dict[str, Any]:
|
||||
allowed_fields = {item.key for item in fields}
|
||||
template_key = str(payload.get("template_key") or "").strip()
|
||||
if template_key not in {"field_required_v1", "field_compare_v1", "keyword_match_v1"}:
|
||||
template_key = "field_required_v1"
|
||||
|
||||
raw_field_keys = payload.get("field_keys")
|
||||
field_keys = [
|
||||
str(item or "").strip()
|
||||
for item in (raw_field_keys if isinstance(raw_field_keys, list) else [])
|
||||
if str(item or "").strip() in allowed_fields
|
||||
]
|
||||
if not field_keys and fields:
|
||||
field_keys = [fields[0].key]
|
||||
|
||||
keywords = [
|
||||
str(item or "").strip()
|
||||
for item in (
|
||||
payload.get("keywords") if isinstance(payload.get("keywords"), list) else []
|
||||
)
|
||||
if str(item or "").strip()
|
||||
]
|
||||
flow = payload.get("flow") if isinstance(payload.get("flow"), dict) else {}
|
||||
return {
|
||||
"name": self._clean_text(payload.get("name"))[:80],
|
||||
"description": self._clean_text(payload.get("description")),
|
||||
"template_key": template_key,
|
||||
"field_keys": field_keys,
|
||||
"condition_summary": self._clean_text(payload.get("condition_summary")),
|
||||
"keywords": keywords[:12],
|
||||
"flow": {
|
||||
"start": self._clean_text(flow.get("start")) or "提交业务单据",
|
||||
"evidence": self._clean_text(flow.get("evidence")) or "读取规则字段",
|
||||
"decision": self._clean_text(flow.get("decision")) or "判断是否命中风险",
|
||||
"pass": self._clean_text(flow.get("pass")) or "继续流转",
|
||||
"fail": self._clean_text(flow.get("fail")) or "提示风险并进入复核",
|
||||
},
|
||||
}
|
||||
|
||||
def _build_fallback_draft(
|
||||
self,
|
||||
*,
|
||||
natural_language: str,
|
||||
domain: str,
|
||||
risk_level: str,
|
||||
fields: list[RiskRuleField],
|
||||
) -> dict[str, Any]:
|
||||
field_keys = [item.key for item in fields[:4]]
|
||||
template_key = self._infer_template_key(natural_language)
|
||||
condition_summary = self._build_condition_summary(
|
||||
natural_language,
|
||||
template_key=template_key,
|
||||
fields=fields,
|
||||
)
|
||||
name = self._infer_rule_name(natural_language)
|
||||
description = (
|
||||
f"当{BUSINESS_DOMAIN_LABELS[domain]}业务满足“{natural_language}”时,系统会按"
|
||||
f"{RISK_LEVEL_LABELS[risk_level]}进行提示,并要求经办人或审核人补充核对依据。"
|
||||
)
|
||||
return {
|
||||
"name": name,
|
||||
"description": description,
|
||||
"template_key": template_key,
|
||||
"field_keys": field_keys,
|
||||
"condition_summary": condition_summary,
|
||||
"keywords": self._infer_keywords(natural_language),
|
||||
"flow": {
|
||||
"start": f"{BUSINESS_DOMAIN_LABELS[domain]}单据提交",
|
||||
"evidence": "读取" + "、".join(item.label for item in fields[:3]),
|
||||
"decision": condition_summary,
|
||||
"pass": "未命中风险,继续业务流转",
|
||||
"fail": f"命中{RISK_LEVEL_LABELS[risk_level]},提示复核",
|
||||
},
|
||||
}
|
||||
|
||||
def _build_rule_payload(
|
||||
self,
|
||||
draft: dict[str, Any],
|
||||
*,
|
||||
natural_language: str,
|
||||
domain: str,
|
||||
risk_level: str,
|
||||
fields: list[RiskRuleField],
|
||||
created_at: datetime,
|
||||
actor: str,
|
||||
) -> dict[str, Any]:
|
||||
created_stamp = created_at.strftime("%Y%m%d%H%M%S")
|
||||
domain_slug = {"expense": "expense", "ar": "ar", "ap": "ap"}[domain]
|
||||
rule_code = f"risk.{domain_slug}.generated_{created_stamp}"
|
||||
template_key = str(draft.get("template_key") or "field_required_v1").strip()
|
||||
field_keys = [
|
||||
str(item or "").strip()
|
||||
for item in list(draft.get("field_keys") or [])
|
||||
if str(item or "").strip()
|
||||
]
|
||||
condition_summary = (
|
||||
self._clean_text(draft.get("condition_summary")) or "判断是否符合自然语言规则描述"
|
||||
)
|
||||
risk_category = BUSINESS_DOMAIN_LABELS[domain]
|
||||
keywords = list(draft.get("keywords") or [])
|
||||
field_by_key = {item.key: item for item in fields}
|
||||
params: dict[str, Any] = {
|
||||
"template_key": template_key,
|
||||
"field_keys": field_keys,
|
||||
"condition_summary": condition_summary,
|
||||
"natural_language": natural_language,
|
||||
}
|
||||
if template_key == "field_required_v1":
|
||||
params["required_fields"] = field_keys
|
||||
if template_key == "field_compare_v1":
|
||||
params["conditions"] = self._build_compare_conditions(field_keys)
|
||||
if template_key == "keyword_match_v1":
|
||||
params["keywords"] = keywords
|
||||
params["search_fields"] = field_keys
|
||||
|
||||
payload = {
|
||||
"schema_version": "2.0",
|
||||
"rule_code": rule_code,
|
||||
"name": self._clean_text(draft.get("name")) or self._infer_rule_name(natural_language),
|
||||
"description": self._clean_text(draft.get("description")) or natural_language,
|
||||
"enabled": True,
|
||||
"risk_dimension": "natural_language_rule",
|
||||
"risk_category": risk_category,
|
||||
"ontology_signal": "natural_language_risk",
|
||||
"evaluator": "template_rule",
|
||||
"template_key": template_key,
|
||||
"applies_to": {"domains": [domain]},
|
||||
"inputs": {
|
||||
"fields": [
|
||||
{
|
||||
"key": item.key,
|
||||
"label": item.label,
|
||||
"type": item.field_type,
|
||||
"source": item.source,
|
||||
}
|
||||
for item in [field_by_key[key] for key in field_keys if key in field_by_key]
|
||||
],
|
||||
},
|
||||
"params": params,
|
||||
"outcomes": {
|
||||
"pass": {"severity": "none", "action": "continue"},
|
||||
"fail": {
|
||||
"severity": risk_level,
|
||||
"action": "manual_review",
|
||||
},
|
||||
},
|
||||
"metadata": {
|
||||
"owner": actor,
|
||||
"stability": "generated_draft",
|
||||
"source_ref": "自然语言风险规则",
|
||||
"created_at": created_at.isoformat(),
|
||||
"created_by": actor,
|
||||
"natural_language": natural_language,
|
||||
"business_explanation": self._clean_text(draft.get("description")),
|
||||
"condition_summary": condition_summary,
|
||||
"flow": draft.get("flow") if isinstance(draft.get("flow"), dict) else {},
|
||||
},
|
||||
}
|
||||
payload["flow_diagram_svg"] = self._build_flow_diagram_svg(
|
||||
payload,
|
||||
fields=[field_by_key[key] for key in field_keys if key in field_by_key],
|
||||
domain=domain,
|
||||
risk_level=risk_level,
|
||||
)
|
||||
return payload
|
||||
|
||||
def _build_flow_diagram_svg(
|
||||
self,
|
||||
payload: dict[str, Any],
|
||||
*,
|
||||
fields: list[RiskRuleField],
|
||||
domain: str,
|
||||
risk_level: str,
|
||||
) -> str:
|
||||
metadata = payload.get("metadata") if isinstance(payload.get("metadata"), dict) else {}
|
||||
flow = metadata.get("flow") if isinstance(metadata.get("flow"), dict) else {}
|
||||
condition_summary = self._clean_text(metadata.get("condition_summary"))
|
||||
return self.flow_diagram_renderer.render(
|
||||
RiskRuleFlowDiagramSpec(
|
||||
title=self._clean_text(payload.get("name")) or "风险规则判断流程",
|
||||
domain_label=BUSINESS_DOMAIN_LABELS.get(domain, "业务"),
|
||||
severity=risk_level,
|
||||
severity_label=RISK_LEVEL_LABELS.get(risk_level, "中风险"),
|
||||
fields=tuple(
|
||||
RiskRuleFlowDiagramField(key=field.key, label=field.label) for field in fields
|
||||
),
|
||||
start=self._clean_text(flow.get("start")) or "业务单据提交",
|
||||
evidence=self._clean_text(flow.get("evidence")) or "读取规则字段",
|
||||
decision=self._clean_text(flow.get("decision"))
|
||||
or condition_summary
|
||||
or "判断是否命中风险",
|
||||
basis=(
|
||||
condition_summary
|
||||
or self._clean_text(flow.get("decision"))
|
||||
or "根据规则字段判断"
|
||||
),
|
||||
pass_text=self._clean_text(flow.get("pass")) or "未命中风险,继续流转",
|
||||
fail_text=self._clean_text(flow.get("fail"))
|
||||
or f"命中{RISK_LEVEL_LABELS.get(risk_level, '风险')},进入人工复核",
|
||||
)
|
||||
)
|
||||
|
||||
def _resolve_fields(self, text: str, *, domain: str) -> list[RiskRuleField]:
|
||||
prefixes = DOMAIN_FIELD_PREFIXES.get(domain, ())
|
||||
candidates = [field for field in FIELD_ONTOLOGY if field.key.startswith(prefixes)]
|
||||
normalized = text.lower()
|
||||
matched: list[tuple[int, RiskRuleField]] = []
|
||||
for field in candidates:
|
||||
score = self._score_field_match(field, text, normalized)
|
||||
if score > 0:
|
||||
matched.append((score, field))
|
||||
|
||||
if domain == AgentAssetDomain.EXPENSE.value:
|
||||
if any(keyword in text for keyword in ("住宿", "酒店", "行程", "城市", "出差")):
|
||||
matched.extend(
|
||||
(10, field)
|
||||
for field in candidates
|
||||
if field.key
|
||||
in {"claim.location", "attachment.hotel_city", "attachment.route_cities"}
|
||||
)
|
||||
if any(keyword in text for keyword in ("发票", "票据", "品名", "抬头", "开票")):
|
||||
matched.extend(
|
||||
(6, field)
|
||||
for field in candidates
|
||||
if field.key
|
||||
in {
|
||||
"attachment.invoice_no",
|
||||
"attachment.buyer_name",
|
||||
"attachment.goods_name",
|
||||
"attachment.ocr_text",
|
||||
}
|
||||
)
|
||||
|
||||
matched.sort(key=lambda item: item[0], reverse=True)
|
||||
deduped: list[RiskRuleField] = []
|
||||
seen: set[str] = set()
|
||||
for _, field in matched:
|
||||
if field.key in seen:
|
||||
continue
|
||||
seen.add(field.key)
|
||||
deduped.append(field)
|
||||
if deduped:
|
||||
return deduped[:8]
|
||||
return candidates[:4]
|
||||
|
||||
@staticmethod
|
||||
def _score_field_match(field: RiskRuleField, text: str, normalized: str) -> int:
|
||||
score = 0
|
||||
if field.label in text:
|
||||
score += 8
|
||||
for alias in field.aliases:
|
||||
if alias.lower() in normalized:
|
||||
score += 4 + min(len(alias), 6)
|
||||
|
||||
if field.key == "attachment.hotel_city" and any(term in text for term in ("酒店", "住宿")):
|
||||
score += 12
|
||||
if field.key == "attachment.route_cities" and any(
|
||||
term in text for term in ("行程", "交通票", "路线", "途经")
|
||||
):
|
||||
score += 10
|
||||
if field.key == "claim.location" and any(
|
||||
term in text for term in ("申报目的地", "申报地点", "目的地", "出差地")
|
||||
):
|
||||
score += 10
|
||||
if field.key.startswith("attachment.") and any(term in text for term in ("发票", "票据")):
|
||||
score += 2
|
||||
return score
|
||||
|
||||
def _align_draft_fields(
|
||||
self,
|
||||
draft: dict[str, Any],
|
||||
*,
|
||||
natural_language: str,
|
||||
fields: list[RiskRuleField],
|
||||
) -> dict[str, Any]:
|
||||
field_by_key = {field.key: field for field in fields}
|
||||
original_keys = [
|
||||
str(item or "").strip()
|
||||
for item in list(draft.get("field_keys") or [])
|
||||
if str(item or "").strip() in field_by_key
|
||||
]
|
||||
preferred_keys: list[str] = []
|
||||
|
||||
def add_preferred(key: str, *terms: str) -> None:
|
||||
if key in field_by_key and any(term in natural_language for term in terms):
|
||||
preferred_keys.append(key)
|
||||
|
||||
add_preferred("attachment.hotel_city", "酒店", "住宿")
|
||||
add_preferred("claim.location", "申报目的地", "申报地点", "目的地", "出差地")
|
||||
add_preferred("attachment.route_cities", "行程", "交通票", "路线", "途经")
|
||||
|
||||
merged_keys: list[str] = []
|
||||
for key in [*preferred_keys, *original_keys, *[field.key for field in fields]]:
|
||||
if key in field_by_key and key not in merged_keys:
|
||||
merged_keys.append(key)
|
||||
if len(merged_keys) >= 4:
|
||||
break
|
||||
|
||||
if draft.get("template_key") == "field_compare_v1" and len(merged_keys) < 2:
|
||||
for field in fields:
|
||||
if field.key not in merged_keys:
|
||||
merged_keys.append(field.key)
|
||||
if len(merged_keys) >= 2:
|
||||
break
|
||||
|
||||
aligned = {**draft, "field_keys": merged_keys}
|
||||
selected_fields = [field_by_key[key] for key in merged_keys if key in field_by_key]
|
||||
if selected_fields:
|
||||
aligned["condition_summary"] = self._build_condition_summary(
|
||||
natural_language,
|
||||
template_key=str(aligned.get("template_key") or "field_required_v1"),
|
||||
fields=selected_fields,
|
||||
)
|
||||
flow = aligned.get("flow") if isinstance(aligned.get("flow"), dict) else {}
|
||||
aligned["flow"] = {
|
||||
**flow,
|
||||
"evidence": "读取" + "、".join(field.label for field in selected_fields[:3]),
|
||||
"decision": aligned["condition_summary"],
|
||||
}
|
||||
return aligned
|
||||
|
||||
@staticmethod
|
||||
def _build_compare_conditions(field_keys: list[str]) -> list[dict[str, str]]:
|
||||
if len(field_keys) >= 2:
|
||||
return [{"left": field_keys[0], "operator": "overlap", "right": field_keys[1]}]
|
||||
if field_keys:
|
||||
return [{"left": field_keys[0], "operator": "is_empty", "right": ""}]
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _infer_template_key(text: str) -> str:
|
||||
if any(
|
||||
keyword in text
|
||||
for keyword in ("一致", "匹配", "相同", "不一致", "不符", "对应", "出现在")
|
||||
):
|
||||
return "field_compare_v1"
|
||||
if any(
|
||||
keyword in text
|
||||
for keyword in ("关键词", "包含", "出现", "品名", "摘要", "服务费", "咨询费")
|
||||
):
|
||||
return "keyword_match_v1"
|
||||
return "field_required_v1"
|
||||
|
||||
@staticmethod
|
||||
def _infer_keywords(text: str) -> list[str]:
|
||||
quoted = re.findall(r"[“\"']([^“”\"']{2,20})[”\"']", text)
|
||||
keywords = [item.strip() for item in quoted if item.strip()]
|
||||
for candidate in ("咨询费", "服务费", "其他", "办公用品", "招待", "红冲", "作废"):
|
||||
if candidate in text and candidate not in keywords:
|
||||
keywords.append(candidate)
|
||||
return keywords[:8]
|
||||
|
||||
@staticmethod
|
||||
def _infer_rule_name(text: str) -> str:
|
||||
normalized = re.sub(r"\s+", "", str(text or ""))
|
||||
normalized = re.sub(r"[,。;;::、,.!?!?]", "", normalized)
|
||||
if not normalized:
|
||||
return "自然语言风险规则"
|
||||
return f"{normalized[:18]}风险规则"
|
||||
|
||||
@staticmethod
|
||||
def _build_condition_summary(
|
||||
natural_language: str,
|
||||
*,
|
||||
template_key: str,
|
||||
fields: list[RiskRuleField],
|
||||
) -> str:
|
||||
field_text = "、".join(item.label for item in fields[:3]) or "业务字段"
|
||||
if template_key == "field_compare_v1":
|
||||
return f"对比{field_text}之间是否一致或存在交集"
|
||||
if template_key == "keyword_match_v1":
|
||||
return f"检查{field_text}是否出现规则描述中的风险关键词"
|
||||
return f"检查{field_text}是否满足必填和完整性要求"
|
||||
|
||||
@staticmethod
|
||||
def _clean_text(value: Any) -> str:
|
||||
return re.sub(r"\s+", " ", str(value or "")).strip()
|
||||
|
||||
@staticmethod
|
||||
def _extract_json_object(text: str) -> str:
|
||||
normalized = re.sub(r"^```(?:json)?|```$", "", str(text or "").strip(), flags=re.IGNORECASE)
|
||||
start = normalized.find("{")
|
||||
end = normalized.rfind("}")
|
||||
if start < 0 or end <= start:
|
||||
raise ValueError("JSON object not found.")
|
||||
return normalized[start : end + 1]
|
||||
|
||||
@staticmethod
|
||||
def _build_version_markdown(payload: dict[str, Any]) -> str:
|
||||
metadata = payload.get("metadata") if isinstance(payload.get("metadata"), dict) else {}
|
||||
fields = (
|
||||
payload.get("inputs", {}).get("fields")
|
||||
if isinstance(payload.get("inputs"), dict)
|
||||
else []
|
||||
)
|
||||
field_labels = [
|
||||
str(item.get("label") or item.get("key") or "").strip()
|
||||
for item in fields
|
||||
if isinstance(item, dict) and str(item.get("label") or item.get("key") or "").strip()
|
||||
]
|
||||
return "\n".join(
|
||||
[
|
||||
f"# {payload.get('name')}",
|
||||
"",
|
||||
"## 业务说明",
|
||||
"",
|
||||
str(payload.get("description") or ""),
|
||||
"",
|
||||
"## 自然语言原文",
|
||||
"",
|
||||
str(metadata.get("natural_language") or ""),
|
||||
"",
|
||||
"## 使用字段",
|
||||
"",
|
||||
"、".join(field_labels) or "未识别字段",
|
||||
"",
|
||||
"## 运行时 JSON",
|
||||
"",
|
||||
"```json",
|
||||
json.dumps(payload, ensure_ascii=False, indent=2),
|
||||
"```",
|
||||
]
|
||||
)
|
||||
259
server/src/app/services/risk_rule_template_executor.py
Normal file
259
server/src/app/services/risk_rule_template_executor.py
Normal file
@@ -0,0 +1,259 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from app.models.financial_record import ExpenseClaim
|
||||
|
||||
|
||||
class RiskRuleTemplateExecutor:
|
||||
def evaluate(
|
||||
self,
|
||||
manifest: dict[str, Any],
|
||||
*,
|
||||
claim: ExpenseClaim,
|
||||
contexts: list[dict[str, Any]],
|
||||
) -> dict[str, Any] | None:
|
||||
params = manifest.get("params") if isinstance(manifest.get("params"), dict) else {}
|
||||
template_key = str(manifest.get("template_key") or params.get("template_key") or "").strip()
|
||||
|
||||
if template_key == "field_required_v1":
|
||||
return self._evaluate_required_fields(params, claim=claim, contexts=contexts)
|
||||
if template_key == "field_compare_v1":
|
||||
return self._evaluate_compare_conditions(params, claim=claim, contexts=contexts)
|
||||
if template_key == "keyword_match_v1":
|
||||
return self._evaluate_keyword_match(params, claim=claim, contexts=contexts)
|
||||
return None
|
||||
|
||||
def _evaluate_required_fields(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
*,
|
||||
claim: ExpenseClaim,
|
||||
contexts: list[dict[str, Any]],
|
||||
) -> dict[str, Any] | None:
|
||||
required_fields = self._read_string_list(
|
||||
params.get("required_fields") or params.get("field_keys")
|
||||
)
|
||||
missing = [
|
||||
field_key
|
||||
for field_key in required_fields
|
||||
if not self._has_resolved_value(field_key, claim=claim, contexts=contexts)
|
||||
]
|
||||
if not missing:
|
||||
return None
|
||||
return {
|
||||
"message": self._resolve_message(
|
||||
params,
|
||||
fallback=f"规则要求的字段未完整提供:{'、'.join(missing[:4])}。",
|
||||
),
|
||||
"evidence": {
|
||||
"missing_fields": missing,
|
||||
"condition_summary": params.get("condition_summary"),
|
||||
},
|
||||
}
|
||||
|
||||
def _evaluate_compare_conditions(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
*,
|
||||
claim: ExpenseClaim,
|
||||
contexts: list[dict[str, Any]],
|
||||
) -> dict[str, Any] | None:
|
||||
conditions = params.get("conditions") if isinstance(params.get("conditions"), list) else []
|
||||
failures: list[dict[str, Any]] = []
|
||||
for condition in conditions:
|
||||
if not isinstance(condition, dict):
|
||||
continue
|
||||
left_key = str(condition.get("left") or "").strip()
|
||||
right_key = str(condition.get("right") or "").strip()
|
||||
operator = str(condition.get("operator") or "not_overlap").strip()
|
||||
left_values = self._resolve_values(left_key, claim=claim, contexts=contexts)
|
||||
right_values = self._resolve_values(right_key, claim=claim, contexts=contexts)
|
||||
if self._condition_passes(operator, left_values, right_values):
|
||||
continue
|
||||
failures.append(
|
||||
{
|
||||
"left": left_key,
|
||||
"operator": operator,
|
||||
"right": right_key,
|
||||
"left_values": left_values[:5],
|
||||
"right_values": right_values[:5],
|
||||
}
|
||||
)
|
||||
|
||||
if not failures:
|
||||
return None
|
||||
return {
|
||||
"message": self._resolve_message(
|
||||
params,
|
||||
fallback=(
|
||||
"规则字段对比未通过:"
|
||||
f"{params.get('condition_summary') or '字段关系不符合要求'}。"
|
||||
),
|
||||
),
|
||||
"evidence": {
|
||||
"failed_conditions": failures[:5],
|
||||
"condition_summary": params.get("condition_summary"),
|
||||
},
|
||||
}
|
||||
|
||||
def _evaluate_keyword_match(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
*,
|
||||
claim: ExpenseClaim,
|
||||
contexts: list[dict[str, Any]],
|
||||
) -> dict[str, Any] | None:
|
||||
keywords = self._read_string_list(params.get("keywords"))
|
||||
search_fields = self._read_string_list(
|
||||
params.get("search_fields") or params.get("field_keys")
|
||||
)
|
||||
if not keywords:
|
||||
return None
|
||||
|
||||
corpus_parts: list[str] = []
|
||||
for field_key in search_fields:
|
||||
corpus_parts.extend(self._resolve_values(field_key, claim=claim, contexts=contexts))
|
||||
if not corpus_parts:
|
||||
corpus_parts.extend(
|
||||
[
|
||||
str(claim.reason or ""),
|
||||
str(claim.location or ""),
|
||||
*[str(item.item_reason or "") for item in list(claim.items or [])],
|
||||
*[str(context.get("ocr_text") or "") for context in contexts],
|
||||
]
|
||||
)
|
||||
corpus = "\n".join(corpus_parts)
|
||||
hits = [keyword for keyword in keywords if keyword and keyword in corpus]
|
||||
if not hits:
|
||||
return None
|
||||
return {
|
||||
"message": self._resolve_message(
|
||||
params,
|
||||
fallback=f"识别到风险关键词:{'、'.join(hits[:5])}。",
|
||||
),
|
||||
"evidence": {
|
||||
"keyword_hits": hits[:8],
|
||||
"search_fields": search_fields,
|
||||
"condition_summary": params.get("condition_summary"),
|
||||
},
|
||||
}
|
||||
|
||||
def _resolve_values(
|
||||
self,
|
||||
field_key: str,
|
||||
*,
|
||||
claim: ExpenseClaim,
|
||||
contexts: list[dict[str, Any]],
|
||||
) -> list[str]:
|
||||
normalized = str(field_key or "").strip()
|
||||
if not normalized:
|
||||
return []
|
||||
if normalized.startswith("claim."):
|
||||
return self._normalize_values([getattr(claim, normalized.removeprefix("claim."), "")])
|
||||
if normalized.startswith("item."):
|
||||
attr = normalized.removeprefix("item.")
|
||||
return self._normalize_values(
|
||||
[getattr(item, attr, "") for item in list(claim.items or [])]
|
||||
)
|
||||
if normalized.startswith("attachment."):
|
||||
return self._resolve_attachment_values(normalized.removeprefix("attachment."), contexts)
|
||||
return []
|
||||
|
||||
def _resolve_attachment_values(
|
||||
self, field_key: str, contexts: list[dict[str, Any]]
|
||||
) -> list[str]:
|
||||
values: list[Any] = []
|
||||
for context in contexts:
|
||||
document_info = context.get("document_info") if isinstance(context, dict) else {}
|
||||
if not isinstance(document_info, dict):
|
||||
document_info = {}
|
||||
if field_key == "ocr_text":
|
||||
values.extend([context.get("ocr_text"), context.get("ocr_summary")])
|
||||
if field_key in {"hotel_city", "route_cities"}:
|
||||
values.extend(self._scan_document_values(document_info, field_key))
|
||||
values.extend(self._scan_document_values(document_info, "city"))
|
||||
else:
|
||||
values.extend(self._scan_document_values(document_info, field_key))
|
||||
return self._normalize_values(values)
|
||||
|
||||
def _scan_document_values(self, document_info: dict[str, Any], field_key: str) -> list[Any]:
|
||||
values: list[Any] = []
|
||||
for key in {field_key, field_key.replace("_", ""), field_key.replace("_", "-")}:
|
||||
if key in document_info:
|
||||
values.append(document_info.get(key))
|
||||
for field in list(document_info.get("fields") or []):
|
||||
if not isinstance(field, dict):
|
||||
continue
|
||||
key = str(field.get("key") or "").strip().lower()
|
||||
label = str(field.get("label") or "").strip()
|
||||
if self._field_matches(key, label, field_key):
|
||||
values.append(field.get("value"))
|
||||
return values
|
||||
|
||||
@staticmethod
|
||||
def _field_matches(key: str, label: str, field_key: str) -> bool:
|
||||
compact_key = key.replace("_", "")
|
||||
compact_target = field_key.replace("_", "")
|
||||
if compact_target in compact_key:
|
||||
return True
|
||||
label_map = {
|
||||
"invoice_no": ("发票号", "发票号码", "票号"),
|
||||
"buyer_name": ("购买方", "抬头", "买方"),
|
||||
"goods_name": ("品名", "商品", "服务名称"),
|
||||
"issue_date": ("日期", "开票日期", "发票日期"),
|
||||
"hotel_city": ("住宿城市", "酒店城市", "酒店地点"),
|
||||
"route_cities": ("行程", "路线", "城市"),
|
||||
"city": ("城市", "地点"),
|
||||
}
|
||||
return any(item in label for item in label_map.get(field_key, ()))
|
||||
|
||||
def _has_resolved_value(
|
||||
self,
|
||||
field_key: str,
|
||||
*,
|
||||
claim: ExpenseClaim,
|
||||
contexts: list[dict[str, Any]],
|
||||
) -> bool:
|
||||
return bool(self._resolve_values(field_key, claim=claim, contexts=contexts))
|
||||
|
||||
@staticmethod
|
||||
def _condition_passes(operator: str, left_values: list[str], right_values: list[str]) -> bool:
|
||||
if operator == "is_empty":
|
||||
return not left_values
|
||||
if not left_values or not right_values:
|
||||
return False
|
||||
|
||||
left_set = {value.lower() for value in left_values}
|
||||
right_set = {value.lower() for value in right_values}
|
||||
if operator in {"equals", "in", "overlap"}:
|
||||
return bool(left_set & right_set)
|
||||
if operator in {"not_equals", "not_in", "not_overlap"}:
|
||||
return not bool(left_set & right_set)
|
||||
if operator == "contains_any":
|
||||
return any(any(right in left for right in right_set) for left in left_set)
|
||||
return bool(left_set & right_set)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_values(values: list[Any]) -> list[str]:
|
||||
normalized: list[str] = []
|
||||
for value in values:
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
normalized.extend(RiskRuleTemplateExecutor._normalize_values(list(value)))
|
||||
continue
|
||||
text = re.sub(r"\s+", " ", str(value or "")).strip()
|
||||
if text and text not in normalized:
|
||||
normalized.append(text)
|
||||
return normalized
|
||||
|
||||
@staticmethod
|
||||
def _read_string_list(value: Any) -> list[str]:
|
||||
if not isinstance(value, list):
|
||||
return []
|
||||
return [str(item or "").strip() for item in value if str(item or "").strip()]
|
||||
|
||||
@staticmethod
|
||||
def _resolve_message(params: dict[str, Any], *, fallback: str) -> str:
|
||||
template = str(params.get("message_template") or "").strip()
|
||||
return template or fallback
|
||||
Reference in New Issue
Block a user