"""Few-shot 样本的 Qdrant 向量存储。 独立于 LightRAG 的 Qdrant 客户端,使用专用 collection ``few_shot_samples``, 与知识库 RAG 的 collection 隔离。所有操作失败都不抛异常(记日志返回空), 保证主链路不阻塞。 向量来自 :class:`EmbeddingProvider`,payload 带业务过滤字段(scene/label/domain/risk_type), 检索时按这些字段过滤 + 向量相似度排序。 """ from __future__ import annotations import os import uuid from typing import Any from app.core.logging import get_logger from app.services.knowledge_rag import _resolve_default_qdrant_url logger = get_logger("app.services.few_shot_store") FEW_SHOT_COLLECTION = "few_shot_samples" def _resolve_qdrant_config() -> tuple[str, str]: """复用 knowledge_rag 的 Qdrant URL/key 解析逻辑。""" url = os.environ.get("QDRANT_URL", "").strip() or _resolve_default_qdrant_url() api_key = os.environ.get("QDRANT_API_KEY", "").strip() return url, api_key class FewShotStore: """对 Qdrant 的轻量封装,专供 few-shot 样本检索使用。 设计要点: - 惰性创建 client 和 collection,首次操作时初始化。 - 所有公共方法吞异常(返回空/False),主链路永远不被拖崩。 - 向量写入和检索都依赖外部传入的 :class:`EmbeddingProvider`, 由调用方保证与配置一致。 """ def __init__(self, embedding_provider: Any) -> None: self._embedding_provider = embedding_provider self._client: Any = None self._ensured = False def _client_or_none(self) -> Any: """惰性初始化 QdrantClient,失败返回 None。""" if self._client is not None: return self._client try: from qdrant_client import QdrantClient url, api_key = _resolve_qdrant_config() self._client = QdrantClient(url=url, api_key=api_key or None) except Exception: logger.warning("few-shot QdrantClient 初始化失败,本轮操作跳过", exc_info=True) self._client = None return self._client def _ensure_collection(self) -> bool: """确保 collection 存在,成功返回 True。""" if self._ensured: return True client = self._client_or_none() if client is None: return False try: from qdrant_client.http.exceptions import UnexpectedResponse try: client.get_collection(FEW_SHOT_COLLECTION) self._ensured = True return True except UnexpectedResponse as exc: if exc.status_code != 404: raise # collection 不存在则创建 dim = self._embedding_provider.dimension() from qdrant_client.http.models import ( Distance, VectorParams, PayloadSchemaType, ) client.create_collection( collection_name=FEW_SHOT_COLLECTION, vectors_config=VectorParams(size=dim, distance=Distance.COSINE), ) for field, field_type in [ ("sample_id", PayloadSchemaType.KEYWORD), ("scene", PayloadSchemaType.KEYWORD), ("label", PayloadSchemaType.KEYWORD), ("domain", PayloadSchemaType.KEYWORD), ("risk_type", PayloadSchemaType.KEYWORD), ("status", PayloadSchemaType.KEYWORD), ]: try: client.create_payload_index( collection_name=FEW_SHOT_COLLECTION, field_name=field, field_schema=field_type, ) except Exception: logger.debug("payload index 创建跳过 field=%s", field, exc_info=True) self._ensured = True logger.info("few-shot collection 创建成功 dim=%s", dim) return True except Exception: logger.warning("few-shot collection 初始化失败,本轮操作跳过", exc_info=True) return False def upsert(self, sample: Any) -> str | None: """把一条样本向量化并写入 Qdrant,返回 vector_id,失败返回 None。""" if not self._ensure_collection(): return None client = self._client try: vector = self._embedding_provider.embed([sample.case_text])[0] except Exception: logger.warning("few-shot embedding 失败 sample_key=%s", getattr(sample, "sample_key", ""), exc_info=True) return None vector_id = uuid.uuid4().hex payload = { "sample_id": sample.id, "scene": sample.scene, "label": sample.label, "domain": sample.domain, "risk_type": sample.risk_type, "risk_level": sample.risk_level, "status": getattr(sample, "status", "active"), "conclusion_text": sample.conclusion_text, "payload_json": sample.payload_json, } try: client.upsert( collection_name=FEW_SHOT_COLLECTION, points=[{"id": vector_id, "vector": vector, "payload": payload}], ) return vector_id except Exception: logger.warning("few-shot upsert 失败 sample_key=%s", getattr(sample, "sample_key", ""), exc_info=True) return None def search( self, case_text: str, *, scene: str | None = None, labels: list[str] | None = None, top_k: int = 3, ) -> list[dict[str, Any]]: """按 case_text 检索相似样本,可按 scene/label 过滤。失败返回空列表。""" if not case_text or not self._ensure_collection(): return [] client = self._client try: vector = self._embedding_provider.embed([case_text])[0] except Exception: logger.warning("few-shot 检索 embedding 失败", exc_info=True) return [] must: list[dict[str, Any]] = [{"key": "status", "match": {"value": "active"}}] if scene: must.append({"key": "scene", "match": {"value": scene}}) if labels: must.append({"key": "label", "match": {"any": labels}}) try: from qdrant_client.http.models import Filter results = client.query_points( collection_name=FEW_SHOT_COLLECTION, query=vector, query_filter=Filter(must=must), limit=top_k, with_payload=True, ).points except Exception: logger.warning("few-shot 检索失败", exc_info=True) return [] hits: list[dict[str, Any]] = [] for point in results: payload = getattr(point, "payload", None) or {} hits.append( { "sample_id": payload.get("sample_id"), "score": float(getattr(point, "score", 0.0)), "label": payload.get("label"), "domain": payload.get("domain"), "risk_type": payload.get("risk_type"), "conclusion_text": payload.get("conclusion_text") or "", "payload_json": payload.get("payload_json") or {}, } ) return hits def delete_by_vector_id(self, vector_id: str) -> bool: """按 vector_id 删除向量,失败返回 False。""" if not vector_id or not self._ensure_collection(): return False try: self._client.delete( collection_name=FEW_SHOT_COLLECTION, points_selector=[vector_id], ) return True except Exception: logger.warning("few-shot 删除失败 vector_id=%s", vector_id, exc_info=True) return False