Compare commits
3 Commits
feat/ai-da
...
feat/ai-da
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c7ba7bb453 | ||
|
|
67c3f30eb2 | ||
|
|
73aee622c7 |
@@ -1,190 +0,0 @@
|
|||||||
# AI 数据飞轮 概念文档
|
|
||||||
|
|
||||||
更新时间:2026-07-03
|
|
||||||
|
|
||||||
文档路径:document/development/2026-07-03/feature/ai-data-flywheel/CONCEPT.md
|
|
||||||
|
|
||||||
## 功能一句话
|
|
||||||
|
|
||||||
把用户反馈、人工修正、风险样本、评测结果自动沉淀并回流到下一次 LLM 推理与规则生成,让费控系统在不停机的情况下持续提升准确率、降低误报率和人工干预率。
|
|
||||||
|
|
||||||
## 背景与问题
|
|
||||||
|
|
||||||
- 当前现状:项目已具备"聪明"的骨架——RAG(`knowledge_rag_runtime.py` + Qdrant)、风险规则自动生成(`risk_rule_generation*.py`)、反馈样本沉淀(`skills/domain/false-positive-sample-accumulator` 等 3 个 accumulator)、规则回放评测(`risk-algorithm-replay-evaluator`)、行为画像(`employee_behavior_profile*`)、用户反馈表(`agent_feedback.py` + `AgentOperationFeedback`)。
|
|
||||||
- 用户痛点:样本在往 accumulator 池子里堆,但**下一次 LLM 推理时并没有把这些样本检索出来当 few-shot 喂进去**;prompt 散落在各 `*_prompt.py`,没有版本号、没有在线 A/B、没有回归门禁;OCR 抽取的人工修正值没有回流成评测/训练数据;低分反馈只汇总不归因。
|
|
||||||
- 业务影响:系统每次推理都从"初始水平"出发,无法把历史踩过的坑转化成下一次的能力;改 prompt / 改规则无法证明是否变好,存在隐性回归风险;运营和算法同事看不到"系统在进步"的证据。
|
|
||||||
- 为什么现在需要做:飞轮骨架已齐,缺的是把"样本池 → 检索注入 + 评测门禁 → prompt/规则版本"这段断开的箭头接上。补上后整张图就转起来,且改动集中在 prompt 构造层 + 新增 eval 目录,不动业务主链路,风险低。
|
|
||||||
|
|
||||||
## 目标与非目标
|
|
||||||
|
|
||||||
### 目标
|
|
||||||
|
|
||||||
- [G1] few-shot 在线检索注入:推理前从样本池按 case 特征做向量检索,取 top-k 历史样本(含人工结论)拼进 system prompt。
|
|
||||||
- [G2] 黄金评测集 + 自动回归门禁:版本化 golden set,prompt/规则变更后在 golden set 上自动跑分,分数不达标禁止发布。
|
|
||||||
- [G3] Prompt 版本化 + Canary A/B:prompt 进表带版本号,支持 stable / canary 流量切分,反馈分数对比。
|
|
||||||
- [G4] 抽取修正回流:附件/明细字段的人工修正值记录为 diff,沉淀为抽取评测集与 few-shot 样本。
|
|
||||||
- [G5] 低分反馈自动归因:低分反馈触发归因 agent,拉 trace 诊断错误环节并生成改进任务。
|
|
||||||
- [G6] AI 智商看板:每周自动跑 golden set,输出准确率/召回率/误报率/人工干预率随时间的曲线。
|
|
||||||
|
|
||||||
### 非目标
|
|
||||||
|
|
||||||
- [NG1] 本轮不做模型微调 / 自训练:只走 prompt 侧的 in-context learning + 规则学习。
|
|
||||||
- [NG2] 本轮不改变现有业务主链路(申请单、报销、审批)的接口契约。
|
|
||||||
- [NG3] 本轮不替换 Qdrant / LightRAG 底座,复用现有向量存储与 embedding 配置。
|
|
||||||
- [NG4] 政策新鲜度检测(外部政策变更 → 自动重生成规则)后续再评估,本轮只在评测门禁侧预留接口。
|
|
||||||
|
|
||||||
## 用户与场景
|
|
||||||
|
|
||||||
- 目标用户:
|
|
||||||
1. 报销人 / 申请人:感知到系统越来越准,少打回、少补件。
|
|
||||||
2. 财务审批人:误报率下降,审批被打断的次数减少。
|
|
||||||
3. 算法/运营同学:能看到智商曲线、能灰度上线 prompt、能跑回归评测。
|
|
||||||
- 使用入口:
|
|
||||||
- 推理时自动注入 few-shot(对用户透明)。
|
|
||||||
- 后台 Canary 控制台(运营切流量、看分数)。
|
|
||||||
- AI 智商看板(周报 / 在线查询)。
|
|
||||||
- 核心场景:
|
|
||||||
1. 用户提交报销 → 系统预审 → 预审 prompt 自动注入相似历史误报样本 → 给出更准结论。
|
|
||||||
2. 算法同学改了 risk rule 生成 prompt → 发布前自动跑 golden set → 不达标被门禁拦下。
|
|
||||||
3. 用户给低分 → 归因 agent 诊断"是检索没召回 / 规则误判 / 回复格式问题"→ 自动建改进任务并回写样本池。
|
|
||||||
4. 审批人改了 OCR 抽错的金额 → diff 自动沉淀 → 下次同类票据抽取 prompt 多一条 few-shot。
|
|
||||||
- 异常场景:
|
|
||||||
- 样本池为空或检索失败 → 退化为无 few-shot 推理,不阻塞主链路。
|
|
||||||
- 评测门禁服务不可用 → 默认放行 stable,canary 自动暂停。
|
|
||||||
- Canary 候选 prompt 分数劣化 → 自动回滚到 stable。
|
|
||||||
|
|
||||||
## 功能能力
|
|
||||||
|
|
||||||
- [C1] 输入能力:消费 accumulator 样本池、`AgentOperationFeedback`、附件修正 diff、trace 数据作为飞轮原料。
|
|
||||||
- [C2] 处理能力:样本检索(向量 + 元数据过滤)、评测打分(准确率/召回率/误报率/F1)、Canary 流量切分、低分归因。
|
|
||||||
- [C3] 输出能力:few-shot 注入后的 messages、评测报告、智商曲线、改进任务、归因结论。
|
|
||||||
- [C4] 状态与权限:样本带"人工已确认"标签才进可注入集合;prompt 版本有 stable/canary/pinned 状态;评测门禁可由运营关闭(审计可见)。
|
|
||||||
- [C5] 边界与降级:检索失败、评测失败、Canary 失败均降级到 stable,不阻塞业务推理。
|
|
||||||
|
|
||||||
## 方案设计
|
|
||||||
|
|
||||||
### 前端
|
|
||||||
|
|
||||||
- 页面/组件:
|
|
||||||
- AI 智商看板(新页面,复用 `finance-report` 看板骨架):准确率/召回率/误报率/人工干预率随时间曲线 + golden set 覆盖度。
|
|
||||||
- Canary 控制台(并入 `SettingsView` / `PoliciesView`):列出各场景 prompt 版本、流量比例、当前分数、一键回滚。
|
|
||||||
- 交互状态:加载/空态(样本不足)/错误态(评测失败)/权限态(仅算法运营)。
|
|
||||||
- 展示规则:曲线按场景(差旅/报销/预算)分面;Canary 显示置信区间,差异不显著时标注。
|
|
||||||
- 降级处理:看板数据不可用时提示"数据生成中",不报错。
|
|
||||||
|
|
||||||
### 后端
|
|
||||||
|
|
||||||
- 接口/服务(新增,按职责拆分,单文件 ≤ 800 行):
|
|
||||||
- `services/few_shot_retrieval.py`:样本检索器,复用 Qdrant,输入 case 特征 → 输出 top-k 样本(带人工结论)。
|
|
||||||
- `services/prompt_registry.py`:prompt 版本注册中心,按场景 + 策略(latest/canary/pinned)取 prompt。
|
|
||||||
- `services/eval_harness.py`:在 golden set 上跑评测,输出指标;被发布门禁和智商看板共用。
|
|
||||||
- `services/feedback_attribution.py`:低分归因 agent,复用 `AgentTraceCenter` 数据。
|
|
||||||
- `services/extraction_correction_recorder.py`:记录 OCR 抽取字段的人工修正 diff。
|
|
||||||
- 改造点(在现有 prompt 构造文件加 inject 钩子,不改业务接口):
|
|
||||||
- `risk_rule_generation_prompt.py`、`user_agent_application.py`、`expense_claim_pre_review.py`、`document_intelligence_rules.py`、`ontology_extraction.py` 的 prompt 构造处。
|
|
||||||
- 权限与校验:Canary 控制台仅算法/运营角色;门禁关闭需审计日志。
|
|
||||||
- 持久化(新表,Alembic 迁移):
|
|
||||||
- `prompt_version`:id / scene / content / version / status(stable/canary/pinned) / eval_score / created_by / created_at。
|
|
||||||
- `golden_set`:id / scene / case_payload / expected / source(accumulator/manual) / confirmed / version。
|
|
||||||
- `extraction_correction`:id / attachment_id / field / raw_value / corrected_value / operator / created_at。
|
|
||||||
- `eval_run`:id / prompt_version_id / scene / metrics_json / started_at / finished_at。
|
|
||||||
- 降级处理:所有飞轮组件故障均降级到无 few-shot + stable prompt,主链路不阻塞。
|
|
||||||
|
|
||||||
### 算法与规则
|
|
||||||
|
|
||||||
- 输入:case 特征向量(场景标签 + 文本摘要 + 关键字段)、golden set、反馈样本、修正 diff。
|
|
||||||
- 流程:
|
|
||||||
1. 推理前:`few_shot_retrieval` 检索 top-k → 拼 system prompt。
|
|
||||||
2. 推理后:结果 + 反馈写入 accumulator。
|
|
||||||
3. 发布前:`eval_harness` 在 golden set 上跑分 → 门禁判定。
|
|
||||||
4. 低分触发:`feedback_attribution` 归因 → 改进任务回写样本池。
|
|
||||||
- 输出:few-shot 样本块、评测指标、归因结论、智商曲线数据点。
|
|
||||||
- 解释:few-shot 注入在 prompt 中保留"参考案例(历史已确认)"段落,可追溯;评测报告附错误 case 列表;归因输出错误环节标签 + 证据 trace 片段。
|
|
||||||
|
|
||||||
### 数据与契约
|
|
||||||
|
|
||||||
- 核心字段:scene、case_signature、few_shot_samples、metrics(acc/recall/fpr/f1)、prompt_version_id、status。
|
|
||||||
- 状态枚举:
|
|
||||||
- prompt: `stable` / `canary` / `pinned` / `archived`。
|
|
||||||
- golden case: `draft` / `confirmed` / `deprecated`。
|
|
||||||
- eval: `pass` / `fail` / `blocked`。
|
|
||||||
- 兼容策略:prompt_registry 找不到版本时回退到当前硬编码 prompt(保证向后兼容)。
|
|
||||||
- 版本/审计:每次 prompt / 规则 / golden set 变更记 `eval_run`,可回放历史。
|
|
||||||
|
|
||||||
## 算法与公式
|
|
||||||
|
|
||||||
### few-shot 检索排序
|
|
||||||
|
|
||||||
```text
|
|
||||||
score(sample, case) = α * sim(emb(sample), emb(case)) + β * match(meta(sample), meta(case))
|
|
||||||
```
|
|
||||||
|
|
||||||
变量说明:
|
|
||||||
|
|
||||||
- score:样本与当前 case 的综合相似度。
|
|
||||||
- sim:余弦相似度,复用 Qdrant 现有 embedding。
|
|
||||||
- match:元数据硬匹配得分(场景同 / 域同 / 级别同),取 0 或 1。
|
|
||||||
- α、β:权重,默认 α=0.8、β=0.2,可在 prompt_registry 中按场景覆盖。
|
|
||||||
- 适用边界:仅取 `confirmed=true` 的样本;top-k 默认 k=3,按 token 预算动态裁剪。
|
|
||||||
|
|
||||||
### 评测指标
|
|
||||||
|
|
||||||
```text
|
|
||||||
precision = TP / (TP + FP)
|
|
||||||
recall = TP / (TP + FN)
|
|
||||||
f1 = 2 * precision * recall / (precision + recall)
|
|
||||||
```
|
|
||||||
|
|
||||||
变量说明:
|
|
||||||
|
|
||||||
- TP/FP/FN:在 golden set 上推理结论与 expected 比对得出(结论为风险标记/字段值/分类标签三类场景各有比对器)。
|
|
||||||
- 发布门禁默认阈值:recall ≥ 上一版 stable 的 recall 且 f1 不下降超过 2 个百分点,否则 `fail`。
|
|
||||||
|
|
||||||
## 测试方案
|
|
||||||
|
|
||||||
后端:
|
|
||||||
|
|
||||||
- `few_shot_retrieval` 单测:样本池空 / 检索失败 / top-k 截断 / 仅取 confirmed 样本。
|
|
||||||
- `eval_harness` 单测:golden set 跑分指标正确性、门禁通过/拦截逻辑、空 golden set 降级。
|
|
||||||
- `prompt_registry` 单测:按策略取版本、回退到硬编码、Canary 流量切分比例。
|
|
||||||
- `feedback_attribution` 单测:mock trace 数据,归因标签正确性。
|
|
||||||
|
|
||||||
前端:
|
|
||||||
|
|
||||||
- AI 智商看板视图模型:空态、加载态、错误态、曲线渲染。
|
|
||||||
- Canary 控制台:列表、切流量、回滚交互。
|
|
||||||
|
|
||||||
集成:
|
|
||||||
|
|
||||||
- 端到端:构造一份 golden set → 改 prompt → 发布被门禁拦截 / 通过 → 智商看板出现新数据点。
|
|
||||||
- 容器内运行:`docker exec -w /app -e SERVER_VENV_DIR=/tmp/x-financial-server-venv local-x-financial-linux /tmp/x-financial-server-venv/bin/pytest -q server/tests/...`,超时 60s。
|
|
||||||
|
|
||||||
手工验证:
|
|
||||||
|
|
||||||
- 在 AI 工作台触发一次预审,确认 prompt 中出现 few-shot 块。
|
|
||||||
- 在 Canary 控制台发布一版劣化 prompt,确认被门禁拦下。
|
|
||||||
|
|
||||||
## 指标与验收
|
|
||||||
|
|
||||||
- [A1] 功能验收:推理时 prompt 中可见 few-shot 块,且样本来自 confirmed 池。
|
|
||||||
- [A2] 性能指标:few-shot 检索 ≤ 200ms(P95),不显著拖慢主链路;eval 单场景 ≤ 60s。
|
|
||||||
- [A3] 质量指标:golden set 覆盖至少 5 个核心场景;门禁能正确拦截劣化 prompt。
|
|
||||||
- [A4] 安全/权限指标:Canary 控制台仅算法/运营可操作;门禁关闭记审计日志。
|
|
||||||
- [A5] 可观测性:AI 智商看板按周生成曲线;每次 eval_run 可回放。
|
|
||||||
|
|
||||||
## 风险与开放问题
|
|
||||||
|
|
||||||
- 风险:
|
|
||||||
- few-shot 注入增加 prompt 长度,可能触发 token 上限或拖慢推理 → 用 token 预算裁剪 + P95 监控兜底。
|
|
||||||
- 样本池噪音(错误标注)污染推理 → 只取 confirmed 样本 + 评测门禁把关。
|
|
||||||
- 评测 golden set 与线上分布漂移 → 季度复审 golden set,标注漂移度。
|
|
||||||
- 已处理依赖:复用 Qdrant / LightRAG / accumulator / AgentTraceCenter / risk_rule_generation 现有能力。
|
|
||||||
- 待确认:
|
|
||||||
- Canary 流量切分的具体比例(建议 90/10)需与业务确认。
|
|
||||||
- 智商看板放哪个一级菜单(Settings 还是独立"AI 运营"菜单)。
|
|
||||||
- 政策新鲜度检测是否本轮接入。
|
|
||||||
- 降级策略:任何飞轮组件故障 → 无 few-shot + stable prompt + 跳过门禁(仅 stable),保证业务连续。
|
|
||||||
|
|
||||||
## 本轮实现记录
|
|
||||||
|
|
||||||
- 2026-07-03:完成数据飞轮概念文档与开发 TODO 拆分,作为后续改造的总纲。
|
|
||||||
@@ -1,98 +0,0 @@
|
|||||||
# AI 数据飞轮 开发 TODO
|
|
||||||
|
|
||||||
更新时间:2026-07-03
|
|
||||||
|
|
||||||
文档路径:document/development/2026-07-03/feature/ai-data-flywheel/TODO.md
|
|
||||||
|
|
||||||
## 使用规则
|
|
||||||
|
|
||||||
- 每个 TODO 必须对应 `CONCEPT.md` 中的目标、能力、方案或验收点。
|
|
||||||
- 只有完成并验证后,才能把 `[ ]` 改成 `[x]`。
|
|
||||||
- 勾选时在任务后补充简短证据,例如文件、接口、命令或验证结果。
|
|
||||||
- 如果需求发生变化,先更新 `CONCEPT.md`,再调整本 TODO。
|
|
||||||
- 实施顺序建议:阶段 1 → 2 → 3(飞轮 1+2 是地基)→ 4/5/6 并行 → 7。
|
|
||||||
|
|
||||||
## 1. 调研与边界
|
|
||||||
|
|
||||||
- [x] [CONCEPT: 背景与问题] 盘点现有 accumulator / feedback / RAG / 规则生成能力,确认飞轮骨架已存在、断点在"检索注入 + 评测门禁"。
|
|
||||||
证据:`server/src/app/skills/domain/{false-positive-sample-accumulator,risk-feedback-sample-accumulator,risk-clue-collector}`、`services/agent_feedback.py`、`services/knowledge_rag_runtime.py`、`services/risk_rule_generation*.py`。
|
|
||||||
- [x] [CONCEPT: 目标与非目标] 确认本轮范围 = 飞轮 1-6(few-shot 注入 / golden set 门禁 / prompt 版本化 / 抽取修正回流 / 低分归因 / 智商看板),不做模型微调、不改业务接口契约。
|
|
||||||
证据:CONCEPT.md「目标与非目标」章节。
|
|
||||||
- [ ] [CONCEPT: 风险与开放问题] 与业务确认 Canary 流量比例、智商看板菜单位置、政策新鲜度检测是否本轮接入。
|
|
||||||
证据:
|
|
||||||
|
|
||||||
## 2. 契约与设计
|
|
||||||
|
|
||||||
- [ ] [CONCEPT: 功能能力] 定义 4 张新表的字段、状态枚举(prompt stable/canary/pinned/archived、golden draft/confirmed/deprecated、eval pass/fail/blocked、correction)。
|
|
||||||
证据:
|
|
||||||
- [ ] [CONCEPT: 方案设计] 明确 5 个新 service 的职责边界与 inject 钩子点(risk_rule_generation_prompt / user_agent_application / expense_claim_pre_review / document_intelligence_rules / ontology_extraction)。
|
|
||||||
证据:
|
|
||||||
- [ ] [CONCEPT: 算法与公式] 确认 few-shot 检索排序公式权重(默认 α=0.8 β=0.2)与门禁阈值(recall 不降、f1 下降 ≤ 2pp)。
|
|
||||||
证据:
|
|
||||||
- [ ] [CONCEPT: 指标与验收] 把验收点 A1-A5 转成可验证检查项,附命令与期望结果。
|
|
||||||
证据:
|
|
||||||
|
|
||||||
## 3. 后端实现
|
|
||||||
|
|
||||||
- [x] [CONCEPT: 后端] 新增 `services/few_shot_retrieval.py`:复用 Qdrant,按 case 特征检索 top-k confirmed 样本,带 token 预算裁剪。
|
|
||||||
证据:`server/src/app/services/few_shot_retrieval.py`;`server/src/app/services/few_shot_store.py`(独立 Qdrant collection `few_shot_samples`);`server/src/app/services/embedding_provider.py`(公共 EmbeddingProvider,复用 knowledge_rag_runtime 的 HTTP 调用)。
|
|
||||||
- [ ] [CONCEPT: 后端] 新增 `services/prompt_registry.py`:prompt 版本 CRUD + 策略取版(latest/canary/pinned)+ 回退硬编码。
|
|
||||||
证据:飞轮 3(prompt 版本化 + Canary)未启动,本轮只做飞轮 1。
|
|
||||||
- [ ] [CONCEPT: 后端] 新增 `services/eval_harness.py`:在 golden set 上跑评测,输出 precision/recall/f1,供门禁与看板共用。
|
|
||||||
证据:飞轮 2(golden set + 门禁)未启动,本轮只做飞轮 1。
|
|
||||||
- [ ] [CONCEPT: 后端] 新增 `services/feedback_attribution.py`:低分反馈触发,复用 AgentTraceCenter trace 做归因,输出错误环节标签 + 改进任务。
|
|
||||||
证据:飞轮 5(低分归因)未启动,本轮只做飞轮 1。
|
|
||||||
- [ ] [CONCEPT: 后端] 新增 `services/extraction_correction_recorder.py`:在附件/明细字段更新处记录 raw vs corrected diff。
|
|
||||||
证据:飞轮 4(抽取修正回流)未启动,本轮只做飞轮 1。
|
|
||||||
- [ ] [CONCEPT: 后端] Alembic 迁移:prompt_version / golden_set / extraction_correction / eval_run 四张表。
|
|
||||||
证据:本轮新增的是 FewShotSample 一张表(`server/src/app/models/few_shot_sample.py`),项目靠 `Base.metadata.create_all()` 建表(无 alembic versions/ 目录),已注册到 `db/base.py` 和 `models/__init__.py`。其余三表随对应飞轮再建。
|
|
||||||
- [x] [CONCEPT: 后端] 新增 `services/few_shot_ingestion.py`:RiskObservation confirmed/false_positive → FewShotSample + Qdrant 向量,在 `risk_observations.create_feedback` commit 后 hook 触发。
|
|
||||||
证据:`server/src/app/services/few_shot_ingestion.py`;`server/src/app/services/risk_observations.py:324-345`(`_maybe_ingest_few_shot` hook,带 feature flag + try/except 兜底)。
|
|
||||||
- [x] [CONCEPT: 数据与契约] 在现有 prompt 构造文件加 few-shot 注入,不改业务接口。
|
|
||||||
证据:`server/src/app/services/risk_rule_generation_prompt.py`(新增 `few_shot_samples` 可选 kwarg,合并进 examples 字段);`server/src/app/services/risk_rule_generation.py:271-292`(`_retrieve_few_shot_samples` 在构造 messages 前调用,失败降级为空)。
|
|
||||||
|
|
||||||
## 4. 算法/规则实现
|
|
||||||
|
|
||||||
- [x] [CONCEPT: 算法与规则] 实现few-shot 检索排序(向量相似度 + 元数据硬匹配),只取 confirmed 样本。
|
|
||||||
证据:`server/src/app/services/few_shot_store.py`(Qdrant 余弦相似度 + payload 过滤 scene/label/status);`few_shot_retrieval.py` 去重 + token 预算 + 单条字符上限裁剪。检索仅取 label ∈ {confirmed, false_positive}。
|
|
||||||
- [ ] [CONCEPT: 算法与规则] 实现评测指标比对器(风险标记 / 字段值 / 分类标签 三类场景)。
|
|
||||||
证据:飞轮 2,未启动。
|
|
||||||
- [ ] [CONCEPT: 算法与规则] 接入发布门禁:`agent_asset_risk_rule_publish` 前调 eval_harness,不达标 block。
|
|
||||||
证据:飞轮 2,未启动。
|
|
||||||
- [ ] [CONCEPT: 算法与规则] 接入 Canary 流量切分(默认 90 stable / 10 canary)+ 劣化自动回滚。
|
|
||||||
证据:飞轮 3,未启动。
|
|
||||||
- [x] [CONCEPT: 结果解释] few-shot 块在 prompt 中保留 `source: "historical_confirmed"` 标记,可追溯。
|
|
||||||
证据:`risk_rule_generation_prompt.py` 合并 examples 时每条历史样本带 `source`/`label`/`conclusion` 字段。
|
|
||||||
|
|
||||||
## 5. 前端实现
|
|
||||||
|
|
||||||
- [ ] [CONCEPT: 前端] AI 智商看板新页面:准确率/召回率/误报率/人工干预率随时间曲线 + golden set 覆盖度,复用 `finance-report` 看板骨架。
|
|
||||||
证据:
|
|
||||||
- [ ] [CONCEPT: 前端] Canary 控制台(并入 Settings/Policies):prompt 版本列表、流量比例、分数、一键回滚。
|
|
||||||
证据:
|
|
||||||
- [ ] [CONCEPT: 前端] 实现加载/空态(样本不足)/错误态(评测失败)/权限态(仅算法运营)。
|
|
||||||
证据:
|
|
||||||
- [ ] [CONCEPT: 前端] 对齐现有企业后台风格(参考 `chat-ui-saas-styling` / `theme-settings-enterprise-ai-style`),避免营销页观感。
|
|
||||||
证据:
|
|
||||||
|
|
||||||
## 6. 测试与验证
|
|
||||||
|
|
||||||
- [x] [CONCEPT: 测试方案] 后端单测:embedding_provider(GLM/Ollama 分支、维度缓存、HTTP 错误降级)、few_shot_ingestion(confirmed/false_positive 入库、ignored 跳过、幂等去重、hook 触发、feature flag、吞异常)、few_shot_retrieval(去重、token 预算、超长截断)+ prompt 注入(合并 examples、向后兼容)。
|
|
||||||
证据:`server/tests/test_embedding_provider.py`、`server/tests/test_few_shot_ingestion.py`、`server/tests/test_few_shot_retrieval_and_prompt.py`,容器内 `pytest -q` 20 passed。
|
|
||||||
- [ ] [CONCEPT: 测试方案] 前端:智商看板与 Canary 控制台视图模型 + 构建验证。
|
|
||||||
证据:飞轮 3/6 前端,未启动。
|
|
||||||
- [ ] [CONCEPT: 测试方案] 集成:golden set → 改 prompt → 门禁拦截/通过 → 看板新增数据点,容器内跑通。
|
|
||||||
证据:飞轮 2 集成,未启动。
|
|
||||||
- [x] [CONCEPT: 测试方案] 回归:现有 RAG / risk_observations / risk_rule_generation 测试全过。
|
|
||||||
证据:容器内 `pytest -q server/tests/test_risk_observations_service.py server/tests/test_knowledge_rag_runtime.py server/tests/test_risk_rule_generation.py server/tests/test_risk_rule_generation_failure.py` → 35 passed,EmbeddingProvider 抽离零回归。
|
|
||||||
- [ ] [CONCEPT: 指标与验收] 记录验证命令与结果,确认 P95 检索 ≤ 200ms、单场景评测 ≤ 60s。
|
|
||||||
证据:性能指标待飞轮 2 评测上线后连同 golden set 一起量。
|
|
||||||
|
|
||||||
## 7. 文档收尾
|
|
||||||
|
|
||||||
- [x] [CONCEPT: 指标与验收] 飞轮 1(few-shot 注入)A1 功能验收已达成:推理时 prompt 中可见带 `source: "historical_confirmed"` 的 few-shot 块,且样本来自 confirmed/false_positive 池。A5 可观测性部分达成(可追溯 source)。A2/A3/A4 随飞轮 2/3 补齐。
|
|
||||||
证据:见阶段 3/4/6 已勾选项。
|
|
||||||
- [ ] [CONCEPT: 风险与开放问题] 更新 Canary 比例、看板菜单位置、政策新鲜度检测的最终结论与剩余风险。
|
|
||||||
证据:飞轮 2-6 启动后再定稿。
|
|
||||||
- [x] [CONCEPT: 功能一句话] 确认飞轮 1 实现没有偏离"让系统越用越聪明"的原始目标。
|
|
||||||
证据:人工确认风险观测 → 自动入库 + 向量化 → 下次规则编译时检索注入相似历史样本,形成"用得越多 → 样本越丰富 → 推理越准"的闭环。飞轮 2-6 待后续迭代。
|
|
||||||
@@ -43,6 +43,10 @@ from app.schemas.agent_asset import (
|
|||||||
AgentAssetVersionCreate,
|
AgentAssetVersionCreate,
|
||||||
AgentAssetVersionRead,
|
AgentAssetVersionRead,
|
||||||
AgentAssetVersionTimelineItemRead,
|
AgentAssetVersionTimelineItemRead,
|
||||||
|
GoldenCaseCreate,
|
||||||
|
GoldenCaseRead,
|
||||||
|
GoldenEvalRead,
|
||||||
|
GoldenEvalRequest,
|
||||||
)
|
)
|
||||||
from app.schemas.common import ErrorResponse, PaginatedResponse
|
from app.schemas.common import ErrorResponse, PaginatedResponse
|
||||||
from app.services.agent_assets import AgentAssetService
|
from app.services.agent_assets import AgentAssetService
|
||||||
@@ -923,3 +927,110 @@ def get_agent_asset_version_timeline(
|
|||||||
return AgentAssetService(db).list_version_timeline(asset_id)
|
return AgentAssetService(db).list_version_timeline(asset_id)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
_handle_asset_error(exc)
|
_handle_asset_error(exc)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/risk-rules/golden-cases",
|
||||||
|
response_model=GoldenCaseRead,
|
||||||
|
status_code=status.HTTP_201_CREATED,
|
||||||
|
summary="创建 golden set 黄金用例",
|
||||||
|
description="为指定规则(或通用场景)创建一条回归用例,发布前作为门禁集执行。",
|
||||||
|
)
|
||||||
|
def create_golden_case(
|
||||||
|
body: GoldenCaseCreate,
|
||||||
|
_: RuleEditorUser,
|
||||||
|
db: DbSession,
|
||||||
|
) -> GoldenCaseRead:
|
||||||
|
from app.models.golden_case import GoldenCase
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
existing = db.scalar(select(GoldenCase).where(GoldenCase.case_key == body.case_key))
|
||||||
|
if existing is not None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="case_key 已存在")
|
||||||
|
case = GoldenCase(
|
||||||
|
case_key=body.case_key,
|
||||||
|
rule_code=body.rule_code,
|
||||||
|
scene=body.scene,
|
||||||
|
name=body.name,
|
||||||
|
values_json=body.values,
|
||||||
|
expected_hit=body.expected_hit,
|
||||||
|
expected_severity=body.expected_severity,
|
||||||
|
note=body.note,
|
||||||
|
status="active",
|
||||||
|
source="manual",
|
||||||
|
)
|
||||||
|
db.add(case)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(case)
|
||||||
|
return _golden_case_read(case)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/risk-rules/{rule_code}/golden-cases",
|
||||||
|
response_model=list[GoldenCaseRead],
|
||||||
|
summary="列出规则的 golden 用例",
|
||||||
|
)
|
||||||
|
def list_golden_cases(
|
||||||
|
rule_code: str,
|
||||||
|
_: CurrentUser,
|
||||||
|
db: DbSession,
|
||||||
|
) -> list[GoldenCaseRead]:
|
||||||
|
from app.models.golden_case import GoldenCase
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
cases = db.scalars(
|
||||||
|
select(GoldenCase).where(GoldenCase.rule_code == rule_code).order_by(GoldenCase.created_at)
|
||||||
|
).all()
|
||||||
|
return [_golden_case_read(case) for case in cases]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/{asset_id}/golden-eval",
|
||||||
|
response_model=GoldenEvalRead,
|
||||||
|
summary="手动触发 golden set 评测(不入门禁)",
|
||||||
|
description="在当前规则版本上跑 golden 用例集,返回指标。门禁由 publish 时自动触发。",
|
||||||
|
)
|
||||||
|
def run_golden_eval(
|
||||||
|
asset_id: str,
|
||||||
|
body: GoldenEvalRequest,
|
||||||
|
_: RuleReviewerUser,
|
||||||
|
db: DbSession,
|
||||||
|
) -> GoldenEvalRead:
|
||||||
|
from app.services.agent_asset_spreadsheet import RISK_RULES_LIBRARY
|
||||||
|
from app.services.risk_rule_golden_evaluator import RiskRuleGoldenEvaluator
|
||||||
|
|
||||||
|
try:
|
||||||
|
asset = AgentAssetService(db).get_asset(asset_id)
|
||||||
|
if asset is None:
|
||||||
|
raise LookupError("Asset not found")
|
||||||
|
config = asset.config_json if isinstance(asset.config_json, dict) else {}
|
||||||
|
rule_document = config.get("rule_document") if isinstance(config.get("rule_document"), dict) else {}
|
||||||
|
file_name = str(rule_document.get("file_name") or "").strip()
|
||||||
|
if not file_name:
|
||||||
|
raise ValueError("该规则没有可执行的 manifest 文件。")
|
||||||
|
manager = AgentAssetService(db).rule_library_manager
|
||||||
|
manifest = manager.read_rule_library_json(library=RISK_RULES_LIBRARY, file_name=file_name)
|
||||||
|
rule_code = str(manifest.get("rule_code") or "").strip()
|
||||||
|
if not rule_code:
|
||||||
|
raise ValueError("manifest 缺少 rule_code。")
|
||||||
|
version = body.version or asset.working_version or ""
|
||||||
|
report = RiskRuleGoldenEvaluator().evaluate_for_rule(db, manifest, rule_code)
|
||||||
|
return GoldenEvalRead(**report.to_dict())
|
||||||
|
except Exception as exc:
|
||||||
|
_handle_asset_error(exc)
|
||||||
|
|
||||||
|
|
||||||
|
def _golden_case_read(case) -> GoldenCaseRead:
|
||||||
|
return GoldenCaseRead(
|
||||||
|
id=case.id,
|
||||||
|
case_key=case.case_key,
|
||||||
|
rule_code=case.rule_code,
|
||||||
|
scene=case.scene or "",
|
||||||
|
name=case.name or "",
|
||||||
|
values=case.values_json or {},
|
||||||
|
expected_hit=bool(case.expected_hit),
|
||||||
|
expected_severity=case.expected_severity,
|
||||||
|
note=case.note,
|
||||||
|
status=case.status,
|
||||||
|
source=case.source,
|
||||||
|
)
|
||||||
|
|||||||
@@ -15,13 +15,13 @@ from app.models.budget import BudgetAllocation, BudgetReservation, BudgetTransac
|
|||||||
from app.models.employee_change_log import EmployeeChangeLog
|
from app.models.employee_change_log import EmployeeChangeLog
|
||||||
from app.models.employee_behavior_profile import EmployeeBehaviorProfileSnapshot
|
from app.models.employee_behavior_profile import EmployeeBehaviorProfileSnapshot
|
||||||
from app.models.employee import Employee
|
from app.models.employee import Employee
|
||||||
from app.models.few_shot_sample import FewShotSample
|
|
||||||
from app.models.financial_record import (
|
from app.models.financial_record import (
|
||||||
AccountsPayableRecord,
|
AccountsPayableRecord,
|
||||||
AccountsReceivableRecord,
|
AccountsReceivableRecord,
|
||||||
ExpenseClaim,
|
ExpenseClaim,
|
||||||
ExpenseClaimItem,
|
ExpenseClaimItem,
|
||||||
)
|
)
|
||||||
|
from app.models.golden_case import GoldenCase
|
||||||
from app.models.hermes_config import HermesTaskConfig, HermesTaskExecutionLog
|
from app.models.hermes_config import HermesTaskConfig, HermesTaskExecutionLog
|
||||||
from app.models.hermes_report import HermesRiskReport
|
from app.models.hermes_report import HermesRiskReport
|
||||||
from app.models.notification_state import NotificationState
|
from app.models.notification_state import NotificationState
|
||||||
@@ -58,8 +58,8 @@ __all__ = [
|
|||||||
"EmployeeBehaviorProfileSnapshot",
|
"EmployeeBehaviorProfileSnapshot",
|
||||||
"EmployeeChangeLog",
|
"EmployeeChangeLog",
|
||||||
"ExpenseClaim",
|
"ExpenseClaim",
|
||||||
"FewShotSample",
|
|
||||||
"ExpenseClaimItem",
|
"ExpenseClaimItem",
|
||||||
|
"GoldenCase",
|
||||||
"HermesTaskConfig",
|
"HermesTaskConfig",
|
||||||
"HermesTaskExecutionLog",
|
"HermesTaskExecutionLog",
|
||||||
"HermesRiskReport",
|
"HermesRiskReport",
|
||||||
|
|||||||
@@ -8,13 +8,13 @@ from app.models.budget import BudgetAllocation, BudgetReservation, BudgetTransac
|
|||||||
from app.models.employee_change_log import EmployeeChangeLog
|
from app.models.employee_change_log import EmployeeChangeLog
|
||||||
from app.models.employee_behavior_profile import EmployeeBehaviorProfileSnapshot
|
from app.models.employee_behavior_profile import EmployeeBehaviorProfileSnapshot
|
||||||
from app.models.employee import Employee
|
from app.models.employee import Employee
|
||||||
from app.models.few_shot_sample import FewShotSample
|
|
||||||
from app.models.financial_record import (
|
from app.models.financial_record import (
|
||||||
AccountsPayableRecord,
|
AccountsPayableRecord,
|
||||||
AccountsReceivableRecord,
|
AccountsReceivableRecord,
|
||||||
ExpenseClaim,
|
ExpenseClaim,
|
||||||
ExpenseClaimItem,
|
ExpenseClaimItem,
|
||||||
)
|
)
|
||||||
|
from app.models.golden_case import GoldenCase
|
||||||
from app.models.hermes_config import HermesTaskConfig, HermesTaskExecutionLog
|
from app.models.hermes_config import HermesTaskConfig, HermesTaskExecutionLog
|
||||||
from app.models.hermes_report import HermesRiskReport
|
from app.models.hermes_report import HermesRiskReport
|
||||||
from app.models.notification_state import NotificationState
|
from app.models.notification_state import NotificationState
|
||||||
@@ -50,7 +50,7 @@ __all__ = [
|
|||||||
"EmployeeChangeLog",
|
"EmployeeChangeLog",
|
||||||
"ExpenseClaim",
|
"ExpenseClaim",
|
||||||
"ExpenseClaimItem",
|
"ExpenseClaimItem",
|
||||||
"FewShotSample",
|
"GoldenCase",
|
||||||
"HermesTaskConfig",
|
"HermesTaskConfig",
|
||||||
"HermesTaskExecutionLog",
|
"HermesTaskExecutionLog",
|
||||||
"HermesRiskReport",
|
"HermesRiskReport",
|
||||||
|
|||||||
@@ -1,54 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import uuid
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from sqlalchemy import DateTime, ForeignKey, Index, String, Text, func
|
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
|
||||||
from sqlalchemy.types import JSON
|
|
||||||
|
|
||||||
from app.db.base_class import Base
|
|
||||||
|
|
||||||
|
|
||||||
class FewShotSample(Base):
|
|
||||||
"""已确认的风险观测样本,供 few-shot 检索注入使用。
|
|
||||||
|
|
||||||
数据来源是 ``RiskObservation`` 上人工确认为 confirmed / false_positive 的观测,
|
|
||||||
入库后同时写一份向量到 Qdrant 的 ``few_shot_samples`` collection。
|
|
||||||
"""
|
|
||||||
|
|
||||||
__tablename__ = "few_shot_samples"
|
|
||||||
__table_args__ = (
|
|
||||||
Index("ix_few_shot_samples_scene_label", "scene", "label"),
|
|
||||||
Index("ix_few_shot_samples_domain_risk_type", "domain", "risk_type"),
|
|
||||||
)
|
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
|
||||||
sample_key: Mapped[str] = mapped_column(String(160), unique=True, index=True)
|
|
||||||
source_observation_id: Mapped[str | None] = mapped_column(
|
|
||||||
ForeignKey("risk_observations.id"),
|
|
||||||
nullable=True,
|
|
||||||
index=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
scene: Mapped[str] = mapped_column(String(50), default="risk_rule_generation", index=True)
|
|
||||||
domain: Mapped[str] = mapped_column(String(50), default="", index=True)
|
|
||||||
risk_type: Mapped[str] = mapped_column(String(80), default="", index=True)
|
|
||||||
risk_level: Mapped[str] = mapped_column(String(20), default="")
|
|
||||||
|
|
||||||
label: Mapped[str] = mapped_column(String(30), default="confirmed", index=True)
|
|
||||||
case_text: Mapped[str] = mapped_column(Text(), default="")
|
|
||||||
conclusion_text: Mapped[str] = mapped_column(Text(), default="")
|
|
||||||
payload_json: Mapped[dict[str, Any]] = mapped_column(JSON, default=dict)
|
|
||||||
|
|
||||||
vector_id: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
|
||||||
status: Mapped[str] = mapped_column(String(20), default="active", index=True)
|
|
||||||
|
|
||||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=func.now(), server_default=func.now())
|
|
||||||
updated_at: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime,
|
|
||||||
default=func.now(),
|
|
||||||
onupdate=func.now(),
|
|
||||||
server_default=func.now(),
|
|
||||||
)
|
|
||||||
48
server/src/app/models/golden_case.py
Normal file
48
server/src/app/models/golden_case.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy import Boolean, DateTime, Index, String, Text, func
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
from sqlalchemy.types import JSON
|
||||||
|
|
||||||
|
from app.db.base_class import Base
|
||||||
|
|
||||||
|
|
||||||
|
class GoldenCase(Base):
|
||||||
|
"""风险规则回归门禁用的黄金用例。
|
||||||
|
|
||||||
|
由运营手动维护(或从已确认风险观测导入),在规则发布前作为回归集执行,
|
||||||
|
100% 通过才放行。``values_json`` 复用 ``AgentAssetRiskRuleSampleCase.values``
|
||||||
|
的扁平字典格式,``expected_hit`` / ``expected_severity`` 作为 ground truth。
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "golden_cases"
|
||||||
|
__table_args__ = (
|
||||||
|
Index("ix_golden_cases_rule_code_status", "rule_code", "status"),
|
||||||
|
Index("ix_golden_cases_scene_status", "scene", "status"),
|
||||||
|
)
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||||
|
case_key: Mapped[str] = mapped_column(String(160), unique=True, index=True)
|
||||||
|
rule_code: Mapped[str | None] = mapped_column(String(120), nullable=True, index=True)
|
||||||
|
scene: Mapped[str] = mapped_column(String(50), default="", index=True)
|
||||||
|
|
||||||
|
name: Mapped[str] = mapped_column(String(120), default="")
|
||||||
|
values_json: Mapped[dict[str, Any]] = mapped_column(JSON, default=dict)
|
||||||
|
expected_hit: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||||
|
expected_severity: Mapped[str | None] = mapped_column(String(20), nullable=True)
|
||||||
|
note: Mapped[str | None] = mapped_column(Text(), nullable=True)
|
||||||
|
|
||||||
|
status: Mapped[str] = mapped_column(String(20), default="active", index=True)
|
||||||
|
source: Mapped[str] = mapped_column(String(30), default="manual")
|
||||||
|
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime, default=func.now(), server_default=func.now())
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime,
|
||||||
|
default=func.now(),
|
||||||
|
onupdate=func.now(),
|
||||||
|
server_default=func.now(),
|
||||||
|
)
|
||||||
@@ -204,6 +204,46 @@ class AgentAssetRiskRuleReportRequest(BaseModel):
|
|||||||
note: str | None = Field(default=None, max_length=1000)
|
note: str | None = Field(default=None, max_length=1000)
|
||||||
|
|
||||||
|
|
||||||
|
class GoldenCaseCreate(BaseModel):
|
||||||
|
case_key: str = Field(..., max_length=160)
|
||||||
|
rule_code: str | None = Field(default=None, max_length=120)
|
||||||
|
scene: str = Field(default="", max_length=50)
|
||||||
|
name: str = Field(default="", max_length=120)
|
||||||
|
values: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
expected_hit: bool = True
|
||||||
|
expected_severity: str | None = Field(default=None, max_length=20)
|
||||||
|
note: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class GoldenCaseRead(BaseModel):
|
||||||
|
id: str
|
||||||
|
case_key: str
|
||||||
|
rule_code: str | None = None
|
||||||
|
scene: str = ""
|
||||||
|
name: str = ""
|
||||||
|
values: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
expected_hit: bool = True
|
||||||
|
expected_severity: str | None = None
|
||||||
|
note: str | None = None
|
||||||
|
status: str = "active"
|
||||||
|
source: str = "manual"
|
||||||
|
|
||||||
|
|
||||||
|
class GoldenEvalRequest(BaseModel):
|
||||||
|
version: str | None = Field(default=None, max_length=30)
|
||||||
|
|
||||||
|
|
||||||
|
class GoldenEvalRead(BaseModel):
|
||||||
|
total: int = 0
|
||||||
|
passed_count: int = 0
|
||||||
|
failed_count: int = 0
|
||||||
|
accuracy: float = 0.0
|
||||||
|
precision: float = 0.0
|
||||||
|
recall: float = 0.0
|
||||||
|
all_passed: bool = True
|
||||||
|
results: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class AgentAssetRiskRuleSimulationAttachment(BaseModel):
|
class AgentAssetRiskRuleSimulationAttachment(BaseModel):
|
||||||
name: str = Field(default="", max_length=240)
|
name: str = Field(default="", max_length=240)
|
||||||
content_type: str | None = Field(default=None, max_length=120)
|
content_type: str | None = Field(default=None, max_length=120)
|
||||||
|
|||||||
@@ -39,6 +39,9 @@ class AgentAssetRiskRulePublishMixin:
|
|||||||
if not self.get_latest_risk_rule_test_summary(asset, version=version).test_passed:
|
if not self.get_latest_risk_rule_test_summary(asset, version=version).test_passed:
|
||||||
raise PermissionError("当前规则版本尚未完成测试通过确认,不能发布。")
|
raise PermissionError("当前规则版本尚未完成测试通过确认,不能发布。")
|
||||||
|
|
||||||
|
# golden set 回归门禁:在 golden 用例集上跑规则,未 100% 通过则拦截发布。
|
||||||
|
self._require_golden_set_passed(asset, version, actor=actor)
|
||||||
|
|
||||||
before = self._asset_snapshot(asset)
|
before = self._asset_snapshot(asset)
|
||||||
self._ensure_approved_review(asset, version=version, actor=actor, note="发布上线前审核通过。")
|
self._ensure_approved_review(asset, version=version, actor=actor, note="发布上线前审核通过。")
|
||||||
asset.reviewer = actor
|
asset.reviewer = actor
|
||||||
@@ -176,6 +179,49 @@ class AgentAssetRiskRulePublishMixin:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _require_golden_set_passed(
|
||||||
|
self,
|
||||||
|
asset: AgentAsset,
|
||||||
|
version: str,
|
||||||
|
*,
|
||||||
|
actor: str,
|
||||||
|
) -> None:
|
||||||
|
"""在 golden set 上跑当前规则 manifest,未 100% 通过则拦截发布。
|
||||||
|
|
||||||
|
降级策略:feature flag 关闭 / 无 rule_document / 无 golden case /
|
||||||
|
evaluator 异常 → 一律放行,不阻塞发布主链路。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
if os.environ.get("GOLDEN_SET_GATE_ENABLED", "true").strip().lower() in {"0", "false", "no"}:
|
||||||
|
return
|
||||||
|
config = asset.config_json if isinstance(asset.config_json, dict) else {}
|
||||||
|
rule_document = config.get("rule_document") if isinstance(config.get("rule_document"), dict) else {}
|
||||||
|
file_name = str(rule_document.get("file_name") or "").strip()
|
||||||
|
if not file_name:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
manifest = self.rule_library_manager.read_rule_library_json(
|
||||||
|
library=RISK_RULES_LIBRARY,
|
||||||
|
file_name=file_name,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
return
|
||||||
|
rule_code = str(manifest.get("rule_code") or "").strip()
|
||||||
|
if not rule_code:
|
||||||
|
return
|
||||||
|
from app.services.risk_rule_golden_evaluator import RiskRuleGoldenEvaluator
|
||||||
|
|
||||||
|
RiskRuleGoldenEvaluator().require_pass(
|
||||||
|
self.db,
|
||||||
|
asset,
|
||||||
|
version,
|
||||||
|
manifest,
|
||||||
|
rule_code,
|
||||||
|
actor=actor,
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _config_from_published_manifest(
|
def _config_from_published_manifest(
|
||||||
manifest: dict[str, Any],
|
manifest: dict[str, Any],
|
||||||
|
|||||||
@@ -1,138 +0,0 @@
|
|||||||
"""公共 Embedding 提供者。
|
|
||||||
|
|
||||||
把 ``knowledge_rag_runtime`` 里 embedding 调用逻辑抽出来,供 RAG 和
|
|
||||||
few-shot 检索复用。本模块只依赖现有模块级纯函数和 ``RuntimeModelConfig``,
|
|
||||||
不改动 ``_LightRagRuntime`` 的行为,RAG 路径保持零回归风险。
|
|
||||||
|
|
||||||
典型用法::
|
|
||||||
|
|
||||||
provider = EmbeddingProvider.from_settings(session)
|
|
||||||
vectors = provider.embed(["差旅超标", "票单不一致"])
|
|
||||||
dim = provider.dimension()
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from app.core.logging import get_logger
|
|
||||||
from app.services.knowledge_rag_runtime import (
|
|
||||||
DEFAULT_EMBEDDING_TIMEOUT_SECONDS,
|
|
||||||
KnowledgeRagError,
|
|
||||||
RuntimeModelConfig,
|
|
||||||
_build_headers,
|
|
||||||
_ensure_path,
|
|
||||||
_extract_embedding_vectors,
|
|
||||||
_normalize_endpoint,
|
|
||||||
_send_json_request,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = get_logger("app.services.embedding_provider")
|
|
||||||
|
|
||||||
|
|
||||||
def _runtime_model_config_from_dict(config: dict[str, str]) -> RuntimeModelConfig:
|
|
||||||
"""把 SettingsService.get_runtime_model_config 返回的 dict 转成 dataclass。"""
|
|
||||||
|
|
||||||
return RuntimeModelConfig(
|
|
||||||
slot=str(config.get("slot") or "embedding"),
|
|
||||||
provider=str(config.get("provider") or ""),
|
|
||||||
model=str(config.get("model") or ""),
|
|
||||||
endpoint=str(config.get("endpoint") or ""),
|
|
||||||
api_key=str(config.get("apiKey") or ""),
|
|
||||||
capability=str(config.get("capability") or ""),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingProvider:
|
|
||||||
"""对 embedding 模型的轻量封装。
|
|
||||||
|
|
||||||
设计要点:
|
|
||||||
- 持有一个 ``RuntimeModelConfig``,构造即固定,不依赖 LightRAG。
|
|
||||||
- 复用 ``knowledge_rag_runtime`` 的 HTTP 调用纯函数,行为与 RAG 完全一致。
|
|
||||||
- 维度采用惰性探测(首次 embed 后缓存),避免空构造就打远端。
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: RuntimeModelConfig) -> None:
|
|
||||||
self.config = config
|
|
||||||
self._dimension: int | None = None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_settings(cls, session: Any) -> "EmbeddingProvider":
|
|
||||||
"""从 SettingsService 取 embedding 配置构造 provider。"""
|
|
||||||
|
|
||||||
from app.services.settings import SettingsService
|
|
||||||
|
|
||||||
raw = SettingsService(session).get_runtime_model_config("embedding")
|
|
||||||
return cls(_runtime_model_config_from_dict(raw))
|
|
||||||
|
|
||||||
def embed(self, texts: list[str]) -> list[list[float]]:
|
|
||||||
"""对一组文本做 embedding,返回与输入等长的向量列表。"""
|
|
||||||
|
|
||||||
if not texts:
|
|
||||||
return []
|
|
||||||
return _request_embeddings_public(self.config, texts)
|
|
||||||
|
|
||||||
def dimension(self) -> int:
|
|
||||||
"""探测 embedding 维度,结果缓存。失败抛 KnowledgeRagError。"""
|
|
||||||
|
|
||||||
if self._dimension is None:
|
|
||||||
vectors = self.embed(["dimension probe"])
|
|
||||||
if not vectors or not isinstance(vectors[0], list):
|
|
||||||
raise KnowledgeRagError("无法从 embedding 模型返回结果中解析向量维度。")
|
|
||||||
self._dimension = len(vectors[0])
|
|
||||||
if self._dimension <= 0:
|
|
||||||
raise KnowledgeRagError("embedding 模型返回了无效的向量维度。")
|
|
||||||
return self._dimension
|
|
||||||
|
|
||||||
|
|
||||||
def _request_embeddings_public(
|
|
||||||
config: RuntimeModelConfig,
|
|
||||||
texts: list[str],
|
|
||||||
) -> list[list[float]]:
|
|
||||||
"""按 provider 分支构造 embedding 请求。
|
|
||||||
|
|
||||||
与 ``_LightRagRuntime._request_embeddings`` 实现保持一致,
|
|
||||||
保证 few-shot 检索与 RAG 走同一套调用语义。
|
|
||||||
"""
|
|
||||||
|
|
||||||
from app.services.model_connectivity import AZURE_API_VERSION
|
|
||||||
|
|
||||||
if config.provider == "Azure OpenAI":
|
|
||||||
from app.services.knowledge_rag_runtime import _build_azure_deployment_base
|
|
||||||
|
|
||||||
url = f"{_build_azure_deployment_base(config.endpoint, config.model)}/embeddings?api-version={AZURE_API_VERSION}"
|
|
||||||
payload: dict[str, Any] = {"input": texts}
|
|
||||||
status_code, body = _send_json_request(
|
|
||||||
"POST",
|
|
||||||
url,
|
|
||||||
headers=_build_headers(config.api_key, use_bearer=False, use_api_key=True),
|
|
||||||
payload=payload,
|
|
||||||
timeout_seconds=DEFAULT_EMBEDDING_TIMEOUT_SECONDS,
|
|
||||||
)
|
|
||||||
elif config.provider == "Ollama":
|
|
||||||
url = _ensure_path(_normalize_endpoint(config.endpoint), "api/embed")
|
|
||||||
payload = {"model": config.model, "input": texts}
|
|
||||||
status_code, body = _send_json_request(
|
|
||||||
"POST",
|
|
||||||
url,
|
|
||||||
headers={"Content-Type": "application/json", "Accept": "application/json"},
|
|
||||||
payload=payload,
|
|
||||||
timeout_seconds=DEFAULT_EMBEDDING_TIMEOUT_SECONDS,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
url = _ensure_path(_normalize_endpoint(config.endpoint), "embeddings")
|
|
||||||
payload = {"model": config.model, "input": texts}
|
|
||||||
status_code, body = _send_json_request(
|
|
||||||
"POST",
|
|
||||||
url,
|
|
||||||
headers=_build_headers(config.api_key, use_bearer=True),
|
|
||||||
payload=payload,
|
|
||||||
timeout_seconds=DEFAULT_EMBEDDING_TIMEOUT_SECONDS,
|
|
||||||
)
|
|
||||||
|
|
||||||
from http import HTTPStatus
|
|
||||||
|
|
||||||
if status_code >= HTTPStatus.BAD_REQUEST:
|
|
||||||
raise KnowledgeRagError(f"embedding 模型返回异常状态码 {status_code}。")
|
|
||||||
|
|
||||||
return _extract_embedding_vectors(body, provider=config.provider)
|
|
||||||
@@ -1,177 +0,0 @@
|
|||||||
"""Few-shot 样本入库编排:RiskObservation → FewShotSample → Qdrant。
|
|
||||||
|
|
||||||
只处理人工确认为 confirmed / false_positive 的观测,把它转成一条
|
|
||||||
:class:`FewShotSample`,持久化到 DB,并同步向量到 Qdrant。
|
|
||||||
|
|
||||||
入库动作由 :meth:`RiskObservationService.create_feedback` 在 commit 后触发,
|
|
||||||
本服务全程吞异常(只记日志),保证反馈主流程不被 few-shot 链路拖崩。
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from app.core.logging import get_logger
|
|
||||||
from app.models.few_shot_sample import FewShotSample
|
|
||||||
from app.models.risk_observation import RiskObservation, RiskObservationFeedback
|
|
||||||
from app.services.embedding_provider import EmbeddingProvider
|
|
||||||
from app.services.few_shot_store import FewShotStore
|
|
||||||
|
|
||||||
logger = get_logger("app.services.few_shot_ingestion")
|
|
||||||
|
|
||||||
# 仅这两个 feedback_status 视为已确认样本,会入库
|
|
||||||
CONFIRMED_LABELS = {"confirmed", "false_positive"}
|
|
||||||
|
|
||||||
# label → 自然语言结论(当 feedback.comment 缺失时兜底)
|
|
||||||
LABEL_CONCLUSION_FALLBACK = {
|
|
||||||
"confirmed": "经人工复核确认,该风险线索成立,需按规则拦截或补件。",
|
|
||||||
"false_positive": "经人工复核判定为误报,相似情形不应触发该风险规则。",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class FewShotIngestionService:
|
|
||||||
"""把已确认的风险观测沉淀为 few-shot 样本。"""
|
|
||||||
|
|
||||||
def __init__(self, db: Session) -> None:
|
|
||||||
self.db = db
|
|
||||||
|
|
||||||
def ingest_observation_feedback(
|
|
||||||
self,
|
|
||||||
observation: RiskObservation,
|
|
||||||
feedback: RiskObservationFeedback,
|
|
||||||
) -> FewShotSample | None:
|
|
||||||
"""人工确认/误报后调用,写入并同步向量。"""
|
|
||||||
|
|
||||||
label = observation.feedback_status
|
|
||||||
if label not in CONFIRMED_LABELS:
|
|
||||||
return None
|
|
||||||
|
|
||||||
sample_key = f"obs:{observation.id}"
|
|
||||||
sample = self.db.scalar(
|
|
||||||
select(FewShotSample).where(FewShotSample.sample_key == sample_key)
|
|
||||||
)
|
|
||||||
|
|
||||||
domain = self._extract_domain(observation)
|
|
||||||
case_text = self._build_case_text(observation)
|
|
||||||
conclusion_text = self._build_conclusion_text(observation, feedback, label)
|
|
||||||
payload = self._build_payload(observation, feedback, label)
|
|
||||||
|
|
||||||
if sample is None:
|
|
||||||
sample = FewShotSample(
|
|
||||||
sample_key=sample_key,
|
|
||||||
source_observation_id=observation.id,
|
|
||||||
scene="risk_rule_generation",
|
|
||||||
domain=domain,
|
|
||||||
risk_type=observation.risk_type or "",
|
|
||||||
risk_level=observation.risk_level or "",
|
|
||||||
label=label,
|
|
||||||
case_text=case_text,
|
|
||||||
conclusion_text=conclusion_text,
|
|
||||||
payload_json=payload,
|
|
||||||
status="active",
|
|
||||||
)
|
|
||||||
self.db.add(sample)
|
|
||||||
else:
|
|
||||||
sample.label = label
|
|
||||||
sample.domain = domain
|
|
||||||
sample.risk_type = observation.risk_type or ""
|
|
||||||
sample.risk_level = observation.risk_level or ""
|
|
||||||
sample.case_text = case_text
|
|
||||||
sample.conclusion_text = conclusion_text
|
|
||||||
sample.payload_json = payload
|
|
||||||
sample.status = "active"
|
|
||||||
sample.vector_id = sample.vector_id
|
|
||||||
try:
|
|
||||||
self.db.commit()
|
|
||||||
self.db.refresh(sample)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("few-shot 样本持久化失败 observation_id=%s", observation.id)
|
|
||||||
self.db.rollback()
|
|
||||||
return None
|
|
||||||
|
|
||||||
vector_id = self._store().upsert(sample)
|
|
||||||
if vector_id:
|
|
||||||
sample.vector_id = vector_id
|
|
||||||
try:
|
|
||||||
self.db.commit()
|
|
||||||
except Exception:
|
|
||||||
logger.warning("few-shot vector_id 回写失败 sample_id=%s", sample.id)
|
|
||||||
return sample
|
|
||||||
|
|
||||||
def retract_observation(self, observation_id: str) -> bool:
|
|
||||||
"""观测被撤销时删掉对应样本及其向量。"""
|
|
||||||
|
|
||||||
sample = self.db.scalar(
|
|
||||||
select(FewShotSample).where(FewShotSample.source_observation_id == observation_id)
|
|
||||||
)
|
|
||||||
if sample is None:
|
|
||||||
return False
|
|
||||||
if sample.vector_id:
|
|
||||||
self._store().delete_by_vector_id(sample.vector_id)
|
|
||||||
try:
|
|
||||||
self.db.delete(sample)
|
|
||||||
self.db.commit()
|
|
||||||
return True
|
|
||||||
except Exception:
|
|
||||||
logger.exception("few-shot 样本删除失败 observation_id=%s", observation_id)
|
|
||||||
self.db.rollback()
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _store(self) -> FewShotStore:
|
|
||||||
provider = EmbeddingProvider.from_settings(self.db)
|
|
||||||
return FewShotStore(provider)
|
|
||||||
|
|
||||||
def _extract_domain(self, observation: RiskObservation) -> str:
|
|
||||||
ontology = observation.ontology_json or {}
|
|
||||||
return str(ontology.get("domain") or "")
|
|
||||||
|
|
||||||
def _build_case_text(self, observation: RiskObservation) -> str:
|
|
||||||
parts = [
|
|
||||||
observation.title or "",
|
|
||||||
observation.description or "",
|
|
||||||
observation.risk_signal or "",
|
|
||||||
observation.risk_type or "",
|
|
||||||
]
|
|
||||||
ontology = observation.ontology_json or {}
|
|
||||||
scenario = ontology.get("scenario")
|
|
||||||
if scenario:
|
|
||||||
parts.append(f"场景:{scenario}")
|
|
||||||
risk_signals = ontology.get("risk_signals")
|
|
||||||
if isinstance(risk_signals, list) and risk_signals:
|
|
||||||
parts.append("信号:" + "|".join(str(s) for s in risk_signals))
|
|
||||||
return "\n".join(part for part in parts if part).strip()
|
|
||||||
|
|
||||||
def _build_conclusion_text(
|
|
||||||
self,
|
|
||||||
observation: RiskObservation,
|
|
||||||
feedback: RiskObservationFeedback,
|
|
||||||
label: str,
|
|
||||||
) -> str:
|
|
||||||
comment = (feedback.comment or "").strip()
|
|
||||||
if comment:
|
|
||||||
return f"[{label}] {comment}"
|
|
||||||
return LABEL_CONCLUSION_FALLBACK.get(label, label)
|
|
||||||
|
|
||||||
def _build_payload(
|
|
||||||
self,
|
|
||||||
observation: RiskObservation,
|
|
||||||
feedback: RiskObservationFeedback,
|
|
||||||
label: str,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"label": label,
|
|
||||||
"risk_type": observation.risk_type,
|
|
||||||
"risk_signal": observation.risk_signal,
|
|
||||||
"risk_level": observation.risk_level,
|
|
||||||
"feedback_type": feedback.feedback_type,
|
|
||||||
"feedback_comment": feedback.comment or "",
|
|
||||||
"feedback_actor": feedback.actor or "",
|
|
||||||
"ontology": observation.ontology_json or {},
|
|
||||||
"policy_refs": observation.policy_refs_json or [],
|
|
||||||
"evidence": observation.evidence_json or [],
|
|
||||||
"subject_label": observation.subject_label or "",
|
|
||||||
"claim_no": observation.claim_no or "",
|
|
||||||
}
|
|
||||||
@@ -1,122 +0,0 @@
|
|||||||
"""Few-shot 检索器:按当前 case 特征检索相似历史样本,拼成注入块。
|
|
||||||
|
|
||||||
从 :class:`FewShotStore` 取相似样本,转成可供 prompt 构造函数直接使用的结构。
|
|
||||||
带 token 预算裁剪和去重,确保不撑爆 prompt。
|
|
||||||
|
|
||||||
典型用法(在构造 prompt 之前调用)::
|
|
||||||
|
|
||||||
retriever = FewShotRetriever.from_session(session)
|
|
||||||
samples = retriever.retrieve_for_risk_rule_generation(
|
|
||||||
domain="travel", natural_language="票据城市与申报地不一致"
|
|
||||||
)
|
|
||||||
messages = build_risk_rule_compiler_messages(
|
|
||||||
...,
|
|
||||||
few_shot_samples=samples,
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from app.core.logging import get_logger
|
|
||||||
from app.services.embedding_provider import EmbeddingProvider
|
|
||||||
from app.services.few_shot_store import FewShotStore
|
|
||||||
|
|
||||||
logger = get_logger("app.services.few_shot_retrieval")
|
|
||||||
|
|
||||||
# 单条 few-shot 样本估算 token 数(用于预算裁剪)
|
|
||||||
SAMPLE_TOKEN_BUDGET = 1200
|
|
||||||
# 单条样本最大字符数,超长直接截断结论,避免撑爆 prompt
|
|
||||||
SINGLE_SAMPLE_MAX_CHARS = 400
|
|
||||||
# 历史样本最多注入条数(与原内置 examples 合并后总量受限)
|
|
||||||
MAX_HISTORICAL_SAMPLES = 3
|
|
||||||
|
|
||||||
|
|
||||||
class FewShotRetriever:
|
|
||||||
"""按 case 特征检索已确认样本,返回 prompt 可直接消费的结构。"""
|
|
||||||
|
|
||||||
def __init__(self, store: FewShotStore) -> None:
|
|
||||||
self._store = store
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_session(cls, session: Session) -> "FewShotRetriever":
|
|
||||||
provider = EmbeddingProvider.from_settings(session)
|
|
||||||
return cls(FewShotStore(provider))
|
|
||||||
|
|
||||||
def retrieve_for_risk_rule_generation(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
domain: str = "",
|
|
||||||
risk_type: str = "",
|
|
||||||
natural_language: str,
|
|
||||||
top_k: int = MAX_HISTORICAL_SAMPLES,
|
|
||||||
) -> list[dict[str, Any]]:
|
|
||||||
"""检索与当前规则需求相似的历史样本,返回注入块列表。"""
|
|
||||||
|
|
||||||
case_text = self._build_case_text(
|
|
||||||
natural_language=natural_language,
|
|
||||||
domain=domain,
|
|
||||||
risk_type=risk_type,
|
|
||||||
)
|
|
||||||
if not case_text:
|
|
||||||
return []
|
|
||||||
hits = self._store.search(
|
|
||||||
case_text,
|
|
||||||
scene="risk_rule_generation",
|
|
||||||
labels=["confirmed", "false_positive"],
|
|
||||||
top_k=top_k,
|
|
||||||
)
|
|
||||||
return self._hits_to_injection_blocks(hits)
|
|
||||||
|
|
||||||
def _build_case_text(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
natural_language: str,
|
|
||||||
domain: str = "",
|
|
||||||
risk_type: str = "",
|
|
||||||
) -> str:
|
|
||||||
parts = [natural_language, domain, risk_type]
|
|
||||||
return "\n".join(p for p in parts if p).strip()
|
|
||||||
|
|
||||||
def _hits_to_injection_blocks(
|
|
||||||
self,
|
|
||||||
hits: list[dict[str, Any]],
|
|
||||||
) -> list[dict[str, Any]]:
|
|
||||||
"""把检索命中转成 prompt 可消费的块,做去重和预算裁剪。"""
|
|
||||||
|
|
||||||
blocks: list[dict[str, Any]] = []
|
|
||||||
seen_conclusions: set[str] = set()
|
|
||||||
budget = SAMPLE_TOKEN_BUDGET
|
|
||||||
for hit in hits:
|
|
||||||
conclusion = (hit.get("conclusion_text") or "").strip()
|
|
||||||
if not conclusion or conclusion in seen_conclusions:
|
|
||||||
continue
|
|
||||||
# 超长结论截断到上限,避免单条样本占用过多预算
|
|
||||||
if len(conclusion) > SINGLE_SAMPLE_MAX_CHARS:
|
|
||||||
conclusion = conclusion[:SINGLE_SAMPLE_MAX_CHARS]
|
|
||||||
payload = hit.get("payload_json") or {}
|
|
||||||
block = {
|
|
||||||
"source": "historical_confirmed",
|
|
||||||
"label": hit.get("label"),
|
|
||||||
"domain": hit.get("domain") or "",
|
|
||||||
"risk_type": hit.get("risk_type") or "",
|
|
||||||
"score": round(float(hit.get("score") or 0.0), 4),
|
|
||||||
"conclusion": conclusion,
|
|
||||||
"context": {
|
|
||||||
"risk_signal": payload.get("risk_signal") or "",
|
|
||||||
"risk_level": payload.get("risk_level") or "",
|
|
||||||
"ontology": payload.get("ontology") or {},
|
|
||||||
"feedback_comment": payload.get("feedback_comment") or "",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
# 粗略 token 估算(按字符数 / 1.6 近似中文 token 比)
|
|
||||||
estimated_tokens = int(len(conclusion) / 1.6) + 40
|
|
||||||
if estimated_tokens > budget:
|
|
||||||
break
|
|
||||||
budget -= estimated_tokens
|
|
||||||
blocks.append(block)
|
|
||||||
seen_conclusions.add(conclusion)
|
|
||||||
return blocks
|
|
||||||
@@ -1,214 +0,0 @@
|
|||||||
"""Few-shot 样本的 Qdrant 向量存储。
|
|
||||||
|
|
||||||
独立于 LightRAG 的 Qdrant 客户端,使用专用 collection ``few_shot_samples``,
|
|
||||||
与知识库 RAG 的 collection 隔离。所有操作失败都不抛异常(记日志返回空),
|
|
||||||
保证主链路不阻塞。
|
|
||||||
|
|
||||||
向量来自 :class:`EmbeddingProvider`,payload 带业务过滤字段(scene/label/domain/risk_type),
|
|
||||||
检索时按这些字段过滤 + 向量相似度排序。
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import os
|
|
||||||
import uuid
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from app.core.logging import get_logger
|
|
||||||
from app.services.knowledge_rag import _resolve_default_qdrant_url
|
|
||||||
|
|
||||||
logger = get_logger("app.services.few_shot_store")
|
|
||||||
|
|
||||||
FEW_SHOT_COLLECTION = "few_shot_samples"
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_qdrant_config() -> tuple[str, str]:
|
|
||||||
"""复用 knowledge_rag 的 Qdrant URL/key 解析逻辑。"""
|
|
||||||
|
|
||||||
url = os.environ.get("QDRANT_URL", "").strip() or _resolve_default_qdrant_url()
|
|
||||||
api_key = os.environ.get("QDRANT_API_KEY", "").strip()
|
|
||||||
return url, api_key
|
|
||||||
|
|
||||||
|
|
||||||
class FewShotStore:
|
|
||||||
"""对 Qdrant 的轻量封装,专供 few-shot 样本检索使用。
|
|
||||||
|
|
||||||
设计要点:
|
|
||||||
- 惰性创建 client 和 collection,首次操作时初始化。
|
|
||||||
- 所有公共方法吞异常(返回空/False),主链路永远不被拖崩。
|
|
||||||
- 向量写入和检索都依赖外部传入的 :class:`EmbeddingProvider`,
|
|
||||||
由调用方保证与配置一致。
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, embedding_provider: Any) -> None:
|
|
||||||
self._embedding_provider = embedding_provider
|
|
||||||
self._client: Any = None
|
|
||||||
self._ensured = False
|
|
||||||
|
|
||||||
def _client_or_none(self) -> Any:
|
|
||||||
"""惰性初始化 QdrantClient,失败返回 None。"""
|
|
||||||
|
|
||||||
if self._client is not None:
|
|
||||||
return self._client
|
|
||||||
try:
|
|
||||||
from qdrant_client import QdrantClient
|
|
||||||
|
|
||||||
url, api_key = _resolve_qdrant_config()
|
|
||||||
self._client = QdrantClient(url=url, api_key=api_key or None)
|
|
||||||
except Exception:
|
|
||||||
logger.warning("few-shot QdrantClient 初始化失败,本轮操作跳过", exc_info=True)
|
|
||||||
self._client = None
|
|
||||||
return self._client
|
|
||||||
|
|
||||||
def _ensure_collection(self) -> bool:
|
|
||||||
"""确保 collection 存在,成功返回 True。"""
|
|
||||||
|
|
||||||
if self._ensured:
|
|
||||||
return True
|
|
||||||
client = self._client_or_none()
|
|
||||||
if client is None:
|
|
||||||
return False
|
|
||||||
try:
|
|
||||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
|
||||||
|
|
||||||
try:
|
|
||||||
client.get_collection(FEW_SHOT_COLLECTION)
|
|
||||||
self._ensured = True
|
|
||||||
return True
|
|
||||||
except UnexpectedResponse as exc:
|
|
||||||
if exc.status_code != 404:
|
|
||||||
raise
|
|
||||||
# collection 不存在则创建
|
|
||||||
dim = self._embedding_provider.dimension()
|
|
||||||
from qdrant_client.http.models import (
|
|
||||||
Distance,
|
|
||||||
VectorParams,
|
|
||||||
PayloadSchemaType,
|
|
||||||
)
|
|
||||||
|
|
||||||
client.create_collection(
|
|
||||||
collection_name=FEW_SHOT_COLLECTION,
|
|
||||||
vectors_config=VectorParams(size=dim, distance=Distance.COSINE),
|
|
||||||
)
|
|
||||||
for field, field_type in [
|
|
||||||
("sample_id", PayloadSchemaType.KEYWORD),
|
|
||||||
("scene", PayloadSchemaType.KEYWORD),
|
|
||||||
("label", PayloadSchemaType.KEYWORD),
|
|
||||||
("domain", PayloadSchemaType.KEYWORD),
|
|
||||||
("risk_type", PayloadSchemaType.KEYWORD),
|
|
||||||
("status", PayloadSchemaType.KEYWORD),
|
|
||||||
]:
|
|
||||||
try:
|
|
||||||
client.create_payload_index(
|
|
||||||
collection_name=FEW_SHOT_COLLECTION,
|
|
||||||
field_name=field,
|
|
||||||
field_schema=field_type,
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
logger.debug("payload index 创建跳过 field=%s", field, exc_info=True)
|
|
||||||
self._ensured = True
|
|
||||||
logger.info("few-shot collection 创建成功 dim=%s", dim)
|
|
||||||
return True
|
|
||||||
except Exception:
|
|
||||||
logger.warning("few-shot collection 初始化失败,本轮操作跳过", exc_info=True)
|
|
||||||
return False
|
|
||||||
|
|
||||||
def upsert(self, sample: Any) -> str | None:
|
|
||||||
"""把一条样本向量化并写入 Qdrant,返回 vector_id,失败返回 None。"""
|
|
||||||
|
|
||||||
if not self._ensure_collection():
|
|
||||||
return None
|
|
||||||
client = self._client
|
|
||||||
try:
|
|
||||||
vector = self._embedding_provider.embed([sample.case_text])[0]
|
|
||||||
except Exception:
|
|
||||||
logger.warning("few-shot embedding 失败 sample_key=%s", getattr(sample, "sample_key", ""), exc_info=True)
|
|
||||||
return None
|
|
||||||
vector_id = uuid.uuid4().hex
|
|
||||||
payload = {
|
|
||||||
"sample_id": sample.id,
|
|
||||||
"scene": sample.scene,
|
|
||||||
"label": sample.label,
|
|
||||||
"domain": sample.domain,
|
|
||||||
"risk_type": sample.risk_type,
|
|
||||||
"risk_level": sample.risk_level,
|
|
||||||
"status": getattr(sample, "status", "active"),
|
|
||||||
"conclusion_text": sample.conclusion_text,
|
|
||||||
"payload_json": sample.payload_json,
|
|
||||||
}
|
|
||||||
try:
|
|
||||||
client.upsert(
|
|
||||||
collection_name=FEW_SHOT_COLLECTION,
|
|
||||||
points=[{"id": vector_id, "vector": vector, "payload": payload}],
|
|
||||||
)
|
|
||||||
return vector_id
|
|
||||||
except Exception:
|
|
||||||
logger.warning("few-shot upsert 失败 sample_key=%s", getattr(sample, "sample_key", ""), exc_info=True)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def search(
|
|
||||||
self,
|
|
||||||
case_text: str,
|
|
||||||
*,
|
|
||||||
scene: str | None = None,
|
|
||||||
labels: list[str] | None = None,
|
|
||||||
top_k: int = 3,
|
|
||||||
) -> list[dict[str, Any]]:
|
|
||||||
"""按 case_text 检索相似样本,可按 scene/label 过滤。失败返回空列表。"""
|
|
||||||
|
|
||||||
if not case_text or not self._ensure_collection():
|
|
||||||
return []
|
|
||||||
client = self._client
|
|
||||||
try:
|
|
||||||
vector = self._embedding_provider.embed([case_text])[0]
|
|
||||||
except Exception:
|
|
||||||
logger.warning("few-shot 检索 embedding 失败", exc_info=True)
|
|
||||||
return []
|
|
||||||
must: list[dict[str, Any]] = [{"key": "status", "match": {"value": "active"}}]
|
|
||||||
if scene:
|
|
||||||
must.append({"key": "scene", "match": {"value": scene}})
|
|
||||||
if labels:
|
|
||||||
must.append({"key": "label", "match": {"any": labels}})
|
|
||||||
try:
|
|
||||||
from qdrant_client.http.models import Filter
|
|
||||||
|
|
||||||
results = client.query_points(
|
|
||||||
collection_name=FEW_SHOT_COLLECTION,
|
|
||||||
query=vector,
|
|
||||||
query_filter=Filter(must=must),
|
|
||||||
limit=top_k,
|
|
||||||
with_payload=True,
|
|
||||||
).points
|
|
||||||
except Exception:
|
|
||||||
logger.warning("few-shot 检索失败", exc_info=True)
|
|
||||||
return []
|
|
||||||
hits: list[dict[str, Any]] = []
|
|
||||||
for point in results:
|
|
||||||
payload = getattr(point, "payload", None) or {}
|
|
||||||
hits.append(
|
|
||||||
{
|
|
||||||
"sample_id": payload.get("sample_id"),
|
|
||||||
"score": float(getattr(point, "score", 0.0)),
|
|
||||||
"label": payload.get("label"),
|
|
||||||
"domain": payload.get("domain"),
|
|
||||||
"risk_type": payload.get("risk_type"),
|
|
||||||
"conclusion_text": payload.get("conclusion_text") or "",
|
|
||||||
"payload_json": payload.get("payload_json") or {},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return hits
|
|
||||||
|
|
||||||
def delete_by_vector_id(self, vector_id: str) -> bool:
|
|
||||||
"""按 vector_id 删除向量,失败返回 False。"""
|
|
||||||
|
|
||||||
if not vector_id or not self._ensure_collection():
|
|
||||||
return False
|
|
||||||
try:
|
|
||||||
self._client.delete(
|
|
||||||
collection_name=FEW_SHOT_COLLECTION,
|
|
||||||
points_selector=[vector_id],
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
except Exception:
|
|
||||||
logger.warning("few-shot 删除失败 vector_id=%s", vector_id, exc_info=True)
|
|
||||||
return False
|
|
||||||
@@ -1,6 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
|
||||||
from datetime import UTC, datetime, timedelta
|
from datetime import UTC, datetime, timedelta
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -9,7 +8,6 @@ from sqlalchemy import func, select
|
|||||||
from sqlalchemy.orm import Session, joinedload
|
from sqlalchemy.orm import Session, joinedload
|
||||||
|
|
||||||
from app.algorithem.risk_graph import RiskHistoryStats, RiskObservationDraft
|
from app.algorithem.risk_graph import RiskHistoryStats, RiskObservationDraft
|
||||||
from app.core.logging import get_logger
|
|
||||||
from app.db.base import Base
|
from app.db.base import Base
|
||||||
from app.models.financial_record import ExpenseClaim
|
from app.models.financial_record import ExpenseClaim
|
||||||
from app.models.risk_observation import RiskObservation, RiskObservationFeedback
|
from app.models.risk_observation import RiskObservation, RiskObservationFeedback
|
||||||
@@ -19,8 +17,6 @@ from app.schemas.risk_observation import (
|
|||||||
)
|
)
|
||||||
from app.services.expense_claim_risk_stage import normalize_risk_business_stage
|
from app.services.expense_claim_risk_stage import normalize_risk_business_stage
|
||||||
|
|
||||||
logger = get_logger("app.services.risk_observations")
|
|
||||||
|
|
||||||
HIGH_LEVELS = {"high", "critical"}
|
HIGH_LEVELS = {"high", "critical"}
|
||||||
SEVERITY_SCORE = {
|
SEVERITY_SCORE = {
|
||||||
"low": 32,
|
"low": 32,
|
||||||
@@ -326,27 +322,8 @@ class RiskObservationService:
|
|||||||
observation.status, observation.feedback_status = mapped
|
observation.status, observation.feedback_status = mapped
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
self.db.refresh(feedback)
|
self.db.refresh(feedback)
|
||||||
self._maybe_ingest_few_shot(observation, feedback)
|
|
||||||
return feedback
|
return feedback
|
||||||
|
|
||||||
def _maybe_ingest_few_shot(
|
|
||||||
self,
|
|
||||||
observation: RiskObservation,
|
|
||||||
feedback: RiskObservationFeedback,
|
|
||||||
) -> None:
|
|
||||||
"""人工确认/误报后把样本沉淀进 few-shot 池,任何失败都不影响主流程。"""
|
|
||||||
|
|
||||||
if os.environ.get("FEW_SHOT_INJECTION_ENABLED", "true").strip().lower() in {"0", "false", "no"}:
|
|
||||||
return
|
|
||||||
if observation.feedback_status not in {"confirmed", "false_positive"}:
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
from app.services.few_shot_ingestion import FewShotIngestionService
|
|
||||||
|
|
||||||
FewShotIngestionService(self.db).ingest_observation_feedback(observation, feedback)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("few-shot ingestion failed for observation %s", observation.id)
|
|
||||||
|
|
||||||
def summarize_dashboard(
|
def summarize_dashboard(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
|
|||||||
@@ -234,10 +234,6 @@ class RiskRuleGenerationService:
|
|||||||
}
|
}
|
||||||
for item in fields
|
for item in fields
|
||||||
]
|
]
|
||||||
few_shot_samples = self._retrieve_few_shot_samples(
|
|
||||||
domain=domain,
|
|
||||||
natural_language=natural_language,
|
|
||||||
)
|
|
||||||
messages = build_risk_rule_compiler_messages(
|
messages = build_risk_rule_compiler_messages(
|
||||||
domain=domain,
|
domain=domain,
|
||||||
domain_label=BUSINESS_DOMAIN_LABELS[domain],
|
domain_label=BUSINESS_DOMAIN_LABELS[domain],
|
||||||
@@ -247,7 +243,6 @@ class RiskRuleGenerationService:
|
|||||||
expense_category_label=expense_category_label,
|
expense_category_label=expense_category_label,
|
||||||
natural_language=natural_language,
|
natural_language=natural_language,
|
||||||
available_fields=field_payload,
|
available_fields=field_payload,
|
||||||
few_shot_samples=few_shot_samples,
|
|
||||||
)
|
)
|
||||||
answer = self.runtime_chat_service.complete(
|
answer = self.runtime_chat_service.complete(
|
||||||
messages,
|
messages,
|
||||||
@@ -268,29 +263,6 @@ class RiskRuleGenerationService:
|
|||||||
payload = unwrap_semantic_plan_payload(payload)
|
payload = unwrap_semantic_plan_payload(payload)
|
||||||
return self._sanitize_model_draft(payload, fields=fields)
|
return self._sanitize_model_draft(payload, fields=fields)
|
||||||
|
|
||||||
def _retrieve_few_shot_samples(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
domain: str,
|
|
||||||
natural_language: str,
|
|
||||||
) -> list[dict[str, Any]]:
|
|
||||||
"""检索已确认历史样本,失败降级为空列表。"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
if os.environ.get("FEW_SHOT_INJECTION_ENABLED", "true").strip().lower() in {"0", "false", "no"}:
|
|
||||||
return []
|
|
||||||
try:
|
|
||||||
from app.services.few_shot_retrieval import FewShotRetriever
|
|
||||||
|
|
||||||
retriever = FewShotRetriever.from_session(self.db)
|
|
||||||
return retriever.retrieve_for_risk_rule_generation(
|
|
||||||
domain=domain,
|
|
||||||
natural_language=natural_language,
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
return []
|
|
||||||
|
|
||||||
def _sanitize_model_draft(
|
def _sanitize_model_draft(
|
||||||
self,
|
self,
|
||||||
payload: dict[str, Any],
|
payload: dict[str, Any],
|
||||||
|
|||||||
@@ -14,15 +14,10 @@ def build_risk_rule_compiler_messages(
|
|||||||
expense_category_label: str,
|
expense_category_label: str,
|
||||||
natural_language: str,
|
natural_language: str,
|
||||||
available_fields: list[dict[str, Any]],
|
available_fields: list[dict[str, Any]],
|
||||||
few_shot_samples: list[dict[str, Any]] | None = None,
|
|
||||||
) -> list[dict[str, str]]:
|
) -> list[dict[str, str]]:
|
||||||
"""构造自然语言规则编译提示词。
|
"""构造自然语言规则编译提示词。
|
||||||
|
|
||||||
大模型只负责把业务语言拆成“语义计划”,后端会校验字段、操作符和模板。
|
大模型只负责把业务语言拆成“语义计划”,后端会校验字段、操作符和模板。
|
||||||
|
|
||||||
``few_shot_samples`` 是从已确认历史样本中检索出来的相似案例,会被合并进
|
|
||||||
``examples`` 字段并标注 ``source: "historical_confirmed"``,让编译器参考
|
|
||||||
过往人工结论。传 ``None`` 或空列表时行为与历史完全一致(向后兼容)。
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
schema = {
|
schema = {
|
||||||
@@ -166,20 +161,6 @@ def build_risk_rule_compiler_messages(
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
historical_examples: list[dict[str, Any]] = []
|
|
||||||
if few_shot_samples:
|
|
||||||
for sample in few_shot_samples:
|
|
||||||
historical_examples.append(
|
|
||||||
{
|
|
||||||
"source": "historical_confirmed",
|
|
||||||
"label": sample.get("label"),
|
|
||||||
"domain": sample.get("domain") or "",
|
|
||||||
"risk_type": sample.get("risk_type") or "",
|
|
||||||
"conclusion": sample.get("conclusion") or "",
|
|
||||||
"context": sample.get("context") or {},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
merged_examples = historical_examples + examples
|
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
@@ -205,7 +186,7 @@ def build_risk_rule_compiler_messages(
|
|||||||
"natural_language": natural_language,
|
"natural_language": natural_language,
|
||||||
"available_fields": available_fields,
|
"available_fields": available_fields,
|
||||||
"required_json_shape": response_schema,
|
"required_json_shape": response_schema,
|
||||||
"examples": merged_examples,
|
"examples": examples,
|
||||||
},
|
},
|
||||||
ensure_ascii=False,
|
ensure_ascii=False,
|
||||||
),
|
),
|
||||||
|
|||||||
329
server/src/app/services/risk_rule_golden_evaluator.py
Normal file
329
server/src/app/services/risk_rule_golden_evaluator.py
Normal file
@@ -0,0 +1,329 @@
|
|||||||
|
"""风险规则 golden set 评测器与发布门禁。
|
||||||
|
|
||||||
|
在版本化的黄金用例集(:class:`GoldenCase`)上跑规则 manifest,计算
|
||||||
|
accuracy/precision/recall,并按"100% 通过"的硬阈值做发布门禁。
|
||||||
|
|
||||||
|
执行链路完全复用现有能力:
|
||||||
|
- ``RiskRuleTemplateExecutor.evaluate_with_trace`` 跑规则
|
||||||
|
- ``AgentAssetRiskRuleTestingMixin`` 的 static helpers 组装 synthetic claim
|
||||||
|
- 单条比对逻辑与 ``_run_sample_case`` 保持一致
|
||||||
|
|
||||||
|
门禁语义与现有 ``test_passed`` 一致:未通过抛 ``PermissionError``,
|
||||||
|
同时写一条 ``AgentAssetTestRun(test_type='golden')`` 记录结果。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import UTC, date, datetime
|
||||||
|
from decimal import Decimal, InvalidOperation
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.core.agent_enums import AgentAssetType
|
||||||
|
from app.core.logging import get_logger
|
||||||
|
from app.models.agent_asset import AgentAsset, AgentAssetTestRun
|
||||||
|
from app.models.employee import Employee
|
||||||
|
from app.models.financial_record import ExpenseClaim, ExpenseClaimItem
|
||||||
|
from app.models.golden_case import GoldenCase
|
||||||
|
from app.services.risk_rule_template_executor import RiskRuleTemplateExecutor
|
||||||
|
|
||||||
|
logger = get_logger("app.services.risk_rule_golden_evaluator")
|
||||||
|
|
||||||
|
GOLDEN_GATE_FLAG = "GOLDEN_SET_GATE_ENABLED"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GoldenCaseResult:
|
||||||
|
case_id: str
|
||||||
|
name: str
|
||||||
|
expected_hit: bool
|
||||||
|
actual_hit: bool
|
||||||
|
expected_severity: str
|
||||||
|
actual_severity: str
|
||||||
|
passed: bool
|
||||||
|
message: str = ""
|
||||||
|
evidence: dict[str, Any] = field(default_factory=dict)
|
||||||
|
trace: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GoldenEvalReport:
|
||||||
|
total: int = 0
|
||||||
|
passed_count: int = 0
|
||||||
|
failed_count: int = 0
|
||||||
|
accuracy: float = 0.0
|
||||||
|
precision: float = 0.0
|
||||||
|
recall: float = 0.0
|
||||||
|
all_passed: bool = True
|
||||||
|
results: list[GoldenCaseResult] = field(default_factory=list)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"total": self.total,
|
||||||
|
"passed_count": self.passed_count,
|
||||||
|
"failed_count": self.failed_count,
|
||||||
|
"accuracy": round(self.accuracy, 4),
|
||||||
|
"precision": round(self.precision, 4),
|
||||||
|
"recall": round(self.recall, 4),
|
||||||
|
"all_passed": self.all_passed,
|
||||||
|
"results": [
|
||||||
|
{
|
||||||
|
"case_id": r.case_id,
|
||||||
|
"name": r.name,
|
||||||
|
"expected_hit": r.expected_hit,
|
||||||
|
"actual_hit": r.actual_hit,
|
||||||
|
"expected_severity": r.expected_severity,
|
||||||
|
"actual_severity": r.actual_severity,
|
||||||
|
"passed": r.passed,
|
||||||
|
"message": r.message,
|
||||||
|
}
|
||||||
|
for r in self.results
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _gate_enabled() -> bool:
|
||||||
|
return os.environ.get(GOLDEN_GATE_FLAG, "true").strip().lower() not in {"0", "false", "no"}
|
||||||
|
|
||||||
|
|
||||||
|
# ---- synthetic claim 构建(与 AgentAssetRiskRuleTestingMixin._build_synthetic_claim 一致)----
|
||||||
|
|
||||||
|
def _extract_manifest_fields(manifest: dict[str, Any]) -> list[dict[str, str]]:
|
||||||
|
inputs = manifest.get("inputs") if isinstance(manifest.get("inputs"), dict) else {}
|
||||||
|
fields = inputs.get("fields") if isinstance(inputs.get("fields"), list) else []
|
||||||
|
normalized: list[dict[str, str]] = []
|
||||||
|
for item in fields:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
key = str(item.get("key") or "").strip()
|
||||||
|
if key:
|
||||||
|
normalized.append({"key": key, "label": str(item.get("label") or key).strip()})
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_sample_value(field_key: str, value: Any) -> Any:
|
||||||
|
import re
|
||||||
|
|
||||||
|
if field_key.endswith("route_cities") and isinstance(value, str):
|
||||||
|
return [item.strip() for item in re.split(r"[,,、/ ]+", value) if item.strip()]
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def _to_decimal(value: Any) -> Decimal:
|
||||||
|
try:
|
||||||
|
return Decimal(str(value or "0"))
|
||||||
|
except (InvalidOperation, ValueError):
|
||||||
|
return Decimal("0")
|
||||||
|
|
||||||
|
|
||||||
|
def _build_synthetic_claim(
|
||||||
|
values: dict[str, Any],
|
||||||
|
manifest: dict[str, Any],
|
||||||
|
) -> tuple[ExpenseClaim, list[dict[str, Any]]]:
|
||||||
|
claim = ExpenseClaim(
|
||||||
|
claim_no="GOLDEN-RISK-RULE",
|
||||||
|
employee_name=str(values.get("claim.employee_name") or "测试员工"),
|
||||||
|
department_name=str(values.get("claim.department_name") or "测试部门"),
|
||||||
|
expense_type=str(values.get("item.item_type") or "差旅费"),
|
||||||
|
reason=str(values.get("claim.reason") or "测试报销事由"),
|
||||||
|
location=str(values.get("claim.location") or "北京"),
|
||||||
|
amount=_to_decimal(values.get("claim.amount")),
|
||||||
|
currency="CNY",
|
||||||
|
invoice_count=1,
|
||||||
|
occurred_at=datetime.now(UTC),
|
||||||
|
status="draft",
|
||||||
|
)
|
||||||
|
item = ExpenseClaimItem(
|
||||||
|
item_date=date.today(),
|
||||||
|
item_type=str(values.get("item.item_type") or "住宿费"),
|
||||||
|
item_reason=str(values.get("item.item_reason") or claim.reason),
|
||||||
|
item_location=str(values.get("item.item_location") or claim.location),
|
||||||
|
item_amount=_to_decimal(values.get("item.item_amount") or claim.amount),
|
||||||
|
)
|
||||||
|
claim.items = [item]
|
||||||
|
if values.get("employee.location"):
|
||||||
|
claim.employee = Employee(
|
||||||
|
employee_no="GOLDEN-EMPLOYEE",
|
||||||
|
name=claim.employee_name,
|
||||||
|
email="golden-rule-test@example.com",
|
||||||
|
location=str(values.get("employee.location") or ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
attachment_fields: list[dict[str, Any]] = []
|
||||||
|
document_info: dict[str, Any] = {"fields": attachment_fields}
|
||||||
|
for field in _extract_manifest_fields(manifest):
|
||||||
|
key = field["key"]
|
||||||
|
if key not in values:
|
||||||
|
continue
|
||||||
|
value = _coerce_sample_value(key, values.get(key))
|
||||||
|
if key.startswith("claim."):
|
||||||
|
setattr(claim, key.removeprefix("claim."), value)
|
||||||
|
elif key.startswith("item."):
|
||||||
|
setattr(item, key.removeprefix("item."), value)
|
||||||
|
elif key.startswith("attachment."):
|
||||||
|
short_key = key.removeprefix("attachment.")
|
||||||
|
document_info[short_key] = value
|
||||||
|
attachment_fields.append({"key": short_key, "label": field["label"], "value": value})
|
||||||
|
return claim, [{"document_info": document_info, "ocr_text": document_info.get("ocr_text", "")}]
|
||||||
|
|
||||||
|
|
||||||
|
def _run_single_case(
|
||||||
|
manifest: dict[str, Any],
|
||||||
|
values: dict[str, Any],
|
||||||
|
expected_hit: bool,
|
||||||
|
expected_severity: str,
|
||||||
|
) -> GoldenCaseResult:
|
||||||
|
claim, contexts = _build_synthetic_claim(values, manifest)
|
||||||
|
execution = RiskRuleTemplateExecutor().evaluate_with_trace(manifest, claim=claim, contexts=contexts)
|
||||||
|
result = execution["result"]
|
||||||
|
actual_hit = result is not None
|
||||||
|
actual_severity = (
|
||||||
|
str((manifest.get("outcomes") or {}).get("fail", {}).get("severity") or "").strip()
|
||||||
|
if actual_hit
|
||||||
|
else "none"
|
||||||
|
)
|
||||||
|
severity_passed = (
|
||||||
|
not actual_hit or not expected_severity or expected_severity == actual_severity
|
||||||
|
)
|
||||||
|
passed = actual_hit == expected_hit and severity_passed
|
||||||
|
return GoldenCaseResult(
|
||||||
|
case_id="",
|
||||||
|
name="",
|
||||||
|
expected_hit=expected_hit,
|
||||||
|
actual_hit=actual_hit,
|
||||||
|
expected_severity=expected_severity,
|
||||||
|
actual_severity=actual_severity,
|
||||||
|
passed=passed,
|
||||||
|
message=str(result.get("message") or "") if isinstance(result, dict) else "",
|
||||||
|
evidence=result.get("evidence") if isinstance(result, dict) else {},
|
||||||
|
trace=execution.get("trace") if isinstance(execution.get("trace"), dict) else {},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _aggregate(results: list[GoldenCaseResult]) -> GoldenEvalReport:
|
||||||
|
total = len(results)
|
||||||
|
if total == 0:
|
||||||
|
return GoldenEvalReport(total=0, all_passed=True)
|
||||||
|
passed_count = sum(1 for r in results if r.passed)
|
||||||
|
tp = sum(1 for r in results if r.expected_hit and r.actual_hit)
|
||||||
|
fp = sum(1 for r in results if r.expected_hit and not r.actual_hit) # 应命中未命中
|
||||||
|
fn = sum(1 for r in results if not r.expected_hit and r.actual_hit) # 不应命中却命中
|
||||||
|
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
|
||||||
|
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
|
||||||
|
return GoldenEvalReport(
|
||||||
|
total=total,
|
||||||
|
passed_count=passed_count,
|
||||||
|
failed_count=total - passed_count,
|
||||||
|
accuracy=passed_count / total,
|
||||||
|
precision=precision,
|
||||||
|
recall=recall,
|
||||||
|
all_passed=passed_count == total,
|
||||||
|
results=results,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RiskRuleGoldenEvaluator:
|
||||||
|
"""在 golden set 上评测规则 manifest 并执行发布门禁。"""
|
||||||
|
|
||||||
|
def evaluate(self, manifest: dict[str, Any], cases: list[GoldenCase]) -> GoldenEvalReport:
|
||||||
|
results: list[GoldenCaseResult] = []
|
||||||
|
for case in cases:
|
||||||
|
result = _run_single_case(
|
||||||
|
manifest,
|
||||||
|
values=case.values_json or {},
|
||||||
|
expected_hit=bool(case.expected_hit),
|
||||||
|
expected_severity=str(case.expected_severity or ""),
|
||||||
|
)
|
||||||
|
result.case_id = case.case_key or case.id
|
||||||
|
result.name = case.name
|
||||||
|
results.append(result)
|
||||||
|
return _aggregate(results)
|
||||||
|
|
||||||
|
def evaluate_for_rule(
|
||||||
|
self,
|
||||||
|
db: Session,
|
||||||
|
manifest: dict[str, Any],
|
||||||
|
rule_code: str,
|
||||||
|
) -> GoldenEvalReport:
|
||||||
|
cases = list(
|
||||||
|
db.scalars(
|
||||||
|
select(GoldenCase).where(
|
||||||
|
GoldenCase.rule_code == rule_code,
|
||||||
|
GoldenCase.status == "active",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if not cases:
|
||||||
|
return GoldenEvalReport(total=0, all_passed=True)
|
||||||
|
return self.evaluate(manifest, cases)
|
||||||
|
|
||||||
|
def require_pass(
|
||||||
|
self,
|
||||||
|
db: Session,
|
||||||
|
asset: AgentAsset,
|
||||||
|
version: str,
|
||||||
|
manifest: dict[str, Any],
|
||||||
|
rule_code: str,
|
||||||
|
*,
|
||||||
|
actor: str,
|
||||||
|
) -> GoldenEvalReport:
|
||||||
|
"""发布门禁入口:跑 golden set,未 100% 通过抛 PermissionError。
|
||||||
|
|
||||||
|
golden set 为空或门禁关闭时放行; evaluator 异常时降级放行(记日志)。
|
||||||
|
无论放行与否,都写一条 ``AgentAssetTestRun(test_type='golden')`` 记录。
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not _gate_enabled():
|
||||||
|
return GoldenEvalReport(total=0, all_passed=True)
|
||||||
|
try:
|
||||||
|
report = self.evaluate_for_rule(db, manifest, rule_code)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("golden set 评测异常,降级放行 asset_id=%s", asset.id)
|
||||||
|
report = GoldenEvalReport(total=0, all_passed=True)
|
||||||
|
|
||||||
|
self._record_test_run(db, asset, version, report, actor=actor)
|
||||||
|
|
||||||
|
if report.total > 0 and not report.all_passed:
|
||||||
|
failures = report.to_dict()["results"]
|
||||||
|
raise PermissionError(
|
||||||
|
f"golden set 回归未通过({report.passed_count}/{report.total}),"
|
||||||
|
f"发布被拦截。失败用例:{failures}"
|
||||||
|
)
|
||||||
|
return report
|
||||||
|
|
||||||
|
def _record_test_run(
|
||||||
|
self,
|
||||||
|
db: Session,
|
||||||
|
asset: AgentAsset,
|
||||||
|
version: str,
|
||||||
|
report: GoldenEvalReport,
|
||||||
|
*,
|
||||||
|
actor: str,
|
||||||
|
) -> None:
|
||||||
|
try:
|
||||||
|
run = AgentAssetTestRun(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
asset_id=asset.id,
|
||||||
|
version=version,
|
||||||
|
test_type="golden",
|
||||||
|
status="completed",
|
||||||
|
passed=report.all_passed,
|
||||||
|
summary=(
|
||||||
|
f"golden set {report.passed_count}/{report.total} passed"
|
||||||
|
if report.total > 0
|
||||||
|
else "golden set empty, gate skipped"
|
||||||
|
),
|
||||||
|
input_json={"rule_code": getattr(asset, "rule_code", "") or ""},
|
||||||
|
result_json=report.to_dict(),
|
||||||
|
created_by=actor,
|
||||||
|
)
|
||||||
|
db.add(run)
|
||||||
|
db.commit()
|
||||||
|
except Exception:
|
||||||
|
logger.warning("golden test run 记录失败 asset_id=%s", asset.id, exc_info=True)
|
||||||
|
db.rollback()
|
||||||
@@ -1,104 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from app.services.embedding_provider import EmbeddingProvider, _runtime_model_config_from_dict
|
|
||||||
from app.services.knowledge_rag_runtime import KnowledgeRagError, RuntimeModelConfig
|
|
||||||
|
|
||||||
|
|
||||||
def _config(provider: str = "GLM") -> RuntimeModelConfig:
|
|
||||||
return RuntimeModelConfig(
|
|
||||||
slot="embedding",
|
|
||||||
provider=provider,
|
|
||||||
model="Embedding-3",
|
|
||||||
endpoint="https://open.bigmodel.cn/api/paas/v4/",
|
|
||||||
api_key="k",
|
|
||||||
capability="embedding",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_runtime_model_config_from_dict_maps_fields() -> None:
|
|
||||||
cfg = _runtime_model_config_from_dict(
|
|
||||||
{
|
|
||||||
"slot": "embedding",
|
|
||||||
"provider": "GLM",
|
|
||||||
"model": "Embedding-3",
|
|
||||||
"endpoint": "https://e",
|
|
||||||
"apiKey": "secret",
|
|
||||||
"capability": "embedding",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
assert cfg.api_key == "secret"
|
|
||||||
assert cfg.model == "Embedding-3"
|
|
||||||
|
|
||||||
|
|
||||||
def test_embed_empty_texts_returns_empty() -> None:
|
|
||||||
provider = EmbeddingProvider(_config())
|
|
||||||
assert provider.embed([]) == []
|
|
||||||
|
|
||||||
|
|
||||||
def test_embed_returns_vectors_and_caches_dimension() -> None:
|
|
||||||
provider = EmbeddingProvider(_config())
|
|
||||||
with patch(
|
|
||||||
"app.services.embedding_provider._request_embeddings_public",
|
|
||||||
return_value=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
|
|
||||||
) as mock_req:
|
|
||||||
vectors = provider.embed(["a", "b"])
|
|
||||||
assert vectors == [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
|
||||||
assert provider.dimension() == 3
|
|
||||||
calls_after_first_dimension = mock_req.call_count
|
|
||||||
# 第二次 dimension 不应再次请求
|
|
||||||
assert provider.dimension() == 3
|
|
||||||
assert mock_req.call_count == calls_after_first_dimension
|
|
||||||
|
|
||||||
|
|
||||||
def test_dimension_raises_on_invalid_vectors() -> None:
|
|
||||||
provider = EmbeddingProvider(_config())
|
|
||||||
with patch(
|
|
||||||
"app.services.embedding_provider._request_embeddings_public",
|
|
||||||
return_value=[],
|
|
||||||
):
|
|
||||||
with pytest.raises(KnowledgeRagError):
|
|
||||||
provider.dimension()
|
|
||||||
|
|
||||||
|
|
||||||
def test_request_embeddings_public_glm_branch() -> None:
|
|
||||||
cfg = _config("GLM")
|
|
||||||
with patch(
|
|
||||||
"app.services.embedding_provider._send_json_request",
|
|
||||||
return_value=(200, {"data": [{"embedding": [0.1, 0.2]}]}),
|
|
||||||
) as mock_send:
|
|
||||||
from app.services.embedding_provider import _request_embeddings_public
|
|
||||||
|
|
||||||
vectors = _request_embeddings_public(cfg, ["x"])
|
|
||||||
assert vectors == [[0.1, 0.2]]
|
|
||||||
called_url = mock_send.call_args.args[1]
|
|
||||||
assert called_url.endswith("/embeddings")
|
|
||||||
|
|
||||||
|
|
||||||
def test_request_embeddings_public_ollama_branch() -> None:
|
|
||||||
cfg = _config("Ollama")
|
|
||||||
with patch(
|
|
||||||
"app.services.embedding_provider._send_json_request",
|
|
||||||
return_value=(200, {"embeddings": [[0.5, 0.6]]}),
|
|
||||||
) as mock_send:
|
|
||||||
from app.services.embedding_provider import _request_embeddings_public
|
|
||||||
|
|
||||||
vectors = _request_embeddings_public(cfg, ["x"])
|
|
||||||
assert vectors == [[0.5, 0.6]]
|
|
||||||
called_url = mock_send.call_args.args[1]
|
|
||||||
assert called_url.endswith("/api/embed")
|
|
||||||
|
|
||||||
|
|
||||||
def test_request_embeddings_public_raises_on_http_error() -> None:
|
|
||||||
cfg = _config("GLM")
|
|
||||||
with patch(
|
|
||||||
"app.services.embedding_provider._send_json_request",
|
|
||||||
return_value=(500, {"message": "boom"}),
|
|
||||||
):
|
|
||||||
from app.services.embedding_provider import _request_embeddings_public
|
|
||||||
|
|
||||||
with pytest.raises(KnowledgeRagError):
|
|
||||||
_request_embeddings_public(cfg, ["x"])
|
|
||||||
@@ -1,214 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from collections.abc import Generator
|
|
||||||
from datetime import datetime
|
|
||||||
from decimal import Decimal
|
|
||||||
from types import SimpleNamespace
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from sqlalchemy import create_engine
|
|
||||||
from sqlalchemy.orm import Session, sessionmaker
|
|
||||||
from sqlalchemy.pool import StaticPool
|
|
||||||
|
|
||||||
from app.db.base import Base
|
|
||||||
from app.models.employee import Employee
|
|
||||||
from app.models.few_shot_sample import FewShotSample
|
|
||||||
from app.models.financial_record import ExpenseClaim
|
|
||||||
from app.models.risk_observation import RiskObservation
|
|
||||||
from app.schemas.risk_observation import RiskObservationFeedbackCreate
|
|
||||||
from app.services.few_shot_ingestion import FewShotIngestionService
|
|
||||||
from app.services.risk_observations import RiskObservationService
|
|
||||||
|
|
||||||
|
|
||||||
def _build_session() -> Session:
|
|
||||||
engine = create_engine(
|
|
||||||
"sqlite+pysqlite:///:memory:",
|
|
||||||
connect_args={"check_same_thread": False},
|
|
||||||
poolclass=StaticPool,
|
|
||||||
)
|
|
||||||
Base.metadata.create_all(bind=engine)
|
|
||||||
factory = sessionmaker(bind=engine, autoflush=False, autocommit=False)
|
|
||||||
return factory()
|
|
||||||
|
|
||||||
|
|
||||||
def _observation(db: Session, key: str = "risk:c1:dup") -> RiskObservation:
|
|
||||||
db.add(Employee(id="emp-1", employee_no="E1", name="员工", email="e@e.com", grade="P6"))
|
|
||||||
db.add(
|
|
||||||
ExpenseClaim(
|
|
||||||
id="c1",
|
|
||||||
claim_no="BX-001",
|
|
||||||
employee_id="emp-1",
|
|
||||||
employee_name="员工",
|
|
||||||
department_name="风控部",
|
|
||||||
expense_type="travel",
|
|
||||||
reason="客户拜访",
|
|
||||||
location="上海",
|
|
||||||
amount=Decimal("1000"),
|
|
||||||
currency="CNY",
|
|
||||||
occurred_at=datetime(2026, 1, 1),
|
|
||||||
submitted_at=datetime(2026, 1, 1),
|
|
||||||
status="submitted",
|
|
||||||
approval_stage="manager_review",
|
|
||||||
risk_flags_json=[],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
db.flush()
|
|
||||||
obs = RiskObservation(
|
|
||||||
observation_key=key,
|
|
||||||
subject_type="expense_claim",
|
|
||||||
subject_key="claim:c1",
|
|
||||||
claim_id="c1",
|
|
||||||
claim_no="BX-001",
|
|
||||||
risk_type="duplicate_invoice",
|
|
||||||
risk_signal="duplicate_invoice",
|
|
||||||
title="重复发票",
|
|
||||||
description="同一发票出现在多张报销单",
|
|
||||||
risk_score=86,
|
|
||||||
risk_level="high",
|
|
||||||
confidence_score=0.8,
|
|
||||||
source="financial_risk_graph",
|
|
||||||
algorithm_version="v1",
|
|
||||||
ontology_json={"domain": "expense", "scenario": "reimbursement"},
|
|
||||||
)
|
|
||||||
db.add(obs)
|
|
||||||
db.commit()
|
|
||||||
db.refresh(obs)
|
|
||||||
return obs
|
|
||||||
|
|
||||||
|
|
||||||
def test_ingest_confirmed_persists_sample_and_calls_store() -> None:
|
|
||||||
with _build_session() as db:
|
|
||||||
obs = _observation(db)
|
|
||||||
obs.feedback_status = "confirmed"
|
|
||||||
service = FewShotIngestionService(db)
|
|
||||||
fake_store = MagicMock()
|
|
||||||
fake_store.upsert.return_value = "vec-1"
|
|
||||||
with patch.object(service, "_store", return_value=fake_store):
|
|
||||||
sample = service.ingest_observation_feedback(
|
|
||||||
obs,
|
|
||||||
MagicMock(feedback_type="confirm", comment="确认重复发票", actor="audit"),
|
|
||||||
)
|
|
||||||
assert sample is not None
|
|
||||||
assert sample.label == "confirmed"
|
|
||||||
assert sample.sample_key == f"obs:{obs.id}"
|
|
||||||
assert "重复发票" in sample.case_text
|
|
||||||
assert "确认重复发票" in sample.conclusion_text
|
|
||||||
assert sample.vector_id == "vec-1"
|
|
||||||
fake_store.upsert.assert_called_once()
|
|
||||||
|
|
||||||
|
|
||||||
def test_ingest_false_positive_also_persisted() -> None:
|
|
||||||
with _build_session() as db:
|
|
||||||
obs = _observation(db, key="risk:c2:fp")
|
|
||||||
obs.feedback_status = "false_positive"
|
|
||||||
db.commit()
|
|
||||||
service = FewShotIngestionService(db)
|
|
||||||
with patch.object(service, "_store", return_value=MagicMock(upsert=MagicMock(return_value=None))):
|
|
||||||
sample = service.ingest_observation_feedback(
|
|
||||||
obs,
|
|
||||||
MagicMock(feedback_type="false_positive", comment="", actor="audit"),
|
|
||||||
)
|
|
||||||
assert sample is not None
|
|
||||||
assert sample.label == "false_positive"
|
|
||||||
assert "误报" in sample.conclusion_text
|
|
||||||
|
|
||||||
|
|
||||||
def test_ingest_ignored_label_returns_none() -> None:
|
|
||||||
with _build_session() as db:
|
|
||||||
obs = _observation(db)
|
|
||||||
obs.feedback_status = "ignored"
|
|
||||||
service = FewShotIngestionService(db)
|
|
||||||
assert service.ingest_observation_feedback(obs, MagicMock()) is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_ingest_is_idempotent_on_duplicate_sample_key() -> None:
|
|
||||||
with _build_session() as db:
|
|
||||||
obs = _observation(db)
|
|
||||||
service = FewShotIngestionService(db)
|
|
||||||
store = MagicMock()
|
|
||||||
store.upsert.side_effect = ["vec-1", "vec-2"]
|
|
||||||
with patch.object(service, "_store", return_value=store):
|
|
||||||
obs.feedback_status = "confirmed"
|
|
||||||
first = service.ingest_observation_feedback(
|
|
||||||
obs, MagicMock(feedback_type="confirm", comment="第一次", actor="a")
|
|
||||||
)
|
|
||||||
# 模拟后续被改判为误报
|
|
||||||
obs.feedback_status = "false_positive"
|
|
||||||
second = service.ingest_observation_feedback(
|
|
||||||
obs, MagicMock(feedback_type="false_positive", comment="改判", actor="a")
|
|
||||||
)
|
|
||||||
assert first is not None and second is not None
|
|
||||||
assert first.id == second.id # 同一行更新
|
|
||||||
from sqlalchemy import select
|
|
||||||
|
|
||||||
count = db.scalar(select(FewShotSample).where(FewShotSample.sample_key == f"obs:{obs.id}"))
|
|
||||||
assert count is not None
|
|
||||||
assert second.label == "false_positive"
|
|
||||||
|
|
||||||
|
|
||||||
def test_create_feedback_hook_triggers_ingestion() -> None:
|
|
||||||
with _build_session() as db:
|
|
||||||
service = RiskObservationService(db)
|
|
||||||
obs = _observation(db)
|
|
||||||
ingest_calls: list = []
|
|
||||||
|
|
||||||
def _spy_ingest(o, f):
|
|
||||||
ingest_calls.append((o.id, f.feedback_type))
|
|
||||||
return None
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"app.services.few_shot_ingestion.FewShotIngestionService.ingest_observation_feedback",
|
|
||||||
side_effect=_spy_ingest,
|
|
||||||
):
|
|
||||||
service.create_feedback(
|
|
||||||
obs.observation_key,
|
|
||||||
RiskObservationFeedbackCreate(feedback_type="confirm", actor="audit"),
|
|
||||||
)
|
|
||||||
assert len(ingest_calls) == 1
|
|
||||||
assert ingest_calls[0][1] == "confirm"
|
|
||||||
|
|
||||||
|
|
||||||
def test_create_feedback_hook_skipped_for_comment_feedback() -> None:
|
|
||||||
with _build_session() as db:
|
|
||||||
service = RiskObservationService(db)
|
|
||||||
obs = _observation(db)
|
|
||||||
with patch(
|
|
||||||
"app.services.few_shot_ingestion.FewShotIngestionService.ingest_observation_feedback"
|
|
||||||
) as mock_ingest:
|
|
||||||
service.create_feedback(
|
|
||||||
obs.observation_key,
|
|
||||||
RiskObservationFeedbackCreate(feedback_type="comment", action="note", actor="audit"),
|
|
||||||
)
|
|
||||||
mock_ingest.assert_not_called()
|
|
||||||
|
|
||||||
|
|
||||||
def test_create_feedback_hook_swallows_ingestion_failure() -> None:
|
|
||||||
with _build_session() as db:
|
|
||||||
service = RiskObservationService(db)
|
|
||||||
obs = _observation(db)
|
|
||||||
with patch(
|
|
||||||
"app.services.few_shot_ingestion.FewShotIngestionService.ingest_observation_feedback",
|
|
||||||
side_effect=RuntimeError("boom"),
|
|
||||||
):
|
|
||||||
# 不应抛异常
|
|
||||||
feedback = service.create_feedback(
|
|
||||||
obs.observation_key,
|
|
||||||
RiskObservationFeedbackCreate(feedback_type="confirm", actor="audit"),
|
|
||||||
)
|
|
||||||
assert feedback.feedback_type == "confirm"
|
|
||||||
|
|
||||||
|
|
||||||
def test_create_feedback_hook_respects_feature_flag(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
monkeypatch.setenv("FEW_SHOT_INJECTION_ENABLED", "false")
|
|
||||||
with _build_session() as db:
|
|
||||||
service = RiskObservationService(db)
|
|
||||||
obs = _observation(db)
|
|
||||||
with patch(
|
|
||||||
"app.services.few_shot_ingestion.FewShotIngestionService.ingest_observation_feedback"
|
|
||||||
) as mock_ingest:
|
|
||||||
service.create_feedback(
|
|
||||||
obs.observation_key,
|
|
||||||
RiskObservationFeedbackCreate(feedback_type="confirm", actor="audit"),
|
|
||||||
)
|
|
||||||
mock_ingest.assert_not_called()
|
|
||||||
@@ -1,119 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from app.services.few_shot_retrieval import FewShotRetriever
|
|
||||||
from app.services.few_shot_store import FewShotStore
|
|
||||||
from app.services.risk_rule_generation_prompt import build_risk_rule_compiler_messages
|
|
||||||
|
|
||||||
|
|
||||||
def _hit(score: float, label: str, conclusion: str, risk_type: str = "duplicate_invoice") -> dict:
|
|
||||||
return {
|
|
||||||
"sample_id": "s1",
|
|
||||||
"score": score,
|
|
||||||
"label": label,
|
|
||||||
"domain": "expense",
|
|
||||||
"risk_type": risk_type,
|
|
||||||
"conclusion_text": conclusion,
|
|
||||||
"payload_json": {
|
|
||||||
"risk_signal": risk_type,
|
|
||||||
"risk_level": "high",
|
|
||||||
"ontology": {"scenario": "reimbursement"},
|
|
||||||
"feedback_comment": "",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def test_retrieve_returns_injection_blocks_with_token_budget() -> None:
|
|
||||||
store = MagicMock(spec=FewShotStore)
|
|
||||||
store.search.return_value = [
|
|
||||||
_hit(0.9, "confirmed", "确认重复发票需拦截"),
|
|
||||||
_hit(0.8, "false_positive", "此情形属于正常拆单不拦截"),
|
|
||||||
_hit(0.7, "confirmed", "确认重复发票需拦截"), # 重复结论应被去重
|
|
||||||
]
|
|
||||||
retriever = FewShotRetriever(store)
|
|
||||||
blocks = retriever.retrieve_for_risk_rule_generation(
|
|
||||||
domain="expense", natural_language="同一发票重复报销"
|
|
||||||
)
|
|
||||||
assert len(blocks) == 2
|
|
||||||
assert blocks[0]["score"] == 0.9
|
|
||||||
assert blocks[0]["label"] == "confirmed"
|
|
||||||
assert blocks[0]["source"] == "historical_confirmed"
|
|
||||||
assert blocks[1]["label"] == "false_positive"
|
|
||||||
# 去重:第三条结论与第一条相同,应被过滤
|
|
||||||
conclusions = [b["conclusion"] for b in blocks]
|
|
||||||
assert len(set(conclusions)) == len(conclusions)
|
|
||||||
|
|
||||||
|
|
||||||
def test_retrieve_empty_case_text_returns_empty() -> None:
|
|
||||||
store = MagicMock(spec=FewShotStore)
|
|
||||||
retriever = FewShotRetriever(store)
|
|
||||||
assert retriever.retrieve_for_risk_rule_generation(natural_language="") == []
|
|
||||||
store.search.assert_not_called()
|
|
||||||
|
|
||||||
|
|
||||||
def test_retrieve_truncates_overlong_conclusion() -> None:
|
|
||||||
store = MagicMock(spec=FewShotStore)
|
|
||||||
long_text = "长结论" * 500
|
|
||||||
store.search.return_value = [
|
|
||||||
_hit(0.9, "confirmed", long_text),
|
|
||||||
]
|
|
||||||
retriever = FewShotRetriever(store)
|
|
||||||
blocks = retriever.retrieve_for_risk_rule_generation(natural_language="x")
|
|
||||||
assert len(blocks) == 1
|
|
||||||
# 超长结论应被截断到单条上限
|
|
||||||
from app.services.few_shot_retrieval import SINGLE_SAMPLE_MAX_CHARS
|
|
||||||
|
|
||||||
assert len(blocks[0]["conclusion"]) <= SINGLE_SAMPLE_MAX_CHARS
|
|
||||||
|
|
||||||
|
|
||||||
def test_build_prompt_merges_few_shot_into_examples() -> None:
|
|
||||||
samples = [
|
|
||||||
{
|
|
||||||
"source": "historical_confirmed",
|
|
||||||
"label": "confirmed",
|
|
||||||
"domain": "expense",
|
|
||||||
"risk_type": "duplicate_invoice",
|
|
||||||
"conclusion": "确认重复发票",
|
|
||||||
"context": {"risk_signal": "duplicate_invoice"},
|
|
||||||
}
|
|
||||||
]
|
|
||||||
messages = build_risk_rule_compiler_messages(
|
|
||||||
domain="expense",
|
|
||||||
domain_label="报销",
|
|
||||||
business_stage="reimbursement",
|
|
||||||
business_stage_label="报销",
|
|
||||||
expense_category=None,
|
|
||||||
expense_category_label="",
|
|
||||||
natural_language="重复发票规则",
|
|
||||||
available_fields=[{"key": "attachment.invoice_no", "label": "发票号", "type": "string", "source": "attachment"}],
|
|
||||||
few_shot_samples=samples,
|
|
||||||
)
|
|
||||||
assert len(messages) == 2
|
|
||||||
payload = json.loads(messages[1]["content"])
|
|
||||||
examples = payload["examples"]
|
|
||||||
# 前两条是历史样本,后面是内置 examples
|
|
||||||
assert examples[0]["source"] == "historical_confirmed"
|
|
||||||
assert examples[0]["conclusion"] == "确认重复发票"
|
|
||||||
# 内置 example 仍存在(无 source 字段)
|
|
||||||
assert any("user_rule" in ex for ex in examples)
|
|
||||||
|
|
||||||
|
|
||||||
def test_build_prompt_without_few_shot_is_backward_compatible() -> None:
|
|
||||||
messages = build_risk_rule_compiler_messages(
|
|
||||||
domain="expense",
|
|
||||||
domain_label="报销",
|
|
||||||
business_stage="reimbursement",
|
|
||||||
business_stage_label="报销",
|
|
||||||
expense_category=None,
|
|
||||||
expense_category_label="",
|
|
||||||
natural_language="重复发票规则",
|
|
||||||
available_fields=[],
|
|
||||||
)
|
|
||||||
payload = json.loads(messages[1]["content"])
|
|
||||||
examples = payload["examples"]
|
|
||||||
# 无 few_shot_samples 时 examples 里不应有 historical_confirmed 来源
|
|
||||||
assert all(ex.get("source") != "historical_confirmed" for ex in examples)
|
|
||||||
262
server/tests/test_risk_rule_golden_evaluator.py
Normal file
262
server/tests/test_risk_rule_golden_evaluator.py
Normal file
@@ -0,0 +1,262 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Generator
|
||||||
|
from datetime import datetime
|
||||||
|
from decimal import Decimal
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
from sqlalchemy.pool import StaticPool
|
||||||
|
|
||||||
|
from app.db.base import Base
|
||||||
|
from app.models.agent_asset import AgentAsset, AgentAssetTestRun
|
||||||
|
from app.models.employee import Employee
|
||||||
|
from app.models.financial_record import ExpenseClaim
|
||||||
|
from app.models.golden_case import GoldenCase
|
||||||
|
from app.services.risk_rule_golden_evaluator import (
|
||||||
|
GoldenEvalReport,
|
||||||
|
RiskRuleGoldenEvaluator,
|
||||||
|
_aggregate,
|
||||||
|
_run_single_case,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_session() -> Session:
|
||||||
|
engine = create_engine(
|
||||||
|
"sqlite+pysqlite:///:memory:",
|
||||||
|
connect_args={"check_same_thread": False},
|
||||||
|
poolclass=StaticPool,
|
||||||
|
)
|
||||||
|
Base.metadata.create_all(bind=engine)
|
||||||
|
factory = sessionmaker(bind=engine, autoflush=False, autocommit=False)
|
||||||
|
return factory()
|
||||||
|
|
||||||
|
|
||||||
|
def _keyword_manifest() -> dict:
|
||||||
|
"""一个简单的 keyword_match_v1 manifest:reason 含"虚假"则命中。"""
|
||||||
|
|
||||||
|
return {
|
||||||
|
"rule_code": "risk.test.keyword",
|
||||||
|
"template_key": "keyword_match_v1",
|
||||||
|
"inputs": {
|
||||||
|
"fields": [
|
||||||
|
{"key": "claim.reason", "label": "事由", "type": "text", "source": "claim"},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"params": {
|
||||||
|
"keywords": ["虚假"],
|
||||||
|
"field_keys": ["claim.reason"],
|
||||||
|
"search_fields": ["claim.reason"],
|
||||||
|
},
|
||||||
|
"outcomes": {"fail": {"severity": "high", "risk_score": 80}},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _golden_case(
|
||||||
|
case_key: str,
|
||||||
|
*,
|
||||||
|
reason: str,
|
||||||
|
expected_hit: bool,
|
||||||
|
rule_code: str = "risk.test.keyword",
|
||||||
|
) -> GoldenCase:
|
||||||
|
return GoldenCase(
|
||||||
|
case_key=case_key,
|
||||||
|
rule_code=rule_code,
|
||||||
|
name=f"case-{case_key}",
|
||||||
|
values_json={"claim.reason": reason},
|
||||||
|
expected_hit=expected_hit,
|
||||||
|
status="active",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_single_case_hit_matches() -> None:
|
||||||
|
result = _run_single_case(
|
||||||
|
_keyword_manifest(),
|
||||||
|
values={"claim.reason": "虚假发票报销"},
|
||||||
|
expected_hit=True,
|
||||||
|
expected_severity="high",
|
||||||
|
)
|
||||||
|
assert result.actual_hit is True
|
||||||
|
assert result.passed is True
|
||||||
|
assert result.actual_severity == "high"
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_single_case_no_hit_matches() -> None:
|
||||||
|
result = _run_single_case(
|
||||||
|
_keyword_manifest(),
|
||||||
|
values={"claim.reason": "正常差旅报销"},
|
||||||
|
expected_hit=False,
|
||||||
|
expected_severity="",
|
||||||
|
)
|
||||||
|
assert result.actual_hit is False
|
||||||
|
assert result.passed is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_single_case_mismatch_fails() -> None:
|
||||||
|
result = _run_single_case(
|
||||||
|
_keyword_manifest(),
|
||||||
|
values={"claim.reason": "虚假发票"},
|
||||||
|
expected_hit=False, # 期望不命中,但实际命中
|
||||||
|
expected_severity="",
|
||||||
|
)
|
||||||
|
assert result.actual_hit is True
|
||||||
|
assert result.passed is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_single_case_severity_mismatch_fails() -> None:
|
||||||
|
result = _run_single_case(
|
||||||
|
_keyword_manifest(),
|
||||||
|
values={"claim.reason": "虚假发票"},
|
||||||
|
expected_hit=True,
|
||||||
|
expected_severity="critical", # 实际是 high
|
||||||
|
)
|
||||||
|
assert result.passed is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_aggregate_empty_returns_passed() -> None:
|
||||||
|
report = _aggregate([])
|
||||||
|
assert report.total == 0
|
||||||
|
assert report.all_passed is True
|
||||||
|
assert report.accuracy == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_aggregate_all_passed() -> None:
|
||||||
|
from app.services.risk_rule_golden_evaluator import GoldenCaseResult
|
||||||
|
|
||||||
|
results = [
|
||||||
|
GoldenCaseResult("1", "a", True, True, "high", "high", True),
|
||||||
|
GoldenCaseResult("2", "b", False, False, "", "none", True),
|
||||||
|
]
|
||||||
|
report = _aggregate(results)
|
||||||
|
assert report.total == 2
|
||||||
|
assert report.passed_count == 2
|
||||||
|
assert report.accuracy == 1.0
|
||||||
|
assert report.all_passed is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_aggregate_with_failure() -> None:
|
||||||
|
from app.services.risk_rule_golden_evaluator import GoldenCaseResult
|
||||||
|
|
||||||
|
results = [
|
||||||
|
GoldenCaseResult("1", "a", True, True, "high", "high", True),
|
||||||
|
GoldenCaseResult("2", "b", True, False, "high", "none", False), # FP
|
||||||
|
]
|
||||||
|
report = _aggregate(results)
|
||||||
|
assert report.passed_count == 1
|
||||||
|
assert report.failed_count == 1
|
||||||
|
assert report.accuracy == 0.5
|
||||||
|
assert report.all_passed is False
|
||||||
|
assert report.precision == 0.5 # 1/(1+1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_evaluate_for_rule_empty_returns_passed() -> None:
|
||||||
|
with _build_session() as db:
|
||||||
|
report = RiskRuleGoldenEvaluator().evaluate_for_rule(db, _keyword_manifest(), "risk.test.keyword")
|
||||||
|
assert report.total == 0
|
||||||
|
assert report.all_passed is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_evaluate_for_rule_all_pass() -> None:
|
||||||
|
with _build_session() as db:
|
||||||
|
db.add(_golden_case("g1", reason="虚假发票", expected_hit=True))
|
||||||
|
db.add(_golden_case("g2", reason="正常报销", expected_hit=False))
|
||||||
|
db.commit()
|
||||||
|
report = RiskRuleGoldenEvaluator().evaluate_for_rule(db, _keyword_manifest(), "risk.test.keyword")
|
||||||
|
assert report.total == 2
|
||||||
|
assert report.all_passed is True
|
||||||
|
assert report.accuracy == 1.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_evaluate_for_rule_with_failure() -> None:
|
||||||
|
with _build_session() as db:
|
||||||
|
db.add(_golden_case("g1", reason="虚假发票", expected_hit=False)) # 期望不命中但实际命中
|
||||||
|
db.add(_golden_case("g2", reason="正常报销", expected_hit=True)) # 期望命中但实际不命中
|
||||||
|
db.commit()
|
||||||
|
report = RiskRuleGoldenEvaluator().evaluate_for_rule(db, _keyword_manifest(), "risk.test.keyword")
|
||||||
|
assert report.total == 2
|
||||||
|
assert report.all_passed is False
|
||||||
|
assert report.failed_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
def _asset(asset_id: str, code: str) -> AgentAsset:
|
||||||
|
return AgentAsset(
|
||||||
|
id=asset_id,
|
||||||
|
code=code,
|
||||||
|
name=code,
|
||||||
|
asset_type="rule",
|
||||||
|
domain="expense",
|
||||||
|
owner="tester",
|
||||||
|
status="review",
|
||||||
|
working_version="v1",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_require_pass_passes_when_all_green() -> None:
|
||||||
|
with _build_session() as db:
|
||||||
|
asset = _asset("a1", "R1")
|
||||||
|
db.add(asset)
|
||||||
|
db.add(_golden_case("g1", reason="虚假", expected_hit=True))
|
||||||
|
db.commit()
|
||||||
|
report = RiskRuleGoldenEvaluator().require_pass(
|
||||||
|
db, asset, "v1", _keyword_manifest(), "risk.test.keyword", actor="tester"
|
||||||
|
)
|
||||||
|
assert report.all_passed is True
|
||||||
|
# 应写一条 test_type='golden' 记录
|
||||||
|
run = db.query(AgentAssetTestRun).filter_by(asset_id="a1", test_type="golden").one()
|
||||||
|
assert run.passed is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_require_pass_raises_on_failure() -> None:
|
||||||
|
with _build_session() as db:
|
||||||
|
asset = _asset("a2", "R2")
|
||||||
|
db.add(asset)
|
||||||
|
db.add(_golden_case("g1", reason="虚假", expected_hit=False)) # 会失败
|
||||||
|
db.commit()
|
||||||
|
with pytest.raises(PermissionError):
|
||||||
|
RiskRuleGoldenEvaluator().require_pass(
|
||||||
|
db, asset, "v1", _keyword_manifest(), "risk.test.keyword", actor="tester"
|
||||||
|
)
|
||||||
|
run = db.query(AgentAssetTestRun).filter_by(asset_id="a2", test_type="golden").one()
|
||||||
|
assert run.passed is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_require_pass_empty_golden_set_passes() -> None:
|
||||||
|
with _build_session() as db:
|
||||||
|
asset = _asset("a3", "R3")
|
||||||
|
db.add(asset)
|
||||||
|
db.commit()
|
||||||
|
report = RiskRuleGoldenEvaluator().require_pass(
|
||||||
|
db, asset, "v1", _keyword_manifest(), "risk.test.keyword", actor="tester"
|
||||||
|
)
|
||||||
|
assert report.total == 0
|
||||||
|
assert report.all_passed is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_require_pass_respects_feature_flag(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
monkeypatch.setenv("GOLDEN_SET_GATE_ENABLED", "false")
|
||||||
|
with _build_session() as db:
|
||||||
|
asset = _asset("a4", "R4")
|
||||||
|
db.add(asset)
|
||||||
|
db.add(_golden_case("g1", reason="虚假", expected_hit=False)) # 本应失败
|
||||||
|
db.commit()
|
||||||
|
# 门禁关闭,应放行不抛异常
|
||||||
|
report = RiskRuleGoldenEvaluator().require_pass(
|
||||||
|
db, asset, "v1", _keyword_manifest(), "risk.test.keyword", actor="tester"
|
||||||
|
)
|
||||||
|
assert report.total == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_require_pass_swallows_evaluator_exception() -> None:
|
||||||
|
with _build_session() as db:
|
||||||
|
asset = _asset("a5", "R5")
|
||||||
|
db.add(asset)
|
||||||
|
db.commit()
|
||||||
|
evaluator = RiskRuleGoldenEvaluator()
|
||||||
|
with patch.object(evaluator, "evaluate_for_rule", side_effect=RuntimeError("boom")):
|
||||||
|
report = evaluator.require_pass(
|
||||||
|
db, asset, "v1", _keyword_manifest(), "risk.test.keyword", actor="tester"
|
||||||
|
)
|
||||||
|
assert report.total == 0
|
||||||
|
assert report.all_passed is True # 降级放行
|
||||||
Reference in New Issue
Block a user