""" Database Configuration and Session Management 支持 SQLite 和 PostgreSQL """ import logging from contextlib import asynccontextmanager from typing import AsyncGenerator from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker from sqlalchemy.orm import DeclarativeBase from sqlalchemy import create_engine, event from sqlalchemy.pool import NullPool from app.core.config import get_settings logger = logging.getLogger(__name__) settings = get_settings() def get_engine_config(): """根据数据库类型返回引擎配置""" if settings.DATABASE_URL.startswith("sqlite"): return {"echo": settings.DEBUG, "poolclass": NullPool} else: return { "echo": settings.DEBUG, "pool_pre_ping": True, "pool_size": 10, "max_overflow": 20, "pool_recycle": 3600, "pool_timeout": 30, } # Async engine for FastAPI async_engine = create_async_engine( settings.DATABASE_URL, **get_engine_config() ) # Sync engine for migrations (use NullPool for SQLite) sync_engine = create_engine( settings.DATABASE_URL_SYNC, echo=settings.DEBUG, pool_pre_ping=True, poolclass=NullPool if settings.DATABASE_URL_SYNC.startswith("sqlite") else None, ) # Async session factory AsyncSessionLocal = async_sessionmaker( async_engine, class_=AsyncSession, expire_on_commit=False, autocommit=False, autoflush=False, ) class Base(DeclarativeBase): """Base class for all models""" pass async def init_db(): """Initialize database tables""" logger.info("Initializing database...") async with async_engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) logger.info("Database initialized successfully") async def close_db(): """Close database connections""" logger.info("Closing database connections...") await async_engine.dispose() logger.info("Database connections closed") @asynccontextmanager async def get_db_session() -> AsyncGenerator[AsyncSession, None]: """Context manager for database sessions with automatic cleanup""" session = AsyncSessionLocal() try: yield session except Exception as e: logger.error(f"Database session error: {str(e)}") await session.rollback() raise finally: await session.close() async def get_db() -> AsyncSession: """Dependency for getting database session""" async with AsyncSessionLocal() as session: try: yield session except Exception as e: logger.error(f"Database error in dependency: {str(e)}") await session.rollback() raise finally: await session.close()