3 Commits

Author SHA1 Message Date
caoxiaozhu
52d57c3be7 test(flywheel): 补 few-shot 飞轮单测并沉淀开发文档
- embedding_provider:GLM/Ollama 分支、维度缓存、HTTP 错误降级
- few_shot_ingestion:confirmed/false_positive 入库、ignored 跳过、幂等去重、
  create_feedback hook 触发、feature flag、吞异常
- few_shot_retrieval:去重、token 预算、超长截断;prompt 注入合并 examples + 向后兼容
- 容器内新增测试 20 passed;回归测试 35 passed(RAG/risk_observations/rule_generation)
- 沉淀 document/development/2026-07-03/feature/ai-data-flywheel 概念文档与 TODO,
  飞轮 1 已勾选证据,飞轮 2-6 待后续迭代
2026-07-03 13:56:21 +08:00
caoxiaozhu
3a9d154783 feat(flywheel): few-shot 在线检索注入打通风险规则编译链路
- 新增 FewShotStore:独立 Qdrant collection few_shot_samples,向量 upsert/search/delete,
  全程失败降级不阻塞主链路
- 新增 FewShotIngestionService:RiskObservation confirmed/false_positive → FewShotSample +
  向量,带 sample_key 幂等去重
- 新增 FewShotRetriever:按 case 特征检索相似历史样本,去重 + token 预算 + 单条字符上限裁剪
- risk_observations.create_feedback commit 后挂 hook 自动入库,带 feature flag 和 try/except 兜底
- risk_rule_generation_prompt 新增 few_shot_samples 可选参数,合并进 examples 并标
  source=historical_confirmed;risk_rule_generation 构造 prompt 前调 retriever,失败降级为空
2026-07-03 13:55:52 +08:00
caoxiaozhu
765cfb40f3 feat(flywheel): 抽公共 EmbeddingProvider 并新增 FewShotSample 模型
- 从 knowledge_rag_runtime 抽出 embedding 调用逻辑为独立 EmbeddingProvider,
  复用现有 HTTP 纯函数,RAG 路径零回归
- 新增 FewShotSample 表模型(样本池),注册到 db/base.py 和 models/__init__.py
  供 few-shot 飞轮沉淀已确认风险观测
2026-07-03 13:55:39 +08:00
21 changed files with 1505 additions and 841 deletions

View File

