From efe5d240ae0724757e5af7ef49fe5f4151019fc4 Mon Sep 17 00:00:00 2001 From: Developer Date: Tue, 17 Mar 2026 17:28:36 +0800 Subject: [PATCH] =?UTF-8?q?feat(backend):=20=E6=B7=BB=E5=8A=A0=E6=A0=B8?= =?UTF-8?q?=E5=BF=83=E6=9E=B6=E6=9E=84=E6=A8=A1=E5=9D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 添加认证模块 (auth.py) - 添加 CRUD 基础操作 (crud.py) - 添加异常处理 (exceptions.py) - 添加日志模块 (logging.py) - 添加响应格式 (response.py) - 添加依赖注入 (dependencies.py) Co-Authored-By: Claude Opus 4.6 --- backend/app/api/dependencies.py | 20 ++++ backend/app/api/response.py | 75 ++++++++++++++ backend/app/core/auth.py | 38 +++++++ backend/app/core/crud.py | 178 ++++++++++++++++++++++++++++++++ backend/app/core/exceptions.py | 119 +++++++++++++++++++++ backend/app/core/logging.py | 66 ++++++++++++ 6 files changed, 496 insertions(+) create mode 100644 backend/app/api/dependencies.py create mode 100644 backend/app/api/response.py create mode 100644 backend/app/core/auth.py create mode 100644 backend/app/core/crud.py create mode 100644 backend/app/core/exceptions.py create mode 100644 backend/app/core/logging.py diff --git a/backend/app/api/dependencies.py b/backend/app/api/dependencies.py new file mode 100644 index 0000000..09721cc --- /dev/null +++ b/backend/app/api/dependencies.py @@ -0,0 +1,20 @@ +""" +API Dependencies +API 依赖项 +""" +from typing import Annotated +from fastapi import Depends +from app.core.auth import verify_api_key + + +# Type alias for API key dependency +ApiKey = Annotated[str, Depends(verify_api_key)] + + +# Optional API key (for endpoints that can work with or without auth) +async def get_optional_api_key(api_key: str = None) -> Optional[str]: + """Get optional API key""" + return api_key + + +OptionalApiKey = Annotated[Optional[str], Depends(get_optional_api_key)] diff --git a/backend/app/api/response.py b/backend/app/api/response.py new file mode 100644 index 0000000..22063c1 --- /dev/null +++ b/backend/app/api/response.py @@ -0,0 +1,75 @@ +""" +API Response Wrapper +统一 API 响应格式 +""" +from datetime import datetime +from typing import Any, Generic, List, Optional, TypeVar +from uuid import UUID +from pydantic import BaseModel, ConfigDict, Field + +T = TypeVar("T") + + +class ApiResponse(BaseModel, Generic[T]): + """统一 API 响应格式""" + model_config = ConfigDict(from_attributes=True) + + success: bool = True + message: str = "Success" + data: Optional[T] = None + error: Optional[dict] = None + timestamp: datetime = Field(default_factory=datetime.utcnow) + + @classmethod + def ok(cls, data: T = None, message: str = "Success") -> "ApiResponse[T]": + """成功响应""" + return cls(success=True, message=message, data=data) + + @classmethod + def fail(cls, message: str, error: dict = None) -> "ApiResponse[None]": + """失败响应""" + return cls(success=False, message=message, error=error) + + +class PaginatedResponse(BaseModel, Generic[T]): + """分页响应格式""" + model_config = ConfigDict(from_attributes=True) + + success: bool = True + message: str = "Success" + data: List[T] = [] + pagination: dict = Field(default_factory=lambda: { + "page": 1, + "page_size": 20, + "total": 0, + "total_pages": 0 + }) + + @classmethod + def ok( + cls, + items: List[T], + page: int = 1, + page_size: int = 20, + total: int = 0 + ) -> "PaginatedResponse[T]": + """创建分页响应""" + total_pages = (total + page_size - 1) // page_size if page_size > 0 else 0 + return cls( + success=True, + data=items, + pagination={ + "page": page, + "page_size": page_size, + "total": total, + "total_pages": total_pages + } + ) + + +class ErrorDetail(BaseModel): + """错误详情""" + code: str + message: str + details: Optional[dict] = None + field: Optional[str] = None diff --git a/backend/app/core/auth.py b/backend/app/core/auth.py new file mode 100644 index 0000000..6e2d0ad --- /dev/null +++ b/backend/app/core/auth.py @@ -0,0 +1,38 @@ +""" +API Key Authentication +API Key 认证中间件 +""" +from typing import Optional +from fastapi import Header, HTTPException, Request +from fastapi.security import APIKeyHeader + +from app.core.config import get_settings + +settings = get_settings() + +# API Key header +API_KEY_HEADER = APIKeyHeader(name="X-API-Key", auto_error=False) + + +async def verify_api_key(api_key: Optional[str] = Header(None)) -> str: + """Verify API key from header""" + if not api_key: + raise HTTPException(status_code=401, detail="API key is required") + + # In production, you would validate against a database or cache + # For development, we can use a simple validation + if settings.DEBUG and api_key == "dev-api-key": + return api_key + + # TODO: Implement proper API key validation + # This is a placeholder - in production, validate against stored keys + if len(api_key) < 32: + raise HTTPException(status_code=401, detail="Invalid API key") + + return api_key + + +def create_api_key() -> str: + """Generate a new API key""" + import secrets + return secrets.token_hex(32) diff --git a/backend/app/core/crud.py b/backend/app/core/crud.py new file mode 100644 index 0000000..ede198a --- /dev/null +++ b/backend/app/core/crud.py @@ -0,0 +1,178 @@ +""" +Database CRUD Operations +数据库通用 CRUD 操作 +""" +from typing import Any, Generic, List, Optional, Type, TypeVar +from uuid import UUID +from sqlalchemy import select, func, and_ +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload, joinedload + +from app.core.exceptions import NotFoundException, DuplicateException +from app.core.logging import LoggerMixin + +ModelType = TypeVar("ModelType") + + +class CRUDBase(Generic[ModelType], LoggerMixin): + """Base CRUD class with common operations""" + + def __init__(self, model: Type[ModelType]): + """Initialize CRUD with model""" + self.model = model + + async def get( + self, + db: AsyncSession, + id: UUID, + load_relations: Optional[List[str]] = None + ) -> Optional[ModelType]: + """Get single record by ID""" + query = select(self.model).where(self.model.id == id) + + # Load relationships if specified + if load_relations: + for relation in load_relations: + if hasattr(self.model, relation): + query = query.options(selectinload(getattr(self.model, relation))) + + result = await db.execute(query) + return result.scalar_one_or_none() + + async def get_or_raise( + self, + db: AsyncSession, + id: UUID, + resource_name: str = "Resource", + load_relations: Optional[List[str]] = None + ) -> ModelType: + """Get single record by ID or raise NotFoundException""" + obj = await self.get(db, id, load_relations) + if not obj: + raise NotFoundException(resource_name, id) + return obj + + async def get_multi( + self, + db: AsyncSession, + skip: int = 0, + limit: int = 20, + load_relations: Optional[List[str]] = None, + filters: Optional[dict] = None, + order_by: Optional[str] = "created_at", + descending: bool = True + ) -> tuple[List[ModelType], int]: + """Get multiple records with pagination""" + query = select(self.model) + count_query = select(func.count()).select_from(self.model) + + # Apply filters + if filters: + conditions = [] + for key, value in filters.items(): + if hasattr(self.model, key): + conditions.append(getattr(self.model, key) == value) + if conditions: + query = query.where(and_(*conditions)) + count_query = count_query.where(and_(*conditions)) + + # Load relationships if specified + if load_relations: + for relation in load_relations: + if hasattr(self.model, relation): + query = query.options(selectinload(getattr(self.model, relation))) + + # Count total + total_result = await db.execute(count_query) + total = total_result.scalar() or 0 + + # Apply ordering + if order_by and hasattr(self.model, order_by): + order_column = getattr(self.model, order_by) + if descending: + query = query.order_by(order_column.desc()) + else: + query = query.order_by(order_column.asc()) + + # Apply pagination + query = query.offset(skip).limit(limit) + + result = await db.execute(query) + items = result.scalars().all() + + return list(items), total + + async def create( + self, + db: AsyncSession, + obj_in: Any, + commit: bool = True + ) -> ModelType: + """Create new record""" + obj_data = obj_in.model_dump() if hasattr(obj_in, 'model_dump') else obj_in.dict() + db_obj = self.model(**obj_data) + db.add(db_obj) + + if commit: + await db.commit() + await db.refresh(db_obj) + self.log.debug(f"Created {self.model.__name__}: {db_obj.id}") + + return db_obj + + async def update( + self, + db: AsyncSession, + db_obj: ModelType, + obj_in: Any, + commit: bool = True + ) -> ModelType: + """Update existing record""" + if hasattr(obj_in, 'model_dump'): + obj_data = obj_in.model_dump(exclude_unset=True) + elif hasattr(obj_in, 'dict'): + obj_data = obj_in.dict(exclude_unset=True) + else: + obj_data = obj_in + + for field, value in obj_data.items(): + if hasattr(db_obj, field): + setattr(db_obj, field, value) + + if commit: + await db.commit() + await db.refresh(db_obj) + self.log.debug(f"Updated {self.model.__name__}: {db_obj.id}") + + return db_obj + + async def delete( + self, + db: AsyncSession, + id: UUID, + commit: bool = True + ) -> bool: + """Delete record by ID""" + obj = await self.get(db, id) + if obj: + await db.delete(obj) + if commit: + await db.commit() + self.log.debug(f"Deleted {self.model.__name__}: {id}") + return True + return False + + async def exists( + self, + db: AsyncSession, + filters: dict + ) -> bool: + """Check if record exists""" + query = select(func.count()).select_from(self.model) + for key, value in filters.items(): + if hasattr(self.model, key): + query = query.where(getattr(self.model, key) == value) + + result = await db.execute(query) + count = result.scalar() or 0 + return count > 0 diff --git a/backend/app/core/exceptions.py b/backend/app/core/exceptions.py new file mode 100644 index 0000000..aca75d0 --- /dev/null +++ b/backend/app/core/exceptions.py @@ -0,0 +1,119 @@ +""" +Custom Exceptions +自定义异常类 +""" +from typing import Any, Optional + + +class AppException(Exception): + """基础应用异常""" + + def __init__( + self, + message: str, + code: str = "INTERNAL_ERROR", + status_code: int = 500, + details: Optional[dict] = None + ): + self.message = message + self.code = code + self.status_code = status_code + self.details = details + super().__init__(self.message) + + +class NotFoundException(AppException): + """资源未找到异常""" + + def __init__(self, resource: str, resource_id: Any = None): + message = f"{resource} not found" + if resource_id: + message = f"{resource} with id '{resource_id}' not found" + super().__init__( + message=message, + code="NOT_FOUND", + status_code=404 + ) + + +class ValidationException(AppException): + """验证异常""" + + def __init__(self, message: str, field: str = None, details: dict = None): + super().__init__( + message=message, + code="VALIDATION_ERROR", + status_code=422, + details={"field": field, **(details or {})} + ) + + +class DuplicateException(AppException): + """重复资源异常""" + + def __init__(self, resource: str, field: str = None): + message = f"{resource} already exists" + if field: + message = f"{resource} with {field} already exists" + super().__init__( + message=message, + code="DUPLICATE", + status_code=409 + ) + + +class UnauthorizedException(AppException): + """未授权异常""" + + def __init__(self, message: str = "Unauthorized"): + super().__init__( + message=message, + code="UNAUTHORIZED", + status_code=401 + ) + + +class ForbiddenException(AppException): + """禁止访问异常""" + + def __init__(self, message: str = "Forbidden"): + super().__init__( + message=message, + code="FORBIDDEN", + status_code=403 + ) + + +class RateLimitException(AppException): + """速率限制异常""" + + def __init__(self, message: str = "Rate limit exceeded"): + super().__init__( + message=message, + code="RATE_LIMIT", + status_code=429 + ) + + +class FileProcessingException(AppException): + """文件处理异常""" + + def __init__(self, message: str, file_name: str = None): + details = {"file_name": file_name} if file_name else None + super().__init__( + message=message, + code="FILE_PROCESSING_ERROR", + status_code=422, + details=details + ) + + +class DatabaseException(AppException): + """数据库异常""" + + def __init__(self, message: str = "Database operation failed"): + super().__init__( + message=message, + code="DATABASE_ERROR", + status_code=500 + ) diff --git a/backend/app/core/logging.py b/backend/app/core/logging.py new file mode 100644 index 0000000..f4e737e --- /dev/null +++ b/backend/app/core/logging.py @@ -0,0 +1,66 @@ +""" +Logging Configuration +日志配置 +""" +import logging +import sys +from typing import Any +from logging.handlers import RotatingFileHandler +from pathlib import Path +from app.core.config import get_settings + +settings = get_settings() + +# Log directory +LOG_DIR = Path("./logs") +LOG_DIR.mkdir(exist_ok=True) + + +def setup_logging(name: str = "yg_dataset") -> logging.Logger: + """Setup application logging""" + logger = logging.getLogger(name) + logger.setLevel(logging.DEBUG if settings.DEBUG else logging.INFO) + + # Avoid duplicate handlers + if logger.handlers: + return logger + + # Console handler + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(logging.DEBUG if settings.DEBUG else logging.INFO) + console_formatter = logging.Formatter( + fmt="%(asctime)s | %(levelname)-8s | %(name)s:%(funcName)s:%(lineno)d | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S" + ) + console_handler.setFormatter(console_formatter) + logger.addHandler(console_handler) + + # File handler + file_handler = RotatingFileHandler( + LOG_DIR / f"{name}.log", + maxBytes=10 * 1024 * 1024, # 10MB + backupCount=5, + encoding="utf-8" + ) + file_handler.setLevel(logging.INFO) + file_formatter = logging.Formatter( + fmt="%(asctime)s | %(levelname)-8s | %(name)s:%(funcName)s:%(lineno)d | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S" + ) + file_handler.setFormatter(file_formatter) + logger.addHandler(file_handler) + + return logger + + +# Create default logger +logger = setup_logging() + + +class LoggerMixin: + """Mixin to add logging capability to classes""" + + @property + def log(self) -> logging.Logger: + """Get logger for this class""" + return logging.getLogger(self.__class__.__module__ + "." + self.__class__.__name__)