feat(backend): 更新核心模块和文件处理

- 更新配置模块 (config.py)
- 更新数据库连接 (database.py)
- 更新主应用入口 (main.py)
- 更新数据模型 (models.py)
- 更新基础 Schema (base.py)
- 更新文件处理器 (docx, excel, pdf)
- 更新 Dockerfile

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Developer
2026-03-17 17:30:11 +08:00
parent db11429290
commit 47d1da7cea
10 changed files with 393 additions and 189 deletions

View File

@@ -1,27 +1,60 @@
FROM python:3.11-slim
# Multi-stage build for Python FastAPI application
# Stage 1: Base image
FROM python:3.11-slim as base
# Set environment variables
ENV PYTHONDONTWRITEBYTECODE=1 \
PYTHONUNBUFFERED=1 \
PYTHONFAULTHANDLER=1 \
PIP_NO_CACHE_DIR=1 \
PIP_DISABLE_PIP_VERSION_CHECK=1
WORKDIR /app
# Stage 2: Dependencies
FROM base as deps
# Install system dependencies
RUN apt-get update && apt-get install -y \
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
libpq-dev \
&& rm -rf /var/lib/apt/lists/*
# Copy requirements
COPY requirements.txt .
# Create virtual environment
RUN python -m venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
# Install Python dependencies
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# Copy application
COPY . .
# Create uploads directory
RUN mkdir -p uploads
# Stage 3: Production
FROM base
# Install system dependencies for production
RUN apt-get update && apt-get install -y --no-install-recommends \
libpq5 \
&& rm -rf /var/lib/apt/lists/*
# Copy virtual environment from deps stage
COPY --from=deps /opt/venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
# Copy application
COPY --chown=app:app . /app
RUN mkdir -p /app/uploads /app/logs && chown -R app:app /app
# Switch to non-root user
USER app
# Expose port
EXPOSE 8000
# Health check
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" || exit 1
# Run application
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000", "--reload"]
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "4"]

View File

@@ -4,7 +4,7 @@ Application Configuration
from functools import lru_cache
from pydantic_settings import BaseSettings
from pydantic import Field
from pydantic import Field, field_validator
class Settings(BaseSettings):
@@ -15,12 +15,16 @@ class Settings(BaseSettings):
DEBUG: bool = True
HOST: str = "0.0.0.0"
PORT: int = 8000
ALLOWED_ORIGINS: str = Field(
default="*",
description="Comma-separated list of allowed CORS origins"
)
# Database - 使用 SQLite 进行开发/测试
# 生产环境可切换为 PostgreSQL
DATABASE_URL: str = Field(
default="sqlite:///./ygdataset.db",
description="Database connection URL (sqlite:// or postgresql+asyncpg://)"
default="sqlite+aiosqlite:///./ygdataset.db",
description="Database connection URL (sqlite+aiosqlite:// or postgresql+asyncpg://)"
)
DATABASE_URL_SYNC: str = Field(
default="sqlite:///./ygdataset.db",
@@ -38,8 +42,31 @@ class Settings(BaseSettings):
DEFAULT_MODEL_PROVIDER: str = "openai"
DEFAULT_MODEL_NAME: str = "gpt-4o-mini"
# Security
SECRET_KEY: str = Field(
default="your-secret-key-change-in-production",
description="Secret key for JWT and other security operations"
)
API_KEY_HEADER: str = "X-API-Key"
# Pagination
DEFAULT_PAGE_SIZE: int = 20
MAX_PAGE_SIZE: int = 100
# Logging
LOG_LEVEL: str = "INFO"
@field_validator("MAX_FILE_SIZE")
@classmethod
def validate_max_file_size(cls, v: int) -> int:
"""Validate max file size (max 500MB)"""
if v > 500 * 1024 * 1024:
return 500 * 1024 * 1024
return v
class Config:
env_file = ".env"
env_file_encoding = "utf-8"
extra = "allow"
@@ -47,3 +74,7 @@ class Settings(BaseSettings):
def get_settings() -> Settings:
"""Get cached settings"""
return Settings()
# Create global settings instance
settings = get_settings()

View File

@@ -2,25 +2,32 @@
Database Configuration and Session Management
支持 SQLite 和 PostgreSQL
"""
import logging
from contextlib import asynccontextmanager
from typing import AsyncGenerator
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy import create_engine
from sqlalchemy import create_engine, event
from sqlalchemy.pool import NullPool
from app.core.config import get_settings
logger = logging.getLogger(__name__)
settings = get_settings()
def get_engine_config():
"""根据数据库类型返回引擎配置"""
if settings.DATABASE_URL.startswith("sqlite"):
return {"echo": settings.DEBUG}
return {"echo": settings.DEBUG, "poolclass": NullPool}
else:
return {
"echo": settings.DEBUG,
"pool_pre_ping": True,
"pool_size": 10,
"max_overflow": 20,
"pool_recycle": 3600,
"pool_timeout": 30,
}
@@ -30,14 +37,14 @@ async_engine = create_async_engine(
**get_engine_config()
)
# Sync engine for migrations
# Sync engine for migrations (use NullPool for SQLite)
sync_engine = create_engine(
settings.DATABASE_URL_SYNC,
echo=settings.DEBUG,
pool_pre_ping=True,
poolclass=NullPool if settings.DATABASE_URL_SYNC.startswith("sqlite") else None,
)
# Async session factory
AsyncSessionLocal = async_sessionmaker(
async_engine,
@@ -55,8 +62,31 @@ class Base(DeclarativeBase):
async def init_db():
"""Initialize database tables"""
logger.info("Initializing database...")
async with async_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
logger.info("Database initialized successfully")
async def close_db():
"""Close database connections"""
logger.info("Closing database connections...")
await async_engine.dispose()
logger.info("Database connections closed")
@asynccontextmanager
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
"""Context manager for database sessions with automatic cleanup"""
session = AsyncSessionLocal()
try:
yield session
except Exception as e:
logger.error(f"Database session error: {str(e)}")
await session.rollback()
raise
finally:
await session.close()
async def get_db() -> AsyncSession:
@@ -64,5 +94,9 @@ async def get_db() -> AsyncSession:
async with AsyncSessionLocal() as session:
try:
yield session
except Exception as e:
logger.error(f"Database error in dependency: {str(e)}")
await session.rollback()
raise
finally:
await session.close()

View File

@@ -3,23 +3,71 @@ YG-Dataset Backend Application
FastAPI-based API server for dataset generation platform
"""
import logging
import time
import uuid
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from starlette.middleware.base import BaseHTTPMiddleware
from sqlalchemy.exc import SQLAlchemyError
from app.api.v1 import api_router
from app.api.response import ApiResponse
from app.core.config import settings
from app.core.database import init_db
from app.core.database import init_db, close_db
from app.core.exceptions import AppException
from app.core.logging import logger
class RequestIDMiddleware(BaseHTTPMiddleware):
"""Middleware to add request ID to each request"""
async def dispatch(self, request: Request, call_next):
request_id = str(uuid.uuid4())
request.state.request_id = request_id
# Add request ID to response headers
response = await call_next(request)
response.headers["X-Request-ID"] = request_id
return response
class TimingMiddleware(BaseHTTPMiddleware):
"""Middleware to measure request processing time"""
async def dispatch(self, request: Request, call_next):
start_time = time.time()
# Log request
logger.info(f"{request.method} {request.url.path}")
response = await call_next(request)
process_time = time.time() - start_time
response.headers["X-Process-Time"] = str(process_time)
# Log response
logger.info(f"{request.method} {request.url.path} | Status: {response.status_code} | Time: {process_time:.3f}s")
return response
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan events"""
# Startup
logger.info("Starting YG-Dataset application...")
await init_db()
logger.info("Database initialized successfully")
yield
# Shutdown
pass
logger.info("Shutting down YG-Dataset application...")
await close_db()
logger.info("Database connections closed")
app = FastAPI(
@@ -29,15 +77,83 @@ app = FastAPI(
lifespan=lifespan,
)
# CORS
# Add custom middleware (order matters: last added = first executed)
app.add_middleware(TimingMiddleware)
app.add_middleware(RequestIDMiddleware)
# CORS - Configure properly for production
# For development, you can use ["*"] but for production, specify exact origins
ALLOWED_ORIGINS = settings.ALLOWED_ORIGINS.split(",") if settings.ALLOWED_ORIGINS else ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_origins=ALLOWED_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
allow_headers=["*"],
)
# Exception handlers
@app.exception_handler(AppException)
async def app_exception_handler(request: Request, exc: AppException):
"""Handle custom application exceptions"""
logger.warning(f"App exception: {exc.message} | Code: {exc.code}")
return JSONResponse(
status_code=exc.status_code,
content=ApiResponse.fail(
message=exc.message,
error={"code": exc.code, "details": exc.details}
).model_dump()
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
"""Handle validation exceptions"""
errors = []
for error in exc.errors():
errors.append({
"field": ".".join(str(loc) for loc in error["loc"]),
"message": error["msg"],
"type": error["type"]
})
logger.warning(f"Validation error: {errors}")
return JSONResponse(
status_code=422,
content=ApiResponse.fail(
message="Validation error",
error={"code": "VALIDATION_ERROR", "details": {"errors": errors}}
).model_dump()
)
@app.exception_handler(SQLAlchemyError)
async def database_exception_handler(request: Request, exc: SQLAlchemyError):
"""Handle database exceptions"""
logger.error(f"Database error: {str(exc)}", exc_info=True)
return JSONResponse(
status_code=500,
content=ApiResponse.fail(
message="Database operation failed",
error={"code": "DATABASE_ERROR"}
).model_dump()
)
@app.exception_handler(Exception)
async def general_exception_handler(request: Request, exc: Exception):
"""Handle unhandled exceptions"""
logger.error(f"Unhandled exception: {str(exc)}", exc_info=True)
return JSONResponse(
status_code=500,
content=ApiResponse.fail(
message="Internal server error",
error={"code": "INTERNAL_ERROR"}
).model_dump()
)
# Include API routes
app.include_router(api_router, prefix="/api/v1")
@@ -45,7 +161,10 @@ app.include_router(api_router, prefix="/api/v1")
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {"status": "healthy", "version": "1.0.0"}
return ApiResponse.ok(
data={"status": "healthy", "version": "1.0.0"},
message="Service is running"
)
if __name__ == "__main__":

View File

@@ -51,7 +51,7 @@ class Chunk(Base, UUIDMixin, TimestampMixin):
content = Column(Text, nullable=False)
summary = Column(Text)
word_count = Column(Integer)
metadata = Column(JSON) # store additional info like headings, page numbers
extra_data = Column(JSON) # store additional info like headings, page numbers
# Relationships
project = relationship("Project", back_populates="chunks")
@@ -112,7 +112,7 @@ class Dataset(Base, UUIDMixin, TimestampMixin):
name = Column(String(255), nullable=False)
description = Column(Text)
dataset_type = Column(String(50)) # qa, conversation, instruction
metadata = Column(JSON)
extra_data = Column(JSON)
# Relationships
project = relationship("Project", back_populates="datasets")
@@ -125,7 +125,7 @@ class EvalDataset(Base, UUIDMixin, TimestampMixin):
project_id = Column(UUID(as_uuid=True), ForeignKey("projects.id", ondelete="CASCADE"), nullable=False)
name = Column(String(255), nullable=False)
question_type = Column(String(50)) # mixed, fact, reasoning
metadata = Column(JSON)
extra_data = Column(JSON)
# Relationships
project = relationship("Project", back_populates="eval_datasets")

View File

@@ -1,3 +1,89 @@
"""
Pydantic Schemas
"""
from app.schemas.base import (
TimestampMixin,
UUIDMixin,
)
from app.schemas.project import (
ProjectBase,
ProjectCreate,
ProjectUpdate,
ProjectResponse,
)
from app.schemas.file import (
FileBase,
FileCreate,
FileUpdate,
FileResponse,
)
from app.schemas.chunk import (
ChunkBase,
ChunkCreate,
ChunkUpdate,
ChunkResponse,
)
from app.schemas.question import (
QuestionBase,
QuestionCreate,
QuestionUpdate,
QuestionResponse,
)
from app.schemas.dataset import (
DatasetBase,
DatasetCreate,
DatasetUpdate,
DatasetResponse,
)
from app.schemas.eval import (
EvalDatasetBase,
EvalDatasetCreate,
EvalDatasetUpdate,
EvalDatasetResponse,
TaskBase,
TaskResponse,
)
__all__ = [
# Base
"TimestampMixin",
"UUIDMixin",
# Project
"ProjectBase",
"ProjectCreate",
"ProjectUpdate",
"ProjectResponse",
# File
"FileBase",
"FileCreate",
"FileUpdate",
"FileResponse",
# Chunk
"ChunkBase",
"ChunkCreate",
"ChunkUpdate",
"ChunkResponse",
# Question
"QuestionBase",
"QuestionCreate",
"QuestionUpdate",
"QuestionResponse",
# Dataset
"DatasetBase",
"DatasetCreate",
"DatasetUpdate",
"DatasetResponse",
# Eval
"EvalDatasetBase",
"EvalDatasetCreate",
"EvalDatasetUpdate",
"EvalDatasetResponse",
"TaskBase",
"TaskResponse",
]

View File

@@ -4,7 +4,7 @@ Base Pydantic schemas
from datetime import datetime
from typing import Optional, Any
from uuid import UUID
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, Field
class TimestampMixin(BaseModel):
@@ -18,153 +18,3 @@ class UUIDMixin(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: UUID
class ProjectBase(BaseModel):
"""Base project schema"""
name: str
description: Optional[str] = None
class ProjectCreate(ProjectBase):
"""Project create schema"""
pass
class ProjectUpdate(ProjectBase):
"""Project update schema"""
pass
class ProjectResponse(ProjectBase, UUIDMixin, TimestampMixin):
"""Project response schema"""
pass
class FileBase(BaseModel):
"""Base file schema"""
filename: str
file_type: str
size: Optional[int] = None
class FileResponse(FileBase, UUIDMixin, TimestampMixin):
"""File response schema"""
status: str
class ChunkBase(BaseModel):
"""Base chunk schema"""
name: Optional[str] = None
content: str
summary: Optional[str] = None
word_count: Optional[int] = None
class ChunkCreate(ChunkBase):
"""Chunk create schema"""
file_id: Optional[UUID] = None
class ChunkResponse(ChunkBase, UUIDMixin, TimestampMixin):
"""Chunk response schema"""
pass
class QuestionBase(BaseModel):
"""Base question schema"""
content: str
answer: Optional[str] = None
question_type: Optional[str] = None
class QuestionCreate(QuestionBase):
"""Question create schema"""
chunk_id: Optional[UUID] = None
class QuestionResponse(QuestionBase, UUIDMixin, TimestampMixin):
"""Question response schema"""
source: str
class DatasetBase(BaseModel):
"""Base dataset schema"""
name: str
description: Optional[str] = None
dataset_type: Optional[str] = None
class DatasetCreate(DatasetBase):
"""Dataset create schema"""
pass
class DatasetResponse(DatasetBase, UUIDMixin, TimestampMixin):
"""Dataset response schema"""
question_count: Optional[int] = None
class EvalDatasetBase(BaseModel):
"""Base eval dataset schema"""
name: str
question_type: Optional[str] = None
class EvalDatasetCreate(EvalDatasetBase):
"""Eval dataset create schema"""
pass
class EvalDatasetResponse(EvalDatasetBase, UUIDMixin, TimestampMixin):
"""Eval dataset response schema"""
pass
class TagBase(BaseModel):
"""Base tag schema"""
label: str
parent_id: Optional[UUID] = None
color: Optional[str] = None
class TagCreate(TagBase):
"""Tag create schema"""
pass
class TagResponse(TagBase, UUIDMixin, TimestampMixin):
"""Tag response schema"""
pass
class ModelConfigBase(BaseModel):
"""Base model config schema"""
provider: str
model_name: Optional[str] = None
api_key: Optional[str] = None
api_base: Optional[str] = None
is_default: Optional[str] = "false"
class ModelConfigCreate(ModelConfigBase):
"""Model config create schema"""
pass
class ModelConfigResponse(ModelConfigBase, UUIDMixin, TimestampMixin):
"""Model config response schema"""
pass
class TaskBase(BaseModel):
"""Base task schema"""
task_type: str
status: Optional[str] = "pending"
progress: Optional[int] = 0
class TaskResponse(TaskBase, UUIDMixin, TimestampMixin):
"""Task response schema"""
result: Optional[Any] = None
error: Optional[str] = None

View File

@@ -1,8 +1,9 @@
"""
DOCX Text Extractor
"""
import asyncio
from typing import Dict
from docx import Document
from typing import Dict, List
class DOCXProcessor:
@@ -26,6 +27,12 @@ class DOCXProcessor:
return "\n\n".join(text_parts)
async def extract_text_async(self, file_path: str) -> str:
"""Extract all text from DOCX asynchronously"""
return await asyncio.get_event_loop().run_in_executor(
None, self.extract_text, file_path
)
def extract_with_metadata(self, file_path: str) -> Dict:
"""Extract text with DOCX metadata"""
doc = Document(file_path)
@@ -46,8 +53,14 @@ class DOCXProcessor:
return result
async def extract_with_metadata_async(self, file_path: str) -> Dict:
"""Extract with metadata asynchronously"""
return await asyncio.get_event_loop().run_in_executor(
None, self.extract_with_metadata, file_path
)
def process_docx(file_path: str) -> str:
async def process_docx(file_path: str) -> str:
"""Process DOCX file and return text"""
processor = DOCXProcessor()
return processor.extract_text(file_path)
return await processor.extract_text_async(file_path)

View File

@@ -1,8 +1,9 @@
"""
Excel/CSV Text Extractor
"""
import pandas as pd
import asyncio
from typing import Dict, List
import pandas as pd
class ExcelProcessor:
@@ -13,6 +14,12 @@ class ExcelProcessor:
df = pd.read_csv(file_path)
return self._dataframe_to_text(df)
async def extract_csv_async(self, file_path: str) -> str:
"""Extract CSV asynchronously"""
return await asyncio.get_event_loop().run_in_executor(
None, self.extract_csv, file_path
)
def extract_excel(self, file_path: str, sheet_name: str = None) -> str:
"""Extract text from Excel file"""
if sheet_name:
@@ -27,6 +34,12 @@ class ExcelProcessor:
text_parts.append(self._dataframe_to_text(df))
return "\n\n".join(text_parts)
async def extract_excel_async(self, file_path: str, sheet_name: str = None) -> str:
"""Extract Excel asynchronously"""
return await asyncio.get_event_loop().run_in_executor(
None, self.extract_excel, file_path, sheet_name
)
def _dataframe_to_text(self, df: pd.DataFrame) -> str:
"""Convert DataFrame to readable text"""
text_parts = []
@@ -48,19 +61,25 @@ class ExcelProcessor:
sheets = pd.read_excel(file_path, sheet_name=None)
return {name: self._dataframe_to_text(df) for name, df in sheets.items()}
async def extract_all_sheets_async(self, file_path: str) -> Dict[str, str]:
"""Extract all sheets asynchronously"""
return await asyncio.get_event_loop().run_in_executor(
None, self.extract_all_sheets, file_path
)
def get_sheet_names(self, file_path: str) -> List[str]:
"""Get all sheet names from Excel file"""
xl = pd.ExcelFile(file_path)
return xl.sheet_names
def process_csv(file_path: str) -> str:
async def process_csv(file_path: str) -> str:
"""Process CSV file and return text"""
processor = ExcelProcessor()
return processor.extract_csv(file_path)
return await processor.extract_csv_async(file_path)
def process_excel(file_path: str) -> str:
async def process_excel(file_path: str) -> str:
"""Process Excel file and return text"""
processor = ExcelProcessor()
return processor.extract_excel(file_path)
return await processor.extract_excel_async(file_path)

View File

@@ -1,8 +1,9 @@
"""
PDF Text Extractor
"""
import asyncio
from typing import Dict, List
import pdfplumber
from typing import Dict, List, Optional
class PDFProcessor:
@@ -20,6 +21,12 @@ class PDFProcessor:
return "\n\n".join(text_parts)
async def extract_text_async(self, file_path: str) -> str:
"""Extract all text from PDF asynchronously"""
return await asyncio.get_event_loop().run_in_executor(
None, self.extract_text, file_path
)
def extract_pages(self, file_path: str) -> List[Dict]:
"""Extract text page by page with metadata"""
pages = []
@@ -36,6 +43,12 @@ class PDFProcessor:
return pages
async def extract_pages_async(self, file_path: str) -> List[Dict]:
"""Extract pages asynchronously"""
return await asyncio.get_event_loop().run_in_executor(
None, self.extract_pages, file_path
)
def extract_with_metadata(self, file_path: str) -> Dict:
"""Extract text with PDF metadata"""
result = {
@@ -58,8 +71,14 @@ class PDFProcessor:
return result
async def extract_with_metadata_async(self, file_path: str) -> Dict:
"""Extract with metadata asynchronously"""
return await asyncio.get_event_loop().run_in_executor(
None, self.extract_with_metadata, file_path
)
def process_pdf(file_path: str) -> str:
async def process_pdf(file_path: str) -> str:
"""Process PDF file and return text"""
processor = PDFProcessor()
return processor.extract_with_metadata(file_path)["text"]
return await processor.extract_with_metadata_async(file_path)