feat: 增强规则资产管理与审计页面运行时调试

后端新增规则资产版本管理和规则文件 CRUD 接口,优化风险
规则生成模板执行和员工数据模型字段,知识库 RAG 增强本
地回退和文档提取能力,清理旧风险规则文件统一由生成引擎
管理,前端审计页面增加运行时调试面板和规则资产编辑交互,
补充单元测试覆盖。
This commit is contained in:
caoxiaozhu
2026-05-24 21:44:17 +08:00
parent 575f093c74
commit 50b1c3f9a9
113 changed files with 13896 additions and 5044 deletions

View File

@@ -0,0 +1,552 @@
from __future__ import annotations
import re
from datetime import UTC, date, datetime
from typing import Any
from app.schemas.agent_asset import (
AgentAssetRiskRuleSimulationAttachment,
AgentAssetRiskRuleSimulationRead,
AgentAssetRiskRuleSimulationRequest,
)
from app.services.risk_rule_template_executor import RiskRuleTemplateExecutor
class AgentAssetRiskRuleSimulationMixin:
def simulate_risk_rule_message(
self,
asset_id: str,
body: AgentAssetRiskRuleSimulationRequest,
) -> AgentAssetRiskRuleSimulationRead:
_, version, manifest = self._load_risk_rule_for_test(asset_id, body.version)
attachments = self._normalize_simulation_attachments(body.attachments)
field_values, source_map, recognized_fields = self._build_simulation_field_values(
manifest,
message=body.message,
explicit_values=body.field_values,
attachments=attachments,
)
recognition_summary = self._build_recognition_summary(attachments)
required_keys = self._extract_execution_field_keys(manifest)
missing_fields = self._build_missing_fields(
manifest,
field_values=field_values,
source_map=source_map,
required_keys=required_keys,
)
block = self._resolve_simulation_block(
manifest,
message=body.message,
attachments=attachments,
missing_fields=missing_fields,
)
if block:
return AgentAssetRiskRuleSimulationRead(
version=version,
ready=False,
stage=block["stage"],
hit=False,
severity="none",
severity_label="待补充",
summary=block["summary"],
blocking_reason=block["reason"],
field_values=field_values,
attachments=attachments,
recognized_fields=recognized_fields,
missing_fields=missing_fields,
recognition_summary=recognition_summary,
created_at=datetime.now(UTC),
)
claim, contexts = self._build_synthetic_claim(field_values, manifest)
result = RiskRuleTemplateExecutor().evaluate(manifest, claim=claim, contexts=contexts)
hit = result is not None
severity = (
str((manifest.get("outcomes") or {}).get("fail", {}).get("severity") or "medium")
if hit
else "none"
)
severity_label = self._risk_severity_label(severity)
message = str(result.get("message") or "") if isinstance(result, dict) else ""
summary = (
f"本次仿真命中{severity_label},仅生成风险识别结果,不创建业务单据。"
if hit
else "本次仿真未命中风险,仅完成规则识别,不创建业务单据。"
)
evidence = result.get("evidence") if isinstance(result, dict) else {}
return AgentAssetRiskRuleSimulationRead(
version=version,
ready=True,
stage="executed",
hit=hit,
severity=severity,
severity_label=severity_label,
summary=summary,
message=message,
field_values=field_values,
evidence=evidence if isinstance(evidence, dict) else {},
attachments=attachments,
recognized_fields=recognized_fields,
missing_fields=[],
recognition_summary=recognition_summary,
created_at=datetime.now(UTC),
)
def _build_simulation_field_values(
self,
manifest: dict[str, Any],
*,
message: str,
explicit_values: dict[str, Any],
attachments: list[dict[str, Any]],
) -> tuple[dict[str, Any], dict[str, str], list[dict[str, Any]]]:
fields = self._extract_manifest_fields(manifest)
values: dict[str, Any] = {}
source_map: dict[str, str] = {}
safe_explicit_values = explicit_values if isinstance(explicit_values, dict) else {}
corpus = self._build_simulation_corpus(message, attachments)
city_mentions = self._extract_city_mentions(corpus)
for field in fields:
key = field["key"]
explicit_value = safe_explicit_values.get(key)
if self._has_meaningful_value(explicit_value):
values[key] = explicit_value
source_map[key] = "manual"
continue
attachment_value = self._find_attachment_field_value(
key,
field.get("label") or key,
attachments,
)
if self._has_meaningful_value(attachment_value):
values[key] = attachment_value
source_map[key] = "ocr"
continue
inferred = self._infer_simulation_value(
key,
field.get("label") or key,
corpus=corpus,
city_mentions=city_mentions,
)
if self._has_meaningful_value(inferred):
values[key] = inferred
source_map[key] = "inferred"
self._apply_compare_city_hints(manifest, values, source_map, city_mentions)
recognized_fields = self._build_recognized_fields(fields, values, source_map)
return values, source_map, recognized_fields
def _infer_simulation_value(
self,
field_key: str,
label: str,
*,
corpus: str,
city_mentions: list[str],
) -> Any:
key_text = f"{field_key} {label}".lower()
if field_key.endswith("route_cities"):
return city_mentions or []
if "city" in field_key or "location" in field_key:
if any(
token in key_text
for token in ("hotel", "invoice", "attachment", "发票", "酒店", "住宿")
):
return city_mentions[0] if city_mentions else ""
if any(token in key_text for token in ("route", "trip", "目的", "行程", "申报")):
return (
city_mentions[1]
if len(city_mentions) > 1
else (city_mentions[0] if city_mentions else "")
)
return city_mentions[0] if city_mentions else ""
if field_key.endswith("amount"):
return self._extract_amount(corpus)
if field_key.endswith("issue_date") or field_key.endswith("item_date"):
return self._extract_iso_date(corpus)
if field_key.endswith("invoice_no"):
return self._extract_invoice_no(corpus)
if field_key.endswith("ocr_text"):
return corpus
if field_key.endswith("goods_name"):
return self._infer_goods_name(corpus)
if field_key.endswith("item_type"):
return self._infer_item_type(corpus)
if field_key.endswith("reason") or field_key.endswith("item_reason"):
return corpus or "仿真测试报销事由"
return None
def _apply_compare_city_hints(
self,
manifest: dict[str, Any],
values: dict[str, Any],
source_map: dict[str, str],
city_mentions: list[str],
) -> None:
if len(city_mentions) < 2:
return
params = manifest.get("params") if isinstance(manifest.get("params"), dict) else {}
conditions = params.get("conditions") if isinstance(params.get("conditions"), list) else []
for condition in conditions:
if not isinstance(condition, dict):
continue
left = str(condition.get("left") or "").strip()
right = str(condition.get("right") or "").strip()
if not left or not right:
continue
if self._looks_like_city_field(left):
values[left] = city_mentions[0]
source_map[left] = source_map.get(left) or "inferred"
if self._looks_like_city_field(right):
values[right] = city_mentions[1]
source_map[right] = source_map.get(right) or "inferred"
@staticmethod
def _normalize_simulation_attachments(
attachments: list[AgentAssetRiskRuleSimulationAttachment],
) -> list[dict[str, Any]]:
normalized: list[dict[str, Any]] = []
for item in list(attachments or [])[:12]:
normalized.append(
{
"name": str(item.name or "").strip(),
"content_type": str(item.content_type or "").strip(),
"size": item.size or 0,
"note": str(item.note or "").strip(),
"ocr_text": str(item.ocr_text or "").strip(),
"summary": str(item.summary or "").strip(),
"document_type": str(item.document_type or "").strip(),
"document_type_label": str(item.document_type_label or "").strip(),
"scene_code": str(item.scene_code or "").strip(),
"scene_label": str(item.scene_label or "").strip(),
"avg_score": float(item.avg_score or 0.0),
"recognition_status": str(item.recognition_status or "").strip(),
"document_fields": AgentAssetRiskRuleSimulationMixin._normalize_document_fields(
item.document_fields
),
}
)
return normalized
@staticmethod
def _build_simulation_corpus(message: str, attachments: list[dict[str, Any]]) -> str:
parts = [str(message or "").strip()]
for item in attachments:
parts.append(str(item.get("name") or "").strip())
parts.append(str(item.get("note") or "").strip())
parts.append(str(item.get("summary") or "").strip())
parts.append(str(item.get("ocr_text") or "").strip())
for field in list(item.get("document_fields") or []):
if isinstance(field, dict):
parts.append(str(field.get("value") or "").strip())
return "\n".join(part for part in parts if part)
@staticmethod
def _normalize_document_fields(fields: list[dict[str, Any]]) -> list[dict[str, Any]]:
normalized: list[dict[str, Any]] = []
for field in list(fields or [])[:30]:
if not isinstance(field, dict):
continue
key = str(field.get("key") or "").strip()
label = str(field.get("label") or "").strip()
value = field.get("value")
if key and label and AgentAssetRiskRuleSimulationMixin._has_meaningful_value(value):
normalized.append({"key": key, "label": label, "value": value})
return normalized
def _find_attachment_field_value(
self,
field_key: str,
label: str,
attachments: list[dict[str, Any]],
) -> Any:
short_key = field_key.removeprefix("attachment.")
for attachment in attachments:
if short_key == "ocr_text":
value = attachment.get("ocr_text") or attachment.get("summary")
if self._has_meaningful_value(value):
return value
for field in list(attachment.get("document_fields") or []):
if not isinstance(field, dict):
continue
candidate_key = str(field.get("key") or "").strip().lower()
candidate_label = str(field.get("label") or "").strip()
if self._field_matches_simulation_key(
candidate_key, candidate_label, short_key, label
):
return field.get("value")
return None
@staticmethod
def _field_matches_simulation_key(
candidate_key: str,
candidate_label: str,
short_key: str,
target_label: str,
) -> bool:
compact_candidate = candidate_key.replace("_", "")
compact_target = short_key.replace("_", "").lower()
if compact_target and compact_target in compact_candidate:
return True
label_text = f"{candidate_label} {target_label}"
label_map = {
"invoice_no": ("发票号", "发票号码", "票号"),
"hotel_city": ("住宿城市", "酒店城市", "酒店地点", "住宿", "酒店"),
"route_cities": ("行程", "路线", "目的地", "出差城市"),
"goods_name": ("品名", "商品", "服务名称"),
"amount": ("金额", "价税合计", "合计"),
"issue_date": ("日期", "开票日期", "发票日期"),
}
return any(token in label_text for token in label_map.get(short_key, ()))
def _extract_execution_field_keys(self, manifest: dict[str, Any]) -> list[str]:
params = manifest.get("params") if isinstance(manifest.get("params"), dict) else {}
template_key = str(manifest.get("template_key") or params.get("template_key") or "").strip()
keys: list[str] = []
if template_key == "field_compare_v1":
conditions = (
params.get("conditions") if isinstance(params.get("conditions"), list) else []
)
for condition in conditions:
if not isinstance(condition, dict):
continue
for side in ("left", "right"):
key = str(condition.get(side) or "").strip()
if key and key not in keys:
keys.append(key)
elif template_key == "keyword_match_v1":
for key in self._read_string_list(
params.get("search_fields") or params.get("field_keys")
):
if key not in keys:
keys.append(key)
elif template_key == "field_required_v1":
return []
return keys
def _build_missing_fields(
self,
manifest: dict[str, Any],
*,
field_values: dict[str, Any],
source_map: dict[str, str],
required_keys: list[str],
) -> list[dict[str, Any]]:
labels = {field["key"]: field["label"] for field in self._extract_manifest_fields(manifest)}
missing: list[dict[str, Any]] = []
for key in required_keys:
value = field_values.get(key)
if key not in source_map or not self._has_meaningful_value(value):
missing.append({"key": key, "label": labels.get(key, key)})
return missing
def _resolve_simulation_block(
self,
manifest: dict[str, Any],
*,
message: str,
attachments: list[dict[str, Any]],
missing_fields: list[dict[str, Any]],
) -> dict[str, str] | None:
has_attachment = bool(attachments)
requires_attachment = self._rule_requires_attachment(manifest)
has_recognition = any(
self._has_meaningful_value(item.get("ocr_text"))
or self._has_meaningful_value(item.get("summary"))
or self._has_meaningful_value(item.get("document_fields"))
for item in attachments
)
has_user_evidence = self._has_meaningful_user_message(message)
if requires_attachment and not has_attachment:
return {
"stage": "needs_attachment",
"summary": "当前规则要求上传附件,暂不能仅凭文字执行风险判断。",
"reason": "请上传测试单据,并填写本次测试意图后再执行仿真。",
}
if requires_attachment and not has_user_evidence:
return {
"stage": "needs_test_intent",
"summary": "当前规则要求附件和测试说明一起进入仿真判断。",
"reason": "请补充本次测试意图或关键业务事实,再执行风险识别。",
}
if has_attachment and not has_recognition and not has_user_evidence:
return {
"stage": "needs_recognition",
"summary": "单据尚未完成识别,暂不能执行风险规则。",
"reason": "请先完成 OCR 识别,或在对话中补充票据城市、金额、发票号等关键信息。",
}
template_key = str(
manifest.get("template_key") or (manifest.get("params") or {}).get("template_key") or ""
).strip()
if template_key != "field_required_v1" and missing_fields:
labels = "".join(
str(item.get("label") or item.get("key")) for item in missing_fields[:4]
)
return {
"stage": "needs_field_confirmation",
"summary": f"还缺少规则执行所需字段:{labels},暂不能判断是否命中。",
"reason": "请补充缺失字段,或上传可识别出这些字段的票据后再执行。",
}
if not has_attachment and not has_user_evidence:
return {
"stage": "needs_input",
"summary": "请先描述测试单据或上传票据,再执行风险识别。",
"reason": "当前没有可用于规则判断的业务事实。",
}
return None
@staticmethod
def _rule_requires_attachment(manifest: dict[str, Any]) -> bool:
if bool(manifest.get("requires_attachment")):
return True
metadata = manifest.get("metadata") if isinstance(manifest.get("metadata"), dict) else {}
return bool(metadata.get("requires_attachment"))
@staticmethod
def _has_meaningful_user_message(message: str) -> bool:
text = str(message or "").strip()
if not text:
return False
generic_prompts = (
"请识别我上传的临时单据是否命中这条风险规则",
"请识别上传单据是否命中风险规则",
)
return not any(prompt in text for prompt in generic_prompts)
@staticmethod
def _build_recognized_fields(
fields: list[dict[str, str]],
values: dict[str, Any],
source_map: dict[str, str],
) -> list[dict[str, Any]]:
labels = {field["key"]: field["label"] for field in fields}
return [
{
"key": key,
"label": labels.get(key, key),
"value": value,
"source": source_map.get(key, ""),
}
for key, value in values.items()
if source_map.get(key)
]
@staticmethod
def _build_recognition_summary(attachments: list[dict[str, Any]]) -> list[dict[str, Any]]:
return [
{
"name": item.get("name") or "",
"status": item.get("recognition_status")
or (
"recognized"
if item.get("ocr_text") or item.get("document_fields")
else "pending"
),
"document_type_label": item.get("document_type_label") or "",
"scene_label": item.get("scene_label") or "",
"summary": item.get("summary") or "",
"field_count": len(list(item.get("document_fields") or [])),
"avg_score": item.get("avg_score") or 0.0,
}
for item in attachments
]
@staticmethod
def _extract_city_mentions(text: str) -> list[str]:
city_names = [
"北京",
"上海",
"广州",
"深圳",
"杭州",
"南京",
"成都",
"武汉",
"重庆",
"天津",
"苏州",
"西安",
]
pattern = "|".join(re.escape(city) for city in city_names)
found: list[str] = []
for match in re.finditer(pattern, text):
city = match.group(0)
if city not in found:
found.append(city)
return found
@staticmethod
def _extract_amount(text: str) -> str:
match = re.search(r"(\d{2,8}(?:\.\d{1,2})?)\s*(?:元|块|人民币|CNY)?", text, re.IGNORECASE)
return match.group(1) if match else ""
@staticmethod
def _extract_iso_date(text: str) -> str:
match = re.search(r"(20\d{2})[-/.年](\d{1,2})[-/.月](\d{1,2})", text)
if not match:
return ""
year, month, day = (int(part) for part in match.groups())
try:
return date(year, month, day).isoformat()
except ValueError:
return ""
@staticmethod
def _extract_invoice_no(text: str) -> str:
match = re.search(r"(?:发票号|发票号码|票号)[:\s]*([A-Z0-9-]{6,32})", text, re.IGNORECASE)
return match.group(1) if match else ""
@staticmethod
def _infer_item_type(text: str) -> str:
if not text:
return ""
if any(keyword in text for keyword in ("酒店", "住宿", "宾馆")):
return "住宿费"
if any(keyword in text for keyword in ("机票", "航班", "火车", "高铁", "打车")):
return "交通费"
if any(keyword in text for keyword in ("餐饮", "餐费", "招待")):
return "餐饮费"
return "差旅费"
@staticmethod
def _infer_goods_name(text: str) -> str:
if not text:
return ""
if any(keyword in text for keyword in ("酒店", "住宿", "宾馆")):
return "住宿服务"
if any(keyword in text for keyword in ("机票", "航班", "火车", "高铁", "打车")):
return "交通服务"
if any(keyword in text for keyword in ("餐饮", "餐费", "招待")):
return "餐饮服务"
return "报销服务"
@staticmethod
def _looks_like_city_field(field_key: str) -> bool:
lowered = field_key.lower()
return "city" in lowered or "location" in lowered or lowered.endswith("route_cities")
@staticmethod
def _has_meaningful_value(value: Any) -> bool:
if value is None:
return False
if isinstance(value, str):
return bool(value.strip())
if isinstance(value, (list, tuple, set, dict)):
return bool(value)
return True
@staticmethod
def _risk_severity_label(severity: str) -> str:
return {
"low": "低风险",
"medium": "中风险",
"high": "高风险",
"none": "未命中",
}.get(str(severity or "").strip().lower(), "风险")
@staticmethod
def _read_string_list(value: Any) -> list[str]:
if not isinstance(value, list):
return []
return [str(item or "").strip() for item in value if str(item or "").strip()]

