feat(flywheel): few-shot 在线检索注入打通风险规则编译链路
- 新增 FewShotStore:独立 Qdrant collection few_shot_samples,向量 upsert/search/delete, 全程失败降级不阻塞主链路 - 新增 FewShotIngestionService:RiskObservation confirmed/false_positive → FewShotSample + 向量,带 sample_key 幂等去重 - 新增 FewShotRetriever:按 case 特征检索相似历史样本,去重 + token 预算 + 单条字符上限裁剪 - risk_observations.create_feedback commit 后挂 hook 自动入库,带 feature flag 和 try/except 兜底 - risk_rule_generation_prompt 新增 few_shot_samples 可选参数,合并进 examples 并标 source=historical_confirmed;risk_rule_generation 构造 prompt 前调 retriever,失败降级为空
This commit is contained in:
122
server/src/app/services/few_shot_retrieval.py
Normal file
122
server/src/app/services/few_shot_retrieval.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""Few-shot 检索器:按当前 case 特征检索相似历史样本,拼成注入块。
|
||||
|
||||
从 :class:`FewShotStore` 取相似样本,转成可供 prompt 构造函数直接使用的结构。
|
||||
带 token 预算裁剪和去重,确保不撑爆 prompt。
|
||||
|
||||
典型用法(在构造 prompt 之前调用)::
|
||||
|
||||
retriever = FewShotRetriever.from_session(session)
|
||||
samples = retriever.retrieve_for_risk_rule_generation(
|
||||
domain="travel", natural_language="票据城市与申报地不一致"
|
||||
)
|
||||
messages = build_risk_rule_compiler_messages(
|
||||
...,
|
||||
few_shot_samples=samples,
|
||||
)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.services.embedding_provider import EmbeddingProvider
|
||||
from app.services.few_shot_store import FewShotStore
|
||||
|
||||
logger = get_logger("app.services.few_shot_retrieval")
|
||||
|
||||
# 单条 few-shot 样本估算 token 数(用于预算裁剪)
|
||||
SAMPLE_TOKEN_BUDGET = 1200
|
||||
# 单条样本最大字符数,超长直接截断结论,避免撑爆 prompt
|
||||
SINGLE_SAMPLE_MAX_CHARS = 400
|
||||
# 历史样本最多注入条数(与原内置 examples 合并后总量受限)
|
||||
MAX_HISTORICAL_SAMPLES = 3
|
||||
|
||||
|
||||
class FewShotRetriever:
|
||||
"""按 case 特征检索已确认样本,返回 prompt 可直接消费的结构。"""
|
||||
|
||||
def __init__(self, store: FewShotStore) -> None:
|
||||
self._store = store
|
||||
|
||||
@classmethod
|
||||
def from_session(cls, session: Session) -> "FewShotRetriever":
|
||||
provider = EmbeddingProvider.from_settings(session)
|
||||
return cls(FewShotStore(provider))
|
||||
|
||||
def retrieve_for_risk_rule_generation(
|
||||
self,
|
||||
*,
|
||||
domain: str = "",
|
||||
risk_type: str = "",
|
||||
natural_language: str,
|
||||
top_k: int = MAX_HISTORICAL_SAMPLES,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""检索与当前规则需求相似的历史样本,返回注入块列表。"""
|
||||
|
||||
case_text = self._build_case_text(
|
||||
natural_language=natural_language,
|
||||
domain=domain,
|
||||
risk_type=risk_type,
|
||||
)
|
||||
if not case_text:
|
||||
return []
|
||||
hits = self._store.search(
|
||||
case_text,
|
||||
scene="risk_rule_generation",
|
||||
labels=["confirmed", "false_positive"],
|
||||
top_k=top_k,
|
||||
)
|
||||
return self._hits_to_injection_blocks(hits)
|
||||
|
||||
def _build_case_text(
|
||||
self,
|
||||
*,
|
||||
natural_language: str,
|
||||
domain: str = "",
|
||||
risk_type: str = "",
|
||||
) -> str:
|
||||
parts = [natural_language, domain, risk_type]
|
||||
return "\n".join(p for p in parts if p).strip()
|
||||
|
||||
def _hits_to_injection_blocks(
|
||||
self,
|
||||
hits: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""把检索命中转成 prompt 可消费的块,做去重和预算裁剪。"""
|
||||
|
||||
blocks: list[dict[str, Any]] = []
|
||||
seen_conclusions: set[str] = set()
|
||||
budget = SAMPLE_TOKEN_BUDGET
|
||||
for hit in hits:
|
||||
conclusion = (hit.get("conclusion_text") or "").strip()
|
||||
if not conclusion or conclusion in seen_conclusions:
|
||||
continue
|
||||
# 超长结论截断到上限,避免单条样本占用过多预算
|
||||
if len(conclusion) > SINGLE_SAMPLE_MAX_CHARS:
|
||||
conclusion = conclusion[:SINGLE_SAMPLE_MAX_CHARS]
|
||||
payload = hit.get("payload_json") or {}
|
||||
block = {
|
||||
"source": "historical_confirmed",
|
||||
"label": hit.get("label"),
|
||||
"domain": hit.get("domain") or "",
|
||||
"risk_type": hit.get("risk_type") or "",
|
||||
"score": round(float(hit.get("score") or 0.0), 4),
|
||||
"conclusion": conclusion,
|
||||
"context": {
|
||||
"risk_signal": payload.get("risk_signal") or "",
|
||||
"risk_level": payload.get("risk_level") or "",
|
||||
"ontology": payload.get("ontology") or {},
|
||||
"feedback_comment": payload.get("feedback_comment") or "",
|
||||
},
|
||||
}
|
||||
# 粗略 token 估算(按字符数 / 1.6 近似中文 token 比)
|
||||
estimated_tokens = int(len(conclusion) / 1.6) + 40
|
||||
if estimated_tokens > budget:
|
||||
break
|
||||
budget -= estimated_tokens
|
||||
blocks.append(block)
|
||||
seen_conclusions.add(conclusion)
|
||||
return blocks
|
||||
Reference in New Issue
Block a user