215 lines
7.9 KiB
Python
215 lines
7.9 KiB
Python
|
|
"""Few-shot 样本的 Qdrant 向量存储。
|
|||
|
|
|
|||
|
|
独立于 LightRAG 的 Qdrant 客户端,使用专用 collection ``few_shot_samples``,
|
|||
|
|
与知识库 RAG 的 collection 隔离。所有操作失败都不抛异常(记日志返回空),
|
|||
|
|
保证主链路不阻塞。
|
|||
|
|
|
|||
|
|
向量来自 :class:`EmbeddingProvider`,payload 带业务过滤字段(scene/label/domain/risk_type),
|
|||
|
|
检索时按这些字段过滤 + 向量相似度排序。
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
import os
|
|||
|
|
import uuid
|
|||
|
|
from typing import Any
|
|||
|
|
|
|||
|
|
from app.core.logging import get_logger
|
|||
|
|
from app.services.knowledge_rag import _resolve_default_qdrant_url
|
|||
|
|
|
|||
|
|
logger = get_logger("app.services.few_shot_store")
|
|||
|
|
|
|||
|
|
FEW_SHOT_COLLECTION = "few_shot_samples"
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _resolve_qdrant_config() -> tuple[str, str]:
|
|||
|
|
"""复用 knowledge_rag 的 Qdrant URL/key 解析逻辑。"""
|
|||
|
|
|
|||
|
|
url = os.environ.get("QDRANT_URL", "").strip() or _resolve_default_qdrant_url()
|
|||
|
|
api_key = os.environ.get("QDRANT_API_KEY", "").strip()
|
|||
|
|
return url, api_key
|
|||
|
|
|
|||
|
|
|
|||
|
|
class FewShotStore:
|
|||
|
|
"""对 Qdrant 的轻量封装,专供 few-shot 样本检索使用。
|
|||
|
|
|
|||
|
|
设计要点:
|
|||
|
|
- 惰性创建 client 和 collection,首次操作时初始化。
|
|||
|
|
- 所有公共方法吞异常(返回空/False),主链路永远不被拖崩。
|
|||
|
|
- 向量写入和检索都依赖外部传入的 :class:`EmbeddingProvider`,
|
|||
|
|
由调用方保证与配置一致。
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self, embedding_provider: Any) -> None:
|
|||
|
|
self._embedding_provider = embedding_provider
|
|||
|
|
self._client: Any = None
|
|||
|
|
self._ensured = False
|
|||
|
|
|
|||
|
|
def _client_or_none(self) -> Any:
|
|||
|
|
"""惰性初始化 QdrantClient,失败返回 None。"""
|
|||
|
|
|
|||
|
|
if self._client is not None:
|
|||
|
|
return self._client
|
|||
|
|
try:
|
|||
|
|
from qdrant_client import QdrantClient
|
|||
|
|
|
|||
|
|
url, api_key = _resolve_qdrant_config()
|
|||
|
|
self._client = QdrantClient(url=url, api_key=api_key or None)
|
|||
|
|
except Exception:
|
|||
|
|
logger.warning("few-shot QdrantClient 初始化失败,本轮操作跳过", exc_info=True)
|
|||
|
|
self._client = None
|
|||
|
|
return self._client
|
|||
|
|
|
|||
|
|
def _ensure_collection(self) -> bool:
|
|||
|
|
"""确保 collection 存在,成功返回 True。"""
|
|||
|
|
|
|||
|
|
if self._ensured:
|
|||
|
|
return True
|
|||
|
|
client = self._client_or_none()
|
|||
|
|
if client is None:
|
|||
|
|
return False
|
|||
|
|
try:
|
|||
|
|
from qdrant_client.http.exceptions import UnexpectedResponse
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
client.get_collection(FEW_SHOT_COLLECTION)
|
|||
|
|
self._ensured = True
|
|||
|
|
return True
|
|||
|
|
except UnexpectedResponse as exc:
|
|||
|
|
if exc.status_code != 404:
|
|||
|
|
raise
|
|||
|
|
# collection 不存在则创建
|
|||
|
|
dim = self._embedding_provider.dimension()
|
|||
|
|
from qdrant_client.http.models import (
|
|||
|
|
Distance,
|
|||
|
|
VectorParams,
|
|||
|
|
PayloadSchemaType,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
client.create_collection(
|
|||
|
|
collection_name=FEW_SHOT_COLLECTION,
|
|||
|
|
vectors_config=VectorParams(size=dim, distance=Distance.COSINE),
|
|||
|
|
)
|
|||
|
|
for field, field_type in [
|
|||
|
|
("sample_id", PayloadSchemaType.KEYWORD),
|
|||
|
|
("scene", PayloadSchemaType.KEYWORD),
|
|||
|
|
("label", PayloadSchemaType.KEYWORD),
|
|||
|
|
("domain", PayloadSchemaType.KEYWORD),
|
|||
|
|
("risk_type", PayloadSchemaType.KEYWORD),
|
|||
|
|
("status", PayloadSchemaType.KEYWORD),
|
|||
|
|
]:
|
|||
|
|
try:
|
|||
|
|
client.create_payload_index(
|
|||
|
|
collection_name=FEW_SHOT_COLLECTION,
|
|||
|
|
field_name=field,
|
|||
|
|
field_schema=field_type,
|
|||
|
|
)
|
|||
|
|
except Exception:
|
|||
|
|
logger.debug("payload index 创建跳过 field=%s", field, exc_info=True)
|
|||
|
|
self._ensured = True
|
|||
|
|
logger.info("few-shot collection 创建成功 dim=%s", dim)
|
|||
|
|
return True
|
|||
|
|
except Exception:
|
|||
|
|
logger.warning("few-shot collection 初始化失败,本轮操作跳过", exc_info=True)
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
def upsert(self, sample: Any) -> str | None:
|
|||
|
|
"""把一条样本向量化并写入 Qdrant,返回 vector_id,失败返回 None。"""
|
|||
|
|
|
|||
|
|
if not self._ensure_collection():
|
|||
|
|
return None
|
|||
|
|
client = self._client
|
|||
|
|
try:
|
|||
|
|
vector = self._embedding_provider.embed([sample.case_text])[0]
|
|||
|
|
except Exception:
|
|||
|
|
logger.warning("few-shot embedding 失败 sample_key=%s", getattr(sample, "sample_key", ""), exc_info=True)
|
|||
|
|
return None
|
|||
|
|
vector_id = uuid.uuid4().hex
|
|||
|
|
payload = {
|
|||
|
|
"sample_id": sample.id,
|
|||
|
|
"scene": sample.scene,
|
|||
|
|
"label": sample.label,
|
|||
|
|
"domain": sample.domain,
|
|||
|
|
"risk_type": sample.risk_type,
|
|||
|
|
"risk_level": sample.risk_level,
|
|||
|
|
"status": getattr(sample, "status", "active"),
|
|||
|
|
"conclusion_text": sample.conclusion_text,
|
|||
|
|
"payload_json": sample.payload_json,
|
|||
|
|
}
|
|||
|
|
try:
|
|||
|
|
client.upsert(
|
|||
|
|
collection_name=FEW_SHOT_COLLECTION,
|
|||
|
|
points=[{"id": vector_id, "vector": vector, "payload": payload}],
|
|||
|
|
)
|
|||
|
|
return vector_id
|
|||
|
|
except Exception:
|
|||
|
|
logger.warning("few-shot upsert 失败 sample_key=%s", getattr(sample, "sample_key", ""), exc_info=True)
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
def search(
|
|||
|
|
self,
|
|||
|
|
case_text: str,
|
|||
|
|
*,
|
|||
|
|
scene: str | None = None,
|
|||
|
|
labels: list[str] | None = None,
|
|||
|
|
top_k: int = 3,
|
|||
|
|
) -> list[dict[str, Any]]:
|
|||
|
|
"""按 case_text 检索相似样本,可按 scene/label 过滤。失败返回空列表。"""
|
|||
|
|
|
|||
|
|
if not case_text or not self._ensure_collection():
|
|||
|
|
return []
|
|||
|
|
client = self._client
|
|||
|
|
try:
|
|||
|
|
vector = self._embedding_provider.embed([case_text])[0]
|
|||
|
|
except Exception:
|
|||
|
|
logger.warning("few-shot 检索 embedding 失败", exc_info=True)
|
|||
|
|
return []
|
|||
|
|
must: list[dict[str, Any]] = [{"key": "status", "match": {"value": "active"}}]
|
|||
|
|
if scene:
|
|||
|
|
must.append({"key": "scene", "match": {"value": scene}})
|
|||
|
|
if labels:
|
|||
|
|
must.append({"key": "label", "match": {"any": labels}})
|
|||
|
|
try:
|
|||
|
|
from qdrant_client.http.models import Filter
|
|||
|
|
|
|||
|
|
results = client.query_points(
|
|||
|
|
collection_name=FEW_SHOT_COLLECTION,
|
|||
|
|
query=vector,
|
|||
|
|
query_filter=Filter(must=must),
|
|||
|
|
limit=top_k,
|
|||
|
|
with_payload=True,
|
|||
|
|
).points
|
|||
|
|
except Exception:
|
|||
|
|
logger.warning("few-shot 检索失败", exc_info=True)
|
|||
|
|
return []
|
|||
|
|
hits: list[dict[str, Any]] = []
|
|||
|
|
for point in results:
|
|||
|
|
payload = getattr(point, "payload", None) or {}
|
|||
|
|
hits.append(
|
|||
|
|
{
|
|||
|
|
"sample_id": payload.get("sample_id"),
|
|||
|
|
"score": float(getattr(point, "score", 0.0)),
|
|||
|
|
"label": payload.get("label"),
|
|||
|
|
"domain": payload.get("domain"),
|
|||
|
|
"risk_type": payload.get("risk_type"),
|
|||
|
|
"conclusion_text": payload.get("conclusion_text") or "",
|
|||
|
|
"payload_json": payload.get("payload_json") or {},
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
return hits
|
|||
|
|
|
|||
|
|
def delete_by_vector_id(self, vector_id: str) -> bool:
|
|||
|
|
"""按 vector_id 删除向量,失败返回 False。"""
|
|||
|
|
|
|||
|
|
if not vector_id or not self._ensure_collection():
|
|||
|
|
return False
|
|||
|
|
try:
|
|||
|
|
self._client.delete(
|
|||
|
|
collection_name=FEW_SHOT_COLLECTION,
|
|||
|
|
points_selector=[vector_id],
|
|||
|
|
)
|
|||
|
|
return True
|
|||
|
|
except Exception:
|
|||
|
|
logger.warning("few-shot 删除失败 vector_id=%s", vector_id, exc_info=True)
|
|||
|
|
return False
|