feat(flywheel): few-shot 在线检索注入打通风险规则编译链路

- 新增 FewShotStore:独立 Qdrant collection few_shot_samples,向量 upsert/search/delete,
  全程失败降级不阻塞主链路
- 新增 FewShotIngestionService:RiskObservation confirmed/false_positive → FewShotSample +
  向量,带 sample_key 幂等去重
- 新增 FewShotRetriever:按 case 特征检索相似历史样本,去重 + token 预算 + 单条字符上限裁剪
- risk_observations.create_feedback commit 后挂 hook 自动入库,带 feature flag 和 try/except 兜底
- risk_rule_generation_prompt 新增 few_shot_samples 可选参数,合并进 examples 并标
  source=historical_confirmed;risk_rule_generation 构造 prompt 前调 retriever,失败降级为空
This commit is contained in:
caoxiaozhu
2026-07-03 13:55:52 +08:00
parent 765cfb40f3
commit 3a9d154783
6 changed files with 584 additions and 1 deletions

View File

@@ -0,0 +1,122 @@
"""Few-shot 检索器:按当前 case 特征检索相似历史样本,拼成注入块。
从 :class:`FewShotStore` 取相似样本,转成可供 prompt 构造函数直接使用的结构。
带 token 预算裁剪和去重,确保不撑爆 prompt。
典型用法(在构造 prompt 之前调用)::
retriever = FewShotRetriever.from_session(session)
samples = retriever.retrieve_for_risk_rule_generation(
domain="travel", natural_language="票据城市与申报地不一致"
)
messages = build_risk_rule_compiler_messages(
...,
few_shot_samples=samples,
)
"""
from __future__ import annotations
from typing import Any
from sqlalchemy.orm import Session
from app.core.logging import get_logger
from app.services.embedding_provider import EmbeddingProvider
from app.services.few_shot_store import FewShotStore
logger = get_logger("app.services.few_shot_retrieval")
# 单条 few-shot 样本估算 token 数(用于预算裁剪)
SAMPLE_TOKEN_BUDGET = 1200
# 单条样本最大字符数,超长直接截断结论,避免撑爆 prompt
SINGLE_SAMPLE_MAX_CHARS = 400
# 历史样本最多注入条数(与原内置 examples 合并后总量受限)
MAX_HISTORICAL_SAMPLES = 3
class FewShotRetriever:
"""按 case 特征检索已确认样本,返回 prompt 可直接消费的结构。"""
def __init__(self, store: FewShotStore) -> None:
self._store = store
@classmethod
def from_session(cls, session: Session) -> "FewShotRetriever":
provider = EmbeddingProvider.from_settings(session)
return cls(FewShotStore(provider))
def retrieve_for_risk_rule_generation(
self,
*,
domain: str = "",
risk_type: str = "",
natural_language: str,
top_k: int = MAX_HISTORICAL_SAMPLES,
) -> list[dict[str, Any]]:
"""检索与当前规则需求相似的历史样本,返回注入块列表。"""
case_text = self._build_case_text(
natural_language=natural_language,
domain=domain,
risk_type=risk_type,
)
if not case_text:
return []
hits = self._store.search(
case_text,
scene="risk_rule_generation",
labels=["confirmed", "false_positive"],
top_k=top_k,
)
return self._hits_to_injection_blocks(hits)
def _build_case_text(
self,
*,
natural_language: str,
domain: str = "",
risk_type: str = "",
) -> str:
parts = [natural_language, domain, risk_type]
return "\n".join(p for p in parts if p).strip()
def _hits_to_injection_blocks(
self,
hits: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""把检索命中转成 prompt 可消费的块,做去重和预算裁剪。"""
blocks: list[dict[str, Any]] = []
seen_conclusions: set[str] = set()
budget = SAMPLE_TOKEN_BUDGET
for hit in hits:
conclusion = (hit.get("conclusion_text") or "").strip()
if not conclusion or conclusion in seen_conclusions:
continue
# 超长结论截断到上限,避免单条样本占用过多预算
if len(conclusion) > SINGLE_SAMPLE_MAX_CHARS:
conclusion = conclusion[:SINGLE_SAMPLE_MAX_CHARS]
payload = hit.get("payload_json") or {}
block = {
"source": "historical_confirmed",
"label": hit.get("label"),
"domain": hit.get("domain") or "",
"risk_type": hit.get("risk_type") or "",
"score": round(float(hit.get("score") or 0.0), 4),
"conclusion": conclusion,
"context": {
"risk_signal": payload.get("risk_signal") or "",
"risk_level": payload.get("risk_level") or "",
"ontology": payload.get("ontology") or {},
"feedback_comment": payload.get("feedback_comment") or "",
},
}
# 粗略 token 估算(按字符数 / 1.6 近似中文 token 比)
estimated_tokens = int(len(conclusion) / 1.6) + 40
if estimated_tokens > budget:
break
budget -= estimated_tokens
blocks.append(block)
seen_conclusions.add(conclusion)
return blocks