@@ -0,0 +1,190 @@
# AI 数据飞轮 概念文档
更新时间2026-07-03
文档路径document/development/2026-07-03/feature/ai-data-flywheel/CONCEPT.md
## 功能一句话
把用户反馈、人工修正、风险样本、评测结果自动沉淀并回流到下一次 LLM 推理与规则生成,让费控系统在不停机的情况下持续提升准确率、降低误报率和人工干预率。
## 背景与问题
- 当前现状:项目已具备"聪明"的骨架——RAG`knowledge_rag_runtime.py` + Qdrant、风险规则自动生成`risk_rule_generation*.py`)、反馈样本沉淀(`skills/domain/false-positive-sample-accumulator` 等 3 个 accumulator、规则回放评测`risk-algorithm-replay-evaluator`)、行为画像(`employee_behavior_profile*`)、用户反馈表(`agent_feedback.py` + `AgentOperationFeedback`)。
- 用户痛点:样本在往 accumulator 池子里堆,但**下一次 LLM 推理时并没有把这些样本检索出来当 few-shot 喂进去**prompt 散落在各 `*_prompt.py`,没有版本号、没有在线 A/B、没有回归门禁OCR 抽取的人工修正值没有回流成评测/训练数据;低分反馈只汇总不归因。
- 业务影响:系统每次推理都从"初始水平"出发,无法把历史踩过的坑转化成下一次的能力;改 prompt / 改规则无法证明是否变好,存在隐性回归风险;运营和算法同事看不到"系统在进步"的证据。
- 为什么现在需要做:飞轮骨架已齐,缺的是把"样本池 → 检索注入 + 评测门禁 → prompt/规则版本"这段断开的箭头接上。补上后整张图就转起来,且改动集中在 prompt 构造层 + 新增 eval 目录,不动业务主链路,风险低。
## 目标与非目标
### 目标
- [G1] few-shot 在线检索注入:推理前从样本池按 case 特征做向量检索,取 top-k 历史样本(含人工结论)拼进 system prompt。
- [G2] 黄金评测集 + 自动回归门禁:版本化 golden setprompt/规则变更后在 golden set 上自动跑分,分数不达标禁止发布。
- [G3] Prompt 版本化 + Canary A/Bprompt 进表带版本号,支持 stable / canary 流量切分,反馈分数对比。
- [G4] 抽取修正回流:附件/明细字段的人工修正值记录为 diff沉淀为抽取评测集与 few-shot 样本。
- [G5] 低分反馈自动归因:低分反馈触发归因 agent拉 trace 诊断错误环节并生成改进任务。
- [G6] AI 智商看板:每周自动跑 golden set输出准确率/召回率/误报率/人工干预率随时间的曲线。
### 非目标
- [NG1] 本轮不做模型微调 / 自训练:只走 prompt 侧的 in-context learning + 规则学习。
- [NG2] 本轮不改变现有业务主链路(申请单、报销、审批)的接口契约。
- [NG3] 本轮不替换 Qdrant / LightRAG 底座,复用现有向量存储与 embedding 配置。
- [NG4] 政策新鲜度检测(外部政策变更 → 自动重生成规则)后续再评估,本轮只在评测门禁侧预留接口。
## 用户与场景
- 目标用户:
1. 报销人 / 申请人:感知到系统越来越准,少打回、少补件。
2. 财务审批人:误报率下降,审批被打断的次数减少。
3. 算法/运营同学:能看到智商曲线、能灰度上线 prompt、能跑回归评测。
- 使用入口:
- 推理时自动注入 few-shot对用户透明
- 后台 Canary 控制台(运营切流量、看分数)。
- AI 智商看板(周报 / 在线查询)。
- 核心场景:
1. 用户提交报销 → 系统预审 → 预审 prompt 自动注入相似历史误报样本 → 给出更准结论。
2. 算法同学改了 risk rule 生成 prompt → 发布前自动跑 golden set → 不达标被门禁拦下。
3. 用户给低分 → 归因 agent 诊断"是检索没召回 / 规则误判 / 回复格式问题"→ 自动建改进任务并回写样本池。
4. 审批人改了 OCR 抽错的金额 → diff 自动沉淀 → 下次同类票据抽取 prompt 多一条 few-shot。
- 异常场景:
- 样本池为空或检索失败 → 退化为无 few-shot 推理,不阻塞主链路。
- 评测门禁服务不可用 → 默认放行 stablecanary 自动暂停。
- Canary 候选 prompt 分数劣化 → 自动回滚到 stable。
## 功能能力
- [C1] 输入能力:消费 accumulator 样本池、`AgentOperationFeedback`、附件修正 diff、trace 数据作为飞轮原料。
- [C2] 处理能力:样本检索(向量 + 元数据过滤)、评测打分(准确率/召回率/误报率/F1、Canary 流量切分、低分归因。
- [C3] 输出能力few-shot 注入后的 messages、评测报告、智商曲线、改进任务、归因结论。
- [C4] 状态与权限:样本带"人工已确认"标签才进可注入集合prompt 版本有 stable/canary/pinned 状态;评测门禁可由运营关闭(审计可见)。
- [C5] 边界与降级检索失败、评测失败、Canary 失败均降级到 stable不阻塞业务推理。
## 方案设计
### 前端
- 页面/组件:
- AI 智商看板(新页面,复用 `finance-report` 看板骨架):准确率/召回率/误报率/人工干预率随时间曲线 + golden set 覆盖度。
- Canary 控制台(并入 `SettingsView` / `PoliciesView`):列出各场景 prompt 版本、流量比例、当前分数、一键回滚。
- 交互状态:加载/空态(样本不足)/错误态(评测失败)/权限态(仅算法运营)。
- 展示规则:曲线按场景(差旅/报销/预算分面Canary 显示置信区间,差异不显著时标注。
- 降级处理:看板数据不可用时提示"数据生成中",不报错。
### 后端
- 接口/服务(新增,按职责拆分,单文件 ≤ 800 行):
- `services/few_shot_retrieval.py`:样本检索器,复用 Qdrant输入 case 特征 → 输出 top-k 样本(带人工结论)。
- `services/prompt_registry.py`prompt 版本注册中心,按场景 + 策略latest/canary/pinned取 prompt。
- `services/eval_harness.py`:在 golden set 上跑评测,输出指标;被发布门禁和智商看板共用。
- `services/feedback_attribution.py`:低分归因 agent复用 `AgentTraceCenter` 数据。
- `services/extraction_correction_recorder.py`:记录 OCR 抽取字段的人工修正 diff。
- 改造点(在现有 prompt 构造文件加 inject 钩子,不改业务接口):
- `risk_rule_generation_prompt.py``user_agent_application.py``expense_claim_pre_review.py``document_intelligence_rules.py``ontology_extraction.py` 的 prompt 构造处。
- 权限与校验Canary 控制台仅算法/运营角色;门禁关闭需审计日志。
- 持久化新表Alembic 迁移):
- `prompt_version`id / scene / content / version / status(stable/canary/pinned) / eval_score / created_by / created_at。
- `golden_set`id / scene / case_payload / expected / source(accumulator/manual) / confirmed / version。
- `extraction_correction`id / attachment_id / field / raw_value / corrected_value / operator / created_at。
- `eval_run`id / prompt_version_id / scene / metrics_json / started_at / finished_at。
- 降级处理:所有飞轮组件故障均降级到无 few-shot + stable prompt主链路不阻塞。
### 算法与规则
- 输入case 特征向量(场景标签 + 文本摘要 + 关键字段、golden set、反馈样本、修正 diff。
- 流程:
1. 推理前:`few_shot_retrieval` 检索 top-k → 拼 system prompt。
2. 推理后:结果 + 反馈写入 accumulator。
3. 发布前:`eval_harness` 在 golden set 上跑分 → 门禁判定。
4. 低分触发:`feedback_attribution` 归因 → 改进任务回写样本池。
- 输出few-shot 样本块、评测指标、归因结论、智商曲线数据点。
- 解释few-shot 注入在 prompt 中保留"参考案例(历史已确认)"段落,可追溯;评测报告附错误 case 列表;归因输出错误环节标签 + 证据 trace 片段。
### 数据与契约
- 核心字段scene、case_signature、few_shot_samples、metrics(acc/recall/fpr/f1)、prompt_version_id、status。
- 状态枚举:
- prompt: `stable` / `canary` / `pinned` / `archived`
- golden case: `draft` / `confirmed` / `deprecated`
- eval: `pass` / `fail` / `blocked`
- 兼容策略prompt_registry 找不到版本时回退到当前硬编码 prompt保证向后兼容
- 版本/审计:每次 prompt / 规则 / golden set 变更记 `eval_run`,可回放历史。
## 算法与公式
### few-shot 检索排序
```text
score(sample, case) = α * sim(emb(sample), emb(case)) + β * match(meta(sample), meta(case))
```
变量说明:
- score样本与当前 case 的综合相似度。
- sim余弦相似度复用 Qdrant 现有 embedding。
- match元数据硬匹配得分场景同 / 域同 / 级别同),取 0 或 1。
- α、β:权重,默认 α=0.8、β=0.2,可在 prompt_registry 中按场景覆盖。
- 适用边界:仅取 `confirmed=true` 的样本top-k 默认 k=3按 token 预算动态裁剪。
### 评测指标
```text
precision = TP / (TP + FP)
recall = TP / (TP + FN)
f1 = 2 * precision * recall / (precision + recall)
```
变量说明:
- TP/FP/FN在 golden set 上推理结论与 expected 比对得出(结论为风险标记/字段值/分类标签三类场景各有比对器)。
- 发布门禁默认阈值recall ≥ 上一版 stable 的 recall 且 f1 不下降超过 2 个百分点,否则 `fail`
## 测试方案
后端:
- `few_shot_retrieval` 单测:样本池空 / 检索失败 / top-k 截断 / 仅取 confirmed 样本。
- `eval_harness` 单测golden set 跑分指标正确性、门禁通过/拦截逻辑、空 golden set 降级。
- `prompt_registry` 单测按策略取版本、回退到硬编码、Canary 流量切分比例。
- `feedback_attribution` 单测mock trace 数据,归因标签正确性。
前端:
- AI 智商看板视图模型:空态、加载态、错误态、曲线渲染。
- Canary 控制台:列表、切流量、回滚交互。
集成:
- 端到端:构造一份 golden set → 改 prompt → 发布被门禁拦截 / 通过 → 智商看板出现新数据点。
- 容器内运行:`docker exec -w /app -e SERVER_VENV_DIR=/tmp/x-financial-server-venv local-x-financial-linux /tmp/x-financial-server-venv/bin/pytest -q server/tests/...`,超时 60s。
手工验证:
- 在 AI 工作台触发一次预审,确认 prompt 中出现 few-shot 块。
- 在 Canary 控制台发布一版劣化 prompt确认被门禁拦下。
## 指标与验收
- [A1] 功能验收:推理时 prompt 中可见 few-shot 块,且样本来自 confirmed 池。
- [A2] 性能指标few-shot 检索 ≤ 200msP95不显著拖慢主链路eval 单场景 ≤ 60s。
- [A3] 质量指标golden set 覆盖至少 5 个核心场景;门禁能正确拦截劣化 prompt。
- [A4] 安全/权限指标Canary 控制台仅算法/运营可操作;门禁关闭记审计日志。
- [A5] 可观测性AI 智商看板按周生成曲线;每次 eval_run 可回放。
## 风险与开放问题
- 风险:
- few-shot 注入增加 prompt 长度,可能触发 token 上限或拖慢推理 → 用 token 预算裁剪 + P95 监控兜底。
- 样本池噪音(错误标注)污染推理 → 只取 confirmed 样本 + 评测门禁把关。
- 评测 golden set 与线上分布漂移 → 季度复审 golden set标注漂移度。
- 已处理依赖:复用 Qdrant / LightRAG / accumulator / AgentTraceCenter / risk_rule_generation 现有能力。
- 待确认:
- Canary 流量切分的具体比例(建议 90/10需与业务确认。
- 智商看板放哪个一级菜单Settings 还是独立"AI 运营"菜单)。
- 政策新鲜度检测是否本轮接入。
- 降级策略:任何飞轮组件故障 → 无 few-shot + stable prompt + 跳过门禁(仅 stable保证业务连续。
## 本轮实现记录
- 2026-07-03完成数据飞轮概念文档与开发 TODO 拆分,作为后续改造的总纲。

View File

@@ -0,0 +1,98 @@
# AI 数据飞轮 开发 TODO
更新时间2026-07-03
文档路径document/development/2026-07-03/feature/ai-data-flywheel/TODO.md
## 使用规则
- 每个 TODO 必须对应 `CONCEPT.md` 中的目标、能力、方案或验收点。
- 只有完成并验证后,才能把 `[ ]` 改成 `[x]`
- 勾选时在任务后补充简短证据,例如文件、接口、命令或验证结果。
- 如果需求发生变化,先更新 `CONCEPT.md`,再调整本 TODO。
- 实施顺序建议:阶段 1 → 2 → 3飞轮 1+2 是地基)→ 4/5/6 并行 → 7。
## 1. 调研与边界
- [x] [CONCEPT: 背景与问题] 盘点现有 accumulator / feedback / RAG / 规则生成能力,确认飞轮骨架已存在、断点在"检索注入 + 评测门禁"。
证据:`server/src/app/skills/domain/{false-positive-sample-accumulator,risk-feedback-sample-accumulator,risk-clue-collector}``services/agent_feedback.py``services/knowledge_rag_runtime.py``services/risk_rule_generation*.py`
- [x] [CONCEPT: 目标与非目标] 确认本轮范围 = 飞轮 1-6few-shot 注入 / golden set 门禁 / prompt 版本化 / 抽取修正回流 / 低分归因 / 智商看板),不做模型微调、不改业务接口契约。
证据CONCEPT.md「目标与非目标」章节。
- [ ] [CONCEPT: 风险与开放问题] 与业务确认 Canary 流量比例、智商看板菜单位置、政策新鲜度检测是否本轮接入。
证据:
## 2. 契约与设计
- [ ] [CONCEPT: 功能能力] 定义 4 张新表的字段、状态枚举prompt stable/canary/pinned/archived、golden draft/confirmed/deprecated、eval pass/fail/blocked、correction
证据:
- [ ] [CONCEPT: 方案设计] 明确 5 个新 service 的职责边界与 inject 钩子点risk_rule_generation_prompt / user_agent_application / expense_claim_pre_review / document_intelligence_rules / ontology_extraction
证据:
- [ ] [CONCEPT: 算法与公式] 确认 few-shot 检索排序公式权重(默认 α=0.8 β=0.2与门禁阈值recall 不降、f1 下降 ≤ 2pp
证据:
- [ ] [CONCEPT: 指标与验收] 把验收点 A1-A5 转成可验证检查项,附命令与期望结果。
证据:
## 3. 后端实现
- [x] [CONCEPT: 后端] 新增 `services/few_shot_retrieval.py`:复用 Qdrant按 case 特征检索 top-k confirmed 样本,带 token 预算裁剪。
证据:`server/src/app/services/few_shot_retrieval.py``server/src/app/services/few_shot_store.py`(独立 Qdrant collection `few_shot_samples``server/src/app/services/embedding_provider.py`(公共 EmbeddingProvider复用 knowledge_rag_runtime 的 HTTP 调用)。
- [ ] [CONCEPT: 后端] 新增 `services/prompt_registry.py`prompt 版本 CRUD + 策略取版latest/canary/pinned+ 回退硬编码。
证据:飞轮 3prompt 版本化 + Canary未启动本轮只做飞轮 1。
- [ ] [CONCEPT: 后端] 新增 `services/eval_harness.py`:在 golden set 上跑评测,输出 precision/recall/f1供门禁与看板共用。
证据:飞轮 2golden set + 门禁)未启动,本轮只做飞轮 1。
- [ ] [CONCEPT: 后端] 新增 `services/feedback_attribution.py`:低分反馈触发,复用 AgentTraceCenter trace 做归因,输出错误环节标签 + 改进任务。
证据:飞轮 5低分归因未启动本轮只做飞轮 1。
- [ ] [CONCEPT: 后端] 新增 `services/extraction_correction_recorder.py`:在附件/明细字段更新处记录 raw vs corrected diff。
证据:飞轮 4抽取修正回流未启动本轮只做飞轮 1。
- [ ] [CONCEPT: 后端] Alembic 迁移prompt_version / golden_set / extraction_correction / eval_run 四张表。
证据:本轮新增的是 FewShotSample 一张表(`server/src/app/models/few_shot_sample.py`),项目靠 `Base.metadata.create_all()` 建表(无 alembic versions/ 目录),已注册到 `db/base.py``models/__init__.py`。其余三表随对应飞轮再建。
- [x] [CONCEPT: 后端] 新增 `services/few_shot_ingestion.py`RiskObservation confirmed/false_positive → FewShotSample + Qdrant 向量,在 `risk_observations.create_feedback` commit 后 hook 触发。
证据:`server/src/app/services/few_shot_ingestion.py``server/src/app/services/risk_observations.py:324-345``_maybe_ingest_few_shot` hook带 feature flag + try/except 兜底)。
- [x] [CONCEPT: 数据与契约] 在现有 prompt 构造文件加 few-shot 注入,不改业务接口。
证据:`server/src/app/services/risk_rule_generation_prompt.py`(新增 `few_shot_samples` 可选 kwarg合并进 examples 字段);`server/src/app/services/risk_rule_generation.py:271-292``_retrieve_few_shot_samples` 在构造 messages 前调用,失败降级为空)。
## 4. 算法/规则实现
- [x] [CONCEPT: 算法与规则] 实现few-shot 检索排序(向量相似度 + 元数据硬匹配),只取 confirmed 样本。
证据:`server/src/app/services/few_shot_store.py`Qdrant 余弦相似度 + payload 过滤 scene/label/status`few_shot_retrieval.py` 去重 + token 预算 + 单条字符上限裁剪。检索仅取 label ∈ {confirmed, false_positive}。
- [ ] [CONCEPT: 算法与规则] 实现评测指标比对器(风险标记 / 字段值 / 分类标签 三类场景)。
证据:飞轮 2未启动。
- [ ] [CONCEPT: 算法与规则] 接入发布门禁:`agent_asset_risk_rule_publish` 前调 eval_harness不达标 block。
证据:飞轮 2未启动。
- [ ] [CONCEPT: 算法与规则] 接入 Canary 流量切分(默认 90 stable / 10 canary+ 劣化自动回滚。
证据:飞轮 3未启动。
- [x] [CONCEPT: 结果解释] few-shot 块在 prompt 中保留 `source: "historical_confirmed"` 标记,可追溯。
证据:`risk_rule_generation_prompt.py` 合并 examples 时每条历史样本带 `source`/`label`/`conclusion` 字段。
## 5. 前端实现
- [ ] [CONCEPT: 前端] AI 智商看板新页面:准确率/召回率/误报率/人工干预率随时间曲线 + golden set 覆盖度,复用 `finance-report` 看板骨架。
证据:
- [ ] [CONCEPT: 前端] Canary 控制台(并入 Settings/Policiesprompt 版本列表、流量比例、分数、一键回滚。
证据:
- [ ] [CONCEPT: 前端] 实现加载/空态(样本不足)/错误态(评测失败)/权限态(仅算法运营)。
证据:
- [ ] [CONCEPT: 前端] 对齐现有企业后台风格(参考 `chat-ui-saas-styling` / `theme-settings-enterprise-ai-style`),避免营销页观感。
证据:
## 6. 测试与验证
- [x] [CONCEPT: 测试方案] 后端单测embedding_providerGLM/Ollama 分支、维度缓存、HTTP 错误降级、few_shot_ingestionconfirmed/false_positive 入库、ignored 跳过、幂等去重、hook 触发、feature flag、吞异常、few_shot_retrieval去重、token 预算、超长截断)+ prompt 注入(合并 examples、向后兼容
证据:`server/tests/test_embedding_provider.py``server/tests/test_few_shot_ingestion.py``server/tests/test_few_shot_retrieval_and_prompt.py`,容器内 `pytest -q` 20 passed。
- [ ] [CONCEPT: 测试方案] 前端:智商看板与 Canary 控制台视图模型 + 构建验证。
证据:飞轮 3/6 前端,未启动。
- [ ] [CONCEPT: 测试方案] 集成golden set → 改 prompt → 门禁拦截/通过 → 看板新增数据点,容器内跑通。
证据:飞轮 2 集成,未启动。
- [x] [CONCEPT: 测试方案] 回归:现有 RAG / risk_observations / risk_rule_generation 测试全过。
证据:容器内 `pytest -q server/tests/test_risk_observations_service.py server/tests/test_knowledge_rag_runtime.py server/tests/test_risk_rule_generation.py server/tests/test_risk_rule_generation_failure.py` → 35 passedEmbeddingProvider 抽离零回归。
- [ ] [CONCEPT: 指标与验收] 记录验证命令与结果,确认 P95 检索 ≤ 200ms、单场景评测 ≤ 60s。
证据:性能指标待飞轮 2 评测上线后连同 golden set 一起量。
## 7. 文档收尾
- [x] [CONCEPT: 指标与验收] 飞轮 1few-shot 注入A1 功能验收已达成:推理时 prompt 中可见带 `source: "historical_confirmed"` 的 few-shot 块,且样本来自 confirmed/false_positive 池。A5 可观测性部分达成(可追溯 source。A2/A3/A4 随飞轮 2/3 补齐。
证据:见阶段 3/4/6 已勾选项。
- [ ] [CONCEPT: 风险与开放问题] 更新 Canary 比例、看板菜单位置、政策新鲜度检测的最终结论与剩余风险。
证据:飞轮 2-6 启动后再定稿。
- [x] [CONCEPT: 功能一句话] 确认飞轮 1 实现没有偏离"让系统越用越聪明"的原始目标。
证据:人工确认风险观测 → 自动入库 + 向量化 → 下次规则编译时检索注入相似历史样本,形成"用得越多 → 样本越丰富 → 推理越准"的闭环。飞轮 2-6 待后续迭代。

View File

@@ -43,10 +43,6 @@ from app.schemas.agent_asset import (
AgentAssetVersionCreate, AgentAssetVersionCreate,
AgentAssetVersionRead, AgentAssetVersionRead,
AgentAssetVersionTimelineItemRead, AgentAssetVersionTimelineItemRead,
GoldenCaseCreate,
GoldenCaseRead,
GoldenEvalRead,
GoldenEvalRequest,
) )
from app.schemas.common import ErrorResponse, PaginatedResponse from app.schemas.common import ErrorResponse, PaginatedResponse
from app.services.agent_assets import AgentAssetService from app.services.agent_assets import AgentAssetService
@@ -927,110 +923,3 @@ def get_agent_asset_version_timeline(
return AgentAssetService(db).list_version_timeline(asset_id) return AgentAssetService(db).list_version_timeline(asset_id)
except Exception as exc: except Exception as exc:
_handle_asset_error(exc) _handle_asset_error(exc)
@router.post(
"/risk-rules/golden-cases",
response_model=GoldenCaseRead,
status_code=status.HTTP_201_CREATED,
summary="创建 golden set 黄金用例",
description="为指定规则(或通用场景)创建一条回归用例,发布前作为门禁集执行。",
)
def create_golden_case(
body: GoldenCaseCreate,
_: RuleEditorUser,
db: DbSession,
) -> GoldenCaseRead:
from app.models.golden_case import GoldenCase
from sqlalchemy import select
existing = db.scalar(select(GoldenCase).where(GoldenCase.case_key == body.case_key))
if existing is not None:
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="case_key 已存在")
case = GoldenCase(
case_key=body.case_key,
rule_code=body.rule_code,
scene=body.scene,
name=body.name,
values_json=body.values,
expected_hit=body.expected_hit,
expected_severity=body.expected_severity,
note=body.note,
status="active",
source="manual",
)
db.add(case)
db.commit()
db.refresh(case)
return _golden_case_read(case)
@router.get(
"/risk-rules/{rule_code}/golden-cases",
response_model=list[GoldenCaseRead],
summary="列出规则的 golden 用例",
)
def list_golden_cases(
rule_code: str,
_: CurrentUser,
db: DbSession,
) -> list[GoldenCaseRead]:
from app.models.golden_case import GoldenCase
from sqlalchemy import select
cases = db.scalars(
select(GoldenCase).where(GoldenCase.rule_code == rule_code).order_by(GoldenCase.created_at)
).all()
return [_golden_case_read(case) for case in cases]
@router.post(
"/{asset_id}/golden-eval",
response_model=GoldenEvalRead,
summary="手动触发 golden set 评测(不入门禁)",
description="在当前规则版本上跑 golden 用例集,返回指标。门禁由 publish 时自动触发。",
)
def run_golden_eval(
asset_id: str,
body: GoldenEvalRequest,
_: RuleReviewerUser,
db: DbSession,
) -> GoldenEvalRead:
from app.services.agent_asset_spreadsheet import RISK_RULES_LIBRARY
from app.services.risk_rule_golden_evaluator import RiskRuleGoldenEvaluator
try:
asset = AgentAssetService(db).get_asset(asset_id)
if asset is None:
raise LookupError("Asset not found")
config = asset.config_json if isinstance(asset.config_json, dict) else {}
rule_document = config.get("rule_document") if isinstance(config.get("rule_document"), dict) else {}
file_name = str(rule_document.get("file_name") or "").strip()
if not file_name:
raise ValueError("该规则没有可执行的 manifest 文件。")
manager = AgentAssetService(db).rule_library_manager
manifest = manager.read_rule_library_json(library=RISK_RULES_LIBRARY, file_name=file_name)
rule_code = str(manifest.get("rule_code") or "").strip()
if not rule_code:
raise ValueError("manifest 缺少 rule_code。")
version = body.version or asset.working_version or ""
report = RiskRuleGoldenEvaluator().evaluate_for_rule(db, manifest, rule_code)
return GoldenEvalRead(**report.to_dict())
except Exception as exc:
_handle_asset_error(exc)
def _golden_case_read(case) -> GoldenCaseRead:
return GoldenCaseRead(
id=case.id,
case_key=case.case_key,
rule_code=case.rule_code,
scene=case.scene or "",
name=case.name or "",
values=case.values_json or {},
expected_hit=bool(case.expected_hit),
expected_severity=case.expected_severity,
note=case.note,
status=case.status,
source=case.source,
)

View File

@@ -15,13 +15,13 @@ from app.models.budget import BudgetAllocation, BudgetReservation, BudgetTransac
from app.models.employee_change_log import EmployeeChangeLog from app.models.employee_change_log import EmployeeChangeLog
from app.models.employee_behavior_profile import EmployeeBehaviorProfileSnapshot from app.models.employee_behavior_profile import EmployeeBehaviorProfileSnapshot
from app.models.employee import Employee from app.models.employee import Employee
from app.models.few_shot_sample import FewShotSample
from app.models.financial_record import ( from app.models.financial_record import (
AccountsPayableRecord, AccountsPayableRecord,
AccountsReceivableRecord, AccountsReceivableRecord,
ExpenseClaim, ExpenseClaim,
ExpenseClaimItem, ExpenseClaimItem,
) )
from app.models.golden_case import GoldenCase
from app.models.hermes_config import HermesTaskConfig, HermesTaskExecutionLog from app.models.hermes_config import HermesTaskConfig, HermesTaskExecutionLog
from app.models.hermes_report import HermesRiskReport from app.models.hermes_report import HermesRiskReport
from app.models.notification_state import NotificationState from app.models.notification_state import NotificationState
@@ -58,8 +58,8 @@ __all__ = [
"EmployeeBehaviorProfileSnapshot", "EmployeeBehaviorProfileSnapshot",
"EmployeeChangeLog", "EmployeeChangeLog",
"ExpenseClaim", "ExpenseClaim",
"FewShotSample",
"ExpenseClaimItem", "ExpenseClaimItem",
"GoldenCase",
"HermesTaskConfig", "HermesTaskConfig",
"HermesTaskExecutionLog", "HermesTaskExecutionLog",
"HermesRiskReport", "HermesRiskReport",

View File

@@ -8,13 +8,13 @@ from app.models.budget import BudgetAllocation, BudgetReservation, BudgetTransac
from app.models.employee_change_log import EmployeeChangeLog from app.models.employee_change_log import EmployeeChangeLog
from app.models.employee_behavior_profile import EmployeeBehaviorProfileSnapshot from app.models.employee_behavior_profile import EmployeeBehaviorProfileSnapshot
from app.models.employee import Employee from app.models.employee import Employee
from app.models.few_shot_sample import FewShotSample
from app.models.financial_record import ( from app.models.financial_record import (
AccountsPayableRecord, AccountsPayableRecord,
AccountsReceivableRecord, AccountsReceivableRecord,
ExpenseClaim, ExpenseClaim,
ExpenseClaimItem, ExpenseClaimItem,
) )
from app.models.golden_case import GoldenCase
from app.models.hermes_config import HermesTaskConfig, HermesTaskExecutionLog from app.models.hermes_config import HermesTaskConfig, HermesTaskExecutionLog
from app.models.hermes_report import HermesRiskReport from app.models.hermes_report import HermesRiskReport
from app.models.notification_state import NotificationState from app.models.notification_state import NotificationState
@@ -50,7 +50,7 @@ __all__ = [
"EmployeeChangeLog", "EmployeeChangeLog",
"ExpenseClaim", "ExpenseClaim",
"ExpenseClaimItem", "ExpenseClaimItem",
"GoldenCase", "FewShotSample",
"HermesTaskConfig", "HermesTaskConfig",
"HermesTaskExecutionLog", "HermesTaskExecutionLog",
"HermesRiskReport", "HermesRiskReport",

View File

@@ -0,0 +1,54 @@
from __future__ import annotations
import uuid
from datetime import datetime
from typing import Any
from sqlalchemy import DateTime, ForeignKey, Index, String, Text, func
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.types import JSON
from app.db.base_class import Base
class FewShotSample(Base):
"""已确认的风险观测样本,供 few-shot 检索注入使用。
数据来源是 ``RiskObservation`` 上人工确认为 confirmed / false_positive 的观测,
入库后同时写一份向量到 Qdrant 的 ``few_shot_samples`` collection。
"""
__tablename__ = "few_shot_samples"
__table_args__ = (
Index("ix_few_shot_samples_scene_label", "scene", "label"),
Index("ix_few_shot_samples_domain_risk_type", "domain", "risk_type"),
)
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
sample_key: Mapped[str] = mapped_column(String(160), unique=True, index=True)
source_observation_id: Mapped[str | None] = mapped_column(
ForeignKey("risk_observations.id"),
nullable=True,
index=True,
)
scene: Mapped[str] = mapped_column(String(50), default="risk_rule_generation", index=True)
domain: Mapped[str] = mapped_column(String(50), default="", index=True)
risk_type: Mapped[str] = mapped_column(String(80), default="", index=True)
risk_level: Mapped[str] = mapped_column(String(20), default="")
label: Mapped[str] = mapped_column(String(30), default="confirmed", index=True)
case_text: Mapped[str] = mapped_column(Text(), default="")
conclusion_text: Mapped[str] = mapped_column(Text(), default="")
payload_json: Mapped[dict[str, Any]] = mapped_column(JSON, default=dict)
vector_id: Mapped[str | None] = mapped_column(String(100), nullable=True)
status: Mapped[str] = mapped_column(String(20), default="active", index=True)
created_at: Mapped[datetime] = mapped_column(DateTime, default=func.now(), server_default=func.now())
updated_at: Mapped[datetime] = mapped_column(
DateTime,
default=func.now(),
onupdate=func.now(),
server_default=func.now(),
)

View File

@@ -1,48 +0,0 @@
from __future__ import annotations
import uuid
from datetime import datetime
from typing import Any
from sqlalchemy import Boolean, DateTime, Index, String, Text, func
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.types import JSON
from app.db.base_class import Base
class GoldenCase(Base):
"""风险规则回归门禁用的黄金用例。
由运营手动维护(或从已确认风险观测导入),在规则发布前作为回归集执行,
100% 通过才放行。``values_json`` 复用 ``AgentAssetRiskRuleSampleCase.values``
的扁平字典格式,``expected_hit`` / ``expected_severity`` 作为 ground truth。
"""
__tablename__ = "golden_cases"
__table_args__ = (
Index("ix_golden_cases_rule_code_status", "rule_code", "status"),
Index("ix_golden_cases_scene_status", "scene", "status"),
)
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
case_key: Mapped[str] = mapped_column(String(160), unique=True, index=True)
rule_code: Mapped[str | None] = mapped_column(String(120), nullable=True, index=True)
scene: Mapped[str] = mapped_column(String(50), default="", index=True)
name: Mapped[str] = mapped_column(String(120), default="")
values_json: Mapped[dict[str, Any]] = mapped_column(JSON, default=dict)
expected_hit: Mapped[bool] = mapped_column(Boolean, default=True)
expected_severity: Mapped[str | None] = mapped_column(String(20), nullable=True)
note: Mapped[str | None] = mapped_column(Text(), nullable=True)
status: Mapped[str] = mapped_column(String(20), default="active", index=True)
source: Mapped[str] = mapped_column(String(30), default="manual")
created_at: Mapped[datetime] = mapped_column(DateTime, default=func.now(), server_default=func.now())
updated_at: Mapped[datetime] = mapped_column(
DateTime,
default=func.now(),
onupdate=func.now(),
server_default=func.now(),
)

View File

@@ -204,46 +204,6 @@ class AgentAssetRiskRuleReportRequest(BaseModel):
note: str | None = Field(default=None, max_length=1000) note: str | None = Field(default=None, max_length=1000)
class GoldenCaseCreate(BaseModel):
case_key: str = Field(..., max_length=160)
rule_code: str | None = Field(default=None, max_length=120)
scene: str = Field(default="", max_length=50)
name: str = Field(default="", max_length=120)
values: dict[str, Any] = Field(default_factory=dict)
expected_hit: bool = True
expected_severity: str | None = Field(default=None, max_length=20)
note: str | None = None
class GoldenCaseRead(BaseModel):
id: str
case_key: str
rule_code: str | None = None
scene: str = ""
name: str = ""
values: dict[str, Any] = Field(default_factory=dict)
expected_hit: bool = True
expected_severity: str | None = None
note: str | None = None
status: str = "active"
source: str = "manual"
class GoldenEvalRequest(BaseModel):
version: str | None = Field(default=None, max_length=30)
class GoldenEvalRead(BaseModel):
total: int = 0
passed_count: int = 0
failed_count: int = 0
accuracy: float = 0.0
precision: float = 0.0
recall: float = 0.0
all_passed: bool = True
results: list[dict[str, Any]] = Field(default_factory=list)
class AgentAssetRiskRuleSimulationAttachment(BaseModel): class AgentAssetRiskRuleSimulationAttachment(BaseModel):
name: str = Field(default="", max_length=240) name: str = Field(default="", max_length=240)
content_type: str | None = Field(default=None, max_length=120) content_type: str | None = Field(default=None, max_length=120)

View File

@@ -39,9 +39,6 @@ class AgentAssetRiskRulePublishMixin:
if not self.get_latest_risk_rule_test_summary(asset, version=version).test_passed: if not self.get_latest_risk_rule_test_summary(asset, version=version).test_passed:
raise PermissionError("当前规则版本尚未完成测试通过确认,不能发布。") raise PermissionError("当前规则版本尚未完成测试通过确认,不能发布。")
# golden set 回归门禁:在 golden 用例集上跑规则,未 100% 通过则拦截发布。
self._require_golden_set_passed(asset, version, actor=actor)
before = self._asset_snapshot(asset) before = self._asset_snapshot(asset)
self._ensure_approved_review(asset, version=version, actor=actor, note="发布上线前审核通过。") self._ensure_approved_review(asset, version=version, actor=actor, note="发布上线前审核通过。")
asset.reviewer = actor asset.reviewer = actor
@@ -179,49 +176,6 @@ class AgentAssetRiskRulePublishMixin:
) )
) )
def _require_golden_set_passed(
self,
asset: AgentAsset,
version: str,
*,
actor: str,
) -> None:
"""在 golden set 上跑当前规则 manifest未 100% 通过则拦截发布。
降级策略feature flag 关闭 / 无 rule_document / 无 golden case /
evaluator 异常 → 一律放行,不阻塞发布主链路。
"""
import os
if os.environ.get("GOLDEN_SET_GATE_ENABLED", "true").strip().lower() in {"0", "false", "no"}:
return
config = asset.config_json if isinstance(asset.config_json, dict) else {}
rule_document = config.get("rule_document") if isinstance(config.get("rule_document"), dict) else {}
file_name = str(rule_document.get("file_name") or "").strip()
if not file_name:
return
try:
manifest = self.rule_library_manager.read_rule_library_json(
library=RISK_RULES_LIBRARY,
file_name=file_name,
)
except Exception:
return
rule_code = str(manifest.get("rule_code") or "").strip()
if not rule_code:
return
from app.services.risk_rule_golden_evaluator import RiskRuleGoldenEvaluator
RiskRuleGoldenEvaluator().require_pass(
self.db,
asset,
version,
manifest,
rule_code,
actor=actor,
)
@staticmethod @staticmethod
def _config_from_published_manifest( def _config_from_published_manifest(
manifest: dict[str, Any], manifest: dict[str, Any],

View File

@@ -0,0 +1,138 @@
"""公共 Embedding 提供者。
把 ``knowledge_rag_runtime`` 里 embedding 调用逻辑抽出来,供 RAG 和
few-shot 检索复用。本模块只依赖现有模块级纯函数和 ``RuntimeModelConfig``
不改动 ``_LightRagRuntime`` 的行为RAG 路径保持零回归风险。
典型用法::
provider = EmbeddingProvider.from_settings(session)
vectors = provider.embed(["差旅超标", "票单不一致"])
dim = provider.dimension()
"""
from __future__ import annotations
from typing import Any
from app.core.logging import get_logger
from app.services.knowledge_rag_runtime import (
DEFAULT_EMBEDDING_TIMEOUT_SECONDS,
KnowledgeRagError,
RuntimeModelConfig,
_build_headers,
_ensure_path,
_extract_embedding_vectors,
_normalize_endpoint,
_send_json_request,
)
logger = get_logger("app.services.embedding_provider")
def _runtime_model_config_from_dict(config: dict[str, str]) -> RuntimeModelConfig:
"""把 SettingsService.get_runtime_model_config 返回的 dict 转成 dataclass。"""
return RuntimeModelConfig(
slot=str(config.get("slot") or "embedding"),
provider=str(config.get("provider") or ""),
model=str(config.get("model") or ""),
endpoint=str(config.get("endpoint") or ""),
api_key=str(config.get("apiKey") or ""),
capability=str(config.get("capability") or ""),
)
class EmbeddingProvider:
"""对 embedding 模型的轻量封装。
设计要点:
- 持有一个 ``RuntimeModelConfig``,构造即固定,不依赖 LightRAG。
- 复用 ``knowledge_rag_runtime`` 的 HTTP 调用纯函数,行为与 RAG 完全一致。
- 维度采用惰性探测(首次 embed 后缓存),避免空构造就打远端。
"""
def __init__(self, config: RuntimeModelConfig) -> None:
self.config = config
self._dimension: int | None = None
@classmethod
def from_settings(cls, session: Any) -> "EmbeddingProvider":
"""从 SettingsService 取 embedding 配置构造 provider。"""
from app.services.settings import SettingsService
raw = SettingsService(session).get_runtime_model_config("embedding")
return cls(_runtime_model_config_from_dict(raw))
def embed(self, texts: list[str]) -> list[list[float]]:
"""对一组文本做 embedding返回与输入等长的向量列表。"""
if not texts:
return []
return _request_embeddings_public(self.config, texts)
def dimension(self) -> int:
"""探测 embedding 维度,结果缓存。失败抛 KnowledgeRagError。"""
if self._dimension is None:
vectors = self.embed(["dimension probe"])
if not vectors or not isinstance(vectors[0], list):
raise KnowledgeRagError("无法从 embedding 模型返回结果中解析向量维度。")
self._dimension = len(vectors[0])
if self._dimension <= 0:
raise KnowledgeRagError("embedding 模型返回了无效的向量维度。")
return self._dimension
def _request_embeddings_public(
config: RuntimeModelConfig,
texts: list[str],
) -> list[list[float]]:
"""按 provider 分支构造 embedding 请求。
与 ``_LightRagRuntime._request_embeddings`` 实现保持一致,
保证 few-shot 检索与 RAG 走同一套调用语义。
"""
from app.services.model_connectivity import AZURE_API_VERSION
if config.provider == "Azure OpenAI":
from app.services.knowledge_rag_runtime import _build_azure_deployment_base
url = f"{_build_azure_deployment_base(config.endpoint, config.model)}/embeddings?api-version={AZURE_API_VERSION}"
payload: dict[str, Any] = {"input": texts}
status_code, body = _send_json_request(
"POST",
url,
headers=_build_headers(config.api_key, use_bearer=False, use_api_key=True),
payload=payload,
timeout_seconds=DEFAULT_EMBEDDING_TIMEOUT_SECONDS,
)
elif config.provider == "Ollama":
url = _ensure_path(_normalize_endpoint(config.endpoint), "api/embed")
payload = {"model": config.model, "input": texts}
status_code, body = _send_json_request(
"POST",
url,
headers={"Content-Type": "application/json", "Accept": "application/json"},
payload=payload,
timeout_seconds=DEFAULT_EMBEDDING_TIMEOUT_SECONDS,
)
else:
url = _ensure_path(_normalize_endpoint(config.endpoint), "embeddings")
payload = {"model": config.model, "input": texts}
status_code, body = _send_json_request(
"POST",
url,
headers=_build_headers(config.api_key, use_bearer=True),
payload=payload,
timeout_seconds=DEFAULT_EMBEDDING_TIMEOUT_SECONDS,
)
from http import HTTPStatus
if status_code >= HTTPStatus.BAD_REQUEST:
raise KnowledgeRagError(f"embedding 模型返回异常状态码 {status_code}")
return _extract_embedding_vectors(body, provider=config.provider)

View File

@@ -0,0 +1,177 @@
"""Few-shot 样本入库编排RiskObservation → FewShotSample → Qdrant。
只处理人工确认为 confirmed / false_positive 的观测,把它转成一条
:class:`FewShotSample`,持久化到 DB并同步向量到 Qdrant。
入库动作由 :meth:`RiskObservationService.create_feedback` 在 commit 后触发,
本服务全程吞异常(只记日志),保证反馈主流程不被 few-shot 链路拖崩。
"""
from __future__ import annotations
from typing import Any
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.core.logging import get_logger
from app.models.few_shot_sample import FewShotSample
from app.models.risk_observation import RiskObservation, RiskObservationFeedback
from app.services.embedding_provider import EmbeddingProvider
from app.services.few_shot_store import FewShotStore
logger = get_logger("app.services.few_shot_ingestion")
# 仅这两个 feedback_status 视为已确认样本,会入库
CONFIRMED_LABELS = {"confirmed", "false_positive"}
# label → 自然语言结论(当 feedback.comment 缺失时兜底)
LABEL_CONCLUSION_FALLBACK = {
"confirmed": "经人工复核确认,该风险线索成立,需按规则拦截或补件。",
"false_positive": "经人工复核判定为误报,相似情形不应触发该风险规则。",
}
class FewShotIngestionService:
"""把已确认的风险观测沉淀为 few-shot 样本。"""
def __init__(self, db: Session) -> None:
self.db = db
def ingest_observation_feedback(
self,
observation: RiskObservation,
feedback: RiskObservationFeedback,
) -> FewShotSample | None:
"""人工确认/误报后调用,写入并同步向量。"""
label = observation.feedback_status
if label not in CONFIRMED_LABELS:
return None
sample_key = f"obs:{observation.id}"
sample = self.db.scalar(
select(FewShotSample).where(FewShotSample.sample_key == sample_key)
)
domain = self._extract_domain(observation)
case_text = self._build_case_text(observation)
conclusion_text = self._build_conclusion_text(observation, feedback, label)
payload = self._build_payload(observation, feedback, label)
if sample is None:
sample = FewShotSample(
sample_key=sample_key,
source_observation_id=observation.id,
scene="risk_rule_generation",
domain=domain,
risk_type=observation.risk_type or "",
risk_level=observation.risk_level or "",
label=label,
case_text=case_text,
conclusion_text=conclusion_text,
payload_json=payload,
status="active",
)
self.db.add(sample)
else:
sample.label = label
sample.domain = domain
sample.risk_type = observation.risk_type or ""
sample.risk_level = observation.risk_level or ""
sample.case_text = case_text
sample.conclusion_text = conclusion_text
sample.payload_json = payload
sample.status = "active"
sample.vector_id = sample.vector_id
try:
self.db.commit()
self.db.refresh(sample)
except Exception:
logger.exception("few-shot 样本持久化失败 observation_id=%s", observation.id)
self.db.rollback()
return None
vector_id = self._store().upsert(sample)
if vector_id:
sample.vector_id = vector_id
try:
self.db.commit()
except Exception:
logger.warning("few-shot vector_id 回写失败 sample_id=%s", sample.id)
return sample
def retract_observation(self, observation_id: str) -> bool:
"""观测被撤销时删掉对应样本及其向量。"""
sample = self.db.scalar(
select(FewShotSample).where(FewShotSample.source_observation_id == observation_id)
)
if sample is None:
return False
if sample.vector_id:
self._store().delete_by_vector_id(sample.vector_id)
try:
self.db.delete(sample)
self.db.commit()
return True
except Exception:
logger.exception("few-shot 样本删除失败 observation_id=%s", observation_id)
self.db.rollback()
return False
def _store(self) -> FewShotStore:
provider = EmbeddingProvider.from_settings(self.db)
return FewShotStore(provider)
def _extract_domain(self, observation: RiskObservation) -> str:
ontology = observation.ontology_json or {}
return str(ontology.get("domain") or "")
def _build_case_text(self, observation: RiskObservation) -> str:
parts = [
observation.title or "",
observation.description or "",
observation.risk_signal or "",
observation.risk_type or "",
]
ontology = observation.ontology_json or {}
scenario = ontology.get("scenario")
if scenario:
parts.append(f"场景:{scenario}")
risk_signals = ontology.get("risk_signals")
if isinstance(risk_signals, list) and risk_signals:
parts.append("信号:" + "|".join(str(s) for s in risk_signals))
return "\n".join(part for part in parts if part).strip()
def _build_conclusion_text(
self,
observation: RiskObservation,
feedback: RiskObservationFeedback,
label: str,
) -> str:
comment = (feedback.comment or "").strip()
if comment:
return f"[{label}] {comment}"
return LABEL_CONCLUSION_FALLBACK.get(label, label)
def _build_payload(
self,
observation: RiskObservation,
feedback: RiskObservationFeedback,
label: str,
) -> dict[str, Any]:
return {
"label": label,
"risk_type": observation.risk_type,
"risk_signal": observation.risk_signal,
"risk_level": observation.risk_level,
"feedback_type": feedback.feedback_type,
"feedback_comment": feedback.comment or "",
"feedback_actor": feedback.actor or "",
"ontology": observation.ontology_json or {},
"policy_refs": observation.policy_refs_json or [],
"evidence": observation.evidence_json or [],
"subject_label": observation.subject_label or "",
"claim_no": observation.claim_no or "",
}

View File

@@ -0,0 +1,122 @@
"""Few-shot 检索器:按当前 case 特征检索相似历史样本,拼成注入块。
从 :class:`FewShotStore` 取相似样本,转成可供 prompt 构造函数直接使用的结构。
带 token 预算裁剪和去重,确保不撑爆 prompt。
典型用法(在构造 prompt 之前调用)::
retriever = FewShotRetriever.from_session(session)
samples = retriever.retrieve_for_risk_rule_generation(
domain="travel", natural_language="票据城市与申报地不一致"
)
messages = build_risk_rule_compiler_messages(
...,
few_shot_samples=samples,
)
"""
from __future__ import annotations
from typing import Any
from sqlalchemy.orm import Session
from app.core.logging import get_logger
from app.services.embedding_provider import EmbeddingProvider
from app.services.few_shot_store import FewShotStore
logger = get_logger("app.services.few_shot_retrieval")
# 单条 few-shot 样本估算 token 数(用于预算裁剪)
SAMPLE_TOKEN_BUDGET = 1200
# 单条样本最大字符数,超长直接截断结论,避免撑爆 prompt
SINGLE_SAMPLE_MAX_CHARS = 400
# 历史样本最多注入条数(与原内置 examples 合并后总量受限)
MAX_HISTORICAL_SAMPLES = 3
class FewShotRetriever:
"""按 case 特征检索已确认样本,返回 prompt 可直接消费的结构。"""
def __init__(self, store: FewShotStore) -> None:
self._store = store
@classmethod
def from_session(cls, session: Session) -> "FewShotRetriever":
provider = EmbeddingProvider.from_settings(session)
return cls(FewShotStore(provider))
def retrieve_for_risk_rule_generation(
self,
*,
domain: str = "",
risk_type: str = "",
natural_language: str,
top_k: int = MAX_HISTORICAL_SAMPLES,
) -> list[dict[str, Any]]:
"""检索与当前规则需求相似的历史样本,返回注入块列表。"""
case_text = self._build_case_text(
natural_language=natural_language,
domain=domain,
risk_type=risk_type,
)
if not case_text:
return []
hits = self._store.search(
case_text,
scene="risk_rule_generation",
labels=["confirmed", "false_positive"],
top_k=top_k,
)
return self._hits_to_injection_blocks(hits)
def _build_case_text(
self,
*,
natural_language: str,
domain: str = "",
risk_type: str = "",
) -> str:
parts = [natural_language, domain, risk_type]
return "\n".join(p for p in parts if p).strip()
def _hits_to_injection_blocks(
self,
hits: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""把检索命中转成 prompt 可消费的块,做去重和预算裁剪。"""
blocks: list[dict[str, Any]] = []
seen_conclusions: set[str] = set()
budget = SAMPLE_TOKEN_BUDGET
for hit in hits:
conclusion = (hit.get("conclusion_text") or "").strip()
if not conclusion or conclusion in seen_conclusions:
continue
# 超长结论截断到上限,避免单条样本占用过多预算
if len(conclusion) > SINGLE_SAMPLE_MAX_CHARS:
conclusion = conclusion[:SINGLE_SAMPLE_MAX_CHARS]
payload = hit.get("payload_json") or {}
block = {
"source": "historical_confirmed",
"label": hit.get("label"),
"domain": hit.get("domain") or "",
"risk_type": hit.get("risk_type") or "",
"score": round(float(hit.get("score") or 0.0), 4),
"conclusion": conclusion,
"context": {
"risk_signal": payload.get("risk_signal") or "",
"risk_level": payload.get("risk_level") or "",
"ontology": payload.get("ontology") or {},
"feedback_comment": payload.get("feedback_comment") or "",
},
}
# 粗略 token 估算(按字符数 / 1.6 近似中文 token 比)
estimated_tokens = int(len(conclusion) / 1.6) + 40
if estimated_tokens > budget:
break
budget -= estimated_tokens
blocks.append(block)
seen_conclusions.add(conclusion)
return blocks

View File

@@ -0,0 +1,214 @@
"""Few-shot 样本的 Qdrant 向量存储。
独立于 LightRAG 的 Qdrant 客户端,使用专用 collection ``few_shot_samples``
与知识库 RAG 的 collection 隔离。所有操作失败都不抛异常(记日志返回空),
保证主链路不阻塞。
向量来自 :class:`EmbeddingProvider`payload 带业务过滤字段scene/label/domain/risk_type
检索时按这些字段过滤 + 向量相似度排序。
"""
from __future__ import annotations
import os
import uuid
from typing import Any
from app.core.logging import get_logger
from app.services.knowledge_rag import _resolve_default_qdrant_url
logger = get_logger("app.services.few_shot_store")
FEW_SHOT_COLLECTION = "few_shot_samples"
def _resolve_qdrant_config() -> tuple[str, str]:
"""复用 knowledge_rag 的 Qdrant URL/key 解析逻辑。"""
url = os.environ.get("QDRANT_URL", "").strip() or _resolve_default_qdrant_url()
api_key = os.environ.get("QDRANT_API_KEY", "").strip()
return url, api_key
class FewShotStore:
"""对 Qdrant 的轻量封装,专供 few-shot 样本检索使用。
设计要点:
- 惰性创建 client 和 collection首次操作时初始化。
- 所有公共方法吞异常(返回空/False主链路永远不被拖崩。
- 向量写入和检索都依赖外部传入的 :class:`EmbeddingProvider`
由调用方保证与配置一致。
"""
def __init__(self, embedding_provider: Any) -> None:
self._embedding_provider = embedding_provider
self._client: Any = None
self._ensured = False
def _client_or_none(self) -> Any:
"""惰性初始化 QdrantClient失败返回 None。"""
if self._client is not None:
return self._client
try:
from qdrant_client import QdrantClient
url, api_key = _resolve_qdrant_config()
self._client = QdrantClient(url=url, api_key=api_key or None)
except Exception:
logger.warning("few-shot QdrantClient 初始化失败,本轮操作跳过", exc_info=True)
self._client = None
return self._client
def _ensure_collection(self) -> bool:
"""确保 collection 存在,成功返回 True。"""
if self._ensured:
return True
client = self._client_or_none()
if client is None:
return False
try:
from qdrant_client.http.exceptions import UnexpectedResponse
try:
client.get_collection(FEW_SHOT_COLLECTION)
self._ensured = True
return True
except UnexpectedResponse as exc:
if exc.status_code != 404:
raise
# collection 不存在则创建
dim = self._embedding_provider.dimension()
from qdrant_client.http.models import (
Distance,
VectorParams,
PayloadSchemaType,
)
client.create_collection(
collection_name=FEW_SHOT_COLLECTION,
vectors_config=VectorParams(size=dim, distance=Distance.COSINE),
)
for field, field_type in [
("sample_id", PayloadSchemaType.KEYWORD),
("scene", PayloadSchemaType.KEYWORD),
("label", PayloadSchemaType.KEYWORD),
("domain", PayloadSchemaType.KEYWORD),
("risk_type", PayloadSchemaType.KEYWORD),
("status", PayloadSchemaType.KEYWORD),
]:
try:
client.create_payload_index(
collection_name=FEW_SHOT_COLLECTION,
field_name=field,
field_schema=field_type,
)
except Exception:
logger.debug("payload index 创建跳过 field=%s", field, exc_info=True)
self._ensured = True
logger.info("few-shot collection 创建成功 dim=%s", dim)
return True
except Exception:
logger.warning("few-shot collection 初始化失败,本轮操作跳过", exc_info=True)
return False
def upsert(self, sample: Any) -> str | None:
"""把一条样本向量化并写入 Qdrant返回 vector_id失败返回 None。"""
if not self._ensure_collection():
return None
client = self._client
try:
vector = self._embedding_provider.embed([sample.case_text])[0]
except Exception:
logger.warning("few-shot embedding 失败 sample_key=%s", getattr(sample, "sample_key", ""), exc_info=True)
return None
vector_id = uuid.uuid4().hex
payload = {
"sample_id": sample.id,
"scene": sample.scene,
"label": sample.label,
"domain": sample.domain,
"risk_type": sample.risk_type,
"risk_level": sample.risk_level,
"status": getattr(sample, "status", "active"),
"conclusion_text": sample.conclusion_text,
"payload_json": sample.payload_json,
}
try:
client.upsert(
collection_name=FEW_SHOT_COLLECTION,
points=[{"id": vector_id, "vector": vector, "payload": payload}],
)
return vector_id
except Exception:
logger.warning("few-shot upsert 失败 sample_key=%s", getattr(sample, "sample_key", ""), exc_info=True)
return None
def search(
self,
case_text: str,
*,
scene: str | None = None,
labels: list[str] | None = None,
top_k: int = 3,
) -> list[dict[str, Any]]:
"""按 case_text 检索相似样本,可按 scene/label 过滤。失败返回空列表。"""
if not case_text or not self._ensure_collection():
return []
client = self._client
try:
vector = self._embedding_provider.embed([case_text])[0]
except Exception:
logger.warning("few-shot 检索 embedding 失败", exc_info=True)
return []
must: list[dict[str, Any]] = [{"key": "status", "match": {"value": "active"}}]
if scene:
must.append({"key": "scene", "match": {"value": scene}})
if labels:
must.append({"key": "label", "match": {"any": labels}})
try:
from qdrant_client.http.models import Filter
results = client.query_points(
collection_name=FEW_SHOT_COLLECTION,
query=vector,
query_filter=Filter(must=must),
limit=top_k,
with_payload=True,
).points
except Exception:
logger.warning("few-shot 检索失败", exc_info=True)
return []
hits: list[dict[str, Any]] = []
for point in results:
payload = getattr(point, "payload", None) or {}
hits.append(
{
"sample_id": payload.get("sample_id"),
"score": float(getattr(point, "score", 0.0)),
"label": payload.get("label"),
"domain": payload.get("domain"),
"risk_type": payload.get("risk_type"),
"conclusion_text": payload.get("conclusion_text") or "",
"payload_json": payload.get("payload_json") or {},
}
)
return hits
def delete_by_vector_id(self, vector_id: str) -> bool:
"""按 vector_id 删除向量,失败返回 False。"""
if not vector_id or not self._ensure_collection():
return False
try:
self._client.delete(
collection_name=FEW_SHOT_COLLECTION,
points_selector=[vector_id],
)
return True
except Exception:
logger.warning("few-shot 删除失败 vector_id=%s", vector_id, exc_info=True)
return False

View File

@@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import os
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
from decimal import Decimal from decimal import Decimal
from typing import Any from typing import Any
@@ -8,6 +9,7 @@ from sqlalchemy import func, select
from sqlalchemy.orm import Session, joinedload from sqlalchemy.orm import Session, joinedload
from app.algorithem.risk_graph import RiskHistoryStats, RiskObservationDraft from app.algorithem.risk_graph import RiskHistoryStats, RiskObservationDraft
from app.core.logging import get_logger
from app.db.base import Base from app.db.base import Base
from app.models.financial_record import ExpenseClaim from app.models.financial_record import ExpenseClaim
from app.models.risk_observation import RiskObservation, RiskObservationFeedback from app.models.risk_observation import RiskObservation, RiskObservationFeedback
@@ -17,6 +19,8 @@ from app.schemas.risk_observation import (
) )
from app.services.expense_claim_risk_stage import normalize_risk_business_stage from app.services.expense_claim_risk_stage import normalize_risk_business_stage
logger = get_logger("app.services.risk_observations")
HIGH_LEVELS = {"high", "critical"} HIGH_LEVELS = {"high", "critical"}
SEVERITY_SCORE = { SEVERITY_SCORE = {
"low": 32, "low": 32,
@@ -322,8 +326,27 @@ class RiskObservationService:
observation.status, observation.feedback_status = mapped observation.status, observation.feedback_status = mapped
self.db.commit() self.db.commit()
self.db.refresh(feedback) self.db.refresh(feedback)
self._maybe_ingest_few_shot(observation, feedback)
return feedback return feedback
def _maybe_ingest_few_shot(
self,
observation: RiskObservation,
feedback: RiskObservationFeedback,
) -> None:
"""人工确认/误报后把样本沉淀进 few-shot 池,任何失败都不影响主流程。"""
if os.environ.get("FEW_SHOT_INJECTION_ENABLED", "true").strip().lower() in {"0", "false", "no"}:
return
if observation.feedback_status not in {"confirmed", "false_positive"}:
return
try:
from app.services.few_shot_ingestion import FewShotIngestionService
FewShotIngestionService(self.db).ingest_observation_feedback(observation, feedback)
except Exception:
logger.exception("few-shot ingestion failed for observation %s", observation.id)
def summarize_dashboard( def summarize_dashboard(
self, self,
*, *,

View File

@@ -234,6 +234,10 @@ class RiskRuleGenerationService:
} }
for item in fields for item in fields
] ]
few_shot_samples = self._retrieve_few_shot_samples(
domain=domain,
natural_language=natural_language,
)
messages = build_risk_rule_compiler_messages( messages = build_risk_rule_compiler_messages(
domain=domain, domain=domain,
domain_label=BUSINESS_DOMAIN_LABELS[domain], domain_label=BUSINESS_DOMAIN_LABELS[domain],
@@ -243,6 +247,7 @@ class RiskRuleGenerationService:
expense_category_label=expense_category_label, expense_category_label=expense_category_label,
natural_language=natural_language, natural_language=natural_language,
available_fields=field_payload, available_fields=field_payload,
few_shot_samples=few_shot_samples,
) )
answer = self.runtime_chat_service.complete( answer = self.runtime_chat_service.complete(
messages, messages,
@@ -263,6 +268,29 @@ class RiskRuleGenerationService:
payload = unwrap_semantic_plan_payload(payload) payload = unwrap_semantic_plan_payload(payload)
return self._sanitize_model_draft(payload, fields=fields) return self._sanitize_model_draft(payload, fields=fields)
def _retrieve_few_shot_samples(
self,
*,
domain: str,
natural_language: str,
) -> list[dict[str, Any]]:
"""检索已确认历史样本,失败降级为空列表。"""
import os
if os.environ.get("FEW_SHOT_INJECTION_ENABLED", "true").strip().lower() in {"0", "false", "no"}:
return []
try:
from app.services.few_shot_retrieval import FewShotRetriever
retriever = FewShotRetriever.from_session(self.db)
return retriever.retrieve_for_risk_rule_generation(
domain=domain,
natural_language=natural_language,
)
except Exception:
return []
def _sanitize_model_draft( def _sanitize_model_draft(
self, self,
payload: dict[str, Any], payload: dict[str, Any],

View File

@@ -14,10 +14,15 @@ def build_risk_rule_compiler_messages(
expense_category_label: str, expense_category_label: str,
natural_language: str, natural_language: str,
available_fields: list[dict[str, Any]], available_fields: list[dict[str, Any]],
few_shot_samples: list[dict[str, Any]] | None = None,
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
"""构造自然语言规则编译提示词。 """构造自然语言规则编译提示词。
大模型只负责把业务语言拆成“语义计划”,后端会校验字段、操作符和模板。 大模型只负责把业务语言拆成“语义计划”,后端会校验字段、操作符和模板。
``few_shot_samples`` 是从已确认历史样本中检索出来的相似案例,会被合并进
``examples`` 字段并标注 ``source: "historical_confirmed"``,让编译器参考
过往人工结论。传 ``None`` 或空列表时行为与历史完全一致(向后兼容)。
""" """
schema = { schema = {
@@ -161,6 +166,20 @@ def build_risk_rule_compiler_messages(
}, },
} }
] ]
historical_examples: list[dict[str, Any]] = []
if few_shot_samples:
for sample in few_shot_samples:
historical_examples.append(
{
"source": "historical_confirmed",
"label": sample.get("label"),
"domain": sample.get("domain") or "",
"risk_type": sample.get("risk_type") or "",
"conclusion": sample.get("conclusion") or "",
"context": sample.get("context") or {},
}
)
merged_examples = historical_examples + examples
return [ return [
{ {
"role": "system", "role": "system",
@@ -186,7 +205,7 @@ def build_risk_rule_compiler_messages(
"natural_language": natural_language, "natural_language": natural_language,
"available_fields": available_fields, "available_fields": available_fields,
"required_json_shape": response_schema, "required_json_shape": response_schema,
"examples": examples, "examples": merged_examples,
}, },
ensure_ascii=False, ensure_ascii=False,
), ),

View File

@@ -1,329 +0,0 @@
"""风险规则 golden set 评测器与发布门禁。
在版本化的黄金用例集(:class:`GoldenCase`)上跑规则 manifest计算
accuracy/precision/recall并按"100% 通过"的硬阈值做发布门禁。
执行链路完全复用现有能力:
- ``RiskRuleTemplateExecutor.evaluate_with_trace`` 跑规则
- ``AgentAssetRiskRuleTestingMixin`` 的 static helpers 组装 synthetic claim
- 单条比对逻辑与 ``_run_sample_case`` 保持一致
门禁语义与现有 ``test_passed`` 一致:未通过抛 ``PermissionError``
同时写一条 ``AgentAssetTestRun(test_type='golden')`` 记录结果。
"""
from __future__ import annotations
import os
import uuid
from dataclasses import dataclass, field
from datetime import UTC, date, datetime
from decimal import Decimal, InvalidOperation
from typing import Any
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.core.agent_enums import AgentAssetType
from app.core.logging import get_logger
from app.models.agent_asset import AgentAsset, AgentAssetTestRun
from app.models.employee import Employee
from app.models.financial_record import ExpenseClaim, ExpenseClaimItem
from app.models.golden_case import GoldenCase
from app.services.risk_rule_template_executor import RiskRuleTemplateExecutor
logger = get_logger("app.services.risk_rule_golden_evaluator")
GOLDEN_GATE_FLAG = "GOLDEN_SET_GATE_ENABLED"
@dataclass
class GoldenCaseResult:
case_id: str
name: str
expected_hit: bool
actual_hit: bool
expected_severity: str
actual_severity: str
passed: bool
message: str = ""
evidence: dict[str, Any] = field(default_factory=dict)
trace: dict[str, Any] = field(default_factory=dict)
@dataclass
class GoldenEvalReport:
total: int = 0
passed_count: int = 0
failed_count: int = 0
accuracy: float = 0.0
precision: float = 0.0
recall: float = 0.0
all_passed: bool = True
results: list[GoldenCaseResult] = field(default_factory=list)
def to_dict(self) -> dict[str, Any]:
return {
"total": self.total,
"passed_count": self.passed_count,
"failed_count": self.failed_count,
"accuracy": round(self.accuracy, 4),
"precision": round(self.precision, 4),
"recall": round(self.recall, 4),
"all_passed": self.all_passed,
"results": [
{
"case_id": r.case_id,
"name": r.name,
"expected_hit": r.expected_hit,
"actual_hit": r.actual_hit,
"expected_severity": r.expected_severity,
"actual_severity": r.actual_severity,
"passed": r.passed,
"message": r.message,
}
for r in self.results
],
}
def _gate_enabled() -> bool:
return os.environ.get(GOLDEN_GATE_FLAG, "true").strip().lower() not in {"0", "false", "no"}
# ---- synthetic claim 构建(与 AgentAssetRiskRuleTestingMixin._build_synthetic_claim 一致)----
def _extract_manifest_fields(manifest: dict[str, Any]) -> list[dict[str, str]]:
inputs = manifest.get("inputs") if isinstance(manifest.get("inputs"), dict) else {}
fields = inputs.get("fields") if isinstance(inputs.get("fields"), list) else []
normalized: list[dict[str, str]] = []
for item in fields:
if not isinstance(item, dict):
continue
key = str(item.get("key") or "").strip()
if key:
normalized.append({"key": key, "label": str(item.get("label") or key).strip()})
return normalized
def _coerce_sample_value(field_key: str, value: Any) -> Any:
import re
if field_key.endswith("route_cities") and isinstance(value, str):
return [item.strip() for item in re.split(r"[,,、/ ]+", value) if item.strip()]
return value
def _to_decimal(value: Any) -> Decimal:
try:
return Decimal(str(value or "0"))
except (InvalidOperation, ValueError):
return Decimal("0")
def _build_synthetic_claim(
values: dict[str, Any],
manifest: dict[str, Any],
) -> tuple[ExpenseClaim, list[dict[str, Any]]]:
claim = ExpenseClaim(
claim_no="GOLDEN-RISK-RULE",
employee_name=str(values.get("claim.employee_name") or "测试员工"),
department_name=str(values.get("claim.department_name") or "测试部门"),
expense_type=str(values.get("item.item_type") or "差旅费"),
reason=str(values.get("claim.reason") or "测试报销事由"),
location=str(values.get("claim.location") or "北京"),
amount=_to_decimal(values.get("claim.amount")),
currency="CNY",
invoice_count=1,
occurred_at=datetime.now(UTC),
status="draft",
)
item = ExpenseClaimItem(
item_date=date.today(),
item_type=str(values.get("item.item_type") or "住宿费"),
item_reason=str(values.get("item.item_reason") or claim.reason),
item_location=str(values.get("item.item_location") or claim.location),
item_amount=_to_decimal(values.get("item.item_amount") or claim.amount),
)
claim.items = [item]
if values.get("employee.location"):
claim.employee = Employee(
employee_no="GOLDEN-EMPLOYEE",
name=claim.employee_name,
email="golden-rule-test@example.com",
location=str(values.get("employee.location") or ""),
)
attachment_fields: list[dict[str, Any]] = []
document_info: dict[str, Any] = {"fields": attachment_fields}
for field in _extract_manifest_fields(manifest):
key = field["key"]
if key not in values:
continue
value = _coerce_sample_value(key, values.get(key))
if key.startswith("claim."):
setattr(claim, key.removeprefix("claim."), value)
elif key.startswith("item."):
setattr(item, key.removeprefix("item."), value)
elif key.startswith("attachment."):
short_key = key.removeprefix("attachment.")
document_info[short_key] = value
attachment_fields.append({"key": short_key, "label": field["label"], "value": value})
return claim, [{"document_info": document_info, "ocr_text": document_info.get("ocr_text", "")}]
def _run_single_case(
manifest: dict[str, Any],
values: dict[str, Any],
expected_hit: bool,
expected_severity: str,
) -> GoldenCaseResult:
claim, contexts = _build_synthetic_claim(values, manifest)
execution = RiskRuleTemplateExecutor().evaluate_with_trace(manifest, claim=claim, contexts=contexts)
result = execution["result"]
actual_hit = result is not None
actual_severity = (
str((manifest.get("outcomes") or {}).get("fail", {}).get("severity") or "").strip()
if actual_hit
else "none"
)
severity_passed = (
not actual_hit or not expected_severity or expected_severity == actual_severity
)
passed = actual_hit == expected_hit and severity_passed
return GoldenCaseResult(
case_id="",
name="",
expected_hit=expected_hit,
actual_hit=actual_hit,
expected_severity=expected_severity,
actual_severity=actual_severity,
passed=passed,
message=str(result.get("message") or "") if isinstance(result, dict) else "",
evidence=result.get("evidence") if isinstance(result, dict) else {},
trace=execution.get("trace") if isinstance(execution.get("trace"), dict) else {},
)
def _aggregate(results: list[GoldenCaseResult]) -> GoldenEvalReport:
total = len(results)
if total == 0:
return GoldenEvalReport(total=0, all_passed=True)
passed_count = sum(1 for r in results if r.passed)
tp = sum(1 for r in results if r.expected_hit and r.actual_hit)
fp = sum(1 for r in results if r.expected_hit and not r.actual_hit) # 应命中未命中
fn = sum(1 for r in results if not r.expected_hit and r.actual_hit) # 不应命中却命中
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
return GoldenEvalReport(
total=total,
passed_count=passed_count,
failed_count=total - passed_count,
accuracy=passed_count / total,
precision=precision,
recall=recall,
all_passed=passed_count == total,
results=results,
)
class RiskRuleGoldenEvaluator:
"""在 golden set 上评测规则 manifest 并执行发布门禁。"""
def evaluate(self, manifest: dict[str, Any], cases: list[GoldenCase]) -> GoldenEvalReport:
results: list[GoldenCaseResult] = []
for case in cases:
result = _run_single_case(
manifest,
values=case.values_json or {},
expected_hit=bool(case.expected_hit),
expected_severity=str(case.expected_severity or ""),
)
result.case_id = case.case_key or case.id
result.name = case.name
results.append(result)
return _aggregate(results)
def evaluate_for_rule(
self,
db: Session,
manifest: dict[str, Any],
rule_code: str,
) -> GoldenEvalReport:
cases = list(
db.scalars(
select(GoldenCase).where(
GoldenCase.rule_code == rule_code,
GoldenCase.status == "active",
)
)
)
if not cases:
return GoldenEvalReport(total=0, all_passed=True)
return self.evaluate(manifest, cases)
def require_pass(
self,
db: Session,
asset: AgentAsset,
version: str,
manifest: dict[str, Any],
rule_code: str,
*,
actor: str,
) -> GoldenEvalReport:
"""发布门禁入口:跑 golden set未 100% 通过抛 PermissionError。
golden set 为空或门禁关闭时放行; evaluator 异常时降级放行(记日志)。
无论放行与否,都写一条 ``AgentAssetTestRun(test_type='golden')`` 记录。
"""
if not _gate_enabled():
return GoldenEvalReport(total=0, all_passed=True)
try:
report = self.evaluate_for_rule(db, manifest, rule_code)
except Exception:
logger.exception("golden set 评测异常,降级放行 asset_id=%s", asset.id)
report = GoldenEvalReport(total=0, all_passed=True)
self._record_test_run(db, asset, version, report, actor=actor)
if report.total > 0 and not report.all_passed:
failures = report.to_dict()["results"]
raise PermissionError(
f"golden set 回归未通过({report.passed_count}/{report.total}"
f"发布被拦截。失败用例:{failures}"
)
return report
def _record_test_run(
self,
db: Session,
asset: AgentAsset,
version: str,
report: GoldenEvalReport,
*,
actor: str,
) -> None:
try:
run = AgentAssetTestRun(
id=str(uuid.uuid4()),
asset_id=asset.id,
version=version,
test_type="golden",
status="completed",
passed=report.all_passed,
summary=(
f"golden set {report.passed_count}/{report.total} passed"
if report.total > 0
else "golden set empty, gate skipped"
),
input_json={"rule_code": getattr(asset, "rule_code", "") or ""},
result_json=report.to_dict(),
created_by=actor,
)
db.add(run)
db.commit()
except Exception:
logger.warning("golden test run 记录失败 asset_id=%s", asset.id, exc_info=True)
db.rollback()

View File

@@ -0,0 +1,104 @@
from __future__ import annotations
from unittest.mock import patch
import pytest
from app.services.embedding_provider import EmbeddingProvider, _runtime_model_config_from_dict
from app.services.knowledge_rag_runtime import KnowledgeRagError, RuntimeModelConfig
def _config(provider: str = "GLM") -> RuntimeModelConfig:
return RuntimeModelConfig(
slot="embedding",
provider=provider,
model="Embedding-3",
endpoint="https://open.bigmodel.cn/api/paas/v4/",
api_key="k",
capability="embedding",
)
def test_runtime_model_config_from_dict_maps_fields() -> None:
cfg = _runtime_model_config_from_dict(
{
"slot": "embedding",
"provider": "GLM",
"model": "Embedding-3",
"endpoint": "https://e",
"apiKey": "secret",
"capability": "embedding",
}
)
assert cfg.api_key == "secret"
assert cfg.model == "Embedding-3"
def test_embed_empty_texts_returns_empty() -> None:
provider = EmbeddingProvider(_config())
assert provider.embed([]) == []
def test_embed_returns_vectors_and_caches_dimension() -> None:
provider = EmbeddingProvider(_config())
with patch(
"app.services.embedding_provider._request_embeddings_public",
return_value=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
) as mock_req:
vectors = provider.embed(["a", "b"])
assert vectors == [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
assert provider.dimension() == 3
calls_after_first_dimension = mock_req.call_count
# 第二次 dimension 不应再次请求
assert provider.dimension() == 3
assert mock_req.call_count == calls_after_first_dimension
def test_dimension_raises_on_invalid_vectors() -> None:
provider = EmbeddingProvider(_config())
with patch(
"app.services.embedding_provider._request_embeddings_public",
return_value=[],
):
with pytest.raises(KnowledgeRagError):
provider.dimension()
def test_request_embeddings_public_glm_branch() -> None:
cfg = _config("GLM")
with patch(
"app.services.embedding_provider._send_json_request",
return_value=(200, {"data": [{"embedding": [0.1, 0.2]}]}),
) as mock_send:
from app.services.embedding_provider import _request_embeddings_public
vectors = _request_embeddings_public(cfg, ["x"])
assert vectors == [[0.1, 0.2]]
called_url = mock_send.call_args.args[1]
assert called_url.endswith("/embeddings")
def test_request_embeddings_public_ollama_branch() -> None:
cfg = _config("Ollama")
with patch(
"app.services.embedding_provider._send_json_request",
return_value=(200, {"embeddings": [[0.5, 0.6]]}),
) as mock_send:
from app.services.embedding_provider import _request_embeddings_public
vectors = _request_embeddings_public(cfg, ["x"])
assert vectors == [[0.5, 0.6]]
called_url = mock_send.call_args.args[1]
assert called_url.endswith("/api/embed")
def test_request_embeddings_public_raises_on_http_error() -> None:
cfg = _config("GLM")
with patch(
"app.services.embedding_provider._send_json_request",
return_value=(500, {"message": "boom"}),
):
from app.services.embedding_provider import _request_embeddings_public
with pytest.raises(KnowledgeRagError):
_request_embeddings_public(cfg, ["x"])

View File

@@ -0,0 +1,214 @@
from __future__ import annotations
from collections.abc import Generator
from datetime import datetime
from decimal import Decimal
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import StaticPool
from app.db.base import Base
from app.models.employee import Employee
from app.models.few_shot_sample import FewShotSample
from app.models.financial_record import ExpenseClaim
from app.models.risk_observation import RiskObservation
from app.schemas.risk_observation import RiskObservationFeedbackCreate
from app.services.few_shot_ingestion import FewShotIngestionService
from app.services.risk_observations import RiskObservationService
def _build_session() -> Session:
engine = create_engine(
"sqlite+pysqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
Base.metadata.create_all(bind=engine)
factory = sessionmaker(bind=engine, autoflush=False, autocommit=False)
return factory()
def _observation(db: Session, key: str = "risk:c1:dup") -> RiskObservation:
db.add(Employee(id="emp-1", employee_no="E1", name="员工", email="e@e.com", grade="P6"))
db.add(
ExpenseClaim(
id="c1",
claim_no="BX-001",
employee_id="emp-1",
employee_name="员工",
department_name="风控部",
expense_type="travel",
reason="客户拜访",
location="上海",
amount=Decimal("1000"),
currency="CNY",
occurred_at=datetime(2026, 1, 1),
submitted_at=datetime(2026, 1, 1),
status="submitted",
approval_stage="manager_review",
risk_flags_json=[],
)
)
db.flush()
obs = RiskObservation(
observation_key=key,
subject_type="expense_claim",
subject_key="claim:c1",
claim_id="c1",
claim_no="BX-001",
risk_type="duplicate_invoice",
risk_signal="duplicate_invoice",
title="重复发票",
description="同一发票出现在多张报销单",
risk_score=86,
risk_level="high",
confidence_score=0.8,
source="financial_risk_graph",
algorithm_version="v1",
ontology_json={"domain": "expense", "scenario": "reimbursement"},
)
db.add(obs)
db.commit()
db.refresh(obs)
return obs
def test_ingest_confirmed_persists_sample_and_calls_store() -> None:
with _build_session() as db:
obs = _observation(db)
obs.feedback_status = "confirmed"
service = FewShotIngestionService(db)
fake_store = MagicMock()
fake_store.upsert.return_value = "vec-1"
with patch.object(service, "_store", return_value=fake_store):
sample = service.ingest_observation_feedback(
obs,
MagicMock(feedback_type="confirm", comment="确认重复发票", actor="audit"),
)
assert sample is not None
assert sample.label == "confirmed"
assert sample.sample_key == f"obs:{obs.id}"
assert "重复发票" in sample.case_text
assert "确认重复发票" in sample.conclusion_text
assert sample.vector_id == "vec-1"
fake_store.upsert.assert_called_once()
def test_ingest_false_positive_also_persisted() -> None:
with _build_session() as db:
obs = _observation(db, key="risk:c2:fp")
obs.feedback_status = "false_positive"
db.commit()
service = FewShotIngestionService(db)
with patch.object(service, "_store", return_value=MagicMock(upsert=MagicMock(return_value=None))):
sample = service.ingest_observation_feedback(
obs,
MagicMock(feedback_type="false_positive", comment="", actor="audit"),
)
assert sample is not None
assert sample.label == "false_positive"
assert "误报" in sample.conclusion_text
def test_ingest_ignored_label_returns_none() -> None:
with _build_session() as db:
obs = _observation(db)
obs.feedback_status = "ignored"
service = FewShotIngestionService(db)
assert service.ingest_observation_feedback(obs, MagicMock()) is None
def test_ingest_is_idempotent_on_duplicate_sample_key() -> None:
with _build_session() as db:
obs = _observation(db)
service = FewShotIngestionService(db)
store = MagicMock()
store.upsert.side_effect = ["vec-1", "vec-2"]
with patch.object(service, "_store", return_value=store):
obs.feedback_status = "confirmed"
first = service.ingest_observation_feedback(
obs, MagicMock(feedback_type="confirm", comment="第一次", actor="a")
)
# 模拟后续被改判为误报
obs.feedback_status = "false_positive"
second = service.ingest_observation_feedback(
obs, MagicMock(feedback_type="false_positive", comment="改判", actor="a")
)
assert first is not None and second is not None
assert first.id == second.id # 同一行更新
from sqlalchemy import select
count = db.scalar(select(FewShotSample).where(FewShotSample.sample_key == f"obs:{obs.id}"))
assert count is not None
assert second.label == "false_positive"
def test_create_feedback_hook_triggers_ingestion() -> None:
with _build_session() as db:
service = RiskObservationService(db)
obs = _observation(db)
ingest_calls: list = []
def _spy_ingest(o, f):
ingest_calls.append((o.id, f.feedback_type))
return None
with patch(
"app.services.few_shot_ingestion.FewShotIngestionService.ingest_observation_feedback",
side_effect=_spy_ingest,
):
service.create_feedback(
obs.observation_key,
RiskObservationFeedbackCreate(feedback_type="confirm", actor="audit"),
)
assert len(ingest_calls) == 1
assert ingest_calls[0][1] == "confirm"
def test_create_feedback_hook_skipped_for_comment_feedback() -> None:
with _build_session() as db:
service = RiskObservationService(db)
obs = _observation(db)
with patch(
"app.services.few_shot_ingestion.FewShotIngestionService.ingest_observation_feedback"
) as mock_ingest:
service.create_feedback(
obs.observation_key,
RiskObservationFeedbackCreate(feedback_type="comment", action="note", actor="audit"),
)
mock_ingest.assert_not_called()
def test_create_feedback_hook_swallows_ingestion_failure() -> None:
with _build_session() as db:
service = RiskObservationService(db)
obs = _observation(db)
with patch(
"app.services.few_shot_ingestion.FewShotIngestionService.ingest_observation_feedback",
side_effect=RuntimeError("boom"),
):
# 不应抛异常
feedback = service.create_feedback(
obs.observation_key,
RiskObservationFeedbackCreate(feedback_type="confirm", actor="audit"),
)
assert feedback.feedback_type == "confirm"
def test_create_feedback_hook_respects_feature_flag(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("FEW_SHOT_INJECTION_ENABLED", "false")
with _build_session() as db:
service = RiskObservationService(db)
obs = _observation(db)
with patch(
"app.services.few_shot_ingestion.FewShotIngestionService.ingest_observation_feedback"
) as mock_ingest:
service.create_feedback(
obs.observation_key,
RiskObservationFeedbackCreate(feedback_type="confirm", actor="audit"),
)
mock_ingest.assert_not_called()

View File

@@ -0,0 +1,119 @@
from __future__ import annotations
import json
from unittest.mock import MagicMock, patch
import pytest
from app.services.few_shot_retrieval import FewShotRetriever
from app.services.few_shot_store import FewShotStore
from app.services.risk_rule_generation_prompt import build_risk_rule_compiler_messages
def _hit(score: float, label: str, conclusion: str, risk_type: str = "duplicate_invoice") -> dict:
return {
"sample_id": "s1",
"score": score,
"label": label,
"domain": "expense",
"risk_type": risk_type,
"conclusion_text": conclusion,
"payload_json": {
"risk_signal": risk_type,
"risk_level": "high",
"ontology": {"scenario": "reimbursement"},
"feedback_comment": "",
},
}
def test_retrieve_returns_injection_blocks_with_token_budget() -> None:
store = MagicMock(spec=FewShotStore)
store.search.return_value = [
_hit(0.9, "confirmed", "确认重复发票需拦截"),
_hit(0.8, "false_positive", "此情形属于正常拆单不拦截"),
_hit(0.7, "confirmed", "确认重复发票需拦截"), # 重复结论应被去重
]
retriever = FewShotRetriever(store)
blocks = retriever.retrieve_for_risk_rule_generation(
domain="expense", natural_language="同一发票重复报销"
)
assert len(blocks) == 2
assert blocks[0]["score"] == 0.9
assert blocks[0]["label"] == "confirmed"
assert blocks[0]["source"] == "historical_confirmed"
assert blocks[1]["label"] == "false_positive"
# 去重:第三条结论与第一条相同,应被过滤
conclusions = [b["conclusion"] for b in blocks]
assert len(set(conclusions)) == len(conclusions)
def test_retrieve_empty_case_text_returns_empty() -> None:
store = MagicMock(spec=FewShotStore)
retriever = FewShotRetriever(store)
assert retriever.retrieve_for_risk_rule_generation(natural_language="") == []
store.search.assert_not_called()
def test_retrieve_truncates_overlong_conclusion() -> None:
store = MagicMock(spec=FewShotStore)
long_text = "长结论" * 500
store.search.return_value = [
_hit(0.9, "confirmed", long_text),
]
retriever = FewShotRetriever(store)
blocks = retriever.retrieve_for_risk_rule_generation(natural_language="x")
assert len(blocks) == 1
# 超长结论应被截断到单条上限
from app.services.few_shot_retrieval import SINGLE_SAMPLE_MAX_CHARS
assert len(blocks[0]["conclusion"]) <= SINGLE_SAMPLE_MAX_CHARS
def test_build_prompt_merges_few_shot_into_examples() -> None:
samples = [
{
"source": "historical_confirmed",
"label": "confirmed",
"domain": "expense",
"risk_type": "duplicate_invoice",
"conclusion": "确认重复发票",
"context": {"risk_signal": "duplicate_invoice"},
}
]
messages = build_risk_rule_compiler_messages(
domain="expense",
domain_label="报销",
business_stage="reimbursement",
business_stage_label="报销",
expense_category=None,
expense_category_label="",
natural_language="重复发票规则",
available_fields=[{"key": "attachment.invoice_no", "label": "发票号", "type": "string", "source": "attachment"}],
few_shot_samples=samples,
)
assert len(messages) == 2
payload = json.loads(messages[1]["content"])
examples = payload["examples"]
# 前两条是历史样本,后面是内置 examples
assert examples[0]["source"] == "historical_confirmed"
assert examples[0]["conclusion"] == "确认重复发票"
# 内置 example 仍存在(无 source 字段)
assert any("user_rule" in ex for ex in examples)
def test_build_prompt_without_few_shot_is_backward_compatible() -> None:
messages = build_risk_rule_compiler_messages(
domain="expense",
domain_label="报销",
business_stage="reimbursement",
business_stage_label="报销",
expense_category=None,
expense_category_label="",
natural_language="重复发票规则",
available_fields=[],
)
payload = json.loads(messages[1]["content"])
examples = payload["examples"]
# 无 few_shot_samples 时 examples 里不应有 historical_confirmed 来源
assert all(ex.get("source") != "historical_confirmed" for ex in examples)

View File

@@ -1,262 +0,0 @@
from __future__ import annotations
from collections.abc import Generator
from datetime import datetime
from decimal import Decimal
from unittest.mock import MagicMock, patch
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import StaticPool
from app.db.base import Base
from app.models.agent_asset import AgentAsset, AgentAssetTestRun
from app.models.employee import Employee
from app.models.financial_record import ExpenseClaim
from app.models.golden_case import GoldenCase
from app.services.risk_rule_golden_evaluator import (
GoldenEvalReport,
RiskRuleGoldenEvaluator,
_aggregate,
_run_single_case,
)
def _build_session() -> Session:
engine = create_engine(
"sqlite+pysqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
Base.metadata.create_all(bind=engine)
factory = sessionmaker(bind=engine, autoflush=False, autocommit=False)
return factory()
def _keyword_manifest() -> dict:
"""一个简单的 keyword_match_v1 manifestreason 含"虚假"则命中。"""
return {
"rule_code": "risk.test.keyword",
"template_key": "keyword_match_v1",
"inputs": {
"fields": [
{"key": "claim.reason", "label": "事由", "type": "text", "source": "claim"},
]
},
"params": {
"keywords": ["虚假"],
"field_keys": ["claim.reason"],
"search_fields": ["claim.reason"],
},
"outcomes": {"fail": {"severity": "high", "risk_score": 80}},
}
def _golden_case(
case_key: str,
*,
reason: str,
expected_hit: bool,
rule_code: str = "risk.test.keyword",
) -> GoldenCase:
return GoldenCase(
case_key=case_key,
rule_code=rule_code,
name=f"case-{case_key}",
values_json={"claim.reason": reason},
expected_hit=expected_hit,
status="active",
)
def test_run_single_case_hit_matches() -> None:
result = _run_single_case(
_keyword_manifest(),
values={"claim.reason": "虚假发票报销"},
expected_hit=True,
expected_severity="high",
)
assert result.actual_hit is True
assert result.passed is True
assert result.actual_severity == "high"
def test_run_single_case_no_hit_matches() -> None:
result = _run_single_case(
_keyword_manifest(),
values={"claim.reason": "正常差旅报销"},
expected_hit=False,
expected_severity="",
)
assert result.actual_hit is False
assert result.passed is True
def test_run_single_case_mismatch_fails() -> None:
result = _run_single_case(
_keyword_manifest(),
values={"claim.reason": "虚假发票"},
expected_hit=False, # 期望不命中,但实际命中
expected_severity="",
)
assert result.actual_hit is True
assert result.passed is False
def test_run_single_case_severity_mismatch_fails() -> None:
result = _run_single_case(
_keyword_manifest(),
values={"claim.reason": "虚假发票"},
expected_hit=True,
expected_severity="critical", # 实际是 high
)
assert result.passed is False
def test_aggregate_empty_returns_passed() -> None:
report = _aggregate([])
assert report.total == 0
assert report.all_passed is True
assert report.accuracy == 0.0
def test_aggregate_all_passed() -> None:
from app.services.risk_rule_golden_evaluator import GoldenCaseResult
results = [
GoldenCaseResult("1", "a", True, True, "high", "high", True),
GoldenCaseResult("2", "b", False, False, "", "none", True),
]
report = _aggregate(results)
assert report.total == 2
assert report.passed_count == 2
assert report.accuracy == 1.0
assert report.all_passed is True
def test_aggregate_with_failure() -> None:
from app.services.risk_rule_golden_evaluator import GoldenCaseResult
results = [
GoldenCaseResult("1", "a", True, True, "high", "high", True),
GoldenCaseResult("2", "b", True, False, "high", "none", False), # FP
]
report = _aggregate(results)
assert report.passed_count == 1
assert report.failed_count == 1
assert report.accuracy == 0.5
assert report.all_passed is False
assert report.precision == 0.5 # 1/(1+1)
def test_evaluate_for_rule_empty_returns_passed() -> None:
with _build_session() as db:
report = RiskRuleGoldenEvaluator().evaluate_for_rule(db, _keyword_manifest(), "risk.test.keyword")
assert report.total == 0
assert report.all_passed is True
def test_evaluate_for_rule_all_pass() -> None:
with _build_session() as db:
db.add(_golden_case("g1", reason="虚假发票", expected_hit=True))
db.add(_golden_case("g2", reason="正常报销", expected_hit=False))
db.commit()
report = RiskRuleGoldenEvaluator().evaluate_for_rule(db, _keyword_manifest(), "risk.test.keyword")
assert report.total == 2
assert report.all_passed is True
assert report.accuracy == 1.0
def test_evaluate_for_rule_with_failure() -> None:
with _build_session() as db:
db.add(_golden_case("g1", reason="虚假发票", expected_hit=False)) # 期望不命中但实际命中
db.add(_golden_case("g2", reason="正常报销", expected_hit=True)) # 期望命中但实际不命中
db.commit()
report = RiskRuleGoldenEvaluator().evaluate_for_rule(db, _keyword_manifest(), "risk.test.keyword")
assert report.total == 2
assert report.all_passed is False
assert report.failed_count == 2
def _asset(asset_id: str, code: str) -> AgentAsset:
return AgentAsset(
id=asset_id,
code=code,
name=code,
asset_type="rule",
domain="expense",
owner="tester",
status="review",
working_version="v1",
)
def test_require_pass_passes_when_all_green() -> None:
with _build_session() as db:
asset = _asset("a1", "R1")
db.add(asset)
db.add(_golden_case("g1", reason="虚假", expected_hit=True))
db.commit()
report = RiskRuleGoldenEvaluator().require_pass(
db, asset, "v1", _keyword_manifest(), "risk.test.keyword", actor="tester"
)
assert report.all_passed is True
# 应写一条 test_type='golden' 记录
run = db.query(AgentAssetTestRun).filter_by(asset_id="a1", test_type="golden").one()
assert run.passed is True
def test_require_pass_raises_on_failure() -> None:
with _build_session() as db:
asset = _asset("a2", "R2")
db.add(asset)
db.add(_golden_case("g1", reason="虚假", expected_hit=False)) # 会失败
db.commit()
with pytest.raises(PermissionError):
RiskRuleGoldenEvaluator().require_pass(
db, asset, "v1", _keyword_manifest(), "risk.test.keyword", actor="tester"
)
run = db.query(AgentAssetTestRun).filter_by(asset_id="a2", test_type="golden").one()
assert run.passed is False
def test_require_pass_empty_golden_set_passes() -> None:
with _build_session() as db:
asset = _asset("a3", "R3")
db.add(asset)
db.commit()
report = RiskRuleGoldenEvaluator().require_pass(
db, asset, "v1", _keyword_manifest(), "risk.test.keyword", actor="tester"
)
assert report.total == 0
assert report.all_passed is True
def test_require_pass_respects_feature_flag(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("GOLDEN_SET_GATE_ENABLED", "false")
with _build_session() as db:
asset = _asset("a4", "R4")
db.add(asset)
db.add(_golden_case("g1", reason="虚假", expected_hit=False)) # 本应失败
db.commit()
# 门禁关闭,应放行不抛异常
report = RiskRuleGoldenEvaluator().require_pass(
db, asset, "v1", _keyword_manifest(), "risk.test.keyword", actor="tester"
)
assert report.total == 0
def test_require_pass_swallows_evaluator_exception() -> None:
with _build_session() as db:
asset = _asset("a5", "R5")
db.add(asset)
db.commit()
evaluator = RiskRuleGoldenEvaluator()
with patch.object(evaluator, "evaluate_for_rule", side_effect=RuntimeError("boom")):
report = evaluator.require_pass(
db, asset, "v1", _keyword_manifest(), "risk.test.keyword", actor="tester"
)
assert report.total == 0
assert report.all_passed is True # 降级放行