feat(backend): 添加核心架构模块
- 添加认证模块 (auth.py) - 添加 CRUD 基础操作 (crud.py) - 添加异常处理 (exceptions.py) - 添加日志模块 (logging.py) - 添加响应格式 (response.py) - 添加依赖注入 (dependencies.py) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
20
backend/app/api/dependencies.py
Normal file
20
backend/app/api/dependencies.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""
|
||||
API Dependencies
|
||||
API 依赖项
|
||||
"""
|
||||
from typing import Annotated
|
||||
from fastapi import Depends
|
||||
from app.core.auth import verify_api_key
|
||||
|
||||
|
||||
# Type alias for API key dependency
|
||||
ApiKey = Annotated[str, Depends(verify_api_key)]
|
||||
|
||||
|
||||
# Optional API key (for endpoints that can work with or without auth)
|
||||
async def get_optional_api_key(api_key: str = None) -> Optional[str]:
|
||||
"""Get optional API key"""
|
||||
return api_key
|
||||
|
||||
|
||||
OptionalApiKey = Annotated[Optional[str], Depends(get_optional_api_key)]
|
||||
75
backend/app/api/response.py
Normal file
75
backend/app/api/response.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""
|
||||
API Response Wrapper
|
||||
统一 API 响应格式
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import Any, Generic, List, Optional, TypeVar
|
||||
from uuid import UUID
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class ApiResponse(BaseModel, Generic[T]):
|
||||
"""统一 API 响应格式"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
success: bool = True
|
||||
message: str = "Success"
|
||||
data: Optional[T] = None
|
||||
error: Optional[dict] = None
|
||||
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
@classmethod
|
||||
def ok(cls, data: T = None, message: str = "Success") -> "ApiResponse[T]":
|
||||
"""成功响应"""
|
||||
return cls(success=True, message=message, data=data)
|
||||
|
||||
@classmethod
|
||||
def fail(cls, message: str, error: dict = None) -> "ApiResponse[None]":
|
||||
"""失败响应"""
|
||||
return cls(success=False, message=message, error=error)
|
||||
|
||||
|
||||
class PaginatedResponse(BaseModel, Generic[T]):
|
||||
"""分页响应格式"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
success: bool = True
|
||||
message: str = "Success"
|
||||
data: List[T] = []
|
||||
pagination: dict = Field(default_factory=lambda: {
|
||||
"page": 1,
|
||||
"page_size": 20,
|
||||
"total": 0,
|
||||
"total_pages": 0
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def ok(
|
||||
cls,
|
||||
items: List[T],
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
total: int = 0
|
||||
) -> "PaginatedResponse[T]":
|
||||
"""创建分页响应"""
|
||||
total_pages = (total + page_size - 1) // page_size if page_size > 0 else 0
|
||||
return cls(
|
||||
success=True,
|
||||
data=items,
|
||||
pagination={
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"total": total,
|
||||
"total_pages": total_pages
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class ErrorDetail(BaseModel):
|
||||
"""错误详情"""
|
||||
code: str
|
||||
message: str
|
||||
details: Optional[dict] = None
|
||||
field: Optional[str] = None
|
||||
38
backend/app/core/auth.py
Normal file
38
backend/app/core/auth.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""
|
||||
API Key Authentication
|
||||
API Key 认证中间件
|
||||
"""
|
||||
from typing import Optional
|
||||
from fastapi import Header, HTTPException, Request
|
||||
from fastapi.security import APIKeyHeader
|
||||
|
||||
from app.core.config import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
# API Key header
|
||||
API_KEY_HEADER = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||
|
||||
|
||||
async def verify_api_key(api_key: Optional[str] = Header(None)) -> str:
|
||||
"""Verify API key from header"""
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=401, detail="API key is required")
|
||||
|
||||
# In production, you would validate against a database or cache
|
||||
# For development, we can use a simple validation
|
||||
if settings.DEBUG and api_key == "dev-api-key":
|
||||
return api_key
|
||||
|
||||
# TODO: Implement proper API key validation
|
||||
# This is a placeholder - in production, validate against stored keys
|
||||
if len(api_key) < 32:
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
return api_key
|
||||
|
||||
|
||||
def create_api_key() -> str:
|
||||
"""Generate a new API key"""
|
||||
import secrets
|
||||
return secrets.token_hex(32)
|
||||
178
backend/app/core/crud.py
Normal file
178
backend/app/core/crud.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""
|
||||
Database CRUD Operations
|
||||
数据库通用 CRUD 操作
|
||||
"""
|
||||
from typing import Any, Generic, List, Optional, Type, TypeVar
|
||||
from uuid import UUID
|
||||
from sqlalchemy import select, func, and_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload, joinedload
|
||||
|
||||
from app.core.exceptions import NotFoundException, DuplicateException
|
||||
from app.core.logging import LoggerMixin
|
||||
|
||||
ModelType = TypeVar("ModelType")
|
||||
|
||||
|
||||
class CRUDBase(Generic[ModelType], LoggerMixin):
|
||||
"""Base CRUD class with common operations"""
|
||||
|
||||
def __init__(self, model: Type[ModelType]):
|
||||
"""Initialize CRUD with model"""
|
||||
self.model = model
|
||||
|
||||
async def get(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
id: UUID,
|
||||
load_relations: Optional[List[str]] = None
|
||||
) -> Optional[ModelType]:
|
||||
"""Get single record by ID"""
|
||||
query = select(self.model).where(self.model.id == id)
|
||||
|
||||
# Load relationships if specified
|
||||
if load_relations:
|
||||
for relation in load_relations:
|
||||
if hasattr(self.model, relation):
|
||||
query = query.options(selectinload(getattr(self.model, relation)))
|
||||
|
||||
result = await db.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_or_raise(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
id: UUID,
|
||||
resource_name: str = "Resource",
|
||||
load_relations: Optional[List[str]] = None
|
||||
) -> ModelType:
|
||||
"""Get single record by ID or raise NotFoundException"""
|
||||
obj = await self.get(db, id, load_relations)
|
||||
if not obj:
|
||||
raise NotFoundException(resource_name, id)
|
||||
return obj
|
||||
|
||||
async def get_multi(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
skip: int = 0,
|
||||
limit: int = 20,
|
||||
load_relations: Optional[List[str]] = None,
|
||||
filters: Optional[dict] = None,
|
||||
order_by: Optional[str] = "created_at",
|
||||
descending: bool = True
|
||||
) -> tuple[List[ModelType], int]:
|
||||
"""Get multiple records with pagination"""
|
||||
query = select(self.model)
|
||||
count_query = select(func.count()).select_from(self.model)
|
||||
|
||||
# Apply filters
|
||||
if filters:
|
||||
conditions = []
|
||||
for key, value in filters.items():
|
||||
if hasattr(self.model, key):
|
||||
conditions.append(getattr(self.model, key) == value)
|
||||
if conditions:
|
||||
query = query.where(and_(*conditions))
|
||||
count_query = count_query.where(and_(*conditions))
|
||||
|
||||
# Load relationships if specified
|
||||
if load_relations:
|
||||
for relation in load_relations:
|
||||
if hasattr(self.model, relation):
|
||||
query = query.options(selectinload(getattr(self.model, relation)))
|
||||
|
||||
# Count total
|
||||
total_result = await db.execute(count_query)
|
||||
total = total_result.scalar() or 0
|
||||
|
||||
# Apply ordering
|
||||
if order_by and hasattr(self.model, order_by):
|
||||
order_column = getattr(self.model, order_by)
|
||||
if descending:
|
||||
query = query.order_by(order_column.desc())
|
||||
else:
|
||||
query = query.order_by(order_column.asc())
|
||||
|
||||
# Apply pagination
|
||||
query = query.offset(skip).limit(limit)
|
||||
|
||||
result = await db.execute(query)
|
||||
items = result.scalars().all()
|
||||
|
||||
return list(items), total
|
||||
|
||||
async def create(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
obj_in: Any,
|
||||
commit: bool = True
|
||||
) -> ModelType:
|
||||
"""Create new record"""
|
||||
obj_data = obj_in.model_dump() if hasattr(obj_in, 'model_dump') else obj_in.dict()
|
||||
db_obj = self.model(**obj_data)
|
||||
db.add(db_obj)
|
||||
|
||||
if commit:
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
self.log.debug(f"Created {self.model.__name__}: {db_obj.id}")
|
||||
|
||||
return db_obj
|
||||
|
||||
async def update(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
db_obj: ModelType,
|
||||
obj_in: Any,
|
||||
commit: bool = True
|
||||
) -> ModelType:
|
||||
"""Update existing record"""
|
||||
if hasattr(obj_in, 'model_dump'):
|
||||
obj_data = obj_in.model_dump(exclude_unset=True)
|
||||
elif hasattr(obj_in, 'dict'):
|
||||
obj_data = obj_in.dict(exclude_unset=True)
|
||||
else:
|
||||
obj_data = obj_in
|
||||
|
||||
for field, value in obj_data.items():
|
||||
if hasattr(db_obj, field):
|
||||
setattr(db_obj, field, value)
|
||||
|
||||
if commit:
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
self.log.debug(f"Updated {self.model.__name__}: {db_obj.id}")
|
||||
|
||||
return db_obj
|
||||
|
||||
async def delete(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
id: UUID,
|
||||
commit: bool = True
|
||||
) -> bool:
|
||||
"""Delete record by ID"""
|
||||
obj = await self.get(db, id)
|
||||
if obj:
|
||||
await db.delete(obj)
|
||||
if commit:
|
||||
await db.commit()
|
||||
self.log.debug(f"Deleted {self.model.__name__}: {id}")
|
||||
return True
|
||||
return False
|
||||
|
||||
async def exists(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
filters: dict
|
||||
) -> bool:
|
||||
"""Check if record exists"""
|
||||
query = select(func.count()).select_from(self.model)
|
||||
for key, value in filters.items():
|
||||
if hasattr(self.model, key):
|
||||
query = query.where(getattr(self.model, key) == value)
|
||||
|
||||
result = await db.execute(query)
|
||||
count = result.scalar() or 0
|
||||
return count > 0
|
||||
119
backend/app/core/exceptions.py
Normal file
119
backend/app/core/exceptions.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""
|
||||
Custom Exceptions
|
||||
自定义异常类
|
||||
"""
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
class AppException(Exception):
|
||||
"""基础应用异常"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
code: str = "INTERNAL_ERROR",
|
||||
status_code: int = 500,
|
||||
details: Optional[dict] = None
|
||||
):
|
||||
self.message = message
|
||||
self.code = code
|
||||
self.status_code = status_code
|
||||
self.details = details
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class NotFoundException(AppException):
|
||||
"""资源未找到异常"""
|
||||
|
||||
def __init__(self, resource: str, resource_id: Any = None):
|
||||
message = f"{resource} not found"
|
||||
if resource_id:
|
||||
message = f"{resource} with id '{resource_id}' not found"
|
||||
super().__init__(
|
||||
message=message,
|
||||
code="NOT_FOUND",
|
||||
status_code=404
|
||||
)
|
||||
|
||||
|
||||
class ValidationException(AppException):
|
||||
"""验证异常"""
|
||||
|
||||
def __init__(self, message: str, field: str = None, details: dict = None):
|
||||
super().__init__(
|
||||
message=message,
|
||||
code="VALIDATION_ERROR",
|
||||
status_code=422,
|
||||
details={"field": field, **(details or {})}
|
||||
)
|
||||
|
||||
|
||||
class DuplicateException(AppException):
|
||||
"""重复资源异常"""
|
||||
|
||||
def __init__(self, resource: str, field: str = None):
|
||||
message = f"{resource} already exists"
|
||||
if field:
|
||||
message = f"{resource} with {field} already exists"
|
||||
super().__init__(
|
||||
message=message,
|
||||
code="DUPLICATE",
|
||||
status_code=409
|
||||
)
|
||||
|
||||
|
||||
class UnauthorizedException(AppException):
|
||||
"""未授权异常"""
|
||||
|
||||
def __init__(self, message: str = "Unauthorized"):
|
||||
super().__init__(
|
||||
message=message,
|
||||
code="UNAUTHORIZED",
|
||||
status_code=401
|
||||
)
|
||||
|
||||
|
||||
class ForbiddenException(AppException):
|
||||
"""禁止访问异常"""
|
||||
|
||||
def __init__(self, message: str = "Forbidden"):
|
||||
super().__init__(
|
||||
message=message,
|
||||
code="FORBIDDEN",
|
||||
status_code=403
|
||||
)
|
||||
|
||||
|
||||
class RateLimitException(AppException):
|
||||
"""速率限制异常"""
|
||||
|
||||
def __init__(self, message: str = "Rate limit exceeded"):
|
||||
super().__init__(
|
||||
message=message,
|
||||
code="RATE_LIMIT",
|
||||
status_code=429
|
||||
)
|
||||
|
||||
|
||||
class FileProcessingException(AppException):
|
||||
"""文件处理异常"""
|
||||
|
||||
def __init__(self, message: str, file_name: str = None):
|
||||
details = {"file_name": file_name} if file_name else None
|
||||
super().__init__(
|
||||
message=message,
|
||||
code="FILE_PROCESSING_ERROR",
|
||||
status_code=422,
|
||||
details=details
|
||||
)
|
||||
|
||||
|
||||
class DatabaseException(AppException):
|
||||
"""数据库异常"""
|
||||
|
||||
def __init__(self, message: str = "Database operation failed"):
|
||||
super().__init__(
|
||||
message=message,
|
||||
code="DATABASE_ERROR",
|
||||
status_code=500
|
||||
)
|
||||
66
backend/app/core/logging.py
Normal file
66
backend/app/core/logging.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""
|
||||
Logging Configuration
|
||||
日志配置
|
||||
"""
|
||||
import logging
|
||||
import sys
|
||||
from typing import Any
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from pathlib import Path
|
||||
from app.core.config import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
# Log directory
|
||||
LOG_DIR = Path("./logs")
|
||||
LOG_DIR.mkdir(exist_ok=True)
|
||||
|
||||
|
||||
def setup_logging(name: str = "yg_dataset") -> logging.Logger:
|
||||
"""Setup application logging"""
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(logging.DEBUG if settings.DEBUG else logging.INFO)
|
||||
|
||||
# Avoid duplicate handlers
|
||||
if logger.handlers:
|
||||
return logger
|
||||
|
||||
# Console handler
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setLevel(logging.DEBUG if settings.DEBUG else logging.INFO)
|
||||
console_formatter = logging.Formatter(
|
||||
fmt="%(asctime)s | %(levelname)-8s | %(name)s:%(funcName)s:%(lineno)d | %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
console_handler.setFormatter(console_formatter)
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
# File handler
|
||||
file_handler = RotatingFileHandler(
|
||||
LOG_DIR / f"{name}.log",
|
||||
maxBytes=10 * 1024 * 1024, # 10MB
|
||||
backupCount=5,
|
||||
encoding="utf-8"
|
||||
)
|
||||
file_handler.setLevel(logging.INFO)
|
||||
file_formatter = logging.Formatter(
|
||||
fmt="%(asctime)s | %(levelname)-8s | %(name)s:%(funcName)s:%(lineno)d | %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
file_handler.setFormatter(file_formatter)
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
# Create default logger
|
||||
logger = setup_logging()
|
||||
|
||||
|
||||
class LoggerMixin:
|
||||
"""Mixin to add logging capability to classes"""
|
||||
|
||||
@property
|
||||
def log(self) -> logging.Logger:
|
||||
"""Get logger for this class"""
|
||||
return logging.getLogger(self.__class__.__module__ + "." + self.__class__.__name__)
|
||||
Reference in New Issue
Block a user