feat: add system settings with model connectivity and encrypted storage
This commit is contained in:
44
server/src/app/api/v1/endpoints/settings.py
Normal file
44
server/src/app/api/v1/endpoints/settings.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.deps import get_db
|
||||
from app.schemas.settings import (
|
||||
ModelConnectivityTestRead,
|
||||
ModelConnectivityTestRequest,
|
||||
SettingsRead,
|
||||
SettingsWrite,
|
||||
)
|
||||
from app.services.model_connectivity import probe_model_connectivity
|
||||
from app.services.settings import SettingsService
|
||||
|
||||
router = APIRouter(prefix="/settings")
|
||||
DbSession = Annotated[Session, Depends(get_db)]
|
||||
|
||||
|
||||
@router.get("", response_model=SettingsRead)
|
||||
def get_settings(db: DbSession) -> SettingsRead:
|
||||
return SettingsService(db).get_settings_snapshot()
|
||||
|
||||
|
||||
@router.put("", response_model=SettingsRead)
|
||||
def update_settings(payload: SettingsWrite, db: DbSession) -> SettingsRead:
|
||||
try:
|
||||
return SettingsService(db).save_settings_snapshot(payload)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
|
||||
|
||||
|
||||
@router.post("/model-connectivity", response_model=ModelConnectivityTestRead)
|
||||
def test_model_connectivity(payload: ModelConnectivityTestRequest, db: DbSession) -> ModelConnectivityTestRead:
|
||||
resolved_payload = payload
|
||||
|
||||
if not payload.api_key and payload.slot:
|
||||
stored_api_key = SettingsService(db).load_saved_model_api_key(payload.slot)
|
||||
if stored_api_key:
|
||||
resolved_payload = payload.model_copy(update={"api_key": stored_api_key})
|
||||
|
||||
return probe_model_connectivity(resolved_payload)
|
||||
@@ -5,6 +5,7 @@ from app.api.v1.endpoints.bootstrap import router as bootstrap_router
|
||||
from app.api.v1.endpoints.employees import router as employees_router
|
||||
from app.api.v1.endpoints.health import router as health_router
|
||||
from app.api.v1.endpoints.reimbursements import router as reimbursements_router
|
||||
from app.api.v1.endpoints.settings import router as settings_router
|
||||
|
||||
router = APIRouter()
|
||||
router.include_router(health_router, tags=["health"])
|
||||
@@ -12,3 +13,4 @@ router.include_router(bootstrap_router, tags=["bootstrap"])
|
||||
router.include_router(auth_router, tags=["auth"])
|
||||
router.include_router(employees_router, prefix="/employees", tags=["employees"])
|
||||
router.include_router(reimbursements_router, prefix="/reimbursements", tags=["reimbursements"])
|
||||
router.include_router(settings_router, tags=["settings"])
|
||||
|
||||
91
server/src/app/core/secret_box.py
Normal file
91
server/src/app/core/secret_box.py
Normal file
@@ -0,0 +1,91 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import secrets
|
||||
from pathlib import Path
|
||||
|
||||
from app.core.config import SERVER_DIR
|
||||
|
||||
SECRET_KEY_FILE = SERVER_DIR / ".secrets" / "settings.key"
|
||||
SECRET_BOX_VERSION = "v1"
|
||||
KEY_BYTES = 32
|
||||
NONCE_BYTES = 16
|
||||
MAC_BYTES = 32
|
||||
BLOCK_BYTES = 32
|
||||
|
||||
|
||||
def _ensure_secret_dir(path: Path) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def get_or_create_secret_key() -> bytes:
|
||||
_ensure_secret_dir(SECRET_KEY_FILE)
|
||||
|
||||
if SECRET_KEY_FILE.exists():
|
||||
encoded = SECRET_KEY_FILE.read_text(encoding="utf-8").strip()
|
||||
if encoded:
|
||||
return base64.urlsafe_b64decode(encoded.encode("ascii"))
|
||||
|
||||
secret_key = secrets.token_bytes(KEY_BYTES)
|
||||
encoded = base64.urlsafe_b64encode(secret_key).decode("ascii")
|
||||
SECRET_KEY_FILE.write_text(encoded, encoding="utf-8")
|
||||
return secret_key
|
||||
|
||||
|
||||
def _keystream(secret_key: bytes, nonce: bytes, length: int) -> bytes:
|
||||
chunks: list[bytes] = []
|
||||
counter = 0
|
||||
|
||||
while sum(len(chunk) for chunk in chunks) < length:
|
||||
block = hmac.new(
|
||||
secret_key,
|
||||
b"stream:" + nonce + counter.to_bytes(4, "big"),
|
||||
hashlib.sha256,
|
||||
).digest()
|
||||
chunks.append(block)
|
||||
counter += 1
|
||||
|
||||
return b"".join(chunks)[:length]
|
||||
|
||||
|
||||
def encrypt_secret(value: str) -> str:
|
||||
if not value:
|
||||
return ""
|
||||
|
||||
secret_key = get_or_create_secret_key()
|
||||
nonce = secrets.token_bytes(NONCE_BYTES)
|
||||
plaintext = value.encode("utf-8")
|
||||
ciphertext = bytes(a ^ b for a, b in zip(plaintext, _keystream(secret_key, nonce, len(plaintext)), strict=False))
|
||||
mac = hmac.new(secret_key, b"mac:" + nonce + ciphertext, hashlib.sha256).digest()
|
||||
|
||||
encoded_nonce = base64.urlsafe_b64encode(nonce).decode("ascii")
|
||||
encoded_ciphertext = base64.urlsafe_b64encode(ciphertext).decode("ascii")
|
||||
encoded_mac = base64.urlsafe_b64encode(mac).decode("ascii")
|
||||
return f"{SECRET_BOX_VERSION}${encoded_nonce}${encoded_ciphertext}${encoded_mac}"
|
||||
|
||||
|
||||
def decrypt_secret(value: str) -> str:
|
||||
if not value:
|
||||
return ""
|
||||
|
||||
try:
|
||||
version, encoded_nonce, encoded_ciphertext, encoded_mac = value.split("$", 3)
|
||||
except ValueError as exc:
|
||||
raise ValueError("Invalid secret payload format") from exc
|
||||
|
||||
if version != SECRET_BOX_VERSION:
|
||||
raise ValueError("Unsupported secret payload version")
|
||||
|
||||
secret_key = get_or_create_secret_key()
|
||||
nonce = base64.urlsafe_b64decode(encoded_nonce.encode("ascii"))
|
||||
ciphertext = base64.urlsafe_b64decode(encoded_ciphertext.encode("ascii"))
|
||||
expected_mac = base64.urlsafe_b64decode(encoded_mac.encode("ascii"))
|
||||
actual_mac = hmac.new(secret_key, b"mac:" + nonce + ciphertext, hashlib.sha256).digest()
|
||||
|
||||
if not hmac.compare_digest(actual_mac, expected_mac):
|
||||
raise ValueError("Secret payload integrity check failed")
|
||||
|
||||
plaintext = bytes(a ^ b for a, b in zip(ciphertext, _keystream(secret_key, nonce, len(ciphertext)), strict=False))
|
||||
return plaintext.decode("utf-8")
|
||||
@@ -5,6 +5,8 @@ from app.models.employee import Employee
|
||||
from app.models.organization import OrganizationUnit
|
||||
from app.models.reimbursement import ReimbursementRequest
|
||||
from app.models.role import Role
|
||||
from app.models.system_setting import SystemSetting
|
||||
from app.models.system_setting_secret import SystemSettingSecret
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
@@ -14,4 +16,6 @@ __all__ = [
|
||||
"OrganizationUnit",
|
||||
"ReimbursementRequest",
|
||||
"Role",
|
||||
"SystemSetting",
|
||||
"SystemSettingSecret",
|
||||
]
|
||||
|
||||
@@ -4,6 +4,8 @@ from app.models.employee import Employee
|
||||
from app.models.organization import OrganizationUnit
|
||||
from app.models.reimbursement import ReimbursementRequest
|
||||
from app.models.role import Role
|
||||
from app.models.system_setting import SystemSetting
|
||||
from app.models.system_setting_secret import SystemSettingSecret
|
||||
|
||||
__all__ = [
|
||||
"ApprovalRecord",
|
||||
@@ -12,4 +14,6 @@ __all__ = [
|
||||
"OrganizationUnit",
|
||||
"ReimbursementRequest",
|
||||
"Role",
|
||||
"SystemSetting",
|
||||
"SystemSettingSecret",
|
||||
]
|
||||
|
||||
68
server/src/app/models/system_setting.py
Normal file
68
server/src/app/models/system_setting.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, Integer, String, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.db.base_class import Base
|
||||
|
||||
|
||||
class SystemSetting(Base):
|
||||
__tablename__ = "system_settings"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default="default")
|
||||
|
||||
company_name: Mapped[str] = mapped_column(String(120), default="X-Financial")
|
||||
display_name: Mapped[str] = mapped_column(String(120), default="X-Financial")
|
||||
company_code: Mapped[str] = mapped_column(String(64), default="XF-001")
|
||||
record_number: Mapped[str] = mapped_column(String(120), default="")
|
||||
copyright_text: Mapped[str] = mapped_column(String(255), default="")
|
||||
|
||||
admin_account: Mapped[str] = mapped_column(String(120), default="superadmin")
|
||||
admin_email: Mapped[str] = mapped_column(String(255), default="")
|
||||
session_timeout: Mapped[int] = mapped_column(Integer, default=30)
|
||||
notice_email: Mapped[str] = mapped_column(String(255), default="")
|
||||
mfa_enabled: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
strong_password: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
login_alert_enabled: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
|
||||
main_provider: Mapped[str] = mapped_column(String(64), default="Codex")
|
||||
main_model: Mapped[str] = mapped_column(String(255), default="codex-mini-latest")
|
||||
main_endpoint: Mapped[str] = mapped_column(String(512), default="https://api.openai.com/v1")
|
||||
backup_provider: Mapped[str] = mapped_column(String(64), default="GLM")
|
||||
backup_model: Mapped[str] = mapped_column(String(255), default="glm-5.1")
|
||||
backup_endpoint: Mapped[str] = mapped_column(String(512), default="https://open.bigmodel.cn/api/paas/v4/")
|
||||
vlm_provider: Mapped[str] = mapped_column(String(64), default="Gemini")
|
||||
vlm_model: Mapped[str] = mapped_column(String(255), default="gemini-2.5-flash")
|
||||
vlm_endpoint: Mapped[str] = mapped_column(String(512), default="https://generativelanguage.googleapis.com/v1beta/openai/")
|
||||
embedding_provider: Mapped[str] = mapped_column(String(64), default="GLM")
|
||||
embedding_model: Mapped[str] = mapped_column(String(255), default="Embedding-3")
|
||||
embedding_endpoint: Mapped[str] = mapped_column(String(512), default="https://open.bigmodel.cn/api/paas/v4/")
|
||||
|
||||
log_level: Mapped[str] = mapped_column(String(16), default="INFO")
|
||||
retention_days: Mapped[int] = mapped_column(Integer, default=180)
|
||||
archive_cycle: Mapped[str] = mapped_column(String(32), default="weekly")
|
||||
log_path: Mapped[str] = mapped_column(String(255), default="server/logs/app.log")
|
||||
alert_email: Mapped[str] = mapped_column(String(255), default="")
|
||||
operation_audit: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
login_audit: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
mask_sensitive: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
|
||||
smtp_host: Mapped[str] = mapped_column(String(255), default="smtp.exmail.qq.com")
|
||||
smtp_port: Mapped[int] = mapped_column(Integer, default=465)
|
||||
smtp_encryption: Mapped[str] = mapped_column(String(32), default="SSL/TLS")
|
||||
sender_name: Mapped[str] = mapped_column(String(120), default="X-Financial")
|
||||
sender_address: Mapped[str] = mapped_column(String(255), default="")
|
||||
smtp_username: Mapped[str] = mapped_column(String(255), default="")
|
||||
alert_enabled: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
digest_enabled: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
digest_time: Mapped[str] = mapped_column(String(16), default="09:00")
|
||||
default_receiver: Mapped[str] = mapped_column(String(255), default="")
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
28
server/src/app/models/system_setting_secret.py
Normal file
28
server/src/app/models/system_setting_secret.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, String, Text, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.db.base_class import Base
|
||||
|
||||
|
||||
class SystemSettingSecret(Base):
|
||||
__tablename__ = "system_setting_secrets"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default="default")
|
||||
|
||||
admin_password_hash: Mapped[str] = mapped_column(Text, default="")
|
||||
main_api_key_encrypted: Mapped[str] = mapped_column(Text, default="")
|
||||
backup_api_key_encrypted: Mapped[str] = mapped_column(Text, default="")
|
||||
vlm_api_key_encrypted: Mapped[str] = mapped_column(Text, default="")
|
||||
embedding_api_key_encrypted: Mapped[str] = mapped_column(Text, default="")
|
||||
smtp_password_encrypted: Mapped[str] = mapped_column(Text, default="")
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
34
server/src/app/repositories/settings.py
Normal file
34
server/src/app/repositories/settings.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.system_setting import SystemSetting
|
||||
from app.models.system_setting_secret import SystemSettingSecret
|
||||
|
||||
SETTINGS_ROW_ID = "default"
|
||||
|
||||
|
||||
class SettingsRepository:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
|
||||
def get_settings(self) -> SystemSetting | None:
|
||||
stmt = select(SystemSetting).where(SystemSetting.id == SETTINGS_ROW_ID)
|
||||
return self.db.execute(stmt).scalars().first()
|
||||
|
||||
def get_secrets(self) -> SystemSettingSecret | None:
|
||||
stmt = select(SystemSettingSecret).where(SystemSettingSecret.id == SETTINGS_ROW_ID)
|
||||
return self.db.execute(stmt).scalars().first()
|
||||
|
||||
def save_settings(self, settings: SystemSetting) -> SystemSetting:
|
||||
self.db.add(settings)
|
||||
self.db.commit()
|
||||
self.db.refresh(settings)
|
||||
return settings
|
||||
|
||||
def save_secrets(self, secrets: SystemSettingSecret) -> SystemSettingSecret:
|
||||
self.db.add(secrets)
|
||||
self.db.commit()
|
||||
self.db.refresh(secrets)
|
||||
return secrets
|
||||
185
server/src/app/schemas/settings.py
Normal file
185
server/src/app/schemas/settings.py
Normal file
@@ -0,0 +1,185 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class SettingsCompanyForm(BaseModel):
|
||||
companyName: str = Field(min_length=1, max_length=120)
|
||||
displayName: str = Field(min_length=1, max_length=120)
|
||||
companyCode: str = Field(default="", max_length=64)
|
||||
recordNumber: str = Field(default="", max_length=120)
|
||||
copyright: str = Field(default="", max_length=255)
|
||||
|
||||
@field_validator("companyName", "displayName", "companyCode", "recordNumber", "copyright", mode="before")
|
||||
@classmethod
|
||||
def strip_string(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
return value.strip()
|
||||
|
||||
|
||||
class SettingsAdminForm(BaseModel):
|
||||
adminAccount: str = Field(min_length=1, max_length=120)
|
||||
adminEmail: str = Field(min_length=1, max_length=255)
|
||||
newPassword: str = Field(default="", max_length=128)
|
||||
confirmPassword: str = Field(default="", max_length=128)
|
||||
sessionTimeout: int = Field(default=30, ge=5, le=240)
|
||||
noticeEmail: str = Field(default="", max_length=255)
|
||||
mfaEnabled: bool = True
|
||||
strongPassword: bool = True
|
||||
loginAlertEnabled: bool = True
|
||||
adminPasswordConfigured: bool = False
|
||||
|
||||
@field_validator("adminAccount", "adminEmail", "newPassword", "confirmPassword", "noticeEmail", mode="before")
|
||||
@classmethod
|
||||
def strip_string(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
return value.strip()
|
||||
|
||||
|
||||
class SettingsLlmForm(BaseModel):
|
||||
mainProvider: str = Field(min_length=1, max_length=64)
|
||||
mainModel: str = Field(min_length=1, max_length=255)
|
||||
mainEndpoint: str = Field(min_length=1, max_length=512)
|
||||
mainApiKey: str = Field(default="", max_length=1024)
|
||||
mainApiKeyConfigured: bool = False
|
||||
|
||||
backupProvider: str = Field(min_length=1, max_length=64)
|
||||
backupModel: str = Field(min_length=1, max_length=255)
|
||||
backupEndpoint: str = Field(min_length=1, max_length=512)
|
||||
backupApiKey: str = Field(default="", max_length=1024)
|
||||
backupApiKeyConfigured: bool = False
|
||||
|
||||
vlmProvider: str = Field(min_length=1, max_length=64)
|
||||
vlmModel: str = Field(min_length=1, max_length=255)
|
||||
vlmEndpoint: str = Field(min_length=1, max_length=512)
|
||||
vlmApiKey: str = Field(default="", max_length=1024)
|
||||
vlmApiKeyConfigured: bool = False
|
||||
|
||||
embeddingProvider: str = Field(min_length=1, max_length=64)
|
||||
embeddingModel: str = Field(min_length=1, max_length=255)
|
||||
embeddingEndpoint: str = Field(min_length=1, max_length=512)
|
||||
embeddingApiKey: str = Field(default="", max_length=1024)
|
||||
embeddingApiKeyConfigured: bool = False
|
||||
|
||||
@field_validator(
|
||||
"mainProvider",
|
||||
"mainModel",
|
||||
"mainEndpoint",
|
||||
"mainApiKey",
|
||||
"backupProvider",
|
||||
"backupModel",
|
||||
"backupEndpoint",
|
||||
"backupApiKey",
|
||||
"vlmProvider",
|
||||
"vlmModel",
|
||||
"vlmEndpoint",
|
||||
"vlmApiKey",
|
||||
"embeddingProvider",
|
||||
"embeddingModel",
|
||||
"embeddingEndpoint",
|
||||
"embeddingApiKey",
|
||||
mode="before",
|
||||
)
|
||||
@classmethod
|
||||
def strip_string(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
return value.strip()
|
||||
|
||||
|
||||
class SettingsLogForm(BaseModel):
|
||||
level: str = Field(min_length=1, max_length=16)
|
||||
retentionDays: int = Field(default=180, ge=1, le=3650)
|
||||
archiveCycle: str = Field(default="weekly", max_length=32)
|
||||
logPath: str = Field(min_length=1, max_length=255)
|
||||
alertEmail: str = Field(default="", max_length=255)
|
||||
operationAudit: bool = True
|
||||
loginAudit: bool = True
|
||||
maskSensitive: bool = True
|
||||
|
||||
@field_validator("level", "archiveCycle", "logPath", "alertEmail", mode="before")
|
||||
@classmethod
|
||||
def strip_string(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
return value.strip()
|
||||
|
||||
|
||||
class SettingsMailForm(BaseModel):
|
||||
smtpHost: str = Field(min_length=1, max_length=255)
|
||||
port: int = Field(default=465, ge=1, le=65535)
|
||||
encryption: str = Field(default="SSL/TLS", max_length=32)
|
||||
senderName: str = Field(default="", max_length=120)
|
||||
senderAddress: str = Field(default="", max_length=255)
|
||||
username: str = Field(default="", max_length=255)
|
||||
password: str = Field(default="", max_length=1024)
|
||||
passwordConfigured: bool = False
|
||||
alertEnabled: bool = True
|
||||
digestEnabled: bool = False
|
||||
digestTime: str = Field(default="09:00", max_length=16)
|
||||
defaultReceiver: str = Field(default="", max_length=255)
|
||||
|
||||
@field_validator(
|
||||
"smtpHost",
|
||||
"encryption",
|
||||
"senderName",
|
||||
"senderAddress",
|
||||
"username",
|
||||
"password",
|
||||
"digestTime",
|
||||
"defaultReceiver",
|
||||
mode="before",
|
||||
)
|
||||
@classmethod
|
||||
def strip_string(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
return value.strip()
|
||||
|
||||
|
||||
class SettingsRead(BaseModel):
|
||||
companyForm: SettingsCompanyForm
|
||||
adminForm: SettingsAdminForm
|
||||
llmForm: SettingsLlmForm
|
||||
logForm: SettingsLogForm
|
||||
mailForm: SettingsMailForm
|
||||
|
||||
|
||||
class SettingsWrite(BaseModel):
|
||||
companyForm: SettingsCompanyForm
|
||||
adminForm: SettingsAdminForm
|
||||
llmForm: SettingsLlmForm
|
||||
logForm: SettingsLogForm
|
||||
mailForm: SettingsMailForm
|
||||
|
||||
|
||||
class ModelConnectivityTestRequest(BaseModel):
|
||||
provider: str = Field(min_length=1, max_length=64)
|
||||
endpoint: str = Field(min_length=1, max_length=512)
|
||||
model: str = Field(min_length=1, max_length=255)
|
||||
api_key: str | None = Field(default=None, max_length=1024)
|
||||
capability: Literal["chat", "embedding"] = "chat"
|
||||
slot: Literal["main", "backup", "vlm", "embedding"] | None = None
|
||||
|
||||
@field_validator("provider", "endpoint", "model", "api_key", mode="before")
|
||||
@classmethod
|
||||
def strip_model_string(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
return value.strip()
|
||||
|
||||
|
||||
class ModelConnectivityTestRead(BaseModel):
|
||||
ok: bool
|
||||
provider: str
|
||||
model: str
|
||||
endpoint: str
|
||||
capability: str
|
||||
detail: str
|
||||
status_code: int | None = None
|
||||
checked_at: datetime
|
||||
@@ -5,7 +5,6 @@ from dataclasses import dataclass
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
from app.core.admin_secret import read_admin_secret, verify_admin_secret
|
||||
from app.core.config import get_settings
|
||||
from app.core.logging import get_logger
|
||||
from app.core.security import verify_password
|
||||
@@ -13,6 +12,7 @@ from app.models.employee import Employee
|
||||
from app.schemas.auth import AuthUserRead, LoginRequest, LoginResponse
|
||||
from app.services.employee import EmployeeService
|
||||
from app.services.employee_seed import ROLE_DISPLAY_ORDER
|
||||
from app.services.settings import SettingsService
|
||||
|
||||
logger = get_logger("app.services.auth")
|
||||
|
||||
@@ -53,34 +53,25 @@ class AuthService:
|
||||
|
||||
employee_user = self._authenticate_employee(identifier, password)
|
||||
if employee_user is not None:
|
||||
logger.info("Employee login succeeded identifier=%s role_codes=%s", identifier, ",".join(employee_user.role_codes))
|
||||
logger.info(
|
||||
"Employee login succeeded identifier=%s role_codes=%s",
|
||||
identifier,
|
||||
",".join(employee_user.role_codes),
|
||||
)
|
||||
return LoginResponse(user=self._serialize_user(employee_user))
|
||||
|
||||
logger.warning("Login failed identifier=%s", identifier)
|
||||
raise ValueError("账号或密码错误。")
|
||||
|
||||
def _authenticate_admin(self, identifier: str, password: str) -> AuthenticatedUser | None:
|
||||
record = read_admin_secret()
|
||||
record = SettingsService(self.db).verify_admin_login(identifier, password)
|
||||
if record is None:
|
||||
return None
|
||||
|
||||
admin_username = str(record.get("username", "")).strip()
|
||||
admin_email = str(self.settings.admin_email or "").strip()
|
||||
normalized_identifier = identifier.casefold()
|
||||
|
||||
allowed_identifiers = {
|
||||
value.casefold()
|
||||
for value in [admin_username, admin_email]
|
||||
if value
|
||||
}
|
||||
|
||||
if normalized_identifier not in allowed_identifiers:
|
||||
return None
|
||||
|
||||
if not verify_admin_secret(password, record):
|
||||
return None
|
||||
|
||||
admin_username = record.account.strip()
|
||||
admin_email = record.email.strip()
|
||||
display_name = admin_username or admin_email or "系统管理员"
|
||||
|
||||
return AuthenticatedUser(
|
||||
username=admin_username or admin_email,
|
||||
name=display_name,
|
||||
|
||||
216
server/src/app/services/model_connectivity.py
Normal file
216
server/src/app/services/model_connectivity.py
Normal file
@@ -0,0 +1,216 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from http import HTTPStatus
|
||||
from typing import Any
|
||||
from urllib.error import HTTPError, URLError
|
||||
from urllib.parse import quote
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
from app.schemas.settings import ModelConnectivityTestRead, ModelConnectivityTestRequest
|
||||
|
||||
AZURE_API_VERSION = "2024-10-21"
|
||||
DEFAULT_TIMEOUT_SECONDS = 12
|
||||
|
||||
|
||||
class ConnectivityCheckError(Exception):
|
||||
def __init__(self, message: str, status_code: int | None = None) -> None:
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
def probe_model_connectivity(payload: ModelConnectivityTestRequest) -> ModelConnectivityTestRead:
|
||||
checked_at = datetime.now(timezone.utc)
|
||||
|
||||
try:
|
||||
if payload.provider == "Azure OpenAI":
|
||||
status_code = _probe_azure_openai(payload)
|
||||
elif payload.provider == "Ollama":
|
||||
status_code = _probe_ollama(payload)
|
||||
else:
|
||||
status_code = _probe_openai_compatible(payload)
|
||||
|
||||
detail = f"{payload.provider} 已连接,模型 {payload.model} 可正常访问。"
|
||||
return ModelConnectivityTestRead(
|
||||
ok=True,
|
||||
provider=payload.provider,
|
||||
model=payload.model,
|
||||
endpoint=payload.endpoint,
|
||||
capability=payload.capability,
|
||||
detail=detail,
|
||||
status_code=status_code,
|
||||
checked_at=checked_at,
|
||||
)
|
||||
except ConnectivityCheckError as exc:
|
||||
return ModelConnectivityTestRead(
|
||||
ok=False,
|
||||
provider=payload.provider,
|
||||
model=payload.model,
|
||||
endpoint=payload.endpoint,
|
||||
capability=payload.capability,
|
||||
detail=str(exc),
|
||||
status_code=exc.status_code,
|
||||
checked_at=checked_at,
|
||||
)
|
||||
|
||||
|
||||
def _probe_openai_compatible(payload: ModelConnectivityTestRequest) -> int:
|
||||
normalized_endpoint = _normalize_endpoint(payload.endpoint)
|
||||
headers = _build_headers(api_key=payload.api_key, use_bearer=True)
|
||||
|
||||
if payload.capability == "embedding":
|
||||
url = _ensure_path(normalized_endpoint, "embeddings")
|
||||
body = {"model": payload.model, "input": "connectivity test"}
|
||||
else:
|
||||
url = _ensure_path(normalized_endpoint, "chat/completions")
|
||||
body = {
|
||||
"model": payload.model,
|
||||
"messages": [{"role": "user", "content": "ping"}],
|
||||
"max_tokens": 1,
|
||||
}
|
||||
|
||||
status_code, _ = _send_json_request("POST", url, headers=headers, payload=body)
|
||||
return status_code
|
||||
|
||||
|
||||
def _probe_ollama(payload: ModelConnectivityTestRequest) -> int:
|
||||
normalized_endpoint = _normalize_endpoint(payload.endpoint)
|
||||
headers = _build_headers(api_key=payload.api_key, use_bearer=False)
|
||||
|
||||
if payload.capability == "embedding":
|
||||
url = _ensure_path(normalized_endpoint, "api/embed")
|
||||
body = {"model": payload.model, "input": "connectivity test"}
|
||||
else:
|
||||
url = _ensure_path(normalized_endpoint, "api/chat")
|
||||
body = {
|
||||
"model": payload.model,
|
||||
"messages": [{"role": "user", "content": "ping"}],
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
status_code, _ = _send_json_request("POST", url, headers=headers, payload=body)
|
||||
return status_code
|
||||
|
||||
|
||||
def _probe_azure_openai(payload: ModelConnectivityTestRequest) -> int:
|
||||
deployment_base = _build_azure_deployment_base(payload.endpoint, payload.model)
|
||||
headers = _build_headers(api_key=payload.api_key, use_bearer=False, use_api_key=True)
|
||||
|
||||
if payload.capability == "embedding":
|
||||
url = f"{deployment_base}/embeddings?api-version={AZURE_API_VERSION}"
|
||||
body = {"input": "connectivity test"}
|
||||
else:
|
||||
url = f"{deployment_base}/chat/completions?api-version={AZURE_API_VERSION}"
|
||||
body = {
|
||||
"messages": [{"role": "user", "content": "ping"}],
|
||||
"max_tokens": 1,
|
||||
}
|
||||
|
||||
status_code, _ = _send_json_request("POST", url, headers=headers, payload=body)
|
||||
return status_code
|
||||
|
||||
|
||||
def _build_azure_deployment_base(endpoint: str, model: str) -> str:
|
||||
normalized_endpoint = _normalize_endpoint(endpoint)
|
||||
quoted_model = quote(model, safe="")
|
||||
|
||||
if "/openai/deployments/" in normalized_endpoint:
|
||||
return normalized_endpoint
|
||||
|
||||
if "/openai/v1" in normalized_endpoint:
|
||||
resource_root = normalized_endpoint.split("/openai/v1", maxsplit=1)[0]
|
||||
return f"{resource_root}/openai/deployments/{quoted_model}"
|
||||
|
||||
if normalized_endpoint.endswith("/openai"):
|
||||
return f"{normalized_endpoint}/deployments/{quoted_model}"
|
||||
|
||||
return f"{normalized_endpoint}/openai/deployments/{quoted_model}"
|
||||
|
||||
|
||||
def _build_headers(
|
||||
api_key: str | None,
|
||||
*,
|
||||
use_bearer: bool,
|
||||
use_api_key: bool = False,
|
||||
) -> dict[str, str]:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
if api_key:
|
||||
if use_api_key:
|
||||
headers["api-key"] = api_key
|
||||
elif use_bearer:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
return headers
|
||||
|
||||
|
||||
def _normalize_endpoint(endpoint: str) -> str:
|
||||
normalized = endpoint.strip()
|
||||
if not normalized:
|
||||
raise ConnectivityCheckError("接口地址不能为空。", status_code=HTTPStatus.BAD_REQUEST)
|
||||
return normalized.rstrip("/")
|
||||
|
||||
|
||||
def _ensure_path(endpoint: str, suffix: str) -> str:
|
||||
suffix = suffix.lstrip("/")
|
||||
if endpoint.endswith(suffix):
|
||||
return endpoint
|
||||
return f"{endpoint}/{suffix}"
|
||||
|
||||
|
||||
def _send_json_request(
|
||||
method: str,
|
||||
url: str,
|
||||
*,
|
||||
headers: dict[str, str],
|
||||
payload: dict[str, Any],
|
||||
) -> tuple[int, Any]:
|
||||
data = json.dumps(payload).encode("utf-8")
|
||||
request = Request(url=url, data=data, headers=headers, method=method)
|
||||
|
||||
try:
|
||||
with urlopen(request, timeout=DEFAULT_TIMEOUT_SECONDS) as response:
|
||||
body = response.read().decode("utf-8") if response.length != 0 else ""
|
||||
return response.status, _parse_json_body(body)
|
||||
except HTTPError as exc:
|
||||
body = exc.read().decode("utf-8", errors="ignore")
|
||||
message = _extract_error_message(_parse_json_body(body)) or f"模型接口返回 {exc.code}。"
|
||||
raise ConnectivityCheckError(message, status_code=exc.code) from exc
|
||||
except URLError as exc:
|
||||
reason = getattr(exc, "reason", exc)
|
||||
raise ConnectivityCheckError(f"无法连接到模型接口:{reason}") from exc
|
||||
except TimeoutError as exc:
|
||||
raise ConnectivityCheckError("模型接口连接超时,请检查地址或网络。") from exc
|
||||
|
||||
|
||||
def _parse_json_body(body: str) -> Any:
|
||||
if not body:
|
||||
return None
|
||||
|
||||
try:
|
||||
return json.loads(body)
|
||||
except json.JSONDecodeError:
|
||||
return {"message": body}
|
||||
|
||||
|
||||
def _extract_error_message(payload: Any) -> str | None:
|
||||
if payload is None:
|
||||
return None
|
||||
|
||||
if isinstance(payload, dict):
|
||||
if isinstance(payload.get("detail"), str):
|
||||
return payload["detail"]
|
||||
if isinstance(payload.get("message"), str):
|
||||
return payload["message"]
|
||||
error_payload = payload.get("error")
|
||||
if isinstance(error_payload, dict) and isinstance(error_payload.get("message"), str):
|
||||
return error_payload["message"]
|
||||
|
||||
if isinstance(payload, str):
|
||||
return payload
|
||||
|
||||
return None
|
||||
352
server/src/app/services/settings.py
Normal file
352
server/src/app/services/settings.py
Normal file
@@ -0,0 +1,352 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.admin_secret import read_admin_secret, verify_admin_secret
|
||||
from app.core.config import get_settings
|
||||
from app.core.secret_box import decrypt_secret, encrypt_secret
|
||||
from app.core.security import hash_password, verify_password
|
||||
from app.db.base import Base
|
||||
from app.models.system_setting import SystemSetting
|
||||
from app.models.system_setting_secret import SystemSettingSecret
|
||||
from app.repositories.settings import SETTINGS_ROW_ID, SettingsRepository
|
||||
from app.schemas.settings import SettingsRead, SettingsWrite
|
||||
|
||||
MODEL_SECRET_FIELDS = {
|
||||
"main": "main_api_key_encrypted",
|
||||
"backup": "backup_api_key_encrypted",
|
||||
"vlm": "vlm_api_key_encrypted",
|
||||
"embedding": "embedding_api_key_encrypted",
|
||||
}
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class AdminCredentialRecord:
|
||||
account: str
|
||||
email: str
|
||||
password_hash: str
|
||||
|
||||
|
||||
class SettingsService:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
self.repository = SettingsRepository(db)
|
||||
self.runtime_settings = get_settings()
|
||||
|
||||
def ensure_settings_ready(self) -> tuple[SystemSetting, SystemSettingSecret]:
|
||||
Base.metadata.create_all(bind=self.db.get_bind())
|
||||
|
||||
settings_row = self.repository.get_settings()
|
||||
secrets_row = self.repository.get_secrets()
|
||||
should_commit = False
|
||||
|
||||
if settings_row is None:
|
||||
settings_row = self._build_default_settings()
|
||||
self.db.add(settings_row)
|
||||
should_commit = True
|
||||
|
||||
if secrets_row is None:
|
||||
secrets_row = SystemSettingSecret(id=SETTINGS_ROW_ID)
|
||||
self.db.add(secrets_row)
|
||||
should_commit = True
|
||||
|
||||
if should_commit:
|
||||
self.db.commit()
|
||||
self.db.refresh(settings_row)
|
||||
self.db.refresh(secrets_row)
|
||||
|
||||
return settings_row, secrets_row
|
||||
|
||||
def get_settings_snapshot(self) -> SettingsRead:
|
||||
settings_row, secrets_row = self.ensure_settings_ready()
|
||||
return self._serialize(settings_row, secrets_row)
|
||||
|
||||
def save_settings_snapshot(self, payload: SettingsWrite) -> SettingsRead:
|
||||
settings_row, secrets_row = self.ensure_settings_ready()
|
||||
|
||||
if payload.adminForm.newPassword:
|
||||
if len(payload.adminForm.newPassword) < 5:
|
||||
raise ValueError("管理员密码至少需要 5 位。")
|
||||
if payload.adminForm.newPassword != payload.adminForm.confirmPassword:
|
||||
raise ValueError("两次输入的管理员密码不一致。")
|
||||
secrets_row.admin_password_hash = hash_password(payload.adminForm.newPassword)
|
||||
|
||||
settings_row.company_name = payload.companyForm.companyName
|
||||
settings_row.display_name = payload.companyForm.displayName
|
||||
settings_row.company_code = payload.companyForm.companyCode
|
||||
settings_row.record_number = payload.companyForm.recordNumber
|
||||
settings_row.copyright_text = payload.companyForm.copyright
|
||||
|
||||
settings_row.admin_account = payload.adminForm.adminAccount
|
||||
settings_row.admin_email = payload.adminForm.adminEmail
|
||||
settings_row.session_timeout = payload.adminForm.sessionTimeout
|
||||
settings_row.notice_email = payload.adminForm.noticeEmail
|
||||
settings_row.mfa_enabled = payload.adminForm.mfaEnabled
|
||||
settings_row.strong_password = payload.adminForm.strongPassword
|
||||
settings_row.login_alert_enabled = payload.adminForm.loginAlertEnabled
|
||||
|
||||
settings_row.main_provider = payload.llmForm.mainProvider
|
||||
settings_row.main_model = payload.llmForm.mainModel
|
||||
settings_row.main_endpoint = payload.llmForm.mainEndpoint
|
||||
settings_row.backup_provider = payload.llmForm.backupProvider
|
||||
settings_row.backup_model = payload.llmForm.backupModel
|
||||
settings_row.backup_endpoint = payload.llmForm.backupEndpoint
|
||||
settings_row.vlm_provider = payload.llmForm.vlmProvider
|
||||
settings_row.vlm_model = payload.llmForm.vlmModel
|
||||
settings_row.vlm_endpoint = payload.llmForm.vlmEndpoint
|
||||
settings_row.embedding_provider = payload.llmForm.embeddingProvider
|
||||
settings_row.embedding_model = payload.llmForm.embeddingModel
|
||||
settings_row.embedding_endpoint = payload.llmForm.embeddingEndpoint
|
||||
|
||||
self._replace_secret_if_present(secrets_row, "main_api_key_encrypted", payload.llmForm.mainApiKey)
|
||||
self._replace_secret_if_present(secrets_row, "backup_api_key_encrypted", payload.llmForm.backupApiKey)
|
||||
self._replace_secret_if_present(secrets_row, "vlm_api_key_encrypted", payload.llmForm.vlmApiKey)
|
||||
self._replace_secret_if_present(
|
||||
secrets_row,
|
||||
"embedding_api_key_encrypted",
|
||||
payload.llmForm.embeddingApiKey,
|
||||
)
|
||||
|
||||
settings_row.log_level = payload.logForm.level
|
||||
settings_row.retention_days = payload.logForm.retentionDays
|
||||
settings_row.archive_cycle = payload.logForm.archiveCycle
|
||||
settings_row.log_path = payload.logForm.logPath
|
||||
settings_row.alert_email = payload.logForm.alertEmail
|
||||
settings_row.operation_audit = payload.logForm.operationAudit
|
||||
settings_row.login_audit = payload.logForm.loginAudit
|
||||
settings_row.mask_sensitive = payload.logForm.maskSensitive
|
||||
|
||||
settings_row.smtp_host = payload.mailForm.smtpHost
|
||||
settings_row.smtp_port = payload.mailForm.port
|
||||
settings_row.smtp_encryption = payload.mailForm.encryption
|
||||
settings_row.sender_name = payload.mailForm.senderName
|
||||
settings_row.sender_address = payload.mailForm.senderAddress
|
||||
settings_row.smtp_username = payload.mailForm.username
|
||||
settings_row.alert_enabled = payload.mailForm.alertEnabled
|
||||
settings_row.digest_enabled = payload.mailForm.digestEnabled
|
||||
settings_row.digest_time = payload.mailForm.digestTime
|
||||
settings_row.default_receiver = payload.mailForm.defaultReceiver
|
||||
|
||||
self._replace_secret_if_present(secrets_row, "smtp_password_encrypted", payload.mailForm.password)
|
||||
|
||||
self.db.add(settings_row)
|
||||
self.db.add(secrets_row)
|
||||
self.db.commit()
|
||||
self.db.refresh(settings_row)
|
||||
self.db.refresh(secrets_row)
|
||||
|
||||
return self._serialize(settings_row, secrets_row)
|
||||
|
||||
def load_saved_model_api_key(self, slot: str | None) -> str:
|
||||
if not slot or slot not in MODEL_SECRET_FIELDS:
|
||||
return ""
|
||||
|
||||
_, secrets_row = self.ensure_settings_ready()
|
||||
encrypted_value = getattr(secrets_row, MODEL_SECRET_FIELDS[slot], "")
|
||||
if not encrypted_value:
|
||||
return ""
|
||||
|
||||
return decrypt_secret(encrypted_value)
|
||||
|
||||
def get_admin_credentials(self) -> AdminCredentialRecord | None:
|
||||
settings_row, secrets_row = self.ensure_settings_ready()
|
||||
|
||||
if secrets_row.admin_password_hash:
|
||||
return AdminCredentialRecord(
|
||||
account=settings_row.admin_account,
|
||||
email=settings_row.admin_email,
|
||||
password_hash=secrets_row.admin_password_hash,
|
||||
)
|
||||
|
||||
legacy_record = read_admin_secret()
|
||||
if legacy_record is None:
|
||||
return None
|
||||
|
||||
username = str(legacy_record.get("username", "")).strip()
|
||||
email = str(settings_row.admin_email or self.runtime_settings.admin_email or "").strip()
|
||||
password_hash = ""
|
||||
|
||||
# Legacy admin.json uses scrypt fields rather than the app password format.
|
||||
# The auth flow handles this file separately when no DB-backed admin password exists.
|
||||
if username or email:
|
||||
return AdminCredentialRecord(account=username, email=email, password_hash=password_hash)
|
||||
|
||||
return None
|
||||
|
||||
def verify_admin_login(self, identifier: str, password: str) -> AdminCredentialRecord | None:
|
||||
settings_row, secrets_row = self.ensure_settings_ready()
|
||||
normalized_identifier = identifier.casefold()
|
||||
|
||||
if secrets_row.admin_password_hash:
|
||||
allowed_identifiers = {
|
||||
value.casefold()
|
||||
for value in [settings_row.admin_account, settings_row.admin_email]
|
||||
if value
|
||||
}
|
||||
|
||||
if normalized_identifier not in allowed_identifiers:
|
||||
return None
|
||||
|
||||
if not verify_password(password, secrets_row.admin_password_hash):
|
||||
return None
|
||||
|
||||
return AdminCredentialRecord(
|
||||
account=settings_row.admin_account,
|
||||
email=settings_row.admin_email,
|
||||
password_hash=secrets_row.admin_password_hash,
|
||||
)
|
||||
|
||||
legacy_record = read_admin_secret()
|
||||
if legacy_record is None:
|
||||
return None
|
||||
|
||||
admin_username = str(legacy_record.get("username", "")).strip()
|
||||
admin_email = str(settings_row.admin_email or self.runtime_settings.admin_email or "").strip()
|
||||
allowed_identifiers = {
|
||||
value.casefold()
|
||||
for value in [admin_username, admin_email]
|
||||
if value
|
||||
}
|
||||
|
||||
if normalized_identifier not in allowed_identifiers:
|
||||
return None
|
||||
|
||||
if not verify_admin_secret(password, legacy_record):
|
||||
return None
|
||||
|
||||
return AdminCredentialRecord(account=admin_username, email=admin_email, password_hash="")
|
||||
|
||||
def _build_default_settings(self) -> SystemSetting:
|
||||
current_year = datetime.now().year
|
||||
company_name = str(self.runtime_settings.company_name or "X-Financial").strip() or "X-Financial"
|
||||
company_code = str(self.runtime_settings.company_code or "XF-001").strip() or "XF-001"
|
||||
admin_email = str(self.runtime_settings.admin_email or "").strip()
|
||||
legacy_admin = read_admin_secret() or {}
|
||||
admin_account = str(legacy_admin.get("username", "")).strip() or "superadmin"
|
||||
|
||||
return SystemSetting(
|
||||
id=SETTINGS_ROW_ID,
|
||||
company_name=company_name,
|
||||
display_name=company_name,
|
||||
company_code=company_code,
|
||||
record_number="",
|
||||
copyright_text=f"Copyright © 2024-{current_year} {company_name}. All Rights Reserved.",
|
||||
admin_account=admin_account,
|
||||
admin_email=admin_email,
|
||||
session_timeout=30,
|
||||
notice_email=admin_email,
|
||||
mfa_enabled=True,
|
||||
strong_password=True,
|
||||
login_alert_enabled=True,
|
||||
main_provider="Codex",
|
||||
main_model="codex-mini-latest",
|
||||
main_endpoint="https://api.openai.com/v1",
|
||||
backup_provider="GLM",
|
||||
backup_model="glm-5.1",
|
||||
backup_endpoint="https://open.bigmodel.cn/api/paas/v4/",
|
||||
vlm_provider="Gemini",
|
||||
vlm_model="gemini-2.5-flash",
|
||||
vlm_endpoint="https://generativelanguage.googleapis.com/v1beta/openai/",
|
||||
embedding_provider="GLM",
|
||||
embedding_model="Embedding-3",
|
||||
embedding_endpoint="https://open.bigmodel.cn/api/paas/v4/",
|
||||
log_level="INFO",
|
||||
retention_days=180,
|
||||
archive_cycle="weekly",
|
||||
log_path="server/logs/app.log",
|
||||
alert_email=admin_email,
|
||||
operation_audit=True,
|
||||
login_audit=True,
|
||||
mask_sensitive=True,
|
||||
smtp_host="smtp.exmail.qq.com",
|
||||
smtp_port=465,
|
||||
smtp_encryption="SSL/TLS",
|
||||
sender_name=company_name,
|
||||
sender_address=admin_email,
|
||||
smtp_username=admin_email,
|
||||
alert_enabled=True,
|
||||
digest_enabled=False,
|
||||
digest_time="09:00",
|
||||
default_receiver=admin_email,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _replace_secret_if_present(secret_row: SystemSettingSecret, field_name: str, value: str) -> None:
|
||||
normalized = value.strip()
|
||||
if not normalized:
|
||||
return
|
||||
|
||||
setattr(secret_row, field_name, encrypt_secret(normalized))
|
||||
|
||||
@staticmethod
|
||||
def _serialize(settings_row: SystemSetting, secrets_row: SystemSettingSecret) -> SettingsRead:
|
||||
return SettingsRead(
|
||||
companyForm={
|
||||
"companyName": settings_row.company_name,
|
||||
"displayName": settings_row.display_name,
|
||||
"companyCode": settings_row.company_code,
|
||||
"recordNumber": settings_row.record_number,
|
||||
"copyright": settings_row.copyright_text,
|
||||
},
|
||||
adminForm={
|
||||
"adminAccount": settings_row.admin_account,
|
||||
"adminEmail": settings_row.admin_email,
|
||||
"newPassword": "",
|
||||
"confirmPassword": "",
|
||||
"sessionTimeout": settings_row.session_timeout,
|
||||
"noticeEmail": settings_row.notice_email,
|
||||
"mfaEnabled": settings_row.mfa_enabled,
|
||||
"strongPassword": settings_row.strong_password,
|
||||
"loginAlertEnabled": settings_row.login_alert_enabled,
|
||||
"adminPasswordConfigured": bool(secrets_row.admin_password_hash),
|
||||
},
|
||||
llmForm={
|
||||
"mainProvider": settings_row.main_provider,
|
||||
"mainModel": settings_row.main_model,
|
||||
"mainEndpoint": settings_row.main_endpoint,
|
||||
"mainApiKey": "",
|
||||
"mainApiKeyConfigured": bool(secrets_row.main_api_key_encrypted),
|
||||
"backupProvider": settings_row.backup_provider,
|
||||
"backupModel": settings_row.backup_model,
|
||||
"backupEndpoint": settings_row.backup_endpoint,
|
||||
"backupApiKey": "",
|
||||
"backupApiKeyConfigured": bool(secrets_row.backup_api_key_encrypted),
|
||||
"vlmProvider": settings_row.vlm_provider,
|
||||
"vlmModel": settings_row.vlm_model,
|
||||
"vlmEndpoint": settings_row.vlm_endpoint,
|
||||
"vlmApiKey": "",
|
||||
"vlmApiKeyConfigured": bool(secrets_row.vlm_api_key_encrypted),
|
||||
"embeddingProvider": settings_row.embedding_provider,
|
||||
"embeddingModel": settings_row.embedding_model,
|
||||
"embeddingEndpoint": settings_row.embedding_endpoint,
|
||||
"embeddingApiKey": "",
|
||||
"embeddingApiKeyConfigured": bool(secrets_row.embedding_api_key_encrypted),
|
||||
},
|
||||
logForm={
|
||||
"level": settings_row.log_level,
|
||||
"retentionDays": settings_row.retention_days,
|
||||
"archiveCycle": settings_row.archive_cycle,
|
||||
"logPath": settings_row.log_path,
|
||||
"alertEmail": settings_row.alert_email,
|
||||
"operationAudit": settings_row.operation_audit,
|
||||
"loginAudit": settings_row.login_audit,
|
||||
"maskSensitive": settings_row.mask_sensitive,
|
||||
},
|
||||
mailForm={
|
||||
"smtpHost": settings_row.smtp_host,
|
||||
"port": settings_row.smtp_port,
|
||||
"encryption": settings_row.smtp_encryption,
|
||||
"senderName": settings_row.sender_name,
|
||||
"senderAddress": settings_row.sender_address,
|
||||
"username": settings_row.smtp_username,
|
||||
"password": "",
|
||||
"passwordConfigured": bool(secrets_row.smtp_password_encrypted),
|
||||
"alertEnabled": settings_row.alert_enabled,
|
||||
"digestEnabled": settings_row.digest_enabled,
|
||||
"digestTime": settings_row.digest_time,
|
||||
"defaultReceiver": settings_row.default_receiver,
|
||||
},
|
||||
)
|
||||
@@ -35,7 +35,19 @@ set +a
|
||||
|
||||
SERVER_HOST="${SERVER_HOST:-127.0.0.1}"
|
||||
SERVER_PORT="${SERVER_PORT:-8000}"
|
||||
SERVER_RELOAD="${SERVER_RELOAD:-false}"
|
||||
DEFAULT_SERVER_RELOAD="false"
|
||||
|
||||
case "${APP_ENV:-local}" in
|
||||
local|dev|development)
|
||||
DEFAULT_SERVER_RELOAD="true"
|
||||
;;
|
||||
esac
|
||||
|
||||
if [ "${APP_DEBUG:-true}" = "true" ]; then
|
||||
DEFAULT_SERVER_RELOAD="true"
|
||||
fi
|
||||
|
||||
SERVER_RELOAD="${SERVER_RELOAD:-$DEFAULT_SERVER_RELOAD}"
|
||||
|
||||
is_wsl() {
|
||||
grep -qi microsoft /proc/version 2>/dev/null
|
||||
|
||||
84
server/tests/test_settings_persistence.py
Normal file
84
server/tests/test_settings_persistence.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from app.core import secret_box
|
||||
from app.db.base import Base
|
||||
from app.models.system_setting import SystemSetting
|
||||
from app.models.system_setting_secret import SystemSettingSecret
|
||||
from app.schemas.settings import SettingsWrite
|
||||
from app.services.settings import SettingsService
|
||||
|
||||
|
||||
def build_session(db_file: Path) -> Session:
|
||||
engine = create_engine(
|
||||
f"sqlite+pysqlite:///{db_file.as_posix()}",
|
||||
connect_args={"check_same_thread": False},
|
||||
)
|
||||
SystemSetting.__table__.create(bind=engine)
|
||||
SystemSettingSecret.__table__.create(bind=engine)
|
||||
session_factory = sessionmaker(bind=engine, autoflush=False, autocommit=False)
|
||||
return session_factory()
|
||||
|
||||
|
||||
def build_temp_secret_dir() -> Path:
|
||||
return Path(tempfile.mkdtemp(prefix="xf-settings-test-", dir="D:\\tmp"))
|
||||
|
||||
|
||||
def test_settings_service_persists_non_secret_and_secret_fields(monkeypatch) -> None:
|
||||
temp_dir = build_temp_secret_dir()
|
||||
monkeypatch.setattr(secret_box, "SECRET_KEY_FILE", temp_dir / "settings.key")
|
||||
monkeypatch.setattr(Base.metadata, "create_all", lambda *args, **kwargs: None)
|
||||
|
||||
with build_session(temp_dir / "settings.db") as db:
|
||||
service = SettingsService(db)
|
||||
initial_snapshot = service.get_settings_snapshot()
|
||||
payload = initial_snapshot.model_dump()
|
||||
|
||||
payload["companyForm"]["companyName"] = "YGSOFT"
|
||||
payload["companyForm"]["displayName"] = "云广软件"
|
||||
payload["adminForm"]["adminAccount"] = "admin-root"
|
||||
payload["adminForm"]["adminEmail"] = "admin@example.com"
|
||||
payload["adminForm"]["newPassword"] = "54321"
|
||||
payload["adminForm"]["confirmPassword"] = "54321"
|
||||
payload["llmForm"]["mainModel"] = "glm-4.5"
|
||||
payload["llmForm"]["mainApiKey"] = "main-secret"
|
||||
payload["mailForm"]["password"] = "smtp-secret"
|
||||
|
||||
saved_snapshot = service.save_settings_snapshot(SettingsWrite(**payload))
|
||||
|
||||
assert saved_snapshot.companyForm.companyName == "YGSOFT"
|
||||
assert saved_snapshot.companyForm.displayName == "云广软件"
|
||||
assert saved_snapshot.llmForm.mainModel == "glm-4.5"
|
||||
assert saved_snapshot.llmForm.mainApiKey == ""
|
||||
assert saved_snapshot.llmForm.mainApiKeyConfigured is True
|
||||
assert saved_snapshot.mailForm.password == ""
|
||||
assert saved_snapshot.mailForm.passwordConfigured is True
|
||||
assert saved_snapshot.adminForm.newPassword == ""
|
||||
assert saved_snapshot.adminForm.adminPasswordConfigured is True
|
||||
|
||||
assert service.load_saved_model_api_key("main") == "main-secret"
|
||||
assert service.verify_admin_login("admin-root", "54321") is not None
|
||||
assert service.verify_admin_login("admin@example.com", "54321") is not None
|
||||
|
||||
|
||||
def test_blank_secret_input_does_not_clear_saved_secret(monkeypatch) -> None:
|
||||
temp_dir = build_temp_secret_dir()
|
||||
monkeypatch.setattr(secret_box, "SECRET_KEY_FILE", temp_dir / "settings.key")
|
||||
monkeypatch.setattr(Base.metadata, "create_all", lambda *args, **kwargs: None)
|
||||
|
||||
with build_session(temp_dir / "settings.db") as db:
|
||||
service = SettingsService(db)
|
||||
first_payload = service.get_settings_snapshot().model_dump()
|
||||
first_payload["llmForm"]["mainApiKey"] = "persisted-key"
|
||||
service.save_settings_snapshot(SettingsWrite(**first_payload))
|
||||
|
||||
second_payload = service.get_settings_snapshot().model_dump()
|
||||
second_payload["llmForm"]["mainApiKey"] = ""
|
||||
service.save_settings_snapshot(SettingsWrite(**second_payload))
|
||||
|
||||
assert service.load_saved_model_api_key("main") == "persisted-key"
|
||||
84
server/tests/test_settings_service.py
Normal file
84
server/tests/test_settings_service.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.schemas.settings import ModelConnectivityTestRequest
|
||||
from app.services.model_connectivity import ConnectivityCheckError, probe_model_connectivity
|
||||
|
||||
|
||||
def test_probe_openai_compatible_chat_model(monkeypatch) -> None:
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
def fake_send_json_request(method, url, *, headers, payload):
|
||||
captured["method"] = method
|
||||
captured["url"] = url
|
||||
captured["headers"] = headers
|
||||
captured["payload"] = payload
|
||||
return 200, {"id": "ok"}
|
||||
|
||||
monkeypatch.setattr("app.services.model_connectivity._send_json_request", fake_send_json_request)
|
||||
|
||||
result = probe_model_connectivity(
|
||||
ModelConnectivityTestRequest(
|
||||
provider="OpenAI Compatible",
|
||||
endpoint="https://api.example.com/v1",
|
||||
model="gpt-test",
|
||||
api_key="secret",
|
||||
capability="chat",
|
||||
)
|
||||
)
|
||||
|
||||
assert result.ok is True
|
||||
assert result.status_code == 200
|
||||
assert captured["method"] == "POST"
|
||||
assert captured["url"] == "https://api.example.com/v1/chat/completions"
|
||||
assert captured["headers"]["Authorization"] == "Bearer secret"
|
||||
assert captured["payload"]["model"] == "gpt-test"
|
||||
|
||||
|
||||
def test_probe_azure_embedding_model(monkeypatch) -> None:
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
def fake_send_json_request(method, url, *, headers, payload):
|
||||
captured["url"] = url
|
||||
captured["headers"] = headers
|
||||
captured["payload"] = payload
|
||||
return 200, {"data": []}
|
||||
|
||||
monkeypatch.setattr("app.services.model_connectivity._send_json_request", fake_send_json_request)
|
||||
|
||||
result = probe_model_connectivity(
|
||||
ModelConnectivityTestRequest(
|
||||
provider="Azure OpenAI",
|
||||
endpoint="https://resource.openai.azure.com",
|
||||
model="embedding-deployment",
|
||||
api_key="azure-key",
|
||||
capability="embedding",
|
||||
)
|
||||
)
|
||||
|
||||
assert result.ok is True
|
||||
assert (
|
||||
captured["url"]
|
||||
== "https://resource.openai.azure.com/openai/deployments/embedding-deployment/embeddings?api-version=2024-10-21"
|
||||
)
|
||||
assert captured["headers"]["api-key"] == "azure-key"
|
||||
assert captured["payload"]["input"] == "connectivity test"
|
||||
|
||||
|
||||
def test_probe_ollama_failure_returns_error_payload(monkeypatch) -> None:
|
||||
def fake_send_json_request(method, url, *, headers, payload):
|
||||
raise ConnectivityCheckError("模型不存在或尚未拉取。", status_code=404)
|
||||
|
||||
monkeypatch.setattr("app.services.model_connectivity._send_json_request", fake_send_json_request)
|
||||
|
||||
result = probe_model_connectivity(
|
||||
ModelConnectivityTestRequest(
|
||||
provider="Ollama",
|
||||
endpoint="http://127.0.0.1:11434",
|
||||
model="llama3.1",
|
||||
capability="chat",
|
||||
)
|
||||
)
|
||||
|
||||
assert result.ok is False
|
||||
assert result.status_code == 404
|
||||
assert "模型不存在" in result.detail
|
||||
Reference in New Issue
Block a user