feat: 完善系统配置、安全增强与知识库功能
- .env.example: API基础路径改为相对路径 /api/v1,支持代理转发 - README.md: 完善项目结构与启动说明文档 - docker-compose.yml: 新增Docker编排配置,支持容器化部署 - docker/: 新增Docker部署相关文档与配置 - server_start.sh: 重构启动脚本,添加容器环境检测、隔离虚拟环境路径、环境变量覆盖机制 - deps.py: 完善API依赖注入,增强权限验证逻辑 - admin_secret.py: 优化管理员密钥加密存储与验证 - config.py: 扩展配置管理,支持多环境变量绑定 - security.py: 增强安全模块,完善加密与认证机制 - db/base.py: 优化数据库基础架构与连接管理 - main.py: 更新应用入口,整合新模块路由 - models/: 完善系统模型配置,支持模型设置持久化 - repositories/settings.py: 优化设置仓储层,增强数据持久化 - services/settings.py: 重构设置服务,精简代码结构 - router.py: 更新API路由配置 - endpoints/knowledge.py: 新增知识库API端点 - schemas/knowledge.py: 新增知识库数据模型 - services/knowledge.py: 新增知识库业务逻辑 - storage/knowledge/.index.json: 知识库索引存储 - api.js: 完善API服务层,增强错误处理 - bootstrap.js: 优化前端初始化与引导流程 - useSetupView.js / useSystemState.js: 重构组合式函数 - TopBar.vue: 优化顶部导航栏组件 - SettingsView.vue: 重构设置页面UI,增强用户体验 - SetupView.vue / SetupRouteView.vue: 完善引导流程页面 - PoliciesView.vue: 优化策略视图组件 - vite.config.js: 更新Vite构建配置 - web_start.sh: 完善前端启动脚本 - views/scripts/: 优化各业务视图JS逻辑 - settings-view.css: 重构设置页面样式 - setup-view.css: 完善引导页样式 - policies-view.css: 优化策略页样式 - test_auth_service.py: 完善认证服务测试 - test_settings_persistence.py: 增强设置持久化测试 - document/: 新增开发文档与工作日志
This commit is contained in:
@@ -1,5 +1,8 @@
|
||||
from collections.abc import Generator
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends, Header, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db.session import get_session_factory
|
||||
@@ -11,3 +14,49 @@ def get_db() -> Generator[Session, None, None]:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class CurrentUserContext:
|
||||
username: str
|
||||
name: str
|
||||
role_codes: list[str]
|
||||
is_admin: bool
|
||||
|
||||
|
||||
def get_current_user(
|
||||
x_auth_username: Annotated[str | None, Header()] = None,
|
||||
x_auth_name: Annotated[str | None, Header()] = None,
|
||||
x_auth_role_codes: Annotated[str | None, Header()] = None,
|
||||
x_auth_is_admin: Annotated[str | None, Header()] = None,
|
||||
) -> CurrentUserContext:
|
||||
role_codes = [item.strip() for item in (x_auth_role_codes or "").split(",") if item.strip()]
|
||||
is_admin = str(x_auth_is_admin or "").strip().lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
username = (x_auth_username or "").strip()
|
||||
name = (x_auth_name or username).strip()
|
||||
|
||||
if not username and not name:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="请先登录后再访问知识库。",
|
||||
)
|
||||
|
||||
return CurrentUserContext(
|
||||
username=username or name,
|
||||
name=name or username,
|
||||
role_codes=role_codes,
|
||||
is_admin=is_admin,
|
||||
)
|
||||
|
||||
|
||||
def require_admin_user(
|
||||
current_user: Annotated[CurrentUserContext, Depends(get_current_user)],
|
||||
) -> CurrentUserContext:
|
||||
if current_user.is_admin or "manager" in current_user.role_codes:
|
||||
return current_user
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="只有管理员可以上传、删除或修改知识库文件。",
|
||||
)
|
||||
|
||||
76
server/src/app/api/v1/endpoints/knowledge.py
Normal file
76
server/src/app/api/v1/endpoints/knowledge.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
from app.api.deps import CurrentUserContext, get_current_user, require_admin_user
|
||||
from app.schemas.knowledge import (
|
||||
KnowledgeActionResponse,
|
||||
KnowledgeDocumentDetailRead,
|
||||
KnowledgeLibraryRead,
|
||||
)
|
||||
from app.services.knowledge import KnowledgeService
|
||||
|
||||
router = APIRouter(prefix="/knowledge")
|
||||
|
||||
|
||||
@router.get("/library", response_model=KnowledgeLibraryRead)
|
||||
def get_knowledge_library(
|
||||
_: Annotated[CurrentUserContext, Depends(get_current_user)],
|
||||
) -> KnowledgeLibraryRead:
|
||||
return KnowledgeService().list_library()
|
||||
|
||||
|
||||
@router.get("/documents/{document_id}", response_model=KnowledgeDocumentDetailRead)
|
||||
def get_knowledge_document(
|
||||
document_id: str,
|
||||
_: Annotated[CurrentUserContext, Depends(get_current_user)],
|
||||
) -> KnowledgeDocumentDetailRead:
|
||||
try:
|
||||
return KnowledgeService().get_document_detail(document_id)
|
||||
except FileNotFoundError as exc:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="知识库文件不存在。") from exc
|
||||
|
||||
|
||||
@router.post("/documents", response_model=KnowledgeDocumentDetailRead, status_code=status.HTTP_201_CREATED)
|
||||
async def upload_knowledge_document(
|
||||
request: Request,
|
||||
folder: Annotated[str, Query(min_length=1)],
|
||||
filename: Annotated[str, Query(min_length=1)],
|
||||
current_user: Annotated[CurrentUserContext, Depends(require_admin_user)],
|
||||
) -> KnowledgeDocumentDetailRead:
|
||||
content = await request.body()
|
||||
try:
|
||||
return KnowledgeService().upload_document(folder, filename, content, current_user)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
|
||||
|
||||
|
||||
@router.delete("/documents/{document_id}", response_model=KnowledgeActionResponse)
|
||||
def delete_knowledge_document(
|
||||
document_id: str,
|
||||
_: Annotated[CurrentUserContext, Depends(require_admin_user)],
|
||||
) -> KnowledgeActionResponse:
|
||||
try:
|
||||
KnowledgeService().delete_document(document_id)
|
||||
except FileNotFoundError as exc:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="知识库文件不存在。") from exc
|
||||
|
||||
return KnowledgeActionResponse(detail="知识库文件已删除。")
|
||||
|
||||
|
||||
@router.get("/documents/{document_id}/content")
|
||||
def get_knowledge_document_content(
|
||||
document_id: str,
|
||||
disposition: Annotated[str, Query(pattern="^(inline|attachment)$")] = "inline",
|
||||
_: Annotated[CurrentUserContext, Depends(get_current_user)] = None,
|
||||
) -> FileResponse:
|
||||
try:
|
||||
file_path, media_type, filename = KnowledgeService().get_document_content(document_id)
|
||||
except FileNotFoundError as exc:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="知识库文件不存在。") from exc
|
||||
|
||||
_ = disposition
|
||||
return FileResponse(file_path, media_type=media_type, filename=filename)
|
||||
@@ -4,6 +4,7 @@ from app.api.v1.endpoints.auth import router as auth_router
|
||||
from app.api.v1.endpoints.bootstrap import router as bootstrap_router
|
||||
from app.api.v1.endpoints.employees import router as employees_router
|
||||
from app.api.v1.endpoints.health import router as health_router
|
||||
from app.api.v1.endpoints.knowledge import router as knowledge_router
|
||||
from app.api.v1.endpoints.reimbursements import router as reimbursements_router
|
||||
from app.api.v1.endpoints.settings import router as settings_router
|
||||
|
||||
@@ -11,6 +12,7 @@ router = APIRouter()
|
||||
router.include_router(health_router, tags=["health"])
|
||||
router.include_router(bootstrap_router, tags=["bootstrap"])
|
||||
router.include_router(auth_router, tags=["auth"])
|
||||
router.include_router(knowledge_router, tags=["knowledge"])
|
||||
router.include_router(employees_router, prefix="/employees", tags=["employees"])
|
||||
router.include_router(reimbursements_router, prefix="/reimbursements", tags=["reimbursements"])
|
||||
router.include_router(settings_router, tags=["settings"])
|
||||
|
||||
@@ -1,63 +1,63 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import secrets
|
||||
from pathlib import Path
|
||||
|
||||
from app.core.config import SERVER_DIR
|
||||
|
||||
ADMIN_SECRET_FILE = SERVER_DIR / ".secrets" / "admin.json"
|
||||
|
||||
|
||||
def read_admin_secret() -> dict[str, object] | None:
|
||||
if not ADMIN_SECRET_FILE.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
payload = json.loads(ADMIN_SECRET_FILE.read_text(encoding="utf-8"))
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return None
|
||||
|
||||
if (
|
||||
payload
|
||||
and payload.get("algorithm") == "scrypt"
|
||||
and isinstance(payload.get("username"), str)
|
||||
and isinstance(payload.get("salt"), str)
|
||||
and isinstance(payload.get("derived_key"), str)
|
||||
):
|
||||
return payload
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def verify_admin_secret(password: str, record: dict[str, object]) -> bool:
|
||||
try:
|
||||
salt = bytes.fromhex(str(record["salt"]))
|
||||
stored_key = bytes.fromhex(str(record["derived_key"]))
|
||||
key_length = int(record.get("key_length", 64))
|
||||
n_value = int(record.get("N", 16384))
|
||||
r_value = int(record.get("r", 8))
|
||||
p_value = int(record.get("p", 1))
|
||||
except (KeyError, TypeError, ValueError):
|
||||
return False
|
||||
|
||||
derived_key = hashlib.scrypt(
|
||||
password.encode("utf-8"),
|
||||
salt=salt,
|
||||
n=n_value,
|
||||
r=r_value,
|
||||
p=p_value,
|
||||
dklen=key_length,
|
||||
)
|
||||
return secrets.compare_digest(derived_key, stored_key)
|
||||
|
||||
|
||||
def legacy_admin_secret_to_password_hash(record: dict[str, object]) -> str:
|
||||
salt = str(record["salt"])
|
||||
derived_key = str(record["derived_key"])
|
||||
key_length = int(record.get("key_length", 64))
|
||||
n_value = int(record.get("N", 16384))
|
||||
r_value = int(record.get("r", 8))
|
||||
p_value = int(record.get("p", 1))
|
||||
return f"scrypt${n_value}${r_value}${p_value}${key_length}${salt}${derived_key}"
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import secrets
|
||||
from pathlib import Path
|
||||
|
||||
from app.core.config import SERVER_DIR
|
||||
|
||||
ADMIN_SECRET_FILE = SERVER_DIR / ".secrets" / "admin.json"
|
||||
|
||||
|
||||
def read_admin_secret() -> dict[str, object] | None:
|
||||
if not ADMIN_SECRET_FILE.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
payload = json.loads(ADMIN_SECRET_FILE.read_text(encoding="utf-8"))
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return None
|
||||
|
||||
if (
|
||||
payload
|
||||
and payload.get("algorithm") == "scrypt"
|
||||
and isinstance(payload.get("username"), str)
|
||||
and isinstance(payload.get("salt"), str)
|
||||
and isinstance(payload.get("derived_key"), str)
|
||||
):
|
||||
return payload
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def verify_admin_secret(password: str, record: dict[str, object]) -> bool:
|
||||
try:
|
||||
salt = bytes.fromhex(str(record["salt"]))
|
||||
stored_key = bytes.fromhex(str(record["derived_key"]))
|
||||
key_length = int(record.get("key_length", 64))
|
||||
n_value = int(record.get("N", 16384))
|
||||
r_value = int(record.get("r", 8))
|
||||
p_value = int(record.get("p", 1))
|
||||
except (KeyError, TypeError, ValueError):
|
||||
return False
|
||||
|
||||
derived_key = hashlib.scrypt(
|
||||
password.encode("utf-8"),
|
||||
salt=salt,
|
||||
n=n_value,
|
||||
r=r_value,
|
||||
p=p_value,
|
||||
dklen=key_length,
|
||||
)
|
||||
return secrets.compare_digest(derived_key, stored_key)
|
||||
|
||||
|
||||
def legacy_admin_secret_to_password_hash(record: dict[str, object]) -> str:
|
||||
salt = str(record["salt"])
|
||||
derived_key = str(record["derived_key"])
|
||||
key_length = int(record.get("key_length", 64))
|
||||
n_value = int(record.get("N", 16384))
|
||||
r_value = int(record.get("r", 8))
|
||||
p_value = int(record.get("p", 1))
|
||||
return f"scrypt${n_value}${r_value}${p_value}${key_length}${salt}${derived_key}"
|
||||
|
||||
@@ -1,76 +1,84 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
from os import environ
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
SERVER_DIR = Path(__file__).resolve().parents[3]
|
||||
ROOT_DIR = SERVER_DIR.parent
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=(ROOT_DIR / ".env", SERVER_DIR / ".env"),
|
||||
env_file_encoding="utf-8",
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
app_name: str = Field(default="X-Financial Server", alias="APP_NAME")
|
||||
app_env: str = Field(default="local", alias="APP_ENV")
|
||||
app_debug: bool = Field(default=True, alias="APP_DEBUG")
|
||||
setup_completed: bool = Field(default=False, alias="SETUP_COMPLETED")
|
||||
|
||||
company_name: str = Field(default="", alias="COMPANY_NAME")
|
||||
company_code: str = Field(default="", alias="COMPANY_CODE")
|
||||
admin_email: str = Field(default="", alias="ADMIN_EMAIL")
|
||||
|
||||
web_host: str = Field(default="0.0.0.0", alias="WEB_HOST")
|
||||
web_port: int = Field(default=5173, alias="WEB_PORT")
|
||||
app_host: str = Field(default="0.0.0.0", alias="SERVER_HOST")
|
||||
app_port: int = Field(default=8000, alias="SERVER_PORT")
|
||||
api_v1_prefix: str = Field(default="/api/v1", alias="API_V1_PREFIX")
|
||||
|
||||
postgres_host: str = Field(default="127.0.0.1", alias="POSTGRES_HOST")
|
||||
postgres_port: int = Field(default=5432, alias="POSTGRES_PORT")
|
||||
postgres_db: str = Field(default="x_financial", alias="POSTGRES_DB")
|
||||
postgres_user: str = Field(default="postgres", alias="POSTGRES_USER")
|
||||
postgres_password: str = Field(default="postgres", alias="POSTGRES_PASSWORD")
|
||||
|
||||
database_url: str | None = Field(default=None, alias="DATABASE_URL")
|
||||
sqlalchemy_echo: bool = Field(default=False, alias="SQLALCHEMY_ECHO")
|
||||
|
||||
redis_url: str | None = Field(default=None, alias="REDIS_URL")
|
||||
cors_origins: list[str] = Field(default_factory=list, alias="CORS_ORIGINS")
|
||||
vite_api_base_url: str = Field(
|
||||
default="http://127.0.0.1:8000/api/v1", alias="VITE_API_BASE_URL"
|
||||
)
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
from os import environ
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
SERVER_DIR = Path(__file__).resolve().parents[3]
|
||||
ROOT_DIR = SERVER_DIR.parent
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=(ROOT_DIR / ".env", SERVER_DIR / ".env"),
|
||||
env_file_encoding="utf-8",
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
app_name: str = Field(default="X-Financial Server", alias="APP_NAME")
|
||||
app_env: str = Field(default="local", alias="APP_ENV")
|
||||
app_debug: bool = Field(default=True, alias="APP_DEBUG")
|
||||
setup_completed: bool = Field(default=False, alias="SETUP_COMPLETED")
|
||||
|
||||
company_name: str = Field(default="", alias="COMPANY_NAME")
|
||||
company_code: str = Field(default="", alias="COMPANY_CODE")
|
||||
admin_email: str = Field(default="", alias="ADMIN_EMAIL")
|
||||
|
||||
web_host: str = Field(default="0.0.0.0", alias="WEB_HOST")
|
||||
web_port: int = Field(default=5173, alias="WEB_PORT")
|
||||
app_host: str = Field(default="0.0.0.0", alias="SERVER_HOST")
|
||||
app_port: int = Field(default=8000, alias="SERVER_PORT")
|
||||
api_v1_prefix: str = Field(default="/api/v1", alias="API_V1_PREFIX")
|
||||
|
||||
postgres_host: str = Field(default="127.0.0.1", alias="POSTGRES_HOST")
|
||||
postgres_port: int = Field(default=5432, alias="POSTGRES_PORT")
|
||||
postgres_db: str = Field(default="x_financial", alias="POSTGRES_DB")
|
||||
postgres_user: str = Field(default="postgres", alias="POSTGRES_USER")
|
||||
postgres_password: str = Field(default="postgres", alias="POSTGRES_PASSWORD")
|
||||
|
||||
database_url: str | None = Field(default=None, alias="DATABASE_URL")
|
||||
sqlalchemy_echo: bool = Field(default=False, alias="SQLALCHEMY_ECHO")
|
||||
|
||||
redis_url: str | None = Field(default=None, alias="REDIS_URL")
|
||||
cors_origins: list[str] = Field(default_factory=list, alias="CORS_ORIGINS")
|
||||
vite_api_base_url: str = Field(
|
||||
default="http://127.0.0.1:8000/api/v1", alias="VITE_API_BASE_URL"
|
||||
)
|
||||
|
||||
log_level: str = Field(default="INFO", alias="LOG_LEVEL")
|
||||
log_dir: str = Field(default="logs", alias="LOG_DIR")
|
||||
log_file_enabled: bool = Field(default=True, alias="LOG_FILE_ENABLED")
|
||||
storage_root_dir: str = Field(default="storage", alias="STORAGE_ROOT_DIR")
|
||||
|
||||
@property
|
||||
def resolved_database_url(self) -> str:
|
||||
if self.database_url:
|
||||
return self.database_url
|
||||
|
||||
|
||||
return (
|
||||
f"postgresql+psycopg://{self.postgres_user}:{self.postgres_password}"
|
||||
f"@{self.postgres_host}:{self.postgres_port}/{self.postgres_db}"
|
||||
)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_settings() -> Settings:
|
||||
return Settings()
|
||||
|
||||
|
||||
def refresh_settings(updated_values: dict[str, str]) -> Settings:
|
||||
for key, value in updated_values.items():
|
||||
environ[key] = value
|
||||
|
||||
get_settings.cache_clear()
|
||||
return get_settings()
|
||||
@property
|
||||
def resolved_storage_root_dir(self) -> Path:
|
||||
path = Path(self.storage_root_dir)
|
||||
if not path.is_absolute():
|
||||
path = SERVER_DIR / path
|
||||
return path.resolve()
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_settings() -> Settings:
|
||||
return Settings()
|
||||
|
||||
|
||||
def refresh_settings(updated_values: dict[str, str]) -> Settings:
|
||||
for key, value in updated_values.items():
|
||||
environ[key] = value
|
||||
|
||||
get_settings.cache_clear()
|
||||
return get_settings()
|
||||
|
||||
@@ -1,71 +1,71 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import secrets
|
||||
from base64 import urlsafe_b64decode, urlsafe_b64encode
|
||||
|
||||
PBKDF2_ALGORITHM = "sha256"
|
||||
PBKDF2_ITERATIONS = 120_000
|
||||
SALT_BYTES = 16
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
salt = secrets.token_bytes(SALT_BYTES)
|
||||
digest = hashlib.pbkdf2_hmac(
|
||||
PBKDF2_ALGORITHM,
|
||||
password.encode("utf-8"),
|
||||
salt,
|
||||
PBKDF2_ITERATIONS,
|
||||
)
|
||||
encoded_salt = urlsafe_b64encode(salt).decode("utf-8")
|
||||
encoded_digest = urlsafe_b64encode(digest).decode("utf-8")
|
||||
return f"pbkdf2_{PBKDF2_ALGORITHM}${PBKDF2_ITERATIONS}${encoded_salt}${encoded_digest}"
|
||||
|
||||
|
||||
def verify_password(password: str, password_hash: str) -> bool:
|
||||
if password_hash.startswith("scrypt$"):
|
||||
return verify_scrypt_password(password, password_hash)
|
||||
|
||||
try:
|
||||
scheme, iterations, encoded_salt, encoded_digest = password_hash.split("$", 3)
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
if scheme != f"pbkdf2_{PBKDF2_ALGORITHM}":
|
||||
return False
|
||||
|
||||
salt = urlsafe_b64decode(encoded_salt.encode("utf-8"))
|
||||
expected_digest = urlsafe_b64decode(encoded_digest.encode("utf-8"))
|
||||
computed_digest = hashlib.pbkdf2_hmac(
|
||||
PBKDF2_ALGORITHM,
|
||||
password.encode("utf-8"),
|
||||
salt,
|
||||
int(iterations),
|
||||
)
|
||||
return secrets.compare_digest(computed_digest, expected_digest)
|
||||
|
||||
|
||||
def verify_scrypt_password(password: str, password_hash: str) -> bool:
|
||||
try:
|
||||
scheme, n_value, r_value, p_value, key_length, salt_hex, derived_key_hex = password_hash.split("$", 6)
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
if scheme != "scrypt":
|
||||
return False
|
||||
|
||||
try:
|
||||
salt = bytes.fromhex(salt_hex)
|
||||
expected_key = bytes.fromhex(derived_key_hex)
|
||||
derived_key = hashlib.scrypt(
|
||||
password.encode("utf-8"),
|
||||
salt=salt,
|
||||
n=int(n_value),
|
||||
r=int(r_value),
|
||||
p=int(p_value),
|
||||
dklen=int(key_length),
|
||||
)
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
return secrets.compare_digest(derived_key, expected_key)
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import secrets
|
||||
from base64 import urlsafe_b64decode, urlsafe_b64encode
|
||||
|
||||
PBKDF2_ALGORITHM = "sha256"
|
||||
PBKDF2_ITERATIONS = 120_000
|
||||
SALT_BYTES = 16
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
salt = secrets.token_bytes(SALT_BYTES)
|
||||
digest = hashlib.pbkdf2_hmac(
|
||||
PBKDF2_ALGORITHM,
|
||||
password.encode("utf-8"),
|
||||
salt,
|
||||
PBKDF2_ITERATIONS,
|
||||
)
|
||||
encoded_salt = urlsafe_b64encode(salt).decode("utf-8")
|
||||
encoded_digest = urlsafe_b64encode(digest).decode("utf-8")
|
||||
return f"pbkdf2_{PBKDF2_ALGORITHM}${PBKDF2_ITERATIONS}${encoded_salt}${encoded_digest}"
|
||||
|
||||
|
||||
def verify_password(password: str, password_hash: str) -> bool:
|
||||
if password_hash.startswith("scrypt$"):
|
||||
return verify_scrypt_password(password, password_hash)
|
||||
|
||||
try:
|
||||
scheme, iterations, encoded_salt, encoded_digest = password_hash.split("$", 3)
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
if scheme != f"pbkdf2_{PBKDF2_ALGORITHM}":
|
||||
return False
|
||||
|
||||
salt = urlsafe_b64decode(encoded_salt.encode("utf-8"))
|
||||
expected_digest = urlsafe_b64decode(encoded_digest.encode("utf-8"))
|
||||
computed_digest = hashlib.pbkdf2_hmac(
|
||||
PBKDF2_ALGORITHM,
|
||||
password.encode("utf-8"),
|
||||
salt,
|
||||
int(iterations),
|
||||
)
|
||||
return secrets.compare_digest(computed_digest, expected_digest)
|
||||
|
||||
|
||||
def verify_scrypt_password(password: str, password_hash: str) -> bool:
|
||||
try:
|
||||
scheme, n_value, r_value, p_value, key_length, salt_hex, derived_key_hex = password_hash.split("$", 6)
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
if scheme != "scrypt":
|
||||
return False
|
||||
|
||||
try:
|
||||
salt = bytes.fromhex(salt_hex)
|
||||
expected_key = bytes.fromhex(derived_key_hex)
|
||||
derived_key = hashlib.scrypt(
|
||||
password.encode("utf-8"),
|
||||
salt=salt,
|
||||
n=int(n_value),
|
||||
r=int(r_value),
|
||||
p=int(p_value),
|
||||
dklen=int(key_length),
|
||||
)
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
return secrets.compare_digest(derived_key, expected_key)
|
||||
|
||||
@@ -1,23 +1,23 @@
|
||||
from app.db.base_class import Base
|
||||
from app.models.approval import ApprovalRecord
|
||||
from app.models.employee_change_log import EmployeeChangeLog
|
||||
from app.models.employee import Employee
|
||||
from app.models.organization import OrganizationUnit
|
||||
from app.models.reimbursement import ReimbursementRequest
|
||||
from app.models.role import Role
|
||||
from app.models.system_model_setting import SystemModelSetting
|
||||
from app.models.system_setting import SystemSetting
|
||||
from app.models.system_setting_secret import SystemSettingSecret
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
"ApprovalRecord",
|
||||
"Employee",
|
||||
"EmployeeChangeLog",
|
||||
"OrganizationUnit",
|
||||
"ReimbursementRequest",
|
||||
"Role",
|
||||
"SystemModelSetting",
|
||||
"SystemSetting",
|
||||
"SystemSettingSecret",
|
||||
]
|
||||
from app.db.base_class import Base
|
||||
from app.models.approval import ApprovalRecord
|
||||
from app.models.employee_change_log import EmployeeChangeLog
|
||||
from app.models.employee import Employee
|
||||
from app.models.organization import OrganizationUnit
|
||||
from app.models.reimbursement import ReimbursementRequest
|
||||
from app.models.role import Role
|
||||
from app.models.system_model_setting import SystemModelSetting
|
||||
from app.models.system_setting import SystemSetting
|
||||
from app.models.system_setting_secret import SystemSettingSecret
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
"ApprovalRecord",
|
||||
"Employee",
|
||||
"EmployeeChangeLog",
|
||||
"OrganizationUnit",
|
||||
"ReimbursementRequest",
|
||||
"Role",
|
||||
"SystemModelSetting",
|
||||
"SystemSetting",
|
||||
"SystemSettingSecret",
|
||||
]
|
||||
|
||||
@@ -8,6 +8,7 @@ from app.core.config import get_settings
|
||||
from app.core.logging import get_logger, setup_logging
|
||||
from app.middleware.logging import AccessLogMiddleware
|
||||
from app.services.employee import prepare_employee_directory
|
||||
from app.services.knowledge import prepare_knowledge_library
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
@@ -50,6 +51,7 @@ def create_app() -> FastAPI:
|
||||
@app.on_event("startup")
|
||||
def _on_startup() -> None:
|
||||
prepare_employee_directory()
|
||||
prepare_knowledge_library()
|
||||
logger.info(
|
||||
"Server ready - host=%s port=%s prefix=%s",
|
||||
settings.app_host,
|
||||
|
||||
@@ -1,21 +1,21 @@
|
||||
from app.models.approval import ApprovalRecord
|
||||
from app.models.employee_change_log import EmployeeChangeLog
|
||||
from app.models.employee import Employee
|
||||
from app.models.organization import OrganizationUnit
|
||||
from app.models.reimbursement import ReimbursementRequest
|
||||
from app.models.role import Role
|
||||
from app.models.system_model_setting import SystemModelSetting
|
||||
from app.models.system_setting import SystemSetting
|
||||
from app.models.system_setting_secret import SystemSettingSecret
|
||||
|
||||
__all__ = [
|
||||
"ApprovalRecord",
|
||||
"Employee",
|
||||
"EmployeeChangeLog",
|
||||
"OrganizationUnit",
|
||||
"ReimbursementRequest",
|
||||
"Role",
|
||||
"SystemModelSetting",
|
||||
"SystemSetting",
|
||||
"SystemSettingSecret",
|
||||
]
|
||||
from app.models.approval import ApprovalRecord
|
||||
from app.models.employee_change_log import EmployeeChangeLog
|
||||
from app.models.employee import Employee
|
||||
from app.models.organization import OrganizationUnit
|
||||
from app.models.reimbursement import ReimbursementRequest
|
||||
from app.models.role import Role
|
||||
from app.models.system_model_setting import SystemModelSetting
|
||||
from app.models.system_setting import SystemSetting
|
||||
from app.models.system_setting_secret import SystemSettingSecret
|
||||
|
||||
__all__ = [
|
||||
"ApprovalRecord",
|
||||
"Employee",
|
||||
"EmployeeChangeLog",
|
||||
"OrganizationUnit",
|
||||
"ReimbursementRequest",
|
||||
"Role",
|
||||
"SystemModelSetting",
|
||||
"SystemSetting",
|
||||
"SystemSettingSecret",
|
||||
]
|
||||
|
||||
@@ -1,28 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, Integer, String, Text, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.db.base_class import Base
|
||||
|
||||
|
||||
class SystemModelSetting(Base):
|
||||
__tablename__ = "system_model_settings"
|
||||
|
||||
slot: Mapped[str] = mapped_column(String(32), primary_key=True)
|
||||
provider: Mapped[str] = mapped_column(String(64), default="")
|
||||
model_name: Mapped[str] = mapped_column(String(255), default="")
|
||||
endpoint: Mapped[str] = mapped_column(String(512), default="")
|
||||
capability: Mapped[str] = mapped_column(String(32), default="chat")
|
||||
priority: Mapped[int] = mapped_column(Integer, default=0)
|
||||
enabled: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
api_key_encrypted: Mapped[str] = mapped_column(Text, default="")
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, Integer, String, Text, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.db.base_class import Base
|
||||
|
||||
|
||||
class SystemModelSetting(Base):
|
||||
__tablename__ = "system_model_settings"
|
||||
|
||||
slot: Mapped[str] = mapped_column(String(32), primary_key=True)
|
||||
provider: Mapped[str] = mapped_column(String(64), default="")
|
||||
model_name: Mapped[str] = mapped_column(String(255), default="")
|
||||
endpoint: Mapped[str] = mapped_column(String(512), default="")
|
||||
capability: Mapped[str] = mapped_column(String(32), default="chat")
|
||||
priority: Mapped[int] = mapped_column(Integer, default=0)
|
||||
enabled: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
api_key_encrypted: Mapped[str] = mapped_column(Text, default="")
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
|
||||
@@ -1,43 +1,43 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.system_model_setting import SystemModelSetting
|
||||
from app.models.system_setting import SystemSetting
|
||||
from app.models.system_setting_secret import SystemSettingSecret
|
||||
|
||||
SETTINGS_ROW_ID = "default"
|
||||
|
||||
|
||||
class SettingsRepository:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
|
||||
def get_settings(self) -> SystemSetting | None:
|
||||
stmt = select(SystemSetting).where(SystemSetting.id == SETTINGS_ROW_ID)
|
||||
return self.db.execute(stmt).scalars().first()
|
||||
|
||||
def get_secrets(self) -> SystemSettingSecret | None:
|
||||
stmt = select(SystemSettingSecret).where(SystemSettingSecret.id == SETTINGS_ROW_ID)
|
||||
return self.db.execute(stmt).scalars().first()
|
||||
|
||||
def get_model_settings(self) -> list[SystemModelSetting]:
|
||||
stmt = select(SystemModelSetting)
|
||||
return list(self.db.execute(stmt).scalars().all())
|
||||
|
||||
def get_model_setting(self, slot: str) -> SystemModelSetting | None:
|
||||
stmt = select(SystemModelSetting).where(SystemModelSetting.slot == slot)
|
||||
return self.db.execute(stmt).scalars().first()
|
||||
|
||||
def save_settings(self, settings: SystemSetting) -> SystemSetting:
|
||||
self.db.add(settings)
|
||||
self.db.commit()
|
||||
self.db.refresh(settings)
|
||||
return settings
|
||||
|
||||
def save_secrets(self, secrets: SystemSettingSecret) -> SystemSettingSecret:
|
||||
self.db.add(secrets)
|
||||
self.db.commit()
|
||||
self.db.refresh(secrets)
|
||||
return secrets
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.system_model_setting import SystemModelSetting
|
||||
from app.models.system_setting import SystemSetting
|
||||
from app.models.system_setting_secret import SystemSettingSecret
|
||||
|
||||
SETTINGS_ROW_ID = "default"
|
||||
|
||||
|
||||
class SettingsRepository:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
|
||||
def get_settings(self) -> SystemSetting | None:
|
||||
stmt = select(SystemSetting).where(SystemSetting.id == SETTINGS_ROW_ID)
|
||||
return self.db.execute(stmt).scalars().first()
|
||||
|
||||
def get_secrets(self) -> SystemSettingSecret | None:
|
||||
stmt = select(SystemSettingSecret).where(SystemSettingSecret.id == SETTINGS_ROW_ID)
|
||||
return self.db.execute(stmt).scalars().first()
|
||||
|
||||
def get_model_settings(self) -> list[SystemModelSetting]:
|
||||
stmt = select(SystemModelSetting)
|
||||
return list(self.db.execute(stmt).scalars().all())
|
||||
|
||||
def get_model_setting(self, slot: str) -> SystemModelSetting | None:
|
||||
stmt = select(SystemModelSetting).where(SystemModelSetting.slot == slot)
|
||||
return self.db.execute(stmt).scalars().first()
|
||||
|
||||
def save_settings(self, settings: SystemSetting) -> SystemSetting:
|
||||
self.db.add(settings)
|
||||
self.db.commit()
|
||||
self.db.refresh(settings)
|
||||
return settings
|
||||
|
||||
def save_secrets(self, secrets: SystemSettingSecret) -> SystemSettingSecret:
|
||||
self.db.add(secrets)
|
||||
self.db.commit()
|
||||
self.db.refresh(secrets)
|
||||
return secrets
|
||||
|
||||
@@ -1 +1 @@
|
||||
__all__ = ["employee", "reimbursement"]
|
||||
__all__ = ["employee", "knowledge", "reimbursement"]
|
||||
|
||||
61
server/src/app/schemas/knowledge.py
Normal file
61
server/src/app/schemas/knowledge.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class KnowledgeFolderRead(BaseModel):
|
||||
name: str
|
||||
count: int
|
||||
icon: str = "mdi mdi-folder"
|
||||
|
||||
|
||||
class KnowledgePreviewStatRead(BaseModel):
|
||||
label: str
|
||||
value: str
|
||||
|
||||
|
||||
class KnowledgePreviewBlockRead(BaseModel):
|
||||
heading: str
|
||||
lines: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class KnowledgePreviewPageRead(BaseModel):
|
||||
title: str
|
||||
subtitle: str
|
||||
stats: list[KnowledgePreviewStatRead] = Field(default_factory=list)
|
||||
blocks: list[KnowledgePreviewBlockRead] = Field(default_factory=list)
|
||||
|
||||
|
||||
class KnowledgeDocumentRead(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
folder: str
|
||||
tag: str
|
||||
time: str
|
||||
version: str
|
||||
state: str
|
||||
stateTone: str
|
||||
owner: str
|
||||
icon: str
|
||||
fileType: str
|
||||
fileTypeLabel: str
|
||||
summary: str
|
||||
mimeType: str
|
||||
extension: str
|
||||
sizeBytes: int
|
||||
canPreview: bool = False
|
||||
|
||||
|
||||
class KnowledgeDocumentDetailRead(KnowledgeDocumentRead):
|
||||
previewKind: str
|
||||
previewPages: list[KnowledgePreviewPageRead] = Field(default_factory=list)
|
||||
|
||||
|
||||
class KnowledgeLibraryRead(BaseModel):
|
||||
folders: list[KnowledgeFolderRead] = Field(default_factory=list)
|
||||
documents: list[KnowledgeDocumentRead] = Field(default_factory=list)
|
||||
|
||||
|
||||
class KnowledgeActionResponse(BaseModel):
|
||||
ok: bool = True
|
||||
detail: str
|
||||
634
server/src/app/services/knowledge.py
Normal file
634
server/src/app/services/knowledge.py
Normal file
@@ -0,0 +1,634 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import mimetypes
|
||||
import re
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
from xml.etree import ElementTree
|
||||
from zipfile import BadZipFile, ZipFile
|
||||
|
||||
from app.api.deps import CurrentUserContext
|
||||
from app.core.config import get_settings
|
||||
from app.core.logging import get_logger
|
||||
from app.schemas.knowledge import (
|
||||
KnowledgeDocumentDetailRead,
|
||||
KnowledgeDocumentRead,
|
||||
KnowledgeFolderRead,
|
||||
KnowledgeLibraryRead,
|
||||
KnowledgePreviewBlockRead,
|
||||
KnowledgePreviewPageRead,
|
||||
KnowledgePreviewStatRead,
|
||||
)
|
||||
|
||||
logger = get_logger("app.services.knowledge")
|
||||
|
||||
FIXED_KNOWLEDGE_FOLDERS = [
|
||||
"财务知识库",
|
||||
"制度政策",
|
||||
"报销制度",
|
||||
"差旅规范",
|
||||
"发票管理",
|
||||
"税务合规",
|
||||
"预算管理",
|
||||
"财务共享",
|
||||
"培训资料",
|
||||
"常见问答",
|
||||
]
|
||||
|
||||
ICON_BY_TYPE = {
|
||||
"pdf": "mdi mdi-file-document-outline-pdf pdf",
|
||||
"word": "mdi mdi-file-document-outline-word word",
|
||||
"excel": "mdi mdi-file-document-outline-excel excel",
|
||||
"ppt": "mdi mdi-file-powerpoint-box ppt",
|
||||
"image": "mdi mdi-file-image-outline image",
|
||||
"text": "mdi mdi-file-document-outline text",
|
||||
"archive": "mdi mdi-folder-zip-outline archive",
|
||||
"binary": "mdi mdi-file-outline",
|
||||
}
|
||||
|
||||
TEXT_EXTENSIONS = {"txt", "md", "csv", "json", "xml", "yml", "yaml", "log"}
|
||||
WORD_EXTENSIONS = {"doc", "docx"}
|
||||
EXCEL_EXTENSIONS = {"xls", "xlsx", "csv"}
|
||||
PPT_EXTENSIONS = {"ppt", "pptx"}
|
||||
IMAGE_EXTENSIONS = {"png", "jpg", "jpeg", "gif", "bmp", "webp", "svg"}
|
||||
ARCHIVE_EXTENSIONS = {"zip", "rar", "7z"}
|
||||
STRUCTURED_PREVIEW_EXTENSIONS = {"docx", "xlsx", "pptx"} | TEXT_EXTENSIONS
|
||||
INLINE_PREVIEW_EXTENSIONS = {"pdf"} | IMAGE_EXTENSIONS
|
||||
|
||||
|
||||
def prepare_knowledge_library() -> None:
|
||||
KnowledgeService().ensure_library_ready()
|
||||
|
||||
|
||||
class KnowledgeService:
|
||||
def __init__(self, storage_root: Path | None = None) -> None:
|
||||
settings = get_settings()
|
||||
self.storage_root = Path(storage_root or settings.resolved_storage_root_dir)
|
||||
self.library_root = self.storage_root / "knowledge"
|
||||
self.index_path = self.library_root / ".index.json"
|
||||
|
||||
def ensure_library_ready(self) -> None:
|
||||
self.library_root.mkdir(parents=True, exist_ok=True)
|
||||
for folder_name in FIXED_KNOWLEDGE_FOLDERS:
|
||||
(self.library_root / folder_name).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if not self.index_path.exists():
|
||||
self._save_index({"version": 1, "documents": []})
|
||||
|
||||
index = self._load_index()
|
||||
if self._reconcile_index(index):
|
||||
self._save_index(index)
|
||||
|
||||
def list_library(self) -> KnowledgeLibraryRead:
|
||||
documents = self._load_documents()
|
||||
folders = [
|
||||
KnowledgeFolderRead(
|
||||
name=folder_name,
|
||||
count=sum(1 for item in documents if item.folder == folder_name),
|
||||
icon="mdi mdi-folder-open" if folder_name == "差旅规范" else "mdi mdi-folder",
|
||||
)
|
||||
for folder_name in FIXED_KNOWLEDGE_FOLDERS
|
||||
]
|
||||
return KnowledgeLibraryRead(folders=folders, documents=documents)
|
||||
|
||||
def get_document_detail(self, document_id: str) -> KnowledgeDocumentDetailRead:
|
||||
self.ensure_library_ready()
|
||||
index = self._load_index()
|
||||
entry = self._require_entry(index, document_id)
|
||||
preview_kind, preview_pages = self._build_preview(entry)
|
||||
document = self._serialize_document(entry)
|
||||
return KnowledgeDocumentDetailRead(
|
||||
**document.model_dump(),
|
||||
previewKind=preview_kind,
|
||||
previewPages=preview_pages,
|
||||
)
|
||||
|
||||
def upload_document(
|
||||
self,
|
||||
folder: str,
|
||||
filename: str,
|
||||
content: bytes,
|
||||
current_user: CurrentUserContext,
|
||||
) -> KnowledgeDocumentDetailRead:
|
||||
self.ensure_library_ready()
|
||||
normalized_folder = self._normalize_folder(folder)
|
||||
normalized_name = self._normalize_filename(filename)
|
||||
|
||||
if not content:
|
||||
raise ValueError("上传文件不能为空。")
|
||||
|
||||
index = self._load_index()
|
||||
existing_entry = next(
|
||||
(
|
||||
item
|
||||
for item in index["documents"]
|
||||
if item["folder"] == normalized_folder
|
||||
and item["original_name"].lower() == normalized_name.lower()
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
document_id = existing_entry["id"] if existing_entry else uuid4().hex
|
||||
stored_name = f"{document_id}__{normalized_name}"
|
||||
target_path = self.library_root / normalized_folder / stored_name
|
||||
|
||||
if existing_entry is not None and existing_entry["stored_name"] != stored_name:
|
||||
old_path = self.library_root / existing_entry["folder"] / existing_entry["stored_name"]
|
||||
if old_path.exists():
|
||||
old_path.unlink()
|
||||
|
||||
target_path.write_bytes(content)
|
||||
|
||||
now = datetime.now(UTC).isoformat()
|
||||
mime_type = mimetypes.guess_type(normalized_name)[0] or "application/octet-stream"
|
||||
checksum = hashlib.sha256(content).hexdigest()
|
||||
extension = self._extract_extension(normalized_name)
|
||||
|
||||
if existing_entry is None:
|
||||
entry = {
|
||||
"id": document_id,
|
||||
"folder": normalized_folder,
|
||||
"original_name": normalized_name,
|
||||
"stored_name": stored_name,
|
||||
"mime_type": mime_type,
|
||||
"extension": extension,
|
||||
"size_bytes": len(content),
|
||||
"sha256": checksum,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"uploaded_by": current_user.name,
|
||||
"version_number": 1,
|
||||
}
|
||||
index["documents"].append(entry)
|
||||
logger.info(
|
||||
"Knowledge document uploaded id=%s folder=%s filename=%s by=%s",
|
||||
document_id,
|
||||
normalized_folder,
|
||||
normalized_name,
|
||||
current_user.name,
|
||||
)
|
||||
else:
|
||||
existing_entry.update(
|
||||
{
|
||||
"stored_name": stored_name,
|
||||
"mime_type": mime_type,
|
||||
"extension": extension,
|
||||
"size_bytes": len(content),
|
||||
"sha256": checksum,
|
||||
"updated_at": now,
|
||||
"uploaded_by": current_user.name,
|
||||
"version_number": int(existing_entry.get("version_number", 1)) + 1,
|
||||
}
|
||||
)
|
||||
entry = existing_entry
|
||||
logger.info(
|
||||
"Knowledge document updated id=%s folder=%s filename=%s by=%s",
|
||||
document_id,
|
||||
normalized_folder,
|
||||
normalized_name,
|
||||
current_user.name,
|
||||
)
|
||||
|
||||
self._save_index(index)
|
||||
return self.get_document_detail(document_id)
|
||||
|
||||
def delete_document(self, document_id: str) -> None:
|
||||
self.ensure_library_ready()
|
||||
index = self._load_index()
|
||||
entry = self._require_entry(index, document_id)
|
||||
file_path = self._resolve_document_path(entry)
|
||||
if file_path.exists():
|
||||
file_path.unlink()
|
||||
|
||||
index["documents"] = [item for item in index["documents"] if item["id"] != document_id]
|
||||
self._save_index(index)
|
||||
logger.info("Knowledge document deleted id=%s filename=%s", document_id, entry["original_name"])
|
||||
|
||||
def get_document_content(self, document_id: str) -> tuple[Path, str, str]:
|
||||
self.ensure_library_ready()
|
||||
index = self._load_index()
|
||||
entry = self._require_entry(index, document_id)
|
||||
file_path = self._resolve_document_path(entry)
|
||||
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(entry["original_name"])
|
||||
|
||||
return file_path, entry["mime_type"], entry["original_name"]
|
||||
|
||||
def _load_documents(self) -> list[KnowledgeDocumentRead]:
|
||||
self.ensure_library_ready()
|
||||
index = self._load_index()
|
||||
self._reconcile_index(index)
|
||||
self._save_index(index)
|
||||
|
||||
documents = [self._serialize_document(entry) for entry in index["documents"]]
|
||||
return sorted(documents, key=lambda item: item.time, reverse=True)
|
||||
|
||||
def _serialize_document(self, entry: dict[str, Any]) -> KnowledgeDocumentRead:
|
||||
extension = entry.get("extension") or self._extract_extension(entry["original_name"])
|
||||
file_type = self._resolve_file_type(extension)
|
||||
size_bytes = int(entry.get("size_bytes") or 0)
|
||||
updated_at = self._format_time(entry.get("updated_at") or entry.get("created_at"))
|
||||
|
||||
return KnowledgeDocumentRead(
|
||||
id=entry["id"],
|
||||
name=entry["original_name"],
|
||||
folder=entry["folder"],
|
||||
tag=f"{entry['folder']} / {extension.upper() or 'FILE'}",
|
||||
time=updated_at,
|
||||
version=f"v{int(entry.get('version_number', 1))}.0",
|
||||
state="已发布",
|
||||
stateTone="success",
|
||||
owner=entry.get("uploaded_by") or "系统导入",
|
||||
icon=ICON_BY_TYPE.get(file_type, ICON_BY_TYPE["binary"]),
|
||||
fileType=file_type,
|
||||
fileTypeLabel=self._resolve_file_type_label(file_type),
|
||||
summary=f"{entry['folder']} · {extension.upper() or 'FILE'} · {self._format_size(size_bytes)}",
|
||||
mimeType=entry.get("mime_type") or "application/octet-stream",
|
||||
extension=extension,
|
||||
sizeBytes=size_bytes,
|
||||
canPreview=self._can_preview(extension),
|
||||
)
|
||||
|
||||
def _build_preview(
|
||||
self, entry: dict[str, Any]
|
||||
) -> tuple[str, list[KnowledgePreviewPageRead]]:
|
||||
extension = self._extract_extension(entry["original_name"])
|
||||
file_path = self._resolve_document_path(entry)
|
||||
|
||||
if extension == "pdf":
|
||||
return "pdf", []
|
||||
|
||||
if extension in IMAGE_EXTENSIONS:
|
||||
return "image", []
|
||||
|
||||
if extension in TEXT_EXTENSIONS:
|
||||
text = self._read_text_preview(file_path)
|
||||
return "text", [self._build_text_preview_page(entry, text)]
|
||||
|
||||
if extension == "docx":
|
||||
text = self._extract_docx_text(file_path)
|
||||
return "text", [self._build_text_preview_page(entry, text)]
|
||||
|
||||
if extension == "xlsx":
|
||||
return "table", [self._build_xlsx_preview_page(entry, file_path)]
|
||||
|
||||
if extension == "pptx":
|
||||
return "slides", self._build_pptx_preview_pages(entry, file_path)
|
||||
|
||||
return (
|
||||
"unsupported",
|
||||
[
|
||||
KnowledgePreviewPageRead(
|
||||
title=entry["original_name"],
|
||||
subtitle="当前格式暂不支持在线解析预览。",
|
||||
stats=[
|
||||
KnowledgePreviewStatRead(label="文件格式", value=extension.upper() or "FILE"),
|
||||
KnowledgePreviewStatRead(label="文件大小", value=self._format_size(entry["size_bytes"])),
|
||||
KnowledgePreviewStatRead(label="建议操作", value="下载后查看"),
|
||||
],
|
||||
blocks=[
|
||||
KnowledgePreviewBlockRead(
|
||||
heading="预览说明",
|
||||
lines=[
|
||||
"当前系统已支持该文件的上传、下载和权限控制。",
|
||||
"如需在线预览,可后续接入专门的文档转换服务。",
|
||||
],
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
def _build_text_preview_page(
|
||||
self, entry: dict[str, Any], text: str
|
||||
) -> KnowledgePreviewPageRead:
|
||||
lines = [line.strip() for line in text.splitlines() if line.strip()]
|
||||
if not lines:
|
||||
lines = ["文件内容为空,或当前文档未提取到可展示文本。"]
|
||||
|
||||
groups = [lines[index : index + 8] for index in range(0, min(len(lines), 24), 8)]
|
||||
blocks = [
|
||||
KnowledgePreviewBlockRead(heading=f"内容片段 {index + 1}", lines=group)
|
||||
for index, group in enumerate(groups)
|
||||
]
|
||||
|
||||
return KnowledgePreviewPageRead(
|
||||
title=entry["original_name"],
|
||||
subtitle="文本提取预览",
|
||||
stats=[
|
||||
KnowledgePreviewStatRead(label="文件格式", value=entry["extension"].upper() or "TEXT"),
|
||||
KnowledgePreviewStatRead(label="可见行数", value=str(len(lines))),
|
||||
KnowledgePreviewStatRead(label="文件大小", value=self._format_size(entry["size_bytes"])),
|
||||
],
|
||||
blocks=blocks,
|
||||
)
|
||||
|
||||
def _build_xlsx_preview_page(
|
||||
self, entry: dict[str, Any], file_path: Path
|
||||
) -> KnowledgePreviewPageRead:
|
||||
rows, sheet_count = self._extract_xlsx_rows(file_path)
|
||||
if not rows:
|
||||
rows = [["未提取到表格内容。"]]
|
||||
|
||||
blocks = [
|
||||
KnowledgePreviewBlockRead(
|
||||
heading=f"第 {index + 1} 行",
|
||||
lines=[" | ".join(cell for cell in row if cell) or "(空行)"],
|
||||
)
|
||||
for index, row in enumerate(rows[:12])
|
||||
]
|
||||
|
||||
return KnowledgePreviewPageRead(
|
||||
title=entry["original_name"],
|
||||
subtitle="表格内容预览",
|
||||
stats=[
|
||||
KnowledgePreviewStatRead(label="工作表数量", value=str(sheet_count)),
|
||||
KnowledgePreviewStatRead(label="预览行数", value=str(min(len(rows), 12))),
|
||||
KnowledgePreviewStatRead(label="文件大小", value=self._format_size(entry["size_bytes"])),
|
||||
],
|
||||
blocks=blocks,
|
||||
)
|
||||
|
||||
def _build_pptx_preview_pages(
|
||||
self, entry: dict[str, Any], file_path: Path
|
||||
) -> list[KnowledgePreviewPageRead]:
|
||||
slides = self._extract_pptx_slides(file_path)
|
||||
if not slides:
|
||||
slides = [["未提取到幻灯片文本。"]]
|
||||
|
||||
pages: list[KnowledgePreviewPageRead] = []
|
||||
for index, slide_lines in enumerate(slides[:8]):
|
||||
pages.append(
|
||||
KnowledgePreviewPageRead(
|
||||
title=entry["original_name"],
|
||||
subtitle=f"幻灯片 {index + 1}",
|
||||
stats=[
|
||||
KnowledgePreviewStatRead(label="页码", value=str(index + 1)),
|
||||
KnowledgePreviewStatRead(label="文本条数", value=str(len(slide_lines))),
|
||||
KnowledgePreviewStatRead(label="文件格式", value="PPTX"),
|
||||
],
|
||||
blocks=[
|
||||
KnowledgePreviewBlockRead(
|
||||
heading="幻灯片内容",
|
||||
lines=slide_lines or ["该页未提取到文本内容。"],
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
return pages
|
||||
|
||||
def _load_index(self) -> dict[str, Any]:
|
||||
try:
|
||||
payload = json.loads(self.index_path.read_text(encoding="utf-8"))
|
||||
except (FileNotFoundError, json.JSONDecodeError):
|
||||
payload = {"version": 1, "documents": []}
|
||||
payload.setdefault("documents", [])
|
||||
return payload
|
||||
|
||||
def _save_index(self, index: dict[str, Any]) -> None:
|
||||
self.index_path.write_text(
|
||||
json.dumps(index, ensure_ascii=False, indent=2),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
def _reconcile_index(self, index: dict[str, Any]) -> bool:
|
||||
changed = False
|
||||
documents = index.setdefault("documents", [])
|
||||
known_by_stored = {
|
||||
(item["folder"], item["stored_name"]): item
|
||||
for item in documents
|
||||
if item.get("folder") and item.get("stored_name")
|
||||
}
|
||||
|
||||
existing_items: list[dict[str, Any]] = []
|
||||
for item in documents:
|
||||
file_path = self._resolve_document_path(item)
|
||||
if file_path.exists():
|
||||
item["size_bytes"] = file_path.stat().st_size
|
||||
item["extension"] = self._extract_extension(item["original_name"])
|
||||
item["mime_type"] = item.get("mime_type") or (
|
||||
mimetypes.guess_type(item["original_name"])[0] or "application/octet-stream"
|
||||
)
|
||||
existing_items.append(item)
|
||||
else:
|
||||
changed = True
|
||||
|
||||
for folder_name in FIXED_KNOWLEDGE_FOLDERS:
|
||||
folder_path = self.library_root / folder_name
|
||||
for file_path in folder_path.iterdir():
|
||||
if not file_path.is_file() or file_path.name.startswith("."):
|
||||
continue
|
||||
|
||||
key = (folder_name, file_path.name)
|
||||
if key in known_by_stored:
|
||||
continue
|
||||
|
||||
document_id, original_name = self._parse_stored_name(file_path.name)
|
||||
stat = file_path.stat()
|
||||
existing_items.append(
|
||||
{
|
||||
"id": document_id,
|
||||
"folder": folder_name,
|
||||
"original_name": original_name,
|
||||
"stored_name": file_path.name,
|
||||
"mime_type": mimetypes.guess_type(original_name)[0]
|
||||
or "application/octet-stream",
|
||||
"extension": self._extract_extension(original_name),
|
||||
"size_bytes": stat.st_size,
|
||||
"sha256": "",
|
||||
"created_at": datetime.fromtimestamp(stat.st_ctime, tz=UTC).isoformat(),
|
||||
"updated_at": datetime.fromtimestamp(stat.st_mtime, tz=UTC).isoformat(),
|
||||
"uploaded_by": "系统导入",
|
||||
"version_number": 1,
|
||||
}
|
||||
)
|
||||
changed = True
|
||||
|
||||
if changed or len(existing_items) != len(documents):
|
||||
index["documents"] = existing_items
|
||||
return True
|
||||
return False
|
||||
|
||||
def _require_entry(self, index: dict[str, Any], document_id: str) -> dict[str, Any]:
|
||||
for entry in index["documents"]:
|
||||
if entry["id"] == document_id:
|
||||
return entry
|
||||
raise FileNotFoundError(document_id)
|
||||
|
||||
def _resolve_document_path(self, entry: dict[str, Any]) -> Path:
|
||||
return self.library_root / entry["folder"] / entry["stored_name"]
|
||||
|
||||
@staticmethod
|
||||
def _normalize_filename(filename: str) -> str:
|
||||
normalized = Path(str(filename or "").strip()).name.strip()
|
||||
normalized = normalized.replace("/", "_").replace("\\", "_")
|
||||
if not normalized:
|
||||
raise ValueError("文件名不能为空。")
|
||||
return normalized
|
||||
|
||||
@staticmethod
|
||||
def _normalize_folder(folder: str) -> str:
|
||||
normalized = str(folder or "").strip()
|
||||
if normalized not in FIXED_KNOWLEDGE_FOLDERS:
|
||||
raise ValueError("只能上传到预设知识库文件夹。")
|
||||
return normalized
|
||||
|
||||
@staticmethod
|
||||
def _extract_extension(filename: str) -> str:
|
||||
suffix = Path(filename).suffix.lower().lstrip(".")
|
||||
return suffix
|
||||
|
||||
@staticmethod
|
||||
def _parse_stored_name(stored_name: str) -> tuple[str, str]:
|
||||
if "__" not in stored_name:
|
||||
return uuid4().hex, stored_name
|
||||
document_id, original_name = stored_name.split("__", 1)
|
||||
return document_id or uuid4().hex, original_name or stored_name
|
||||
|
||||
@staticmethod
|
||||
def _format_time(value: str | None) -> str:
|
||||
if not value:
|
||||
return ""
|
||||
try:
|
||||
parsed = datetime.fromisoformat(value)
|
||||
except ValueError:
|
||||
return value
|
||||
return parsed.astimezone(UTC).strftime("%Y-%m-%d %H:%M")
|
||||
|
||||
@staticmethod
|
||||
def _format_size(size_bytes: int) -> str:
|
||||
if size_bytes < 1024:
|
||||
return f"{size_bytes} B"
|
||||
if size_bytes < 1024 * 1024:
|
||||
return f"{size_bytes / 1024:.1f} KB"
|
||||
return f"{size_bytes / (1024 * 1024):.1f} MB"
|
||||
|
||||
@staticmethod
|
||||
def _resolve_file_type(extension: str) -> str:
|
||||
if extension == "pdf":
|
||||
return "pdf"
|
||||
if extension in WORD_EXTENSIONS:
|
||||
return "word"
|
||||
if extension in EXCEL_EXTENSIONS:
|
||||
return "excel"
|
||||
if extension in PPT_EXTENSIONS:
|
||||
return "ppt"
|
||||
if extension in IMAGE_EXTENSIONS:
|
||||
return "image"
|
||||
if extension in TEXT_EXTENSIONS:
|
||||
return "text"
|
||||
if extension in ARCHIVE_EXTENSIONS:
|
||||
return "archive"
|
||||
return "binary"
|
||||
|
||||
@staticmethod
|
||||
def _resolve_file_type_label(file_type: str) -> str:
|
||||
mapping = {
|
||||
"pdf": "PDF 预览",
|
||||
"word": "Word 预览",
|
||||
"excel": "Excel 预览",
|
||||
"ppt": "PPT 预览",
|
||||
"image": "图片预览",
|
||||
"text": "文本预览",
|
||||
"archive": "压缩包",
|
||||
"binary": "文件预览",
|
||||
}
|
||||
return mapping.get(file_type, "文件预览")
|
||||
|
||||
@staticmethod
|
||||
def _can_preview(extension: str) -> bool:
|
||||
return extension in INLINE_PREVIEW_EXTENSIONS or extension in STRUCTURED_PREVIEW_EXTENSIONS
|
||||
|
||||
@staticmethod
|
||||
def _read_text_preview(file_path: Path) -> str:
|
||||
encodings = ("utf-8", "utf-8-sig", "gbk")
|
||||
for encoding in encodings:
|
||||
try:
|
||||
return file_path.read_text(encoding=encoding)
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
return "当前文本文件编码暂不支持在线解析。"
|
||||
|
||||
@staticmethod
|
||||
def _extract_docx_text(file_path: Path) -> str:
|
||||
try:
|
||||
with ZipFile(file_path) as archive:
|
||||
xml_content = archive.read("word/document.xml")
|
||||
except (BadZipFile, KeyError):
|
||||
return "当前 Word 文件解析失败。"
|
||||
|
||||
root = ElementTree.fromstring(xml_content)
|
||||
texts = [node.text.strip() for node in root.iter() if node.tag.endswith("}t") and node.text]
|
||||
return "\n".join(texts)
|
||||
|
||||
@staticmethod
|
||||
def _extract_xlsx_rows(file_path: Path) -> tuple[list[list[str]], int]:
|
||||
try:
|
||||
with ZipFile(file_path) as archive:
|
||||
shared_strings: list[str] = []
|
||||
if "xl/sharedStrings.xml" in archive.namelist():
|
||||
shared_root = ElementTree.fromstring(archive.read("xl/sharedStrings.xml"))
|
||||
shared_strings = [
|
||||
"".join(node.itertext()).strip()
|
||||
for node in shared_root.iter()
|
||||
if node.tag.endswith("}si")
|
||||
]
|
||||
|
||||
sheet_names = sorted(
|
||||
name
|
||||
for name in archive.namelist()
|
||||
if re.fullmatch(r"xl/worksheets/sheet\d+\.xml", name)
|
||||
)
|
||||
if not sheet_names:
|
||||
return [], 0
|
||||
|
||||
first_sheet = ElementTree.fromstring(archive.read(sheet_names[0]))
|
||||
rows: list[list[str]] = []
|
||||
for row in first_sheet.iter():
|
||||
if not row.tag.endswith("}row"):
|
||||
continue
|
||||
row_values: list[str] = []
|
||||
for cell in row:
|
||||
if not cell.tag.endswith("}c"):
|
||||
continue
|
||||
cell_type = cell.attrib.get("t")
|
||||
value_node = next((item for item in cell if item.tag.endswith("}v")), None)
|
||||
if value_node is None or value_node.text is None:
|
||||
row_values.append("")
|
||||
continue
|
||||
raw_value = value_node.text.strip()
|
||||
if cell_type == "s" and raw_value.isdigit():
|
||||
index = int(raw_value)
|
||||
row_values.append(shared_strings[index] if index < len(shared_strings) else raw_value)
|
||||
else:
|
||||
row_values.append(raw_value)
|
||||
if row_values:
|
||||
rows.append(row_values)
|
||||
|
||||
return rows, len(sheet_names)
|
||||
except (BadZipFile, ElementTree.ParseError, KeyError, ValueError):
|
||||
return [], 0
|
||||
|
||||
@staticmethod
|
||||
def _extract_pptx_slides(file_path: Path) -> list[list[str]]:
|
||||
try:
|
||||
with ZipFile(file_path) as archive:
|
||||
slide_names = sorted(
|
||||
name
|
||||
for name in archive.namelist()
|
||||
if re.fullmatch(r"ppt/slides/slide\d+\.xml", name)
|
||||
)
|
||||
slides: list[list[str]] = []
|
||||
for slide_name in slide_names:
|
||||
root = ElementTree.fromstring(archive.read(slide_name))
|
||||
texts = [node.text.strip() for node in root.iter() if node.tag.endswith("}t") and node.text]
|
||||
slides.append(texts)
|
||||
return slides
|
||||
except (BadZipFile, ElementTree.ParseError, KeyError):
|
||||
return []
|
||||
File diff suppressed because it is too large
Load Diff
4
server/storage/knowledge/.index.json
Normal file
4
server/storage/knowledge/.index.json
Normal file
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"version": 1,
|
||||
"documents": []
|
||||
}
|
||||
@@ -1,70 +1,70 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from app.db.base import Base
|
||||
from app.schemas.auth import LoginRequest
|
||||
from app.schemas.settings import SettingsWrite
|
||||
from app.services.auth import AuthService
|
||||
from app.services.employee import EmployeeService
|
||||
from app.services.settings import SettingsService
|
||||
|
||||
|
||||
def build_session() -> Session:
|
||||
engine = create_engine(
|
||||
"sqlite+pysqlite:///:memory:",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
session_factory = sessionmaker(bind=engine, autoflush=False, autocommit=False)
|
||||
return session_factory()
|
||||
|
||||
|
||||
def test_employee_can_login_with_seed_default_password() -> None:
|
||||
with build_session() as db:
|
||||
employee = EmployeeService(db).list_employees()[0]
|
||||
result = AuthService(db).login(
|
||||
LoginRequest(username=employee.email, password="123456")
|
||||
)
|
||||
|
||||
assert result.ok is True
|
||||
assert result.user.username == employee.email
|
||||
assert result.user.name == employee.name
|
||||
assert result.user.roleCodes
|
||||
assert result.user.isAdmin is False
|
||||
|
||||
|
||||
def test_admin_can_login_with_database_password() -> None:
|
||||
with build_session() as db:
|
||||
settings_service = SettingsService(db)
|
||||
payload = settings_service.get_settings_snapshot().model_dump()
|
||||
payload["adminForm"]["adminAccount"] = "superadmin"
|
||||
payload["adminForm"]["newPassword"] = "admin123"
|
||||
payload["adminForm"]["confirmPassword"] = "admin123"
|
||||
settings_service.save_settings_snapshot(SettingsWrite(**payload))
|
||||
|
||||
result = AuthService(db).login(
|
||||
LoginRequest(username="superadmin", password="admin123")
|
||||
)
|
||||
|
||||
assert result.ok is True
|
||||
assert result.user.username == "superadmin"
|
||||
assert result.user.isAdmin is True
|
||||
assert result.user.roleCodes == ["manager"]
|
||||
|
||||
|
||||
def test_disabled_employee_cannot_login() -> None:
|
||||
with build_session() as db:
|
||||
service = EmployeeService(db)
|
||||
employee = service.list_employees()[0]
|
||||
service.disable_employee(employee.id)
|
||||
|
||||
try:
|
||||
AuthService(db).login(LoginRequest(username=employee.email, password="123456"))
|
||||
except ValueError as exc:
|
||||
assert "账号或密码错误" in str(exc)
|
||||
else:
|
||||
raise AssertionError("disabled employee login should be rejected")
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from app.db.base import Base
|
||||
from app.schemas.auth import LoginRequest
|
||||
from app.schemas.settings import SettingsWrite
|
||||
from app.services.auth import AuthService
|
||||
from app.services.employee import EmployeeService
|
||||
from app.services.settings import SettingsService
|
||||
|
||||
|
||||
def build_session() -> Session:
|
||||
engine = create_engine(
|
||||
"sqlite+pysqlite:///:memory:",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
session_factory = sessionmaker(bind=engine, autoflush=False, autocommit=False)
|
||||
return session_factory()
|
||||
|
||||
|
||||
def test_employee_can_login_with_seed_default_password() -> None:
|
||||
with build_session() as db:
|
||||
employee = EmployeeService(db).list_employees()[0]
|
||||
result = AuthService(db).login(
|
||||
LoginRequest(username=employee.email, password="123456")
|
||||
)
|
||||
|
||||
assert result.ok is True
|
||||
assert result.user.username == employee.email
|
||||
assert result.user.name == employee.name
|
||||
assert result.user.roleCodes
|
||||
assert result.user.isAdmin is False
|
||||
|
||||
|
||||
def test_admin_can_login_with_database_password() -> None:
|
||||
with build_session() as db:
|
||||
settings_service = SettingsService(db)
|
||||
payload = settings_service.get_settings_snapshot().model_dump()
|
||||
payload["adminForm"]["adminAccount"] = "superadmin"
|
||||
payload["adminForm"]["newPassword"] = "admin123"
|
||||
payload["adminForm"]["confirmPassword"] = "admin123"
|
||||
settings_service.save_settings_snapshot(SettingsWrite(**payload))
|
||||
|
||||
result = AuthService(db).login(
|
||||
LoginRequest(username="superadmin", password="admin123")
|
||||
)
|
||||
|
||||
assert result.ok is True
|
||||
assert result.user.username == "superadmin"
|
||||
assert result.user.isAdmin is True
|
||||
assert result.user.roleCodes == ["manager"]
|
||||
|
||||
|
||||
def test_disabled_employee_cannot_login() -> None:
|
||||
with build_session() as db:
|
||||
service = EmployeeService(db)
|
||||
employee = service.list_employees()[0]
|
||||
service.disable_employee(employee.id)
|
||||
|
||||
try:
|
||||
AuthService(db).login(LoginRequest(username=employee.email, password="123456"))
|
||||
except ValueError as exc:
|
||||
assert "账号或密码错误" in str(exc)
|
||||
else:
|
||||
raise AssertionError("disabled employee login should be rejected")
|
||||
|
||||
@@ -1,132 +1,132 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
import hashlib
|
||||
import json
|
||||
import secrets
|
||||
import tempfile
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from app.core import admin_secret
|
||||
from app.core import secret_box
|
||||
from app.db.base import Base
|
||||
from app.models.system_model_setting import SystemModelSetting
|
||||
from app.models.system_setting import SystemSetting
|
||||
from app.models.system_setting_secret import SystemSettingSecret
|
||||
from app.schemas.settings import SettingsWrite
|
||||
from app.services.settings import SettingsService
|
||||
|
||||
|
||||
def build_session(db_file: Path) -> Session:
|
||||
engine = create_engine(
|
||||
f"sqlite+pysqlite:///{db_file.as_posix()}",
|
||||
connect_args={"check_same_thread": False},
|
||||
)
|
||||
SystemSetting.__table__.create(bind=engine)
|
||||
SystemSettingSecret.__table__.create(bind=engine)
|
||||
SystemModelSetting.__table__.create(bind=engine)
|
||||
session_factory = sessionmaker(bind=engine, autoflush=False, autocommit=False)
|
||||
return session_factory()
|
||||
|
||||
|
||||
def build_temp_secret_dir() -> Path:
|
||||
return Path(tempfile.mkdtemp(prefix="xf-settings-test-"))
|
||||
|
||||
|
||||
def test_settings_service_persists_non_secret_and_secret_fields(monkeypatch) -> None:
|
||||
temp_dir = build_temp_secret_dir()
|
||||
monkeypatch.setattr(secret_box, "SECRET_KEY_FILE", temp_dir / "settings.key")
|
||||
monkeypatch.setattr(Base.metadata, "create_all", lambda *args, **kwargs: None)
|
||||
|
||||
with build_session(temp_dir / "settings.db") as db:
|
||||
service = SettingsService(db)
|
||||
initial_snapshot = service.get_settings_snapshot()
|
||||
payload = initial_snapshot.model_dump()
|
||||
|
||||
payload["companyForm"]["companyName"] = "YGSOFT"
|
||||
payload["companyForm"]["displayName"] = "云广软件"
|
||||
payload["adminForm"]["adminAccount"] = "admin-root"
|
||||
payload["adminForm"]["adminEmail"] = "admin@example.com"
|
||||
payload["adminForm"]["newPassword"] = "54321"
|
||||
payload["adminForm"]["confirmPassword"] = "54321"
|
||||
payload["llmForm"]["mainModel"] = "glm-4.5"
|
||||
payload["llmForm"]["mainApiKey"] = "main-secret"
|
||||
payload["mailForm"]["password"] = "smtp-secret"
|
||||
|
||||
saved_snapshot = service.save_settings_snapshot(SettingsWrite(**payload))
|
||||
|
||||
assert saved_snapshot.companyForm.companyName == "YGSOFT"
|
||||
assert saved_snapshot.companyForm.displayName == "云广软件"
|
||||
assert saved_snapshot.llmForm.mainModel == "glm-4.5"
|
||||
assert saved_snapshot.llmForm.mainApiKey == ""
|
||||
assert saved_snapshot.llmForm.mainApiKeyConfigured is True
|
||||
assert saved_snapshot.mailForm.password == ""
|
||||
assert saved_snapshot.mailForm.passwordConfigured is True
|
||||
assert saved_snapshot.adminForm.newPassword == ""
|
||||
assert saved_snapshot.adminForm.adminPasswordConfigured is True
|
||||
|
||||
model_row = db.get(SystemModelSetting, "main")
|
||||
assert model_row is not None
|
||||
assert model_row.model_name == "glm-4.5"
|
||||
assert model_row.api_key_encrypted
|
||||
|
||||
assert service.load_saved_model_api_key("main") == "main-secret"
|
||||
assert service.verify_admin_login("admin-root", "54321") is not None
|
||||
assert service.verify_admin_login("admin@example.com", "54321") is not None
|
||||
|
||||
|
||||
def test_blank_secret_input_does_not_clear_saved_secret(monkeypatch) -> None:
|
||||
temp_dir = build_temp_secret_dir()
|
||||
monkeypatch.setattr(secret_box, "SECRET_KEY_FILE", temp_dir / "settings.key")
|
||||
monkeypatch.setattr(Base.metadata, "create_all", lambda *args, **kwargs: None)
|
||||
|
||||
with build_session(temp_dir / "settings.db") as db:
|
||||
service = SettingsService(db)
|
||||
first_payload = service.get_settings_snapshot().model_dump()
|
||||
first_payload["llmForm"]["mainApiKey"] = "persisted-key"
|
||||
service.save_settings_snapshot(SettingsWrite(**first_payload))
|
||||
|
||||
second_payload = service.get_settings_snapshot().model_dump()
|
||||
second_payload["llmForm"]["mainApiKey"] = ""
|
||||
service.save_settings_snapshot(SettingsWrite(**second_payload))
|
||||
|
||||
assert service.load_saved_model_api_key("main") == "persisted-key"
|
||||
|
||||
|
||||
def test_legacy_setup_admin_password_is_migrated_to_database(monkeypatch) -> None:
|
||||
temp_dir = build_temp_secret_dir()
|
||||
admin_file = temp_dir / "admin.json"
|
||||
monkeypatch.setattr(admin_secret, "ADMIN_SECRET_FILE", admin_file)
|
||||
monkeypatch.setattr(secret_box, "SECRET_KEY_FILE", temp_dir / "settings.key")
|
||||
monkeypatch.setattr(Base.metadata, "create_all", lambda *args, **kwargs: None)
|
||||
|
||||
password = "setup-secret"
|
||||
salt = secrets.token_bytes(16)
|
||||
derived_key = hashlib.scrypt(password.encode("utf-8"), salt=salt, n=16384, r=8, p=1, dklen=64)
|
||||
admin_file.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"algorithm": "scrypt",
|
||||
"username": "setup-admin",
|
||||
"salt": salt.hex(),
|
||||
"derived_key": derived_key.hex(),
|
||||
"key_length": 64,
|
||||
"N": 16384,
|
||||
"r": 8,
|
||||
"p": 1,
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
with build_session(temp_dir / "settings.db") as db:
|
||||
service = SettingsService(db)
|
||||
snapshot = service.get_settings_snapshot()
|
||||
secrets_row = db.get(SystemSettingSecret, "default")
|
||||
|
||||
assert snapshot.adminForm.adminPasswordConfigured is True
|
||||
assert secrets_row is not None
|
||||
assert secrets_row.admin_password_hash.startswith("scrypt$")
|
||||
assert service.verify_admin_login("setup-admin", password) is not None
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
import hashlib
|
||||
import json
|
||||
import secrets
|
||||
import tempfile
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from app.core import admin_secret
|
||||
from app.core import secret_box
|
||||
from app.db.base import Base
|
||||
from app.models.system_model_setting import SystemModelSetting
|
||||
from app.models.system_setting import SystemSetting
|
||||
from app.models.system_setting_secret import SystemSettingSecret
|
||||
from app.schemas.settings import SettingsWrite
|
||||
from app.services.settings import SettingsService
|
||||
|
||||
|
||||
def build_session(db_file: Path) -> Session:
|
||||
engine = create_engine(
|
||||
f"sqlite+pysqlite:///{db_file.as_posix()}",
|
||||
connect_args={"check_same_thread": False},
|
||||
)
|
||||
SystemSetting.__table__.create(bind=engine)
|
||||
SystemSettingSecret.__table__.create(bind=engine)
|
||||
SystemModelSetting.__table__.create(bind=engine)
|
||||
session_factory = sessionmaker(bind=engine, autoflush=False, autocommit=False)
|
||||
return session_factory()
|
||||
|
||||
|
||||
def build_temp_secret_dir() -> Path:
|
||||
return Path(tempfile.mkdtemp(prefix="xf-settings-test-"))
|
||||
|
||||
|
||||
def test_settings_service_persists_non_secret_and_secret_fields(monkeypatch) -> None:
|
||||
temp_dir = build_temp_secret_dir()
|
||||
monkeypatch.setattr(secret_box, "SECRET_KEY_FILE", temp_dir / "settings.key")
|
||||
monkeypatch.setattr(Base.metadata, "create_all", lambda *args, **kwargs: None)
|
||||
|
||||
with build_session(temp_dir / "settings.db") as db:
|
||||
service = SettingsService(db)
|
||||
initial_snapshot = service.get_settings_snapshot()
|
||||
payload = initial_snapshot.model_dump()
|
||||
|
||||
payload["companyForm"]["companyName"] = "YGSOFT"
|
||||
payload["companyForm"]["displayName"] = "云广软件"
|
||||
payload["adminForm"]["adminAccount"] = "admin-root"
|
||||
payload["adminForm"]["adminEmail"] = "admin@example.com"
|
||||
payload["adminForm"]["newPassword"] = "54321"
|
||||
payload["adminForm"]["confirmPassword"] = "54321"
|
||||
payload["llmForm"]["mainModel"] = "glm-4.5"
|
||||
payload["llmForm"]["mainApiKey"] = "main-secret"
|
||||
payload["mailForm"]["password"] = "smtp-secret"
|
||||
|
||||
saved_snapshot = service.save_settings_snapshot(SettingsWrite(**payload))
|
||||
|
||||
assert saved_snapshot.companyForm.companyName == "YGSOFT"
|
||||
assert saved_snapshot.companyForm.displayName == "云广软件"
|
||||
assert saved_snapshot.llmForm.mainModel == "glm-4.5"
|
||||
assert saved_snapshot.llmForm.mainApiKey == ""
|
||||
assert saved_snapshot.llmForm.mainApiKeyConfigured is True
|
||||
assert saved_snapshot.mailForm.password == ""
|
||||
assert saved_snapshot.mailForm.passwordConfigured is True
|
||||
assert saved_snapshot.adminForm.newPassword == ""
|
||||
assert saved_snapshot.adminForm.adminPasswordConfigured is True
|
||||
|
||||
model_row = db.get(SystemModelSetting, "main")
|
||||
assert model_row is not None
|
||||
assert model_row.model_name == "glm-4.5"
|
||||
assert model_row.api_key_encrypted
|
||||
|
||||
assert service.load_saved_model_api_key("main") == "main-secret"
|
||||
assert service.verify_admin_login("admin-root", "54321") is not None
|
||||
assert service.verify_admin_login("admin@example.com", "54321") is not None
|
||||
|
||||
|
||||
def test_blank_secret_input_does_not_clear_saved_secret(monkeypatch) -> None:
|
||||
temp_dir = build_temp_secret_dir()
|
||||
monkeypatch.setattr(secret_box, "SECRET_KEY_FILE", temp_dir / "settings.key")
|
||||
monkeypatch.setattr(Base.metadata, "create_all", lambda *args, **kwargs: None)
|
||||
|
||||
with build_session(temp_dir / "settings.db") as db:
|
||||
service = SettingsService(db)
|
||||
first_payload = service.get_settings_snapshot().model_dump()
|
||||
first_payload["llmForm"]["mainApiKey"] = "persisted-key"
|
||||
service.save_settings_snapshot(SettingsWrite(**first_payload))
|
||||
|
||||
second_payload = service.get_settings_snapshot().model_dump()
|
||||
second_payload["llmForm"]["mainApiKey"] = ""
|
||||
service.save_settings_snapshot(SettingsWrite(**second_payload))
|
||||
|
||||
assert service.load_saved_model_api_key("main") == "persisted-key"
|
||||
|
||||
|
||||
def test_legacy_setup_admin_password_is_migrated_to_database(monkeypatch) -> None:
|
||||
temp_dir = build_temp_secret_dir()
|
||||
admin_file = temp_dir / "admin.json"
|
||||
monkeypatch.setattr(admin_secret, "ADMIN_SECRET_FILE", admin_file)
|
||||
monkeypatch.setattr(secret_box, "SECRET_KEY_FILE", temp_dir / "settings.key")
|
||||
monkeypatch.setattr(Base.metadata, "create_all", lambda *args, **kwargs: None)
|
||||
|
||||
password = "setup-secret"
|
||||
salt = secrets.token_bytes(16)
|
||||
derived_key = hashlib.scrypt(password.encode("utf-8"), salt=salt, n=16384, r=8, p=1, dklen=64)
|
||||
admin_file.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"algorithm": "scrypt",
|
||||
"username": "setup-admin",
|
||||
"salt": salt.hex(),
|
||||
"derived_key": derived_key.hex(),
|
||||
"key_length": 64,
|
||||
"N": 16384,
|
||||
"r": 8,
|
||||
"p": 1,
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
with build_session(temp_dir / "settings.db") as db:
|
||||
service = SettingsService(db)
|
||||
snapshot = service.get_settings_snapshot()
|
||||
secrets_row = db.get(SystemSettingSecret, "default")
|
||||
|
||||
assert snapshot.adminForm.adminPasswordConfigured is True
|
||||
assert secrets_row is not None
|
||||
assert secrets_row.admin_password_hash.startswith("scrypt$")
|
||||
assert service.verify_admin_login("setup-admin", password) is not None
|
||||
|
||||
Reference in New Issue
Block a user