Files
YG-Datasets/backend/app/main.py
Developer 7514e7e763 feat: 完善模型管理功能
- 新增模型 API 路由,支持 CRUD 和测试连接
- 支持 MiniMax、GLM、OpenAI Compatible 三种供应商
- 添加连接状态持久化 (untested/connected/disconnected)
- 修复 CORS 和数据库模型兼容性问题
- 前端 UI 优化:供应商默认 API 地址自动填充

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-17 23:02:43 +08:00

181 lines
5.2 KiB
Python

"""
YG-Dataset Backend Application
FastAPI-based API server for dataset generation platform
"""
import logging
import time
import uuid
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from starlette.middleware.base import BaseHTTPMiddleware
from sqlalchemy.exc import SQLAlchemyError
from app.api.v1 import api_router
from app.api.response import ApiResponse
from app.core.config import settings
from app.core.database import init_db, close_db
from app.core.exceptions import AppException
from app.core.logging import logger
# Import all models to register them with Base.metadata
from app.models.models import * # noqa: F401, F403
class RequestIDMiddleware(BaseHTTPMiddleware):
"""Middleware to add request ID to each request"""
async def dispatch(self, request: Request, call_next):
request_id = str(uuid.uuid4())
request.state.request_id = request_id
# Add request ID to response headers
response = await call_next(request)
response.headers["X-Request-ID"] = request_id
return response
class TimingMiddleware(BaseHTTPMiddleware):
"""Middleware to measure request processing time"""
async def dispatch(self, request: Request, call_next):
start_time = time.time()
# Log request
logger.info(f"{request.method} {request.url.path}")
response = await call_next(request)
process_time = time.time() - start_time
response.headers["X-Process-Time"] = str(process_time)
# Log response
logger.info(f"{request.method} {request.url.path} | Status: {response.status_code} | Time: {process_time:.3f}s")
return response
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan events"""
# Startup
logger.info("Starting YG-Dataset application...")
await init_db()
logger.info("Database initialized successfully")
yield
# Shutdown
logger.info("Shutting down YG-Dataset application...")
await close_db()
logger.info("Database connections closed")
app = FastAPI(
title="YG-Dataset API",
description="Dataset Generation Platform API",
version="1.0.0",
lifespan=lifespan,
)
# Add custom middleware (order matters: last added = first executed)
app.add_middleware(TimingMiddleware)
app.add_middleware(RequestIDMiddleware)
# CORS - Configure properly for production
# For development, you can use ["*"] but for production, specify exact origins
ALLOWED_ORIGINS = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=ALLOWED_ORIGINS,
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
allow_headers=["*"],
)
# Exception handlers
@app.exception_handler(AppException)
async def app_exception_handler(request: Request, exc: AppException):
"""Handle custom application exceptions"""
logger.warning(f"App exception: {exc.message} | Code: {exc.code}")
return JSONResponse(
status_code=exc.status_code,
content=ApiResponse.fail(
message=exc.message,
error={"code": exc.code, "details": exc.details}
).model_dump()
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
"""Handle validation exceptions"""
errors = []
for error in exc.errors():
errors.append({
"field": ".".join(str(loc) for loc in error["loc"]),
"message": error["msg"],
"type": error["type"]
})
logger.warning(f"Validation error: {errors}")
return JSONResponse(
status_code=422,
content=ApiResponse.fail(
message="Validation error",
error={"code": "VALIDATION_ERROR", "details": {"errors": errors}}
).model_dump()
)
@app.exception_handler(SQLAlchemyError)
async def database_exception_handler(request: Request, exc: SQLAlchemyError):
"""Handle database exceptions"""
logger.error(f"Database error: {str(exc)}", exc_info=True)
return JSONResponse(
status_code=500,
content=ApiResponse.fail(
message="Database operation failed",
error={"code": "DATABASE_ERROR"}
).model_dump()
)
@app.exception_handler(Exception)
async def general_exception_handler(request: Request, exc: Exception):
"""Handle unhandled exceptions"""
logger.error(f"Unhandled exception: {str(exc)}", exc_info=True)
return JSONResponse(
status_code=500,
content=ApiResponse.fail(
message="Internal server error",
error={"code": "INTERNAL_ERROR"}
).model_dump()
)
# Include API routes
app.include_router(api_router, prefix="/api/v1")
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return ApiResponse.ok(
data={"status": "healthy", "version": "1.0.0"},
message="Service is running"
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"app.main:app",
host=settings.HOST,
port=settings.PORT,
reload=settings.DEBUG,
)