feat(flywheel): few-shot 在线检索注入打通风险规则编译链路
- 新增 FewShotStore:独立 Qdrant collection few_shot_samples,向量 upsert/search/delete, 全程失败降级不阻塞主链路 - 新增 FewShotIngestionService:RiskObservation confirmed/false_positive → FewShotSample + 向量,带 sample_key 幂等去重 - 新增 FewShotRetriever:按 case 特征检索相似历史样本,去重 + token 预算 + 单条字符上限裁剪 - risk_observations.create_feedback commit 后挂 hook 自动入库,带 feature flag 和 try/except 兜底 - risk_rule_generation_prompt 新增 few_shot_samples 可选参数,合并进 examples 并标 source=historical_confirmed;risk_rule_generation 构造 prompt 前调 retriever,失败降级为空
This commit is contained in:
177
server/src/app/services/few_shot_ingestion.py
Normal file
177
server/src/app/services/few_shot_ingestion.py
Normal file
@@ -0,0 +1,177 @@
|
|||||||
|
"""Few-shot 样本入库编排:RiskObservation → FewShotSample → Qdrant。
|
||||||
|
|
||||||
|
只处理人工确认为 confirmed / false_positive 的观测,把它转成一条
|
||||||
|
:class:`FewShotSample`,持久化到 DB,并同步向量到 Qdrant。
|
||||||
|
|
||||||
|
入库动作由 :meth:`RiskObservationService.create_feedback` 在 commit 后触发,
|
||||||
|
本服务全程吞异常(只记日志),保证反馈主流程不被 few-shot 链路拖崩。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.core.logging import get_logger
|
||||||
|
from app.models.few_shot_sample import FewShotSample
|
||||||
|
from app.models.risk_observation import RiskObservation, RiskObservationFeedback
|
||||||
|
from app.services.embedding_provider import EmbeddingProvider
|
||||||
|
from app.services.few_shot_store import FewShotStore
|
||||||
|
|
||||||
|
logger = get_logger("app.services.few_shot_ingestion")
|
||||||
|
|
||||||
|
# 仅这两个 feedback_status 视为已确认样本,会入库
|
||||||
|
CONFIRMED_LABELS = {"confirmed", "false_positive"}
|
||||||
|
|
||||||
|
# label → 自然语言结论(当 feedback.comment 缺失时兜底)
|
||||||
|
LABEL_CONCLUSION_FALLBACK = {
|
||||||
|
"confirmed": "经人工复核确认,该风险线索成立,需按规则拦截或补件。",
|
||||||
|
"false_positive": "经人工复核判定为误报,相似情形不应触发该风险规则。",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class FewShotIngestionService:
|
||||||
|
"""把已确认的风险观测沉淀为 few-shot 样本。"""
|
||||||
|
|
||||||
|
def __init__(self, db: Session) -> None:
|
||||||
|
self.db = db
|
||||||
|
|
||||||
|
def ingest_observation_feedback(
|
||||||
|
self,
|
||||||
|
observation: RiskObservation,
|
||||||
|
feedback: RiskObservationFeedback,
|
||||||
|
) -> FewShotSample | None:
|
||||||
|
"""人工确认/误报后调用,写入并同步向量。"""
|
||||||
|
|
||||||
|
label = observation.feedback_status
|
||||||
|
if label not in CONFIRMED_LABELS:
|
||||||
|
return None
|
||||||
|
|
||||||
|
sample_key = f"obs:{observation.id}"
|
||||||
|
sample = self.db.scalar(
|
||||||
|
select(FewShotSample).where(FewShotSample.sample_key == sample_key)
|
||||||
|
)
|
||||||
|
|
||||||
|
domain = self._extract_domain(observation)
|
||||||
|
case_text = self._build_case_text(observation)
|
||||||
|
conclusion_text = self._build_conclusion_text(observation, feedback, label)
|
||||||
|
payload = self._build_payload(observation, feedback, label)
|
||||||
|
|
||||||
|
if sample is None:
|
||||||
|
sample = FewShotSample(
|
||||||
|
sample_key=sample_key,
|
||||||
|
source_observation_id=observation.id,
|
||||||
|
scene="risk_rule_generation",
|
||||||
|
domain=domain,
|
||||||
|
risk_type=observation.risk_type or "",
|
||||||
|
risk_level=observation.risk_level or "",
|
||||||
|
label=label,
|
||||||
|
case_text=case_text,
|
||||||
|
conclusion_text=conclusion_text,
|
||||||
|
payload_json=payload,
|
||||||
|
status="active",
|
||||||
|
)
|
||||||
|
self.db.add(sample)
|
||||||
|
else:
|
||||||
|
sample.label = label
|
||||||
|
sample.domain = domain
|
||||||
|
sample.risk_type = observation.risk_type or ""
|
||||||
|
sample.risk_level = observation.risk_level or ""
|
||||||
|
sample.case_text = case_text
|
||||||
|
sample.conclusion_text = conclusion_text
|
||||||
|
sample.payload_json = payload
|
||||||
|
sample.status = "active"
|
||||||
|
sample.vector_id = sample.vector_id
|
||||||
|
try:
|
||||||
|
self.db.commit()
|
||||||
|
self.db.refresh(sample)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("few-shot 样本持久化失败 observation_id=%s", observation.id)
|
||||||
|
self.db.rollback()
|
||||||
|
return None
|
||||||
|
|
||||||
|
vector_id = self._store().upsert(sample)
|
||||||
|
if vector_id:
|
||||||
|
sample.vector_id = vector_id
|
||||||
|
try:
|
||||||
|
self.db.commit()
|
||||||
|
except Exception:
|
||||||
|
logger.warning("few-shot vector_id 回写失败 sample_id=%s", sample.id)
|
||||||
|
return sample
|
||||||
|
|
||||||
|
def retract_observation(self, observation_id: str) -> bool:
|
||||||
|
"""观测被撤销时删掉对应样本及其向量。"""
|
||||||
|
|
||||||
|
sample = self.db.scalar(
|
||||||
|
select(FewShotSample).where(FewShotSample.source_observation_id == observation_id)
|
||||||
|
)
|
||||||
|
if sample is None:
|
||||||
|
return False
|
||||||
|
if sample.vector_id:
|
||||||
|
self._store().delete_by_vector_id(sample.vector_id)
|
||||||
|
try:
|
||||||
|
self.db.delete(sample)
|
||||||
|
self.db.commit()
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
logger.exception("few-shot 样本删除失败 observation_id=%s", observation_id)
|
||||||
|
self.db.rollback()
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _store(self) -> FewShotStore:
|
||||||
|
provider = EmbeddingProvider.from_settings(self.db)
|
||||||
|
return FewShotStore(provider)
|
||||||
|
|
||||||
|
def _extract_domain(self, observation: RiskObservation) -> str:
|
||||||
|
ontology = observation.ontology_json or {}
|
||||||
|
return str(ontology.get("domain") or "")
|
||||||
|
|
||||||
|
def _build_case_text(self, observation: RiskObservation) -> str:
|
||||||
|
parts = [
|
||||||
|
observation.title or "",
|
||||||
|
observation.description or "",
|
||||||
|
observation.risk_signal or "",
|
||||||
|
observation.risk_type or "",
|
||||||
|
]
|
||||||
|
ontology = observation.ontology_json or {}
|
||||||
|
scenario = ontology.get("scenario")
|
||||||
|
if scenario:
|
||||||
|
parts.append(f"场景:{scenario}")
|
||||||
|
risk_signals = ontology.get("risk_signals")
|
||||||
|
if isinstance(risk_signals, list) and risk_signals:
|
||||||
|
parts.append("信号:" + "|".join(str(s) for s in risk_signals))
|
||||||
|
return "\n".join(part for part in parts if part).strip()
|
||||||
|
|
||||||
|
def _build_conclusion_text(
|
||||||
|
self,
|
||||||
|
observation: RiskObservation,
|
||||||
|
feedback: RiskObservationFeedback,
|
||||||
|
label: str,
|
||||||
|
) -> str:
|
||||||
|
comment = (feedback.comment or "").strip()
|
||||||
|
if comment:
|
||||||
|
return f"[{label}] {comment}"
|
||||||
|
return LABEL_CONCLUSION_FALLBACK.get(label, label)
|
||||||
|
|
||||||
|
def _build_payload(
|
||||||
|
self,
|
||||||
|
observation: RiskObservation,
|
||||||
|
feedback: RiskObservationFeedback,
|
||||||
|
label: str,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"label": label,
|
||||||
|
"risk_type": observation.risk_type,
|
||||||
|
"risk_signal": observation.risk_signal,
|
||||||
|
"risk_level": observation.risk_level,
|
||||||
|
"feedback_type": feedback.feedback_type,
|
||||||
|
"feedback_comment": feedback.comment or "",
|
||||||
|
"feedback_actor": feedback.actor or "",
|
||||||
|
"ontology": observation.ontology_json or {},
|
||||||
|
"policy_refs": observation.policy_refs_json or [],
|
||||||
|
"evidence": observation.evidence_json or [],
|
||||||
|
"subject_label": observation.subject_label or "",
|
||||||
|
"claim_no": observation.claim_no or "",
|
||||||
|
}
|
||||||
122
server/src/app/services/few_shot_retrieval.py
Normal file
122
server/src/app/services/few_shot_retrieval.py
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
"""Few-shot 检索器:按当前 case 特征检索相似历史样本,拼成注入块。
|
||||||
|
|
||||||
|
从 :class:`FewShotStore` 取相似样本,转成可供 prompt 构造函数直接使用的结构。
|
||||||
|
带 token 预算裁剪和去重,确保不撑爆 prompt。
|
||||||
|
|
||||||
|
典型用法(在构造 prompt 之前调用)::
|
||||||
|
|
||||||
|
retriever = FewShotRetriever.from_session(session)
|
||||||
|
samples = retriever.retrieve_for_risk_rule_generation(
|
||||||
|
domain="travel", natural_language="票据城市与申报地不一致"
|
||||||
|
)
|
||||||
|
messages = build_risk_rule_compiler_messages(
|
||||||
|
...,
|
||||||
|
few_shot_samples=samples,
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.core.logging import get_logger
|
||||||
|
from app.services.embedding_provider import EmbeddingProvider
|
||||||
|
from app.services.few_shot_store import FewShotStore
|
||||||
|
|
||||||
|
logger = get_logger("app.services.few_shot_retrieval")
|
||||||
|
|
||||||
|
# 单条 few-shot 样本估算 token 数(用于预算裁剪)
|
||||||
|
SAMPLE_TOKEN_BUDGET = 1200
|
||||||
|
# 单条样本最大字符数,超长直接截断结论,避免撑爆 prompt
|
||||||
|
SINGLE_SAMPLE_MAX_CHARS = 400
|
||||||
|
# 历史样本最多注入条数(与原内置 examples 合并后总量受限)
|
||||||
|
MAX_HISTORICAL_SAMPLES = 3
|
||||||
|
|
||||||
|
|
||||||
|
class FewShotRetriever:
|
||||||
|
"""按 case 特征检索已确认样本,返回 prompt 可直接消费的结构。"""
|
||||||
|
|
||||||
|
def __init__(self, store: FewShotStore) -> None:
|
||||||
|
self._store = store
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_session(cls, session: Session) -> "FewShotRetriever":
|
||||||
|
provider = EmbeddingProvider.from_settings(session)
|
||||||
|
return cls(FewShotStore(provider))
|
||||||
|
|
||||||
|
def retrieve_for_risk_rule_generation(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
domain: str = "",
|
||||||
|
risk_type: str = "",
|
||||||
|
natural_language: str,
|
||||||
|
top_k: int = MAX_HISTORICAL_SAMPLES,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""检索与当前规则需求相似的历史样本,返回注入块列表。"""
|
||||||
|
|
||||||
|
case_text = self._build_case_text(
|
||||||
|
natural_language=natural_language,
|
||||||
|
domain=domain,
|
||||||
|
risk_type=risk_type,
|
||||||
|
)
|
||||||
|
if not case_text:
|
||||||
|
return []
|
||||||
|
hits = self._store.search(
|
||||||
|
case_text,
|
||||||
|
scene="risk_rule_generation",
|
||||||
|
labels=["confirmed", "false_positive"],
|
||||||
|
top_k=top_k,
|
||||||
|
)
|
||||||
|
return self._hits_to_injection_blocks(hits)
|
||||||
|
|
||||||
|
def _build_case_text(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
natural_language: str,
|
||||||
|
domain: str = "",
|
||||||
|
risk_type: str = "",
|
||||||
|
) -> str:
|
||||||
|
parts = [natural_language, domain, risk_type]
|
||||||
|
return "\n".join(p for p in parts if p).strip()
|
||||||
|
|
||||||
|
def _hits_to_injection_blocks(
|
||||||
|
self,
|
||||||
|
hits: list[dict[str, Any]],
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""把检索命中转成 prompt 可消费的块,做去重和预算裁剪。"""
|
||||||
|
|
||||||
|
blocks: list[dict[str, Any]] = []
|
||||||
|
seen_conclusions: set[str] = set()
|
||||||
|
budget = SAMPLE_TOKEN_BUDGET
|
||||||
|
for hit in hits:
|
||||||
|
conclusion = (hit.get("conclusion_text") or "").strip()
|
||||||
|
if not conclusion or conclusion in seen_conclusions:
|
||||||
|
continue
|
||||||
|
# 超长结论截断到上限,避免单条样本占用过多预算
|
||||||
|
if len(conclusion) > SINGLE_SAMPLE_MAX_CHARS:
|
||||||
|
conclusion = conclusion[:SINGLE_SAMPLE_MAX_CHARS]
|
||||||
|
payload = hit.get("payload_json") or {}
|
||||||
|
block = {
|
||||||
|
"source": "historical_confirmed",
|
||||||
|
"label": hit.get("label"),
|
||||||
|
"domain": hit.get("domain") or "",
|
||||||
|
"risk_type": hit.get("risk_type") or "",
|
||||||
|
"score": round(float(hit.get("score") or 0.0), 4),
|
||||||
|
"conclusion": conclusion,
|
||||||
|
"context": {
|
||||||
|
"risk_signal": payload.get("risk_signal") or "",
|
||||||
|
"risk_level": payload.get("risk_level") or "",
|
||||||
|
"ontology": payload.get("ontology") or {},
|
||||||
|
"feedback_comment": payload.get("feedback_comment") or "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
# 粗略 token 估算(按字符数 / 1.6 近似中文 token 比)
|
||||||
|
estimated_tokens = int(len(conclusion) / 1.6) + 40
|
||||||
|
if estimated_tokens > budget:
|
||||||
|
break
|
||||||
|
budget -= estimated_tokens
|
||||||
|
blocks.append(block)
|
||||||
|
seen_conclusions.add(conclusion)
|
||||||
|
return blocks
|
||||||
214
server/src/app/services/few_shot_store.py
Normal file
214
server/src/app/services/few_shot_store.py
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
"""Few-shot 样本的 Qdrant 向量存储。
|
||||||
|
|
||||||
|
独立于 LightRAG 的 Qdrant 客户端,使用专用 collection ``few_shot_samples``,
|
||||||
|
与知识库 RAG 的 collection 隔离。所有操作失败都不抛异常(记日志返回空),
|
||||||
|
保证主链路不阻塞。
|
||||||
|
|
||||||
|
向量来自 :class:`EmbeddingProvider`,payload 带业务过滤字段(scene/label/domain/risk_type),
|
||||||
|
检索时按这些字段过滤 + 向量相似度排序。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.core.logging import get_logger
|
||||||
|
from app.services.knowledge_rag import _resolve_default_qdrant_url
|
||||||
|
|
||||||
|
logger = get_logger("app.services.few_shot_store")
|
||||||
|
|
||||||
|
FEW_SHOT_COLLECTION = "few_shot_samples"
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_qdrant_config() -> tuple[str, str]:
|
||||||
|
"""复用 knowledge_rag 的 Qdrant URL/key 解析逻辑。"""
|
||||||
|
|
||||||
|
url = os.environ.get("QDRANT_URL", "").strip() or _resolve_default_qdrant_url()
|
||||||
|
api_key = os.environ.get("QDRANT_API_KEY", "").strip()
|
||||||
|
return url, api_key
|
||||||
|
|
||||||
|
|
||||||
|
class FewShotStore:
|
||||||
|
"""对 Qdrant 的轻量封装,专供 few-shot 样本检索使用。
|
||||||
|
|
||||||
|
设计要点:
|
||||||
|
- 惰性创建 client 和 collection,首次操作时初始化。
|
||||||
|
- 所有公共方法吞异常(返回空/False),主链路永远不被拖崩。
|
||||||
|
- 向量写入和检索都依赖外部传入的 :class:`EmbeddingProvider`,
|
||||||
|
由调用方保证与配置一致。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, embedding_provider: Any) -> None:
|
||||||
|
self._embedding_provider = embedding_provider
|
||||||
|
self._client: Any = None
|
||||||
|
self._ensured = False
|
||||||
|
|
||||||
|
def _client_or_none(self) -> Any:
|
||||||
|
"""惰性初始化 QdrantClient,失败返回 None。"""
|
||||||
|
|
||||||
|
if self._client is not None:
|
||||||
|
return self._client
|
||||||
|
try:
|
||||||
|
from qdrant_client import QdrantClient
|
||||||
|
|
||||||
|
url, api_key = _resolve_qdrant_config()
|
||||||
|
self._client = QdrantClient(url=url, api_key=api_key or None)
|
||||||
|
except Exception:
|
||||||
|
logger.warning("few-shot QdrantClient 初始化失败,本轮操作跳过", exc_info=True)
|
||||||
|
self._client = None
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
def _ensure_collection(self) -> bool:
|
||||||
|
"""确保 collection 存在,成功返回 True。"""
|
||||||
|
|
||||||
|
if self._ensured:
|
||||||
|
return True
|
||||||
|
client = self._client_or_none()
|
||||||
|
if client is None:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||||
|
|
||||||
|
try:
|
||||||
|
client.get_collection(FEW_SHOT_COLLECTION)
|
||||||
|
self._ensured = True
|
||||||
|
return True
|
||||||
|
except UnexpectedResponse as exc:
|
||||||
|
if exc.status_code != 404:
|
||||||
|
raise
|
||||||
|
# collection 不存在则创建
|
||||||
|
dim = self._embedding_provider.dimension()
|
||||||
|
from qdrant_client.http.models import (
|
||||||
|
Distance,
|
||||||
|
VectorParams,
|
||||||
|
PayloadSchemaType,
|
||||||
|
)
|
||||||
|
|
||||||
|
client.create_collection(
|
||||||
|
collection_name=FEW_SHOT_COLLECTION,
|
||||||
|
vectors_config=VectorParams(size=dim, distance=Distance.COSINE),
|
||||||
|
)
|
||||||
|
for field, field_type in [
|
||||||
|
("sample_id", PayloadSchemaType.KEYWORD),
|
||||||
|
("scene", PayloadSchemaType.KEYWORD),
|
||||||
|
("label", PayloadSchemaType.KEYWORD),
|
||||||
|
("domain", PayloadSchemaType.KEYWORD),
|
||||||
|
("risk_type", PayloadSchemaType.KEYWORD),
|
||||||
|
("status", PayloadSchemaType.KEYWORD),
|
||||||
|
]:
|
||||||
|
try:
|
||||||
|
client.create_payload_index(
|
||||||
|
collection_name=FEW_SHOT_COLLECTION,
|
||||||
|
field_name=field,
|
||||||
|
field_schema=field_type,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.debug("payload index 创建跳过 field=%s", field, exc_info=True)
|
||||||
|
self._ensured = True
|
||||||
|
logger.info("few-shot collection 创建成功 dim=%s", dim)
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
logger.warning("few-shot collection 初始化失败,本轮操作跳过", exc_info=True)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def upsert(self, sample: Any) -> str | None:
|
||||||
|
"""把一条样本向量化并写入 Qdrant,返回 vector_id,失败返回 None。"""
|
||||||
|
|
||||||
|
if not self._ensure_collection():
|
||||||
|
return None
|
||||||
|
client = self._client
|
||||||
|
try:
|
||||||
|
vector = self._embedding_provider.embed([sample.case_text])[0]
|
||||||
|
except Exception:
|
||||||
|
logger.warning("few-shot embedding 失败 sample_key=%s", getattr(sample, "sample_key", ""), exc_info=True)
|
||||||
|
return None
|
||||||
|
vector_id = uuid.uuid4().hex
|
||||||
|
payload = {
|
||||||
|
"sample_id": sample.id,
|
||||||
|
"scene": sample.scene,
|
||||||
|
"label": sample.label,
|
||||||
|
"domain": sample.domain,
|
||||||
|
"risk_type": sample.risk_type,
|
||||||
|
"risk_level": sample.risk_level,
|
||||||
|
"status": getattr(sample, "status", "active"),
|
||||||
|
"conclusion_text": sample.conclusion_text,
|
||||||
|
"payload_json": sample.payload_json,
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
client.upsert(
|
||||||
|
collection_name=FEW_SHOT_COLLECTION,
|
||||||
|
points=[{"id": vector_id, "vector": vector, "payload": payload}],
|
||||||
|
)
|
||||||
|
return vector_id
|
||||||
|
except Exception:
|
||||||
|
logger.warning("few-shot upsert 失败 sample_key=%s", getattr(sample, "sample_key", ""), exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def search(
|
||||||
|
self,
|
||||||
|
case_text: str,
|
||||||
|
*,
|
||||||
|
scene: str | None = None,
|
||||||
|
labels: list[str] | None = None,
|
||||||
|
top_k: int = 3,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""按 case_text 检索相似样本,可按 scene/label 过滤。失败返回空列表。"""
|
||||||
|
|
||||||
|
if not case_text or not self._ensure_collection():
|
||||||
|
return []
|
||||||
|
client = self._client
|
||||||
|
try:
|
||||||
|
vector = self._embedding_provider.embed([case_text])[0]
|
||||||
|
except Exception:
|
||||||
|
logger.warning("few-shot 检索 embedding 失败", exc_info=True)
|
||||||
|
return []
|
||||||
|
must: list[dict[str, Any]] = [{"key": "status", "match": {"value": "active"}}]
|
||||||
|
if scene:
|
||||||
|
must.append({"key": "scene", "match": {"value": scene}})
|
||||||
|
if labels:
|
||||||
|
must.append({"key": "label", "match": {"any": labels}})
|
||||||
|
try:
|
||||||
|
from qdrant_client.http.models import Filter
|
||||||
|
|
||||||
|
results = client.query_points(
|
||||||
|
collection_name=FEW_SHOT_COLLECTION,
|
||||||
|
query=vector,
|
||||||
|
query_filter=Filter(must=must),
|
||||||
|
limit=top_k,
|
||||||
|
with_payload=True,
|
||||||
|
).points
|
||||||
|
except Exception:
|
||||||
|
logger.warning("few-shot 检索失败", exc_info=True)
|
||||||
|
return []
|
||||||
|
hits: list[dict[str, Any]] = []
|
||||||
|
for point in results:
|
||||||
|
payload = getattr(point, "payload", None) or {}
|
||||||
|
hits.append(
|
||||||
|
{
|
||||||
|
"sample_id": payload.get("sample_id"),
|
||||||
|
"score": float(getattr(point, "score", 0.0)),
|
||||||
|
"label": payload.get("label"),
|
||||||
|
"domain": payload.get("domain"),
|
||||||
|
"risk_type": payload.get("risk_type"),
|
||||||
|
"conclusion_text": payload.get("conclusion_text") or "",
|
||||||
|
"payload_json": payload.get("payload_json") or {},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return hits
|
||||||
|
|
||||||
|
def delete_by_vector_id(self, vector_id: str) -> bool:
|
||||||
|
"""按 vector_id 删除向量,失败返回 False。"""
|
||||||
|
|
||||||
|
if not vector_id or not self._ensure_collection():
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
self._client.delete(
|
||||||
|
collection_name=FEW_SHOT_COLLECTION,
|
||||||
|
points_selector=[vector_id],
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
logger.warning("few-shot 删除失败 vector_id=%s", vector_id, exc_info=True)
|
||||||
|
return False
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
from datetime import UTC, datetime, timedelta
|
from datetime import UTC, datetime, timedelta
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -8,6 +9,7 @@ from sqlalchemy import func, select
|
|||||||
from sqlalchemy.orm import Session, joinedload
|
from sqlalchemy.orm import Session, joinedload
|
||||||
|
|
||||||
from app.algorithem.risk_graph import RiskHistoryStats, RiskObservationDraft
|
from app.algorithem.risk_graph import RiskHistoryStats, RiskObservationDraft
|
||||||
|
from app.core.logging import get_logger
|
||||||
from app.db.base import Base
|
from app.db.base import Base
|
||||||
from app.models.financial_record import ExpenseClaim
|
from app.models.financial_record import ExpenseClaim
|
||||||
from app.models.risk_observation import RiskObservation, RiskObservationFeedback
|
from app.models.risk_observation import RiskObservation, RiskObservationFeedback
|
||||||
@@ -17,6 +19,8 @@ from app.schemas.risk_observation import (
|
|||||||
)
|
)
|
||||||
from app.services.expense_claim_risk_stage import normalize_risk_business_stage
|
from app.services.expense_claim_risk_stage import normalize_risk_business_stage
|
||||||
|
|
||||||
|
logger = get_logger("app.services.risk_observations")
|
||||||
|
|
||||||
HIGH_LEVELS = {"high", "critical"}
|
HIGH_LEVELS = {"high", "critical"}
|
||||||
SEVERITY_SCORE = {
|
SEVERITY_SCORE = {
|
||||||
"low": 32,
|
"low": 32,
|
||||||
@@ -322,8 +326,27 @@ class RiskObservationService:
|
|||||||
observation.status, observation.feedback_status = mapped
|
observation.status, observation.feedback_status = mapped
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
self.db.refresh(feedback)
|
self.db.refresh(feedback)
|
||||||
|
self._maybe_ingest_few_shot(observation, feedback)
|
||||||
return feedback
|
return feedback
|
||||||
|
|
||||||
|
def _maybe_ingest_few_shot(
|
||||||
|
self,
|
||||||
|
observation: RiskObservation,
|
||||||
|
feedback: RiskObservationFeedback,
|
||||||
|
) -> None:
|
||||||
|
"""人工确认/误报后把样本沉淀进 few-shot 池,任何失败都不影响主流程。"""
|
||||||
|
|
||||||
|
if os.environ.get("FEW_SHOT_INJECTION_ENABLED", "true").strip().lower() in {"0", "false", "no"}:
|
||||||
|
return
|
||||||
|
if observation.feedback_status not in {"confirmed", "false_positive"}:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
from app.services.few_shot_ingestion import FewShotIngestionService
|
||||||
|
|
||||||
|
FewShotIngestionService(self.db).ingest_observation_feedback(observation, feedback)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("few-shot ingestion failed for observation %s", observation.id)
|
||||||
|
|
||||||
def summarize_dashboard(
|
def summarize_dashboard(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
|
|||||||
@@ -234,6 +234,10 @@ class RiskRuleGenerationService:
|
|||||||
}
|
}
|
||||||
for item in fields
|
for item in fields
|
||||||
]
|
]
|
||||||
|
few_shot_samples = self._retrieve_few_shot_samples(
|
||||||
|
domain=domain,
|
||||||
|
natural_language=natural_language,
|
||||||
|
)
|
||||||
messages = build_risk_rule_compiler_messages(
|
messages = build_risk_rule_compiler_messages(
|
||||||
domain=domain,
|
domain=domain,
|
||||||
domain_label=BUSINESS_DOMAIN_LABELS[domain],
|
domain_label=BUSINESS_DOMAIN_LABELS[domain],
|
||||||
@@ -243,6 +247,7 @@ class RiskRuleGenerationService:
|
|||||||
expense_category_label=expense_category_label,
|
expense_category_label=expense_category_label,
|
||||||
natural_language=natural_language,
|
natural_language=natural_language,
|
||||||
available_fields=field_payload,
|
available_fields=field_payload,
|
||||||
|
few_shot_samples=few_shot_samples,
|
||||||
)
|
)
|
||||||
answer = self.runtime_chat_service.complete(
|
answer = self.runtime_chat_service.complete(
|
||||||
messages,
|
messages,
|
||||||
@@ -263,6 +268,29 @@ class RiskRuleGenerationService:
|
|||||||
payload = unwrap_semantic_plan_payload(payload)
|
payload = unwrap_semantic_plan_payload(payload)
|
||||||
return self._sanitize_model_draft(payload, fields=fields)
|
return self._sanitize_model_draft(payload, fields=fields)
|
||||||
|
|
||||||
|
def _retrieve_few_shot_samples(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
domain: str,
|
||||||
|
natural_language: str,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""检索已确认历史样本,失败降级为空列表。"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
if os.environ.get("FEW_SHOT_INJECTION_ENABLED", "true").strip().lower() in {"0", "false", "no"}:
|
||||||
|
return []
|
||||||
|
try:
|
||||||
|
from app.services.few_shot_retrieval import FewShotRetriever
|
||||||
|
|
||||||
|
retriever = FewShotRetriever.from_session(self.db)
|
||||||
|
return retriever.retrieve_for_risk_rule_generation(
|
||||||
|
domain=domain,
|
||||||
|
natural_language=natural_language,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
return []
|
||||||
|
|
||||||
def _sanitize_model_draft(
|
def _sanitize_model_draft(
|
||||||
self,
|
self,
|
||||||
payload: dict[str, Any],
|
payload: dict[str, Any],
|
||||||
|
|||||||
@@ -14,10 +14,15 @@ def build_risk_rule_compiler_messages(
|
|||||||
expense_category_label: str,
|
expense_category_label: str,
|
||||||
natural_language: str,
|
natural_language: str,
|
||||||
available_fields: list[dict[str, Any]],
|
available_fields: list[dict[str, Any]],
|
||||||
|
few_shot_samples: list[dict[str, Any]] | None = None,
|
||||||
) -> list[dict[str, str]]:
|
) -> list[dict[str, str]]:
|
||||||
"""构造自然语言规则编译提示词。
|
"""构造自然语言规则编译提示词。
|
||||||
|
|
||||||
大模型只负责把业务语言拆成“语义计划”,后端会校验字段、操作符和模板。
|
大模型只负责把业务语言拆成“语义计划”,后端会校验字段、操作符和模板。
|
||||||
|
|
||||||
|
``few_shot_samples`` 是从已确认历史样本中检索出来的相似案例,会被合并进
|
||||||
|
``examples`` 字段并标注 ``source: "historical_confirmed"``,让编译器参考
|
||||||
|
过往人工结论。传 ``None`` 或空列表时行为与历史完全一致(向后兼容)。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
schema = {
|
schema = {
|
||||||
@@ -161,6 +166,20 @@ def build_risk_rule_compiler_messages(
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
historical_examples: list[dict[str, Any]] = []
|
||||||
|
if few_shot_samples:
|
||||||
|
for sample in few_shot_samples:
|
||||||
|
historical_examples.append(
|
||||||
|
{
|
||||||
|
"source": "historical_confirmed",
|
||||||
|
"label": sample.get("label"),
|
||||||
|
"domain": sample.get("domain") or "",
|
||||||
|
"risk_type": sample.get("risk_type") or "",
|
||||||
|
"conclusion": sample.get("conclusion") or "",
|
||||||
|
"context": sample.get("context") or {},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
merged_examples = historical_examples + examples
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
@@ -186,7 +205,7 @@ def build_risk_rule_compiler_messages(
|
|||||||
"natural_language": natural_language,
|
"natural_language": natural_language,
|
||||||
"available_fields": available_fields,
|
"available_fields": available_fields,
|
||||||
"required_json_shape": response_schema,
|
"required_json_shape": response_schema,
|
||||||
"examples": examples,
|
"examples": merged_examples,
|
||||||
},
|
},
|
||||||
ensure_ascii=False,
|
ensure_ascii=False,
|
||||||
),
|
),
|
||||||
|
|||||||
Reference in New Issue
Block a user