feat(flywheel): few-shot 在线检索注入打通风险规则编译链路

- 新增 FewShotStore:独立 Qdrant collection few_shot_samples,向量 upsert/search/delete,
  全程失败降级不阻塞主链路
- 新增 FewShotIngestionService:RiskObservation confirmed/false_positive → FewShotSample +
  向量,带 sample_key 幂等去重
- 新增 FewShotRetriever:按 case 特征检索相似历史样本,去重 + token 预算 + 单条字符上限裁剪
- risk_observations.create_feedback commit 后挂 hook 自动入库,带 feature flag 和 try/except 兜底
- risk_rule_generation_prompt 新增 few_shot_samples 可选参数,合并进 examples 并标
  source=historical_confirmed;risk_rule_generation 构造 prompt 前调 retriever,失败降级为空
This commit is contained in:
caoxiaozhu
2026-07-03 13:55:52 +08:00
parent 765cfb40f3
commit 3a9d154783
6 changed files with 584 additions and 1 deletions

View File

@@ -0,0 +1,177 @@
"""Few-shot 样本入库编排RiskObservation → FewShotSample → Qdrant。
只处理人工确认为 confirmed / false_positive 的观测,把它转成一条
:class:`FewShotSample`,持久化到 DB并同步向量到 Qdrant。
入库动作由 :meth:`RiskObservationService.create_feedback` 在 commit 后触发,
本服务全程吞异常(只记日志),保证反馈主流程不被 few-shot 链路拖崩。
"""
from __future__ import annotations
from typing import Any
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.core.logging import get_logger
from app.models.few_shot_sample import FewShotSample
from app.models.risk_observation import RiskObservation, RiskObservationFeedback
from app.services.embedding_provider import EmbeddingProvider
from app.services.few_shot_store import FewShotStore
logger = get_logger("app.services.few_shot_ingestion")
# 仅这两个 feedback_status 视为已确认样本,会入库
CONFIRMED_LABELS = {"confirmed", "false_positive"}
# label → 自然语言结论(当 feedback.comment 缺失时兜底)
LABEL_CONCLUSION_FALLBACK = {
"confirmed": "经人工复核确认,该风险线索成立,需按规则拦截或补件。",
"false_positive": "经人工复核判定为误报,相似情形不应触发该风险规则。",
}
class FewShotIngestionService:
"""把已确认的风险观测沉淀为 few-shot 样本。"""
def __init__(self, db: Session) -> None:
self.db = db
def ingest_observation_feedback(
self,
observation: RiskObservation,
feedback: RiskObservationFeedback,
) -> FewShotSample | None:
"""人工确认/误报后调用,写入并同步向量。"""
label = observation.feedback_status
if label not in CONFIRMED_LABELS:
return None
sample_key = f"obs:{observation.id}"
sample = self.db.scalar(
select(FewShotSample).where(FewShotSample.sample_key == sample_key)
)
domain = self._extract_domain(observation)
case_text = self._build_case_text(observation)
conclusion_text = self._build_conclusion_text(observation, feedback, label)
payload = self._build_payload(observation, feedback, label)
if sample is None:
sample = FewShotSample(
sample_key=sample_key,
source_observation_id=observation.id,
scene="risk_rule_generation",
domain=domain,
risk_type=observation.risk_type or "",
risk_level=observation.risk_level or "",
label=label,
case_text=case_text,
conclusion_text=conclusion_text,
payload_json=payload,
status="active",
)
self.db.add(sample)
else:
sample.label = label
sample.domain = domain
sample.risk_type = observation.risk_type or ""
sample.risk_level = observation.risk_level or ""
sample.case_text = case_text
sample.conclusion_text = conclusion_text
sample.payload_json = payload
sample.status = "active"
sample.vector_id = sample.vector_id
try:
self.db.commit()
self.db.refresh(sample)
except Exception:
logger.exception("few-shot 样本持久化失败 observation_id=%s", observation.id)
self.db.rollback()
return None
vector_id = self._store().upsert(sample)
if vector_id:
sample.vector_id = vector_id
try:
self.db.commit()
except Exception:
logger.warning("few-shot vector_id 回写失败 sample_id=%s", sample.id)
return sample
def retract_observation(self, observation_id: str) -> bool:
"""观测被撤销时删掉对应样本及其向量。"""
sample = self.db.scalar(
select(FewShotSample).where(FewShotSample.source_observation_id == observation_id)
)
if sample is None:
return False
if sample.vector_id:
self._store().delete_by_vector_id(sample.vector_id)
try:
self.db.delete(sample)
self.db.commit()
return True
except Exception:
logger.exception("few-shot 样本删除失败 observation_id=%s", observation_id)
self.db.rollback()
return False
def _store(self) -> FewShotStore:
provider = EmbeddingProvider.from_settings(self.db)
return FewShotStore(provider)
def _extract_domain(self, observation: RiskObservation) -> str:
ontology = observation.ontology_json or {}
return str(ontology.get("domain") or "")
def _build_case_text(self, observation: RiskObservation) -> str:
parts = [
observation.title or "",
observation.description or "",
observation.risk_signal or "",
observation.risk_type or "",
]
ontology = observation.ontology_json or {}
scenario = ontology.get("scenario")
if scenario:
parts.append(f"场景:{scenario}")
risk_signals = ontology.get("risk_signals")
if isinstance(risk_signals, list) and risk_signals:
parts.append("信号:" + "|".join(str(s) for s in risk_signals))
return "\n".join(part for part in parts if part).strip()
def _build_conclusion_text(
self,
observation: RiskObservation,
feedback: RiskObservationFeedback,
label: str,
) -> str:
comment = (feedback.comment or "").strip()
if comment:
return f"[{label}] {comment}"
return LABEL_CONCLUSION_FALLBACK.get(label, label)
def _build_payload(
self,
observation: RiskObservation,
feedback: RiskObservationFeedback,
label: str,
) -> dict[str, Any]:
return {
"label": label,
"risk_type": observation.risk_type,
"risk_signal": observation.risk_signal,
"risk_level": observation.risk_level,
"feedback_type": feedback.feedback_type,
"feedback_comment": feedback.comment or "",
"feedback_actor": feedback.actor or "",
"ontology": observation.ontology_json or {},
"policy_refs": observation.policy_refs_json or [],
"evidence": observation.evidence_json or [],
"subject_label": observation.subject_label or "",
"claim_no": observation.claim_no or "",
}

