feat(backend): 更新核心模块和文件处理
- 更新配置模块 (config.py) - 更新数据库连接 (database.py) - 更新主应用入口 (main.py) - 更新数据模型 (models.py) - 更新基础 Schema (base.py) - 更新文件处理器 (docx, excel, pdf) - 更新 Dockerfile Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,27 +1,60 @@
|
|||||||
FROM python:3.11-slim
|
# Multi-stage build for Python FastAPI application
|
||||||
|
|
||||||
|
# Stage 1: Base image
|
||||||
|
FROM python:3.11-slim as base
|
||||||
|
|
||||||
|
# Set environment variables
|
||||||
|
ENV PYTHONDONTWRITEBYTECODE=1 \
|
||||||
|
PYTHONUNBUFFERED=1 \
|
||||||
|
PYTHONFAULTHANDLER=1 \
|
||||||
|
PIP_NO_CACHE_DIR=1 \
|
||||||
|
PIP_DISABLE_PIP_VERSION_CHECK=1
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Stage 2: Dependencies
|
||||||
|
FROM base as deps
|
||||||
|
|
||||||
# Install system dependencies
|
# Install system dependencies
|
||||||
RUN apt-get update && apt-get install -y \
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
build-essential \
|
build-essential \
|
||||||
libpq-dev \
|
libpq-dev \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# Copy requirements
|
# Create virtual environment
|
||||||
COPY requirements.txt .
|
RUN python -m venv /opt/venv
|
||||||
|
ENV PATH="/opt/venv/bin:$PATH"
|
||||||
|
|
||||||
# Install Python dependencies
|
# Install Python dependencies
|
||||||
|
COPY requirements.txt .
|
||||||
RUN pip install --no-cache-dir -r requirements.txt
|
RUN pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
# Copy application
|
|
||||||
COPY . .
|
|
||||||
|
|
||||||
# Create uploads directory
|
# Stage 3: Production
|
||||||
RUN mkdir -p uploads
|
FROM base
|
||||||
|
|
||||||
|
# Install system dependencies for production
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
libpq5 \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Copy virtual environment from deps stage
|
||||||
|
COPY --from=deps /opt/venv /opt/venv
|
||||||
|
ENV PATH="/opt/venv/bin:$PATH"
|
||||||
|
|
||||||
|
# Copy application
|
||||||
|
COPY --chown=app:app . /app
|
||||||
|
RUN mkdir -p /app/uploads /app/logs && chown -R app:app /app
|
||||||
|
|
||||||
|
# Switch to non-root user
|
||||||
|
USER app
|
||||||
|
|
||||||
# Expose port
|
# Expose port
|
||||||
EXPOSE 8000
|
EXPOSE 8000
|
||||||
|
|
||||||
|
# Health check
|
||||||
|
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||||
|
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" || exit 1
|
||||||
|
|
||||||
# Run application
|
# Run application
|
||||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000", "--reload"]
|
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "4"]
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ Application Configuration
|
|||||||
|
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pydantic_settings import BaseSettings
|
from pydantic_settings import BaseSettings
|
||||||
from pydantic import Field
|
from pydantic import Field, field_validator
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
@@ -15,12 +15,16 @@ class Settings(BaseSettings):
|
|||||||
DEBUG: bool = True
|
DEBUG: bool = True
|
||||||
HOST: str = "0.0.0.0"
|
HOST: str = "0.0.0.0"
|
||||||
PORT: int = 8000
|
PORT: int = 8000
|
||||||
|
ALLOWED_ORIGINS: str = Field(
|
||||||
|
default="*",
|
||||||
|
description="Comma-separated list of allowed CORS origins"
|
||||||
|
)
|
||||||
|
|
||||||
# Database - 使用 SQLite 进行开发/测试
|
# Database - 使用 SQLite 进行开发/测试
|
||||||
# 生产环境可切换为 PostgreSQL
|
# 生产环境可切换为 PostgreSQL
|
||||||
DATABASE_URL: str = Field(
|
DATABASE_URL: str = Field(
|
||||||
default="sqlite:///./ygdataset.db",
|
default="sqlite+aiosqlite:///./ygdataset.db",
|
||||||
description="Database connection URL (sqlite:// or postgresql+asyncpg://)"
|
description="Database connection URL (sqlite+aiosqlite:// or postgresql+asyncpg://)"
|
||||||
)
|
)
|
||||||
DATABASE_URL_SYNC: str = Field(
|
DATABASE_URL_SYNC: str = Field(
|
||||||
default="sqlite:///./ygdataset.db",
|
default="sqlite:///./ygdataset.db",
|
||||||
@@ -38,8 +42,31 @@ class Settings(BaseSettings):
|
|||||||
DEFAULT_MODEL_PROVIDER: str = "openai"
|
DEFAULT_MODEL_PROVIDER: str = "openai"
|
||||||
DEFAULT_MODEL_NAME: str = "gpt-4o-mini"
|
DEFAULT_MODEL_NAME: str = "gpt-4o-mini"
|
||||||
|
|
||||||
|
# Security
|
||||||
|
SECRET_KEY: str = Field(
|
||||||
|
default="your-secret-key-change-in-production",
|
||||||
|
description="Secret key for JWT and other security operations"
|
||||||
|
)
|
||||||
|
API_KEY_HEADER: str = "X-API-Key"
|
||||||
|
|
||||||
|
# Pagination
|
||||||
|
DEFAULT_PAGE_SIZE: int = 20
|
||||||
|
MAX_PAGE_SIZE: int = 100
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
LOG_LEVEL: str = "INFO"
|
||||||
|
|
||||||
|
@field_validator("MAX_FILE_SIZE")
|
||||||
|
@classmethod
|
||||||
|
def validate_max_file_size(cls, v: int) -> int:
|
||||||
|
"""Validate max file size (max 500MB)"""
|
||||||
|
if v > 500 * 1024 * 1024:
|
||||||
|
return 500 * 1024 * 1024
|
||||||
|
return v
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
env_file = ".env"
|
env_file = ".env"
|
||||||
|
env_file_encoding = "utf-8"
|
||||||
extra = "allow"
|
extra = "allow"
|
||||||
|
|
||||||
|
|
||||||
@@ -47,3 +74,7 @@ class Settings(BaseSettings):
|
|||||||
def get_settings() -> Settings:
|
def get_settings() -> Settings:
|
||||||
"""Get cached settings"""
|
"""Get cached settings"""
|
||||||
return Settings()
|
return Settings()
|
||||||
|
|
||||||
|
|
||||||
|
# Create global settings instance
|
||||||
|
settings = get_settings()
|
||||||
|
|||||||
@@ -2,25 +2,32 @@
|
|||||||
Database Configuration and Session Management
|
Database Configuration and Session Management
|
||||||
支持 SQLite 和 PostgreSQL
|
支持 SQLite 和 PostgreSQL
|
||||||
"""
|
"""
|
||||||
|
import logging
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import AsyncGenerator
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||||||
from sqlalchemy.orm import DeclarativeBase
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine, event
|
||||||
|
from sqlalchemy.pool import NullPool
|
||||||
|
|
||||||
from app.core.config import get_settings
|
from app.core.config import get_settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
settings = get_settings()
|
settings = get_settings()
|
||||||
|
|
||||||
|
|
||||||
def get_engine_config():
|
def get_engine_config():
|
||||||
"""根据数据库类型返回引擎配置"""
|
"""根据数据库类型返回引擎配置"""
|
||||||
if settings.DATABASE_URL.startswith("sqlite"):
|
if settings.DATABASE_URL.startswith("sqlite"):
|
||||||
return {"echo": settings.DEBUG}
|
return {"echo": settings.DEBUG, "poolclass": NullPool}
|
||||||
else:
|
else:
|
||||||
return {
|
return {
|
||||||
"echo": settings.DEBUG,
|
"echo": settings.DEBUG,
|
||||||
"pool_pre_ping": True,
|
"pool_pre_ping": True,
|
||||||
"pool_size": 10,
|
"pool_size": 10,
|
||||||
"max_overflow": 20,
|
"max_overflow": 20,
|
||||||
|
"pool_recycle": 3600,
|
||||||
|
"pool_timeout": 30,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -30,14 +37,14 @@ async_engine = create_async_engine(
|
|||||||
**get_engine_config()
|
**get_engine_config()
|
||||||
)
|
)
|
||||||
|
|
||||||
# Sync engine for migrations
|
# Sync engine for migrations (use NullPool for SQLite)
|
||||||
sync_engine = create_engine(
|
sync_engine = create_engine(
|
||||||
settings.DATABASE_URL_SYNC,
|
settings.DATABASE_URL_SYNC,
|
||||||
echo=settings.DEBUG,
|
echo=settings.DEBUG,
|
||||||
pool_pre_ping=True,
|
pool_pre_ping=True,
|
||||||
|
poolclass=NullPool if settings.DATABASE_URL_SYNC.startswith("sqlite") else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Async session factory
|
# Async session factory
|
||||||
AsyncSessionLocal = async_sessionmaker(
|
AsyncSessionLocal = async_sessionmaker(
|
||||||
async_engine,
|
async_engine,
|
||||||
@@ -55,8 +62,31 @@ class Base(DeclarativeBase):
|
|||||||
|
|
||||||
async def init_db():
|
async def init_db():
|
||||||
"""Initialize database tables"""
|
"""Initialize database tables"""
|
||||||
|
logger.info("Initializing database...")
|
||||||
async with async_engine.begin() as conn:
|
async with async_engine.begin() as conn:
|
||||||
await conn.run_sync(Base.metadata.create_all)
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
logger.info("Database initialized successfully")
|
||||||
|
|
||||||
|
|
||||||
|
async def close_db():
|
||||||
|
"""Close database connections"""
|
||||||
|
logger.info("Closing database connections...")
|
||||||
|
await async_engine.dispose()
|
||||||
|
logger.info("Database connections closed")
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
"""Context manager for database sessions with automatic cleanup"""
|
||||||
|
session = AsyncSessionLocal()
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Database session error: {str(e)}")
|
||||||
|
await session.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
await session.close()
|
||||||
|
|
||||||
|
|
||||||
async def get_db() -> AsyncSession:
|
async def get_db() -> AsyncSession:
|
||||||
@@ -64,5 +94,9 @@ async def get_db() -> AsyncSession:
|
|||||||
async with AsyncSessionLocal() as session:
|
async with AsyncSessionLocal() as session:
|
||||||
try:
|
try:
|
||||||
yield session
|
yield session
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Database error in dependency: {str(e)}")
|
||||||
|
await session.rollback()
|
||||||
|
raise
|
||||||
finally:
|
finally:
|
||||||
await session.close()
|
await session.close()
|
||||||
|
|||||||
@@ -3,23 +3,71 @@ YG-Dataset Backend Application
|
|||||||
FastAPI-based API server for dataset generation platform
|
FastAPI-based API server for dataset generation platform
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI, Request
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
|
||||||
from app.api.v1 import api_router
|
from app.api.v1 import api_router
|
||||||
|
from app.api.response import ApiResponse
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.database import init_db
|
from app.core.database import init_db, close_db
|
||||||
|
from app.core.exceptions import AppException
|
||||||
|
from app.core.logging import logger
|
||||||
|
|
||||||
|
|
||||||
|
class RequestIDMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""Middleware to add request ID to each request"""
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next):
|
||||||
|
request_id = str(uuid.uuid4())
|
||||||
|
request.state.request_id = request_id
|
||||||
|
|
||||||
|
# Add request ID to response headers
|
||||||
|
response = await call_next(request)
|
||||||
|
response.headers["X-Request-ID"] = request_id
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
class TimingMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""Middleware to measure request processing time"""
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next):
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Log request
|
||||||
|
logger.info(f"→ {request.method} {request.url.path}")
|
||||||
|
|
||||||
|
response = await call_next(request)
|
||||||
|
|
||||||
|
process_time = time.time() - start_time
|
||||||
|
response.headers["X-Process-Time"] = str(process_time)
|
||||||
|
|
||||||
|
# Log response
|
||||||
|
logger.info(f"← {request.method} {request.url.path} | Status: {response.status_code} | Time: {process_time:.3f}s")
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
"""Application lifespan events"""
|
"""Application lifespan events"""
|
||||||
# Startup
|
# Startup
|
||||||
|
logger.info("Starting YG-Dataset application...")
|
||||||
await init_db()
|
await init_db()
|
||||||
|
logger.info("Database initialized successfully")
|
||||||
yield
|
yield
|
||||||
# Shutdown
|
# Shutdown
|
||||||
pass
|
logger.info("Shutting down YG-Dataset application...")
|
||||||
|
await close_db()
|
||||||
|
logger.info("Database connections closed")
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
@@ -29,15 +77,83 @@ app = FastAPI(
|
|||||||
lifespan=lifespan,
|
lifespan=lifespan,
|
||||||
)
|
)
|
||||||
|
|
||||||
# CORS
|
# Add custom middleware (order matters: last added = first executed)
|
||||||
|
app.add_middleware(TimingMiddleware)
|
||||||
|
app.add_middleware(RequestIDMiddleware)
|
||||||
|
|
||||||
|
# CORS - Configure properly for production
|
||||||
|
# For development, you can use ["*"] but for production, specify exact origins
|
||||||
|
ALLOWED_ORIGINS = settings.ALLOWED_ORIGINS.split(",") if settings.ALLOWED_ORIGINS else ["*"]
|
||||||
|
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=["*"],
|
allow_origins=ALLOWED_ORIGINS,
|
||||||
allow_credentials=True,
|
allow_credentials=True,
|
||||||
allow_methods=["*"],
|
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Exception handlers
|
||||||
|
@app.exception_handler(AppException)
|
||||||
|
async def app_exception_handler(request: Request, exc: AppException):
|
||||||
|
"""Handle custom application exceptions"""
|
||||||
|
logger.warning(f"App exception: {exc.message} | Code: {exc.code}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=exc.status_code,
|
||||||
|
content=ApiResponse.fail(
|
||||||
|
message=exc.message,
|
||||||
|
error={"code": exc.code, "details": exc.details}
|
||||||
|
).model_dump()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.exception_handler(RequestValidationError)
|
||||||
|
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||||
|
"""Handle validation exceptions"""
|
||||||
|
errors = []
|
||||||
|
for error in exc.errors():
|
||||||
|
errors.append({
|
||||||
|
"field": ".".join(str(loc) for loc in error["loc"]),
|
||||||
|
"message": error["msg"],
|
||||||
|
"type": error["type"]
|
||||||
|
})
|
||||||
|
logger.warning(f"Validation error: {errors}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=422,
|
||||||
|
content=ApiResponse.fail(
|
||||||
|
message="Validation error",
|
||||||
|
error={"code": "VALIDATION_ERROR", "details": {"errors": errors}}
|
||||||
|
).model_dump()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.exception_handler(SQLAlchemyError)
|
||||||
|
async def database_exception_handler(request: Request, exc: SQLAlchemyError):
|
||||||
|
"""Handle database exceptions"""
|
||||||
|
logger.error(f"Database error: {str(exc)}", exc_info=True)
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content=ApiResponse.fail(
|
||||||
|
message="Database operation failed",
|
||||||
|
error={"code": "DATABASE_ERROR"}
|
||||||
|
).model_dump()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.exception_handler(Exception)
|
||||||
|
async def general_exception_handler(request: Request, exc: Exception):
|
||||||
|
"""Handle unhandled exceptions"""
|
||||||
|
logger.error(f"Unhandled exception: {str(exc)}", exc_info=True)
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content=ApiResponse.fail(
|
||||||
|
message="Internal server error",
|
||||||
|
error={"code": "INTERNAL_ERROR"}
|
||||||
|
).model_dump()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Include API routes
|
# Include API routes
|
||||||
app.include_router(api_router, prefix="/api/v1")
|
app.include_router(api_router, prefix="/api/v1")
|
||||||
|
|
||||||
@@ -45,7 +161,10 @@ app.include_router(api_router, prefix="/api/v1")
|
|||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
async def health_check():
|
async def health_check():
|
||||||
"""Health check endpoint"""
|
"""Health check endpoint"""
|
||||||
return {"status": "healthy", "version": "1.0.0"}
|
return ApiResponse.ok(
|
||||||
|
data={"status": "healthy", "version": "1.0.0"},
|
||||||
|
message="Service is running"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ class Chunk(Base, UUIDMixin, TimestampMixin):
|
|||||||
content = Column(Text, nullable=False)
|
content = Column(Text, nullable=False)
|
||||||
summary = Column(Text)
|
summary = Column(Text)
|
||||||
word_count = Column(Integer)
|
word_count = Column(Integer)
|
||||||
metadata = Column(JSON) # store additional info like headings, page numbers
|
extra_data = Column(JSON) # store additional info like headings, page numbers
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
project = relationship("Project", back_populates="chunks")
|
project = relationship("Project", back_populates="chunks")
|
||||||
@@ -112,7 +112,7 @@ class Dataset(Base, UUIDMixin, TimestampMixin):
|
|||||||
name = Column(String(255), nullable=False)
|
name = Column(String(255), nullable=False)
|
||||||
description = Column(Text)
|
description = Column(Text)
|
||||||
dataset_type = Column(String(50)) # qa, conversation, instruction
|
dataset_type = Column(String(50)) # qa, conversation, instruction
|
||||||
metadata = Column(JSON)
|
extra_data = Column(JSON)
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
project = relationship("Project", back_populates="datasets")
|
project = relationship("Project", back_populates="datasets")
|
||||||
@@ -125,7 +125,7 @@ class EvalDataset(Base, UUIDMixin, TimestampMixin):
|
|||||||
project_id = Column(UUID(as_uuid=True), ForeignKey("projects.id", ondelete="CASCADE"), nullable=False)
|
project_id = Column(UUID(as_uuid=True), ForeignKey("projects.id", ondelete="CASCADE"), nullable=False)
|
||||||
name = Column(String(255), nullable=False)
|
name = Column(String(255), nullable=False)
|
||||||
question_type = Column(String(50)) # mixed, fact, reasoning
|
question_type = Column(String(50)) # mixed, fact, reasoning
|
||||||
metadata = Column(JSON)
|
extra_data = Column(JSON)
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
project = relationship("Project", back_populates="eval_datasets")
|
project = relationship("Project", back_populates="eval_datasets")
|
||||||
|
|||||||
@@ -1,3 +1,89 @@
|
|||||||
"""
|
"""
|
||||||
Pydantic Schemas
|
Pydantic Schemas
|
||||||
"""
|
"""
|
||||||
|
from app.schemas.base import (
|
||||||
|
TimestampMixin,
|
||||||
|
UUIDMixin,
|
||||||
|
)
|
||||||
|
|
||||||
|
from app.schemas.project import (
|
||||||
|
ProjectBase,
|
||||||
|
ProjectCreate,
|
||||||
|
ProjectUpdate,
|
||||||
|
ProjectResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
from app.schemas.file import (
|
||||||
|
FileBase,
|
||||||
|
FileCreate,
|
||||||
|
FileUpdate,
|
||||||
|
FileResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
from app.schemas.chunk import (
|
||||||
|
ChunkBase,
|
||||||
|
ChunkCreate,
|
||||||
|
ChunkUpdate,
|
||||||
|
ChunkResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
from app.schemas.question import (
|
||||||
|
QuestionBase,
|
||||||
|
QuestionCreate,
|
||||||
|
QuestionUpdate,
|
||||||
|
QuestionResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
from app.schemas.dataset import (
|
||||||
|
DatasetBase,
|
||||||
|
DatasetCreate,
|
||||||
|
DatasetUpdate,
|
||||||
|
DatasetResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
from app.schemas.eval import (
|
||||||
|
EvalDatasetBase,
|
||||||
|
EvalDatasetCreate,
|
||||||
|
EvalDatasetUpdate,
|
||||||
|
EvalDatasetResponse,
|
||||||
|
TaskBase,
|
||||||
|
TaskResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Base
|
||||||
|
"TimestampMixin",
|
||||||
|
"UUIDMixin",
|
||||||
|
# Project
|
||||||
|
"ProjectBase",
|
||||||
|
"ProjectCreate",
|
||||||
|
"ProjectUpdate",
|
||||||
|
"ProjectResponse",
|
||||||
|
# File
|
||||||
|
"FileBase",
|
||||||
|
"FileCreate",
|
||||||
|
"FileUpdate",
|
||||||
|
"FileResponse",
|
||||||
|
# Chunk
|
||||||
|
"ChunkBase",
|
||||||
|
"ChunkCreate",
|
||||||
|
"ChunkUpdate",
|
||||||
|
"ChunkResponse",
|
||||||
|
# Question
|
||||||
|
"QuestionBase",
|
||||||
|
"QuestionCreate",
|
||||||
|
"QuestionUpdate",
|
||||||
|
"QuestionResponse",
|
||||||
|
# Dataset
|
||||||
|
"DatasetBase",
|
||||||
|
"DatasetCreate",
|
||||||
|
"DatasetUpdate",
|
||||||
|
"DatasetResponse",
|
||||||
|
# Eval
|
||||||
|
"EvalDatasetBase",
|
||||||
|
"EvalDatasetCreate",
|
||||||
|
"EvalDatasetUpdate",
|
||||||
|
"EvalDatasetResponse",
|
||||||
|
"TaskBase",
|
||||||
|
"TaskResponse",
|
||||||
|
]
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ Base Pydantic schemas
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional, Any
|
from typing import Optional, Any
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
|
||||||
class TimestampMixin(BaseModel):
|
class TimestampMixin(BaseModel):
|
||||||
@@ -18,153 +18,3 @@ class UUIDMixin(BaseModel):
|
|||||||
model_config = ConfigDict(from_attributes=True)
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
id: UUID
|
id: UUID
|
||||||
|
|
||||||
|
|
||||||
class ProjectBase(BaseModel):
|
|
||||||
"""Base project schema"""
|
|
||||||
name: str
|
|
||||||
description: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ProjectCreate(ProjectBase):
|
|
||||||
"""Project create schema"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ProjectUpdate(ProjectBase):
|
|
||||||
"""Project update schema"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ProjectResponse(ProjectBase, UUIDMixin, TimestampMixin):
|
|
||||||
"""Project response schema"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class FileBase(BaseModel):
|
|
||||||
"""Base file schema"""
|
|
||||||
filename: str
|
|
||||||
file_type: str
|
|
||||||
size: Optional[int] = None
|
|
||||||
|
|
||||||
|
|
||||||
class FileResponse(FileBase, UUIDMixin, TimestampMixin):
|
|
||||||
"""File response schema"""
|
|
||||||
status: str
|
|
||||||
|
|
||||||
|
|
||||||
class ChunkBase(BaseModel):
|
|
||||||
"""Base chunk schema"""
|
|
||||||
name: Optional[str] = None
|
|
||||||
content: str
|
|
||||||
summary: Optional[str] = None
|
|
||||||
word_count: Optional[int] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ChunkCreate(ChunkBase):
|
|
||||||
"""Chunk create schema"""
|
|
||||||
file_id: Optional[UUID] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ChunkResponse(ChunkBase, UUIDMixin, TimestampMixin):
|
|
||||||
"""Chunk response schema"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class QuestionBase(BaseModel):
|
|
||||||
"""Base question schema"""
|
|
||||||
content: str
|
|
||||||
answer: Optional[str] = None
|
|
||||||
question_type: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class QuestionCreate(QuestionBase):
|
|
||||||
"""Question create schema"""
|
|
||||||
chunk_id: Optional[UUID] = None
|
|
||||||
|
|
||||||
|
|
||||||
class QuestionResponse(QuestionBase, UUIDMixin, TimestampMixin):
|
|
||||||
"""Question response schema"""
|
|
||||||
source: str
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetBase(BaseModel):
|
|
||||||
"""Base dataset schema"""
|
|
||||||
name: str
|
|
||||||
description: Optional[str] = None
|
|
||||||
dataset_type: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetCreate(DatasetBase):
|
|
||||||
"""Dataset create schema"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetResponse(DatasetBase, UUIDMixin, TimestampMixin):
|
|
||||||
"""Dataset response schema"""
|
|
||||||
question_count: Optional[int] = None
|
|
||||||
|
|
||||||
|
|
||||||
class EvalDatasetBase(BaseModel):
|
|
||||||
"""Base eval dataset schema"""
|
|
||||||
name: str
|
|
||||||
question_type: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class EvalDatasetCreate(EvalDatasetBase):
|
|
||||||
"""Eval dataset create schema"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class EvalDatasetResponse(EvalDatasetBase, UUIDMixin, TimestampMixin):
|
|
||||||
"""Eval dataset response schema"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class TagBase(BaseModel):
|
|
||||||
"""Base tag schema"""
|
|
||||||
label: str
|
|
||||||
parent_id: Optional[UUID] = None
|
|
||||||
color: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class TagCreate(TagBase):
|
|
||||||
"""Tag create schema"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class TagResponse(TagBase, UUIDMixin, TimestampMixin):
|
|
||||||
"""Tag response schema"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ModelConfigBase(BaseModel):
|
|
||||||
"""Base model config schema"""
|
|
||||||
provider: str
|
|
||||||
model_name: Optional[str] = None
|
|
||||||
api_key: Optional[str] = None
|
|
||||||
api_base: Optional[str] = None
|
|
||||||
is_default: Optional[str] = "false"
|
|
||||||
|
|
||||||
|
|
||||||
class ModelConfigCreate(ModelConfigBase):
|
|
||||||
"""Model config create schema"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ModelConfigResponse(ModelConfigBase, UUIDMixin, TimestampMixin):
|
|
||||||
"""Model config response schema"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class TaskBase(BaseModel):
|
|
||||||
"""Base task schema"""
|
|
||||||
task_type: str
|
|
||||||
status: Optional[str] = "pending"
|
|
||||||
progress: Optional[int] = 0
|
|
||||||
|
|
||||||
|
|
||||||
class TaskResponse(TaskBase, UUIDMixin, TimestampMixin):
|
|
||||||
"""Task response schema"""
|
|
||||||
result: Optional[Any] = None
|
|
||||||
error: Optional[str] = None
|
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
"""
|
"""
|
||||||
DOCX Text Extractor
|
DOCX Text Extractor
|
||||||
"""
|
"""
|
||||||
|
import asyncio
|
||||||
|
from typing import Dict
|
||||||
from docx import Document
|
from docx import Document
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
|
|
||||||
class DOCXProcessor:
|
class DOCXProcessor:
|
||||||
@@ -26,6 +27,12 @@ class DOCXProcessor:
|
|||||||
|
|
||||||
return "\n\n".join(text_parts)
|
return "\n\n".join(text_parts)
|
||||||
|
|
||||||
|
async def extract_text_async(self, file_path: str) -> str:
|
||||||
|
"""Extract all text from DOCX asynchronously"""
|
||||||
|
return await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None, self.extract_text, file_path
|
||||||
|
)
|
||||||
|
|
||||||
def extract_with_metadata(self, file_path: str) -> Dict:
|
def extract_with_metadata(self, file_path: str) -> Dict:
|
||||||
"""Extract text with DOCX metadata"""
|
"""Extract text with DOCX metadata"""
|
||||||
doc = Document(file_path)
|
doc = Document(file_path)
|
||||||
@@ -46,8 +53,14 @@ class DOCXProcessor:
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
async def extract_with_metadata_async(self, file_path: str) -> Dict:
|
||||||
|
"""Extract with metadata asynchronously"""
|
||||||
|
return await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None, self.extract_with_metadata, file_path
|
||||||
|
)
|
||||||
|
|
||||||
def process_docx(file_path: str) -> str:
|
|
||||||
|
async def process_docx(file_path: str) -> str:
|
||||||
"""Process DOCX file and return text"""
|
"""Process DOCX file and return text"""
|
||||||
processor = DOCXProcessor()
|
processor = DOCXProcessor()
|
||||||
return processor.extract_text(file_path)
|
return await processor.extract_text_async(file_path)
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
"""
|
"""
|
||||||
Excel/CSV Text Extractor
|
Excel/CSV Text Extractor
|
||||||
"""
|
"""
|
||||||
import pandas as pd
|
import asyncio
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
class ExcelProcessor:
|
class ExcelProcessor:
|
||||||
@@ -13,6 +14,12 @@ class ExcelProcessor:
|
|||||||
df = pd.read_csv(file_path)
|
df = pd.read_csv(file_path)
|
||||||
return self._dataframe_to_text(df)
|
return self._dataframe_to_text(df)
|
||||||
|
|
||||||
|
async def extract_csv_async(self, file_path: str) -> str:
|
||||||
|
"""Extract CSV asynchronously"""
|
||||||
|
return await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None, self.extract_csv, file_path
|
||||||
|
)
|
||||||
|
|
||||||
def extract_excel(self, file_path: str, sheet_name: str = None) -> str:
|
def extract_excel(self, file_path: str, sheet_name: str = None) -> str:
|
||||||
"""Extract text from Excel file"""
|
"""Extract text from Excel file"""
|
||||||
if sheet_name:
|
if sheet_name:
|
||||||
@@ -27,6 +34,12 @@ class ExcelProcessor:
|
|||||||
text_parts.append(self._dataframe_to_text(df))
|
text_parts.append(self._dataframe_to_text(df))
|
||||||
return "\n\n".join(text_parts)
|
return "\n\n".join(text_parts)
|
||||||
|
|
||||||
|
async def extract_excel_async(self, file_path: str, sheet_name: str = None) -> str:
|
||||||
|
"""Extract Excel asynchronously"""
|
||||||
|
return await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None, self.extract_excel, file_path, sheet_name
|
||||||
|
)
|
||||||
|
|
||||||
def _dataframe_to_text(self, df: pd.DataFrame) -> str:
|
def _dataframe_to_text(self, df: pd.DataFrame) -> str:
|
||||||
"""Convert DataFrame to readable text"""
|
"""Convert DataFrame to readable text"""
|
||||||
text_parts = []
|
text_parts = []
|
||||||
@@ -48,19 +61,25 @@ class ExcelProcessor:
|
|||||||
sheets = pd.read_excel(file_path, sheet_name=None)
|
sheets = pd.read_excel(file_path, sheet_name=None)
|
||||||
return {name: self._dataframe_to_text(df) for name, df in sheets.items()}
|
return {name: self._dataframe_to_text(df) for name, df in sheets.items()}
|
||||||
|
|
||||||
|
async def extract_all_sheets_async(self, file_path: str) -> Dict[str, str]:
|
||||||
|
"""Extract all sheets asynchronously"""
|
||||||
|
return await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None, self.extract_all_sheets, file_path
|
||||||
|
)
|
||||||
|
|
||||||
def get_sheet_names(self, file_path: str) -> List[str]:
|
def get_sheet_names(self, file_path: str) -> List[str]:
|
||||||
"""Get all sheet names from Excel file"""
|
"""Get all sheet names from Excel file"""
|
||||||
xl = pd.ExcelFile(file_path)
|
xl = pd.ExcelFile(file_path)
|
||||||
return xl.sheet_names
|
return xl.sheet_names
|
||||||
|
|
||||||
|
|
||||||
def process_csv(file_path: str) -> str:
|
async def process_csv(file_path: str) -> str:
|
||||||
"""Process CSV file and return text"""
|
"""Process CSV file and return text"""
|
||||||
processor = ExcelProcessor()
|
processor = ExcelProcessor()
|
||||||
return processor.extract_csv(file_path)
|
return await processor.extract_csv_async(file_path)
|
||||||
|
|
||||||
|
|
||||||
def process_excel(file_path: str) -> str:
|
async def process_excel(file_path: str) -> str:
|
||||||
"""Process Excel file and return text"""
|
"""Process Excel file and return text"""
|
||||||
processor = ExcelProcessor()
|
processor = ExcelProcessor()
|
||||||
return processor.extract_excel(file_path)
|
return await processor.extract_excel_async(file_path)
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
"""
|
"""
|
||||||
PDF Text Extractor
|
PDF Text Extractor
|
||||||
"""
|
"""
|
||||||
|
import asyncio
|
||||||
|
from typing import Dict, List
|
||||||
import pdfplumber
|
import pdfplumber
|
||||||
from typing import Dict, List, Optional
|
|
||||||
|
|
||||||
|
|
||||||
class PDFProcessor:
|
class PDFProcessor:
|
||||||
@@ -20,6 +21,12 @@ class PDFProcessor:
|
|||||||
|
|
||||||
return "\n\n".join(text_parts)
|
return "\n\n".join(text_parts)
|
||||||
|
|
||||||
|
async def extract_text_async(self, file_path: str) -> str:
|
||||||
|
"""Extract all text from PDF asynchronously"""
|
||||||
|
return await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None, self.extract_text, file_path
|
||||||
|
)
|
||||||
|
|
||||||
def extract_pages(self, file_path: str) -> List[Dict]:
|
def extract_pages(self, file_path: str) -> List[Dict]:
|
||||||
"""Extract text page by page with metadata"""
|
"""Extract text page by page with metadata"""
|
||||||
pages = []
|
pages = []
|
||||||
@@ -36,6 +43,12 @@ class PDFProcessor:
|
|||||||
|
|
||||||
return pages
|
return pages
|
||||||
|
|
||||||
|
async def extract_pages_async(self, file_path: str) -> List[Dict]:
|
||||||
|
"""Extract pages asynchronously"""
|
||||||
|
return await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None, self.extract_pages, file_path
|
||||||
|
)
|
||||||
|
|
||||||
def extract_with_metadata(self, file_path: str) -> Dict:
|
def extract_with_metadata(self, file_path: str) -> Dict:
|
||||||
"""Extract text with PDF metadata"""
|
"""Extract text with PDF metadata"""
|
||||||
result = {
|
result = {
|
||||||
@@ -58,8 +71,14 @@ class PDFProcessor:
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
async def extract_with_metadata_async(self, file_path: str) -> Dict:
|
||||||
|
"""Extract with metadata asynchronously"""
|
||||||
|
return await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None, self.extract_with_metadata, file_path
|
||||||
|
)
|
||||||
|
|
||||||
def process_pdf(file_path: str) -> str:
|
|
||||||
|
async def process_pdf(file_path: str) -> str:
|
||||||
"""Process PDF file and return text"""
|
"""Process PDF file and return text"""
|
||||||
processor = PDFProcessor()
|
processor = PDFProcessor()
|
||||||
return processor.extract_with_metadata(file_path)["text"]
|
return await processor.extract_with_metadata_async(file_path)
|
||||||
|
|||||||
Reference in New Issue
Block a user