feat(flywheel): 抽公共 EmbeddingProvider 并新增 FewShotSample 模型

- 从 knowledge_rag_runtime 抽出 embedding 调用逻辑为独立 EmbeddingProvider,
  复用现有 HTTP 纯函数,RAG 路径零回归
- 新增 FewShotSample 表模型(样本池),注册到 db/base.py 和 models/__init__.py
  供 few-shot 飞轮沉淀已确认风险观测
This commit is contained in:
caoxiaozhu
2026-07-03 13:55:39 +08:00
parent 08f023243e
commit 765cfb40f3
4 changed files with 196 additions and 0 deletions

View File

@@ -15,6 +15,7 @@ from app.models.budget import BudgetAllocation, BudgetReservation, BudgetTransac
from app.models.employee_change_log import EmployeeChangeLog
from app.models.employee_behavior_profile import EmployeeBehaviorProfileSnapshot
from app.models.employee import Employee
from app.models.few_shot_sample import FewShotSample
from app.models.financial_record import (
AccountsPayableRecord,
AccountsReceivableRecord,
@@ -57,6 +58,7 @@ __all__ = [
"EmployeeBehaviorProfileSnapshot",
"EmployeeChangeLog",
"ExpenseClaim",
"FewShotSample",
"ExpenseClaimItem",
"HermesTaskConfig",
"HermesTaskExecutionLog",

View File

@@ -8,6 +8,7 @@ from app.models.budget import BudgetAllocation, BudgetReservation, BudgetTransac
from app.models.employee_change_log import EmployeeChangeLog
from app.models.employee_behavior_profile import EmployeeBehaviorProfileSnapshot
from app.models.employee import Employee
from app.models.few_shot_sample import FewShotSample
from app.models.financial_record import (
AccountsPayableRecord,
AccountsReceivableRecord,
@@ -49,6 +50,7 @@ __all__ = [
"EmployeeChangeLog",
"ExpenseClaim",
"ExpenseClaimItem",
"FewShotSample",
"HermesTaskConfig",
"HermesTaskExecutionLog",
"HermesRiskReport",

View File

@@ -0,0 +1,54 @@
from __future__ import annotations
import uuid
from datetime import datetime
from typing import Any
from sqlalchemy import DateTime, ForeignKey, Index, String, Text, func
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.types import JSON
from app.db.base_class import Base
class FewShotSample(Base):
"""已确认的风险观测样本,供 few-shot 检索注入使用。
数据来源是 ``RiskObservation`` 上人工确认为 confirmed / false_positive 的观测,
入库后同时写一份向量到 Qdrant 的 ``few_shot_samples`` collection。
"""
__tablename__ = "few_shot_samples"
__table_args__ = (
Index("ix_few_shot_samples_scene_label", "scene", "label"),
Index("ix_few_shot_samples_domain_risk_type", "domain", "risk_type"),
)
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
sample_key: Mapped[str] = mapped_column(String(160), unique=True, index=True)
source_observation_id: Mapped[str | None] = mapped_column(
ForeignKey("risk_observations.id"),
nullable=True,
index=True,
)
scene: Mapped[str] = mapped_column(String(50), default="risk_rule_generation", index=True)
domain: Mapped[str] = mapped_column(String(50), default="", index=True)
risk_type: Mapped[str] = mapped_column(String(80), default="", index=True)
risk_level: Mapped[str] = mapped_column(String(20), default="")
label: Mapped[str] = mapped_column(String(30), default="confirmed", index=True)
case_text: Mapped[str] = mapped_column(Text(), default="")
conclusion_text: Mapped[str] = mapped_column(Text(), default="")
payload_json: Mapped[dict[str, Any]] = mapped_column(JSON, default=dict)
vector_id: Mapped[str | None] = mapped_column(String(100), nullable=True)
status: Mapped[str] = mapped_column(String(20), default="active", index=True)
created_at: Mapped[datetime] = mapped_column(DateTime, default=func.now(), server_default=func.now())
updated_at: Mapped[datetime] = mapped_column(
DateTime,
default=func.now(),
onupdate=func.now(),
server_default=func.now(),
)

View File

@@ -0,0 +1,138 @@
"""公共 Embedding 提供者。
把 ``knowledge_rag_runtime`` 里 embedding 调用逻辑抽出来,供 RAG 和
few-shot 检索复用。本模块只依赖现有模块级纯函数和 ``RuntimeModelConfig``
不改动 ``_LightRagRuntime`` 的行为RAG 路径保持零回归风险。
典型用法::
provider = EmbeddingProvider.from_settings(session)
vectors = provider.embed(["差旅超标", "票单不一致"])
dim = provider.dimension()
"""
from __future__ import annotations
from typing import Any
from app.core.logging import get_logger
from app.services.knowledge_rag_runtime import (
DEFAULT_EMBEDDING_TIMEOUT_SECONDS,
KnowledgeRagError,
RuntimeModelConfig,
_build_headers,
_ensure_path,
_extract_embedding_vectors,
_normalize_endpoint,
_send_json_request,
)
logger = get_logger("app.services.embedding_provider")
def _runtime_model_config_from_dict(config: dict[str, str]) -> RuntimeModelConfig:
"""把 SettingsService.get_runtime_model_config 返回的 dict 转成 dataclass。"""
return RuntimeModelConfig(
slot=str(config.get("slot") or "embedding"),
provider=str(config.get("provider") or ""),
model=str(config.get("model") or ""),
endpoint=str(config.get("endpoint") or ""),
api_key=str(config.get("apiKey") or ""),
capability=str(config.get("capability") or ""),
)
class EmbeddingProvider:
"""对 embedding 模型的轻量封装。
设计要点:
- 持有一个 ``RuntimeModelConfig``,构造即固定,不依赖 LightRAG。
- 复用 ``knowledge_rag_runtime`` 的 HTTP 调用纯函数,行为与 RAG 完全一致。
- 维度采用惰性探测(首次 embed 后缓存),避免空构造就打远端。
"""
def __init__(self, config: RuntimeModelConfig) -> None:
self.config = config
self._dimension: int | None = None
@classmethod
def from_settings(cls, session: Any) -> "EmbeddingProvider":
"""从 SettingsService 取 embedding 配置构造 provider。"""
from app.services.settings import SettingsService
raw = SettingsService(session).get_runtime_model_config("embedding")
return cls(_runtime_model_config_from_dict(raw))
def embed(self, texts: list[str]) -> list[list[float]]:
"""对一组文本做 embedding返回与输入等长的向量列表。"""
if not texts:
return []
return _request_embeddings_public(self.config, texts)
def dimension(self) -> int:
"""探测 embedding 维度,结果缓存。失败抛 KnowledgeRagError。"""
if self._dimension is None:
vectors = self.embed(["dimension probe"])
if not vectors or not isinstance(vectors[0], list):
raise KnowledgeRagError("无法从 embedding 模型返回结果中解析向量维度。")
self._dimension = len(vectors[0])
if self._dimension <= 0:
raise KnowledgeRagError("embedding 模型返回了无效的向量维度。")
return self._dimension
def _request_embeddings_public(
config: RuntimeModelConfig,
texts: list[str],
) -> list[list[float]]:
"""按 provider 分支构造 embedding 请求。
与 ``_LightRagRuntime._request_embeddings`` 实现保持一致,
保证 few-shot 检索与 RAG 走同一套调用语义。
"""
from app.services.model_connectivity import AZURE_API_VERSION
if config.provider == "Azure OpenAI":
from app.services.knowledge_rag_runtime import _build_azure_deployment_base
url = f"{_build_azure_deployment_base(config.endpoint, config.model)}/embeddings?api-version={AZURE_API_VERSION}"
payload: dict[str, Any] = {"input": texts}
status_code, body = _send_json_request(
"POST",
url,
headers=_build_headers(config.api_key, use_bearer=False, use_api_key=True),
payload=payload,
timeout_seconds=DEFAULT_EMBEDDING_TIMEOUT_SECONDS,
)
elif config.provider == "Ollama":
url = _ensure_path(_normalize_endpoint(config.endpoint), "api/embed")
payload = {"model": config.model, "input": texts}
status_code, body = _send_json_request(
"POST",
url,
headers={"Content-Type": "application/json", "Accept": "application/json"},
payload=payload,
timeout_seconds=DEFAULT_EMBEDDING_TIMEOUT_SECONDS,
)
else:
url = _ensure_path(_normalize_endpoint(config.endpoint), "embeddings")
payload = {"model": config.model, "input": texts}
status_code, body = _send_json_request(
"POST",
url,
headers=_build_headers(config.api_key, use_bearer=True),
payload=payload,
timeout_seconds=DEFAULT_EMBEDDING_TIMEOUT_SECONDS,
)
from http import HTTPStatus
if status_code >= HTTPStatus.BAD_REQUEST:
raise KnowledgeRagError(f"embedding 模型返回异常状态码 {status_code}")
return _extract_embedding_vectors(body, provider=config.provider)