"""Few-shot 检索器:按当前 case 特征检索相似历史样本,拼成注入块。 从 :class:`FewShotStore` 取相似样本,转成可供 prompt 构造函数直接使用的结构。 带 token 预算裁剪和去重,确保不撑爆 prompt。 典型用法(在构造 prompt 之前调用):: retriever = FewShotRetriever.from_session(session) samples = retriever.retrieve_for_risk_rule_generation( domain="travel", natural_language="票据城市与申报地不一致" ) messages = build_risk_rule_compiler_messages( ..., few_shot_samples=samples, ) """ from __future__ import annotations from typing import Any from sqlalchemy.orm import Session from app.core.logging import get_logger from app.services.embedding_provider import EmbeddingProvider from app.services.few_shot_store import FewShotStore logger = get_logger("app.services.few_shot_retrieval") # 单条 few-shot 样本估算 token 数(用于预算裁剪) SAMPLE_TOKEN_BUDGET = 1200 # 单条样本最大字符数,超长直接截断结论,避免撑爆 prompt SINGLE_SAMPLE_MAX_CHARS = 400 # 历史样本最多注入条数(与原内置 examples 合并后总量受限) MAX_HISTORICAL_SAMPLES = 3 class FewShotRetriever: """按 case 特征检索已确认样本,返回 prompt 可直接消费的结构。""" def __init__(self, store: FewShotStore) -> None: self._store = store @classmethod def from_session(cls, session: Session) -> "FewShotRetriever": provider = EmbeddingProvider.from_settings(session) return cls(FewShotStore(provider)) def retrieve_for_risk_rule_generation( self, *, domain: str = "", risk_type: str = "", natural_language: str, top_k: int = MAX_HISTORICAL_SAMPLES, ) -> list[dict[str, Any]]: """检索与当前规则需求相似的历史样本,返回注入块列表。""" case_text = self._build_case_text( natural_language=natural_language, domain=domain, risk_type=risk_type, ) if not case_text: return [] hits = self._store.search( case_text, scene="risk_rule_generation", labels=["confirmed", "false_positive"], top_k=top_k, ) return self._hits_to_injection_blocks(hits) def _build_case_text( self, *, natural_language: str, domain: str = "", risk_type: str = "", ) -> str: parts = [natural_language, domain, risk_type] return "\n".join(p for p in parts if p).strip() def _hits_to_injection_blocks( self, hits: list[dict[str, Any]], ) -> list[dict[str, Any]]: """把检索命中转成 prompt 可消费的块,做去重和预算裁剪。""" blocks: list[dict[str, Any]] = [] seen_conclusions: set[str] = set() budget = SAMPLE_TOKEN_BUDGET for hit in hits: conclusion = (hit.get("conclusion_text") or "").strip() if not conclusion or conclusion in seen_conclusions: continue # 超长结论截断到上限,避免单条样本占用过多预算 if len(conclusion) > SINGLE_SAMPLE_MAX_CHARS: conclusion = conclusion[:SINGLE_SAMPLE_MAX_CHARS] payload = hit.get("payload_json") or {} block = { "source": "historical_confirmed", "label": hit.get("label"), "domain": hit.get("domain") or "", "risk_type": hit.get("risk_type") or "", "score": round(float(hit.get("score") or 0.0), 4), "conclusion": conclusion, "context": { "risk_signal": payload.get("risk_signal") or "", "risk_level": payload.get("risk_level") or "", "ontology": payload.get("ontology") or {}, "feedback_comment": payload.get("feedback_comment") or "", }, } # 粗略 token 估算(按字符数 / 1.6 近似中文 token 比) estimated_tokens = int(len(conclusion) / 1.6) + 40 if estimated_tokens > budget: break budget -= estimated_tokens blocks.append(block) seen_conclusions.add(conclusion) return blocks