View File

@@ -0,0 +1,723 @@
from __future__ import annotations
import re
from datetime import UTC, date, datetime, timedelta
from decimal import Decimal, InvalidOperation
from typing import Any
from sqlalchemy import or_, select
from app.core.agent_enums import (
AgentAssetDomain,
AgentAssetStatus,
AgentAssetType,
AgentReviewStatus,
)
from app.models.agent_asset import AgentAsset, AgentAssetReview, AgentAssetTestRun
from app.models.financial_record import ExpenseClaim, ExpenseClaimItem
from app.schemas.agent_asset import (
AgentAssetRiskRuleLatestTestSummary,
AgentAssetRiskRuleReportRequest,
AgentAssetRiskRuleSampleCase,
AgentAssetRiskRuleSampleTestRequest,
AgentAssetRiskRuleScenarioTestRequest,
AgentAssetRiskRuleTestRunRead,
)
from app.services.expense_claims import ExpenseClaimService
from app.services.risk_rule_template_executor import RiskRuleTemplateExecutor
class AgentAssetRiskRuleTestingMixin:
def get_latest_risk_rule_test_summary(
self,
asset_or_id: AgentAsset | str,
*,
version: str | None = None,
) -> AgentAssetRiskRuleLatestTestSummary:
asset = self._resolve_asset(asset_or_id)
target_version = self._resolve_target_version(asset, version)
sample = self.repository.get_latest_test_run(
asset.id, version=target_version, test_type="sample"
)
scenario = self.repository.get_latest_test_run(
asset.id, version=target_version, test_type="scenario"
)
report = self.repository.get_latest_test_run(
asset.id, version=target_version, test_type="report", status="passed"
)
return AgentAssetRiskRuleLatestTestSummary(
version=target_version,
sample=self._serialize_test_run(sample),
scenario=self._serialize_test_run(scenario),
report=self._serialize_test_run(report),
test_passed=bool(report and report.passed),
)
def run_risk_rule_sample_test(
self,
asset_id: str,
body: AgentAssetRiskRuleSampleTestRequest,
*,
actor: str,
request_id: str | None = None,
) -> AgentAssetRiskRuleTestRunRead:
asset, version, manifest = self._load_risk_rule_for_test(asset_id, body.version)
cases = body.cases or self._build_default_sample_cases(manifest)
results = [self._run_sample_case(manifest, case) for case in cases]
passed = bool(results) and all(item["passed"] for item in results)
summary = f"快速样例测试 {'通过' if passed else '未通过'},共 {len(results)} 条。"
return self._create_test_run(
asset,
version=version,
test_type="sample",
passed=passed,
summary=summary,
input_json={"cases": [case.model_dump() for case in cases]},
result_json={"cases": results, "case_count": len(results)},
actor=actor,
request_id=request_id,
)
def run_risk_rule_scenario_test(
self,
asset_id: str,
body: AgentAssetRiskRuleScenarioTestRequest,
*,
actor: str,
request_id: str | None = None,
) -> AgentAssetRiskRuleTestRunRead:
asset, version, manifest = self._load_risk_rule_for_test(asset_id, body.version)
if asset.domain != AgentAssetDomain.EXPENSE.value:
raise ValueError("一期真实场景试运行仅支持报销业务域。")
parsed_scope = self._parse_scenario_scope(body.intent, body.filters)
claims = self._query_expense_claim_samples(parsed_scope)
claim_results = [self._run_claim_scenario(manifest, claim) for claim in claims]
hit_items = [item for item in claim_results if item["hit"]]
severity_counts: dict[str, int] = {}
for item in hit_items:
severity = str(item.get("severity") or "unknown")
severity_counts[severity] = severity_counts.get(severity, 0) + 1
passed = bool(claim_results)
summary = (
f"真实场景试运行完成,样本 {len(claim_results)} 条,命中 {len(hit_items)} 条。"
if passed
else "真实场景试运行未找到可测样本。"
)
return self._create_test_run(
asset,
version=version,
test_type="scenario",
passed=passed,
summary=summary,
input_json={
"intent": body.intent,
"filters": body.filters,
"parsed_scope": parsed_scope,
},
result_json={
"total_count": len(claim_results),
"hit_count": len(hit_items),
"severity_counts": severity_counts,
"items": claim_results[:50],
},
actor=actor,
request_id=request_id,
)
def confirm_risk_rule_test_report(
self,
asset_id: str,
body: AgentAssetRiskRuleReportRequest,
*,
actor: str,
request_id: str | None = None,
) -> AgentAssetRiskRuleTestRunRead:
asset, version, _ = self._load_risk_rule_for_test(asset_id, body.version)
sample = self.repository.get_latest_test_run(
asset.id, version=version, test_type="sample", status="passed"
)
scenario = self.repository.get_latest_test_run(
asset.id, version=version, test_type="scenario"
)
if sample is None:
raise ValueError("提交审核前必须先完成快速样例测试。")
if not body.confirm_passed:
raise ValueError("请确认测试通过后再保存测试报告。")
summary = "测试报告已确认,当前版本可提交审核。"
if scenario is None:
summary = "快速样例测试已确认通过,真实场景试运行未执行。"
elif not scenario.passed:
summary = "快速样例测试已确认通过,真实场景试运行未找到可测样本。"
return self._create_test_run(
asset,
version=version,
test_type="report",
passed=True,
summary=summary,
input_json={"confirm_passed": True, "note": body.note or ""},
result_json={
"sample_test_run_id": sample.id,
"scenario_test_run_id": scenario.id,
"sample_summary": sample.summary,
"scenario_summary": scenario.summary,
},
actor=actor,
request_id=request_id,
)
def delete_unpublished_asset(
self,
asset_id: str,
*,
actor: str,
request_id: str | None = None,
) -> None:
asset = self._resolve_asset(asset_id)
self._require_json_risk_asset(asset)
if str(asset.published_version or "").strip():
raise PermissionError("已发布过的风险规则不能删除。")
before = self._asset_snapshot(asset)
self._delete_risk_rule_json_file(asset)
self.repository.delete_asset(asset)
self.audit_service.log_action(
actor=actor,
action="delete_agent_asset",
resource_type=AgentAssetType.RULE.value,
resource_id=asset_id,
before_json=before,
after_json={"deleted": True},
request_id=request_id,
)
def return_risk_rule(
self,
asset_id: str,
*,
note: str,
actor: str,
request_id: str | None = None,
) -> AgentAssetRiskRuleLatestTestSummary:
asset = self._resolve_asset(asset_id)
self._require_json_risk_asset(asset)
version = self._resolve_target_version(asset, None)
if asset.status != AgentAssetStatus.REVIEW.value:
raise ValueError("只有待审核风险规则可以回退。")
before = self._asset_snapshot(asset)
review = AgentAssetReview(
asset_id=asset.id,
version=version,
reviewer=actor,
review_status=AgentReviewStatus.REJECTED.value,
review_note=str(note or "审核回退").strip() or "审核回退",
reviewed_at=datetime.now(UTC),
)
self.db.add(review)
asset.reviewer = actor
asset.status = AgentAssetStatus.DRAFT.value
self.db.add(asset)
self.db.commit()
self.audit_service.log_action(
actor=actor,
action="return_agent_asset",
resource_type=AgentAssetType.RULE.value,
resource_id=asset.id,
before_json=before,
after_json={"version": version, "status": asset.status, "note": note},
request_id=request_id,
)
return self.get_latest_risk_rule_test_summary(asset)
def publish_risk_rule(
self,
asset_id: str,
*,
actor: str,
request_id: str | None = None,
) -> AgentAsset:
asset = self._resolve_asset(asset_id)
self._require_json_risk_asset(asset)
version = self._resolve_target_version(asset, None)
if asset.status != AgentAssetStatus.REVIEW.value:
raise ValueError("只有待审核风险规则可以发布上线。")
if not self.get_latest_risk_rule_test_summary(asset, version=version).test_passed:
raise PermissionError("当前规则版本尚未完成测试通过确认,不能发布。")
before = self._asset_snapshot(asset)
approved_review = self.repository.get_review(
asset.id, version, AgentReviewStatus.APPROVED.value
)
if approved_review is None:
self.db.add(
AgentAssetReview(
asset_id=asset.id,
version=version,
reviewer=actor,
review_status=AgentReviewStatus.APPROVED.value,
review_note="发布上线前审核通过。",
reviewed_at=datetime.now(UTC),
)
)
asset.reviewer = actor
asset.published_version = version
asset.status = AgentAssetStatus.ACTIVE.value
self.db.add(asset)
self.db.commit()
self.audit_service.log_action(
actor=actor,
action="publish_agent_asset",
resource_type=AgentAssetType.RULE.value,
resource_id=asset.id,
before_json=before,
after_json=self._asset_snapshot(asset),
request_id=request_id,
)
refreshed = self.repository.get(asset.id)
if refreshed is None:
raise LookupError("Asset not found")
return refreshed
def set_risk_rule_enabled(
self,
asset_id: str,
*,
enabled: bool,
actor: str,
request_id: str | None = None,
) -> AgentAsset:
asset = self._resolve_asset(asset_id)
self._require_json_risk_asset(asset)
before = self._asset_snapshot(asset)
rule_library, file_name = self._resolve_json_risk_rule_document(asset)
manifest = self.rule_library_manager.read_rule_library_json(
library=rule_library,
file_name=file_name,
)
manifest["enabled"] = bool(enabled)
self.rule_library_manager.write_rule_library_json(
library=rule_library,
file_name=file_name,
payload=manifest,
)
config_json = dict(asset.config_json or {})
config_json["enabled"] = bool(enabled)
asset.config_json = config_json
updated = self.repository.save_asset(asset)
self.audit_service.log_action(
actor=actor,
action="set_risk_rule_enabled",
resource_type=AgentAssetType.RULE.value,
resource_id=asset.id,
before_json=before,
after_json=self._asset_snapshot(updated),
request_id=request_id,
)
return updated
def _load_risk_rule_for_test(
self, asset_id: str, version: str | None
) -> tuple[AgentAsset, str, dict[str, Any]]:
asset = self._resolve_asset(asset_id)
self._require_json_risk_asset(asset)
target_version = self._resolve_target_version(asset, version)
if self.repository.get_version(asset.id, target_version) is None:
raise LookupError(f"版本 {target_version} 不存在")
rule_library, file_name = self._resolve_json_risk_rule_document(asset)
manifest = self.rule_library_manager.read_rule_library_json(
library=rule_library,
file_name=file_name,
)
return asset, target_version, manifest
def _create_test_run(
self,
asset: AgentAsset,
*,
version: str,
test_type: str,
passed: bool,
summary: str,
input_json: dict[str, Any],
result_json: dict[str, Any],
actor: str,
request_id: str | None,
) -> AgentAssetRiskRuleTestRunRead:
status = "passed" if passed else "failed"
created = self.repository.create_test_run(
AgentAssetTestRun(
asset_id=asset.id,
version=version,
test_type=test_type,
status=status,
passed=passed,
summary=summary,
input_json=input_json,
result_json=result_json,
created_by=actor,
)
)
self.audit_service.log_action(
actor=actor,
action=f"risk_rule_test_{test_type}",
resource_type=AgentAssetType.RULE.value,
resource_id=asset.id,
before_json=None,
after_json={"version": version, "status": status, "summary": summary},
request_id=request_id,
)
return AgentAssetRiskRuleTestRunRead.model_validate(created)
def _run_sample_case(
self,
manifest: dict[str, Any],
case: AgentAssetRiskRuleSampleCase,
) -> dict[str, Any]:
claim, contexts = self._build_synthetic_claim(case.values, manifest)
result = RiskRuleTemplateExecutor().evaluate(manifest, claim=claim, contexts=contexts)
actual_hit = result is not None
actual_severity = (
str((manifest.get("outcomes") or {}).get("fail", {}).get("severity") or "").strip()
if actual_hit
else "none"
)
expected_severity = str(case.expected_severity or "").strip()
severity_passed = (
not actual_hit or not expected_severity or expected_severity == actual_severity
)
passed = actual_hit == case.expected_hit and severity_passed
return {
"case_id": case.case_id or "",
"name": case.name,
"values": case.values,
"expected_hit": case.expected_hit,
"expected_severity": expected_severity,
"actual_hit": actual_hit,
"actual_severity": actual_severity,
"passed": passed,
"message": str(result.get("message") or "") if isinstance(result, dict) else "",
"evidence": result.get("evidence") if isinstance(result, dict) else {},
}
def _run_claim_scenario(self, manifest: dict[str, Any], claim: ExpenseClaim) -> dict[str, Any]:
contexts = ExpenseClaimService(self.db)._build_claim_attachment_contexts(claim)
result = RiskRuleTemplateExecutor().evaluate(manifest, claim=claim, contexts=contexts)
hit = result is not None
return {
"claim_id": claim.id,
"claim_no": claim.claim_no,
"employee_name": claim.employee_name,
"department_name": claim.department_name,
"expense_type": claim.expense_type,
"amount": float(claim.amount or 0),
"status": claim.status,
"occurred_at": claim.occurred_at.isoformat() if claim.occurred_at else "",
"hit": hit,
"severity": str((manifest.get("outcomes") or {}).get("fail", {}).get("severity") or "")
if hit
else "none",
"message": str(result.get("message") or "") if isinstance(result, dict) else "",
"evidence": result.get("evidence") if isinstance(result, dict) else {},
}
def _build_synthetic_claim(
self,
values: dict[str, Any],
manifest: dict[str, Any],
) -> tuple[ExpenseClaim, list[dict[str, Any]]]:
claim = ExpenseClaim(
claim_no="TEST-RISK-RULE",
employee_name=str(values.get("claim.employee_name") or "测试员工"),
department_name=str(values.get("claim.department_name") or "测试部门"),
expense_type=str(values.get("item.item_type") or "差旅费"),
reason=str(values.get("claim.reason") or "测试报销事由"),
location=str(values.get("claim.location") or "北京"),
amount=self._to_decimal(values.get("claim.amount")),
currency="CNY",
invoice_count=1,
occurred_at=datetime.now(UTC),
status="draft",
)
item = ExpenseClaimItem(
item_date=date.today(),
item_type=str(values.get("item.item_type") or "住宿费"),
item_reason=str(values.get("item.item_reason") or claim.reason),
item_location=str(values.get("item.item_location") or claim.location),
item_amount=self._to_decimal(values.get("item.item_amount") or claim.amount),
)
claim.items = [item]
attachment_fields = []
document_info: dict[str, Any] = {"fields": attachment_fields}
for field in self._extract_manifest_fields(manifest):
key = field["key"]
if key not in values:
continue
value = self._coerce_sample_value(key, values.get(key))
if key.startswith("claim."):
setattr(claim, key.removeprefix("claim."), value)
elif key.startswith("item."):
setattr(item, key.removeprefix("item."), value)
elif key.startswith("attachment."):
short_key = key.removeprefix("attachment.")
document_info[short_key] = value
attachment_fields.append(
{"key": short_key, "label": field["label"], "value": value}
)
return claim, [
{
"document_info": document_info,
"ocr_text": document_info.get("ocr_text", ""),
}
]
def _build_default_sample_cases(
self,
manifest: dict[str, Any],
) -> list[AgentAssetRiskRuleSampleCase]:
fields = self._extract_manifest_fields(manifest)
severity = str((manifest.get("outcomes") or {}).get("fail", {}).get("severity") or "")
template_key = str(manifest.get("template_key") or "").strip()
hit_values = self._find_case_values_for_expected(manifest, fields, expected_hit=True)
pass_values = self._find_case_values_for_expected(manifest, fields, expected_hit=False)
cases = [
AgentAssetRiskRuleSampleCase(
case_id="hit",
name="应该命中风险",
values=hit_values,
expected_hit=True,
expected_severity=severity,
note="验证规则能识别异常样本。",
),
AgentAssetRiskRuleSampleCase(
case_id="pass",
name="应该不命中",
values=pass_values,
expected_hit=False,
expected_severity="none",
note="验证正常样本不会误触发。",
),
]
if template_key == "field_required_v1":
cases.append(
AgentAssetRiskRuleSampleCase(
case_id="missing",
name="关键字段缺失",
values={key: "" for key in hit_values},
expected_hit=True,
expected_severity=severity,
note="验证缺字段时会进入复核。",
)
)
return cases
def _find_case_values_for_expected(
self,
manifest: dict[str, Any],
fields: list[dict[str, str]],
*,
expected_hit: bool,
) -> dict[str, Any]:
candidates = [
self._build_case_values(manifest, fields, hit=expected_hit),
{field["key"]: self._default_value_for_field(field["key"]) for field in fields},
{
field["key"]: ("上海" if index == 0 else "北京")
for index, field in enumerate(fields)
},
{field["key"]: "北京" for field in fields},
{field["key"]: "" for field in fields},
]
severity = str((manifest.get("outcomes") or {}).get("fail", {}).get("severity") or "")
for values in candidates:
probe = AgentAssetRiskRuleSampleCase(
name="默认样例探测",
values=values,
expected_hit=expected_hit,
expected_severity=severity if expected_hit else "none",
)
result = self._run_sample_case(manifest, probe)
if bool(result["actual_hit"]) == expected_hit:
return values
return candidates[0]
def _build_case_values(
self,
manifest: dict[str, Any],
fields: list[dict[str, str]],
*,
hit: bool,
) -> dict[str, Any]:
values = {field["key"]: self._default_value_for_field(field["key"]) for field in fields}
template_key = str(manifest.get("template_key") or "").strip()
params = manifest.get("params") if isinstance(manifest.get("params"), dict) else {}
if template_key == "field_compare_v1":
condition = next(
(item for item in params.get("conditions", []) if isinstance(item, dict)),
{},
)
left = str(condition.get("left") or "").strip()
right = str(condition.get("right") or "").strip()
operator = str(condition.get("operator") or "overlap").strip()
if left and operator == "is_empty":
values[left] = "测试值" if hit else ""
elif left and right and operator in {"not_equals", "not_in", "not_overlap"}:
values[left] = "北京" if hit else "上海"
values[right] = "北京"
elif left and right:
values[left] = "上海" if hit else "北京"
values[right] = "北京"
elif template_key == "field_required_v1" and hit and fields:
values[fields[0]["key"]] = ""
elif template_key == "keyword_match_v1":
keywords = params.get("keywords") if isinstance(params.get("keywords"), list) else []
keyword = str(next(iter(keywords), "咨询费") or "咨询费")
target_key = fields[0]["key"] if fields else "claim.reason"
values[target_key] = f"本次报销包含{keyword}" if hit else "正常差旅报销"
return values
@staticmethod
def _default_value_for_field(field_key: str) -> Any:
if field_key.endswith("amount"):
return "100.00"
if field_key.endswith("issue_date"):
return date.today().isoformat()
if field_key.endswith("route_cities"):
return ["北京"]
if field_key.endswith("ocr_text"):
return "正常发票内容"
if "city" in field_key or "location" in field_key:
return "北京"
if field_key.endswith("item_type"):
return "住宿费"
return "测试值"
def _query_expense_claim_samples(self, parsed_scope: dict[str, Any]) -> list[ExpenseClaim]:
days = int(parsed_scope.get("days") or 30)
limit = min(max(int(parsed_scope.get("limit") or 50), 1), 200)
since = datetime.now(UTC) - timedelta(days=days)
stmt = select(ExpenseClaim).where(ExpenseClaim.created_at >= since)
expense_keyword = str(parsed_scope.get("expense_keyword") or "").strip()
if expense_keyword:
like_keyword = f"%{expense_keyword}%"
stmt = stmt.where(
or_(
ExpenseClaim.expense_type.ilike(like_keyword),
ExpenseClaim.reason.ilike(like_keyword),
)
)
cities = [str(item or "").strip() for item in parsed_scope.get("cities", []) if item]
if cities:
city_filters = []
for city in cities[:8]:
like_city = f"%{city}%"
city_filters.extend(
[
ExpenseClaim.location.ilike(like_city),
ExpenseClaim.reason.ilike(like_city),
]
)
stmt = stmt.where(or_(*city_filters))
stmt = stmt.order_by(ExpenseClaim.created_at.desc()).limit(limit)
return list(self.db.scalars(stmt).all())
@staticmethod
def _parse_scenario_scope(intent: str, filters: dict[str, Any]) -> dict[str, Any]:
text = str(intent or "")
raw_days = filters.get("days") or filters.get("recent_days")
days = int(raw_days) if str(raw_days or "").isdigit() else 30
match = re.search(r"最近\s*(\d{1,3})\s*天", text)
if match:
days = int(match.group(1))
limit = filters.get("limit") if str(filters.get("limit") or "").isdigit() else 50
expense_keyword = str(filters.get("expense_keyword") or "").strip()
if not expense_keyword and any(keyword in text for keyword in ("酒店", "住宿")):
expense_keyword = "住宿"
city_candidates = ("北京", "上海", "广州", "深圳", "武汉", "杭州", "成都", "南京")
cities = [
city
for city in city_candidates
if city in text or city in [str(item) for item in filters.get("cities", []) or []]
]
return {
"business_domain": "expense",
"days": max(1, min(days, 365)),
"limit": max(1, min(int(limit), 200)),
"expense_keyword": expense_keyword,
"cities": cities,
"execution_mode": "dry_run",
}
@staticmethod
def _extract_manifest_fields(manifest: dict[str, Any]) -> list[dict[str, str]]:
inputs = manifest.get("inputs") if isinstance(manifest.get("inputs"), dict) else {}
fields = inputs.get("fields") if isinstance(inputs.get("fields"), list) else []
normalized = []
for item in fields:
if not isinstance(item, dict):
continue
key = str(item.get("key") or "").strip()
if key:
normalized.append({"key": key, "label": str(item.get("label") or key).strip()})
return normalized
@staticmethod
def _coerce_sample_value(field_key: str, value: Any) -> Any:
if field_key.endswith("route_cities") and isinstance(value, str):
return [item.strip() for item in re.split(r"[,,、/ ]+", value) if item.strip()]
return value
@staticmethod
def _to_decimal(value: Any) -> Decimal:
try:
return Decimal(str(value or "0"))
except (InvalidOperation, ValueError):
return Decimal("0")
def _resolve_asset(self, asset_or_id: AgentAsset | str) -> AgentAsset:
if isinstance(asset_or_id, AgentAsset):
return asset_or_id
asset = self.repository.get(str(asset_or_id))
if asset is None:
raise LookupError("Asset not found")
return asset
@staticmethod
def _require_json_risk_asset(asset: AgentAsset) -> None:
config_json = asset.config_json if isinstance(asset.config_json, dict) else {}
if asset.asset_type != AgentAssetType.RULE.value:
raise ValueError("仅规则资产支持风险规则操作。")
if str(config_json.get("detail_mode") or "").strip().lower() != "json_risk":
raise ValueError("仅 JSON 风险规则支持该操作。")
def _resolve_target_version(self, asset: AgentAsset, version: str | None) -> str:
target = str(version or self._resolve_working_version(asset) or "").strip()
if not target:
raise ValueError("当前规则尚未配置工作版本。")
return target
def _delete_risk_rule_json_file(self, asset: AgentAsset) -> None:
try:
rule_library, file_name = self._resolve_json_risk_rule_document(asset)
target = self.rule_library_manager.resolve_rule_library_path(
library=rule_library,
file_name=file_name,
)
target.unlink(missing_ok=True)
except (FileNotFoundError, ValueError):
return
@staticmethod
def _serialize_test_run(
run: AgentAssetTestRun | None,
) -> AgentAssetRiskRuleTestRunRead | None:
return AgentAssetRiskRuleTestRunRead.model_validate(run) if run is not None else None

