diff --git a/document/work-log/2026-05-07.md b/document/work-log/2026-05-07.md index e590630..d549398 100644 --- a/document/work-log/2026-05-07.md +++ b/document/work-log/2026-05-07.md @@ -14,6 +14,20 @@ - 添加了安全模块(security.py) - 添加了单元测试 +- **提交 b8ba0ea** (14:32) + - feat: add auth module with login and access control + - 为系统实现了完整的登录认证功能 + - 后端使用 FastAPI 搭建了 auth 服务,支持管理员密钥验证 + - 前端对接了登录接口,实现了 Token 存储和自动登录逻辑 + - 设计并实现了基于角色的访问控制(RBAC),区分超级管理员和普通员工 + +- **提交 e8f3d97** (15:18) + - feat: add settings page with navigation and access control updates + - 搭建了系统设置页面,支持管理员配置系统参数 + - 优化了侧边栏导航交互,增加了收起/展开的流畅动画 + - 将访问控制规则统一收敛到 accessControl.js,避免散落各处 + - 统一了 useSystemState 和 useNavigation 两个 composable 的职责 + --- # 待处理 diff --git a/server/src/app/api/v1/endpoints/settings.py b/server/src/app/api/v1/endpoints/settings.py new file mode 100644 index 0000000..67b6f94 --- /dev/null +++ b/server/src/app/api/v1/endpoints/settings.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +from typing import Annotated + +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy.orm import Session + +from app.api.deps import get_db +from app.schemas.settings import ( + ModelConnectivityTestRead, + ModelConnectivityTestRequest, + SettingsRead, + SettingsWrite, +) +from app.services.model_connectivity import probe_model_connectivity +from app.services.settings import SettingsService + +router = APIRouter(prefix="/settings") +DbSession = Annotated[Session, Depends(get_db)] + + +@router.get("", response_model=SettingsRead) +def get_settings(db: DbSession) -> SettingsRead: + return SettingsService(db).get_settings_snapshot() + + +@router.put("", response_model=SettingsRead) +def update_settings(payload: SettingsWrite, db: DbSession) -> SettingsRead: + try: + return SettingsService(db).save_settings_snapshot(payload) + except ValueError as exc: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc + + +@router.post("/model-connectivity", response_model=ModelConnectivityTestRead) +def test_model_connectivity(payload: ModelConnectivityTestRequest, db: DbSession) -> ModelConnectivityTestRead: + resolved_payload = payload + + if not payload.api_key and payload.slot: + stored_api_key = SettingsService(db).load_saved_model_api_key(payload.slot) + if stored_api_key: + resolved_payload = payload.model_copy(update={"api_key": stored_api_key}) + + return probe_model_connectivity(resolved_payload) diff --git a/server/src/app/api/v1/router.py b/server/src/app/api/v1/router.py index bedee42..dec86d9 100644 --- a/server/src/app/api/v1/router.py +++ b/server/src/app/api/v1/router.py @@ -5,6 +5,7 @@ from app.api.v1.endpoints.bootstrap import router as bootstrap_router from app.api.v1.endpoints.employees import router as employees_router from app.api.v1.endpoints.health import router as health_router from app.api.v1.endpoints.reimbursements import router as reimbursements_router +from app.api.v1.endpoints.settings import router as settings_router router = APIRouter() router.include_router(health_router, tags=["health"]) @@ -12,3 +13,4 @@ router.include_router(bootstrap_router, tags=["bootstrap"]) router.include_router(auth_router, tags=["auth"]) router.include_router(employees_router, prefix="/employees", tags=["employees"]) router.include_router(reimbursements_router, prefix="/reimbursements", tags=["reimbursements"]) +router.include_router(settings_router, tags=["settings"]) diff --git a/server/src/app/core/secret_box.py b/server/src/app/core/secret_box.py new file mode 100644 index 0000000..78a074c --- /dev/null +++ b/server/src/app/core/secret_box.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +import base64 +import hashlib +import hmac +import secrets +from pathlib import Path + +from app.core.config import SERVER_DIR + +SECRET_KEY_FILE = SERVER_DIR / ".secrets" / "settings.key" +SECRET_BOX_VERSION = "v1" +KEY_BYTES = 32 +NONCE_BYTES = 16 +MAC_BYTES = 32 +BLOCK_BYTES = 32 + + +def _ensure_secret_dir(path: Path) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + + +def get_or_create_secret_key() -> bytes: + _ensure_secret_dir(SECRET_KEY_FILE) + + if SECRET_KEY_FILE.exists(): + encoded = SECRET_KEY_FILE.read_text(encoding="utf-8").strip() + if encoded: + return base64.urlsafe_b64decode(encoded.encode("ascii")) + + secret_key = secrets.token_bytes(KEY_BYTES) + encoded = base64.urlsafe_b64encode(secret_key).decode("ascii") + SECRET_KEY_FILE.write_text(encoded, encoding="utf-8") + return secret_key + + +def _keystream(secret_key: bytes, nonce: bytes, length: int) -> bytes: + chunks: list[bytes] = [] + counter = 0 + + while sum(len(chunk) for chunk in chunks) < length: + block = hmac.new( + secret_key, + b"stream:" + nonce + counter.to_bytes(4, "big"), + hashlib.sha256, + ).digest() + chunks.append(block) + counter += 1 + + return b"".join(chunks)[:length] + + +def encrypt_secret(value: str) -> str: + if not value: + return "" + + secret_key = get_or_create_secret_key() + nonce = secrets.token_bytes(NONCE_BYTES) + plaintext = value.encode("utf-8") + ciphertext = bytes(a ^ b for a, b in zip(plaintext, _keystream(secret_key, nonce, len(plaintext)), strict=False)) + mac = hmac.new(secret_key, b"mac:" + nonce + ciphertext, hashlib.sha256).digest() + + encoded_nonce = base64.urlsafe_b64encode(nonce).decode("ascii") + encoded_ciphertext = base64.urlsafe_b64encode(ciphertext).decode("ascii") + encoded_mac = base64.urlsafe_b64encode(mac).decode("ascii") + return f"{SECRET_BOX_VERSION}${encoded_nonce}${encoded_ciphertext}${encoded_mac}" + + +def decrypt_secret(value: str) -> str: + if not value: + return "" + + try: + version, encoded_nonce, encoded_ciphertext, encoded_mac = value.split("$", 3) + except ValueError as exc: + raise ValueError("Invalid secret payload format") from exc + + if version != SECRET_BOX_VERSION: + raise ValueError("Unsupported secret payload version") + + secret_key = get_or_create_secret_key() + nonce = base64.urlsafe_b64decode(encoded_nonce.encode("ascii")) + ciphertext = base64.urlsafe_b64decode(encoded_ciphertext.encode("ascii")) + expected_mac = base64.urlsafe_b64decode(encoded_mac.encode("ascii")) + actual_mac = hmac.new(secret_key, b"mac:" + nonce + ciphertext, hashlib.sha256).digest() + + if not hmac.compare_digest(actual_mac, expected_mac): + raise ValueError("Secret payload integrity check failed") + + plaintext = bytes(a ^ b for a, b in zip(ciphertext, _keystream(secret_key, nonce, len(ciphertext)), strict=False)) + return plaintext.decode("utf-8") diff --git a/server/src/app/db/base.py b/server/src/app/db/base.py index 825f3cf..612b2bb 100644 --- a/server/src/app/db/base.py +++ b/server/src/app/db/base.py @@ -5,6 +5,8 @@ from app.models.employee import Employee from app.models.organization import OrganizationUnit from app.models.reimbursement import ReimbursementRequest from app.models.role import Role +from app.models.system_setting import SystemSetting +from app.models.system_setting_secret import SystemSettingSecret __all__ = [ "Base", @@ -14,4 +16,6 @@ __all__ = [ "OrganizationUnit", "ReimbursementRequest", "Role", + "SystemSetting", + "SystemSettingSecret", ] diff --git a/server/src/app/models/__init__.py b/server/src/app/models/__init__.py index 348c523..a25a818 100644 --- a/server/src/app/models/__init__.py +++ b/server/src/app/models/__init__.py @@ -4,6 +4,8 @@ from app.models.employee import Employee from app.models.organization import OrganizationUnit from app.models.reimbursement import ReimbursementRequest from app.models.role import Role +from app.models.system_setting import SystemSetting +from app.models.system_setting_secret import SystemSettingSecret __all__ = [ "ApprovalRecord", @@ -12,4 +14,6 @@ __all__ = [ "OrganizationUnit", "ReimbursementRequest", "Role", + "SystemSetting", + "SystemSettingSecret", ] diff --git a/server/src/app/models/system_setting.py b/server/src/app/models/system_setting.py new file mode 100644 index 0000000..c1ba54f --- /dev/null +++ b/server/src/app/models/system_setting.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +from datetime import datetime + +from sqlalchemy import Boolean, DateTime, Integer, String, func +from sqlalchemy.orm import Mapped, mapped_column + +from app.db.base_class import Base + + +class SystemSetting(Base): + __tablename__ = "system_settings" + + id: Mapped[str] = mapped_column(String(32), primary_key=True, default="default") + + company_name: Mapped[str] = mapped_column(String(120), default="X-Financial") + display_name: Mapped[str] = mapped_column(String(120), default="X-Financial") + company_code: Mapped[str] = mapped_column(String(64), default="XF-001") + record_number: Mapped[str] = mapped_column(String(120), default="") + copyright_text: Mapped[str] = mapped_column(String(255), default="") + + admin_account: Mapped[str] = mapped_column(String(120), default="superadmin") + admin_email: Mapped[str] = mapped_column(String(255), default="") + session_timeout: Mapped[int] = mapped_column(Integer, default=30) + notice_email: Mapped[str] = mapped_column(String(255), default="") + mfa_enabled: Mapped[bool] = mapped_column(Boolean, default=True) + strong_password: Mapped[bool] = mapped_column(Boolean, default=True) + login_alert_enabled: Mapped[bool] = mapped_column(Boolean, default=True) + + main_provider: Mapped[str] = mapped_column(String(64), default="Codex") + main_model: Mapped[str] = mapped_column(String(255), default="codex-mini-latest") + main_endpoint: Mapped[str] = mapped_column(String(512), default="https://api.openai.com/v1") + backup_provider: Mapped[str] = mapped_column(String(64), default="GLM") + backup_model: Mapped[str] = mapped_column(String(255), default="glm-5.1") + backup_endpoint: Mapped[str] = mapped_column(String(512), default="https://open.bigmodel.cn/api/paas/v4/") + vlm_provider: Mapped[str] = mapped_column(String(64), default="Gemini") + vlm_model: Mapped[str] = mapped_column(String(255), default="gemini-2.5-flash") + vlm_endpoint: Mapped[str] = mapped_column(String(512), default="https://generativelanguage.googleapis.com/v1beta/openai/") + embedding_provider: Mapped[str] = mapped_column(String(64), default="GLM") + embedding_model: Mapped[str] = mapped_column(String(255), default="Embedding-3") + embedding_endpoint: Mapped[str] = mapped_column(String(512), default="https://open.bigmodel.cn/api/paas/v4/") + + log_level: Mapped[str] = mapped_column(String(16), default="INFO") + retention_days: Mapped[int] = mapped_column(Integer, default=180) + archive_cycle: Mapped[str] = mapped_column(String(32), default="weekly") + log_path: Mapped[str] = mapped_column(String(255), default="server/logs/app.log") + alert_email: Mapped[str] = mapped_column(String(255), default="") + operation_audit: Mapped[bool] = mapped_column(Boolean, default=True) + login_audit: Mapped[bool] = mapped_column(Boolean, default=True) + mask_sensitive: Mapped[bool] = mapped_column(Boolean, default=True) + + smtp_host: Mapped[str] = mapped_column(String(255), default="smtp.exmail.qq.com") + smtp_port: Mapped[int] = mapped_column(Integer, default=465) + smtp_encryption: Mapped[str] = mapped_column(String(32), default="SSL/TLS") + sender_name: Mapped[str] = mapped_column(String(120), default="X-Financial") + sender_address: Mapped[str] = mapped_column(String(255), default="") + smtp_username: Mapped[str] = mapped_column(String(255), default="") + alert_enabled: Mapped[bool] = mapped_column(Boolean, default=True) + digest_enabled: Mapped[bool] = mapped_column(Boolean, default=False) + digest_time: Mapped[str] = mapped_column(String(16), default="09:00") + default_receiver: Mapped[str] = mapped_column(String(255), default="") + + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now()) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + server_default=func.now(), + onupdate=func.now(), + ) diff --git a/server/src/app/models/system_setting_secret.py b/server/src/app/models/system_setting_secret.py new file mode 100644 index 0000000..4969e5e --- /dev/null +++ b/server/src/app/models/system_setting_secret.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from datetime import datetime + +from sqlalchemy import DateTime, String, Text, func +from sqlalchemy.orm import Mapped, mapped_column + +from app.db.base_class import Base + + +class SystemSettingSecret(Base): + __tablename__ = "system_setting_secrets" + + id: Mapped[str] = mapped_column(String(32), primary_key=True, default="default") + + admin_password_hash: Mapped[str] = mapped_column(Text, default="") + main_api_key_encrypted: Mapped[str] = mapped_column(Text, default="") + backup_api_key_encrypted: Mapped[str] = mapped_column(Text, default="") + vlm_api_key_encrypted: Mapped[str] = mapped_column(Text, default="") + embedding_api_key_encrypted: Mapped[str] = mapped_column(Text, default="") + smtp_password_encrypted: Mapped[str] = mapped_column(Text, default="") + + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now()) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + server_default=func.now(), + onupdate=func.now(), + ) diff --git a/server/src/app/repositories/settings.py b/server/src/app/repositories/settings.py new file mode 100644 index 0000000..e099f74 --- /dev/null +++ b/server/src/app/repositories/settings.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from app.models.system_setting import SystemSetting +from app.models.system_setting_secret import SystemSettingSecret + +SETTINGS_ROW_ID = "default" + + +class SettingsRepository: + def __init__(self, db: Session) -> None: + self.db = db + + def get_settings(self) -> SystemSetting | None: + stmt = select(SystemSetting).where(SystemSetting.id == SETTINGS_ROW_ID) + return self.db.execute(stmt).scalars().first() + + def get_secrets(self) -> SystemSettingSecret | None: + stmt = select(SystemSettingSecret).where(SystemSettingSecret.id == SETTINGS_ROW_ID) + return self.db.execute(stmt).scalars().first() + + def save_settings(self, settings: SystemSetting) -> SystemSetting: + self.db.add(settings) + self.db.commit() + self.db.refresh(settings) + return settings + + def save_secrets(self, secrets: SystemSettingSecret) -> SystemSettingSecret: + self.db.add(secrets) + self.db.commit() + self.db.refresh(secrets) + return secrets diff --git a/server/src/app/schemas/settings.py b/server/src/app/schemas/settings.py new file mode 100644 index 0000000..6274a70 --- /dev/null +++ b/server/src/app/schemas/settings.py @@ -0,0 +1,185 @@ +from __future__ import annotations + +from datetime import datetime +from typing import Literal + +from pydantic import BaseModel, Field, field_validator + + +class SettingsCompanyForm(BaseModel): + companyName: str = Field(min_length=1, max_length=120) + displayName: str = Field(min_length=1, max_length=120) + companyCode: str = Field(default="", max_length=64) + recordNumber: str = Field(default="", max_length=120) + copyright: str = Field(default="", max_length=255) + + @field_validator("companyName", "displayName", "companyCode", "recordNumber", "copyright", mode="before") + @classmethod + def strip_string(cls, value: str | None) -> str | None: + if value is None: + return None + return value.strip() + + +class SettingsAdminForm(BaseModel): + adminAccount: str = Field(min_length=1, max_length=120) + adminEmail: str = Field(min_length=1, max_length=255) + newPassword: str = Field(default="", max_length=128) + confirmPassword: str = Field(default="", max_length=128) + sessionTimeout: int = Field(default=30, ge=5, le=240) + noticeEmail: str = Field(default="", max_length=255) + mfaEnabled: bool = True + strongPassword: bool = True + loginAlertEnabled: bool = True + adminPasswordConfigured: bool = False + + @field_validator("adminAccount", "adminEmail", "newPassword", "confirmPassword", "noticeEmail", mode="before") + @classmethod + def strip_string(cls, value: str | None) -> str | None: + if value is None: + return None + return value.strip() + + +class SettingsLlmForm(BaseModel): + mainProvider: str = Field(min_length=1, max_length=64) + mainModel: str = Field(min_length=1, max_length=255) + mainEndpoint: str = Field(min_length=1, max_length=512) + mainApiKey: str = Field(default="", max_length=1024) + mainApiKeyConfigured: bool = False + + backupProvider: str = Field(min_length=1, max_length=64) + backupModel: str = Field(min_length=1, max_length=255) + backupEndpoint: str = Field(min_length=1, max_length=512) + backupApiKey: str = Field(default="", max_length=1024) + backupApiKeyConfigured: bool = False + + vlmProvider: str = Field(min_length=1, max_length=64) + vlmModel: str = Field(min_length=1, max_length=255) + vlmEndpoint: str = Field(min_length=1, max_length=512) + vlmApiKey: str = Field(default="", max_length=1024) + vlmApiKeyConfigured: bool = False + + embeddingProvider: str = Field(min_length=1, max_length=64) + embeddingModel: str = Field(min_length=1, max_length=255) + embeddingEndpoint: str = Field(min_length=1, max_length=512) + embeddingApiKey: str = Field(default="", max_length=1024) + embeddingApiKeyConfigured: bool = False + + @field_validator( + "mainProvider", + "mainModel", + "mainEndpoint", + "mainApiKey", + "backupProvider", + "backupModel", + "backupEndpoint", + "backupApiKey", + "vlmProvider", + "vlmModel", + "vlmEndpoint", + "vlmApiKey", + "embeddingProvider", + "embeddingModel", + "embeddingEndpoint", + "embeddingApiKey", + mode="before", + ) + @classmethod + def strip_string(cls, value: str | None) -> str | None: + if value is None: + return None + return value.strip() + + +class SettingsLogForm(BaseModel): + level: str = Field(min_length=1, max_length=16) + retentionDays: int = Field(default=180, ge=1, le=3650) + archiveCycle: str = Field(default="weekly", max_length=32) + logPath: str = Field(min_length=1, max_length=255) + alertEmail: str = Field(default="", max_length=255) + operationAudit: bool = True + loginAudit: bool = True + maskSensitive: bool = True + + @field_validator("level", "archiveCycle", "logPath", "alertEmail", mode="before") + @classmethod + def strip_string(cls, value: str | None) -> str | None: + if value is None: + return None + return value.strip() + + +class SettingsMailForm(BaseModel): + smtpHost: str = Field(min_length=1, max_length=255) + port: int = Field(default=465, ge=1, le=65535) + encryption: str = Field(default="SSL/TLS", max_length=32) + senderName: str = Field(default="", max_length=120) + senderAddress: str = Field(default="", max_length=255) + username: str = Field(default="", max_length=255) + password: str = Field(default="", max_length=1024) + passwordConfigured: bool = False + alertEnabled: bool = True + digestEnabled: bool = False + digestTime: str = Field(default="09:00", max_length=16) + defaultReceiver: str = Field(default="", max_length=255) + + @field_validator( + "smtpHost", + "encryption", + "senderName", + "senderAddress", + "username", + "password", + "digestTime", + "defaultReceiver", + mode="before", + ) + @classmethod + def strip_string(cls, value: str | None) -> str | None: + if value is None: + return None + return value.strip() + + +class SettingsRead(BaseModel): + companyForm: SettingsCompanyForm + adminForm: SettingsAdminForm + llmForm: SettingsLlmForm + logForm: SettingsLogForm + mailForm: SettingsMailForm + + +class SettingsWrite(BaseModel): + companyForm: SettingsCompanyForm + adminForm: SettingsAdminForm + llmForm: SettingsLlmForm + logForm: SettingsLogForm + mailForm: SettingsMailForm + + +class ModelConnectivityTestRequest(BaseModel): + provider: str = Field(min_length=1, max_length=64) + endpoint: str = Field(min_length=1, max_length=512) + model: str = Field(min_length=1, max_length=255) + api_key: str | None = Field(default=None, max_length=1024) + capability: Literal["chat", "embedding"] = "chat" + slot: Literal["main", "backup", "vlm", "embedding"] | None = None + + @field_validator("provider", "endpoint", "model", "api_key", mode="before") + @classmethod + def strip_model_string(cls, value: str | None) -> str | None: + if value is None: + return None + return value.strip() + + +class ModelConnectivityTestRead(BaseModel): + ok: bool + provider: str + model: str + endpoint: str + capability: str + detail: str + status_code: int | None = None + checked_at: datetime diff --git a/server/src/app/services/auth.py b/server/src/app/services/auth.py index a4ffa2a..9f63157 100644 --- a/server/src/app/services/auth.py +++ b/server/src/app/services/auth.py @@ -5,7 +5,6 @@ from dataclasses import dataclass from sqlalchemy import func, select from sqlalchemy.orm import Session, selectinload -from app.core.admin_secret import read_admin_secret, verify_admin_secret from app.core.config import get_settings from app.core.logging import get_logger from app.core.security import verify_password @@ -13,6 +12,7 @@ from app.models.employee import Employee from app.schemas.auth import AuthUserRead, LoginRequest, LoginResponse from app.services.employee import EmployeeService from app.services.employee_seed import ROLE_DISPLAY_ORDER +from app.services.settings import SettingsService logger = get_logger("app.services.auth") @@ -53,34 +53,25 @@ class AuthService: employee_user = self._authenticate_employee(identifier, password) if employee_user is not None: - logger.info("Employee login succeeded identifier=%s role_codes=%s", identifier, ",".join(employee_user.role_codes)) + logger.info( + "Employee login succeeded identifier=%s role_codes=%s", + identifier, + ",".join(employee_user.role_codes), + ) return LoginResponse(user=self._serialize_user(employee_user)) logger.warning("Login failed identifier=%s", identifier) raise ValueError("账号或密码错误。") def _authenticate_admin(self, identifier: str, password: str) -> AuthenticatedUser | None: - record = read_admin_secret() + record = SettingsService(self.db).verify_admin_login(identifier, password) if record is None: return None - admin_username = str(record.get("username", "")).strip() - admin_email = str(self.settings.admin_email or "").strip() - normalized_identifier = identifier.casefold() - - allowed_identifiers = { - value.casefold() - for value in [admin_username, admin_email] - if value - } - - if normalized_identifier not in allowed_identifiers: - return None - - if not verify_admin_secret(password, record): - return None - + admin_username = record.account.strip() + admin_email = record.email.strip() display_name = admin_username or admin_email or "系统管理员" + return AuthenticatedUser( username=admin_username or admin_email, name=display_name, diff --git a/server/src/app/services/model_connectivity.py b/server/src/app/services/model_connectivity.py new file mode 100644 index 0000000..139606e --- /dev/null +++ b/server/src/app/services/model_connectivity.py @@ -0,0 +1,216 @@ +from __future__ import annotations + +import json +from datetime import datetime, timezone +from http import HTTPStatus +from typing import Any +from urllib.error import HTTPError, URLError +from urllib.parse import quote +from urllib.request import Request, urlopen + +from app.schemas.settings import ModelConnectivityTestRead, ModelConnectivityTestRequest + +AZURE_API_VERSION = "2024-10-21" +DEFAULT_TIMEOUT_SECONDS = 12 + + +class ConnectivityCheckError(Exception): + def __init__(self, message: str, status_code: int | None = None) -> None: + super().__init__(message) + self.status_code = status_code + + +def probe_model_connectivity(payload: ModelConnectivityTestRequest) -> ModelConnectivityTestRead: + checked_at = datetime.now(timezone.utc) + + try: + if payload.provider == "Azure OpenAI": + status_code = _probe_azure_openai(payload) + elif payload.provider == "Ollama": + status_code = _probe_ollama(payload) + else: + status_code = _probe_openai_compatible(payload) + + detail = f"{payload.provider} 已连接,模型 {payload.model} 可正常访问。" + return ModelConnectivityTestRead( + ok=True, + provider=payload.provider, + model=payload.model, + endpoint=payload.endpoint, + capability=payload.capability, + detail=detail, + status_code=status_code, + checked_at=checked_at, + ) + except ConnectivityCheckError as exc: + return ModelConnectivityTestRead( + ok=False, + provider=payload.provider, + model=payload.model, + endpoint=payload.endpoint, + capability=payload.capability, + detail=str(exc), + status_code=exc.status_code, + checked_at=checked_at, + ) + + +def _probe_openai_compatible(payload: ModelConnectivityTestRequest) -> int: + normalized_endpoint = _normalize_endpoint(payload.endpoint) + headers = _build_headers(api_key=payload.api_key, use_bearer=True) + + if payload.capability == "embedding": + url = _ensure_path(normalized_endpoint, "embeddings") + body = {"model": payload.model, "input": "connectivity test"} + else: + url = _ensure_path(normalized_endpoint, "chat/completions") + body = { + "model": payload.model, + "messages": [{"role": "user", "content": "ping"}], + "max_tokens": 1, + } + + status_code, _ = _send_json_request("POST", url, headers=headers, payload=body) + return status_code + + +def _probe_ollama(payload: ModelConnectivityTestRequest) -> int: + normalized_endpoint = _normalize_endpoint(payload.endpoint) + headers = _build_headers(api_key=payload.api_key, use_bearer=False) + + if payload.capability == "embedding": + url = _ensure_path(normalized_endpoint, "api/embed") + body = {"model": payload.model, "input": "connectivity test"} + else: + url = _ensure_path(normalized_endpoint, "api/chat") + body = { + "model": payload.model, + "messages": [{"role": "user", "content": "ping"}], + "stream": False, + } + + status_code, _ = _send_json_request("POST", url, headers=headers, payload=body) + return status_code + + +def _probe_azure_openai(payload: ModelConnectivityTestRequest) -> int: + deployment_base = _build_azure_deployment_base(payload.endpoint, payload.model) + headers = _build_headers(api_key=payload.api_key, use_bearer=False, use_api_key=True) + + if payload.capability == "embedding": + url = f"{deployment_base}/embeddings?api-version={AZURE_API_VERSION}" + body = {"input": "connectivity test"} + else: + url = f"{deployment_base}/chat/completions?api-version={AZURE_API_VERSION}" + body = { + "messages": [{"role": "user", "content": "ping"}], + "max_tokens": 1, + } + + status_code, _ = _send_json_request("POST", url, headers=headers, payload=body) + return status_code + + +def _build_azure_deployment_base(endpoint: str, model: str) -> str: + normalized_endpoint = _normalize_endpoint(endpoint) + quoted_model = quote(model, safe="") + + if "/openai/deployments/" in normalized_endpoint: + return normalized_endpoint + + if "/openai/v1" in normalized_endpoint: + resource_root = normalized_endpoint.split("/openai/v1", maxsplit=1)[0] + return f"{resource_root}/openai/deployments/{quoted_model}" + + if normalized_endpoint.endswith("/openai"): + return f"{normalized_endpoint}/deployments/{quoted_model}" + + return f"{normalized_endpoint}/openai/deployments/{quoted_model}" + + +def _build_headers( + api_key: str | None, + *, + use_bearer: bool, + use_api_key: bool = False, +) -> dict[str, str]: + headers = { + "Content-Type": "application/json", + "Accept": "application/json", + } + + if api_key: + if use_api_key: + headers["api-key"] = api_key + elif use_bearer: + headers["Authorization"] = f"Bearer {api_key}" + + return headers + + +def _normalize_endpoint(endpoint: str) -> str: + normalized = endpoint.strip() + if not normalized: + raise ConnectivityCheckError("接口地址不能为空。", status_code=HTTPStatus.BAD_REQUEST) + return normalized.rstrip("/") + + +def _ensure_path(endpoint: str, suffix: str) -> str: + suffix = suffix.lstrip("/") + if endpoint.endswith(suffix): + return endpoint + return f"{endpoint}/{suffix}" + + +def _send_json_request( + method: str, + url: str, + *, + headers: dict[str, str], + payload: dict[str, Any], +) -> tuple[int, Any]: + data = json.dumps(payload).encode("utf-8") + request = Request(url=url, data=data, headers=headers, method=method) + + try: + with urlopen(request, timeout=DEFAULT_TIMEOUT_SECONDS) as response: + body = response.read().decode("utf-8") if response.length != 0 else "" + return response.status, _parse_json_body(body) + except HTTPError as exc: + body = exc.read().decode("utf-8", errors="ignore") + message = _extract_error_message(_parse_json_body(body)) or f"模型接口返回 {exc.code}。" + raise ConnectivityCheckError(message, status_code=exc.code) from exc + except URLError as exc: + reason = getattr(exc, "reason", exc) + raise ConnectivityCheckError(f"无法连接到模型接口:{reason}") from exc + except TimeoutError as exc: + raise ConnectivityCheckError("模型接口连接超时,请检查地址或网络。") from exc + + +def _parse_json_body(body: str) -> Any: + if not body: + return None + + try: + return json.loads(body) + except json.JSONDecodeError: + return {"message": body} + + +def _extract_error_message(payload: Any) -> str | None: + if payload is None: + return None + + if isinstance(payload, dict): + if isinstance(payload.get("detail"), str): + return payload["detail"] + if isinstance(payload.get("message"), str): + return payload["message"] + error_payload = payload.get("error") + if isinstance(error_payload, dict) and isinstance(error_payload.get("message"), str): + return error_payload["message"] + + if isinstance(payload, str): + return payload + + return None diff --git a/server/src/app/services/settings.py b/server/src/app/services/settings.py new file mode 100644 index 0000000..ed8a71b --- /dev/null +++ b/server/src/app/services/settings.py @@ -0,0 +1,352 @@ +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime + +from sqlalchemy.orm import Session + +from app.core.admin_secret import read_admin_secret, verify_admin_secret +from app.core.config import get_settings +from app.core.secret_box import decrypt_secret, encrypt_secret +from app.core.security import hash_password, verify_password +from app.db.base import Base +from app.models.system_setting import SystemSetting +from app.models.system_setting_secret import SystemSettingSecret +from app.repositories.settings import SETTINGS_ROW_ID, SettingsRepository +from app.schemas.settings import SettingsRead, SettingsWrite + +MODEL_SECRET_FIELDS = { + "main": "main_api_key_encrypted", + "backup": "backup_api_key_encrypted", + "vlm": "vlm_api_key_encrypted", + "embedding": "embedding_api_key_encrypted", +} + + +@dataclass(slots=True) +class AdminCredentialRecord: + account: str + email: str + password_hash: str + + +class SettingsService: + def __init__(self, db: Session) -> None: + self.db = db + self.repository = SettingsRepository(db) + self.runtime_settings = get_settings() + + def ensure_settings_ready(self) -> tuple[SystemSetting, SystemSettingSecret]: + Base.metadata.create_all(bind=self.db.get_bind()) + + settings_row = self.repository.get_settings() + secrets_row = self.repository.get_secrets() + should_commit = False + + if settings_row is None: + settings_row = self._build_default_settings() + self.db.add(settings_row) + should_commit = True + + if secrets_row is None: + secrets_row = SystemSettingSecret(id=SETTINGS_ROW_ID) + self.db.add(secrets_row) + should_commit = True + + if should_commit: + self.db.commit() + self.db.refresh(settings_row) + self.db.refresh(secrets_row) + + return settings_row, secrets_row + + def get_settings_snapshot(self) -> SettingsRead: + settings_row, secrets_row = self.ensure_settings_ready() + return self._serialize(settings_row, secrets_row) + + def save_settings_snapshot(self, payload: SettingsWrite) -> SettingsRead: + settings_row, secrets_row = self.ensure_settings_ready() + + if payload.adminForm.newPassword: + if len(payload.adminForm.newPassword) < 5: + raise ValueError("管理员密码至少需要 5 位。") + if payload.adminForm.newPassword != payload.adminForm.confirmPassword: + raise ValueError("两次输入的管理员密码不一致。") + secrets_row.admin_password_hash = hash_password(payload.adminForm.newPassword) + + settings_row.company_name = payload.companyForm.companyName + settings_row.display_name = payload.companyForm.displayName + settings_row.company_code = payload.companyForm.companyCode + settings_row.record_number = payload.companyForm.recordNumber + settings_row.copyright_text = payload.companyForm.copyright + + settings_row.admin_account = payload.adminForm.adminAccount + settings_row.admin_email = payload.adminForm.adminEmail + settings_row.session_timeout = payload.adminForm.sessionTimeout + settings_row.notice_email = payload.adminForm.noticeEmail + settings_row.mfa_enabled = payload.adminForm.mfaEnabled + settings_row.strong_password = payload.adminForm.strongPassword + settings_row.login_alert_enabled = payload.adminForm.loginAlertEnabled + + settings_row.main_provider = payload.llmForm.mainProvider + settings_row.main_model = payload.llmForm.mainModel + settings_row.main_endpoint = payload.llmForm.mainEndpoint + settings_row.backup_provider = payload.llmForm.backupProvider + settings_row.backup_model = payload.llmForm.backupModel + settings_row.backup_endpoint = payload.llmForm.backupEndpoint + settings_row.vlm_provider = payload.llmForm.vlmProvider + settings_row.vlm_model = payload.llmForm.vlmModel + settings_row.vlm_endpoint = payload.llmForm.vlmEndpoint + settings_row.embedding_provider = payload.llmForm.embeddingProvider + settings_row.embedding_model = payload.llmForm.embeddingModel + settings_row.embedding_endpoint = payload.llmForm.embeddingEndpoint + + self._replace_secret_if_present(secrets_row, "main_api_key_encrypted", payload.llmForm.mainApiKey) + self._replace_secret_if_present(secrets_row, "backup_api_key_encrypted", payload.llmForm.backupApiKey) + self._replace_secret_if_present(secrets_row, "vlm_api_key_encrypted", payload.llmForm.vlmApiKey) + self._replace_secret_if_present( + secrets_row, + "embedding_api_key_encrypted", + payload.llmForm.embeddingApiKey, + ) + + settings_row.log_level = payload.logForm.level + settings_row.retention_days = payload.logForm.retentionDays + settings_row.archive_cycle = payload.logForm.archiveCycle + settings_row.log_path = payload.logForm.logPath + settings_row.alert_email = payload.logForm.alertEmail + settings_row.operation_audit = payload.logForm.operationAudit + settings_row.login_audit = payload.logForm.loginAudit + settings_row.mask_sensitive = payload.logForm.maskSensitive + + settings_row.smtp_host = payload.mailForm.smtpHost + settings_row.smtp_port = payload.mailForm.port + settings_row.smtp_encryption = payload.mailForm.encryption + settings_row.sender_name = payload.mailForm.senderName + settings_row.sender_address = payload.mailForm.senderAddress + settings_row.smtp_username = payload.mailForm.username + settings_row.alert_enabled = payload.mailForm.alertEnabled + settings_row.digest_enabled = payload.mailForm.digestEnabled + settings_row.digest_time = payload.mailForm.digestTime + settings_row.default_receiver = payload.mailForm.defaultReceiver + + self._replace_secret_if_present(secrets_row, "smtp_password_encrypted", payload.mailForm.password) + + self.db.add(settings_row) + self.db.add(secrets_row) + self.db.commit() + self.db.refresh(settings_row) + self.db.refresh(secrets_row) + + return self._serialize(settings_row, secrets_row) + + def load_saved_model_api_key(self, slot: str | None) -> str: + if not slot or slot not in MODEL_SECRET_FIELDS: + return "" + + _, secrets_row = self.ensure_settings_ready() + encrypted_value = getattr(secrets_row, MODEL_SECRET_FIELDS[slot], "") + if not encrypted_value: + return "" + + return decrypt_secret(encrypted_value) + + def get_admin_credentials(self) -> AdminCredentialRecord | None: + settings_row, secrets_row = self.ensure_settings_ready() + + if secrets_row.admin_password_hash: + return AdminCredentialRecord( + account=settings_row.admin_account, + email=settings_row.admin_email, + password_hash=secrets_row.admin_password_hash, + ) + + legacy_record = read_admin_secret() + if legacy_record is None: + return None + + username = str(legacy_record.get("username", "")).strip() + email = str(settings_row.admin_email or self.runtime_settings.admin_email or "").strip() + password_hash = "" + + # Legacy admin.json uses scrypt fields rather than the app password format. + # The auth flow handles this file separately when no DB-backed admin password exists. + if username or email: + return AdminCredentialRecord(account=username, email=email, password_hash=password_hash) + + return None + + def verify_admin_login(self, identifier: str, password: str) -> AdminCredentialRecord | None: + settings_row, secrets_row = self.ensure_settings_ready() + normalized_identifier = identifier.casefold() + + if secrets_row.admin_password_hash: + allowed_identifiers = { + value.casefold() + for value in [settings_row.admin_account, settings_row.admin_email] + if value + } + + if normalized_identifier not in allowed_identifiers: + return None + + if not verify_password(password, secrets_row.admin_password_hash): + return None + + return AdminCredentialRecord( + account=settings_row.admin_account, + email=settings_row.admin_email, + password_hash=secrets_row.admin_password_hash, + ) + + legacy_record = read_admin_secret() + if legacy_record is None: + return None + + admin_username = str(legacy_record.get("username", "")).strip() + admin_email = str(settings_row.admin_email or self.runtime_settings.admin_email or "").strip() + allowed_identifiers = { + value.casefold() + for value in [admin_username, admin_email] + if value + } + + if normalized_identifier not in allowed_identifiers: + return None + + if not verify_admin_secret(password, legacy_record): + return None + + return AdminCredentialRecord(account=admin_username, email=admin_email, password_hash="") + + def _build_default_settings(self) -> SystemSetting: + current_year = datetime.now().year + company_name = str(self.runtime_settings.company_name or "X-Financial").strip() or "X-Financial" + company_code = str(self.runtime_settings.company_code or "XF-001").strip() or "XF-001" + admin_email = str(self.runtime_settings.admin_email or "").strip() + legacy_admin = read_admin_secret() or {} + admin_account = str(legacy_admin.get("username", "")).strip() or "superadmin" + + return SystemSetting( + id=SETTINGS_ROW_ID, + company_name=company_name, + display_name=company_name, + company_code=company_code, + record_number="", + copyright_text=f"Copyright © 2024-{current_year} {company_name}. All Rights Reserved.", + admin_account=admin_account, + admin_email=admin_email, + session_timeout=30, + notice_email=admin_email, + mfa_enabled=True, + strong_password=True, + login_alert_enabled=True, + main_provider="Codex", + main_model="codex-mini-latest", + main_endpoint="https://api.openai.com/v1", + backup_provider="GLM", + backup_model="glm-5.1", + backup_endpoint="https://open.bigmodel.cn/api/paas/v4/", + vlm_provider="Gemini", + vlm_model="gemini-2.5-flash", + vlm_endpoint="https://generativelanguage.googleapis.com/v1beta/openai/", + embedding_provider="GLM", + embedding_model="Embedding-3", + embedding_endpoint="https://open.bigmodel.cn/api/paas/v4/", + log_level="INFO", + retention_days=180, + archive_cycle="weekly", + log_path="server/logs/app.log", + alert_email=admin_email, + operation_audit=True, + login_audit=True, + mask_sensitive=True, + smtp_host="smtp.exmail.qq.com", + smtp_port=465, + smtp_encryption="SSL/TLS", + sender_name=company_name, + sender_address=admin_email, + smtp_username=admin_email, + alert_enabled=True, + digest_enabled=False, + digest_time="09:00", + default_receiver=admin_email, + ) + + @staticmethod + def _replace_secret_if_present(secret_row: SystemSettingSecret, field_name: str, value: str) -> None: + normalized = value.strip() + if not normalized: + return + + setattr(secret_row, field_name, encrypt_secret(normalized)) + + @staticmethod + def _serialize(settings_row: SystemSetting, secrets_row: SystemSettingSecret) -> SettingsRead: + return SettingsRead( + companyForm={ + "companyName": settings_row.company_name, + "displayName": settings_row.display_name, + "companyCode": settings_row.company_code, + "recordNumber": settings_row.record_number, + "copyright": settings_row.copyright_text, + }, + adminForm={ + "adminAccount": settings_row.admin_account, + "adminEmail": settings_row.admin_email, + "newPassword": "", + "confirmPassword": "", + "sessionTimeout": settings_row.session_timeout, + "noticeEmail": settings_row.notice_email, + "mfaEnabled": settings_row.mfa_enabled, + "strongPassword": settings_row.strong_password, + "loginAlertEnabled": settings_row.login_alert_enabled, + "adminPasswordConfigured": bool(secrets_row.admin_password_hash), + }, + llmForm={ + "mainProvider": settings_row.main_provider, + "mainModel": settings_row.main_model, + "mainEndpoint": settings_row.main_endpoint, + "mainApiKey": "", + "mainApiKeyConfigured": bool(secrets_row.main_api_key_encrypted), + "backupProvider": settings_row.backup_provider, + "backupModel": settings_row.backup_model, + "backupEndpoint": settings_row.backup_endpoint, + "backupApiKey": "", + "backupApiKeyConfigured": bool(secrets_row.backup_api_key_encrypted), + "vlmProvider": settings_row.vlm_provider, + "vlmModel": settings_row.vlm_model, + "vlmEndpoint": settings_row.vlm_endpoint, + "vlmApiKey": "", + "vlmApiKeyConfigured": bool(secrets_row.vlm_api_key_encrypted), + "embeddingProvider": settings_row.embedding_provider, + "embeddingModel": settings_row.embedding_model, + "embeddingEndpoint": settings_row.embedding_endpoint, + "embeddingApiKey": "", + "embeddingApiKeyConfigured": bool(secrets_row.embedding_api_key_encrypted), + }, + logForm={ + "level": settings_row.log_level, + "retentionDays": settings_row.retention_days, + "archiveCycle": settings_row.archive_cycle, + "logPath": settings_row.log_path, + "alertEmail": settings_row.alert_email, + "operationAudit": settings_row.operation_audit, + "loginAudit": settings_row.login_audit, + "maskSensitive": settings_row.mask_sensitive, + }, + mailForm={ + "smtpHost": settings_row.smtp_host, + "port": settings_row.smtp_port, + "encryption": settings_row.smtp_encryption, + "senderName": settings_row.sender_name, + "senderAddress": settings_row.sender_address, + "username": settings_row.smtp_username, + "password": "", + "passwordConfigured": bool(secrets_row.smtp_password_encrypted), + "alertEnabled": settings_row.alert_enabled, + "digestEnabled": settings_row.digest_enabled, + "digestTime": settings_row.digest_time, + "defaultReceiver": settings_row.default_receiver, + }, + ) diff --git a/server/start.sh b/server/start.sh index 067075b..2fc52fe 100644 --- a/server/start.sh +++ b/server/start.sh @@ -35,7 +35,19 @@ set +a SERVER_HOST="${SERVER_HOST:-127.0.0.1}" SERVER_PORT="${SERVER_PORT:-8000}" -SERVER_RELOAD="${SERVER_RELOAD:-false}" +DEFAULT_SERVER_RELOAD="false" + +case "${APP_ENV:-local}" in + local|dev|development) + DEFAULT_SERVER_RELOAD="true" + ;; +esac + +if [ "${APP_DEBUG:-true}" = "true" ]; then + DEFAULT_SERVER_RELOAD="true" +fi + +SERVER_RELOAD="${SERVER_RELOAD:-$DEFAULT_SERVER_RELOAD}" is_wsl() { grep -qi microsoft /proc/version 2>/dev/null diff --git a/server/tests/test_settings_persistence.py b/server/tests/test_settings_persistence.py new file mode 100644 index 0000000..a3f31c6 --- /dev/null +++ b/server/tests/test_settings_persistence.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from pathlib import Path +import tempfile + +from sqlalchemy import create_engine +from sqlalchemy.orm import Session, sessionmaker + +from app.core import secret_box +from app.db.base import Base +from app.models.system_setting import SystemSetting +from app.models.system_setting_secret import SystemSettingSecret +from app.schemas.settings import SettingsWrite +from app.services.settings import SettingsService + + +def build_session(db_file: Path) -> Session: + engine = create_engine( + f"sqlite+pysqlite:///{db_file.as_posix()}", + connect_args={"check_same_thread": False}, + ) + SystemSetting.__table__.create(bind=engine) + SystemSettingSecret.__table__.create(bind=engine) + session_factory = sessionmaker(bind=engine, autoflush=False, autocommit=False) + return session_factory() + + +def build_temp_secret_dir() -> Path: + return Path(tempfile.mkdtemp(prefix="xf-settings-test-", dir="D:\\tmp")) + + +def test_settings_service_persists_non_secret_and_secret_fields(monkeypatch) -> None: + temp_dir = build_temp_secret_dir() + monkeypatch.setattr(secret_box, "SECRET_KEY_FILE", temp_dir / "settings.key") + monkeypatch.setattr(Base.metadata, "create_all", lambda *args, **kwargs: None) + + with build_session(temp_dir / "settings.db") as db: + service = SettingsService(db) + initial_snapshot = service.get_settings_snapshot() + payload = initial_snapshot.model_dump() + + payload["companyForm"]["companyName"] = "YGSOFT" + payload["companyForm"]["displayName"] = "云广软件" + payload["adminForm"]["adminAccount"] = "admin-root" + payload["adminForm"]["adminEmail"] = "admin@example.com" + payload["adminForm"]["newPassword"] = "54321" + payload["adminForm"]["confirmPassword"] = "54321" + payload["llmForm"]["mainModel"] = "glm-4.5" + payload["llmForm"]["mainApiKey"] = "main-secret" + payload["mailForm"]["password"] = "smtp-secret" + + saved_snapshot = service.save_settings_snapshot(SettingsWrite(**payload)) + + assert saved_snapshot.companyForm.companyName == "YGSOFT" + assert saved_snapshot.companyForm.displayName == "云广软件" + assert saved_snapshot.llmForm.mainModel == "glm-4.5" + assert saved_snapshot.llmForm.mainApiKey == "" + assert saved_snapshot.llmForm.mainApiKeyConfigured is True + assert saved_snapshot.mailForm.password == "" + assert saved_snapshot.mailForm.passwordConfigured is True + assert saved_snapshot.adminForm.newPassword == "" + assert saved_snapshot.adminForm.adminPasswordConfigured is True + + assert service.load_saved_model_api_key("main") == "main-secret" + assert service.verify_admin_login("admin-root", "54321") is not None + assert service.verify_admin_login("admin@example.com", "54321") is not None + + +def test_blank_secret_input_does_not_clear_saved_secret(monkeypatch) -> None: + temp_dir = build_temp_secret_dir() + monkeypatch.setattr(secret_box, "SECRET_KEY_FILE", temp_dir / "settings.key") + monkeypatch.setattr(Base.metadata, "create_all", lambda *args, **kwargs: None) + + with build_session(temp_dir / "settings.db") as db: + service = SettingsService(db) + first_payload = service.get_settings_snapshot().model_dump() + first_payload["llmForm"]["mainApiKey"] = "persisted-key" + service.save_settings_snapshot(SettingsWrite(**first_payload)) + + second_payload = service.get_settings_snapshot().model_dump() + second_payload["llmForm"]["mainApiKey"] = "" + service.save_settings_snapshot(SettingsWrite(**second_payload)) + + assert service.load_saved_model_api_key("main") == "persisted-key" diff --git a/server/tests/test_settings_service.py b/server/tests/test_settings_service.py new file mode 100644 index 0000000..d72b208 --- /dev/null +++ b/server/tests/test_settings_service.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from app.schemas.settings import ModelConnectivityTestRequest +from app.services.model_connectivity import ConnectivityCheckError, probe_model_connectivity + + +def test_probe_openai_compatible_chat_model(monkeypatch) -> None: + captured: dict[str, object] = {} + + def fake_send_json_request(method, url, *, headers, payload): + captured["method"] = method + captured["url"] = url + captured["headers"] = headers + captured["payload"] = payload + return 200, {"id": "ok"} + + monkeypatch.setattr("app.services.model_connectivity._send_json_request", fake_send_json_request) + + result = probe_model_connectivity( + ModelConnectivityTestRequest( + provider="OpenAI Compatible", + endpoint="https://api.example.com/v1", + model="gpt-test", + api_key="secret", + capability="chat", + ) + ) + + assert result.ok is True + assert result.status_code == 200 + assert captured["method"] == "POST" + assert captured["url"] == "https://api.example.com/v1/chat/completions" + assert captured["headers"]["Authorization"] == "Bearer secret" + assert captured["payload"]["model"] == "gpt-test" + + +def test_probe_azure_embedding_model(monkeypatch) -> None: + captured: dict[str, object] = {} + + def fake_send_json_request(method, url, *, headers, payload): + captured["url"] = url + captured["headers"] = headers + captured["payload"] = payload + return 200, {"data": []} + + monkeypatch.setattr("app.services.model_connectivity._send_json_request", fake_send_json_request) + + result = probe_model_connectivity( + ModelConnectivityTestRequest( + provider="Azure OpenAI", + endpoint="https://resource.openai.azure.com", + model="embedding-deployment", + api_key="azure-key", + capability="embedding", + ) + ) + + assert result.ok is True + assert ( + captured["url"] + == "https://resource.openai.azure.com/openai/deployments/embedding-deployment/embeddings?api-version=2024-10-21" + ) + assert captured["headers"]["api-key"] == "azure-key" + assert captured["payload"]["input"] == "connectivity test" + + +def test_probe_ollama_failure_returns_error_payload(monkeypatch) -> None: + def fake_send_json_request(method, url, *, headers, payload): + raise ConnectivityCheckError("模型不存在或尚未拉取。", status_code=404) + + monkeypatch.setattr("app.services.model_connectivity._send_json_request", fake_send_json_request) + + result = probe_model_connectivity( + ModelConnectivityTestRequest( + provider="Ollama", + endpoint="http://127.0.0.1:11434", + model="llama3.1", + capability="chat", + ) + ) + + assert result.ok is False + assert result.status_code == 404 + assert "模型不存在" in result.detail diff --git a/start.sh b/start.sh index 54f8629..9b9e5d5 100644 --- a/start.sh +++ b/start.sh @@ -32,6 +32,9 @@ set +a SERVER_STARTUP_TIMEOUT="${SERVER_STARTUP_TIMEOUT:-300}" SETUP_COMPLETED="${SETUP_COMPLETED:-false}" +APP_DEBUG="${APP_DEBUG:-true}" +APP_ENV="${APP_ENV:-local}" +SERVER_RELOAD="${SERVER_RELOAD:-}" server_probe_url() { echo "http://${SERVER_HOST:-127.0.0.1}:${SERVER_PORT:-8000}${API_V1_PREFIX:-/api/v1}/health" @@ -155,6 +158,9 @@ start_all() { if probe_server_ready "$probe_url" "$smoke_url"; then warn "FastAPI is already ready at $probe_url. Reusing the existing backend process." + if [ "$APP_DEBUG" = "true" ] && [ "$SERVER_RELOAD" != "true" ]; then + warn "This backend may be stale because SERVER_RELOAD is disabled. If new API routes are missing, stop the old backend process and rerun ./start.sh." + fi elif probe_server_health "$probe_url"; then error "An existing backend process is responding at $probe_url, but the smoke check failed at $smoke_url. Stop the old FastAPI process and rerun ./start.sh." else diff --git a/web/src/assets/styles/views/settings-view.css b/web/src/assets/styles/views/settings-view.css index c64061e..3e949b5 100644 --- a/web/src/assets/styles/views/settings-view.css +++ b/web/src/assets/styles/views/settings-view.css @@ -18,7 +18,7 @@ min-width: 0; min-height: 0; display: grid; - grid-template-rows: auto minmax(0, 1fr) auto; + grid-template-rows: auto minmax(0, 1fr); gap: 10px; padding: 22px 16px 18px; border-right: 1px solid #e7edf3; @@ -65,10 +65,7 @@ .settings-nav-item { width: 100%; min-height: 74px; - display: grid; - grid-template-columns: minmax(0, 1fr) auto; - align-items: center; - gap: 12px; + display: block; padding: 14px 14px 14px 16px; border: 1px solid transparent; border-radius: 18px; @@ -115,42 +112,6 @@ line-height: 1.45; } -.nav-item-state { - width: 26px; - height: 26px; - display: inline-flex; - align-items: center; - justify-content: center; - border-radius: 999px; - background: #f1f5f9; - color: #94a3b8; - font-size: 14px; -} - -.settings-nav-item.complete .nav-item-state { - background: #ecfdf5; - color: #059669; -} - -.settings-nav-foot { - display: grid; - gap: 4px; - padding: 16px 12px 2px; - border-top: 1px solid #eef3f7; -} - -.settings-nav-foot span { - color: #64748b; - font-size: 12px; - font-weight: 700; -} - -.settings-nav-foot strong { - color: #0f172a; - font-size: 16px; - font-weight: 820; -} - .settings-body { min-width: 0; min-height: 0; @@ -211,25 +172,6 @@ gap: 12px; } -.section-status { - min-height: 36px; - display: inline-flex; - align-items: center; - gap: 8px; - padding: 0 13px; - border-radius: 999px; - background: #fff7ed; - color: #c2410c; - font-size: 12px; - font-weight: 800; - white-space: nowrap; -} - -.section-status.complete { - background: #ecfdf5; - color: #059669; -} - .save-button { min-height: 42px; display: inline-flex; @@ -265,6 +207,12 @@ padding: 24px 28px 28px; } +.model-grid { + display: grid; + grid-template-columns: repeat(2, minmax(0, 1fr)); + gap: 18px; +} + .settings-card { padding: 22px 22px 24px; border: 1px solid #e8eef3; @@ -280,6 +228,13 @@ margin-bottom: 18px; } +.card-head-actions { + display: flex; + align-items: center; + gap: 10px; + flex: 0 0 auto; +} + .card-head h4 { color: #0f172a; font-size: 18px; @@ -294,6 +249,37 @@ line-height: 1.65; } +.test-button { + min-height: 38px; + display: inline-flex; + align-items: center; + justify-content: center; + gap: 8px; + padding: 0 14px; + border: 1px solid rgba(16, 185, 129, 0.18); + border-radius: 12px; + background: #f7fffb; + color: #047857; + font-size: 13px; + font-weight: 800; + transition: + border-color 180ms var(--ease), + background 180ms var(--ease), + color 180ms var(--ease), + transform 180ms var(--ease); +} + +.test-button:hover:not(:disabled) { + transform: translateY(-1px); + border-color: rgba(16, 185, 129, 0.34); + background: #ecfdf5; +} + +.test-button:disabled { + cursor: wait; + opacity: 0.78; +} + .form-grid { display: grid; grid-template-columns: repeat(2, minmax(0, 1fr)); @@ -363,6 +349,38 @@ box-shadow: 0 0 0 4px rgba(16, 185, 129, 0.12); } +.test-feedback { + display: flex; + align-items: flex-start; + gap: 8px; + margin-top: 16px; + padding: 12px 14px; + border-radius: 14px; + font-size: 12px; + font-weight: 700; + line-height: 1.6; +} + +.test-feedback i { + margin-top: 2px; + font-size: 15px; +} + +.test-feedback.is-success { + background: #ecfdf5; + color: #047857; +} + +.test-feedback.is-error { + background: #fef2f2; + color: #b91c1c; +} + +.test-feedback.is-testing { + background: #eff6ff; + color: #1d4ed8; +} + .logo-field { align-self: stretch; } @@ -587,8 +605,7 @@ justify-items: stretch; } - .save-button, - .section-status { + .save-button { justify-content: center; } } @@ -620,6 +637,7 @@ padding-inline: 20px; } + .model-grid, .form-grid, .profile-grid { grid-template-columns: 1fr; diff --git a/web/src/services/settings.js b/web/src/services/settings.js new file mode 100644 index 0000000..505445d --- /dev/null +++ b/web/src/services/settings.js @@ -0,0 +1,19 @@ +import { apiRequest } from './api.js' + +export function fetchSettings() { + return apiRequest('/settings') +} + +export function saveSettings(payload) { + return apiRequest('/settings', { + method: 'PUT', + body: JSON.stringify(payload) + }) +} + +export function testModelConnectivity(payload) { + return apiRequest('/settings/model-connectivity', { + method: 'POST', + body: JSON.stringify(payload) + }) +} diff --git a/web/src/views/SettingsView.vue b/web/src/views/SettingsView.vue index e96013f..fae3ee6 100644 --- a/web/src/views/SettingsView.vue +++ b/web/src/views/SettingsView.vue @@ -14,8 +14,7 @@ :key="section.id" class="settings-nav-item" :class="{ - active: activeSection === section.id, - complete: sectionStatus[section.id] + active: activeSection === section.id }" type="button" @click="activateSection(section.id)" @@ -24,17 +23,8 @@ {{ section.label }} {{ section.desc }} - - - -
用于确认侧边栏品牌、页脚版权和系统入口名称的实际展示效果。
-{{ pageState.companyForm.companyName || '企业法定名称' }}
- {{ pageState.companyForm.copyright || '版权信息将显示在这里' }} -配置大语言模型的供应商、模型名称和接入地址,用于 AI 助手与识别流程。
-控制响应质量、输出长度以及知识库、引用回溯等增强能力。
-