feat(flywheel): 抽公共 EmbeddingProvider 并新增 FewShotSample 模型
- 从 knowledge_rag_runtime 抽出 embedding 调用逻辑为独立 EmbeddingProvider, 复用现有 HTTP 纯函数,RAG 路径零回归 - 新增 FewShotSample 表模型(样本池),注册到 db/base.py 和 models/__init__.py 供 few-shot 飞轮沉淀已确认风险观测
This commit is contained in:
@@ -15,6 +15,7 @@ from app.models.budget import BudgetAllocation, BudgetReservation, BudgetTransac
|
|||||||
from app.models.employee_change_log import EmployeeChangeLog
|
from app.models.employee_change_log import EmployeeChangeLog
|
||||||
from app.models.employee_behavior_profile import EmployeeBehaviorProfileSnapshot
|
from app.models.employee_behavior_profile import EmployeeBehaviorProfileSnapshot
|
||||||
from app.models.employee import Employee
|
from app.models.employee import Employee
|
||||||
|
from app.models.few_shot_sample import FewShotSample
|
||||||
from app.models.financial_record import (
|
from app.models.financial_record import (
|
||||||
AccountsPayableRecord,
|
AccountsPayableRecord,
|
||||||
AccountsReceivableRecord,
|
AccountsReceivableRecord,
|
||||||
@@ -57,6 +58,7 @@ __all__ = [
|
|||||||
"EmployeeBehaviorProfileSnapshot",
|
"EmployeeBehaviorProfileSnapshot",
|
||||||
"EmployeeChangeLog",
|
"EmployeeChangeLog",
|
||||||
"ExpenseClaim",
|
"ExpenseClaim",
|
||||||
|
"FewShotSample",
|
||||||
"ExpenseClaimItem",
|
"ExpenseClaimItem",
|
||||||
"HermesTaskConfig",
|
"HermesTaskConfig",
|
||||||
"HermesTaskExecutionLog",
|
"HermesTaskExecutionLog",
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from app.models.budget import BudgetAllocation, BudgetReservation, BudgetTransac
|
|||||||
from app.models.employee_change_log import EmployeeChangeLog
|
from app.models.employee_change_log import EmployeeChangeLog
|
||||||
from app.models.employee_behavior_profile import EmployeeBehaviorProfileSnapshot
|
from app.models.employee_behavior_profile import EmployeeBehaviorProfileSnapshot
|
||||||
from app.models.employee import Employee
|
from app.models.employee import Employee
|
||||||
|
from app.models.few_shot_sample import FewShotSample
|
||||||
from app.models.financial_record import (
|
from app.models.financial_record import (
|
||||||
AccountsPayableRecord,
|
AccountsPayableRecord,
|
||||||
AccountsReceivableRecord,
|
AccountsReceivableRecord,
|
||||||
@@ -49,6 +50,7 @@ __all__ = [
|
|||||||
"EmployeeChangeLog",
|
"EmployeeChangeLog",
|
||||||
"ExpenseClaim",
|
"ExpenseClaim",
|
||||||
"ExpenseClaimItem",
|
"ExpenseClaimItem",
|
||||||
|
"FewShotSample",
|
||||||
"HermesTaskConfig",
|
"HermesTaskConfig",
|
||||||
"HermesTaskExecutionLog",
|
"HermesTaskExecutionLog",
|
||||||
"HermesRiskReport",
|
"HermesRiskReport",
|
||||||
|
|||||||
54
server/src/app/models/few_shot_sample.py
Normal file
54
server/src/app/models/few_shot_sample.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy import DateTime, ForeignKey, Index, String, Text, func
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
from sqlalchemy.types import JSON
|
||||||
|
|
||||||
|
from app.db.base_class import Base
|
||||||
|
|
||||||
|
|
||||||
|
class FewShotSample(Base):
|
||||||
|
"""已确认的风险观测样本,供 few-shot 检索注入使用。
|
||||||
|
|
||||||
|
数据来源是 ``RiskObservation`` 上人工确认为 confirmed / false_positive 的观测,
|
||||||
|
入库后同时写一份向量到 Qdrant 的 ``few_shot_samples`` collection。
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "few_shot_samples"
|
||||||
|
__table_args__ = (
|
||||||
|
Index("ix_few_shot_samples_scene_label", "scene", "label"),
|
||||||
|
Index("ix_few_shot_samples_domain_risk_type", "domain", "risk_type"),
|
||||||
|
)
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||||
|
sample_key: Mapped[str] = mapped_column(String(160), unique=True, index=True)
|
||||||
|
source_observation_id: Mapped[str | None] = mapped_column(
|
||||||
|
ForeignKey("risk_observations.id"),
|
||||||
|
nullable=True,
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
scene: Mapped[str] = mapped_column(String(50), default="risk_rule_generation", index=True)
|
||||||
|
domain: Mapped[str] = mapped_column(String(50), default="", index=True)
|
||||||
|
risk_type: Mapped[str] = mapped_column(String(80), default="", index=True)
|
||||||
|
risk_level: Mapped[str] = mapped_column(String(20), default="")
|
||||||
|
|
||||||
|
label: Mapped[str] = mapped_column(String(30), default="confirmed", index=True)
|
||||||
|
case_text: Mapped[str] = mapped_column(Text(), default="")
|
||||||
|
conclusion_text: Mapped[str] = mapped_column(Text(), default="")
|
||||||
|
payload_json: Mapped[dict[str, Any]] = mapped_column(JSON, default=dict)
|
||||||
|
|
||||||
|
vector_id: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||||
|
status: Mapped[str] = mapped_column(String(20), default="active", index=True)
|
||||||
|
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime, default=func.now(), server_default=func.now())
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime,
|
||||||
|
default=func.now(),
|
||||||
|
onupdate=func.now(),
|
||||||
|
server_default=func.now(),
|
||||||
|
)
|
||||||
138
server/src/app/services/embedding_provider.py
Normal file
138
server/src/app/services/embedding_provider.py
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
"""公共 Embedding 提供者。
|
||||||
|
|
||||||
|
把 ``knowledge_rag_runtime`` 里 embedding 调用逻辑抽出来,供 RAG 和
|
||||||
|
few-shot 检索复用。本模块只依赖现有模块级纯函数和 ``RuntimeModelConfig``,
|
||||||
|
不改动 ``_LightRagRuntime`` 的行为,RAG 路径保持零回归风险。
|
||||||
|
|
||||||
|
典型用法::
|
||||||
|
|
||||||
|
provider = EmbeddingProvider.from_settings(session)
|
||||||
|
vectors = provider.embed(["差旅超标", "票单不一致"])
|
||||||
|
dim = provider.dimension()
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.core.logging import get_logger
|
||||||
|
from app.services.knowledge_rag_runtime import (
|
||||||
|
DEFAULT_EMBEDDING_TIMEOUT_SECONDS,
|
||||||
|
KnowledgeRagError,
|
||||||
|
RuntimeModelConfig,
|
||||||
|
_build_headers,
|
||||||
|
_ensure_path,
|
||||||
|
_extract_embedding_vectors,
|
||||||
|
_normalize_endpoint,
|
||||||
|
_send_json_request,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = get_logger("app.services.embedding_provider")
|
||||||
|
|
||||||
|
|
||||||
|
def _runtime_model_config_from_dict(config: dict[str, str]) -> RuntimeModelConfig:
|
||||||
|
"""把 SettingsService.get_runtime_model_config 返回的 dict 转成 dataclass。"""
|
||||||
|
|
||||||
|
return RuntimeModelConfig(
|
||||||
|
slot=str(config.get("slot") or "embedding"),
|
||||||
|
provider=str(config.get("provider") or ""),
|
||||||
|
model=str(config.get("model") or ""),
|
||||||
|
endpoint=str(config.get("endpoint") or ""),
|
||||||
|
api_key=str(config.get("apiKey") or ""),
|
||||||
|
capability=str(config.get("capability") or ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingProvider:
|
||||||
|
"""对 embedding 模型的轻量封装。
|
||||||
|
|
||||||
|
设计要点:
|
||||||
|
- 持有一个 ``RuntimeModelConfig``,构造即固定,不依赖 LightRAG。
|
||||||
|
- 复用 ``knowledge_rag_runtime`` 的 HTTP 调用纯函数,行为与 RAG 完全一致。
|
||||||
|
- 维度采用惰性探测(首次 embed 后缓存),避免空构造就打远端。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: RuntimeModelConfig) -> None:
|
||||||
|
self.config = config
|
||||||
|
self._dimension: int | None = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_settings(cls, session: Any) -> "EmbeddingProvider":
|
||||||
|
"""从 SettingsService 取 embedding 配置构造 provider。"""
|
||||||
|
|
||||||
|
from app.services.settings import SettingsService
|
||||||
|
|
||||||
|
raw = SettingsService(session).get_runtime_model_config("embedding")
|
||||||
|
return cls(_runtime_model_config_from_dict(raw))
|
||||||
|
|
||||||
|
def embed(self, texts: list[str]) -> list[list[float]]:
|
||||||
|
"""对一组文本做 embedding,返回与输入等长的向量列表。"""
|
||||||
|
|
||||||
|
if not texts:
|
||||||
|
return []
|
||||||
|
return _request_embeddings_public(self.config, texts)
|
||||||
|
|
||||||
|
def dimension(self) -> int:
|
||||||
|
"""探测 embedding 维度,结果缓存。失败抛 KnowledgeRagError。"""
|
||||||
|
|
||||||
|
if self._dimension is None:
|
||||||
|
vectors = self.embed(["dimension probe"])
|
||||||
|
if not vectors or not isinstance(vectors[0], list):
|
||||||
|
raise KnowledgeRagError("无法从 embedding 模型返回结果中解析向量维度。")
|
||||||
|
self._dimension = len(vectors[0])
|
||||||
|
if self._dimension <= 0:
|
||||||
|
raise KnowledgeRagError("embedding 模型返回了无效的向量维度。")
|
||||||
|
return self._dimension
|
||||||
|
|
||||||
|
|
||||||
|
def _request_embeddings_public(
|
||||||
|
config: RuntimeModelConfig,
|
||||||
|
texts: list[str],
|
||||||
|
) -> list[list[float]]:
|
||||||
|
"""按 provider 分支构造 embedding 请求。
|
||||||
|
|
||||||
|
与 ``_LightRagRuntime._request_embeddings`` 实现保持一致,
|
||||||
|
保证 few-shot 检索与 RAG 走同一套调用语义。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from app.services.model_connectivity import AZURE_API_VERSION
|
||||||
|
|
||||||
|
if config.provider == "Azure OpenAI":
|
||||||
|
from app.services.knowledge_rag_runtime import _build_azure_deployment_base
|
||||||
|
|
||||||
|
url = f"{_build_azure_deployment_base(config.endpoint, config.model)}/embeddings?api-version={AZURE_API_VERSION}"
|
||||||
|
payload: dict[str, Any] = {"input": texts}
|
||||||
|
status_code, body = _send_json_request(
|
||||||
|
"POST",
|
||||||
|
url,
|
||||||
|
headers=_build_headers(config.api_key, use_bearer=False, use_api_key=True),
|
||||||
|
payload=payload,
|
||||||
|
timeout_seconds=DEFAULT_EMBEDDING_TIMEOUT_SECONDS,
|
||||||
|
)
|
||||||
|
elif config.provider == "Ollama":
|
||||||
|
url = _ensure_path(_normalize_endpoint(config.endpoint), "api/embed")
|
||||||
|
payload = {"model": config.model, "input": texts}
|
||||||
|
status_code, body = _send_json_request(
|
||||||
|
"POST",
|
||||||
|
url,
|
||||||
|
headers={"Content-Type": "application/json", "Accept": "application/json"},
|
||||||
|
payload=payload,
|
||||||
|
timeout_seconds=DEFAULT_EMBEDDING_TIMEOUT_SECONDS,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
url = _ensure_path(_normalize_endpoint(config.endpoint), "embeddings")
|
||||||
|
payload = {"model": config.model, "input": texts}
|
||||||
|
status_code, body = _send_json_request(
|
||||||
|
"POST",
|
||||||
|
url,
|
||||||
|
headers=_build_headers(config.api_key, use_bearer=True),
|
||||||
|
payload=payload,
|
||||||
|
timeout_seconds=DEFAULT_EMBEDDING_TIMEOUT_SECONDS,
|
||||||
|
)
|
||||||
|
|
||||||
|
from http import HTTPStatus
|
||||||
|
|
||||||
|
if status_code >= HTTPStatus.BAD_REQUEST:
|
||||||
|
raise KnowledgeRagError(f"embedding 模型返回异常状态码 {status_code}。")
|
||||||
|
|
||||||
|
return _extract_embedding_vectors(body, provider=config.provider)
|
||||||
Reference in New Issue
Block a user