""" Database Configuration and Session Management 支持 SQLite 和 PostgreSQL """ import logging from contextlib import asynccontextmanager from typing import AsyncGenerator from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker from sqlalchemy.orm import DeclarativeBase from sqlalchemy import create_engine, event, inspect, text from sqlalchemy.pool import NullPool from app.core.config import get_settings logger = logging.getLogger(__name__) settings = get_settings() def get_engine_config(): """根据数据库类型返回引擎配置""" if settings.DATABASE_URL.startswith("sqlite"): return {"echo": settings.DEBUG, "poolclass": NullPool} else: return { "echo": settings.DEBUG, "pool_pre_ping": True, "pool_size": 10, "max_overflow": 20, "pool_recycle": 3600, "pool_timeout": 30, } # Async engine for FastAPI async_engine = create_async_engine( settings.DATABASE_URL, **get_engine_config() ) # Sync engine for migrations (use NullPool for SQLite) sync_engine = create_engine( settings.DATABASE_URL_SYNC, echo=settings.DEBUG, pool_pre_ping=True, poolclass=NullPool if settings.DATABASE_URL_SYNC.startswith("sqlite") else None, ) # Async session factory AsyncSessionLocal = async_sessionmaker( async_engine, class_=AsyncSession, expire_on_commit=False, autocommit=False, autoflush=False, ) class Base(DeclarativeBase): """Base class for all models""" pass async def init_db(): """Initialize database tables""" logger.info("Initializing database...") async with async_engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) await conn.run_sync(_ensure_legacy_columns) logger.info("Database initialized successfully") def _ensure_legacy_columns(sync_conn): """Patch legacy tables with newly introduced columns.""" inspector = inspect(sync_conn) if "model_configs" not in inspector.get_table_names(): return columns = {column["name"] for column in inspector.get_columns("model_configs")} if "model_type" in columns: return logger.info("Adding missing model_type column to model_configs table") dialect = sync_conn.dialect.name if dialect == "postgresql": sync_conn.execute(text("ALTER TABLE model_configs ADD COLUMN model_type VARCHAR(50) NOT NULL DEFAULT 'chat'")) else: sync_conn.execute(text("ALTER TABLE model_configs ADD COLUMN model_type VARCHAR(50) NOT NULL DEFAULT 'chat'")) async def close_db(): """Close database connections""" logger.info("Closing database connections...") await async_engine.dispose() logger.info("Database connections closed") @asynccontextmanager async def get_db_session() -> AsyncGenerator[AsyncSession, None]: """Context manager for database sessions with automatic cleanup""" session = AsyncSessionLocal() try: yield session except Exception as e: logger.error(f"Database session error: {str(e)}") await session.rollback() raise finally: await session.close() async def get_db() -> AsyncSession: """Dependency for getting database session""" async with AsyncSessionLocal() as session: try: yield session except Exception as e: logger.error(f"Database error in dependency: {str(e)}") await session.rollback() raise finally: await session.close() # Import all models to register them with Base.metadata # This ensures all models are loaded before create_all is called from app.models.models import * # noqa: F401, F403, E402