View File

@@ -0,0 +1,122 @@
"""Few-shot 检索器:按当前 case 特征检索相似历史样本,拼成注入块。
从 :class:`FewShotStore` 取相似样本,转成可供 prompt 构造函数直接使用的结构。
带 token 预算裁剪和去重,确保不撑爆 prompt。
典型用法(在构造 prompt 之前调用)::
retriever = FewShotRetriever.from_session(session)
samples = retriever.retrieve_for_risk_rule_generation(
domain="travel", natural_language="票据城市与申报地不一致"
)
messages = build_risk_rule_compiler_messages(
...,
few_shot_samples=samples,
)
"""
from __future__ import annotations
from typing import Any
from sqlalchemy.orm import Session
from app.core.logging import get_logger
from app.services.embedding_provider import EmbeddingProvider
from app.services.few_shot_store import FewShotStore
logger = get_logger("app.services.few_shot_retrieval")
# 单条 few-shot 样本估算 token 数(用于预算裁剪)
SAMPLE_TOKEN_BUDGET = 1200
# 单条样本最大字符数,超长直接截断结论,避免撑爆 prompt
SINGLE_SAMPLE_MAX_CHARS = 400
# 历史样本最多注入条数(与原内置 examples 合并后总量受限)
MAX_HISTORICAL_SAMPLES = 3
class FewShotRetriever:
"""按 case 特征检索已确认样本,返回 prompt 可直接消费的结构。"""
def __init__(self, store: FewShotStore) -> None:
self._store = store
@classmethod
def from_session(cls, session: Session) -> "FewShotRetriever":
provider = EmbeddingProvider.from_settings(session)
return cls(FewShotStore(provider))
def retrieve_for_risk_rule_generation(
self,
*,
domain: str = "",
risk_type: str = "",
natural_language: str,
top_k: int = MAX_HISTORICAL_SAMPLES,
) -> list[dict[str, Any]]:
"""检索与当前规则需求相似的历史样本,返回注入块列表。"""
case_text = self._build_case_text(
natural_language=natural_language,
domain=domain,
risk_type=risk_type,
)
if not case_text:
return []
hits = self._store.search(
case_text,
scene="risk_rule_generation",
labels=["confirmed", "false_positive"],
top_k=top_k,
)
return self._hits_to_injection_blocks(hits)
def _build_case_text(
self,
*,
natural_language: str,
domain: str = "",
risk_type: str = "",
) -> str:
parts = [natural_language, domain, risk_type]
return "\n".join(p for p in parts if p).strip()
def _hits_to_injection_blocks(
self,
hits: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""把检索命中转成 prompt 可消费的块,做去重和预算裁剪。"""
blocks: list[dict[str, Any]] = []
seen_conclusions: set[str] = set()
budget = SAMPLE_TOKEN_BUDGET
for hit in hits:
conclusion = (hit.get("conclusion_text") or "").strip()
if not conclusion or conclusion in seen_conclusions:
continue
# 超长结论截断到上限,避免单条样本占用过多预算
if len(conclusion) > SINGLE_SAMPLE_MAX_CHARS:
conclusion = conclusion[:SINGLE_SAMPLE_MAX_CHARS]
payload = hit.get("payload_json") or {}
block = {
"source": "historical_confirmed",
"label": hit.get("label"),
"domain": hit.get("domain") or "",
"risk_type": hit.get("risk_type") or "",
"score": round(float(hit.get("score") or 0.0), 4),
"conclusion": conclusion,
"context": {
"risk_signal": payload.get("risk_signal") or "",
"risk_level": payload.get("risk_level") or "",
"ontology": payload.get("ontology") or {},
"feedback_comment": payload.get("feedback_comment") or "",
},
}
# 粗略 token 估算(按字符数 / 1.6 近似中文 token 比)
estimated_tokens = int(len(conclusion) / 1.6) + 40
if estimated_tokens > budget:
break
budget -= estimated_tokens
blocks.append(block)
seen_conclusions.add(conclusion)
return blocks

View File

@@ -0,0 +1,214 @@
"""Few-shot 样本的 Qdrant 向量存储。
独立于 LightRAG 的 Qdrant 客户端,使用专用 collection ``few_shot_samples``
与知识库 RAG 的 collection 隔离。所有操作失败都不抛异常(记日志返回空),
保证主链路不阻塞。
向量来自 :class:`EmbeddingProvider`payload 带业务过滤字段scene/label/domain/risk_type
检索时按这些字段过滤 + 向量相似度排序。
"""
from __future__ import annotations
import os
import uuid
from typing import Any
from app.core.logging import get_logger
from app.services.knowledge_rag import _resolve_default_qdrant_url
logger = get_logger("app.services.few_shot_store")
FEW_SHOT_COLLECTION = "few_shot_samples"
def _resolve_qdrant_config() -> tuple[str, str]:
"""复用 knowledge_rag 的 Qdrant URL/key 解析逻辑。"""
url = os.environ.get("QDRANT_URL", "").strip() or _resolve_default_qdrant_url()
api_key = os.environ.get("QDRANT_API_KEY", "").strip()
return url, api_key
class FewShotStore:
"""对 Qdrant 的轻量封装,专供 few-shot 样本检索使用。
设计要点:
- 惰性创建 client 和 collection首次操作时初始化。
- 所有公共方法吞异常(返回空/False主链路永远不被拖崩。
- 向量写入和检索都依赖外部传入的 :class:`EmbeddingProvider`
由调用方保证与配置一致。
"""
def __init__(self, embedding_provider: Any) -> None:
self._embedding_provider = embedding_provider
self._client: Any = None
self._ensured = False
def _client_or_none(self) -> Any:
"""惰性初始化 QdrantClient失败返回 None。"""
if self._client is not None:
return self._client
try:
from qdrant_client import QdrantClient
url, api_key = _resolve_qdrant_config()
self._client = QdrantClient(url=url, api_key=api_key or None)
except Exception:
logger.warning("few-shot QdrantClient 初始化失败,本轮操作跳过", exc_info=True)
self._client = None
return self._client
def _ensure_collection(self) -> bool:
"""确保 collection 存在,成功返回 True。"""
if self._ensured:
return True
client = self._client_or_none()
if client is None:
return False
try:
from qdrant_client.http.exceptions import UnexpectedResponse
try:
client.get_collection(FEW_SHOT_COLLECTION)
self._ensured = True
return True
except UnexpectedResponse as exc:
if exc.status_code != 404:
raise
# collection 不存在则创建
dim = self._embedding_provider.dimension()
from qdrant_client.http.models import (
Distance,
VectorParams,
PayloadSchemaType,
)
client.create_collection(
collection_name=FEW_SHOT_COLLECTION,
vectors_config=VectorParams(size=dim, distance=Distance.COSINE),
)
for field, field_type in [
("sample_id", PayloadSchemaType.KEYWORD),
("scene", PayloadSchemaType.KEYWORD),
("label", PayloadSchemaType.KEYWORD),
("domain", PayloadSchemaType.KEYWORD),
("risk_type", PayloadSchemaType.KEYWORD),
("status", PayloadSchemaType.KEYWORD),
]:
try:
client.create_payload_index(
collection_name=FEW_SHOT_COLLECTION,
field_name=field,
field_schema=field_type,
)
except Exception:
logger.debug("payload index 创建跳过 field=%s", field, exc_info=True)
self._ensured = True
logger.info("few-shot collection 创建成功 dim=%s", dim)
return True
except Exception:
logger.warning("few-shot collection 初始化失败,本轮操作跳过", exc_info=True)
return False
def upsert(self, sample: Any) -> str | None:
"""把一条样本向量化并写入 Qdrant返回 vector_id失败返回 None。"""
if not self._ensure_collection():
return None
client = self._client
try:
vector = self._embedding_provider.embed([sample.case_text])[0]
except Exception:
logger.warning("few-shot embedding 失败 sample_key=%s", getattr(sample, "sample_key", ""), exc_info=True)
return None
vector_id = uuid.uuid4().hex
payload = {
"sample_id": sample.id,
"scene": sample.scene,
"label": sample.label,
"domain": sample.domain,
"risk_type": sample.risk_type,
"risk_level": sample.risk_level,
"status": getattr(sample, "status", "active"),
"conclusion_text": sample.conclusion_text,
"payload_json": sample.payload_json,
}
try:
client.upsert(
collection_name=FEW_SHOT_COLLECTION,
points=[{"id": vector_id, "vector": vector, "payload": payload}],
)
return vector_id
except Exception:
logger.warning("few-shot upsert 失败 sample_key=%s", getattr(sample, "sample_key", ""), exc_info=True)
return None
def search(
self,
case_text: str,
*,
scene: str | None = None,
labels: list[str] | None = None,
top_k: int = 3,
) -> list[dict[str, Any]]:
"""按 case_text 检索相似样本,可按 scene/label 过滤。失败返回空列表。"""
if not case_text or not self._ensure_collection():
return []
client = self._client
try:
vector = self._embedding_provider.embed([case_text])[0]
except Exception:
logger.warning("few-shot 检索 embedding 失败", exc_info=True)
return []
must: list[dict[str, Any]] = [{"key": "status", "match": {"value": "active"}}]
if scene:
must.append({"key": "scene", "match": {"value": scene}})
if labels:
must.append({"key": "label", "match": {"any": labels}})
try:
from qdrant_client.http.models import Filter
results = client.query_points(
collection_name=FEW_SHOT_COLLECTION,
query=vector,
query_filter=Filter(must=must),
limit=top_k,
with_payload=True,
).points
except Exception:
logger.warning("few-shot 检索失败", exc_info=True)
return []
hits: list[dict[str, Any]] = []
for point in results:
payload = getattr(point, "payload", None) or {}
hits.append(
{
"sample_id": payload.get("sample_id"),
"score": float(getattr(point, "score", 0.0)),
"label": payload.get("label"),
"domain": payload.get("domain"),
"risk_type": payload.get("risk_type"),
"conclusion_text": payload.get("conclusion_text") or "",
"payload_json": payload.get("payload_json") or {},
}
)
return hits
def delete_by_vector_id(self, vector_id: str) -> bool:
"""按 vector_id 删除向量,失败返回 False。"""
if not vector_id or not self._ensure_collection():
return False
try:
self._client.delete(
collection_name=FEW_SHOT_COLLECTION,
points_selector=[vector_id],
)
return True
except Exception:
logger.warning("few-shot 删除失败 vector_id=%s", vector_id, exc_info=True)
return False

View File

@@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import os
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
from decimal import Decimal from decimal import Decimal
from typing import Any from typing import Any
@@ -8,6 +9,7 @@ from sqlalchemy import func, select
from sqlalchemy.orm import Session, joinedload from sqlalchemy.orm import Session, joinedload
from app.algorithem.risk_graph import RiskHistoryStats, RiskObservationDraft from app.algorithem.risk_graph import RiskHistoryStats, RiskObservationDraft
from app.core.logging import get_logger
from app.db.base import Base from app.db.base import Base
from app.models.financial_record import ExpenseClaim from app.models.financial_record import ExpenseClaim
from app.models.risk_observation import RiskObservation, RiskObservationFeedback from app.models.risk_observation import RiskObservation, RiskObservationFeedback
@@ -17,6 +19,8 @@ from app.schemas.risk_observation import (
) )
from app.services.expense_claim_risk_stage import normalize_risk_business_stage from app.services.expense_claim_risk_stage import normalize_risk_business_stage
logger = get_logger("app.services.risk_observations")
HIGH_LEVELS = {"high", "critical"} HIGH_LEVELS = {"high", "critical"}
SEVERITY_SCORE = { SEVERITY_SCORE = {
"low": 32, "low": 32,
@@ -322,8 +326,27 @@ class RiskObservationService:
observation.status, observation.feedback_status = mapped observation.status, observation.feedback_status = mapped
self.db.commit() self.db.commit()
self.db.refresh(feedback) self.db.refresh(feedback)
self._maybe_ingest_few_shot(observation, feedback)
return feedback return feedback
def _maybe_ingest_few_shot(
self,
observation: RiskObservation,
feedback: RiskObservationFeedback,
) -> None:
"""人工确认/误报后把样本沉淀进 few-shot 池,任何失败都不影响主流程。"""
if os.environ.get("FEW_SHOT_INJECTION_ENABLED", "true").strip().lower() in {"0", "false", "no"}:
return
if observation.feedback_status not in {"confirmed", "false_positive"}:
return
try:
from app.services.few_shot_ingestion import FewShotIngestionService
FewShotIngestionService(self.db).ingest_observation_feedback(observation, feedback)
except Exception:
logger.exception("few-shot ingestion failed for observation %s", observation.id)
def summarize_dashboard( def summarize_dashboard(
self, self,
*, *,

View File

@@ -234,6 +234,10 @@ class RiskRuleGenerationService:
} }
for item in fields for item in fields
] ]
few_shot_samples = self._retrieve_few_shot_samples(
domain=domain,
natural_language=natural_language,
)
messages = build_risk_rule_compiler_messages( messages = build_risk_rule_compiler_messages(
domain=domain, domain=domain,
domain_label=BUSINESS_DOMAIN_LABELS[domain], domain_label=BUSINESS_DOMAIN_LABELS[domain],
@@ -243,6 +247,7 @@ class RiskRuleGenerationService:
expense_category_label=expense_category_label, expense_category_label=expense_category_label,
natural_language=natural_language, natural_language=natural_language,
available_fields=field_payload, available_fields=field_payload,
few_shot_samples=few_shot_samples,
) )
answer = self.runtime_chat_service.complete( answer = self.runtime_chat_service.complete(
messages, messages,
@@ -263,6 +268,29 @@ class RiskRuleGenerationService:
payload = unwrap_semantic_plan_payload(payload) payload = unwrap_semantic_plan_payload(payload)
return self._sanitize_model_draft(payload, fields=fields) return self._sanitize_model_draft(payload, fields=fields)
def _retrieve_few_shot_samples(
self,
*,
domain: str,
natural_language: str,
) -> list[dict[str, Any]]:
"""检索已确认历史样本,失败降级为空列表。"""
import os
if os.environ.get("FEW_SHOT_INJECTION_ENABLED", "true").strip().lower() in {"0", "false", "no"}:
return []
try:
from app.services.few_shot_retrieval import FewShotRetriever
retriever = FewShotRetriever.from_session(self.db)
return retriever.retrieve_for_risk_rule_generation(
domain=domain,
natural_language=natural_language,
)
except Exception:
return []
def _sanitize_model_draft( def _sanitize_model_draft(
self, self,
payload: dict[str, Any], payload: dict[str, Any],

View File

@@ -14,10 +14,15 @@ def build_risk_rule_compiler_messages(
expense_category_label: str, expense_category_label: str,
natural_language: str, natural_language: str,
available_fields: list[dict[str, Any]], available_fields: list[dict[str, Any]],
few_shot_samples: list[dict[str, Any]] | None = None,
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
"""构造自然语言规则编译提示词。 """构造自然语言规则编译提示词。
大模型只负责把业务语言拆成“语义计划”,后端会校验字段、操作符和模板。 大模型只负责把业务语言拆成“语义计划”,后端会校验字段、操作符和模板。
``few_shot_samples`` 是从已确认历史样本中检索出来的相似案例,会被合并进
``examples`` 字段并标注 ``source: "historical_confirmed"``,让编译器参考
过往人工结论。传 ``None`` 或空列表时行为与历史完全一致(向后兼容)。
""" """
schema = { schema = {
@@ -161,6 +166,20 @@ def build_risk_rule_compiler_messages(
}, },
} }
] ]
historical_examples: list[dict[str, Any]] = []
if few_shot_samples:
for sample in few_shot_samples:
historical_examples.append(
{
"source": "historical_confirmed",
"label": sample.get("label"),
"domain": sample.get("domain") or "",
"risk_type": sample.get("risk_type") or "",
"conclusion": sample.get("conclusion") or "",
"context": sample.get("context") or {},
}
)
merged_examples = historical_examples + examples
return [ return [
{ {
"role": "system", "role": "system",
@@ -186,7 +205,7 @@ def build_risk_rule_compiler_messages(
"natural_language": natural_language, "natural_language": natural_language,
"available_fields": available_fields, "available_fields": available_fields,
"required_json_shape": response_schema, "required_json_shape": response_schema,
"examples": examples, "examples": merged_examples,
}, },
ensure_ascii=False, ensure_ascii=False,
), ),