Files
X-Financial/server/src/app/services/few_shot_store.py
caoxiaozhu 3a9d154783 feat(flywheel): few-shot 在线检索注入打通风险规则编译链路
- 新增 FewShotStore:独立 Qdrant collection few_shot_samples,向量 upsert/search/delete,
  全程失败降级不阻塞主链路
- 新增 FewShotIngestionService:RiskObservation confirmed/false_positive → FewShotSample +
  向量,带 sample_key 幂等去重
- 新增 FewShotRetriever:按 case 特征检索相似历史样本,去重 + token 预算 + 单条字符上限裁剪
- risk_observations.create_feedback commit 后挂 hook 自动入库,带 feature flag 和 try/except 兜底
- risk_rule_generation_prompt 新增 few_shot_samples 可选参数,合并进 examples 并标
  source=historical_confirmed;risk_rule_generation 构造 prompt 前调 retriever,失败降级为空
2026-07-03 13:55:52 +08:00

215 lines
7.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Few-shot 样本的 Qdrant 向量存储。
独立于 LightRAG 的 Qdrant 客户端,使用专用 collection ``few_shot_samples``
与知识库 RAG 的 collection 隔离。所有操作失败都不抛异常(记日志返回空),
保证主链路不阻塞。
向量来自 :class:`EmbeddingProvider`payload 带业务过滤字段scene/label/domain/risk_type
检索时按这些字段过滤 + 向量相似度排序。
"""
from __future__ import annotations
import os
import uuid
from typing import Any
from app.core.logging import get_logger
from app.services.knowledge_rag import _resolve_default_qdrant_url
logger = get_logger("app.services.few_shot_store")
FEW_SHOT_COLLECTION = "few_shot_samples"
def _resolve_qdrant_config() -> tuple[str, str]:
"""复用 knowledge_rag 的 Qdrant URL/key 解析逻辑。"""
url = os.environ.get("QDRANT_URL", "").strip() or _resolve_default_qdrant_url()
api_key = os.environ.get("QDRANT_API_KEY", "").strip()
return url, api_key
class FewShotStore:
"""对 Qdrant 的轻量封装,专供 few-shot 样本检索使用。
设计要点:
- 惰性创建 client 和 collection首次操作时初始化。
- 所有公共方法吞异常(返回空/False主链路永远不被拖崩。
- 向量写入和检索都依赖外部传入的 :class:`EmbeddingProvider`
由调用方保证与配置一致。
"""
def __init__(self, embedding_provider: Any) -> None:
self._embedding_provider = embedding_provider
self._client: Any = None
self._ensured = False
def _client_or_none(self) -> Any:
"""惰性初始化 QdrantClient失败返回 None。"""
if self._client is not None:
return self._client
try:
from qdrant_client import QdrantClient
url, api_key = _resolve_qdrant_config()
self._client = QdrantClient(url=url, api_key=api_key or None)
except Exception:
logger.warning("few-shot QdrantClient 初始化失败,本轮操作跳过", exc_info=True)
self._client = None
return self._client
def _ensure_collection(self) -> bool:
"""确保 collection 存在,成功返回 True。"""
if self._ensured:
return True
client = self._client_or_none()
if client is None:
return False
try:
from qdrant_client.http.exceptions import UnexpectedResponse
try:
client.get_collection(FEW_SHOT_COLLECTION)
self._ensured = True
return True
except UnexpectedResponse as exc:
if exc.status_code != 404:
raise
# collection 不存在则创建
dim = self._embedding_provider.dimension()
from qdrant_client.http.models import (
Distance,
VectorParams,
PayloadSchemaType,
)
client.create_collection(
collection_name=FEW_SHOT_COLLECTION,
vectors_config=VectorParams(size=dim, distance=Distance.COSINE),
)
for field, field_type in [
("sample_id", PayloadSchemaType.KEYWORD),
("scene", PayloadSchemaType.KEYWORD),
("label", PayloadSchemaType.KEYWORD),
("domain", PayloadSchemaType.KEYWORD),
("risk_type", PayloadSchemaType.KEYWORD),
("status", PayloadSchemaType.KEYWORD),
]:
try:
client.create_payload_index(
collection_name=FEW_SHOT_COLLECTION,
field_name=field,
field_schema=field_type,
)
except Exception:
logger.debug("payload index 创建跳过 field=%s", field, exc_info=True)
self._ensured = True
logger.info("few-shot collection 创建成功 dim=%s", dim)
return True
except Exception:
logger.warning("few-shot collection 初始化失败,本轮操作跳过", exc_info=True)
return False
def upsert(self, sample: Any) -> str | None:
"""把一条样本向量化并写入 Qdrant返回 vector_id失败返回 None。"""
if not self._ensure_collection():
return None
client = self._client
try:
vector = self._embedding_provider.embed([sample.case_text])[0]
except Exception:
logger.warning("few-shot embedding 失败 sample_key=%s", getattr(sample, "sample_key", ""), exc_info=True)
return None
vector_id = uuid.uuid4().hex
payload = {
"sample_id": sample.id,
"scene": sample.scene,
"label": sample.label,
"domain": sample.domain,
"risk_type": sample.risk_type,
"risk_level": sample.risk_level,
"status": getattr(sample, "status", "active"),
"conclusion_text": sample.conclusion_text,
"payload_json": sample.payload_json,
}
try:
client.upsert(
collection_name=FEW_SHOT_COLLECTION,
points=[{"id": vector_id, "vector": vector, "payload": payload}],
)
return vector_id
except Exception:
logger.warning("few-shot upsert 失败 sample_key=%s", getattr(sample, "sample_key", ""), exc_info=True)
return None
def search(
self,
case_text: str,
*,
scene: str | None = None,
labels: list[str] | None = None,
top_k: int = 3,
) -> list[dict[str, Any]]:
"""按 case_text 检索相似样本,可按 scene/label 过滤。失败返回空列表。"""
if not case_text or not self._ensure_collection():
return []
client = self._client
try:
vector = self._embedding_provider.embed([case_text])[0]
except Exception:
logger.warning("few-shot 检索 embedding 失败", exc_info=True)
return []
must: list[dict[str, Any]] = [{"key": "status", "match": {"value": "active"}}]
if scene:
must.append({"key": "scene", "match": {"value": scene}})
if labels:
must.append({"key": "label", "match": {"any": labels}})
try:
from qdrant_client.http.models import Filter
results = client.query_points(
collection_name=FEW_SHOT_COLLECTION,
query=vector,
query_filter=Filter(must=must),
limit=top_k,
with_payload=True,
).points
except Exception:
logger.warning("few-shot 检索失败", exc_info=True)
return []
hits: list[dict[str, Any]] = []
for point in results:
payload = getattr(point, "payload", None) or {}
hits.append(
{
"sample_id": payload.get("sample_id"),
"score": float(getattr(point, "score", 0.0)),
"label": payload.get("label"),
"domain": payload.get("domain"),
"risk_type": payload.get("risk_type"),
"conclusion_text": payload.get("conclusion_text") or "",
"payload_json": payload.get("payload_json") or {},
}
)
return hits
def delete_by_vector_id(self, vector_id: str) -> bool:
"""按 vector_id 删除向量,失败返回 False。"""
if not vector_id or not self._ensure_collection():
return False
try:
self._client.delete(
collection_name=FEW_SHOT_COLLECTION,
points_selector=[vector_id],
)
return True
except Exception:
logger.warning("few-shot 删除失败 vector_id=%s", vector_id, exc_info=True)
return False