From 3a9d1547836c28622a4ba0f1f611c57bd633af01 Mon Sep 17 00:00:00 2001 From: caoxiaozhu Date: Fri, 3 Jul 2026 13:55:52 +0800 Subject: [PATCH] =?UTF-8?q?feat(flywheel):=20few-shot=20=E5=9C=A8=E7=BA=BF?= =?UTF-8?q?=E6=A3=80=E7=B4=A2=E6=B3=A8=E5=85=A5=E6=89=93=E9=80=9A=E9=A3=8E?= =?UTF-8?q?=E9=99=A9=E8=A7=84=E5=88=99=E7=BC=96=E8=AF=91=E9=93=BE=E8=B7=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 FewShotStore:独立 Qdrant collection few_shot_samples,向量 upsert/search/delete, 全程失败降级不阻塞主链路 - 新增 FewShotIngestionService:RiskObservation confirmed/false_positive → FewShotSample + 向量,带 sample_key 幂等去重 - 新增 FewShotRetriever:按 case 特征检索相似历史样本,去重 + token 预算 + 单条字符上限裁剪 - risk_observations.create_feedback commit 后挂 hook 自动入库,带 feature flag 和 try/except 兜底 - risk_rule_generation_prompt 新增 few_shot_samples 可选参数,合并进 examples 并标 source=historical_confirmed;risk_rule_generation 构造 prompt 前调 retriever,失败降级为空 --- server/src/app/services/few_shot_ingestion.py | 177 +++++++++++++++ server/src/app/services/few_shot_retrieval.py | 122 ++++++++++ server/src/app/services/few_shot_store.py | 214 ++++++++++++++++++ server/src/app/services/risk_observations.py | 23 ++ .../src/app/services/risk_rule_generation.py | 28 +++ .../services/risk_rule_generation_prompt.py | 21 +- 6 files changed, 584 insertions(+), 1 deletion(-) create mode 100644 server/src/app/services/few_shot_ingestion.py create mode 100644 server/src/app/services/few_shot_retrieval.py create mode 100644 server/src/app/services/few_shot_store.py diff --git a/server/src/app/services/few_shot_ingestion.py b/server/src/app/services/few_shot_ingestion.py new file mode 100644 index 0000000..d8914f9 --- /dev/null +++ b/server/src/app/services/few_shot_ingestion.py @@ -0,0 +1,177 @@ +"""Few-shot 样本入库编排:RiskObservation → FewShotSample → Qdrant。 + +只处理人工确认为 confirmed / false_positive 的观测,把它转成一条 +:class:`FewShotSample`,持久化到 DB,并同步向量到 Qdrant。 + +入库动作由 :meth:`RiskObservationService.create_feedback` 在 commit 后触发, +本服务全程吞异常(只记日志),保证反馈主流程不被 few-shot 链路拖崩。 +""" + +from __future__ import annotations + +from typing import Any + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from app.core.logging import get_logger +from app.models.few_shot_sample import FewShotSample +from app.models.risk_observation import RiskObservation, RiskObservationFeedback +from app.services.embedding_provider import EmbeddingProvider +from app.services.few_shot_store import FewShotStore + +logger = get_logger("app.services.few_shot_ingestion") + +# 仅这两个 feedback_status 视为已确认样本,会入库 +CONFIRMED_LABELS = {"confirmed", "false_positive"} + +# label → 自然语言结论(当 feedback.comment 缺失时兜底) +LABEL_CONCLUSION_FALLBACK = { + "confirmed": "经人工复核确认,该风险线索成立,需按规则拦截或补件。", + "false_positive": "经人工复核判定为误报,相似情形不应触发该风险规则。", +} + + +class FewShotIngestionService: + """把已确认的风险观测沉淀为 few-shot 样本。""" + + def __init__(self, db: Session) -> None: + self.db = db + + def ingest_observation_feedback( + self, + observation: RiskObservation, + feedback: RiskObservationFeedback, + ) -> FewShotSample | None: + """人工确认/误报后调用,写入并同步向量。""" + + label = observation.feedback_status + if label not in CONFIRMED_LABELS: + return None + + sample_key = f"obs:{observation.id}" + sample = self.db.scalar( + select(FewShotSample).where(FewShotSample.sample_key == sample_key) + ) + + domain = self._extract_domain(observation) + case_text = self._build_case_text(observation) + conclusion_text = self._build_conclusion_text(observation, feedback, label) + payload = self._build_payload(observation, feedback, label) + + if sample is None: + sample = FewShotSample( + sample_key=sample_key, + source_observation_id=observation.id, + scene="risk_rule_generation", + domain=domain, + risk_type=observation.risk_type or "", + risk_level=observation.risk_level or "", + label=label, + case_text=case_text, + conclusion_text=conclusion_text, + payload_json=payload, + status="active", + ) + self.db.add(sample) + else: + sample.label = label + sample.domain = domain + sample.risk_type = observation.risk_type or "" + sample.risk_level = observation.risk_level or "" + sample.case_text = case_text + sample.conclusion_text = conclusion_text + sample.payload_json = payload + sample.status = "active" + sample.vector_id = sample.vector_id + try: + self.db.commit() + self.db.refresh(sample) + except Exception: + logger.exception("few-shot 样本持久化失败 observation_id=%s", observation.id) + self.db.rollback() + return None + + vector_id = self._store().upsert(sample) + if vector_id: + sample.vector_id = vector_id + try: + self.db.commit() + except Exception: + logger.warning("few-shot vector_id 回写失败 sample_id=%s", sample.id) + return sample + + def retract_observation(self, observation_id: str) -> bool: + """观测被撤销时删掉对应样本及其向量。""" + + sample = self.db.scalar( + select(FewShotSample).where(FewShotSample.source_observation_id == observation_id) + ) + if sample is None: + return False + if sample.vector_id: + self._store().delete_by_vector_id(sample.vector_id) + try: + self.db.delete(sample) + self.db.commit() + return True + except Exception: + logger.exception("few-shot 样本删除失败 observation_id=%s", observation_id) + self.db.rollback() + return False + + def _store(self) -> FewShotStore: + provider = EmbeddingProvider.from_settings(self.db) + return FewShotStore(provider) + + def _extract_domain(self, observation: RiskObservation) -> str: + ontology = observation.ontology_json or {} + return str(ontology.get("domain") or "") + + def _build_case_text(self, observation: RiskObservation) -> str: + parts = [ + observation.title or "", + observation.description or "", + observation.risk_signal or "", + observation.risk_type or "", + ] + ontology = observation.ontology_json or {} + scenario = ontology.get("scenario") + if scenario: + parts.append(f"场景:{scenario}") + risk_signals = ontology.get("risk_signals") + if isinstance(risk_signals, list) and risk_signals: + parts.append("信号:" + "|".join(str(s) for s in risk_signals)) + return "\n".join(part for part in parts if part).strip() + + def _build_conclusion_text( + self, + observation: RiskObservation, + feedback: RiskObservationFeedback, + label: str, + ) -> str: + comment = (feedback.comment or "").strip() + if comment: + return f"[{label}] {comment}" + return LABEL_CONCLUSION_FALLBACK.get(label, label) + + def _build_payload( + self, + observation: RiskObservation, + feedback: RiskObservationFeedback, + label: str, + ) -> dict[str, Any]: + return { + "label": label, + "risk_type": observation.risk_type, + "risk_signal": observation.risk_signal, + "risk_level": observation.risk_level, + "feedback_type": feedback.feedback_type, + "feedback_comment": feedback.comment or "", + "feedback_actor": feedback.actor or "", + "ontology": observation.ontology_json or {}, + "policy_refs": observation.policy_refs_json or [], + "evidence": observation.evidence_json or [], + "subject_label": observation.subject_label or "", + "claim_no": observation.claim_no or "", + } diff --git a/server/src/app/services/few_shot_retrieval.py b/server/src/app/services/few_shot_retrieval.py new file mode 100644 index 0000000..d08f5f8 --- /dev/null +++ b/server/src/app/services/few_shot_retrieval.py @@ -0,0 +1,122 @@ +"""Few-shot 检索器:按当前 case 特征检索相似历史样本,拼成注入块。 + +从 :class:`FewShotStore` 取相似样本,转成可供 prompt 构造函数直接使用的结构。 +带 token 预算裁剪和去重,确保不撑爆 prompt。 + +典型用法(在构造 prompt 之前调用):: + + retriever = FewShotRetriever.from_session(session) + samples = retriever.retrieve_for_risk_rule_generation( + domain="travel", natural_language="票据城市与申报地不一致" + ) + messages = build_risk_rule_compiler_messages( + ..., + few_shot_samples=samples, + ) +""" + +from __future__ import annotations + +from typing import Any + +from sqlalchemy.orm import Session + +from app.core.logging import get_logger +from app.services.embedding_provider import EmbeddingProvider +from app.services.few_shot_store import FewShotStore + +logger = get_logger("app.services.few_shot_retrieval") + +# 单条 few-shot 样本估算 token 数(用于预算裁剪) +SAMPLE_TOKEN_BUDGET = 1200 +# 单条样本最大字符数,超长直接截断结论,避免撑爆 prompt +SINGLE_SAMPLE_MAX_CHARS = 400 +# 历史样本最多注入条数(与原内置 examples 合并后总量受限) +MAX_HISTORICAL_SAMPLES = 3 + + +class FewShotRetriever: + """按 case 特征检索已确认样本,返回 prompt 可直接消费的结构。""" + + def __init__(self, store: FewShotStore) -> None: + self._store = store + + @classmethod + def from_session(cls, session: Session) -> "FewShotRetriever": + provider = EmbeddingProvider.from_settings(session) + return cls(FewShotStore(provider)) + + def retrieve_for_risk_rule_generation( + self, + *, + domain: str = "", + risk_type: str = "", + natural_language: str, + top_k: int = MAX_HISTORICAL_SAMPLES, + ) -> list[dict[str, Any]]: + """检索与当前规则需求相似的历史样本,返回注入块列表。""" + + case_text = self._build_case_text( + natural_language=natural_language, + domain=domain, + risk_type=risk_type, + ) + if not case_text: + return [] + hits = self._store.search( + case_text, + scene="risk_rule_generation", + labels=["confirmed", "false_positive"], + top_k=top_k, + ) + return self._hits_to_injection_blocks(hits) + + def _build_case_text( + self, + *, + natural_language: str, + domain: str = "", + risk_type: str = "", + ) -> str: + parts = [natural_language, domain, risk_type] + return "\n".join(p for p in parts if p).strip() + + def _hits_to_injection_blocks( + self, + hits: list[dict[str, Any]], + ) -> list[dict[str, Any]]: + """把检索命中转成 prompt 可消费的块,做去重和预算裁剪。""" + + blocks: list[dict[str, Any]] = [] + seen_conclusions: set[str] = set() + budget = SAMPLE_TOKEN_BUDGET + for hit in hits: + conclusion = (hit.get("conclusion_text") or "").strip() + if not conclusion or conclusion in seen_conclusions: + continue + # 超长结论截断到上限,避免单条样本占用过多预算 + if len(conclusion) > SINGLE_SAMPLE_MAX_CHARS: + conclusion = conclusion[:SINGLE_SAMPLE_MAX_CHARS] + payload = hit.get("payload_json") or {} + block = { + "source": "historical_confirmed", + "label": hit.get("label"), + "domain": hit.get("domain") or "", + "risk_type": hit.get("risk_type") or "", + "score": round(float(hit.get("score") or 0.0), 4), + "conclusion": conclusion, + "context": { + "risk_signal": payload.get("risk_signal") or "", + "risk_level": payload.get("risk_level") or "", + "ontology": payload.get("ontology") or {}, + "feedback_comment": payload.get("feedback_comment") or "", + }, + } + # 粗略 token 估算(按字符数 / 1.6 近似中文 token 比) + estimated_tokens = int(len(conclusion) / 1.6) + 40 + if estimated_tokens > budget: + break + budget -= estimated_tokens + blocks.append(block) + seen_conclusions.add(conclusion) + return blocks diff --git a/server/src/app/services/few_shot_store.py b/server/src/app/services/few_shot_store.py new file mode 100644 index 0000000..071b28f --- /dev/null +++ b/server/src/app/services/few_shot_store.py @@ -0,0 +1,214 @@ +"""Few-shot 样本的 Qdrant 向量存储。 + +独立于 LightRAG 的 Qdrant 客户端,使用专用 collection ``few_shot_samples``, +与知识库 RAG 的 collection 隔离。所有操作失败都不抛异常(记日志返回空), +保证主链路不阻塞。 + +向量来自 :class:`EmbeddingProvider`,payload 带业务过滤字段(scene/label/domain/risk_type), +检索时按这些字段过滤 + 向量相似度排序。 +""" + +from __future__ import annotations + +import os +import uuid +from typing import Any + +from app.core.logging import get_logger +from app.services.knowledge_rag import _resolve_default_qdrant_url + +logger = get_logger("app.services.few_shot_store") + +FEW_SHOT_COLLECTION = "few_shot_samples" + + +def _resolve_qdrant_config() -> tuple[str, str]: + """复用 knowledge_rag 的 Qdrant URL/key 解析逻辑。""" + + url = os.environ.get("QDRANT_URL", "").strip() or _resolve_default_qdrant_url() + api_key = os.environ.get("QDRANT_API_KEY", "").strip() + return url, api_key + + +class FewShotStore: + """对 Qdrant 的轻量封装,专供 few-shot 样本检索使用。 + + 设计要点: + - 惰性创建 client 和 collection,首次操作时初始化。 + - 所有公共方法吞异常(返回空/False),主链路永远不被拖崩。 + - 向量写入和检索都依赖外部传入的 :class:`EmbeddingProvider`, + 由调用方保证与配置一致。 + """ + + def __init__(self, embedding_provider: Any) -> None: + self._embedding_provider = embedding_provider + self._client: Any = None + self._ensured = False + + def _client_or_none(self) -> Any: + """惰性初始化 QdrantClient,失败返回 None。""" + + if self._client is not None: + return self._client + try: + from qdrant_client import QdrantClient + + url, api_key = _resolve_qdrant_config() + self._client = QdrantClient(url=url, api_key=api_key or None) + except Exception: + logger.warning("few-shot QdrantClient 初始化失败,本轮操作跳过", exc_info=True) + self._client = None + return self._client + + def _ensure_collection(self) -> bool: + """确保 collection 存在,成功返回 True。""" + + if self._ensured: + return True + client = self._client_or_none() + if client is None: + return False + try: + from qdrant_client.http.exceptions import UnexpectedResponse + + try: + client.get_collection(FEW_SHOT_COLLECTION) + self._ensured = True + return True + except UnexpectedResponse as exc: + if exc.status_code != 404: + raise + # collection 不存在则创建 + dim = self._embedding_provider.dimension() + from qdrant_client.http.models import ( + Distance, + VectorParams, + PayloadSchemaType, + ) + + client.create_collection( + collection_name=FEW_SHOT_COLLECTION, + vectors_config=VectorParams(size=dim, distance=Distance.COSINE), + ) + for field, field_type in [ + ("sample_id", PayloadSchemaType.KEYWORD), + ("scene", PayloadSchemaType.KEYWORD), + ("label", PayloadSchemaType.KEYWORD), + ("domain", PayloadSchemaType.KEYWORD), + ("risk_type", PayloadSchemaType.KEYWORD), + ("status", PayloadSchemaType.KEYWORD), + ]: + try: + client.create_payload_index( + collection_name=FEW_SHOT_COLLECTION, + field_name=field, + field_schema=field_type, + ) + except Exception: + logger.debug("payload index 创建跳过 field=%s", field, exc_info=True) + self._ensured = True + logger.info("few-shot collection 创建成功 dim=%s", dim) + return True + except Exception: + logger.warning("few-shot collection 初始化失败,本轮操作跳过", exc_info=True) + return False + + def upsert(self, sample: Any) -> str | None: + """把一条样本向量化并写入 Qdrant,返回 vector_id,失败返回 None。""" + + if not self._ensure_collection(): + return None + client = self._client + try: + vector = self._embedding_provider.embed([sample.case_text])[0] + except Exception: + logger.warning("few-shot embedding 失败 sample_key=%s", getattr(sample, "sample_key", ""), exc_info=True) + return None + vector_id = uuid.uuid4().hex + payload = { + "sample_id": sample.id, + "scene": sample.scene, + "label": sample.label, + "domain": sample.domain, + "risk_type": sample.risk_type, + "risk_level": sample.risk_level, + "status": getattr(sample, "status", "active"), + "conclusion_text": sample.conclusion_text, + "payload_json": sample.payload_json, + } + try: + client.upsert( + collection_name=FEW_SHOT_COLLECTION, + points=[{"id": vector_id, "vector": vector, "payload": payload}], + ) + return vector_id + except Exception: + logger.warning("few-shot upsert 失败 sample_key=%s", getattr(sample, "sample_key", ""), exc_info=True) + return None + + def search( + self, + case_text: str, + *, + scene: str | None = None, + labels: list[str] | None = None, + top_k: int = 3, + ) -> list[dict[str, Any]]: + """按 case_text 检索相似样本,可按 scene/label 过滤。失败返回空列表。""" + + if not case_text or not self._ensure_collection(): + return [] + client = self._client + try: + vector = self._embedding_provider.embed([case_text])[0] + except Exception: + logger.warning("few-shot 检索 embedding 失败", exc_info=True) + return [] + must: list[dict[str, Any]] = [{"key": "status", "match": {"value": "active"}}] + if scene: + must.append({"key": "scene", "match": {"value": scene}}) + if labels: + must.append({"key": "label", "match": {"any": labels}}) + try: + from qdrant_client.http.models import Filter + + results = client.query_points( + collection_name=FEW_SHOT_COLLECTION, + query=vector, + query_filter=Filter(must=must), + limit=top_k, + with_payload=True, + ).points + except Exception: + logger.warning("few-shot 检索失败", exc_info=True) + return [] + hits: list[dict[str, Any]] = [] + for point in results: + payload = getattr(point, "payload", None) or {} + hits.append( + { + "sample_id": payload.get("sample_id"), + "score": float(getattr(point, "score", 0.0)), + "label": payload.get("label"), + "domain": payload.get("domain"), + "risk_type": payload.get("risk_type"), + "conclusion_text": payload.get("conclusion_text") or "", + "payload_json": payload.get("payload_json") or {}, + } + ) + return hits + + def delete_by_vector_id(self, vector_id: str) -> bool: + """按 vector_id 删除向量,失败返回 False。""" + + if not vector_id or not self._ensure_collection(): + return False + try: + self._client.delete( + collection_name=FEW_SHOT_COLLECTION, + points_selector=[vector_id], + ) + return True + except Exception: + logger.warning("few-shot 删除失败 vector_id=%s", vector_id, exc_info=True) + return False diff --git a/server/src/app/services/risk_observations.py b/server/src/app/services/risk_observations.py index 74d294b..1481a0c 100644 --- a/server/src/app/services/risk_observations.py +++ b/server/src/app/services/risk_observations.py @@ -1,5 +1,6 @@ from __future__ import annotations +import os from datetime import UTC, datetime, timedelta from decimal import Decimal from typing import Any @@ -8,6 +9,7 @@ from sqlalchemy import func, select from sqlalchemy.orm import Session, joinedload from app.algorithem.risk_graph import RiskHistoryStats, RiskObservationDraft +from app.core.logging import get_logger from app.db.base import Base from app.models.financial_record import ExpenseClaim from app.models.risk_observation import RiskObservation, RiskObservationFeedback @@ -17,6 +19,8 @@ from app.schemas.risk_observation import ( ) from app.services.expense_claim_risk_stage import normalize_risk_business_stage +logger = get_logger("app.services.risk_observations") + HIGH_LEVELS = {"high", "critical"} SEVERITY_SCORE = { "low": 32, @@ -322,8 +326,27 @@ class RiskObservationService: observation.status, observation.feedback_status = mapped self.db.commit() self.db.refresh(feedback) + self._maybe_ingest_few_shot(observation, feedback) return feedback + def _maybe_ingest_few_shot( + self, + observation: RiskObservation, + feedback: RiskObservationFeedback, + ) -> None: + """人工确认/误报后把样本沉淀进 few-shot 池,任何失败都不影响主流程。""" + + if os.environ.get("FEW_SHOT_INJECTION_ENABLED", "true").strip().lower() in {"0", "false", "no"}: + return + if observation.feedback_status not in {"confirmed", "false_positive"}: + return + try: + from app.services.few_shot_ingestion import FewShotIngestionService + + FewShotIngestionService(self.db).ingest_observation_feedback(observation, feedback) + except Exception: + logger.exception("few-shot ingestion failed for observation %s", observation.id) + def summarize_dashboard( self, *, diff --git a/server/src/app/services/risk_rule_generation.py b/server/src/app/services/risk_rule_generation.py index 05088d8..c481d26 100644 --- a/server/src/app/services/risk_rule_generation.py +++ b/server/src/app/services/risk_rule_generation.py @@ -234,6 +234,10 @@ class RiskRuleGenerationService: } for item in fields ] + few_shot_samples = self._retrieve_few_shot_samples( + domain=domain, + natural_language=natural_language, + ) messages = build_risk_rule_compiler_messages( domain=domain, domain_label=BUSINESS_DOMAIN_LABELS[domain], @@ -243,6 +247,7 @@ class RiskRuleGenerationService: expense_category_label=expense_category_label, natural_language=natural_language, available_fields=field_payload, + few_shot_samples=few_shot_samples, ) answer = self.runtime_chat_service.complete( messages, @@ -263,6 +268,29 @@ class RiskRuleGenerationService: payload = unwrap_semantic_plan_payload(payload) return self._sanitize_model_draft(payload, fields=fields) + def _retrieve_few_shot_samples( + self, + *, + domain: str, + natural_language: str, + ) -> list[dict[str, Any]]: + """检索已确认历史样本,失败降级为空列表。""" + + import os + + if os.environ.get("FEW_SHOT_INJECTION_ENABLED", "true").strip().lower() in {"0", "false", "no"}: + return [] + try: + from app.services.few_shot_retrieval import FewShotRetriever + + retriever = FewShotRetriever.from_session(self.db) + return retriever.retrieve_for_risk_rule_generation( + domain=domain, + natural_language=natural_language, + ) + except Exception: + return [] + def _sanitize_model_draft( self, payload: dict[str, Any], diff --git a/server/src/app/services/risk_rule_generation_prompt.py b/server/src/app/services/risk_rule_generation_prompt.py index f38b92c..7707eb4 100644 --- a/server/src/app/services/risk_rule_generation_prompt.py +++ b/server/src/app/services/risk_rule_generation_prompt.py @@ -14,10 +14,15 @@ def build_risk_rule_compiler_messages( expense_category_label: str, natural_language: str, available_fields: list[dict[str, Any]], + few_shot_samples: list[dict[str, Any]] | None = None, ) -> list[dict[str, str]]: """构造自然语言规则编译提示词。 大模型只负责把业务语言拆成“语义计划”,后端会校验字段、操作符和模板。 + + ``few_shot_samples`` 是从已确认历史样本中检索出来的相似案例,会被合并进 + ``examples`` 字段并标注 ``source: "historical_confirmed"``,让编译器参考 + 过往人工结论。传 ``None`` 或空列表时行为与历史完全一致(向后兼容)。 """ schema = { @@ -161,6 +166,20 @@ def build_risk_rule_compiler_messages( }, } ] + historical_examples: list[dict[str, Any]] = [] + if few_shot_samples: + for sample in few_shot_samples: + historical_examples.append( + { + "source": "historical_confirmed", + "label": sample.get("label"), + "domain": sample.get("domain") or "", + "risk_type": sample.get("risk_type") or "", + "conclusion": sample.get("conclusion") or "", + "context": sample.get("context") or {}, + } + ) + merged_examples = historical_examples + examples return [ { "role": "system", @@ -186,7 +205,7 @@ def build_risk_rule_compiler_messages( "natural_language": natural_language, "available_fields": available_fields, "required_json_shape": response_schema, - "examples": examples, + "examples": merged_examples, }, ensure_ascii=False, ),