feat(backend): 添加核心架构模块
- 添加认证模块 (auth.py) - 添加 CRUD 基础操作 (crud.py) - 添加异常处理 (exceptions.py) - 添加日志模块 (logging.py) - 添加响应格式 (response.py) - 添加依赖注入 (dependencies.py) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
20
backend/app/api/dependencies.py
Normal file
20
backend/app/api/dependencies.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
"""
|
||||||
|
API Dependencies
|
||||||
|
API 依赖项
|
||||||
|
"""
|
||||||
|
from typing import Annotated
|
||||||
|
from fastapi import Depends
|
||||||
|
from app.core.auth import verify_api_key
|
||||||
|
|
||||||
|
|
||||||
|
# Type alias for API key dependency
|
||||||
|
ApiKey = Annotated[str, Depends(verify_api_key)]
|
||||||
|
|
||||||
|
|
||||||
|
# Optional API key (for endpoints that can work with or without auth)
|
||||||
|
async def get_optional_api_key(api_key: str = None) -> Optional[str]:
|
||||||
|
"""Get optional API key"""
|
||||||
|
return api_key
|
||||||
|
|
||||||
|
|
||||||
|
OptionalApiKey = Annotated[Optional[str], Depends(get_optional_api_key)]
|
||||||
75
backend/app/api/response.py
Normal file
75
backend/app/api/response.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
"""
|
||||||
|
API Response Wrapper
|
||||||
|
统一 API 响应格式
|
||||||
|
"""
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Generic, List, Optional, TypeVar
|
||||||
|
from uuid import UUID
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class ApiResponse(BaseModel, Generic[T]):
|
||||||
|
"""统一 API 响应格式"""
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
success: bool = True
|
||||||
|
message: str = "Success"
|
||||||
|
data: Optional[T] = None
|
||||||
|
error: Optional[dict] = None
|
||||||
|
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def ok(cls, data: T = None, message: str = "Success") -> "ApiResponse[T]":
|
||||||
|
"""成功响应"""
|
||||||
|
return cls(success=True, message=message, data=data)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def fail(cls, message: str, error: dict = None) -> "ApiResponse[None]":
|
||||||
|
"""失败响应"""
|
||||||
|
return cls(success=False, message=message, error=error)
|
||||||
|
|
||||||
|
|
||||||
|
class PaginatedResponse(BaseModel, Generic[T]):
|
||||||
|
"""分页响应格式"""
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
success: bool = True
|
||||||
|
message: str = "Success"
|
||||||
|
data: List[T] = []
|
||||||
|
pagination: dict = Field(default_factory=lambda: {
|
||||||
|
"page": 1,
|
||||||
|
"page_size": 20,
|
||||||
|
"total": 0,
|
||||||
|
"total_pages": 0
|
||||||
|
})
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def ok(
|
||||||
|
cls,
|
||||||
|
items: List[T],
|
||||||
|
page: int = 1,
|
||||||
|
page_size: int = 20,
|
||||||
|
total: int = 0
|
||||||
|
) -> "PaginatedResponse[T]":
|
||||||
|
"""创建分页响应"""
|
||||||
|
total_pages = (total + page_size - 1) // page_size if page_size > 0 else 0
|
||||||
|
return cls(
|
||||||
|
success=True,
|
||||||
|
data=items,
|
||||||
|
pagination={
|
||||||
|
"page": page,
|
||||||
|
"page_size": page_size,
|
||||||
|
"total": total,
|
||||||
|
"total_pages": total_pages
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ErrorDetail(BaseModel):
|
||||||
|
"""错误详情"""
|
||||||
|
code: str
|
||||||
|
message: str
|
||||||
|
details: Optional[dict] = None
|
||||||
|
field: Optional[str] = None
|
||||||
38
backend/app/core/auth.py
Normal file
38
backend/app/core/auth.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
"""
|
||||||
|
API Key Authentication
|
||||||
|
API Key 认证中间件
|
||||||
|
"""
|
||||||
|
from typing import Optional
|
||||||
|
from fastapi import Header, HTTPException, Request
|
||||||
|
from fastapi.security import APIKeyHeader
|
||||||
|
|
||||||
|
from app.core.config import get_settings
|
||||||
|
|
||||||
|
settings = get_settings()
|
||||||
|
|
||||||
|
# API Key header
|
||||||
|
API_KEY_HEADER = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||||
|
|
||||||
|
|
||||||
|
async def verify_api_key(api_key: Optional[str] = Header(None)) -> str:
|
||||||
|
"""Verify API key from header"""
|
||||||
|
if not api_key:
|
||||||
|
raise HTTPException(status_code=401, detail="API key is required")
|
||||||
|
|
||||||
|
# In production, you would validate against a database or cache
|
||||||
|
# For development, we can use a simple validation
|
||||||
|
if settings.DEBUG and api_key == "dev-api-key":
|
||||||
|
return api_key
|
||||||
|
|
||||||
|
# TODO: Implement proper API key validation
|
||||||
|
# This is a placeholder - in production, validate against stored keys
|
||||||
|
if len(api_key) < 32:
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||||
|
|
||||||
|
return api_key
|
||||||
|
|
||||||
|
|
||||||
|
def create_api_key() -> str:
|
||||||
|
"""Generate a new API key"""
|
||||||
|
import secrets
|
||||||
|
return secrets.token_hex(32)
|
||||||
178
backend/app/core/crud.py
Normal file
178
backend/app/core/crud.py
Normal file
@@ -0,0 +1,178 @@
|
|||||||
|
"""
|
||||||
|
Database CRUD Operations
|
||||||
|
数据库通用 CRUD 操作
|
||||||
|
"""
|
||||||
|
from typing import Any, Generic, List, Optional, Type, TypeVar
|
||||||
|
from uuid import UUID
|
||||||
|
from sqlalchemy import select, func, and_
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm import selectinload, joinedload
|
||||||
|
|
||||||
|
from app.core.exceptions import NotFoundException, DuplicateException
|
||||||
|
from app.core.logging import LoggerMixin
|
||||||
|
|
||||||
|
ModelType = TypeVar("ModelType")
|
||||||
|
|
||||||
|
|
||||||
|
class CRUDBase(Generic[ModelType], LoggerMixin):
|
||||||
|
"""Base CRUD class with common operations"""
|
||||||
|
|
||||||
|
def __init__(self, model: Type[ModelType]):
|
||||||
|
"""Initialize CRUD with model"""
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
async def get(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
id: UUID,
|
||||||
|
load_relations: Optional[List[str]] = None
|
||||||
|
) -> Optional[ModelType]:
|
||||||
|
"""Get single record by ID"""
|
||||||
|
query = select(self.model).where(self.model.id == id)
|
||||||
|
|
||||||
|
# Load relationships if specified
|
||||||
|
if load_relations:
|
||||||
|
for relation in load_relations:
|
||||||
|
if hasattr(self.model, relation):
|
||||||
|
query = query.options(selectinload(getattr(self.model, relation)))
|
||||||
|
|
||||||
|
result = await db.execute(query)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def get_or_raise(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
id: UUID,
|
||||||
|
resource_name: str = "Resource",
|
||||||
|
load_relations: Optional[List[str]] = None
|
||||||
|
) -> ModelType:
|
||||||
|
"""Get single record by ID or raise NotFoundException"""
|
||||||
|
obj = await self.get(db, id, load_relations)
|
||||||
|
if not obj:
|
||||||
|
raise NotFoundException(resource_name, id)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
async def get_multi(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
skip: int = 0,
|
||||||
|
limit: int = 20,
|
||||||
|
load_relations: Optional[List[str]] = None,
|
||||||
|
filters: Optional[dict] = None,
|
||||||
|
order_by: Optional[str] = "created_at",
|
||||||
|
descending: bool = True
|
||||||
|
) -> tuple[List[ModelType], int]:
|
||||||
|
"""Get multiple records with pagination"""
|
||||||
|
query = select(self.model)
|
||||||
|
count_query = select(func.count()).select_from(self.model)
|
||||||
|
|
||||||
|
# Apply filters
|
||||||
|
if filters:
|
||||||
|
conditions = []
|
||||||
|
for key, value in filters.items():
|
||||||
|
if hasattr(self.model, key):
|
||||||
|
conditions.append(getattr(self.model, key) == value)
|
||||||
|
if conditions:
|
||||||
|
query = query.where(and_(*conditions))
|
||||||
|
count_query = count_query.where(and_(*conditions))
|
||||||
|
|
||||||
|
# Load relationships if specified
|
||||||
|
if load_relations:
|
||||||
|
for relation in load_relations:
|
||||||
|
if hasattr(self.model, relation):
|
||||||
|
query = query.options(selectinload(getattr(self.model, relation)))
|
||||||
|
|
||||||
|
# Count total
|
||||||
|
total_result = await db.execute(count_query)
|
||||||
|
total = total_result.scalar() or 0
|
||||||
|
|
||||||
|
# Apply ordering
|
||||||
|
if order_by and hasattr(self.model, order_by):
|
||||||
|
order_column = getattr(self.model, order_by)
|
||||||
|
if descending:
|
||||||
|
query = query.order_by(order_column.desc())
|
||||||
|
else:
|
||||||
|
query = query.order_by(order_column.asc())
|
||||||
|
|
||||||
|
# Apply pagination
|
||||||
|
query = query.offset(skip).limit(limit)
|
||||||
|
|
||||||
|
result = await db.execute(query)
|
||||||
|
items = result.scalars().all()
|
||||||
|
|
||||||
|
return list(items), total
|
||||||
|
|
||||||
|
async def create(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
obj_in: Any,
|
||||||
|
commit: bool = True
|
||||||
|
) -> ModelType:
|
||||||
|
"""Create new record"""
|
||||||
|
obj_data = obj_in.model_dump() if hasattr(obj_in, 'model_dump') else obj_in.dict()
|
||||||
|
db_obj = self.model(**obj_data)
|
||||||
|
db.add(db_obj)
|
||||||
|
|
||||||
|
if commit:
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(db_obj)
|
||||||
|
self.log.debug(f"Created {self.model.__name__}: {db_obj.id}")
|
||||||
|
|
||||||
|
return db_obj
|
||||||
|
|
||||||
|
async def update(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
db_obj: ModelType,
|
||||||
|
obj_in: Any,
|
||||||
|
commit: bool = True
|
||||||
|
) -> ModelType:
|
||||||
|
"""Update existing record"""
|
||||||
|
if hasattr(obj_in, 'model_dump'):
|
||||||
|
obj_data = obj_in.model_dump(exclude_unset=True)
|
||||||
|
elif hasattr(obj_in, 'dict'):
|
||||||
|
obj_data = obj_in.dict(exclude_unset=True)
|
||||||
|
else:
|
||||||
|
obj_data = obj_in
|
||||||
|
|
||||||
|
for field, value in obj_data.items():
|
||||||
|
if hasattr(db_obj, field):
|
||||||
|
setattr(db_obj, field, value)
|
||||||
|
|
||||||
|
if commit:
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(db_obj)
|
||||||
|
self.log.debug(f"Updated {self.model.__name__}: {db_obj.id}")
|
||||||
|
|
||||||
|
return db_obj
|
||||||
|
|
||||||
|
async def delete(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
id: UUID,
|
||||||
|
commit: bool = True
|
||||||
|
) -> bool:
|
||||||
|
"""Delete record by ID"""
|
||||||
|
obj = await self.get(db, id)
|
||||||
|
if obj:
|
||||||
|
await db.delete(obj)
|
||||||
|
if commit:
|
||||||
|
await db.commit()
|
||||||
|
self.log.debug(f"Deleted {self.model.__name__}: {id}")
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def exists(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
filters: dict
|
||||||
|
) -> bool:
|
||||||
|
"""Check if record exists"""
|
||||||
|
query = select(func.count()).select_from(self.model)
|
||||||
|
for key, value in filters.items():
|
||||||
|
if hasattr(self.model, key):
|
||||||
|
query = query.where(getattr(self.model, key) == value)
|
||||||
|
|
||||||
|
result = await db.execute(query)
|
||||||
|
count = result.scalar() or 0
|
||||||
|
return count > 0
|
||||||
119
backend/app/core/exceptions.py
Normal file
119
backend/app/core/exceptions.py
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
"""
|
||||||
|
Custom Exceptions
|
||||||
|
自定义异常类
|
||||||
|
"""
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class AppException(Exception):
|
||||||
|
"""基础应用异常"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
code: str = "INTERNAL_ERROR",
|
||||||
|
status_code: int = 500,
|
||||||
|
details: Optional[dict] = None
|
||||||
|
):
|
||||||
|
self.message = message
|
||||||
|
self.code = code
|
||||||
|
self.status_code = status_code
|
||||||
|
self.details = details
|
||||||
|
super().__init__(self.message)
|
||||||
|
|
||||||
|
|
||||||
|
class NotFoundException(AppException):
|
||||||
|
"""资源未找到异常"""
|
||||||
|
|
||||||
|
def __init__(self, resource: str, resource_id: Any = None):
|
||||||
|
message = f"{resource} not found"
|
||||||
|
if resource_id:
|
||||||
|
message = f"{resource} with id '{resource_id}' not found"
|
||||||
|
super().__init__(
|
||||||
|
message=message,
|
||||||
|
code="NOT_FOUND",
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ValidationException(AppException):
|
||||||
|
"""验证异常"""
|
||||||
|
|
||||||
|
def __init__(self, message: str, field: str = None, details: dict = None):
|
||||||
|
super().__init__(
|
||||||
|
message=message,
|
||||||
|
code="VALIDATION_ERROR",
|
||||||
|
status_code=422,
|
||||||
|
details={"field": field, **(details or {})}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DuplicateException(AppException):
|
||||||
|
"""重复资源异常"""
|
||||||
|
|
||||||
|
def __init__(self, resource: str, field: str = None):
|
||||||
|
message = f"{resource} already exists"
|
||||||
|
if field:
|
||||||
|
message = f"{resource} with {field} already exists"
|
||||||
|
super().__init__(
|
||||||
|
message=message,
|
||||||
|
code="DUPLICATE",
|
||||||
|
status_code=409
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UnauthorizedException(AppException):
|
||||||
|
"""未授权异常"""
|
||||||
|
|
||||||
|
def __init__(self, message: str = "Unauthorized"):
|
||||||
|
super().__init__(
|
||||||
|
message=message,
|
||||||
|
code="UNAUTHORIZED",
|
||||||
|
status_code=401
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ForbiddenException(AppException):
|
||||||
|
"""禁止访问异常"""
|
||||||
|
|
||||||
|
def __init__(self, message: str = "Forbidden"):
|
||||||
|
super().__init__(
|
||||||
|
message=message,
|
||||||
|
code="FORBIDDEN",
|
||||||
|
status_code=403
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimitException(AppException):
|
||||||
|
"""速率限制异常"""
|
||||||
|
|
||||||
|
def __init__(self, message: str = "Rate limit exceeded"):
|
||||||
|
super().__init__(
|
||||||
|
message=message,
|
||||||
|
code="RATE_LIMIT",
|
||||||
|
status_code=429
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FileProcessingException(AppException):
|
||||||
|
"""文件处理异常"""
|
||||||
|
|
||||||
|
def __init__(self, message: str, file_name: str = None):
|
||||||
|
details = {"file_name": file_name} if file_name else None
|
||||||
|
super().__init__(
|
||||||
|
message=message,
|
||||||
|
code="FILE_PROCESSING_ERROR",
|
||||||
|
status_code=422,
|
||||||
|
details=details
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseException(AppException):
|
||||||
|
"""数据库异常"""
|
||||||
|
|
||||||
|
def __init__(self, message: str = "Database operation failed"):
|
||||||
|
super().__init__(
|
||||||
|
message=message,
|
||||||
|
code="DATABASE_ERROR",
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
66
backend/app/core/logging.py
Normal file
66
backend/app/core/logging.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
"""
|
||||||
|
Logging Configuration
|
||||||
|
日志配置
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from typing import Any
|
||||||
|
from logging.handlers import RotatingFileHandler
|
||||||
|
from pathlib import Path
|
||||||
|
from app.core.config import get_settings
|
||||||
|
|
||||||
|
settings = get_settings()
|
||||||
|
|
||||||
|
# Log directory
|
||||||
|
LOG_DIR = Path("./logs")
|
||||||
|
LOG_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logging(name: str = "yg_dataset") -> logging.Logger:
|
||||||
|
"""Setup application logging"""
|
||||||
|
logger = logging.getLogger(name)
|
||||||
|
logger.setLevel(logging.DEBUG if settings.DEBUG else logging.INFO)
|
||||||
|
|
||||||
|
# Avoid duplicate handlers
|
||||||
|
if logger.handlers:
|
||||||
|
return logger
|
||||||
|
|
||||||
|
# Console handler
|
||||||
|
console_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
console_handler.setLevel(logging.DEBUG if settings.DEBUG else logging.INFO)
|
||||||
|
console_formatter = logging.Formatter(
|
||||||
|
fmt="%(asctime)s | %(levelname)-8s | %(name)s:%(funcName)s:%(lineno)d | %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S"
|
||||||
|
)
|
||||||
|
console_handler.setFormatter(console_formatter)
|
||||||
|
logger.addHandler(console_handler)
|
||||||
|
|
||||||
|
# File handler
|
||||||
|
file_handler = RotatingFileHandler(
|
||||||
|
LOG_DIR / f"{name}.log",
|
||||||
|
maxBytes=10 * 1024 * 1024, # 10MB
|
||||||
|
backupCount=5,
|
||||||
|
encoding="utf-8"
|
||||||
|
)
|
||||||
|
file_handler.setLevel(logging.INFO)
|
||||||
|
file_formatter = logging.Formatter(
|
||||||
|
fmt="%(asctime)s | %(levelname)-8s | %(name)s:%(funcName)s:%(lineno)d | %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S"
|
||||||
|
)
|
||||||
|
file_handler.setFormatter(file_formatter)
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
|
||||||
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
# Create default logger
|
||||||
|
logger = setup_logging()
|
||||||
|
|
||||||
|
|
||||||
|
class LoggerMixin:
|
||||||
|
"""Mixin to add logging capability to classes"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def log(self) -> logging.Logger:
|
||||||
|
"""Get logger for this class"""
|
||||||
|
return logging.getLogger(self.__class__.__module__ + "." + self.__class__.__name__)
|
||||||
Reference in New Issue
Block a user