diff --git a/document/development/2026-07-03/feature/ai-data-flywheel/CONCEPT.md b/document/development/2026-07-03/feature/ai-data-flywheel/CONCEPT.md new file mode 100644 index 0000000..525127a --- /dev/null +++ b/document/development/2026-07-03/feature/ai-data-flywheel/CONCEPT.md @@ -0,0 +1,190 @@ +# AI 数据飞轮 概念文档 + +更新时间:2026-07-03 + +文档路径:document/development/2026-07-03/feature/ai-data-flywheel/CONCEPT.md + +## 功能一句话 + +把用户反馈、人工修正、风险样本、评测结果自动沉淀并回流到下一次 LLM 推理与规则生成,让费控系统在不停机的情况下持续提升准确率、降低误报率和人工干预率。 + +## 背景与问题 + +- 当前现状:项目已具备"聪明"的骨架——RAG(`knowledge_rag_runtime.py` + Qdrant)、风险规则自动生成(`risk_rule_generation*.py`)、反馈样本沉淀(`skills/domain/false-positive-sample-accumulator` 等 3 个 accumulator)、规则回放评测(`risk-algorithm-replay-evaluator`)、行为画像(`employee_behavior_profile*`)、用户反馈表(`agent_feedback.py` + `AgentOperationFeedback`)。 +- 用户痛点:样本在往 accumulator 池子里堆,但**下一次 LLM 推理时并没有把这些样本检索出来当 few-shot 喂进去**;prompt 散落在各 `*_prompt.py`,没有版本号、没有在线 A/B、没有回归门禁;OCR 抽取的人工修正值没有回流成评测/训练数据;低分反馈只汇总不归因。 +- 业务影响:系统每次推理都从"初始水平"出发,无法把历史踩过的坑转化成下一次的能力;改 prompt / 改规则无法证明是否变好,存在隐性回归风险;运营和算法同事看不到"系统在进步"的证据。 +- 为什么现在需要做:飞轮骨架已齐,缺的是把"样本池 → 检索注入 + 评测门禁 → prompt/规则版本"这段断开的箭头接上。补上后整张图就转起来,且改动集中在 prompt 构造层 + 新增 eval 目录,不动业务主链路,风险低。 + +## 目标与非目标 + +### 目标 + +- [G1] few-shot 在线检索注入:推理前从样本池按 case 特征做向量检索,取 top-k 历史样本(含人工结论)拼进 system prompt。 +- [G2] 黄金评测集 + 自动回归门禁:版本化 golden set,prompt/规则变更后在 golden set 上自动跑分,分数不达标禁止发布。 +- [G3] Prompt 版本化 + Canary A/B:prompt 进表带版本号,支持 stable / canary 流量切分,反馈分数对比。 +- [G4] 抽取修正回流:附件/明细字段的人工修正值记录为 diff,沉淀为抽取评测集与 few-shot 样本。 +- [G5] 低分反馈自动归因:低分反馈触发归因 agent,拉 trace 诊断错误环节并生成改进任务。 +- [G6] AI 智商看板:每周自动跑 golden set,输出准确率/召回率/误报率/人工干预率随时间的曲线。 + +### 非目标 + +- [NG1] 本轮不做模型微调 / 自训练:只走 prompt 侧的 in-context learning + 规则学习。 +- [NG2] 本轮不改变现有业务主链路(申请单、报销、审批)的接口契约。 +- [NG3] 本轮不替换 Qdrant / LightRAG 底座,复用现有向量存储与 embedding 配置。 +- [NG4] 政策新鲜度检测(外部政策变更 → 自动重生成规则)后续再评估,本轮只在评测门禁侧预留接口。 + +## 用户与场景 + +- 目标用户: + 1. 报销人 / 申请人:感知到系统越来越准,少打回、少补件。 + 2. 财务审批人:误报率下降,审批被打断的次数减少。 + 3. 算法/运营同学:能看到智商曲线、能灰度上线 prompt、能跑回归评测。 +- 使用入口: + - 推理时自动注入 few-shot(对用户透明)。 + - 后台 Canary 控制台(运营切流量、看分数)。 + - AI 智商看板(周报 / 在线查询)。 +- 核心场景: + 1. 用户提交报销 → 系统预审 → 预审 prompt 自动注入相似历史误报样本 → 给出更准结论。 + 2. 算法同学改了 risk rule 生成 prompt → 发布前自动跑 golden set → 不达标被门禁拦下。 + 3. 用户给低分 → 归因 agent 诊断"是检索没召回 / 规则误判 / 回复格式问题"→ 自动建改进任务并回写样本池。 + 4. 审批人改了 OCR 抽错的金额 → diff 自动沉淀 → 下次同类票据抽取 prompt 多一条 few-shot。 +- 异常场景: + - 样本池为空或检索失败 → 退化为无 few-shot 推理,不阻塞主链路。 + - 评测门禁服务不可用 → 默认放行 stable,canary 自动暂停。 + - Canary 候选 prompt 分数劣化 → 自动回滚到 stable。 + +## 功能能力 + +- [C1] 输入能力:消费 accumulator 样本池、`AgentOperationFeedback`、附件修正 diff、trace 数据作为飞轮原料。 +- [C2] 处理能力:样本检索(向量 + 元数据过滤)、评测打分(准确率/召回率/误报率/F1)、Canary 流量切分、低分归因。 +- [C3] 输出能力:few-shot 注入后的 messages、评测报告、智商曲线、改进任务、归因结论。 +- [C4] 状态与权限:样本带"人工已确认"标签才进可注入集合;prompt 版本有 stable/canary/pinned 状态;评测门禁可由运营关闭(审计可见)。 +- [C5] 边界与降级:检索失败、评测失败、Canary 失败均降级到 stable,不阻塞业务推理。 + +## 方案设计 + +### 前端 + +- 页面/组件: + - AI 智商看板(新页面,复用 `finance-report` 看板骨架):准确率/召回率/误报率/人工干预率随时间曲线 + golden set 覆盖度。 + - Canary 控制台(并入 `SettingsView` / `PoliciesView`):列出各场景 prompt 版本、流量比例、当前分数、一键回滚。 +- 交互状态:加载/空态(样本不足)/错误态(评测失败)/权限态(仅算法运营)。 +- 展示规则:曲线按场景(差旅/报销/预算)分面;Canary 显示置信区间,差异不显著时标注。 +- 降级处理:看板数据不可用时提示"数据生成中",不报错。 + +### 后端 + +- 接口/服务(新增,按职责拆分,单文件 ≤ 800 行): + - `services/few_shot_retrieval.py`:样本检索器,复用 Qdrant,输入 case 特征 → 输出 top-k 样本(带人工结论)。 + - `services/prompt_registry.py`:prompt 版本注册中心,按场景 + 策略(latest/canary/pinned)取 prompt。 + - `services/eval_harness.py`:在 golden set 上跑评测,输出指标;被发布门禁和智商看板共用。 + - `services/feedback_attribution.py`:低分归因 agent,复用 `AgentTraceCenter` 数据。 + - `services/extraction_correction_recorder.py`:记录 OCR 抽取字段的人工修正 diff。 +- 改造点(在现有 prompt 构造文件加 inject 钩子,不改业务接口): + - `risk_rule_generation_prompt.py`、`user_agent_application.py`、`expense_claim_pre_review.py`、`document_intelligence_rules.py`、`ontology_extraction.py` 的 prompt 构造处。 +- 权限与校验:Canary 控制台仅算法/运营角色;门禁关闭需审计日志。 +- 持久化(新表,Alembic 迁移): + - `prompt_version`:id / scene / content / version / status(stable/canary/pinned) / eval_score / created_by / created_at。 + - `golden_set`:id / scene / case_payload / expected / source(accumulator/manual) / confirmed / version。 + - `extraction_correction`:id / attachment_id / field / raw_value / corrected_value / operator / created_at。 + - `eval_run`:id / prompt_version_id / scene / metrics_json / started_at / finished_at。 +- 降级处理:所有飞轮组件故障均降级到无 few-shot + stable prompt,主链路不阻塞。 + +### 算法与规则 + +- 输入:case 特征向量(场景标签 + 文本摘要 + 关键字段)、golden set、反馈样本、修正 diff。 +- 流程: + 1. 推理前:`few_shot_retrieval` 检索 top-k → 拼 system prompt。 + 2. 推理后:结果 + 反馈写入 accumulator。 + 3. 发布前:`eval_harness` 在 golden set 上跑分 → 门禁判定。 + 4. 低分触发:`feedback_attribution` 归因 → 改进任务回写样本池。 +- 输出:few-shot 样本块、评测指标、归因结论、智商曲线数据点。 +- 解释:few-shot 注入在 prompt 中保留"参考案例(历史已确认)"段落,可追溯;评测报告附错误 case 列表;归因输出错误环节标签 + 证据 trace 片段。 + +### 数据与契约 + +- 核心字段:scene、case_signature、few_shot_samples、metrics(acc/recall/fpr/f1)、prompt_version_id、status。 +- 状态枚举: + - prompt: `stable` / `canary` / `pinned` / `archived`。 + - golden case: `draft` / `confirmed` / `deprecated`。 + - eval: `pass` / `fail` / `blocked`。 +- 兼容策略:prompt_registry 找不到版本时回退到当前硬编码 prompt(保证向后兼容)。 +- 版本/审计:每次 prompt / 规则 / golden set 变更记 `eval_run`,可回放历史。 + +## 算法与公式 + +### few-shot 检索排序 + +```text +score(sample, case) = α * sim(emb(sample), emb(case)) + β * match(meta(sample), meta(case)) +``` + +变量说明: + +- score:样本与当前 case 的综合相似度。 +- sim:余弦相似度,复用 Qdrant 现有 embedding。 +- match:元数据硬匹配得分(场景同 / 域同 / 级别同),取 0 或 1。 +- α、β:权重,默认 α=0.8、β=0.2,可在 prompt_registry 中按场景覆盖。 +- 适用边界:仅取 `confirmed=true` 的样本;top-k 默认 k=3,按 token 预算动态裁剪。 + +### 评测指标 + +```text +precision = TP / (TP + FP) +recall = TP / (TP + FN) +f1 = 2 * precision * recall / (precision + recall) +``` + +变量说明: + +- TP/FP/FN:在 golden set 上推理结论与 expected 比对得出(结论为风险标记/字段值/分类标签三类场景各有比对器)。 +- 发布门禁默认阈值:recall ≥ 上一版 stable 的 recall 且 f1 不下降超过 2 个百分点,否则 `fail`。 + +## 测试方案 + +后端: + +- `few_shot_retrieval` 单测:样本池空 / 检索失败 / top-k 截断 / 仅取 confirmed 样本。 +- `eval_harness` 单测:golden set 跑分指标正确性、门禁通过/拦截逻辑、空 golden set 降级。 +- `prompt_registry` 单测:按策略取版本、回退到硬编码、Canary 流量切分比例。 +- `feedback_attribution` 单测:mock trace 数据,归因标签正确性。 + +前端: + +- AI 智商看板视图模型:空态、加载态、错误态、曲线渲染。 +- Canary 控制台:列表、切流量、回滚交互。 + +集成: + +- 端到端:构造一份 golden set → 改 prompt → 发布被门禁拦截 / 通过 → 智商看板出现新数据点。 +- 容器内运行:`docker exec -w /app -e SERVER_VENV_DIR=/tmp/x-financial-server-venv local-x-financial-linux /tmp/x-financial-server-venv/bin/pytest -q server/tests/...`,超时 60s。 + +手工验证: + +- 在 AI 工作台触发一次预审,确认 prompt 中出现 few-shot 块。 +- 在 Canary 控制台发布一版劣化 prompt,确认被门禁拦下。 + +## 指标与验收 + +- [A1] 功能验收:推理时 prompt 中可见 few-shot 块,且样本来自 confirmed 池。 +- [A2] 性能指标:few-shot 检索 ≤ 200ms(P95),不显著拖慢主链路;eval 单场景 ≤ 60s。 +- [A3] 质量指标:golden set 覆盖至少 5 个核心场景;门禁能正确拦截劣化 prompt。 +- [A4] 安全/权限指标:Canary 控制台仅算法/运营可操作;门禁关闭记审计日志。 +- [A5] 可观测性:AI 智商看板按周生成曲线;每次 eval_run 可回放。 + +## 风险与开放问题 + +- 风险: + - few-shot 注入增加 prompt 长度,可能触发 token 上限或拖慢推理 → 用 token 预算裁剪 + P95 监控兜底。 + - 样本池噪音(错误标注)污染推理 → 只取 confirmed 样本 + 评测门禁把关。 + - 评测 golden set 与线上分布漂移 → 季度复审 golden set,标注漂移度。 +- 已处理依赖:复用 Qdrant / LightRAG / accumulator / AgentTraceCenter / risk_rule_generation 现有能力。 +- 待确认: + - Canary 流量切分的具体比例(建议 90/10)需与业务确认。 + - 智商看板放哪个一级菜单(Settings 还是独立"AI 运营"菜单)。 + - 政策新鲜度检测是否本轮接入。 +- 降级策略:任何飞轮组件故障 → 无 few-shot + stable prompt + 跳过门禁(仅 stable),保证业务连续。 + +## 本轮实现记录 + +- 2026-07-03:完成数据飞轮概念文档与开发 TODO 拆分,作为后续改造的总纲。 diff --git a/document/development/2026-07-03/feature/ai-data-flywheel/TODO.md b/document/development/2026-07-03/feature/ai-data-flywheel/TODO.md new file mode 100644 index 0000000..e81121e --- /dev/null +++ b/document/development/2026-07-03/feature/ai-data-flywheel/TODO.md @@ -0,0 +1,98 @@ +# AI 数据飞轮 开发 TODO + +更新时间:2026-07-03 + +文档路径:document/development/2026-07-03/feature/ai-data-flywheel/TODO.md + +## 使用规则 + +- 每个 TODO 必须对应 `CONCEPT.md` 中的目标、能力、方案或验收点。 +- 只有完成并验证后,才能把 `[ ]` 改成 `[x]`。 +- 勾选时在任务后补充简短证据,例如文件、接口、命令或验证结果。 +- 如果需求发生变化,先更新 `CONCEPT.md`,再调整本 TODO。 +- 实施顺序建议:阶段 1 → 2 → 3(飞轮 1+2 是地基)→ 4/5/6 并行 → 7。 + +## 1. 调研与边界 + +- [x] [CONCEPT: 背景与问题] 盘点现有 accumulator / feedback / RAG / 规则生成能力,确认飞轮骨架已存在、断点在"检索注入 + 评测门禁"。 + 证据:`server/src/app/skills/domain/{false-positive-sample-accumulator,risk-feedback-sample-accumulator,risk-clue-collector}`、`services/agent_feedback.py`、`services/knowledge_rag_runtime.py`、`services/risk_rule_generation*.py`。 +- [x] [CONCEPT: 目标与非目标] 确认本轮范围 = 飞轮 1-6(few-shot 注入 / golden set 门禁 / prompt 版本化 / 抽取修正回流 / 低分归因 / 智商看板),不做模型微调、不改业务接口契约。 + 证据:CONCEPT.md「目标与非目标」章节。 +- [ ] [CONCEPT: 风险与开放问题] 与业务确认 Canary 流量比例、智商看板菜单位置、政策新鲜度检测是否本轮接入。 + 证据: + +## 2. 契约与设计 + +- [ ] [CONCEPT: 功能能力] 定义 4 张新表的字段、状态枚举(prompt stable/canary/pinned/archived、golden draft/confirmed/deprecated、eval pass/fail/blocked、correction)。 + 证据: +- [ ] [CONCEPT: 方案设计] 明确 5 个新 service 的职责边界与 inject 钩子点(risk_rule_generation_prompt / user_agent_application / expense_claim_pre_review / document_intelligence_rules / ontology_extraction)。 + 证据: +- [ ] [CONCEPT: 算法与公式] 确认 few-shot 检索排序公式权重(默认 α=0.8 β=0.2)与门禁阈值(recall 不降、f1 下降 ≤ 2pp)。 + 证据: +- [ ] [CONCEPT: 指标与验收] 把验收点 A1-A5 转成可验证检查项,附命令与期望结果。 + 证据: + +## 3. 后端实现 + +- [x] [CONCEPT: 后端] 新增 `services/few_shot_retrieval.py`:复用 Qdrant,按 case 特征检索 top-k confirmed 样本,带 token 预算裁剪。 + 证据:`server/src/app/services/few_shot_retrieval.py`;`server/src/app/services/few_shot_store.py`(独立 Qdrant collection `few_shot_samples`);`server/src/app/services/embedding_provider.py`(公共 EmbeddingProvider,复用 knowledge_rag_runtime 的 HTTP 调用)。 +- [ ] [CONCEPT: 后端] 新增 `services/prompt_registry.py`:prompt 版本 CRUD + 策略取版(latest/canary/pinned)+ 回退硬编码。 + 证据:飞轮 3(prompt 版本化 + Canary)未启动,本轮只做飞轮 1。 +- [ ] [CONCEPT: 后端] 新增 `services/eval_harness.py`:在 golden set 上跑评测,输出 precision/recall/f1,供门禁与看板共用。 + 证据:飞轮 2(golden set + 门禁)未启动,本轮只做飞轮 1。 +- [ ] [CONCEPT: 后端] 新增 `services/feedback_attribution.py`:低分反馈触发,复用 AgentTraceCenter trace 做归因,输出错误环节标签 + 改进任务。 + 证据:飞轮 5(低分归因)未启动,本轮只做飞轮 1。 +- [ ] [CONCEPT: 后端] 新增 `services/extraction_correction_recorder.py`:在附件/明细字段更新处记录 raw vs corrected diff。 + 证据:飞轮 4(抽取修正回流)未启动,本轮只做飞轮 1。 +- [ ] [CONCEPT: 后端] Alembic 迁移:prompt_version / golden_set / extraction_correction / eval_run 四张表。 + 证据:本轮新增的是 FewShotSample 一张表(`server/src/app/models/few_shot_sample.py`),项目靠 `Base.metadata.create_all()` 建表(无 alembic versions/ 目录),已注册到 `db/base.py` 和 `models/__init__.py`。其余三表随对应飞轮再建。 +- [x] [CONCEPT: 后端] 新增 `services/few_shot_ingestion.py`:RiskObservation confirmed/false_positive → FewShotSample + Qdrant 向量,在 `risk_observations.create_feedback` commit 后 hook 触发。 + 证据:`server/src/app/services/few_shot_ingestion.py`;`server/src/app/services/risk_observations.py:324-345`(`_maybe_ingest_few_shot` hook,带 feature flag + try/except 兜底)。 +- [x] [CONCEPT: 数据与契约] 在现有 prompt 构造文件加 few-shot 注入,不改业务接口。 + 证据:`server/src/app/services/risk_rule_generation_prompt.py`(新增 `few_shot_samples` 可选 kwarg,合并进 examples 字段);`server/src/app/services/risk_rule_generation.py:271-292`(`_retrieve_few_shot_samples` 在构造 messages 前调用,失败降级为空)。 + +## 4. 算法/规则实现 + +- [x] [CONCEPT: 算法与规则] 实现few-shot 检索排序(向量相似度 + 元数据硬匹配),只取 confirmed 样本。 + 证据:`server/src/app/services/few_shot_store.py`(Qdrant 余弦相似度 + payload 过滤 scene/label/status);`few_shot_retrieval.py` 去重 + token 预算 + 单条字符上限裁剪。检索仅取 label ∈ {confirmed, false_positive}。 +- [ ] [CONCEPT: 算法与规则] 实现评测指标比对器(风险标记 / 字段值 / 分类标签 三类场景)。 + 证据:飞轮 2,未启动。 +- [ ] [CONCEPT: 算法与规则] 接入发布门禁:`agent_asset_risk_rule_publish` 前调 eval_harness,不达标 block。 + 证据:飞轮 2,未启动。 +- [ ] [CONCEPT: 算法与规则] 接入 Canary 流量切分(默认 90 stable / 10 canary)+ 劣化自动回滚。 + 证据:飞轮 3,未启动。 +- [x] [CONCEPT: 结果解释] few-shot 块在 prompt 中保留 `source: "historical_confirmed"` 标记,可追溯。 + 证据:`risk_rule_generation_prompt.py` 合并 examples 时每条历史样本带 `source`/`label`/`conclusion` 字段。 + +## 5. 前端实现 + +- [ ] [CONCEPT: 前端] AI 智商看板新页面:准确率/召回率/误报率/人工干预率随时间曲线 + golden set 覆盖度,复用 `finance-report` 看板骨架。 + 证据: +- [ ] [CONCEPT: 前端] Canary 控制台(并入 Settings/Policies):prompt 版本列表、流量比例、分数、一键回滚。 + 证据: +- [ ] [CONCEPT: 前端] 实现加载/空态(样本不足)/错误态(评测失败)/权限态(仅算法运营)。 + 证据: +- [ ] [CONCEPT: 前端] 对齐现有企业后台风格(参考 `chat-ui-saas-styling` / `theme-settings-enterprise-ai-style`),避免营销页观感。 + 证据: + +## 6. 测试与验证 + +- [x] [CONCEPT: 测试方案] 后端单测:embedding_provider(GLM/Ollama 分支、维度缓存、HTTP 错误降级)、few_shot_ingestion(confirmed/false_positive 入库、ignored 跳过、幂等去重、hook 触发、feature flag、吞异常)、few_shot_retrieval(去重、token 预算、超长截断)+ prompt 注入(合并 examples、向后兼容)。 + 证据:`server/tests/test_embedding_provider.py`、`server/tests/test_few_shot_ingestion.py`、`server/tests/test_few_shot_retrieval_and_prompt.py`,容器内 `pytest -q` 20 passed。 +- [ ] [CONCEPT: 测试方案] 前端:智商看板与 Canary 控制台视图模型 + 构建验证。 + 证据:飞轮 3/6 前端,未启动。 +- [ ] [CONCEPT: 测试方案] 集成:golden set → 改 prompt → 门禁拦截/通过 → 看板新增数据点,容器内跑通。 + 证据:飞轮 2 集成,未启动。 +- [x] [CONCEPT: 测试方案] 回归:现有 RAG / risk_observations / risk_rule_generation 测试全过。 + 证据:容器内 `pytest -q server/tests/test_risk_observations_service.py server/tests/test_knowledge_rag_runtime.py server/tests/test_risk_rule_generation.py server/tests/test_risk_rule_generation_failure.py` → 35 passed,EmbeddingProvider 抽离零回归。 +- [ ] [CONCEPT: 指标与验收] 记录验证命令与结果,确认 P95 检索 ≤ 200ms、单场景评测 ≤ 60s。 + 证据:性能指标待飞轮 2 评测上线后连同 golden set 一起量。 + +## 7. 文档收尾 + +- [x] [CONCEPT: 指标与验收] 飞轮 1(few-shot 注入)A1 功能验收已达成:推理时 prompt 中可见带 `source: "historical_confirmed"` 的 few-shot 块,且样本来自 confirmed/false_positive 池。A5 可观测性部分达成(可追溯 source)。A2/A3/A4 随飞轮 2/3 补齐。 + 证据:见阶段 3/4/6 已勾选项。 +- [ ] [CONCEPT: 风险与开放问题] 更新 Canary 比例、看板菜单位置、政策新鲜度检测的最终结论与剩余风险。 + 证据:飞轮 2-6 启动后再定稿。 +- [x] [CONCEPT: 功能一句话] 确认飞轮 1 实现没有偏离"让系统越用越聪明"的原始目标。 + 证据:人工确认风险观测 → 自动入库 + 向量化 → 下次规则编译时检索注入相似历史样本,形成"用得越多 → 样本越丰富 → 推理越准"的闭环。飞轮 2-6 待后续迭代。 diff --git a/server/src/app/db/base.py b/server/src/app/db/base.py index b9ee3c7..2af5c13 100644 --- a/server/src/app/db/base.py +++ b/server/src/app/db/base.py @@ -15,6 +15,7 @@ from app.models.budget import BudgetAllocation, BudgetReservation, BudgetTransac from app.models.employee_change_log import EmployeeChangeLog from app.models.employee_behavior_profile import EmployeeBehaviorProfileSnapshot from app.models.employee import Employee +from app.models.few_shot_sample import FewShotSample from app.models.financial_record import ( AccountsPayableRecord, AccountsReceivableRecord, @@ -57,6 +58,7 @@ __all__ = [ "EmployeeBehaviorProfileSnapshot", "EmployeeChangeLog", "ExpenseClaim", + "FewShotSample", "ExpenseClaimItem", "HermesTaskConfig", "HermesTaskExecutionLog", diff --git a/server/src/app/models/__init__.py b/server/src/app/models/__init__.py index b3549eb..8c4796a 100644 --- a/server/src/app/models/__init__.py +++ b/server/src/app/models/__init__.py @@ -8,6 +8,7 @@ from app.models.budget import BudgetAllocation, BudgetReservation, BudgetTransac from app.models.employee_change_log import EmployeeChangeLog from app.models.employee_behavior_profile import EmployeeBehaviorProfileSnapshot from app.models.employee import Employee +from app.models.few_shot_sample import FewShotSample from app.models.financial_record import ( AccountsPayableRecord, AccountsReceivableRecord, @@ -49,6 +50,7 @@ __all__ = [ "EmployeeChangeLog", "ExpenseClaim", "ExpenseClaimItem", + "FewShotSample", "HermesTaskConfig", "HermesTaskExecutionLog", "HermesRiskReport", diff --git a/server/src/app/models/few_shot_sample.py b/server/src/app/models/few_shot_sample.py new file mode 100644 index 0000000..79ee503 --- /dev/null +++ b/server/src/app/models/few_shot_sample.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +import uuid +from datetime import datetime +from typing import Any + +from sqlalchemy import DateTime, ForeignKey, Index, String, Text, func +from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy.types import JSON + +from app.db.base_class import Base + + +class FewShotSample(Base): + """已确认的风险观测样本,供 few-shot 检索注入使用。 + + 数据来源是 ``RiskObservation`` 上人工确认为 confirmed / false_positive 的观测, + 入库后同时写一份向量到 Qdrant 的 ``few_shot_samples`` collection。 + """ + + __tablename__ = "few_shot_samples" + __table_args__ = ( + Index("ix_few_shot_samples_scene_label", "scene", "label"), + Index("ix_few_shot_samples_domain_risk_type", "domain", "risk_type"), + ) + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + sample_key: Mapped[str] = mapped_column(String(160), unique=True, index=True) + source_observation_id: Mapped[str | None] = mapped_column( + ForeignKey("risk_observations.id"), + nullable=True, + index=True, + ) + + scene: Mapped[str] = mapped_column(String(50), default="risk_rule_generation", index=True) + domain: Mapped[str] = mapped_column(String(50), default="", index=True) + risk_type: Mapped[str] = mapped_column(String(80), default="", index=True) + risk_level: Mapped[str] = mapped_column(String(20), default="") + + label: Mapped[str] = mapped_column(String(30), default="confirmed", index=True) + case_text: Mapped[str] = mapped_column(Text(), default="") + conclusion_text: Mapped[str] = mapped_column(Text(), default="") + payload_json: Mapped[dict[str, Any]] = mapped_column(JSON, default=dict) + + vector_id: Mapped[str | None] = mapped_column(String(100), nullable=True) + status: Mapped[str] = mapped_column(String(20), default="active", index=True) + + created_at: Mapped[datetime] = mapped_column(DateTime, default=func.now(), server_default=func.now()) + updated_at: Mapped[datetime] = mapped_column( + DateTime, + default=func.now(), + onupdate=func.now(), + server_default=func.now(), + ) diff --git a/server/src/app/services/embedding_provider.py b/server/src/app/services/embedding_provider.py new file mode 100644 index 0000000..5cce2d2 --- /dev/null +++ b/server/src/app/services/embedding_provider.py @@ -0,0 +1,138 @@ +"""公共 Embedding 提供者。 + +把 ``knowledge_rag_runtime`` 里 embedding 调用逻辑抽出来,供 RAG 和 +few-shot 检索复用。本模块只依赖现有模块级纯函数和 ``RuntimeModelConfig``, +不改动 ``_LightRagRuntime`` 的行为,RAG 路径保持零回归风险。 + +典型用法:: + + provider = EmbeddingProvider.from_settings(session) + vectors = provider.embed(["差旅超标", "票单不一致"]) + dim = provider.dimension() +""" + +from __future__ import annotations + +from typing import Any + +from app.core.logging import get_logger +from app.services.knowledge_rag_runtime import ( + DEFAULT_EMBEDDING_TIMEOUT_SECONDS, + KnowledgeRagError, + RuntimeModelConfig, + _build_headers, + _ensure_path, + _extract_embedding_vectors, + _normalize_endpoint, + _send_json_request, +) + +logger = get_logger("app.services.embedding_provider") + + +def _runtime_model_config_from_dict(config: dict[str, str]) -> RuntimeModelConfig: + """把 SettingsService.get_runtime_model_config 返回的 dict 转成 dataclass。""" + + return RuntimeModelConfig( + slot=str(config.get("slot") or "embedding"), + provider=str(config.get("provider") or ""), + model=str(config.get("model") or ""), + endpoint=str(config.get("endpoint") or ""), + api_key=str(config.get("apiKey") or ""), + capability=str(config.get("capability") or ""), + ) + + +class EmbeddingProvider: + """对 embedding 模型的轻量封装。 + + 设计要点: + - 持有一个 ``RuntimeModelConfig``,构造即固定,不依赖 LightRAG。 + - 复用 ``knowledge_rag_runtime`` 的 HTTP 调用纯函数,行为与 RAG 完全一致。 + - 维度采用惰性探测(首次 embed 后缓存),避免空构造就打远端。 + """ + + def __init__(self, config: RuntimeModelConfig) -> None: + self.config = config + self._dimension: int | None = None + + @classmethod + def from_settings(cls, session: Any) -> "EmbeddingProvider": + """从 SettingsService 取 embedding 配置构造 provider。""" + + from app.services.settings import SettingsService + + raw = SettingsService(session).get_runtime_model_config("embedding") + return cls(_runtime_model_config_from_dict(raw)) + + def embed(self, texts: list[str]) -> list[list[float]]: + """对一组文本做 embedding,返回与输入等长的向量列表。""" + + if not texts: + return [] + return _request_embeddings_public(self.config, texts) + + def dimension(self) -> int: + """探测 embedding 维度,结果缓存。失败抛 KnowledgeRagError。""" + + if self._dimension is None: + vectors = self.embed(["dimension probe"]) + if not vectors or not isinstance(vectors[0], list): + raise KnowledgeRagError("无法从 embedding 模型返回结果中解析向量维度。") + self._dimension = len(vectors[0]) + if self._dimension <= 0: + raise KnowledgeRagError("embedding 模型返回了无效的向量维度。") + return self._dimension + + +def _request_embeddings_public( + config: RuntimeModelConfig, + texts: list[str], +) -> list[list[float]]: + """按 provider 分支构造 embedding 请求。 + + 与 ``_LightRagRuntime._request_embeddings`` 实现保持一致, + 保证 few-shot 检索与 RAG 走同一套调用语义。 + """ + + from app.services.model_connectivity import AZURE_API_VERSION + + if config.provider == "Azure OpenAI": + from app.services.knowledge_rag_runtime import _build_azure_deployment_base + + url = f"{_build_azure_deployment_base(config.endpoint, config.model)}/embeddings?api-version={AZURE_API_VERSION}" + payload: dict[str, Any] = {"input": texts} + status_code, body = _send_json_request( + "POST", + url, + headers=_build_headers(config.api_key, use_bearer=False, use_api_key=True), + payload=payload, + timeout_seconds=DEFAULT_EMBEDDING_TIMEOUT_SECONDS, + ) + elif config.provider == "Ollama": + url = _ensure_path(_normalize_endpoint(config.endpoint), "api/embed") + payload = {"model": config.model, "input": texts} + status_code, body = _send_json_request( + "POST", + url, + headers={"Content-Type": "application/json", "Accept": "application/json"}, + payload=payload, + timeout_seconds=DEFAULT_EMBEDDING_TIMEOUT_SECONDS, + ) + else: + url = _ensure_path(_normalize_endpoint(config.endpoint), "embeddings") + payload = {"model": config.model, "input": texts} + status_code, body = _send_json_request( + "POST", + url, + headers=_build_headers(config.api_key, use_bearer=True), + payload=payload, + timeout_seconds=DEFAULT_EMBEDDING_TIMEOUT_SECONDS, + ) + + from http import HTTPStatus + + if status_code >= HTTPStatus.BAD_REQUEST: + raise KnowledgeRagError(f"embedding 模型返回异常状态码 {status_code}。") + + return _extract_embedding_vectors(body, provider=config.provider) diff --git a/server/src/app/services/few_shot_ingestion.py b/server/src/app/services/few_shot_ingestion.py new file mode 100644 index 0000000..d8914f9 --- /dev/null +++ b/server/src/app/services/few_shot_ingestion.py @@ -0,0 +1,177 @@ +"""Few-shot 样本入库编排:RiskObservation → FewShotSample → Qdrant。 + +只处理人工确认为 confirmed / false_positive 的观测,把它转成一条 +:class:`FewShotSample`,持久化到 DB,并同步向量到 Qdrant。 + +入库动作由 :meth:`RiskObservationService.create_feedback` 在 commit 后触发, +本服务全程吞异常(只记日志),保证反馈主流程不被 few-shot 链路拖崩。 +""" + +from __future__ import annotations + +from typing import Any + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from app.core.logging import get_logger +from app.models.few_shot_sample import FewShotSample +from app.models.risk_observation import RiskObservation, RiskObservationFeedback +from app.services.embedding_provider import EmbeddingProvider +from app.services.few_shot_store import FewShotStore + +logger = get_logger("app.services.few_shot_ingestion") + +# 仅这两个 feedback_status 视为已确认样本,会入库 +CONFIRMED_LABELS = {"confirmed", "false_positive"} + +# label → 自然语言结论(当 feedback.comment 缺失时兜底) +LABEL_CONCLUSION_FALLBACK = { + "confirmed": "经人工复核确认,该风险线索成立,需按规则拦截或补件。", + "false_positive": "经人工复核判定为误报,相似情形不应触发该风险规则。", +} + + +class FewShotIngestionService: + """把已确认的风险观测沉淀为 few-shot 样本。""" + + def __init__(self, db: Session) -> None: + self.db = db + + def ingest_observation_feedback( + self, + observation: RiskObservation, + feedback: RiskObservationFeedback, + ) -> FewShotSample | None: + """人工确认/误报后调用,写入并同步向量。""" + + label = observation.feedback_status + if label not in CONFIRMED_LABELS: + return None + + sample_key = f"obs:{observation.id}" + sample = self.db.scalar( + select(FewShotSample).where(FewShotSample.sample_key == sample_key) + ) + + domain = self._extract_domain(observation) + case_text = self._build_case_text(observation) + conclusion_text = self._build_conclusion_text(observation, feedback, label) + payload = self._build_payload(observation, feedback, label) + + if sample is None: + sample = FewShotSample( + sample_key=sample_key, + source_observation_id=observation.id, + scene="risk_rule_generation", + domain=domain, + risk_type=observation.risk_type or "", + risk_level=observation.risk_level or "", + label=label, + case_text=case_text, + conclusion_text=conclusion_text, + payload_json=payload, + status="active", + ) + self.db.add(sample) + else: + sample.label = label + sample.domain = domain + sample.risk_type = observation.risk_type or "" + sample.risk_level = observation.risk_level or "" + sample.case_text = case_text + sample.conclusion_text = conclusion_text + sample.payload_json = payload + sample.status = "active" + sample.vector_id = sample.vector_id + try: + self.db.commit() + self.db.refresh(sample) + except Exception: + logger.exception("few-shot 样本持久化失败 observation_id=%s", observation.id) + self.db.rollback() + return None + + vector_id = self._store().upsert(sample) + if vector_id: + sample.vector_id = vector_id + try: + self.db.commit() + except Exception: + logger.warning("few-shot vector_id 回写失败 sample_id=%s", sample.id) + return sample + + def retract_observation(self, observation_id: str) -> bool: + """观测被撤销时删掉对应样本及其向量。""" + + sample = self.db.scalar( + select(FewShotSample).where(FewShotSample.source_observation_id == observation_id) + ) + if sample is None: + return False + if sample.vector_id: + self._store().delete_by_vector_id(sample.vector_id) + try: + self.db.delete(sample) + self.db.commit() + return True + except Exception: + logger.exception("few-shot 样本删除失败 observation_id=%s", observation_id) + self.db.rollback() + return False + + def _store(self) -> FewShotStore: + provider = EmbeddingProvider.from_settings(self.db) + return FewShotStore(provider) + + def _extract_domain(self, observation: RiskObservation) -> str: + ontology = observation.ontology_json or {} + return str(ontology.get("domain") or "") + + def _build_case_text(self, observation: RiskObservation) -> str: + parts = [ + observation.title or "", + observation.description or "", + observation.risk_signal or "", + observation.risk_type or "", + ] + ontology = observation.ontology_json or {} + scenario = ontology.get("scenario") + if scenario: + parts.append(f"场景:{scenario}") + risk_signals = ontology.get("risk_signals") + if isinstance(risk_signals, list) and risk_signals: + parts.append("信号:" + "|".join(str(s) for s in risk_signals)) + return "\n".join(part for part in parts if part).strip() + + def _build_conclusion_text( + self, + observation: RiskObservation, + feedback: RiskObservationFeedback, + label: str, + ) -> str: + comment = (feedback.comment or "").strip() + if comment: + return f"[{label}] {comment}" + return LABEL_CONCLUSION_FALLBACK.get(label, label) + + def _build_payload( + self, + observation: RiskObservation, + feedback: RiskObservationFeedback, + label: str, + ) -> dict[str, Any]: + return { + "label": label, + "risk_type": observation.risk_type, + "risk_signal": observation.risk_signal, + "risk_level": observation.risk_level, + "feedback_type": feedback.feedback_type, + "feedback_comment": feedback.comment or "", + "feedback_actor": feedback.actor or "", + "ontology": observation.ontology_json or {}, + "policy_refs": observation.policy_refs_json or [], + "evidence": observation.evidence_json or [], + "subject_label": observation.subject_label or "", + "claim_no": observation.claim_no or "", + } diff --git a/server/src/app/services/few_shot_retrieval.py b/server/src/app/services/few_shot_retrieval.py new file mode 100644 index 0000000..d08f5f8 --- /dev/null +++ b/server/src/app/services/few_shot_retrieval.py @@ -0,0 +1,122 @@ +"""Few-shot 检索器:按当前 case 特征检索相似历史样本,拼成注入块。 + +从 :class:`FewShotStore` 取相似样本,转成可供 prompt 构造函数直接使用的结构。 +带 token 预算裁剪和去重,确保不撑爆 prompt。 + +典型用法(在构造 prompt 之前调用):: + + retriever = FewShotRetriever.from_session(session) + samples = retriever.retrieve_for_risk_rule_generation( + domain="travel", natural_language="票据城市与申报地不一致" + ) + messages = build_risk_rule_compiler_messages( + ..., + few_shot_samples=samples, + ) +""" + +from __future__ import annotations + +from typing import Any + +from sqlalchemy.orm import Session + +from app.core.logging import get_logger +from app.services.embedding_provider import EmbeddingProvider +from app.services.few_shot_store import FewShotStore + +logger = get_logger("app.services.few_shot_retrieval") + +# 单条 few-shot 样本估算 token 数(用于预算裁剪) +SAMPLE_TOKEN_BUDGET = 1200 +# 单条样本最大字符数,超长直接截断结论,避免撑爆 prompt +SINGLE_SAMPLE_MAX_CHARS = 400 +# 历史样本最多注入条数(与原内置 examples 合并后总量受限) +MAX_HISTORICAL_SAMPLES = 3 + + +class FewShotRetriever: + """按 case 特征检索已确认样本,返回 prompt 可直接消费的结构。""" + + def __init__(self, store: FewShotStore) -> None: + self._store = store + + @classmethod + def from_session(cls, session: Session) -> "FewShotRetriever": + provider = EmbeddingProvider.from_settings(session) + return cls(FewShotStore(provider)) + + def retrieve_for_risk_rule_generation( + self, + *, + domain: str = "", + risk_type: str = "", + natural_language: str, + top_k: int = MAX_HISTORICAL_SAMPLES, + ) -> list[dict[str, Any]]: + """检索与当前规则需求相似的历史样本,返回注入块列表。""" + + case_text = self._build_case_text( + natural_language=natural_language, + domain=domain, + risk_type=risk_type, + ) + if not case_text: + return [] + hits = self._store.search( + case_text, + scene="risk_rule_generation", + labels=["confirmed", "false_positive"], + top_k=top_k, + ) + return self._hits_to_injection_blocks(hits) + + def _build_case_text( + self, + *, + natural_language: str, + domain: str = "", + risk_type: str = "", + ) -> str: + parts = [natural_language, domain, risk_type] + return "\n".join(p for p in parts if p).strip() + + def _hits_to_injection_blocks( + self, + hits: list[dict[str, Any]], + ) -> list[dict[str, Any]]: + """把检索命中转成 prompt 可消费的块,做去重和预算裁剪。""" + + blocks: list[dict[str, Any]] = [] + seen_conclusions: set[str] = set() + budget = SAMPLE_TOKEN_BUDGET + for hit in hits: + conclusion = (hit.get("conclusion_text") or "").strip() + if not conclusion or conclusion in seen_conclusions: + continue + # 超长结论截断到上限,避免单条样本占用过多预算 + if len(conclusion) > SINGLE_SAMPLE_MAX_CHARS: + conclusion = conclusion[:SINGLE_SAMPLE_MAX_CHARS] + payload = hit.get("payload_json") or {} + block = { + "source": "historical_confirmed", + "label": hit.get("label"), + "domain": hit.get("domain") or "", + "risk_type": hit.get("risk_type") or "", + "score": round(float(hit.get("score") or 0.0), 4), + "conclusion": conclusion, + "context": { + "risk_signal": payload.get("risk_signal") or "", + "risk_level": payload.get("risk_level") or "", + "ontology": payload.get("ontology") or {}, + "feedback_comment": payload.get("feedback_comment") or "", + }, + } + # 粗略 token 估算(按字符数 / 1.6 近似中文 token 比) + estimated_tokens = int(len(conclusion) / 1.6) + 40 + if estimated_tokens > budget: + break + budget -= estimated_tokens + blocks.append(block) + seen_conclusions.add(conclusion) + return blocks diff --git a/server/src/app/services/few_shot_store.py b/server/src/app/services/few_shot_store.py new file mode 100644 index 0000000..071b28f --- /dev/null +++ b/server/src/app/services/few_shot_store.py @@ -0,0 +1,214 @@ +"""Few-shot 样本的 Qdrant 向量存储。 + +独立于 LightRAG 的 Qdrant 客户端,使用专用 collection ``few_shot_samples``, +与知识库 RAG 的 collection 隔离。所有操作失败都不抛异常(记日志返回空), +保证主链路不阻塞。 + +向量来自 :class:`EmbeddingProvider`,payload 带业务过滤字段(scene/label/domain/risk_type), +检索时按这些字段过滤 + 向量相似度排序。 +""" + +from __future__ import annotations + +import os +import uuid +from typing import Any + +from app.core.logging import get_logger +from app.services.knowledge_rag import _resolve_default_qdrant_url + +logger = get_logger("app.services.few_shot_store") + +FEW_SHOT_COLLECTION = "few_shot_samples" + + +def _resolve_qdrant_config() -> tuple[str, str]: + """复用 knowledge_rag 的 Qdrant URL/key 解析逻辑。""" + + url = os.environ.get("QDRANT_URL", "").strip() or _resolve_default_qdrant_url() + api_key = os.environ.get("QDRANT_API_KEY", "").strip() + return url, api_key + + +class FewShotStore: + """对 Qdrant 的轻量封装,专供 few-shot 样本检索使用。 + + 设计要点: + - 惰性创建 client 和 collection,首次操作时初始化。 + - 所有公共方法吞异常(返回空/False),主链路永远不被拖崩。 + - 向量写入和检索都依赖外部传入的 :class:`EmbeddingProvider`, + 由调用方保证与配置一致。 + """ + + def __init__(self, embedding_provider: Any) -> None: + self._embedding_provider = embedding_provider + self._client: Any = None + self._ensured = False + + def _client_or_none(self) -> Any: + """惰性初始化 QdrantClient,失败返回 None。""" + + if self._client is not None: + return self._client + try: + from qdrant_client import QdrantClient + + url, api_key = _resolve_qdrant_config() + self._client = QdrantClient(url=url, api_key=api_key or None) + except Exception: + logger.warning("few-shot QdrantClient 初始化失败,本轮操作跳过", exc_info=True) + self._client = None + return self._client + + def _ensure_collection(self) -> bool: + """确保 collection 存在,成功返回 True。""" + + if self._ensured: + return True + client = self._client_or_none() + if client is None: + return False + try: + from qdrant_client.http.exceptions import UnexpectedResponse + + try: + client.get_collection(FEW_SHOT_COLLECTION) + self._ensured = True + return True + except UnexpectedResponse as exc: + if exc.status_code != 404: + raise + # collection 不存在则创建 + dim = self._embedding_provider.dimension() + from qdrant_client.http.models import ( + Distance, + VectorParams, + PayloadSchemaType, + ) + + client.create_collection( + collection_name=FEW_SHOT_COLLECTION, + vectors_config=VectorParams(size=dim, distance=Distance.COSINE), + ) + for field, field_type in [ + ("sample_id", PayloadSchemaType.KEYWORD), + ("scene", PayloadSchemaType.KEYWORD), + ("label", PayloadSchemaType.KEYWORD), + ("domain", PayloadSchemaType.KEYWORD), + ("risk_type", PayloadSchemaType.KEYWORD), + ("status", PayloadSchemaType.KEYWORD), + ]: + try: + client.create_payload_index( + collection_name=FEW_SHOT_COLLECTION, + field_name=field, + field_schema=field_type, + ) + except Exception: + logger.debug("payload index 创建跳过 field=%s", field, exc_info=True) + self._ensured = True + logger.info("few-shot collection 创建成功 dim=%s", dim) + return True + except Exception: + logger.warning("few-shot collection 初始化失败,本轮操作跳过", exc_info=True) + return False + + def upsert(self, sample: Any) -> str | None: + """把一条样本向量化并写入 Qdrant,返回 vector_id,失败返回 None。""" + + if not self._ensure_collection(): + return None + client = self._client + try: + vector = self._embedding_provider.embed([sample.case_text])[0] + except Exception: + logger.warning("few-shot embedding 失败 sample_key=%s", getattr(sample, "sample_key", ""), exc_info=True) + return None + vector_id = uuid.uuid4().hex + payload = { + "sample_id": sample.id, + "scene": sample.scene, + "label": sample.label, + "domain": sample.domain, + "risk_type": sample.risk_type, + "risk_level": sample.risk_level, + "status": getattr(sample, "status", "active"), + "conclusion_text": sample.conclusion_text, + "payload_json": sample.payload_json, + } + try: + client.upsert( + collection_name=FEW_SHOT_COLLECTION, + points=[{"id": vector_id, "vector": vector, "payload": payload}], + ) + return vector_id + except Exception: + logger.warning("few-shot upsert 失败 sample_key=%s", getattr(sample, "sample_key", ""), exc_info=True) + return None + + def search( + self, + case_text: str, + *, + scene: str | None = None, + labels: list[str] | None = None, + top_k: int = 3, + ) -> list[dict[str, Any]]: + """按 case_text 检索相似样本,可按 scene/label 过滤。失败返回空列表。""" + + if not case_text or not self._ensure_collection(): + return [] + client = self._client + try: + vector = self._embedding_provider.embed([case_text])[0] + except Exception: + logger.warning("few-shot 检索 embedding 失败", exc_info=True) + return [] + must: list[dict[str, Any]] = [{"key": "status", "match": {"value": "active"}}] + if scene: + must.append({"key": "scene", "match": {"value": scene}}) + if labels: + must.append({"key": "label", "match": {"any": labels}}) + try: + from qdrant_client.http.models import Filter + + results = client.query_points( + collection_name=FEW_SHOT_COLLECTION, + query=vector, + query_filter=Filter(must=must), + limit=top_k, + with_payload=True, + ).points + except Exception: + logger.warning("few-shot 检索失败", exc_info=True) + return [] + hits: list[dict[str, Any]] = [] + for point in results: + payload = getattr(point, "payload", None) or {} + hits.append( + { + "sample_id": payload.get("sample_id"), + "score": float(getattr(point, "score", 0.0)), + "label": payload.get("label"), + "domain": payload.get("domain"), + "risk_type": payload.get("risk_type"), + "conclusion_text": payload.get("conclusion_text") or "", + "payload_json": payload.get("payload_json") or {}, + } + ) + return hits + + def delete_by_vector_id(self, vector_id: str) -> bool: + """按 vector_id 删除向量,失败返回 False。""" + + if not vector_id or not self._ensure_collection(): + return False + try: + self._client.delete( + collection_name=FEW_SHOT_COLLECTION, + points_selector=[vector_id], + ) + return True + except Exception: + logger.warning("few-shot 删除失败 vector_id=%s", vector_id, exc_info=True) + return False diff --git a/server/src/app/services/risk_observations.py b/server/src/app/services/risk_observations.py index 74d294b..1481a0c 100644 --- a/server/src/app/services/risk_observations.py +++ b/server/src/app/services/risk_observations.py @@ -1,5 +1,6 @@ from __future__ import annotations +import os from datetime import UTC, datetime, timedelta from decimal import Decimal from typing import Any @@ -8,6 +9,7 @@ from sqlalchemy import func, select from sqlalchemy.orm import Session, joinedload from app.algorithem.risk_graph import RiskHistoryStats, RiskObservationDraft +from app.core.logging import get_logger from app.db.base import Base from app.models.financial_record import ExpenseClaim from app.models.risk_observation import RiskObservation, RiskObservationFeedback @@ -17,6 +19,8 @@ from app.schemas.risk_observation import ( ) from app.services.expense_claim_risk_stage import normalize_risk_business_stage +logger = get_logger("app.services.risk_observations") + HIGH_LEVELS = {"high", "critical"} SEVERITY_SCORE = { "low": 32, @@ -322,8 +326,27 @@ class RiskObservationService: observation.status, observation.feedback_status = mapped self.db.commit() self.db.refresh(feedback) + self._maybe_ingest_few_shot(observation, feedback) return feedback + def _maybe_ingest_few_shot( + self, + observation: RiskObservation, + feedback: RiskObservationFeedback, + ) -> None: + """人工确认/误报后把样本沉淀进 few-shot 池,任何失败都不影响主流程。""" + + if os.environ.get("FEW_SHOT_INJECTION_ENABLED", "true").strip().lower() in {"0", "false", "no"}: + return + if observation.feedback_status not in {"confirmed", "false_positive"}: + return + try: + from app.services.few_shot_ingestion import FewShotIngestionService + + FewShotIngestionService(self.db).ingest_observation_feedback(observation, feedback) + except Exception: + logger.exception("few-shot ingestion failed for observation %s", observation.id) + def summarize_dashboard( self, *, diff --git a/server/src/app/services/risk_rule_generation.py b/server/src/app/services/risk_rule_generation.py index 05088d8..c481d26 100644 --- a/server/src/app/services/risk_rule_generation.py +++ b/server/src/app/services/risk_rule_generation.py @@ -234,6 +234,10 @@ class RiskRuleGenerationService: } for item in fields ] + few_shot_samples = self._retrieve_few_shot_samples( + domain=domain, + natural_language=natural_language, + ) messages = build_risk_rule_compiler_messages( domain=domain, domain_label=BUSINESS_DOMAIN_LABELS[domain], @@ -243,6 +247,7 @@ class RiskRuleGenerationService: expense_category_label=expense_category_label, natural_language=natural_language, available_fields=field_payload, + few_shot_samples=few_shot_samples, ) answer = self.runtime_chat_service.complete( messages, @@ -263,6 +268,29 @@ class RiskRuleGenerationService: payload = unwrap_semantic_plan_payload(payload) return self._sanitize_model_draft(payload, fields=fields) + def _retrieve_few_shot_samples( + self, + *, + domain: str, + natural_language: str, + ) -> list[dict[str, Any]]: + """检索已确认历史样本,失败降级为空列表。""" + + import os + + if os.environ.get("FEW_SHOT_INJECTION_ENABLED", "true").strip().lower() in {"0", "false", "no"}: + return [] + try: + from app.services.few_shot_retrieval import FewShotRetriever + + retriever = FewShotRetriever.from_session(self.db) + return retriever.retrieve_for_risk_rule_generation( + domain=domain, + natural_language=natural_language, + ) + except Exception: + return [] + def _sanitize_model_draft( self, payload: dict[str, Any], diff --git a/server/src/app/services/risk_rule_generation_prompt.py b/server/src/app/services/risk_rule_generation_prompt.py index f38b92c..7707eb4 100644 --- a/server/src/app/services/risk_rule_generation_prompt.py +++ b/server/src/app/services/risk_rule_generation_prompt.py @@ -14,10 +14,15 @@ def build_risk_rule_compiler_messages( expense_category_label: str, natural_language: str, available_fields: list[dict[str, Any]], + few_shot_samples: list[dict[str, Any]] | None = None, ) -> list[dict[str, str]]: """构造自然语言规则编译提示词。 大模型只负责把业务语言拆成“语义计划”,后端会校验字段、操作符和模板。 + + ``few_shot_samples`` 是从已确认历史样本中检索出来的相似案例,会被合并进 + ``examples`` 字段并标注 ``source: "historical_confirmed"``,让编译器参考 + 过往人工结论。传 ``None`` 或空列表时行为与历史完全一致(向后兼容)。 """ schema = { @@ -161,6 +166,20 @@ def build_risk_rule_compiler_messages( }, } ] + historical_examples: list[dict[str, Any]] = [] + if few_shot_samples: + for sample in few_shot_samples: + historical_examples.append( + { + "source": "historical_confirmed", + "label": sample.get("label"), + "domain": sample.get("domain") or "", + "risk_type": sample.get("risk_type") or "", + "conclusion": sample.get("conclusion") or "", + "context": sample.get("context") or {}, + } + ) + merged_examples = historical_examples + examples return [ { "role": "system", @@ -186,7 +205,7 @@ def build_risk_rule_compiler_messages( "natural_language": natural_language, "available_fields": available_fields, "required_json_shape": response_schema, - "examples": examples, + "examples": merged_examples, }, ensure_ascii=False, ), diff --git a/server/tests/test_embedding_provider.py b/server/tests/test_embedding_provider.py new file mode 100644 index 0000000..0949a69 --- /dev/null +++ b/server/tests/test_embedding_provider.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +from unittest.mock import patch + +import pytest + +from app.services.embedding_provider import EmbeddingProvider, _runtime_model_config_from_dict +from app.services.knowledge_rag_runtime import KnowledgeRagError, RuntimeModelConfig + + +def _config(provider: str = "GLM") -> RuntimeModelConfig: + return RuntimeModelConfig( + slot="embedding", + provider=provider, + model="Embedding-3", + endpoint="https://open.bigmodel.cn/api/paas/v4/", + api_key="k", + capability="embedding", + ) + + +def test_runtime_model_config_from_dict_maps_fields() -> None: + cfg = _runtime_model_config_from_dict( + { + "slot": "embedding", + "provider": "GLM", + "model": "Embedding-3", + "endpoint": "https://e", + "apiKey": "secret", + "capability": "embedding", + } + ) + assert cfg.api_key == "secret" + assert cfg.model == "Embedding-3" + + +def test_embed_empty_texts_returns_empty() -> None: + provider = EmbeddingProvider(_config()) + assert provider.embed([]) == [] + + +def test_embed_returns_vectors_and_caches_dimension() -> None: + provider = EmbeddingProvider(_config()) + with patch( + "app.services.embedding_provider._request_embeddings_public", + return_value=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], + ) as mock_req: + vectors = provider.embed(["a", "b"]) + assert vectors == [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + assert provider.dimension() == 3 + calls_after_first_dimension = mock_req.call_count + # 第二次 dimension 不应再次请求 + assert provider.dimension() == 3 + assert mock_req.call_count == calls_after_first_dimension + + +def test_dimension_raises_on_invalid_vectors() -> None: + provider = EmbeddingProvider(_config()) + with patch( + "app.services.embedding_provider._request_embeddings_public", + return_value=[], + ): + with pytest.raises(KnowledgeRagError): + provider.dimension() + + +def test_request_embeddings_public_glm_branch() -> None: + cfg = _config("GLM") + with patch( + "app.services.embedding_provider._send_json_request", + return_value=(200, {"data": [{"embedding": [0.1, 0.2]}]}), + ) as mock_send: + from app.services.embedding_provider import _request_embeddings_public + + vectors = _request_embeddings_public(cfg, ["x"]) + assert vectors == [[0.1, 0.2]] + called_url = mock_send.call_args.args[1] + assert called_url.endswith("/embeddings") + + +def test_request_embeddings_public_ollama_branch() -> None: + cfg = _config("Ollama") + with patch( + "app.services.embedding_provider._send_json_request", + return_value=(200, {"embeddings": [[0.5, 0.6]]}), + ) as mock_send: + from app.services.embedding_provider import _request_embeddings_public + + vectors = _request_embeddings_public(cfg, ["x"]) + assert vectors == [[0.5, 0.6]] + called_url = mock_send.call_args.args[1] + assert called_url.endswith("/api/embed") + + +def test_request_embeddings_public_raises_on_http_error() -> None: + cfg = _config("GLM") + with patch( + "app.services.embedding_provider._send_json_request", + return_value=(500, {"message": "boom"}), + ): + from app.services.embedding_provider import _request_embeddings_public + + with pytest.raises(KnowledgeRagError): + _request_embeddings_public(cfg, ["x"]) diff --git a/server/tests/test_few_shot_ingestion.py b/server/tests/test_few_shot_ingestion.py new file mode 100644 index 0000000..244455c --- /dev/null +++ b/server/tests/test_few_shot_ingestion.py @@ -0,0 +1,214 @@ +from __future__ import annotations + +from collections.abc import Generator +from datetime import datetime +from decimal import Decimal +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy.pool import StaticPool + +from app.db.base import Base +from app.models.employee import Employee +from app.models.few_shot_sample import FewShotSample +from app.models.financial_record import ExpenseClaim +from app.models.risk_observation import RiskObservation +from app.schemas.risk_observation import RiskObservationFeedbackCreate +from app.services.few_shot_ingestion import FewShotIngestionService +from app.services.risk_observations import RiskObservationService + + +def _build_session() -> Session: + engine = create_engine( + "sqlite+pysqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + Base.metadata.create_all(bind=engine) + factory = sessionmaker(bind=engine, autoflush=False, autocommit=False) + return factory() + + +def _observation(db: Session, key: str = "risk:c1:dup") -> RiskObservation: + db.add(Employee(id="emp-1", employee_no="E1", name="员工", email="e@e.com", grade="P6")) + db.add( + ExpenseClaim( + id="c1", + claim_no="BX-001", + employee_id="emp-1", + employee_name="员工", + department_name="风控部", + expense_type="travel", + reason="客户拜访", + location="上海", + amount=Decimal("1000"), + currency="CNY", + occurred_at=datetime(2026, 1, 1), + submitted_at=datetime(2026, 1, 1), + status="submitted", + approval_stage="manager_review", + risk_flags_json=[], + ) + ) + db.flush() + obs = RiskObservation( + observation_key=key, + subject_type="expense_claim", + subject_key="claim:c1", + claim_id="c1", + claim_no="BX-001", + risk_type="duplicate_invoice", + risk_signal="duplicate_invoice", + title="重复发票", + description="同一发票出现在多张报销单", + risk_score=86, + risk_level="high", + confidence_score=0.8, + source="financial_risk_graph", + algorithm_version="v1", + ontology_json={"domain": "expense", "scenario": "reimbursement"}, + ) + db.add(obs) + db.commit() + db.refresh(obs) + return obs + + +def test_ingest_confirmed_persists_sample_and_calls_store() -> None: + with _build_session() as db: + obs = _observation(db) + obs.feedback_status = "confirmed" + service = FewShotIngestionService(db) + fake_store = MagicMock() + fake_store.upsert.return_value = "vec-1" + with patch.object(service, "_store", return_value=fake_store): + sample = service.ingest_observation_feedback( + obs, + MagicMock(feedback_type="confirm", comment="确认重复发票", actor="audit"), + ) + assert sample is not None + assert sample.label == "confirmed" + assert sample.sample_key == f"obs:{obs.id}" + assert "重复发票" in sample.case_text + assert "确认重复发票" in sample.conclusion_text + assert sample.vector_id == "vec-1" + fake_store.upsert.assert_called_once() + + +def test_ingest_false_positive_also_persisted() -> None: + with _build_session() as db: + obs = _observation(db, key="risk:c2:fp") + obs.feedback_status = "false_positive" + db.commit() + service = FewShotIngestionService(db) + with patch.object(service, "_store", return_value=MagicMock(upsert=MagicMock(return_value=None))): + sample = service.ingest_observation_feedback( + obs, + MagicMock(feedback_type="false_positive", comment="", actor="audit"), + ) + assert sample is not None + assert sample.label == "false_positive" + assert "误报" in sample.conclusion_text + + +def test_ingest_ignored_label_returns_none() -> None: + with _build_session() as db: + obs = _observation(db) + obs.feedback_status = "ignored" + service = FewShotIngestionService(db) + assert service.ingest_observation_feedback(obs, MagicMock()) is None + + +def test_ingest_is_idempotent_on_duplicate_sample_key() -> None: + with _build_session() as db: + obs = _observation(db) + service = FewShotIngestionService(db) + store = MagicMock() + store.upsert.side_effect = ["vec-1", "vec-2"] + with patch.object(service, "_store", return_value=store): + obs.feedback_status = "confirmed" + first = service.ingest_observation_feedback( + obs, MagicMock(feedback_type="confirm", comment="第一次", actor="a") + ) + # 模拟后续被改判为误报 + obs.feedback_status = "false_positive" + second = service.ingest_observation_feedback( + obs, MagicMock(feedback_type="false_positive", comment="改判", actor="a") + ) + assert first is not None and second is not None + assert first.id == second.id # 同一行更新 + from sqlalchemy import select + + count = db.scalar(select(FewShotSample).where(FewShotSample.sample_key == f"obs:{obs.id}")) + assert count is not None + assert second.label == "false_positive" + + +def test_create_feedback_hook_triggers_ingestion() -> None: + with _build_session() as db: + service = RiskObservationService(db) + obs = _observation(db) + ingest_calls: list = [] + + def _spy_ingest(o, f): + ingest_calls.append((o.id, f.feedback_type)) + return None + + with patch( + "app.services.few_shot_ingestion.FewShotIngestionService.ingest_observation_feedback", + side_effect=_spy_ingest, + ): + service.create_feedback( + obs.observation_key, + RiskObservationFeedbackCreate(feedback_type="confirm", actor="audit"), + ) + assert len(ingest_calls) == 1 + assert ingest_calls[0][1] == "confirm" + + +def test_create_feedback_hook_skipped_for_comment_feedback() -> None: + with _build_session() as db: + service = RiskObservationService(db) + obs = _observation(db) + with patch( + "app.services.few_shot_ingestion.FewShotIngestionService.ingest_observation_feedback" + ) as mock_ingest: + service.create_feedback( + obs.observation_key, + RiskObservationFeedbackCreate(feedback_type="comment", action="note", actor="audit"), + ) + mock_ingest.assert_not_called() + + +def test_create_feedback_hook_swallows_ingestion_failure() -> None: + with _build_session() as db: + service = RiskObservationService(db) + obs = _observation(db) + with patch( + "app.services.few_shot_ingestion.FewShotIngestionService.ingest_observation_feedback", + side_effect=RuntimeError("boom"), + ): + # 不应抛异常 + feedback = service.create_feedback( + obs.observation_key, + RiskObservationFeedbackCreate(feedback_type="confirm", actor="audit"), + ) + assert feedback.feedback_type == "confirm" + + +def test_create_feedback_hook_respects_feature_flag(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("FEW_SHOT_INJECTION_ENABLED", "false") + with _build_session() as db: + service = RiskObservationService(db) + obs = _observation(db) + with patch( + "app.services.few_shot_ingestion.FewShotIngestionService.ingest_observation_feedback" + ) as mock_ingest: + service.create_feedback( + obs.observation_key, + RiskObservationFeedbackCreate(feedback_type="confirm", actor="audit"), + ) + mock_ingest.assert_not_called() diff --git a/server/tests/test_few_shot_retrieval_and_prompt.py b/server/tests/test_few_shot_retrieval_and_prompt.py new file mode 100644 index 0000000..df99f54 --- /dev/null +++ b/server/tests/test_few_shot_retrieval_and_prompt.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +import json +from unittest.mock import MagicMock, patch + +import pytest + +from app.services.few_shot_retrieval import FewShotRetriever +from app.services.few_shot_store import FewShotStore +from app.services.risk_rule_generation_prompt import build_risk_rule_compiler_messages + + +def _hit(score: float, label: str, conclusion: str, risk_type: str = "duplicate_invoice") -> dict: + return { + "sample_id": "s1", + "score": score, + "label": label, + "domain": "expense", + "risk_type": risk_type, + "conclusion_text": conclusion, + "payload_json": { + "risk_signal": risk_type, + "risk_level": "high", + "ontology": {"scenario": "reimbursement"}, + "feedback_comment": "", + }, + } + + +def test_retrieve_returns_injection_blocks_with_token_budget() -> None: + store = MagicMock(spec=FewShotStore) + store.search.return_value = [ + _hit(0.9, "confirmed", "确认重复发票需拦截"), + _hit(0.8, "false_positive", "此情形属于正常拆单不拦截"), + _hit(0.7, "confirmed", "确认重复发票需拦截"), # 重复结论应被去重 + ] + retriever = FewShotRetriever(store) + blocks = retriever.retrieve_for_risk_rule_generation( + domain="expense", natural_language="同一发票重复报销" + ) + assert len(blocks) == 2 + assert blocks[0]["score"] == 0.9 + assert blocks[0]["label"] == "confirmed" + assert blocks[0]["source"] == "historical_confirmed" + assert blocks[1]["label"] == "false_positive" + # 去重:第三条结论与第一条相同,应被过滤 + conclusions = [b["conclusion"] for b in blocks] + assert len(set(conclusions)) == len(conclusions) + + +def test_retrieve_empty_case_text_returns_empty() -> None: + store = MagicMock(spec=FewShotStore) + retriever = FewShotRetriever(store) + assert retriever.retrieve_for_risk_rule_generation(natural_language="") == [] + store.search.assert_not_called() + + +def test_retrieve_truncates_overlong_conclusion() -> None: + store = MagicMock(spec=FewShotStore) + long_text = "长结论" * 500 + store.search.return_value = [ + _hit(0.9, "confirmed", long_text), + ] + retriever = FewShotRetriever(store) + blocks = retriever.retrieve_for_risk_rule_generation(natural_language="x") + assert len(blocks) == 1 + # 超长结论应被截断到单条上限 + from app.services.few_shot_retrieval import SINGLE_SAMPLE_MAX_CHARS + + assert len(blocks[0]["conclusion"]) <= SINGLE_SAMPLE_MAX_CHARS + + +def test_build_prompt_merges_few_shot_into_examples() -> None: + samples = [ + { + "source": "historical_confirmed", + "label": "confirmed", + "domain": "expense", + "risk_type": "duplicate_invoice", + "conclusion": "确认重复发票", + "context": {"risk_signal": "duplicate_invoice"}, + } + ] + messages = build_risk_rule_compiler_messages( + domain="expense", + domain_label="报销", + business_stage="reimbursement", + business_stage_label="报销", + expense_category=None, + expense_category_label="", + natural_language="重复发票规则", + available_fields=[{"key": "attachment.invoice_no", "label": "发票号", "type": "string", "source": "attachment"}], + few_shot_samples=samples, + ) + assert len(messages) == 2 + payload = json.loads(messages[1]["content"]) + examples = payload["examples"] + # 前两条是历史样本,后面是内置 examples + assert examples[0]["source"] == "historical_confirmed" + assert examples[0]["conclusion"] == "确认重复发票" + # 内置 example 仍存在(无 source 字段) + assert any("user_rule" in ex for ex in examples) + + +def test_build_prompt_without_few_shot_is_backward_compatible() -> None: + messages = build_risk_rule_compiler_messages( + domain="expense", + domain_label="报销", + business_stage="reimbursement", + business_stage_label="报销", + expense_category=None, + expense_category_label="", + natural_language="重复发票规则", + available_fields=[], + ) + payload = json.loads(messages[1]["content"]) + examples = payload["examples"] + # 无 few_shot_samples 时 examples 里不应有 historical_confirmed 来源 + assert all(ex.get("source") != "historical_confirmed" for ex in examples)