From 765cfb40f30d3dd8b67ddf516fa6133eef761624 Mon Sep 17 00:00:00 2001 From: caoxiaozhu Date: Fri, 3 Jul 2026 13:55:39 +0800 Subject: [PATCH] =?UTF-8?q?feat(flywheel):=20=E6=8A=BD=E5=85=AC=E5=85=B1?= =?UTF-8?q?=20EmbeddingProvider=20=E5=B9=B6=E6=96=B0=E5=A2=9E=20FewShotSam?= =?UTF-8?q?ple=20=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 从 knowledge_rag_runtime 抽出 embedding 调用逻辑为独立 EmbeddingProvider, 复用现有 HTTP 纯函数,RAG 路径零回归 - 新增 FewShotSample 表模型(样本池),注册到 db/base.py 和 models/__init__.py 供 few-shot 飞轮沉淀已确认风险观测 --- server/src/app/db/base.py | 2 + server/src/app/models/__init__.py | 2 + server/src/app/models/few_shot_sample.py | 54 +++++++ server/src/app/services/embedding_provider.py | 138 ++++++++++++++++++ 4 files changed, 196 insertions(+) create mode 100644 server/src/app/models/few_shot_sample.py create mode 100644 server/src/app/services/embedding_provider.py diff --git a/server/src/app/db/base.py b/server/src/app/db/base.py index b9ee3c7..2af5c13 100644 --- a/server/src/app/db/base.py +++ b/server/src/app/db/base.py @@ -15,6 +15,7 @@ from app.models.budget import BudgetAllocation, BudgetReservation, BudgetTransac from app.models.employee_change_log import EmployeeChangeLog from app.models.employee_behavior_profile import EmployeeBehaviorProfileSnapshot from app.models.employee import Employee +from app.models.few_shot_sample import FewShotSample from app.models.financial_record import ( AccountsPayableRecord, AccountsReceivableRecord, @@ -57,6 +58,7 @@ __all__ = [ "EmployeeBehaviorProfileSnapshot", "EmployeeChangeLog", "ExpenseClaim", + "FewShotSample", "ExpenseClaimItem", "HermesTaskConfig", "HermesTaskExecutionLog", diff --git a/server/src/app/models/__init__.py b/server/src/app/models/__init__.py index b3549eb..8c4796a 100644 --- a/server/src/app/models/__init__.py +++ b/server/src/app/models/__init__.py @@ -8,6 +8,7 @@ from app.models.budget import BudgetAllocation, BudgetReservation, BudgetTransac from app.models.employee_change_log import EmployeeChangeLog from app.models.employee_behavior_profile import EmployeeBehaviorProfileSnapshot from app.models.employee import Employee +from app.models.few_shot_sample import FewShotSample from app.models.financial_record import ( AccountsPayableRecord, AccountsReceivableRecord, @@ -49,6 +50,7 @@ __all__ = [ "EmployeeChangeLog", "ExpenseClaim", "ExpenseClaimItem", + "FewShotSample", "HermesTaskConfig", "HermesTaskExecutionLog", "HermesRiskReport", diff --git a/server/src/app/models/few_shot_sample.py b/server/src/app/models/few_shot_sample.py new file mode 100644 index 0000000..79ee503 --- /dev/null +++ b/server/src/app/models/few_shot_sample.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +import uuid +from datetime import datetime +from typing import Any + +from sqlalchemy import DateTime, ForeignKey, Index, String, Text, func +from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy.types import JSON + +from app.db.base_class import Base + + +class FewShotSample(Base): + """已确认的风险观测样本,供 few-shot 检索注入使用。 + + 数据来源是 ``RiskObservation`` 上人工确认为 confirmed / false_positive 的观测, + 入库后同时写一份向量到 Qdrant 的 ``few_shot_samples`` collection。 + """ + + __tablename__ = "few_shot_samples" + __table_args__ = ( + Index("ix_few_shot_samples_scene_label", "scene", "label"), + Index("ix_few_shot_samples_domain_risk_type", "domain", "risk_type"), + ) + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + sample_key: Mapped[str] = mapped_column(String(160), unique=True, index=True) + source_observation_id: Mapped[str | None] = mapped_column( + ForeignKey("risk_observations.id"), + nullable=True, + index=True, + ) + + scene: Mapped[str] = mapped_column(String(50), default="risk_rule_generation", index=True) + domain: Mapped[str] = mapped_column(String(50), default="", index=True) + risk_type: Mapped[str] = mapped_column(String(80), default="", index=True) + risk_level: Mapped[str] = mapped_column(String(20), default="") + + label: Mapped[str] = mapped_column(String(30), default="confirmed", index=True) + case_text: Mapped[str] = mapped_column(Text(), default="") + conclusion_text: Mapped[str] = mapped_column(Text(), default="") + payload_json: Mapped[dict[str, Any]] = mapped_column(JSON, default=dict) + + vector_id: Mapped[str | None] = mapped_column(String(100), nullable=True) + status: Mapped[str] = mapped_column(String(20), default="active", index=True) + + created_at: Mapped[datetime] = mapped_column(DateTime, default=func.now(), server_default=func.now()) + updated_at: Mapped[datetime] = mapped_column( + DateTime, + default=func.now(), + onupdate=func.now(), + server_default=func.now(), + ) diff --git a/server/src/app/services/embedding_provider.py b/server/src/app/services/embedding_provider.py new file mode 100644 index 0000000..5cce2d2 --- /dev/null +++ b/server/src/app/services/embedding_provider.py @@ -0,0 +1,138 @@ +"""公共 Embedding 提供者。 + +把 ``knowledge_rag_runtime`` 里 embedding 调用逻辑抽出来,供 RAG 和 +few-shot 检索复用。本模块只依赖现有模块级纯函数和 ``RuntimeModelConfig``, +不改动 ``_LightRagRuntime`` 的行为,RAG 路径保持零回归风险。 + +典型用法:: + + provider = EmbeddingProvider.from_settings(session) + vectors = provider.embed(["差旅超标", "票单不一致"]) + dim = provider.dimension() +""" + +from __future__ import annotations + +from typing import Any + +from app.core.logging import get_logger +from app.services.knowledge_rag_runtime import ( + DEFAULT_EMBEDDING_TIMEOUT_SECONDS, + KnowledgeRagError, + RuntimeModelConfig, + _build_headers, + _ensure_path, + _extract_embedding_vectors, + _normalize_endpoint, + _send_json_request, +) + +logger = get_logger("app.services.embedding_provider") + + +def _runtime_model_config_from_dict(config: dict[str, str]) -> RuntimeModelConfig: + """把 SettingsService.get_runtime_model_config 返回的 dict 转成 dataclass。""" + + return RuntimeModelConfig( + slot=str(config.get("slot") or "embedding"), + provider=str(config.get("provider") or ""), + model=str(config.get("model") or ""), + endpoint=str(config.get("endpoint") or ""), + api_key=str(config.get("apiKey") or ""), + capability=str(config.get("capability") or ""), + ) + + +class EmbeddingProvider: + """对 embedding 模型的轻量封装。 + + 设计要点: + - 持有一个 ``RuntimeModelConfig``,构造即固定,不依赖 LightRAG。 + - 复用 ``knowledge_rag_runtime`` 的 HTTP 调用纯函数,行为与 RAG 完全一致。 + - 维度采用惰性探测(首次 embed 后缓存),避免空构造就打远端。 + """ + + def __init__(self, config: RuntimeModelConfig) -> None: + self.config = config + self._dimension: int | None = None + + @classmethod + def from_settings(cls, session: Any) -> "EmbeddingProvider": + """从 SettingsService 取 embedding 配置构造 provider。""" + + from app.services.settings import SettingsService + + raw = SettingsService(session).get_runtime_model_config("embedding") + return cls(_runtime_model_config_from_dict(raw)) + + def embed(self, texts: list[str]) -> list[list[float]]: + """对一组文本做 embedding,返回与输入等长的向量列表。""" + + if not texts: + return [] + return _request_embeddings_public(self.config, texts) + + def dimension(self) -> int: + """探测 embedding 维度,结果缓存。失败抛 KnowledgeRagError。""" + + if self._dimension is None: + vectors = self.embed(["dimension probe"]) + if not vectors or not isinstance(vectors[0], list): + raise KnowledgeRagError("无法从 embedding 模型返回结果中解析向量维度。") + self._dimension = len(vectors[0]) + if self._dimension <= 0: + raise KnowledgeRagError("embedding 模型返回了无效的向量维度。") + return self._dimension + + +def _request_embeddings_public( + config: RuntimeModelConfig, + texts: list[str], +) -> list[list[float]]: + """按 provider 分支构造 embedding 请求。 + + 与 ``_LightRagRuntime._request_embeddings`` 实现保持一致, + 保证 few-shot 检索与 RAG 走同一套调用语义。 + """ + + from app.services.model_connectivity import AZURE_API_VERSION + + if config.provider == "Azure OpenAI": + from app.services.knowledge_rag_runtime import _build_azure_deployment_base + + url = f"{_build_azure_deployment_base(config.endpoint, config.model)}/embeddings?api-version={AZURE_API_VERSION}" + payload: dict[str, Any] = {"input": texts} + status_code, body = _send_json_request( + "POST", + url, + headers=_build_headers(config.api_key, use_bearer=False, use_api_key=True), + payload=payload, + timeout_seconds=DEFAULT_EMBEDDING_TIMEOUT_SECONDS, + ) + elif config.provider == "Ollama": + url = _ensure_path(_normalize_endpoint(config.endpoint), "api/embed") + payload = {"model": config.model, "input": texts} + status_code, body = _send_json_request( + "POST", + url, + headers={"Content-Type": "application/json", "Accept": "application/json"}, + payload=payload, + timeout_seconds=DEFAULT_EMBEDDING_TIMEOUT_SECONDS, + ) + else: + url = _ensure_path(_normalize_endpoint(config.endpoint), "embeddings") + payload = {"model": config.model, "input": texts} + status_code, body = _send_json_request( + "POST", + url, + headers=_build_headers(config.api_key, use_bearer=True), + payload=payload, + timeout_seconds=DEFAULT_EMBEDDING_TIMEOUT_SECONDS, + ) + + from http import HTTPStatus + + if status_code >= HTTPStatus.BAD_REQUEST: + raise KnowledgeRagError(f"embedding 模型返回异常状态码 {status_code}。") + + return _extract_embedding_vectors(body, provider=config.provider)