View File

@@ -4,6 +4,7 @@ import json
from collections import defaultdict
from datetime import UTC, datetime
from typing import Any
from sqlalchemy.orm import Session
from app.core.agent_enums import (
@@ -27,13 +28,14 @@ from app.schemas.agent_asset import (
)
from app.services.agent_asset_json_rules import AgentAssetJsonRuleMixin
from app.services.agent_asset_onlyoffice import AgentAssetOnlyOfficeMixin
from app.services.agent_asset_risk_rule_simulation import AgentAssetRiskRuleSimulationMixin
from app.services.agent_asset_risk_rule_testing import AgentAssetRiskRuleTestingMixin
from app.services.agent_asset_rule_library import AgentAssetRuleLibraryManager
from app.services.agent_asset_spreadsheet import AgentAssetSpreadsheetManager
from app.services.agent_asset_spreadsheet_helpers import AgentAssetSpreadsheetHelperMixin
from app.services.agent_asset_timeline import AgentAssetTimelineMixin
from app.services.agent_asset_spreadsheet import AgentAssetSpreadsheetManager
from app.services.agent_foundation import AgentFoundationService
from app.services.audit import AuditLogService
from app.services.settings import resolve_onlyoffice_settings
logger = get_logger("app.services.agent_assets")
@@ -41,6 +43,8 @@ logger = get_logger("app.services.agent_assets")
class AgentAssetService(
AgentAssetOnlyOfficeMixin,
AgentAssetSpreadsheetHelperMixin,
AgentAssetRiskRuleTestingMixin,
AgentAssetRiskRuleSimulationMixin,
AgentAssetTimelineMixin,
AgentAssetJsonRuleMixin,
):
@@ -66,10 +70,7 @@ class AgentAssetService(
asset_type=asset_type, status=status, domain=domain, keyword=keyword
)
version_stats = self._collect_version_stats(assets)
return [
self._serialize_list_item(asset, version_stats.get(asset.id))
for asset in assets
]
return [self._serialize_list_item(asset, version_stats.get(asset.id)) for asset in assets]
def get_asset(self, asset_id: str) -> AgentAssetRead | None:
self._ensure_ready()
@@ -88,9 +89,7 @@ class AgentAssetService(
else next(iter(self.repository.list_reviews(asset_id, limit=1)), None)
)
current_version = (
self.repository.get_version(asset_id, working_version)
if working_version
else None
self.repository.get_version(asset_id, working_version) if working_version else None
)
version_stats = self._collect_version_stats([asset]).get(asset.id)
return AgentAssetRead(
@@ -100,12 +99,14 @@ class AgentAssetService(
else None,
current_version_content_type=current_version.content_type if current_version else None,
current_version_change_note=current_version.change_note if current_version else None,
recent_versions=[
self._serialize_version(item, asset) for item in recent_versions
],
recent_versions=[self._serialize_version(item, asset) for item in recent_versions],
latest_review=AgentAssetReviewRead.model_validate(latest_review)
if latest_review
else None,
latest_test_summary=self.get_latest_risk_rule_test_summary(asset)
if str((asset.config_json or {}).get("detail_mode") or "").strip().lower()
== "json_risk"
else None,
)
def create_asset(
@@ -301,6 +302,13 @@ class AgentAssetService(
if self.repository.get_version(asset_id, payload.version) is None:
raise LookupError(f"版本 {payload.version} 不存在")
if asset.asset_type == AgentAssetType.RULE.value:
if (
str((asset.config_json or {}).get("detail_mode") or "").strip().lower()
== "json_risk"
and payload.review_status == AgentReviewStatus.PENDING
and not self.get_latest_risk_rule_test_summary(asset).test_passed
):
raise PermissionError("当前规则版本尚未完成测试通过确认,不能提交审核。")
working_version = self._resolve_working_version(asset)
if payload.version != working_version:
raise ValueError("只能对当前工作版本发起审核。")
@@ -594,11 +602,10 @@ class AgentAssetService(
),
)
def _collect_version_stats(
self, assets: list[AgentAsset]
) -> dict[str, dict[str, int | str | None]]:
def _collect_version_stats(self, assets: list[AgentAsset]) -> dict[str, dict[str, Any]]:
asset_ids = [item.id for item in assets]
versions = self.repository.list_versions_for_assets(asset_ids)
reviews = self.repository.list_reviews_for_assets(asset_ids)
spreadsheet_logs = self.audit_service.repository.list_for_resources(
resource_type=AgentAssetType.RULE.value,
resource_ids=[
@@ -610,23 +617,33 @@ class AgentAssetService(
],
action="edit_rule_spreadsheet",
)
working_versions = {
item.id: self._resolve_working_version(item) for item in assets
}
working_versions = {item.id: self._resolve_working_version(item) for item in assets}
version_counts: dict[str, int] = defaultdict(int)
modified_by: dict[str, str | None] = {item.id: None for item in assets}
published_versions = {item.id: self._resolve_published_version(item) for item in assets}
published_by: dict[str, str | None] = {}
published_at: dict[str, datetime | None] = {}
spreadsheet_edit_counts: dict[str, int] = defaultdict(int)
spreadsheet_last_actor: dict[str, str | None] = {}
spreadsheet_last_changed_at: dict[str, datetime] = {}
for version in versions:
version_counts[version.asset_id] += 1
if (
modified_by.get(version.asset_id) is None
and version.version == working_versions.get(version.asset_id)
):
if modified_by.get(
version.asset_id
) is None and version.version == working_versions.get(version.asset_id):
modified_by[version.asset_id] = version.created_by
for review in reviews:
if review.asset_id in published_at:
continue
if review.version != published_versions.get(review.asset_id):
continue
if review.review_status != AgentReviewStatus.APPROVED.value:
continue
published_by[review.asset_id] = review.reviewer
published_at[review.asset_id] = review.reviewed_at or review.created_at
for log in spreadsheet_logs:
spreadsheet_edit_counts[log.resource_id] += 1
last_changed_at = spreadsheet_last_changed_at.get(log.resource_id)
@@ -652,6 +669,8 @@ class AgentAssetService(
and spreadsheet_last_actor.get(item.id)
else modified_by.get(item.id)
),
"published_by": published_by.get(item.id),
"published_at": published_at.get(item.id),
}
for item in assets
}
@@ -663,9 +682,11 @@ class AgentAssetService(
) -> AgentAssetListItem:
payload = AgentAssetListItem.model_validate(asset).model_dump()
payload["change_count"] = int((version_stats or {}).get("change_count") or 0)
payload["modified_by"] = (
str((version_stats or {}).get("modified_by") or "").strip() or None
payload["modified_by"] = str((version_stats or {}).get("modified_by") or "").strip() or None
payload["published_by"] = (
str((version_stats or {}).get("published_by") or "").strip() or None
)
payload["published_at"] = (version_stats or {}).get("published_at")
return AgentAssetListItem.model_validate(payload)
@staticmethod

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import threading
from sqlalchemy import select
from sqlalchemy import inspect, select, text
from sqlalchemy.orm import Session
from app.core.config import get_settings
@@ -75,6 +75,7 @@ class AgentFoundationService(
try:
Base.metadata.create_all(bind=self.db.get_bind())
self._ensure_agent_asset_schema()
self._ensure_financial_record_schema()
self._seed_agent_assets()
self._sync_demo_financial_records()
self._seed_runs_and_logs()
@@ -88,6 +89,36 @@ class AgentFoundationService(
bind = self.db.get_bind()
return str(getattr(bind, "url", "") or id(bind))
def _ensure_financial_record_schema(self) -> None:
bind = self.db.get_bind()
inspector = inspect(bind)
if "expense_claims" not in inspector.get_table_names():
return
column_names = {column["name"] for column in inspector.get_columns("expense_claims")}
dialect_name = bind.dialect.name
timestamp_type = "TIMESTAMP WITH TIME ZONE" if dialect_name == "postgresql" else "DATETIME"
boolean_default = "FALSE" if dialect_name == "postgresql" else "0"
if "hermes_scanned_at" not in column_names:
self.db.execute(
text(f"ALTER TABLE expense_claims ADD COLUMN hermes_scanned_at {timestamp_type}")
)
if "hermes_risk_flag" not in column_names:
self.db.execute(
text(
"ALTER TABLE expense_claims "
f"ADD COLUMN hermes_risk_flag BOOLEAN DEFAULT {boolean_default} NOT NULL"
)
)
self.db.execute(
text(
"CREATE INDEX IF NOT EXISTS ix_expense_claims_hermes_risk_flag "
"ON expense_claims (hermes_risk_flag)"
)
)
self.db.flush()
def _sync_demo_financial_records(self) -> None:
if get_settings().seed_demo_financial_records:
self._seed_financial_records()

View File

@@ -651,7 +651,11 @@ class EmployeeService:
column_names = {column["name"] for column in inspector.get_columns("employees")}
if "password_hash" not in column_names:
self.db.execute(text("ALTER TABLE employees ADD COLUMN password_hash VARCHAR(255)"))
self.db.flush()
if "compliance_score" not in column_names:
self.db.execute(
text("ALTER TABLE employees ADD COLUMN compliance_score INTEGER DEFAULT 100 NOT NULL")
)
self.db.flush()
def _seed_employee_history(self, employee: Employee, definition: dict[str, Any]) -> None:
existing_keys = {

View File

@@ -141,6 +141,10 @@ EXPENSE_TYPE_KEYWORD_GROUPS: tuple[tuple[str, str, tuple[str, ...]], ...] = (
"办公用品",
"办公耗材",
"办公设备",
"采购",
"集中采购",
"物资采购",
"办公采购",
"办公",
"文具",
"耗材",

View File

@@ -0,0 +1,104 @@
from __future__ import annotations
import json
from datetime import datetime, timedelta, timezone
from typing import Any
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from app.core.logging import get_logger
from app.models.financial_record import ExpenseClaim
from app.services.runtime_chat import RuntimeChatService
logger = get_logger("app.services.hermes_expense_report")
class HermesExpenseReportService:
def __init__(self, db: Session) -> None:
self.db = db
self.chat_service = RuntimeChatService(db)
def generate_weekly_report(self, log_id: str | None = None) -> None:
logger.info("Starting Hermes weekly expense report generation...")
# 1. 聚合数据
aggregated_data = self._aggregate_recent_expenses(days=7)
if not aggregated_data.get("total_amount"):
logger.info("No expense data in the last 7 days. Skipping report.")
return
# 2. 传入大模型分析
report_markdown = self._generate_insights_with_llm(aggregated_data)
if not report_markdown:
logger.warning("Failed to generate expense report from LLM.")
return
# 3. 模拟发送报告
self._deliver_report(report_markdown, log_id)
logger.info("Hermes weekly expense report generation completed.")
def _aggregate_recent_expenses(self, days: int = 7) -> dict[str, Any]:
target_date = datetime.now(timezone.utc) - timedelta(days=days)
# 基础过滤最近N天且不是驳回状态的单据
base_filter = [
ExpenseClaim.occurred_at >= target_date,
ExpenseClaim.status != "rejected"
]
# 1. 按部门汇总
dept_stmt = select(
ExpenseClaim.department_name,
func.sum(ExpenseClaim.amount).label("total")
).where(*base_filter).group_by(ExpenseClaim.department_name)
dept_results = self.db.execute(dept_stmt).all()
by_department = {row.department_name or "Unknown": float(row.total or 0) for row in dept_results}
# 2. 按类目汇总
type_stmt = select(
ExpenseClaim.expense_type,
func.sum(ExpenseClaim.amount).label("total")
).where(*base_filter).group_by(ExpenseClaim.expense_type)
type_results = self.db.execute(type_stmt).all()
by_expense_type = {row.expense_type or "Unknown": float(row.total or 0) for row in type_results}
# 3. 总花费
total_amount = sum(by_department.values())
return {
"period": f"Last {days} days",
"total_amount": total_amount,
"by_department": by_department,
"by_expense_type": by_expense_type
}
def _generate_insights_with_llm(self, data: dict[str, Any]) -> str | None:
system_prompt = (
"你是公司的财务分析专家。请根据提供的最近期业务开销数据,撰写一份简洁有力的【高管费控洞察周报】。\n"
"要求:\n"
"1. 不要机械地罗列数字,要像人一样指出异常(例如:哪个部门花钱最多?打车费是不是异常高?)。\n"
"2. 给出 1 条削减成本的实操建议。\n"
"3. 纯 Markdown 格式输出,不超过 300 字。"
)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": f"开销统计数据:\n{json.dumps(data, ensure_ascii=False, indent=2)}"}
]
response = self.chat_service.complete(
messages,
max_tokens=800,
temperature=0.4
)
return response
def _deliver_report(self, report_markdown: str, log_id: str | None) -> None:
# TODO: 未来在这里接入企微/钉钉机器人或邮件发送接口
logger.info(f"\n================ Hermes Weekly Report [LogID: {log_id}] ================\n"
f"{report_markdown}\n"
f"==========================================================================")

View File

@@ -0,0 +1,135 @@
from __future__ import annotations
import json
from datetime import datetime, timezone
from typing import Any
from sqlalchemy import or_, select
from sqlalchemy.orm import Session
from app.core.logging import get_logger
from app.models.financial_record import ExpenseClaim
from app.models.hermes_config import HermesTaskExecutionLog
from app.models.hermes_report import HermesRiskReport
from app.services.runtime_chat import RuntimeChatService
logger = get_logger("app.services.hermes_risk_scanner")
class HermesRiskScannerService:
def __init__(self, db: Session) -> None:
self.db = db
self.chat_service = RuntimeChatService(db)
def scan_global_risks(self, log_id: str | None = None) -> None:
logger.info("Starting global risk scan for Hermes...")
# 1. Fetch unscanned claims
claims = self._fetch_unscanned_claims()
if not claims:
logger.info("No unscanned claims found. Aborting scan.")
return
logger.info(f"Fetched {len(claims)} claims to analyze.")
# 2. Extract context for LLM
claims_context = []
for c in claims:
claims_context.append({
"claim_id": c.id,
"claim_no": c.claim_no,
"employee_name": c.employee_name,
"department_name": c.department_name,
"expense_type": c.expense_type,
"location": c.location,
"amount": float(c.amount),
"occurred_at": str(c.occurred_at) if c.occurred_at else None,
"reason": c.reason,
})
# 3. Analyze with LLM
risk_results = self._analyze_claims_with_llm(claims_context)
# 4. Process and persist results
detected_risk_count = 0
if risk_results:
for risk in risk_results:
claim_ids = risk.get("claim_ids", [])
if not claim_ids:
continue
detected_risk_count += 1
for cid in claim_ids:
report = HermesRiskReport(
claim_id=cid,
execution_log_id=log_id,
risk_level=risk.get("risk_level", "medium"),
risk_type=risk.get("risk_type", "unknown"),
risk_description=risk.get("description", "No description provided"),
related_claim_ids=claim_ids,
)
self.db.add(report)
# Update claim flags
claim_obj = next((c for c in claims if c.id == cid), None)
if claim_obj:
claim_obj.hermes_risk_flag = True
# 5. Mark all as scanned
now = datetime.now(timezone.utc)
for c in claims:
c.hermes_scanned_at = now
self.db.commit()
logger.info(f"Hermes risk scan completed. Found {detected_risk_count} risks.")
def _fetch_unscanned_claims(self) -> list[ExpenseClaim]:
stmt = select(ExpenseClaim).where(
ExpenseClaim.status.in_(["draft", "submitted", "review"]),
or_(
ExpenseClaim.hermes_scanned_at.is_(None),
ExpenseClaim.hermes_risk_flag.is_(False) # only rescan if it has no flags yet
)
).limit(50) # Batch size to prevent Token overflow
return list(self.db.scalars(stmt).all())
def _analyze_claims_with_llm(self, claims_context: list[dict[str, Any]]) -> list[dict[str, Any]]:
system_prompt = (
"你是 X-Financial 的 Hermes 内控审计智能体。请分析以下近期的报销单数据集合,寻找以下潜在风险:\n"
"1. 拆单行为 (split_billing):同一人在相邻日期针对同一类目/商户提交多笔恰好贴近免审额度的小额单据。\n"
"2. 群体合谋 (collusion):不同部门的员工在同一天去同一家非标准酒店类偏僻商户高额消费。\n"
"3. 异常频次 (frequency_anomaly):某员工在短时间内的打车或招待频次极度不合理。\n"
"请严格以 JSON 数组格式返回结果,如果没有风险返回空数组 `[]`。\n"
"JSON 格式要求:\n"
"[\n"
" {\n"
' "risk_type": "split_billing",\n'
' "risk_level": "high",\n'
' "claim_ids": ["uuid-1", "uuid-2"],\n'
' "description": "详细推理过程,为什么判定为拆单。"\n'
" }\n"
"]\n"
)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": json.dumps(claims_context, ensure_ascii=False, indent=2)}
]
response_text = self.chat_service.complete(
messages,
max_tokens=1500,
temperature=0.1
)
if not response_text:
logger.warning("LLM returned empty response for risk scan.")
return []
# Clean markdown formatting if present
cleaned_text = response_text.replace("```json", "").replace("```", "").strip()
try:
return json.loads(cleaned_text)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse LLM risk scan response as JSON: {e}\nResponse: {response_text}")
return []

View File

@@ -0,0 +1,131 @@
import logging
import threading
import time
from datetime import datetime, timezone
import traceback
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.core.logging import get_logger
from app.db.session import get_session_factory
from app.models.hermes_config import HermesTaskConfig, HermesTaskExecutionLog
from app.services.hermes_risk_scanner import HermesRiskScannerService
from app.services.hermes_expense_report import HermesExpenseReportService
logger = get_logger("app.services.hermes_scheduler")
class HermesScheduler:
def __init__(self) -> None:
self._stop_event = threading.Event()
self._thread: threading.Thread | None = None
self._lock = threading.Lock()
self.session_factory = get_session_factory()
def start(self) -> None:
with self._lock:
if self._thread is not None and self._thread.is_alive():
return
self._stop_event.clear()
self._thread = threading.Thread(
target=self._run_loop,
name="hermes-agent-scheduler",
daemon=True,
)
self._thread.start()
logger.info("Hermes Agent Scheduler started.")
def shutdown(self) -> None:
with self._lock:
thread = self._thread
self._thread = None
self._stop_event.set()
if thread is not None and thread.is_alive():
thread.join(timeout=3)
logger.info("Hermes Agent Scheduler stopped.")
def _run_loop(self) -> None:
logger.info("Hermes background loop is now active. Polling interval: 60s.")
while not self._stop_event.is_set():
try:
self._check_and_run_tasks()
except Exception as e:
logger.error(f"Error in Hermes run loop: {e}", exc_info=True)
# 睡眠一分钟,每分钟轮询一次
if self._stop_event.wait(60.0):
break
def _check_and_run_tasks(self) -> None:
db = self.session_factory()
try:
# 获取所有启用的任务配置
stmt = select(HermesTaskConfig).where(HermesTaskConfig.is_enabled == True)
configs = db.scalars(stmt).all()
for config in configs:
if self._should_run_now(db, config):
self._execute_task(db, config)
finally:
db.close()
def _should_run_now(self, db: Session, config: HermesTaskConfig) -> bool:
# 简单策略检查是否在过去24小时内运行过。
# 如果没有 croniter 库,我们暂时采用按天执行的简化逻辑
stmt = select(HermesTaskExecutionLog).where(
HermesTaskExecutionLog.config_id == config.id,
HermesTaskExecutionLog.status.in_(["success", "running"])
).order_by(HermesTaskExecutionLog.started_at.desc()).limit(1)
last_log = db.scalars(stmt).first()
if not last_log:
return True # 从未执行过,立即执行
now = datetime.now(timezone.utc)
elapsed_hours = (now - last_log.started_at).total_seconds() / 3600
# 简化:只要距离上次成功执行超过了 23.5 小时,就认为该跑了(模拟每天跑一次)
if elapsed_hours >= 23.5:
return True
return False
def _execute_task(self, db: Session, config: HermesTaskConfig) -> None:
logger.info(f"Triggering Hermes task: {config.task_type} (Config ID: {config.id})")
# 创建执行日志,标记为 running
log_record = HermesTaskExecutionLog(
config_id=config.id,
status="running"
)
db.add(log_record)
db.commit()
db.refresh(log_record)
try:
if config.task_type == "global_risk_scan":
scanner = HermesRiskScannerService(db)
scanner.scan_global_risks(log_id=log_record.id)
elif config.task_type == "weekly_expense_report":
reporter = HermesExpenseReportService(db)
reporter.generate_weekly_report(log_id=log_record.id)
log_record.status = "success"
log_record.completed_at = datetime.now(timezone.utc)
log_record.result_summary = "Task executed successfully."
except Exception as e:
logger.error(f"Failed to execute Hermes task {config.task_type}: {e}")
log_record.status = "failed"
log_record.completed_at = datetime.now(timezone.utc)
log_record.error_trace = traceback.format_exc()
finally:
db.commit()
# 全局单例
hermes_scheduler = HermesScheduler()

View File

@@ -34,10 +34,104 @@ def _extract_docx_text(file_path: Path) -> str:
return "当前 Word 文件解析失败。"
root = ElementTree.fromstring(xml_content)
body = next((node for node in root.iter() if node.tag.endswith("}body")), root)
blocks: list[str] = []
for child in body:
if child.tag.endswith("}p"):
paragraph = _extract_docx_paragraph_text(child)
if paragraph:
blocks.append(paragraph)
continue
if child.tag.endswith("}tbl"):
table = _extract_docx_table_rows(child)
rendered = _build_docx_table_markdown(table)
if rendered:
blocks.append(rendered)
if blocks:
return "\n\n".join(blocks)
texts = [node.text.strip() for node in root.iter() if node.tag.endswith("}t") and node.text]
return "\n".join(texts)
def _extract_docx_paragraph_text(node: ElementTree.Element) -> str:
parts: list[str] = []
for child in node.iter():
if child.tag.endswith("}t") and child.text:
parts.append(child.text)
elif child.tag.endswith("}tab"):
parts.append("\t")
elif child.tag.endswith("}br"):
parts.append("\n")
return _normalize_docx_cell_text("".join(parts))
def _extract_docx_table_rows(table_node: ElementTree.Element) -> list[list[str]]:
rows: list[list[str]] = []
for row_node in table_node:
if not row_node.tag.endswith("}tr"):
continue
row: list[str] = []
for cell_node in row_node:
if not cell_node.tag.endswith("}tc"):
continue
cell_parts = [
_extract_docx_paragraph_text(paragraph)
for paragraph in cell_node
if paragraph.tag.endswith("}p")
]
row.append(_normalize_docx_cell_text(" ".join(part for part in cell_parts if part)))
if any(row):
rows.append(row)
return rows
def _build_docx_table_markdown(rows: list[list[str]]) -> str:
visible_rows = [
[_escape_markdown_cell(cell) for cell in row]
for row in rows
if any(str(cell or "").strip() for cell in row)
]
if len(visible_rows) < 2:
return ""
column_count = max(len(row) for row in visible_rows)
normalized_rows = [row + [""] * (column_count - len(row)) for row in visible_rows]
header = [
cell or f"{column_index + 1}" for column_index, cell in enumerate(normalized_rows[0])
]
body_rows = normalized_rows[1:]
parts = [_format_markdown_table(header, body_rows)]
row_clues: list[str] = []
for row_number, row in enumerate(body_rows, start=2):
pairs = [
f"{header[column_index]}={value}"
for column_index, value in enumerate(row)
if value
]
if pairs:
row_clues.append(f"- 表格第 {row_number} 行:" + "".join(pairs))
if row_clues:
parts.append("### 表格行级检索线索")
parts.extend(row_clues)
return "\n\n".join(parts)
def _normalize_docx_cell_text(value: str) -> str:
normalized = str(value or "").replace("\r\n", "\n").replace("\r", "\n")
normalized = re.sub(r"[ \t]*\n[ \t]*", " ", normalized)
normalized = re.sub(r"\s+", " ", normalized)
return normalized.strip()
def _extract_document_text_from_path(
*,
file_path: Path,

View File

@@ -12,7 +12,7 @@ logger = get_logger("app.services.knowledge_normalizer")
TABLE_MARKER_PATTERN = re.compile(r"\s*(\d+)")
SECTION_HEADING_PATTERN = re.compile(
r"^(第[一二三四五六七八九十百零0-9]+[章节]\s*.*|[一二三四五六七八九十]+、.*|[一二三四五六七八九十]+.*|\([一二三四五六七八九十]+\).*)$"
r"^(第[一二三四五六七八九十百零0-9]+[部分章节]\s*.*|[一二三四五六七八九十]+、.*|[一二三四五六七八九十]+.*|\([一二三四五六七八九十]+\).*)$"
)
LIST_ITEM_PATTERN = re.compile(r"^[-*•]\s+.+$")
NUMBERED_ITEM_PATTERN = re.compile(r"^(?:\d+[.)、]|[①②③④⑤⑥⑦⑧⑨⑩])\s*.+$")

View File

@@ -50,6 +50,12 @@ QUERY_TERM_STOPWORDS = {
"哪些人",
}
TABLE_OR_STANDARD_QUERY_HINTS = (
"",
"表格",
"清单",
"明细",
"目录",
"科目",
"标准",
"金额",
"限额",
@@ -61,6 +67,20 @@ TABLE_OR_STANDARD_QUERY_HINTS = (
"档位",
"额度",
)
QUERY_ANCHOR_TERMS = (
"财务基础知识手册",
"基础知识手册",
"会计科目",
"常用会计科目",
"财务报表",
"主要税种",
"税种",
"标准",
"清单",
"明细",
"流程",
)
GENERIC_TITLE_TERMS = {"远光软件", "股份有限", "有限公司"}
STRUCTURED_APPENDIX_LEADING_MARKERS = (
"# 章节导航",
"# 重点章节摘录",
@@ -96,6 +116,10 @@ class KnowledgeRagService:
"message": "请先输入要检索的知识库问题。",
}
rewritten_query = normalized_query
if conversation_history:
rewritten_query = self._rewrite_query(normalized_query, conversation_history)
workspace = (
os.environ.get("LIGHTRAG_WORKSPACE", DEFAULT_LIGHTRAG_WORKSPACE).strip()
or DEFAULT_LIGHTRAG_WORKSPACE
@@ -103,81 +127,102 @@ class KnowledgeRagService:
local_result = query_local_text_chunks(
lightrag_root=(self.storage_root / "knowledge" / ".lightrag").resolve(),
workspace=workspace,
query=normalized_query,
query=rewritten_query,
limit=limit,
)
if local_result.confident:
return {
"result_type": "knowledge_search",
"query": normalized_query,
"record_count": len(local_result.hits),
"hits": local_result.hits,
"references": [
str(item.get("code") or "").strip()
for item in local_result.hits
if str(item.get("code") or "").strip()
],
"raw_references": [],
"metadata": {
"retrieval_strategy": "local_text_chunks",
"elapsed_seconds": round(local_result.elapsed_seconds, 4),
"total_chunks": local_result.total_chunks,
"best_score": local_result.best_score,
},
"message": f"已从本地知识块中检索到 {len(local_result.hits)} 条相关内容。",
}
runtime_hits: list[dict[str, Any]] = []
runtime_references: list[str] = []
try:
runtime = self._get_runtime()
raw = runtime.query_data(normalized_query, conversation_history=conversation_history)
raw = runtime.query_data(rewritten_query, conversation_history=conversation_history)
data = raw.get("data") if isinstance(raw, dict) else {}
chunks = list(data.get("chunks") or []) if isinstance(data, dict) else []
entities = list(data.get("entities") or []) if isinstance(data, dict) else []
runtime_references = list(data.get("references") or []) if isinstance(data, dict) else []
runtime_hits = self._build_hits_from_query_data(
query=rewritten_query,
chunks=chunks,
entities=entities,
limit=limit,
)
except Exception as exc:
logger.warning("Knowledge query failed: %s", exc)
all_hits: dict[str, dict[str, Any]] = {}
for hit in local_result.hits:
hit["score"] = int(hit.get("score") or 0)
all_hits[hit["code"]] = hit
for hit in runtime_hits:
code = hit["code"]
if code in all_hits:
all_hits[code]["score"] = max(all_hits[code]["score"], int(hit.get("score") or 0) + 20)
if not all_hits[code].get("tags") and hit.get("tags"):
all_hits[code]["tags"] = hit["tags"]
else:
hit["score"] = int(hit.get("score") or 0)
all_hits[code] = hit
merged_hits = sorted(all_hits.values(), key=lambda x: int(x.get("score") or 0), reverse=True)[:max(1, limit)]
if not merged_hits:
return {
"result_type": "knowledge_search",
"query": normalized_query,
"query": rewritten_query,
"record_count": 0,
"hits": [],
"references": [],
"message": f"知识库检索暂不可用:{exc}",
}
data = raw.get("data") if isinstance(raw, dict) else {}
chunks = list(data.get("chunks") or []) if isinstance(data, dict) else []
entities = list(data.get("entities") or []) if isinstance(data, dict) else []
references = list(data.get("references") or []) if isinstance(data, dict) else []
hits = self._build_hits_from_query_data(
query=normalized_query,
chunks=chunks,
entities=entities,
limit=limit,
)
if not hits:
return {
"result_type": "knowledge_search",
"query": normalized_query,
"record_count": 0,
"hits": [],
"references": [],
"raw_references": references,
"raw_references": runtime_references,
"message": "当前知识库中没有检索到与本次问题直接匹配的内容。",
}
return {
"result_type": "knowledge_search",
"query": normalized_query,
"record_count": len(hits),
"hits": hits,
"query": rewritten_query,
"record_count": len(merged_hits),
"hits": merged_hits,
"references": [
str(item.get("code") or "").strip()
for item in hits
for item in merged_hits
if str(item.get("code") or "").strip()
],
"raw_references": references,
"metadata": raw.get("metadata") if isinstance(raw, dict) else {},
"message": f"已从知识库中检索到 {len(hits)} 条相关内容。",
"raw_references": runtime_references,
"metadata": {
"retrieval_strategy": "fusion",
"local_total_chunks": local_result.total_chunks,
"local_best_score": local_result.best_score,
},
"message": f"已从知识库中联合检索到 {len(merged_hits)} 条相关内容。",
}
def _rewrite_query(self, query: str, conversation_history: list[dict[str, str]]) -> str:
if not self.db:
return query
from app.services.runtime_chat import RuntimeChatService
try:
chat_service = RuntimeChatService(self.db)
messages: list[dict[str, Any]] = [{"role": "system", "content": "你是一个查询重写助手。你的任务是根据用户的多轮对话历史,将用户的最后一次提问重写为一句独立、完整的查询语句,以便于在知识库中进行向量检索。只输出重写后的句子,不要任何解释。"}]
for msg in conversation_history[-6:]:
messages.append({"role": msg.get("role", "user"), "content": msg.get("content", "")})
messages.append({"role": "user", "content": f"当前提问:{query}\n\n请重写当前提问。"})
rewritten = chat_service.complete(
messages,
max_tokens=60,
temperature=0.1,
timeout_seconds=10,
)
if rewritten and len(rewritten) > 2 and len(rewritten) < 80:
logger.info("Query rewritten: '%s' -> '%s'", query, rewritten)
return rewritten
except Exception as exc:
logger.warning("Query rewrite failed: %s", exc)
return query
def index_documents(
self,
*,
@@ -686,6 +731,24 @@ def _extract_query_terms(query: str) -> list[str]:
remember(item)
for block in re.findall(r"[\u4e00-\u9fff]{2,20}", normalized_query):
for marker in ("标准", "金额", "限额", "额度"):
marker_index = block.find(marker)
if marker_index <= 0:
continue
subject = block[:marker_index]
for width in (6, 4, 3, 2):
remember(subject[-width:])
for anchor in QUERY_ANCHOR_TERMS:
if anchor in block:
remember(anchor)
tail = block[-14:]
for size in (8, 7, 6, 5, 4):
for start in range(0, len(tail) - size + 1):
piece = tail[start : start + size]
if any(anchor in piece for anchor in QUERY_ANCHOR_TERMS):
remember(piece)
if len(terms) >= MAX_QUERY_TERMS:
return terms
if len(block) <= 4:
remember(block)
continue
@@ -715,6 +778,11 @@ def _score_knowledge_hit(
matched_terms = [term for term in query_terms if term in haystack]
score += len(matched_terms) * 8
score += sum(1 for term in matched_terms if term in title) * 6
score += sum(
(len(term) - 3) * 12
for term in matched_terms
if len(term) >= 4 and term in title and term not in GENERIC_TITLE_TERMS
)
leading_appendix_marker = _leading_structured_appendix_marker(content)
if leading_appendix_marker == "# 章节导航":

View File

@@ -42,6 +42,12 @@ LOCAL_QUERY_STOPWORDS = {
"问题",
}
LOCAL_TABLE_QUERY_HINTS = (
"",
"表格",
"清单",
"明细",
"目录",
"科目",
"标准",
"金额",
"限额",
@@ -53,6 +59,20 @@ LOCAL_TABLE_QUERY_HINTS = (
"档位",
"额度",
)
LOCAL_QUERY_ANCHOR_TERMS = (
"财务基础知识手册",
"基础知识手册",
"会计科目",
"常用会计科目",
"财务报表",
"主要税种",
"税种",
"标准",
"清单",
"明细",
"流程",
)
LOCAL_GENERIC_TITLE_TERMS = {"远光软件", "股份有限", "有限公司"}
LOCAL_DOMAIN_TERMS = (
"报销",
"费用",
@@ -253,6 +273,8 @@ def _score_local_chunk(
score += weight
if term in lowered_title:
score += max(4, weight)
if len(term) >= 4 and term not in LOCAL_GENERIC_TITLE_TERMS:
score += (len(term) - 3) * 12
occurrences = lowered_content.count(term)
if occurrences > 1:
score += min(8, occurrences * 2)
@@ -299,6 +321,24 @@ def _extract_local_query_terms(query: str) -> list[str]:
remember(item)
for block in re.findall(r"[\u4e00-\u9fff]{2,24}", normalized_query):
for marker in ("标准", "金额", "限额", "额度"):
marker_index = block.find(marker)
if marker_index <= 0:
continue
subject = block[:marker_index]
for width in (6, 4, 3, 2):
remember(subject[-width:])
for anchor in LOCAL_QUERY_ANCHOR_TERMS:
if anchor in block:
remember(anchor)
tail = block[-14:]
for size in (8, 7, 6, 5, 4):
for start in range(0, len(tail) - size + 1):
piece = tail[start : start + size]
if any(anchor in piece for anchor in LOCAL_QUERY_ANCHOR_TERMS):
remember(piece)
if len(terms) >= MAX_LOCAL_QUERY_TERMS:
return terms
if len(block) <= 4:
remember(block)
continue

View File

@@ -102,7 +102,7 @@ class SemanticOntologyService(
context_json = payload.context_json or {}
reference = self._load_reference_catalog()
compact_query = self._compact(query)
entities = self._extract_entities(query, compact_query, reference)
entities = self._extract_entities(query, compact_query, reference, context_json=context_json)
rule_scenario, scenario_score = self._detect_scenario(compact_query)
time_range, _time_score = self._extract_time_range(
query,
@@ -111,9 +111,14 @@ class SemanticOntologyService(
)
session_scenario = self._resolve_session_type_scenario(context_json)
context_scenario = self._resolve_context_scenario(context_json)
application_context = self._is_expense_application_context(context_json)
application_query = self._looks_like_expense_application(compact_query)
if session_scenario == "knowledge":
rule_scenario = "knowledge"
scenario_score = max(scenario_score, 0.34)
if session_scenario != "knowledge" and (application_context or application_query):
rule_scenario = "expense"
scenario_score = max(scenario_score, 0.22)
if rule_scenario == "unknown" and context_scenario is not None:
rule_scenario = context_scenario
scenario_score = max(scenario_score, 0.14)
@@ -138,6 +143,9 @@ class SemanticOntologyService(
entities=entities,
time_range=time_range,
)
if session_scenario != "knowledge" and (application_context or application_query):
rule_intent = "draft"
intent_score = max(intent_score, 0.22)
if session_scenario != "knowledge" and self._should_inherit_expense_draft(
compact_query,
scenario=rule_scenario,

View File

@@ -20,6 +20,8 @@ from app.services.ontology_rules import (
COMPARE_KEYWORDS,
DRAFT_FOLLOW_UP_KEYWORDS,
DRAFT_KEYWORDS,
EXPENSE_APPLICATION_CONTEXT_TYPES,
EXPENSE_APPLICATION_KEYWORDS,
EXPENSE_NARRATIVE_KEYWORDS,
EXPENSE_REVIEW_ACTIONS,
EXPLAIN_KEYWORDS,
@@ -71,6 +73,21 @@ EXPLICIT_ENTERTAINMENT_KEYWORDS = (
class OntologyDetectionMixin:
@staticmethod
def _is_expense_application_context(context_json: dict[str, Any]) -> bool:
document_type = str(context_json.get("document_type") or "").strip()
application_stage = str(context_json.get("application_stage") or "").strip()
entry_source = str(context_json.get("entry_source") or "").strip()
return (
document_type in EXPENSE_APPLICATION_CONTEXT_TYPES
or application_stage in EXPENSE_APPLICATION_CONTEXT_TYPES
or entry_source in {"documents_application", "expense_application"}
)
@staticmethod
def _looks_like_expense_application(compact_query: str) -> bool:
return any(keyword in compact_query for keyword in EXPENSE_APPLICATION_KEYWORDS)
def _detect_scenario(self, compact_query: str) -> tuple[str, float]:
scores = {key: 0.0 for key in SCENARIO_KEYWORDS}
for scenario, keywords in SCENARIO_KEYWORDS.items():
@@ -341,6 +358,9 @@ class OntologyDetectionMixin:
"conversation_id": payload.context_json.get("conversation_id"),
"conversation_scenario": payload.context_json.get("conversation_scenario"),
"conversation_intent": payload.context_json.get("conversation_intent"),
"document_type": payload.context_json.get("document_type"),
"application_stage": payload.context_json.get("application_stage"),
"application_fields": payload.context_json.get("application_fields"),
"draft_claim_id": payload.context_json.get("draft_claim_id"),
"review_action": payload.context_json.get("review_action"),
"review_form_values": payload.context_json.get("review_form_values"),

View File

@@ -18,7 +18,12 @@ from app.services.ontology_rules import (
DATE_RANGE_PATTERN,
EXPLICIT_DATE_PATTERN,
EXPLICIT_MONTH_PATTERN,
EXPENSE_APPLICATION_ATTACHMENT_REQUIRED_TYPES,
EXPENSE_APPLICATION_CONTEXT_TYPES,
EXPENSE_APPLICATION_KEYWORDS,
EXPENSE_APPLICATION_REQUIRED_SLOT_KEYS,
EXPENSE_TYPE_KEYWORDS,
GENERIC_EXPENSE_APPLICATION_PROMPTS,
GENERIC_EXPENSE_PROMPTS,
LOCATION_KEYWORDS,
MONTH_DAY_PATTERN,
@@ -30,6 +35,21 @@ from app.services.ontology_rules import (
class OntologyExtractionMixin:
@staticmethod
def _is_expense_application_context_value(context_json: dict[str, Any]) -> bool:
document_type = str(context_json.get("document_type") or "").strip()
application_stage = str(context_json.get("application_stage") or "").strip()
entry_source = str(context_json.get("entry_source") or "").strip()
return (
document_type in EXPENSE_APPLICATION_CONTEXT_TYPES
or application_stage in EXPENSE_APPLICATION_CONTEXT_TYPES
or entry_source in {"documents_application", "expense_application"}
)
@staticmethod
def _has_expense_application_signal(compact_query: str) -> bool:
return any(keyword in compact_query for keyword in EXPENSE_APPLICATION_KEYWORDS)
def _infer_default_missing_slots(
self,
compact_query: str,
@@ -46,6 +66,44 @@ class OntologyExtractionMixin:
entity_types = {item.type for item in entities}
attachment_count = int(context_json.get("attachment_count") or 0)
missing_slots: list[str] = []
application_mode = (
self._is_expense_application_context_value(context_json)
or self._has_expense_application_signal(compact_query)
or any(
item.type == "document_type" and item.normalized_value == "expense_application"
for item in entities
)
)
if application_mode:
form_values = context_json.get("review_form_values")
if not isinstance(form_values, dict):
form_values = {}
expense_type_codes = {
str(item.normalized_value or item.value or "").strip()
for item in entities
if item.type == "expense_type"
}
if "expense_type" not in entity_types and not str(form_values.get("expense_type") or "").strip():
missing_slots.append("expense_type")
if "amount" not in entity_types and not str(form_values.get("amount") or "").strip():
missing_slots.append("amount")
if not time_range.start_date and not (
str(form_values.get("time_range") or form_values.get("business_time") or "").strip()
):
missing_slots.append("time_range")
reason_value = str(
form_values.get("reason")
or form_values.get("business_reason")
or form_values.get("reason_value")
or ""
).strip()
if not reason_value and compact_query in GENERIC_EXPENSE_APPLICATION_PROMPTS:
missing_slots.append("reason")
if attachment_count <= 0 and expense_type_codes & EXPENSE_APPLICATION_ATTACHMENT_REQUIRED_TYPES:
missing_slots.append("attachments")
ordered_keys = [*EXPENSE_APPLICATION_REQUIRED_SLOT_KEYS, "attachments"]
return [item for item in ordered_keys if item in missing_slots]
if self._is_generic_expense_prompt(compact_query):
if "expense_type" not in entity_types:
@@ -98,14 +156,40 @@ class OntologyExtractionMixin:
query: str,
compact_query: str,
reference: ReferenceCatalog,
*,
context_json: dict[str, Any] | None = None,
) -> list[OntologyEntity]:
entities: dict[tuple[str, str], OntologyEntity] = {}
context_json = context_json or {}
def upsert(entity: OntologyEntity) -> None:
key = (entity.type, entity.normalized_value)
if key not in entities:
entities[key] = entity
if (
self._is_expense_application_context_value(context_json)
or self._has_expense_application_signal(compact_query)
):
upsert(
self._make_entity(
"document_type",
"费用申请",
"expense_application",
role="target",
confidence=0.94,
)
)
upsert(
self._make_entity(
"workflow_stage",
"前置申请",
"pre_approval",
role="target",
confidence=0.9,
)
)
for match in re.finditer(r"客户\s*([A-Za-z0-9一二三四五六七八九十]+)", query):
suffix = match.group(1).strip()
normalized = f"客户{suffix}".replace(" ", "")
@@ -510,6 +594,8 @@ class OntologyExtractionMixin:
"project",
"location",
"expense_type",
"document_type",
"workflow_stage",
}:
upsert(
OntologyConstraint(

View File

@@ -173,6 +173,49 @@ GENERIC_EXPENSE_PROMPTS = {
"发起报销",
"提交报销",
}
EXPENSE_APPLICATION_CONTEXT_TYPES = {
"expense_application",
"application",
"pre_approval",
"preapproval",
}
EXPENSE_APPLICATION_KEYWORDS = (
"费用申请",
"申请单",
"发起申请",
"提交申请",
"提出申请",
"前置申请",
"报销申请",
"申请报销",
"差旅申请",
"出差申请",
"会务申请",
"会议申请",
"采购申请",
"培训申请",
"预算申请",
)
GENERIC_EXPENSE_APPLICATION_PROMPTS = {
"申请",
"费用申请",
"发起申请",
"提交申请",
"提出申请",
"申请报销",
"报销申请",
}
EXPENSE_APPLICATION_REQUIRED_SLOT_KEYS = (
"expense_type",
"amount",
"time_range",
"reason",
)
EXPENSE_APPLICATION_ATTACHMENT_REQUIRED_TYPES = {
"meeting",
"office",
"training",
}
MISSING_SLOT_LABELS = {
"expense_type": "费用类型",
"amount": "金额",

View File

@@ -14,6 +14,7 @@ from app.schemas.agent_asset import AgentAssetRiskRuleGenerateRequest
from app.services.agent_asset_rule_library import AgentAssetRuleLibraryManager
from app.services.agent_asset_spreadsheet import RISK_RULES_LIBRARY
from app.services.audit import AuditLogService
from app.services.expense_type_keywords import EXPENSE_TYPE_LABEL_BY_CODE
from app.services.risk_rule_flow_diagram import (
RiskRuleFlowDiagramField,
RiskRuleFlowDiagramRenderer,
@@ -43,6 +44,24 @@ RISK_LEVEL_LABELS: dict[str, str] = {
"high": "高风险",
}
EXPENSE_RISK_CATEGORY_CODES: tuple[str, ...] = (
"travel",
"hotel",
"transport",
"meal",
"meeting",
"office",
"training",
"communication",
"welfare",
)
EXPENSE_RISK_CATEGORY_LABELS: dict[str, str] = {
code: EXPENSE_TYPE_LABEL_BY_CODE[code] for code in EXPENSE_RISK_CATEGORY_CODES
}
EXPENSE_RISK_CATEGORY_ALIASES = {
"entertainment": "meal",
}
FIELD_ONTOLOGY: tuple[RiskRuleField, ...] = (
RiskRuleField("claim.reason", "报销事由", "text", "claim", ("事由", "说明", "理由", "用途")),
RiskRuleField(
@@ -156,17 +175,23 @@ class RiskRuleGenerationService:
risk_level = str(body.risk_level or "medium").strip().lower()
if risk_level not in RISK_LEVEL_LABELS:
raise ValueError("风险等级仅支持 low、medium、high。")
requires_attachment = bool(body.requires_attachment)
expense_category = self._normalize_expense_category(body.expense_category, domain)
expense_category_label = EXPENSE_RISK_CATEGORY_LABELS.get(expense_category or "", "")
created_at = datetime.now(UTC)
fields = self._resolve_fields(natural_language, domain=domain)
draft = self._compile_with_model(
natural_language=natural_language,
domain=domain,
expense_category=expense_category,
expense_category_label=expense_category_label,
risk_level=risk_level,
fields=fields,
) or self._build_fallback_draft(
natural_language=natural_language,
domain=domain,
expense_category_label=expense_category_label,
risk_level=risk_level,
fields=fields,
)
@@ -179,10 +204,13 @@ class RiskRuleGenerationService:
draft,
natural_language=natural_language,
domain=domain,
expense_category=expense_category,
expense_category_label=expense_category_label,
risk_level=risk_level,
fields=fields,
created_at=created_at,
actor=actor,
requires_attachment=requires_attachment,
)
rule_code = str(payload["rule_code"])
file_name = f"{rule_code}.json"
@@ -209,8 +237,11 @@ class RiskRuleGenerationService:
config_json={
"severity": risk_level,
"enabled": True,
"requires_attachment": requires_attachment,
"tag": "风险规则",
"detail_mode": "json_risk",
"expense_category": expense_category,
"expense_category_label": expense_category_label,
"risk_category": payload.get("risk_category"),
"rule_library": RISK_RULES_LIBRARY,
"rule_document": {
@@ -241,7 +272,13 @@ class RiskRuleGenerationService:
resource_type=AgentAssetType.RULE.value,
resource_id=asset.id,
before_json=None,
after_json={"rule_code": rule_code, "risk_level": risk_level, "domain": domain},
after_json={
"rule_code": rule_code,
"risk_level": risk_level,
"domain": domain,
"expense_category": expense_category,
"requires_attachment": requires_attachment,
},
request_id=request_id,
)
self.db.refresh(asset)
@@ -252,6 +289,8 @@ class RiskRuleGenerationService:
*,
natural_language: str,
domain: str,
expense_category: str | None,
expense_category_label: str,
risk_level: str,
fields: list[RiskRuleField],
) -> dict[str, Any] | None:
@@ -279,6 +318,8 @@ class RiskRuleGenerationService:
{
"business_domain": domain,
"business_domain_label": BUSINESS_DOMAIN_LABELS[domain],
"expense_category": expense_category,
"expense_category_label": expense_category_label,
"risk_level": risk_level,
"risk_level_label": RISK_LEVEL_LABELS[risk_level],
"natural_language": natural_language,
@@ -370,6 +411,7 @@ class RiskRuleGenerationService:
*,
natural_language: str,
domain: str,
expense_category_label: str,
risk_level: str,
fields: list[RiskRuleField],
) -> dict[str, Any]:
@@ -381,8 +423,9 @@ class RiskRuleGenerationService:
fields=fields,
)
name = self._infer_rule_name(natural_language)
business_label = expense_category_label or BUSINESS_DOMAIN_LABELS[domain]
description = (
f"{BUSINESS_DOMAIN_LABELS[domain]}业务满足“{natural_language}”时,系统会按"
f"{business_label}业务满足“{natural_language}”时,系统会按"
f"{RISK_LEVEL_LABELS[risk_level]}进行提示,并要求经办人或审核人补充核对依据。"
)
return {
@@ -393,7 +436,7 @@ class RiskRuleGenerationService:
"condition_summary": condition_summary,
"keywords": self._infer_keywords(natural_language),
"flow": {
"start": f"{BUSINESS_DOMAIN_LABELS[domain]}单据提交",
"start": f"{business_label}单据提交",
"evidence": "读取" + "".join(item.label for item in fields[:3]),
"decision": condition_summary,
"pass": "未命中风险,继续业务流转",
@@ -407,14 +450,18 @@ class RiskRuleGenerationService:
*,
natural_language: str,
domain: str,
expense_category: str | None,
expense_category_label: str,
risk_level: str,
fields: list[RiskRuleField],
created_at: datetime,
actor: str,
requires_attachment: bool,
) -> dict[str, Any]:
created_stamp = created_at.strftime("%Y%m%d%H%M%S")
created_stamp = created_at.strftime("%Y%m%d%H%M%S%f")
domain_slug = {"expense": "expense", "ar": "ar", "ap": "ap"}[domain]
rule_code = f"risk.{domain_slug}.generated_{created_stamp}"
category_slug = f".{expense_category}" if expense_category else ""
rule_code = f"risk.{domain_slug}{category_slug}.generated_{created_stamp}"
template_key = str(draft.get("template_key") or "field_required_v1").strip()
field_keys = [
str(item or "").strip()
@@ -424,7 +471,7 @@ class RiskRuleGenerationService:
condition_summary = (
self._clean_text(draft.get("condition_summary")) or "判断是否符合自然语言规则描述"
)
risk_category = BUSINESS_DOMAIN_LABELS[domain]
risk_category = expense_category_label or BUSINESS_DOMAIN_LABELS[domain]
keywords = list(draft.get("keywords") or [])
field_by_key = {item.key: item for item in fields}
params: dict[str, Any] = {
@@ -440,6 +487,9 @@ class RiskRuleGenerationService:
if template_key == "keyword_match_v1":
params["keywords"] = keywords
params["search_fields"] = field_keys
applies_to: dict[str, Any] = {"domains": [domain]}
if expense_category:
applies_to["expense_categories"] = [expense_category]
payload = {
"schema_version": "2.0",
@@ -447,12 +497,13 @@ class RiskRuleGenerationService:
"name": self._clean_text(draft.get("name")) or self._infer_rule_name(natural_language),
"description": self._clean_text(draft.get("description")) or natural_language,
"enabled": True,
"requires_attachment": requires_attachment,
"risk_dimension": "natural_language_rule",
"risk_category": risk_category,
"ontology_signal": "natural_language_risk",
"evaluator": "template_rule",
"template_key": template_key,
"applies_to": {"domains": [domain]},
"applies_to": applies_to,
"inputs": {
"fields": [
{
@@ -478,6 +529,9 @@ class RiskRuleGenerationService:
"source_ref": "自然语言风险规则",
"created_at": created_at.isoformat(),
"created_by": actor,
"requires_attachment": requires_attachment,
"expense_category": expense_category,
"expense_category_label": expense_category_label,
"natural_language": natural_language,
"business_explanation": self._clean_text(draft.get("description")),
"condition_summary": condition_summary,
@@ -488,6 +542,7 @@ class RiskRuleGenerationService:
payload,
fields=[field_by_key[key] for key in field_keys if key in field_by_key],
domain=domain,
domain_label=risk_category,
risk_level=risk_level,
)
return payload
@@ -498,6 +553,7 @@ class RiskRuleGenerationService:
*,
fields: list[RiskRuleField],
domain: str,
domain_label: str | None = None,
risk_level: str,
) -> str:
metadata = payload.get("metadata") if isinstance(payload.get("metadata"), dict) else {}
@@ -506,7 +562,7 @@ class RiskRuleGenerationService:
return self.flow_diagram_renderer.render(
RiskRuleFlowDiagramSpec(
title=self._clean_text(payload.get("name")) or "风险规则判断流程",
domain_label=BUSINESS_DOMAIN_LABELS.get(domain, "业务"),
domain_label=domain_label or BUSINESS_DOMAIN_LABELS.get(domain, "业务"),
severity=risk_level,
severity_label=RISK_LEVEL_LABELS.get(risk_level, "中风险"),
fields=tuple(
@@ -528,6 +584,21 @@ class RiskRuleGenerationService:
)
)
@staticmethod
def _normalize_expense_category(value: str | None, domain: str) -> str | None:
if domain != AgentAssetDomain.EXPENSE.value:
return None
normalized = str(value or "").strip().lower()
if not normalized:
return None
normalized = EXPENSE_RISK_CATEGORY_ALIASES.get(normalized, normalized)
if normalized not in EXPENSE_RISK_CATEGORY_LABELS:
allowed = "".join(EXPENSE_RISK_CATEGORY_LABELS.values())
raise ValueError(f"费用领域仅支持:{allowed}")
return normalized
def _resolve_fields(self, text: str, *, domain: str) -> list[RiskRuleField]:
prefixes = DOMAIN_FIELD_PREFIXES.get(domain, ())
candidates = [field for field in FIELD_ONTOLOGY if field.key.startswith(prefixes)]

View File

@@ -172,8 +172,12 @@ class RiskRuleTemplateExecutor:
if field_key == "ocr_text":
values.extend([context.get("ocr_text"), context.get("ocr_summary")])
if field_key in {"hotel_city", "route_cities"}:
values.extend(self._scan_document_values(document_info, field_key))
values.extend(self._scan_document_values(document_info, "city"))
specific_values = self._scan_document_values(document_info, field_key)
values.extend(
specific_values
if specific_values
else self._scan_document_values(document_info, "city")
)
else:
values.extend(self._scan_document_values(document_info, field_key))
return self._normalize_values(values)
@@ -203,8 +207,8 @@ class RiskRuleTemplateExecutor:
"buyer_name": ("购买方", "抬头", "买方"),
"goods_name": ("品名", "商品", "服务名称"),
"issue_date": ("日期", "开票日期", "发票日期"),
"hotel_city": ("住宿城市", "酒店城市", "酒店地点"),
"route_cities": ("行程", "路线", "城市"),
"hotel_city": ("住宿城市", "酒店城市", "酒店地点", "住宿", "酒店"),
"route_cities": ("行程", "路线", "目的地", "出差城市"),
"city": ("城市", "地点"),
}
return any(item in label for item in label_map.get(field_key, ()))

View File

@@ -16,6 +16,7 @@ from app.db.session import get_session_factory
from app.models.system_model_setting import SystemModelSetting
from app.models.system_setting import SystemSetting
from app.models.system_setting_secret import SystemSettingSecret
from app.models.hermes_config import HermesTaskConfig
from app.repositories.settings import SETTINGS_ROW_ID, SettingsRepository
from app.schemas.settings import SettingsRead, SettingsWrite
from app.services.hermes_sync import (
@@ -183,28 +184,30 @@ class SettingsService:
capability=config.capability,
priority=config.priority,
enabled=True,
api_key_encrypted=str(getattr(secrets_row, config.legacy_secret_attr, "") or ""),
)
self.db.add(model_row)
model_rows[slot] = model_row
should_commit = True
if should_commit:
self.db.commit()
for model_row in model_rows.values():
self.db.refresh(model_row)
return model_rows
def get_settings_snapshot(self) -> SettingsRead:
settings_row, secrets_row = self.ensure_settings_ready()
model_rows = self.ensure_model_settings_ready(settings_row, secrets_row)
return self._serialize(settings_row, secrets_row, model_rows)
api_key_encrypted=str(getattr(secrets_row, config.legacy_secret_attr, "") or ""),
)
self.db.add(model_row)
model_rows[slot] = model_row
should_commit = True
if should_commit:
self.db.commit()
for model_row in model_rows.values():
self.db.refresh(model_row)
return model_rows
def get_settings_snapshot(self) -> SettingsRead:
settings_row, secrets_row = self.ensure_settings_ready()
model_rows = self.ensure_model_settings_ready(settings_row, secrets_row)
hermes_form = self._build_hermes_form_snapshot()
return self._serialize(settings_row, secrets_row, model_rows, hermes_form)
def save_settings_snapshot(self, payload: SettingsWrite) -> SettingsRead:
settings_row, secrets_row = self.ensure_settings_ready()
model_rows = self.ensure_model_settings_ready(settings_row, secrets_row)
if payload.adminForm.newPassword:
if len(payload.adminForm.newPassword) < 5:
raise ValueError("管理员密码至少需要 5 位。")
@@ -308,6 +311,8 @@ class SettingsService:
self._replace_secret_if_present(secrets_row, "smtp_password_encrypted", payload.mailForm.password)
hermes_snapshot = capture_hermes_config_snapshot()
self._save_hermes_form_snapshot(payload.hermesForm)
try:
sync_hermes_model_settings(
@@ -642,46 +647,107 @@ class SettingsService:
return should_commit
def _build_hermes_form_snapshot(self) -> dict:
configs = self.db.query(HermesTaskConfig).all()
capabilities = {}
schedules = {}
master_enabled = True # 这里假设只要有一个开启,主开关就是开启的(为简单起见)
for config in configs:
task_type = config.task_type
capabilities[task_type] = config.is_enabled
# 简化解析 cron_expression 到 time (假设 cron 为 "0 9 * * 1" 这种形式)
time_str = "00:00"
if config.cron_expression:
parts = config.cron_expression.split(" ")
if len(parts) >= 2:
minute, hour = parts[0], parts[1]
try:
time_str = f"{int(hour):02d}:{int(minute):02d}"
except ValueError:
pass
schedules[task_type] = {
"enabled": config.is_enabled,
"time": time_str
}
return {
"masterEnabled": master_enabled,
"notifyOnFailure": True,
"capabilities": capabilities,
"schedules": schedules
}
def _save_hermes_form_snapshot(self, hermes_form: dict) -> None:
if not hermes_form:
return
schedules = hermes_form.get("schedules", {})
capabilities = hermes_form.get("capabilities", {})
master_enabled = hermes_form.get("masterEnabled", True)
for task_type, schedule in schedules.items():
config = self.db.query(HermesTaskConfig).filter_by(task_type=task_type).first()
if not config:
config = HermesTaskConfig(task_type=task_type)
self.db.add(config)
task_enabled = schedule.get("enabled", False) and capabilities.get(task_type, False) and master_enabled
config.is_enabled = task_enabled
# 从 time 构建简单的 cron expression
time_str = schedule.get("time", "00:00")
parts = time_str.split(":")
if len(parts) == 2:
# 简单映射:把时分放进去,后面保留为 * * * (或者保留旧的后半段)
# 这里偷个懒,风险扫描每天跑,周报每周一跑
if task_type == "global_risk_scan":
config.cron_expression = f"{int(parts[1])} {int(parts[0])} * * *"
elif task_type == "weekly_expense_report":
config.cron_expression = f"{int(parts[1])} {int(parts[0])} * * 1"
else:
config.cron_expression = f"{int(parts[1])} {int(parts[0])} * * *"
@staticmethod
def _serialize(
settings_row: SystemSetting,
secrets_row: SystemSettingSecret,
model_rows: dict[str, SystemModelSetting],
hermes_form: dict,
) -> SettingsRead:
main_model = model_rows["main"]
backup_model = model_rows["backup"]
embedding_model = model_rows["embedding"]
reranker_model = model_rows["reranker"]
return SettingsRead(
companyForm={
"companyName": settings_row.company_name,
"displayName": settings_row.display_name,
"companyCode": settings_row.company_code,
"recordNumber": settings_row.record_number,
"copyright": settings_row.copyright_text,
},
return SettingsRead(
companyForm={
"companyName": settings_row.company_name,
"displayName": settings_row.display_name,
"companyCode": settings_row.company_code,
"recordNumber": settings_row.record_number,
"copyright": settings_row.copyright_text,
},
adminForm={
"adminAccount": settings_row.admin_account,
"adminEmail": settings_row.admin_email,
"newPassword": "",
"confirmPassword": "",
"sessionTimeout": settings_row.session_timeout,
"noticeEmail": settings_row.notice_email,
"mfaEnabled": settings_row.mfa_enabled,
"strongPassword": settings_row.strong_password,
"sessionTimeout": settings_row.session_timeout,
"noticeEmail": settings_row.notice_email,
"mfaEnabled": settings_row.mfa_enabled,
"strongPassword": settings_row.strong_password,
"loginAlertEnabled": settings_row.login_alert_enabled,
"adminPasswordConfigured": bool(secrets_row.admin_password_hash),
},
sessionForm={
"conversationRetentionDays": settings_row.conversation_retention_days,
},
hermesForm=hermes_form,
llmForm={
"mainProvider": main_model.provider,
"mainModel": main_model.model_name,
"mainEndpoint": main_model.endpoint,
"mainApiKey": "",
"mainApiKeyConfigured": bool(main_model.api_key_encrypted),
"backupProvider": backup_model.provider,
"backupModel": backup_model.model_name,
"backupEndpoint": backup_model.endpoint,

View File

@@ -71,8 +71,8 @@ EXPENSE_SCENE_SELECTION_OPTIONS = (
("other", "其他费用", "暂不属于以上分类的报销场景。"),
)
KNOWLEDGE_MODEL_MAIN_TIMEOUT_SECONDS = 3
KNOWLEDGE_MODEL_BACKUP_TIMEOUT_SECONDS = 5
KNOWLEDGE_MODEL_MAIN_TIMEOUT_SECONDS = 20
KNOWLEDGE_MODEL_BACKUP_TIMEOUT_SECONDS = 30
KNOWLEDGE_MODEL_TIMEOUT_SECONDS = KNOWLEDGE_MODEL_BACKUP_TIMEOUT_SECONDS
EXPENSE_STATUS_LABELS = {

View File

@@ -86,6 +86,7 @@ class UserAgentKnowledgeMixin(UserAgentKnowledgeHelpersMixin):
*,
citations: list[UserAgentCitation],
) -> str | None:
return None
if payload.ontology.scenario != "knowledge":
return None
if str(payload.tool_payload.get("result_type") or "").strip() != "knowledge_search":
@@ -583,20 +584,23 @@ class UserAgentKnowledgeMixin(UserAgentKnowledgeHelpersMixin):
evidence_lines: list[str] = []
for item in evidence_items[:3]:
heading = str(item.get("heading") or "").strip()
heading_text = f" > {heading}" if heading else ""
if "表格行级检索线索" in heading:
heading = heading.replace("表格行级检索线索", "").strip(" >")
heading_text = f"{heading}" if heading else ""
item_title = item.get("title") or title
if str(item.get("kind") or "") == "table":
preview = self._extract_relevant_table_preview(
str(item.get("content") or ""),
self._extract_knowledge_query_terms(self._resolve_knowledge_question(payload)),
)
evidence_lines.append(f"- 《{item.get('title') or title}{heading_text}\n{preview}")
evidence_lines.append(f"- **{item_title}** {heading_text}\n{preview}")
continue
rendered = self._render_knowledge_evidence_text(item)
if rendered:
if "\n" in rendered:
evidence_lines.append(f"- 《{item.get('title') or title}{heading_text}\n{rendered}")
evidence_lines.append(f"- **{item_title}** {heading_text}\n{rendered}")
else:
evidence_lines.append(f"- 《{item.get('title') or title}{heading_text}{rendered}")
evidence_lines.append(f"- **{item_title}** {heading_text}\n {rendered}")
if not evidence_lines:
for item in hits[:2]:
@@ -607,21 +611,22 @@ class UserAgentKnowledgeMixin(UserAgentKnowledgeHelpersMixin):
)
if not excerpt:
continue
evidence_lines.append(f"- 《{item_title}》:{excerpt}")
evidence_lines.append(f"- **{item_title}**{excerpt}")
if not evidence_lines:
return (
f"{prefix}我已经从{title}中检索到与你这次问题相关的制度依据,"
"但本次答案生成环节暂时没有成功返回。请稍后重试一次;如果仍然失败,"
"建议先检查主对话模型的连通性。"
f"{prefix}当前{title}里可用于回答的关键条款还不够明确。"
"请补充费用类型、适用地区、职级或具体业务场景,我再继续帮你缩小范围。"
)
return "\n".join(
[
f"{prefix}已经命中与你这次问题最相关的制度依据,但答案整理阶段本轮没有及时返回",
"先给你当前最直接的依据:",
f"{prefix}先根据当前制度依据给出可以确认的部分",
"",
"**依据**",
*evidence_lines,
"如果你希望我继续把这些依据整理成更完整的结论、步骤或对比说明,可以继续缩小问题范围后再问一次。",
"",
"**说明**:以上只使用当前命中的知识库证据;没有在证据中出现的适用条件或金额,我不会替你默认补齐。",
]
).strip()

View File

@@ -4,6 +4,9 @@ import re
KNOWLEDGE_DIRECT_ANSWER_HINTS = (
"是什么",
"介绍",
"说明",
"概述",
"标准",
"限额",
"流程",
@@ -45,7 +48,7 @@ MAX_KNOWLEDGE_QUERY_TERMS = 12
MAX_KNOWLEDGE_DIRECT_EVIDENCE = 4
MAX_KNOWLEDGE_MODEL_HITS = 5
KNOWLEDGE_SECTION_HEADING_PATTERN = re.compile(
r"^(#\s*.+|##\s*.+|###\s*.+|第[一二三四五六七八九十百零0-9]+[章节条]\s*.*|[一二三四五六七八九十]+、.*|[一二三四五六七八九十]+.*|\([一二三四五六七八九十]+\).*)$"
r"^(#\s*.+|##\s*.+|###\s*.+|第[一二三四五六七八九十百零0-9]+[部分章节条]\s*.*|[一二三四五六七八九十]+、.*|[一二三四五六七八九十]+.*|\([一二三四五六七八九十]+\).*)$"
)
KNOWLEDGE_LIST_ITEM_PATTERN = re.compile(r"^[-*•]\s+.+$")
KNOWLEDGE_NUMBERED_ITEM_PATTERN = re.compile(

View File

@@ -15,6 +15,20 @@ from app.services.user_agent_knowledge_constants import (
class UserAgentKnowledgeHelpersMixin:
GENERIC_KNOWLEDGE_TITLE_TERMS = {"远光软件", "股份有限", "有限公司"}
KNOWLEDGE_QUERY_ANCHOR_TERMS = (
"财务基础知识手册",
"基础知识手册",
"会计科目",
"常用会计科目",
"财务报表",
"主要税种",
"税种",
"标准",
"清单",
"明细",
"流程",
)
@staticmethod
def _select_knowledge_model_hits(
@@ -26,7 +40,7 @@ class UserAgentKnowledgeHelpersMixin:
item
for item in list(tool_payload.get("hits") or [])
if isinstance(item, dict)
][: max(MAX_KNOWLEDGE_MODEL_HITS + 1, 6)]
][: max(MAX_KNOWLEDGE_MODEL_HITS + 3, 8)]
if not raw_hits:
return []
@@ -64,7 +78,16 @@ class UserAgentKnowledgeHelpersMixin:
matched_terms = [term for term in query_terms if term in haystack]
score = max(1, 48 - rank_index * 4)
score += len(matched_terms) * 10
score += sum(max(0, len(term) - 4) * 8 for term in matched_terms)
score += sum(1 for term in matched_terms if term in title) * 8
score += sum(max(0, len(term) - 4) * 6 for term in matched_terms if term in title)
score += sum(
(len(term) - 3) * 10
for term in matched_terms
if len(term) >= 4
and term in title
and term not in UserAgentKnowledgeHelpersMixin.GENERIC_KNOWLEDGE_TITLE_TERMS
)
leading_marker = UserAgentKnowledgeHelpersMixin._leading_knowledge_appendix_marker(content)
if leading_marker == "# 章节导航":
@@ -149,6 +172,40 @@ class UserAgentKnowledgeHelpersMixin:
return ""
@staticmethod
def _knowledge_list_marker_sort_key(content: str) -> int:
normalized = str(content or "").strip()
match = re.match(r"^[(]([一二三四五六七八九十百零0-9]+)[)]", normalized)
if not match:
return 999
marker = match.group(1)
if marker.isdigit():
return int(marker)
values = {
"": 0,
"": 1,
"": 2,
"": 3,
"": 4,
"": 5,
"": 6,
"": 7,
"": 8,
"": 9,
"": 10,
}
if marker in values:
return values[marker]
if marker.startswith("") and len(marker) == 2:
return 10 + values.get(marker[1], 0)
if marker.endswith("") and len(marker) == 2:
return values.get(marker[0], 0) * 10
if "" in marker:
left, right = marker.split("", 1)
return values.get(left, 1) * 10 + values.get(right, 0)
return 999
@staticmethod
def _format_knowledge_heading_label(heading: str) -> str:
@@ -156,6 +213,169 @@ class UserAgentKnowledgeHelpersMixin:
return " / ".join(parts)
@staticmethod
def _has_inline_numbered_knowledge_items(content: str) -> bool:
return len(
re.findall(
r"[(][一二三四五六七八九十百零0-9]+[)]",
str(content or ""),
)
) >= 2
@staticmethod
def _split_inline_numbered_knowledge_items(content: str) -> list[str]:
normalized = str(content or "").strip()
if not UserAgentKnowledgeHelpersMixin._has_inline_numbered_knowledge_items(normalized):
return [normalized] if normalized else []
marker_pattern = r"[(][一二三四五六七八九十百零0-9]+[)]"
first_marker = re.search(marker_pattern, normalized)
if first_marker is None:
return [normalized] if normalized else []
prefix = normalized[: first_marker.start()].strip(" :")
tail = normalized[first_marker.start() :].strip()
item_pattern = (
r"([(][一二三四五六七八九十百零0-9]+[)]\s*.*?"
r"(?=\s*[(][一二三四五六七八九十百零0-9]+[)]|\s*$))"
)
items = [item.strip() for item in re.findall(item_pattern, tail) if item.strip()]
if prefix:
return [prefix, *items]
return items or [normalized]
@staticmethod
def _focus_knowledge_segment_content(content: str, query_terms: list[str]) -> str:
normalized = re.sub(r"\s+", " ", str(content or "").strip())
if not normalized:
return ""
anchor_terms = sorted(
{
str(term or "").strip()
for term in query_terms
if len(str(term or "").strip()) >= 3
},
key=len,
reverse=True,
)
anchor_index = -1
for term in anchor_terms:
anchor_index = normalized.lower().find(term.lower())
if anchor_index >= 0:
break
if anchor_index < 0:
return normalized
prefix_window = normalized[max(0, anchor_index - 40) : anchor_index]
marker_match = None
for match in re.finditer(
r"(?:第[一二三四五六七八九十百零0-9]+[部分章节条]|[一二三四五六七八九十]+、|[(][一二三四五六七八九十百零0-9]+[)])",
prefix_window,
):
marker_match = match
start = anchor_index
if marker_match is not None:
start = max(0, anchor_index - len(prefix_window) + marker_match.start())
return normalized[start : start + 700].strip()
@staticmethod
def _split_markdown_table_cells(line: str) -> list[str]:
stripped = str(line or "").strip()
if stripped.startswith("|"):
stripped = stripped[1:]
if stripped.endswith("|"):
stripped = stripped[:-1]
return [
re.sub(r"\s+", " ", cell.replace("**", "").strip())
for cell in stripped.split("|")
]
@classmethod
def _summarize_knowledge_table_preview(cls, preview: str) -> str:
rows: list[list[str]] = []
for line in str(preview or "").splitlines():
if line.count("|") < 2:
continue
cells = cls._split_markdown_table_cells(line)
if not cells or all(re.fullmatch(r":?-{2,}:?", cell.replace(" ", "")) for cell in cells):
continue
rows.append(cells)
if len(rows) < 2:
return "可直接参考的标准表如下。"
header = rows[0]
data_rows = [row for row in rows[1:] if len(row) == len(header)]
if len(data_rows) == 1 and len(header) >= 2:
row = data_rows[0]
subject = row[0] or "该项目"
pairs = [
f"{label}{value}"
for label, value in zip(header[1:], row[1:])
if label and value and value not in {"-", ""}
]
if pairs:
return f"{subject}的标准为:{''.join(pairs)}"
return "相关标准项如下,请按表头和行内容对应使用。"
def _summarize_knowledge_lines_conclusion(
self,
lines: list[str],
*,
heading: str = "",
) -> str:
clean_lines = [
self._clean_knowledge_segment_text(line)
for line in lines
if self._clean_knowledge_segment_text(line)
]
if not clean_lines:
return ""
clean_heading = str(heading or "").strip()
if not clean_heading and clean_lines and "" not in clean_lines[0] and ":" not in clean_lines[0]:
clean_heading = clean_lines[0]
clean_heading = re.sub(
r"^[一二三四五六七八九十百零0-9]+、\s*",
"",
clean_heading,
)
item_labels: list[str] = []
for line in clean_lines:
if "" not in line and ":" not in line:
continue
label = re.split(r"[:]", line, maxsplit=1)[0].strip()
if 1 <= len(label) <= 24:
item_labels.append(label)
if clean_heading and len(item_labels) >= 2:
return f"{clean_heading}包括:{''.join(item_labels[:6])}"
if item_labels:
return f"{item_labels[0]}{clean_lines[0].split('', 1)[-1].strip()}"
return clean_lines[0]
@staticmethod
def _knowledge_lines_have_multiple_labeled_items(lines: list[str]) -> bool:
labeled_count = 0
for line in lines:
normalized = str(line or "").strip()
if "" not in normalized and ":" not in normalized:
continue
label = re.split(r"[:]", normalized, maxsplit=1)[0].strip()
if 1 <= len(label) <= 24:
labeled_count += 1
return labeled_count >= 2
def _score_knowledge_evidence_candidate(
self,
@@ -169,10 +389,14 @@ class UserAgentKnowledgeHelpersMixin:
matched_terms = [term for term in query_terms if term in haystack]
score = len(matched_terms) * 10
score += sum(max(0, len(term) - 4) * 8 for term in matched_terms)
score += sum(1 for term in matched_terms if term in heading) * 6
score += sum(max(0, len(term) - 4) * 6 for term in matched_terms if term in heading)
if kind == "table":
score += 10
if content.count("\n") < 2:
score -= 24
elif kind in {"kv", "clause", "list"}:
score += 8
elif kind == "paragraph":
@@ -220,6 +444,30 @@ class UserAgentKnowledgeHelpersMixin:
remember(item)
for block in re.findall(r"[\u4e00-\u9fff]{2,20}", normalized_question):
remember(block)
if len(terms) >= MAX_KNOWLEDGE_QUERY_TERMS:
return terms
for marker in ("标准", "金额", "限额", "额度"):
marker_index = block.find(marker)
if marker_index <= 0:
continue
subject = block[:marker_index]
for width in (6, 4, 3, 2):
remember(subject[-width:])
for anchor in UserAgentKnowledgeHelpersMixin.KNOWLEDGE_QUERY_ANCHOR_TERMS:
if anchor in block:
remember(anchor)
tail = block[-14:]
for size in (8, 7, 6, 5, 4):
for start in range(0, len(tail) - size + 1):
piece = tail[start : start + size]
if any(
anchor in piece
for anchor in UserAgentKnowledgeHelpersMixin.KNOWLEDGE_QUERY_ANCHOR_TERMS
):
remember(piece)
if len(terms) >= MAX_KNOWLEDGE_QUERY_TERMS:
return terms
if len(block) <= 4:
remember(block)
continue
@@ -276,7 +524,14 @@ class UserAgentKnowledgeHelpersMixin:
@staticmethod
def _extract_relevant_table_preview(content: str, query_terms: list[str]) -> str:
def _extract_relevant_table_preview(
content: str,
query_terms: list[str],
*,
preferred_terms: list[str] | None = None,
max_rows: int = 3,
fallback_rows: int = 2,
) -> str:
lines = [line.strip() for line in str(content or "").splitlines() if line.strip()]
if len(lines) <= 3:
return "\n".join(lines)
@@ -285,12 +540,39 @@ class UserAgentKnowledgeHelpersMixin:
divider = lines[1] if len(lines) > 1 else ""
body = lines[2:] if divider.count("|") >= 2 else lines[1:]
preferred = [
str(term or "").strip().lower()
for term in list(preferred_terms or [])
if str(term or "").strip()
]
base_terms = preferred + [
str(term or "").strip().lower()
for term in query_terms
if str(term or "").strip().lower() not in preferred
]
derived_terms: list[str] = []
for term in base_terms:
for marker in ("标准", "金额", "限额", "额度", "是多少"):
marker_index = term.find(marker)
if marker_index <= 0:
continue
subject = term[:marker_index].strip()
if len(subject) < 2:
continue
for width in (6, 4, 3, 2):
derived_terms.append(subject[-width:])
search_terms: list[str] = []
for term in [*preferred, *derived_terms, *base_terms]:
if term and term not in search_terms:
search_terms.append(term)
matched_rows = [
row
for row in body
if any(term in row.lower() for term in query_terms)
if any(term in row.lower() for term in search_terms)
]
selected_rows = matched_rows[:3] or body[:2]
selected_rows = matched_rows[:max_rows] or body[:fallback_rows]
preview_lines = [header]
if divider:
preview_lines.append(divider)
@@ -298,6 +580,18 @@ class UserAgentKnowledgeHelpersMixin:
return "\n".join(preview_lines).strip()
@staticmethod
def _question_requests_broad_knowledge_table(question: str) -> bool:
normalized = str(question or "").strip()
if not normalized:
return False
broad_hints = ("有哪些", "是什么", "介绍", "说明", "列表", "清单", "全部", "完整")
table_subject_hints = ("科目", "目录", "清单", "列表", "", "明细")
return any(hint in normalized for hint in broad_hints) and any(
hint in normalized for hint in table_subject_hints
)
@staticmethod
def _question_requires_explicit_condition(question: str) -> bool:

View File

@@ -261,7 +261,6 @@ class UserAgentResponseMixin:
"draft_payload": draft_payload.model_dump(mode="json") if draft_payload is not None else None,
"selected_capability_codes": payload.selected_capability_codes,
"requires_confirmation": payload.requires_confirmation,
"fallback_answer": fallback_answer,
}
if payload.ontology.scenario == "knowledge":
facts["knowledge_evidence_blocks"] = self._build_knowledge_evidence_blocks(