123 lines
4.3 KiB
Python
123 lines
4.3 KiB
Python
|
|
"""Few-shot 检索器:按当前 case 特征检索相似历史样本,拼成注入块。
|
|||
|
|
|
|||
|
|
从 :class:`FewShotStore` 取相似样本,转成可供 prompt 构造函数直接使用的结构。
|
|||
|
|
带 token 预算裁剪和去重,确保不撑爆 prompt。
|
|||
|
|
|
|||
|
|
典型用法(在构造 prompt 之前调用)::
|
|||
|
|
|
|||
|
|
retriever = FewShotRetriever.from_session(session)
|
|||
|
|
samples = retriever.retrieve_for_risk_rule_generation(
|
|||
|
|
domain="travel", natural_language="票据城市与申报地不一致"
|
|||
|
|
)
|
|||
|
|
messages = build_risk_rule_compiler_messages(
|
|||
|
|
...,
|
|||
|
|
few_shot_samples=samples,
|
|||
|
|
)
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
from typing import Any
|
|||
|
|
|
|||
|
|
from sqlalchemy.orm import Session
|
|||
|
|
|
|||
|
|
from app.core.logging import get_logger
|
|||
|
|
from app.services.embedding_provider import EmbeddingProvider
|
|||
|
|
from app.services.few_shot_store import FewShotStore
|
|||
|
|
|
|||
|
|
logger = get_logger("app.services.few_shot_retrieval")
|
|||
|
|
|
|||
|
|
# 单条 few-shot 样本估算 token 数(用于预算裁剪)
|
|||
|
|
SAMPLE_TOKEN_BUDGET = 1200
|
|||
|
|
# 单条样本最大字符数,超长直接截断结论,避免撑爆 prompt
|
|||
|
|
SINGLE_SAMPLE_MAX_CHARS = 400
|
|||
|
|
# 历史样本最多注入条数(与原内置 examples 合并后总量受限)
|
|||
|
|
MAX_HISTORICAL_SAMPLES = 3
|
|||
|
|
|
|||
|
|
|
|||
|
|
class FewShotRetriever:
|
|||
|
|
"""按 case 特征检索已确认样本,返回 prompt 可直接消费的结构。"""
|
|||
|
|
|
|||
|
|
def __init__(self, store: FewShotStore) -> None:
|
|||
|
|
self._store = store
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
def from_session(cls, session: Session) -> "FewShotRetriever":
|
|||
|
|
provider = EmbeddingProvider.from_settings(session)
|
|||
|
|
return cls(FewShotStore(provider))
|
|||
|
|
|
|||
|
|
def retrieve_for_risk_rule_generation(
|
|||
|
|
self,
|
|||
|
|
*,
|
|||
|
|
domain: str = "",
|
|||
|
|
risk_type: str = "",
|
|||
|
|
natural_language: str,
|
|||
|
|
top_k: int = MAX_HISTORICAL_SAMPLES,
|
|||
|
|
) -> list[dict[str, Any]]:
|
|||
|
|
"""检索与当前规则需求相似的历史样本,返回注入块列表。"""
|
|||
|
|
|
|||
|
|
case_text = self._build_case_text(
|
|||
|
|
natural_language=natural_language,
|
|||
|
|
domain=domain,
|
|||
|
|
risk_type=risk_type,
|
|||
|
|
)
|
|||
|
|
if not case_text:
|
|||
|
|
return []
|
|||
|
|
hits = self._store.search(
|
|||
|
|
case_text,
|
|||
|
|
scene="risk_rule_generation",
|
|||
|
|
labels=["confirmed", "false_positive"],
|
|||
|
|
top_k=top_k,
|
|||
|
|
)
|
|||
|
|
return self._hits_to_injection_blocks(hits)
|
|||
|
|
|
|||
|
|
def _build_case_text(
|
|||
|
|
self,
|
|||
|
|
*,
|
|||
|
|
natural_language: str,
|
|||
|
|
domain: str = "",
|
|||
|
|
risk_type: str = "",
|
|||
|
|
) -> str:
|
|||
|
|
parts = [natural_language, domain, risk_type]
|
|||
|
|
return "\n".join(p for p in parts if p).strip()
|
|||
|
|
|
|||
|
|
def _hits_to_injection_blocks(
|
|||
|
|
self,
|
|||
|
|
hits: list[dict[str, Any]],
|
|||
|
|
) -> list[dict[str, Any]]:
|
|||
|
|
"""把检索命中转成 prompt 可消费的块,做去重和预算裁剪。"""
|
|||
|
|
|
|||
|
|
blocks: list[dict[str, Any]] = []
|
|||
|
|
seen_conclusions: set[str] = set()
|
|||
|
|
budget = SAMPLE_TOKEN_BUDGET
|
|||
|
|
for hit in hits:
|
|||
|
|
conclusion = (hit.get("conclusion_text") or "").strip()
|
|||
|
|
if not conclusion or conclusion in seen_conclusions:
|
|||
|
|
continue
|
|||
|
|
# 超长结论截断到上限,避免单条样本占用过多预算
|
|||
|
|
if len(conclusion) > SINGLE_SAMPLE_MAX_CHARS:
|
|||
|
|
conclusion = conclusion[:SINGLE_SAMPLE_MAX_CHARS]
|
|||
|
|
payload = hit.get("payload_json") or {}
|
|||
|
|
block = {
|
|||
|
|
"source": "historical_confirmed",
|
|||
|
|
"label": hit.get("label"),
|
|||
|
|
"domain": hit.get("domain") or "",
|
|||
|
|
"risk_type": hit.get("risk_type") or "",
|
|||
|
|
"score": round(float(hit.get("score") or 0.0), 4),
|
|||
|
|
"conclusion": conclusion,
|
|||
|
|
"context": {
|
|||
|
|
"risk_signal": payload.get("risk_signal") or "",
|
|||
|
|
"risk_level": payload.get("risk_level") or "",
|
|||
|
|
"ontology": payload.get("ontology") or {},
|
|||
|
|
"feedback_comment": payload.get("feedback_comment") or "",
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
# 粗略 token 估算(按字符数 / 1.6 近似中文 token 比)
|
|||
|
|
estimated_tokens = int(len(conclusion) / 1.6) + 40
|
|||
|
|
if estimated_tokens > budget:
|
|||
|
|
break
|
|||
|
|
budget -= estimated_tokens
|
|||
|
|
blocks.append(block)
|
|||
|
|
seen_conclusions.add(conclusion)
|
|||
|
|
return blocks
|