feat(flywheel): 抽公共 EmbeddingProvider 并新增 FewShotSample 模型
- 从 knowledge_rag_runtime 抽出 embedding 调用逻辑为独立 EmbeddingProvider, 复用现有 HTTP 纯函数,RAG 路径零回归 - 新增 FewShotSample 表模型(样本池),注册到 db/base.py 和 models/__init__.py 供 few-shot 飞轮沉淀已确认风险观测
This commit is contained in:
@@ -15,6 +15,7 @@ from app.models.budget import BudgetAllocation, BudgetReservation, BudgetTransac
|
||||
from app.models.employee_change_log import EmployeeChangeLog
|
||||
from app.models.employee_behavior_profile import EmployeeBehaviorProfileSnapshot
|
||||
from app.models.employee import Employee
|
||||
from app.models.few_shot_sample import FewShotSample
|
||||
from app.models.financial_record import (
|
||||
AccountsPayableRecord,
|
||||
AccountsReceivableRecord,
|
||||
@@ -57,6 +58,7 @@ __all__ = [
|
||||
"EmployeeBehaviorProfileSnapshot",
|
||||
"EmployeeChangeLog",
|
||||
"ExpenseClaim",
|
||||
"FewShotSample",
|
||||
"ExpenseClaimItem",
|
||||
"HermesTaskConfig",
|
||||
"HermesTaskExecutionLog",
|
||||
|
||||
@@ -8,6 +8,7 @@ from app.models.budget import BudgetAllocation, BudgetReservation, BudgetTransac
|
||||
from app.models.employee_change_log import EmployeeChangeLog
|
||||
from app.models.employee_behavior_profile import EmployeeBehaviorProfileSnapshot
|
||||
from app.models.employee import Employee
|
||||
from app.models.few_shot_sample import FewShotSample
|
||||
from app.models.financial_record import (
|
||||
AccountsPayableRecord,
|
||||
AccountsReceivableRecord,
|
||||
@@ -49,6 +50,7 @@ __all__ = [
|
||||
"EmployeeChangeLog",
|
||||
"ExpenseClaim",
|
||||
"ExpenseClaimItem",
|
||||
"FewShotSample",
|
||||
"HermesTaskConfig",
|
||||
"HermesTaskExecutionLog",
|
||||
"HermesRiskReport",
|
||||
|
||||
54
server/src/app/models/few_shot_sample.py
Normal file
54
server/src/app/models/few_shot_sample.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import DateTime, ForeignKey, Index, String, Text, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy.types import JSON
|
||||
|
||||
from app.db.base_class import Base
|
||||
|
||||
|
||||
class FewShotSample(Base):
|
||||
"""已确认的风险观测样本,供 few-shot 检索注入使用。
|
||||
|
||||
数据来源是 ``RiskObservation`` 上人工确认为 confirmed / false_positive 的观测,
|
||||
入库后同时写一份向量到 Qdrant 的 ``few_shot_samples`` collection。
|
||||
"""
|
||||
|
||||
__tablename__ = "few_shot_samples"
|
||||
__table_args__ = (
|
||||
Index("ix_few_shot_samples_scene_label", "scene", "label"),
|
||||
Index("ix_few_shot_samples_domain_risk_type", "domain", "risk_type"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
sample_key: Mapped[str] = mapped_column(String(160), unique=True, index=True)
|
||||
source_observation_id: Mapped[str | None] = mapped_column(
|
||||
ForeignKey("risk_observations.id"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
|
||||
scene: Mapped[str] = mapped_column(String(50), default="risk_rule_generation", index=True)
|
||||
domain: Mapped[str] = mapped_column(String(50), default="", index=True)
|
||||
risk_type: Mapped[str] = mapped_column(String(80), default="", index=True)
|
||||
risk_level: Mapped[str] = mapped_column(String(20), default="")
|
||||
|
||||
label: Mapped[str] = mapped_column(String(30), default="confirmed", index=True)
|
||||
case_text: Mapped[str] = mapped_column(Text(), default="")
|
||||
conclusion_text: Mapped[str] = mapped_column(Text(), default="")
|
||||
payload_json: Mapped[dict[str, Any]] = mapped_column(JSON, default=dict)
|
||||
|
||||
vector_id: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
status: Mapped[str] = mapped_column(String(20), default="active", index=True)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=func.now(), server_default=func.now())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
default=func.now(),
|
||||
onupdate=func.now(),
|
||||
server_default=func.now(),
|
||||
)
|
||||
138
server/src/app/services/embedding_provider.py
Normal file
138
server/src/app/services/embedding_provider.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""公共 Embedding 提供者。
|
||||
|
||||
把 ``knowledge_rag_runtime`` 里 embedding 调用逻辑抽出来,供 RAG 和
|
||||
few-shot 检索复用。本模块只依赖现有模块级纯函数和 ``RuntimeModelConfig``,
|
||||
不改动 ``_LightRagRuntime`` 的行为,RAG 路径保持零回归风险。
|
||||
|
||||
典型用法::
|
||||
|
||||
provider = EmbeddingProvider.from_settings(session)
|
||||
vectors = provider.embed(["差旅超标", "票单不一致"])
|
||||
dim = provider.dimension()
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.services.knowledge_rag_runtime import (
|
||||
DEFAULT_EMBEDDING_TIMEOUT_SECONDS,
|
||||
KnowledgeRagError,
|
||||
RuntimeModelConfig,
|
||||
_build_headers,
|
||||
_ensure_path,
|
||||
_extract_embedding_vectors,
|
||||
_normalize_endpoint,
|
||||
_send_json_request,
|
||||
)
|
||||
|
||||
logger = get_logger("app.services.embedding_provider")
|
||||
|
||||
|
||||
def _runtime_model_config_from_dict(config: dict[str, str]) -> RuntimeModelConfig:
|
||||
"""把 SettingsService.get_runtime_model_config 返回的 dict 转成 dataclass。"""
|
||||
|
||||
return RuntimeModelConfig(
|
||||
slot=str(config.get("slot") or "embedding"),
|
||||
provider=str(config.get("provider") or ""),
|
||||
model=str(config.get("model") or ""),
|
||||
endpoint=str(config.get("endpoint") or ""),
|
||||
api_key=str(config.get("apiKey") or ""),
|
||||
capability=str(config.get("capability") or ""),
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingProvider:
|
||||
"""对 embedding 模型的轻量封装。
|
||||
|
||||
设计要点:
|
||||
- 持有一个 ``RuntimeModelConfig``,构造即固定,不依赖 LightRAG。
|
||||
- 复用 ``knowledge_rag_runtime`` 的 HTTP 调用纯函数,行为与 RAG 完全一致。
|
||||
- 维度采用惰性探测(首次 embed 后缓存),避免空构造就打远端。
|
||||
"""
|
||||
|
||||
def __init__(self, config: RuntimeModelConfig) -> None:
|
||||
self.config = config
|
||||
self._dimension: int | None = None
|
||||
|
||||
@classmethod
|
||||
def from_settings(cls, session: Any) -> "EmbeddingProvider":
|
||||
"""从 SettingsService 取 embedding 配置构造 provider。"""
|
||||
|
||||
from app.services.settings import SettingsService
|
||||
|
||||
raw = SettingsService(session).get_runtime_model_config("embedding")
|
||||
return cls(_runtime_model_config_from_dict(raw))
|
||||
|
||||
def embed(self, texts: list[str]) -> list[list[float]]:
|
||||
"""对一组文本做 embedding,返回与输入等长的向量列表。"""
|
||||
|
||||
if not texts:
|
||||
return []
|
||||
return _request_embeddings_public(self.config, texts)
|
||||
|
||||
def dimension(self) -> int:
|
||||
"""探测 embedding 维度,结果缓存。失败抛 KnowledgeRagError。"""
|
||||
|
||||
if self._dimension is None:
|
||||
vectors = self.embed(["dimension probe"])
|
||||
if not vectors or not isinstance(vectors[0], list):
|
||||
raise KnowledgeRagError("无法从 embedding 模型返回结果中解析向量维度。")
|
||||
self._dimension = len(vectors[0])
|
||||
if self._dimension <= 0:
|
||||
raise KnowledgeRagError("embedding 模型返回了无效的向量维度。")
|
||||
return self._dimension
|
||||
|
||||
|
||||
def _request_embeddings_public(
|
||||
config: RuntimeModelConfig,
|
||||
texts: list[str],
|
||||
) -> list[list[float]]:
|
||||
"""按 provider 分支构造 embedding 请求。
|
||||
|
||||
与 ``_LightRagRuntime._request_embeddings`` 实现保持一致,
|
||||
保证 few-shot 检索与 RAG 走同一套调用语义。
|
||||
"""
|
||||
|
||||
from app.services.model_connectivity import AZURE_API_VERSION
|
||||
|
||||
if config.provider == "Azure OpenAI":
|
||||
from app.services.knowledge_rag_runtime import _build_azure_deployment_base
|
||||
|
||||
url = f"{_build_azure_deployment_base(config.endpoint, config.model)}/embeddings?api-version={AZURE_API_VERSION}"
|
||||
payload: dict[str, Any] = {"input": texts}
|
||||
status_code, body = _send_json_request(
|
||||
"POST",
|
||||
url,
|
||||
headers=_build_headers(config.api_key, use_bearer=False, use_api_key=True),
|
||||
payload=payload,
|
||||
timeout_seconds=DEFAULT_EMBEDDING_TIMEOUT_SECONDS,
|
||||
)
|
||||
elif config.provider == "Ollama":
|
||||
url = _ensure_path(_normalize_endpoint(config.endpoint), "api/embed")
|
||||
payload = {"model": config.model, "input": texts}
|
||||
status_code, body = _send_json_request(
|
||||
"POST",
|
||||
url,
|
||||
headers={"Content-Type": "application/json", "Accept": "application/json"},
|
||||
payload=payload,
|
||||
timeout_seconds=DEFAULT_EMBEDDING_TIMEOUT_SECONDS,
|
||||
)
|
||||
else:
|
||||
url = _ensure_path(_normalize_endpoint(config.endpoint), "embeddings")
|
||||
payload = {"model": config.model, "input": texts}
|
||||
status_code, body = _send_json_request(
|
||||
"POST",
|
||||
url,
|
||||
headers=_build_headers(config.api_key, use_bearer=True),
|
||||
payload=payload,
|
||||
timeout_seconds=DEFAULT_EMBEDDING_TIMEOUT_SECONDS,
|
||||
)
|
||||
|
||||
from http import HTTPStatus
|
||||
|
||||
if status_code >= HTTPStatus.BAD_REQUEST:
|
||||
raise KnowledgeRagError(f"embedding 模型返回异常状态码 {status_code}。")
|
||||
|
||||
return _extract_embedding_vectors(body, provider=config.provider)
|
||||
Reference in New Issue
Block a user