Files
X-Financial/server/src/app/services/few_shot_retrieval.py

123 lines
4.3 KiB
Python
Raw Normal View History

"""Few-shot 检索器:按当前 case 特征检索相似历史样本,拼成注入块。
:class:`FewShotStore` 取相似样本转成可供 prompt 构造函数直接使用的结构
token 预算裁剪和去重确保不撑爆 prompt
典型用法在构造 prompt 之前调用::
retriever = FewShotRetriever.from_session(session)
samples = retriever.retrieve_for_risk_rule_generation(
domain="travel", natural_language="票据城市与申报地不一致"
)
messages = build_risk_rule_compiler_messages(
...,
few_shot_samples=samples,
)
"""
from __future__ import annotations
from typing import Any
from sqlalchemy.orm import Session
from app.core.logging import get_logger
from app.services.embedding_provider import EmbeddingProvider
from app.services.few_shot_store import FewShotStore
logger = get_logger("app.services.few_shot_retrieval")
# 单条 few-shot 样本估算 token 数(用于预算裁剪)
SAMPLE_TOKEN_BUDGET = 1200
# 单条样本最大字符数,超长直接截断结论,避免撑爆 prompt
SINGLE_SAMPLE_MAX_CHARS = 400
# 历史样本最多注入条数(与原内置 examples 合并后总量受限)
MAX_HISTORICAL_SAMPLES = 3
class FewShotRetriever:
"""按 case 特征检索已确认样本,返回 prompt 可直接消费的结构。"""
def __init__(self, store: FewShotStore) -> None:
self._store = store
@classmethod
def from_session(cls, session: Session) -> "FewShotRetriever":
provider = EmbeddingProvider.from_settings(session)
return cls(FewShotStore(provider))
def retrieve_for_risk_rule_generation(
self,
*,
domain: str = "",
risk_type: str = "",
natural_language: str,
top_k: int = MAX_HISTORICAL_SAMPLES,
) -> list[dict[str, Any]]:
"""检索与当前规则需求相似的历史样本,返回注入块列表。"""
case_text = self._build_case_text(
natural_language=natural_language,
domain=domain,
risk_type=risk_type,
)
if not case_text:
return []
hits = self._store.search(
case_text,
scene="risk_rule_generation",
labels=["confirmed", "false_positive"],
top_k=top_k,
)
return self._hits_to_injection_blocks(hits)
def _build_case_text(
self,
*,
natural_language: str,
domain: str = "",
risk_type: str = "",
) -> str:
parts = [natural_language, domain, risk_type]
return "\n".join(p for p in parts if p).strip()
def _hits_to_injection_blocks(
self,
hits: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""把检索命中转成 prompt 可消费的块,做去重和预算裁剪。"""
blocks: list[dict[str, Any]] = []
seen_conclusions: set[str] = set()
budget = SAMPLE_TOKEN_BUDGET
for hit in hits:
conclusion = (hit.get("conclusion_text") or "").strip()
if not conclusion or conclusion in seen_conclusions:
continue
# 超长结论截断到上限,避免单条样本占用过多预算
if len(conclusion) > SINGLE_SAMPLE_MAX_CHARS:
conclusion = conclusion[:SINGLE_SAMPLE_MAX_CHARS]
payload = hit.get("payload_json") or {}
block = {
"source": "historical_confirmed",
"label": hit.get("label"),
"domain": hit.get("domain") or "",
"risk_type": hit.get("risk_type") or "",
"score": round(float(hit.get("score") or 0.0), 4),
"conclusion": conclusion,
"context": {
"risk_signal": payload.get("risk_signal") or "",
"risk_level": payload.get("risk_level") or "",
"ontology": payload.get("ontology") or {},
"feedback_comment": payload.get("feedback_comment") or "",
},
}
# 粗略 token 估算(按字符数 / 1.6 近似中文 token 比)
estimated_tokens = int(len(conclusion) / 1.6) + 40
if estimated_tokens > budget:
break
budget -= estimated_tokens
blocks.append(block)
seen_conclusions.add(conclusion)
return blocks