feat(backend): 更新核心模块和文件处理
- 更新配置模块 (config.py) - 更新数据库连接 (database.py) - 更新主应用入口 (main.py) - 更新数据模型 (models.py) - 更新基础 Schema (base.py) - 更新文件处理器 (docx, excel, pdf) - 更新 Dockerfile Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,27 +1,60 @@
|
||||
FROM python:3.11-slim
|
||||
# Multi-stage build for Python FastAPI application
|
||||
|
||||
# Stage 1: Base image
|
||||
FROM python:3.11-slim as base
|
||||
|
||||
# Set environment variables
|
||||
ENV PYTHONDONTWRITEBYTECODE=1 \
|
||||
PYTHONUNBUFFERED=1 \
|
||||
PYTHONFAULTHANDLER=1 \
|
||||
PIP_NO_CACHE_DIR=1 \
|
||||
PIP_DISABLE_PIP_VERSION_CHECK=1
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Stage 2: Dependencies
|
||||
FROM base as deps
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
libpq-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy requirements
|
||||
COPY requirements.txt .
|
||||
# Create virtual environment
|
||||
RUN python -m venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
# Install Python dependencies
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy application
|
||||
COPY . .
|
||||
|
||||
# Create uploads directory
|
||||
RUN mkdir -p uploads
|
||||
# Stage 3: Production
|
||||
FROM base
|
||||
|
||||
# Install system dependencies for production
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
libpq5 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy virtual environment from deps stage
|
||||
COPY --from=deps /opt/venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
# Copy application
|
||||
COPY --chown=app:app . /app
|
||||
RUN mkdir -p /app/uploads /app/logs && chown -R app:app /app
|
||||
|
||||
# Switch to non-root user
|
||||
USER app
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8000
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" || exit 1
|
||||
|
||||
# Run application
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000", "--reload"]
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "4"]
|
||||
|
||||
@@ -4,7 +4,7 @@ Application Configuration
|
||||
|
||||
from functools import lru_cache
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic import Field
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
@@ -15,12 +15,16 @@ class Settings(BaseSettings):
|
||||
DEBUG: bool = True
|
||||
HOST: str = "0.0.0.0"
|
||||
PORT: int = 8000
|
||||
ALLOWED_ORIGINS: str = Field(
|
||||
default="*",
|
||||
description="Comma-separated list of allowed CORS origins"
|
||||
)
|
||||
|
||||
# Database - 使用 SQLite 进行开发/测试
|
||||
# 生产环境可切换为 PostgreSQL
|
||||
DATABASE_URL: str = Field(
|
||||
default="sqlite:///./ygdataset.db",
|
||||
description="Database connection URL (sqlite:// or postgresql+asyncpg://)"
|
||||
default="sqlite+aiosqlite:///./ygdataset.db",
|
||||
description="Database connection URL (sqlite+aiosqlite:// or postgresql+asyncpg://)"
|
||||
)
|
||||
DATABASE_URL_SYNC: str = Field(
|
||||
default="sqlite:///./ygdataset.db",
|
||||
@@ -38,8 +42,31 @@ class Settings(BaseSettings):
|
||||
DEFAULT_MODEL_PROVIDER: str = "openai"
|
||||
DEFAULT_MODEL_NAME: str = "gpt-4o-mini"
|
||||
|
||||
# Security
|
||||
SECRET_KEY: str = Field(
|
||||
default="your-secret-key-change-in-production",
|
||||
description="Secret key for JWT and other security operations"
|
||||
)
|
||||
API_KEY_HEADER: str = "X-API-Key"
|
||||
|
||||
# Pagination
|
||||
DEFAULT_PAGE_SIZE: int = 20
|
||||
MAX_PAGE_SIZE: int = 100
|
||||
|
||||
# Logging
|
||||
LOG_LEVEL: str = "INFO"
|
||||
|
||||
@field_validator("MAX_FILE_SIZE")
|
||||
@classmethod
|
||||
def validate_max_file_size(cls, v: int) -> int:
|
||||
"""Validate max file size (max 500MB)"""
|
||||
if v > 500 * 1024 * 1024:
|
||||
return 500 * 1024 * 1024
|
||||
return v
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
extra = "allow"
|
||||
|
||||
|
||||
@@ -47,3 +74,7 @@ class Settings(BaseSettings):
|
||||
def get_settings() -> Settings:
|
||||
"""Get cached settings"""
|
||||
return Settings()
|
||||
|
||||
|
||||
# Create global settings instance
|
||||
settings = get_settings()
|
||||
|
||||
@@ -2,25 +2,32 @@
|
||||
Database Configuration and Session Management
|
||||
支持 SQLite 和 PostgreSQL
|
||||
"""
|
||||
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy import create_engine, event
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
from app.core.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
def get_engine_config():
|
||||
"""根据数据库类型返回引擎配置"""
|
||||
if settings.DATABASE_URL.startswith("sqlite"):
|
||||
return {"echo": settings.DEBUG}
|
||||
return {"echo": settings.DEBUG, "poolclass": NullPool}
|
||||
else:
|
||||
return {
|
||||
"echo": settings.DEBUG,
|
||||
"pool_pre_ping": True,
|
||||
"pool_size": 10,
|
||||
"max_overflow": 20,
|
||||
"pool_recycle": 3600,
|
||||
"pool_timeout": 30,
|
||||
}
|
||||
|
||||
|
||||
@@ -30,14 +37,14 @@ async_engine = create_async_engine(
|
||||
**get_engine_config()
|
||||
)
|
||||
|
||||
# Sync engine for migrations
|
||||
# Sync engine for migrations (use NullPool for SQLite)
|
||||
sync_engine = create_engine(
|
||||
settings.DATABASE_URL_SYNC,
|
||||
echo=settings.DEBUG,
|
||||
pool_pre_ping=True,
|
||||
poolclass=NullPool if settings.DATABASE_URL_SYNC.startswith("sqlite") else None,
|
||||
)
|
||||
|
||||
|
||||
# Async session factory
|
||||
AsyncSessionLocal = async_sessionmaker(
|
||||
async_engine,
|
||||
@@ -55,8 +62,31 @@ class Base(DeclarativeBase):
|
||||
|
||||
async def init_db():
|
||||
"""Initialize database tables"""
|
||||
logger.info("Initializing database...")
|
||||
async with async_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
logger.info("Database initialized successfully")
|
||||
|
||||
|
||||
async def close_db():
|
||||
"""Close database connections"""
|
||||
logger.info("Closing database connections...")
|
||||
await async_engine.dispose()
|
||||
logger.info("Database connections closed")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Context manager for database sessions with automatic cleanup"""
|
||||
session = AsyncSessionLocal()
|
||||
try:
|
||||
yield session
|
||||
except Exception as e:
|
||||
logger.error(f"Database session error: {str(e)}")
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
async def get_db() -> AsyncSession:
|
||||
@@ -64,5 +94,9 @@ async def get_db() -> AsyncSession:
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
except Exception as e:
|
||||
logger.error(f"Database error in dependency: {str(e)}")
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
@@ -3,23 +3,71 @@ YG-Dataset Backend Application
|
||||
FastAPI-based API server for dataset generation platform
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from app.api.v1 import api_router
|
||||
from app.api.response import ApiResponse
|
||||
from app.core.config import settings
|
||||
from app.core.database import init_db
|
||||
from app.core.database import init_db, close_db
|
||||
from app.core.exceptions import AppException
|
||||
from app.core.logging import logger
|
||||
|
||||
|
||||
class RequestIDMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware to add request ID to each request"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
request_id = str(uuid.uuid4())
|
||||
request.state.request_id = request_id
|
||||
|
||||
# Add request ID to response headers
|
||||
response = await call_next(request)
|
||||
response.headers["X-Request-ID"] = request_id
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class TimingMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware to measure request processing time"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
start_time = time.time()
|
||||
|
||||
# Log request
|
||||
logger.info(f"→ {request.method} {request.url.path}")
|
||||
|
||||
response = await call_next(request)
|
||||
|
||||
process_time = time.time() - start_time
|
||||
response.headers["X-Process-Time"] = str(process_time)
|
||||
|
||||
# Log response
|
||||
logger.info(f"← {request.method} {request.url.path} | Status: {response.status_code} | Time: {process_time:.3f}s")
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Application lifespan events"""
|
||||
# Startup
|
||||
logger.info("Starting YG-Dataset application...")
|
||||
await init_db()
|
||||
logger.info("Database initialized successfully")
|
||||
yield
|
||||
# Shutdown
|
||||
pass
|
||||
logger.info("Shutting down YG-Dataset application...")
|
||||
await close_db()
|
||||
logger.info("Database connections closed")
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
@@ -29,15 +77,83 @@ app = FastAPI(
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# CORS
|
||||
# Add custom middleware (order matters: last added = first executed)
|
||||
app.add_middleware(TimingMiddleware)
|
||||
app.add_middleware(RequestIDMiddleware)
|
||||
|
||||
# CORS - Configure properly for production
|
||||
# For development, you can use ["*"] but for production, specify exact origins
|
||||
ALLOWED_ORIGINS = settings.ALLOWED_ORIGINS.split(",") if settings.ALLOWED_ORIGINS else ["*"]
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_origins=ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
# Exception handlers
|
||||
@app.exception_handler(AppException)
|
||||
async def app_exception_handler(request: Request, exc: AppException):
|
||||
"""Handle custom application exceptions"""
|
||||
logger.warning(f"App exception: {exc.message} | Code: {exc.code}")
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=ApiResponse.fail(
|
||||
message=exc.message,
|
||||
error={"code": exc.code, "details": exc.details}
|
||||
).model_dump()
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||
"""Handle validation exceptions"""
|
||||
errors = []
|
||||
for error in exc.errors():
|
||||
errors.append({
|
||||
"field": ".".join(str(loc) for loc in error["loc"]),
|
||||
"message": error["msg"],
|
||||
"type": error["type"]
|
||||
})
|
||||
logger.warning(f"Validation error: {errors}")
|
||||
return JSONResponse(
|
||||
status_code=422,
|
||||
content=ApiResponse.fail(
|
||||
message="Validation error",
|
||||
error={"code": "VALIDATION_ERROR", "details": {"errors": errors}}
|
||||
).model_dump()
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(SQLAlchemyError)
|
||||
async def database_exception_handler(request: Request, exc: SQLAlchemyError):
|
||||
"""Handle database exceptions"""
|
||||
logger.error(f"Database error: {str(exc)}", exc_info=True)
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=ApiResponse.fail(
|
||||
message="Database operation failed",
|
||||
error={"code": "DATABASE_ERROR"}
|
||||
).model_dump()
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def general_exception_handler(request: Request, exc: Exception):
|
||||
"""Handle unhandled exceptions"""
|
||||
logger.error(f"Unhandled exception: {str(exc)}", exc_info=True)
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=ApiResponse.fail(
|
||||
message="Internal server error",
|
||||
error={"code": "INTERNAL_ERROR"}
|
||||
).model_dump()
|
||||
)
|
||||
|
||||
|
||||
# Include API routes
|
||||
app.include_router(api_router, prefix="/api/v1")
|
||||
|
||||
@@ -45,7 +161,10 @@ app.include_router(api_router, prefix="/api/v1")
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint"""
|
||||
return {"status": "healthy", "version": "1.0.0"}
|
||||
return ApiResponse.ok(
|
||||
data={"status": "healthy", "version": "1.0.0"},
|
||||
message="Service is running"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -51,7 +51,7 @@ class Chunk(Base, UUIDMixin, TimestampMixin):
|
||||
content = Column(Text, nullable=False)
|
||||
summary = Column(Text)
|
||||
word_count = Column(Integer)
|
||||
metadata = Column(JSON) # store additional info like headings, page numbers
|
||||
extra_data = Column(JSON) # store additional info like headings, page numbers
|
||||
|
||||
# Relationships
|
||||
project = relationship("Project", back_populates="chunks")
|
||||
@@ -112,7 +112,7 @@ class Dataset(Base, UUIDMixin, TimestampMixin):
|
||||
name = Column(String(255), nullable=False)
|
||||
description = Column(Text)
|
||||
dataset_type = Column(String(50)) # qa, conversation, instruction
|
||||
metadata = Column(JSON)
|
||||
extra_data = Column(JSON)
|
||||
|
||||
# Relationships
|
||||
project = relationship("Project", back_populates="datasets")
|
||||
@@ -125,7 +125,7 @@ class EvalDataset(Base, UUIDMixin, TimestampMixin):
|
||||
project_id = Column(UUID(as_uuid=True), ForeignKey("projects.id", ondelete="CASCADE"), nullable=False)
|
||||
name = Column(String(255), nullable=False)
|
||||
question_type = Column(String(50)) # mixed, fact, reasoning
|
||||
metadata = Column(JSON)
|
||||
extra_data = Column(JSON)
|
||||
|
||||
# Relationships
|
||||
project = relationship("Project", back_populates="eval_datasets")
|
||||
|
||||
@@ -1,3 +1,89 @@
|
||||
"""
|
||||
Pydantic Schemas
|
||||
"""
|
||||
from app.schemas.base import (
|
||||
TimestampMixin,
|
||||
UUIDMixin,
|
||||
)
|
||||
|
||||
from app.schemas.project import (
|
||||
ProjectBase,
|
||||
ProjectCreate,
|
||||
ProjectUpdate,
|
||||
ProjectResponse,
|
||||
)
|
||||
|
||||
from app.schemas.file import (
|
||||
FileBase,
|
||||
FileCreate,
|
||||
FileUpdate,
|
||||
FileResponse,
|
||||
)
|
||||
|
||||
from app.schemas.chunk import (
|
||||
ChunkBase,
|
||||
ChunkCreate,
|
||||
ChunkUpdate,
|
||||
ChunkResponse,
|
||||
)
|
||||
|
||||
from app.schemas.question import (
|
||||
QuestionBase,
|
||||
QuestionCreate,
|
||||
QuestionUpdate,
|
||||
QuestionResponse,
|
||||
)
|
||||
|
||||
from app.schemas.dataset import (
|
||||
DatasetBase,
|
||||
DatasetCreate,
|
||||
DatasetUpdate,
|
||||
DatasetResponse,
|
||||
)
|
||||
|
||||
from app.schemas.eval import (
|
||||
EvalDatasetBase,
|
||||
EvalDatasetCreate,
|
||||
EvalDatasetUpdate,
|
||||
EvalDatasetResponse,
|
||||
TaskBase,
|
||||
TaskResponse,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Base
|
||||
"TimestampMixin",
|
||||
"UUIDMixin",
|
||||
# Project
|
||||
"ProjectBase",
|
||||
"ProjectCreate",
|
||||
"ProjectUpdate",
|
||||
"ProjectResponse",
|
||||
# File
|
||||
"FileBase",
|
||||
"FileCreate",
|
||||
"FileUpdate",
|
||||
"FileResponse",
|
||||
# Chunk
|
||||
"ChunkBase",
|
||||
"ChunkCreate",
|
||||
"ChunkUpdate",
|
||||
"ChunkResponse",
|
||||
# Question
|
||||
"QuestionBase",
|
||||
"QuestionCreate",
|
||||
"QuestionUpdate",
|
||||
"QuestionResponse",
|
||||
# Dataset
|
||||
"DatasetBase",
|
||||
"DatasetCreate",
|
||||
"DatasetUpdate",
|
||||
"DatasetResponse",
|
||||
# Eval
|
||||
"EvalDatasetBase",
|
||||
"EvalDatasetCreate",
|
||||
"EvalDatasetUpdate",
|
||||
"EvalDatasetResponse",
|
||||
"TaskBase",
|
||||
"TaskResponse",
|
||||
]
|
||||
|
||||
@@ -4,7 +4,7 @@ Base Pydantic schemas
|
||||
from datetime import datetime
|
||||
from typing import Optional, Any
|
||||
from uuid import UUID
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class TimestampMixin(BaseModel):
|
||||
@@ -18,153 +18,3 @@ class UUIDMixin(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: UUID
|
||||
|
||||
|
||||
class ProjectBase(BaseModel):
|
||||
"""Base project schema"""
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class ProjectCreate(ProjectBase):
|
||||
"""Project create schema"""
|
||||
pass
|
||||
|
||||
|
||||
class ProjectUpdate(ProjectBase):
|
||||
"""Project update schema"""
|
||||
pass
|
||||
|
||||
|
||||
class ProjectResponse(ProjectBase, UUIDMixin, TimestampMixin):
|
||||
"""Project response schema"""
|
||||
pass
|
||||
|
||||
|
||||
class FileBase(BaseModel):
|
||||
"""Base file schema"""
|
||||
filename: str
|
||||
file_type: str
|
||||
size: Optional[int] = None
|
||||
|
||||
|
||||
class FileResponse(FileBase, UUIDMixin, TimestampMixin):
|
||||
"""File response schema"""
|
||||
status: str
|
||||
|
||||
|
||||
class ChunkBase(BaseModel):
|
||||
"""Base chunk schema"""
|
||||
name: Optional[str] = None
|
||||
content: str
|
||||
summary: Optional[str] = None
|
||||
word_count: Optional[int] = None
|
||||
|
||||
|
||||
class ChunkCreate(ChunkBase):
|
||||
"""Chunk create schema"""
|
||||
file_id: Optional[UUID] = None
|
||||
|
||||
|
||||
class ChunkResponse(ChunkBase, UUIDMixin, TimestampMixin):
|
||||
"""Chunk response schema"""
|
||||
pass
|
||||
|
||||
|
||||
class QuestionBase(BaseModel):
|
||||
"""Base question schema"""
|
||||
content: str
|
||||
answer: Optional[str] = None
|
||||
question_type: Optional[str] = None
|
||||
|
||||
|
||||
class QuestionCreate(QuestionBase):
|
||||
"""Question create schema"""
|
||||
chunk_id: Optional[UUID] = None
|
||||
|
||||
|
||||
class QuestionResponse(QuestionBase, UUIDMixin, TimestampMixin):
|
||||
"""Question response schema"""
|
||||
source: str
|
||||
|
||||
|
||||
class DatasetBase(BaseModel):
|
||||
"""Base dataset schema"""
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
dataset_type: Optional[str] = None
|
||||
|
||||
|
||||
class DatasetCreate(DatasetBase):
|
||||
"""Dataset create schema"""
|
||||
pass
|
||||
|
||||
|
||||
class DatasetResponse(DatasetBase, UUIDMixin, TimestampMixin):
|
||||
"""Dataset response schema"""
|
||||
question_count: Optional[int] = None
|
||||
|
||||
|
||||
class EvalDatasetBase(BaseModel):
|
||||
"""Base eval dataset schema"""
|
||||
name: str
|
||||
question_type: Optional[str] = None
|
||||
|
||||
|
||||
class EvalDatasetCreate(EvalDatasetBase):
|
||||
"""Eval dataset create schema"""
|
||||
pass
|
||||
|
||||
|
||||
class EvalDatasetResponse(EvalDatasetBase, UUIDMixin, TimestampMixin):
|
||||
"""Eval dataset response schema"""
|
||||
pass
|
||||
|
||||
|
||||
class TagBase(BaseModel):
|
||||
"""Base tag schema"""
|
||||
label: str
|
||||
parent_id: Optional[UUID] = None
|
||||
color: Optional[str] = None
|
||||
|
||||
|
||||
class TagCreate(TagBase):
|
||||
"""Tag create schema"""
|
||||
pass
|
||||
|
||||
|
||||
class TagResponse(TagBase, UUIDMixin, TimestampMixin):
|
||||
"""Tag response schema"""
|
||||
pass
|
||||
|
||||
|
||||
class ModelConfigBase(BaseModel):
|
||||
"""Base model config schema"""
|
||||
provider: str
|
||||
model_name: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
api_base: Optional[str] = None
|
||||
is_default: Optional[str] = "false"
|
||||
|
||||
|
||||
class ModelConfigCreate(ModelConfigBase):
|
||||
"""Model config create schema"""
|
||||
pass
|
||||
|
||||
|
||||
class ModelConfigResponse(ModelConfigBase, UUIDMixin, TimestampMixin):
|
||||
"""Model config response schema"""
|
||||
pass
|
||||
|
||||
|
||||
class TaskBase(BaseModel):
|
||||
"""Base task schema"""
|
||||
task_type: str
|
||||
status: Optional[str] = "pending"
|
||||
progress: Optional[int] = 0
|
||||
|
||||
|
||||
class TaskResponse(TaskBase, UUIDMixin, TimestampMixin):
|
||||
"""Task response schema"""
|
||||
result: Optional[Any] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
"""
|
||||
DOCX Text Extractor
|
||||
"""
|
||||
import asyncio
|
||||
from typing import Dict
|
||||
from docx import Document
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
class DOCXProcessor:
|
||||
@@ -26,6 +27,12 @@ class DOCXProcessor:
|
||||
|
||||
return "\n\n".join(text_parts)
|
||||
|
||||
async def extract_text_async(self, file_path: str) -> str:
|
||||
"""Extract all text from DOCX asynchronously"""
|
||||
return await asyncio.get_event_loop().run_in_executor(
|
||||
None, self.extract_text, file_path
|
||||
)
|
||||
|
||||
def extract_with_metadata(self, file_path: str) -> Dict:
|
||||
"""Extract text with DOCX metadata"""
|
||||
doc = Document(file_path)
|
||||
@@ -46,8 +53,14 @@ class DOCXProcessor:
|
||||
|
||||
return result
|
||||
|
||||
async def extract_with_metadata_async(self, file_path: str) -> Dict:
|
||||
"""Extract with metadata asynchronously"""
|
||||
return await asyncio.get_event_loop().run_in_executor(
|
||||
None, self.extract_with_metadata, file_path
|
||||
)
|
||||
|
||||
def process_docx(file_path: str) -> str:
|
||||
|
||||
async def process_docx(file_path: str) -> str:
|
||||
"""Process DOCX file and return text"""
|
||||
processor = DOCXProcessor()
|
||||
return processor.extract_text(file_path)
|
||||
return await processor.extract_text_async(file_path)
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
"""
|
||||
Excel/CSV Text Extractor
|
||||
"""
|
||||
import pandas as pd
|
||||
import asyncio
|
||||
from typing import Dict, List
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class ExcelProcessor:
|
||||
@@ -13,6 +14,12 @@ class ExcelProcessor:
|
||||
df = pd.read_csv(file_path)
|
||||
return self._dataframe_to_text(df)
|
||||
|
||||
async def extract_csv_async(self, file_path: str) -> str:
|
||||
"""Extract CSV asynchronously"""
|
||||
return await asyncio.get_event_loop().run_in_executor(
|
||||
None, self.extract_csv, file_path
|
||||
)
|
||||
|
||||
def extract_excel(self, file_path: str, sheet_name: str = None) -> str:
|
||||
"""Extract text from Excel file"""
|
||||
if sheet_name:
|
||||
@@ -27,6 +34,12 @@ class ExcelProcessor:
|
||||
text_parts.append(self._dataframe_to_text(df))
|
||||
return "\n\n".join(text_parts)
|
||||
|
||||
async def extract_excel_async(self, file_path: str, sheet_name: str = None) -> str:
|
||||
"""Extract Excel asynchronously"""
|
||||
return await asyncio.get_event_loop().run_in_executor(
|
||||
None, self.extract_excel, file_path, sheet_name
|
||||
)
|
||||
|
||||
def _dataframe_to_text(self, df: pd.DataFrame) -> str:
|
||||
"""Convert DataFrame to readable text"""
|
||||
text_parts = []
|
||||
@@ -48,19 +61,25 @@ class ExcelProcessor:
|
||||
sheets = pd.read_excel(file_path, sheet_name=None)
|
||||
return {name: self._dataframe_to_text(df) for name, df in sheets.items()}
|
||||
|
||||
async def extract_all_sheets_async(self, file_path: str) -> Dict[str, str]:
|
||||
"""Extract all sheets asynchronously"""
|
||||
return await asyncio.get_event_loop().run_in_executor(
|
||||
None, self.extract_all_sheets, file_path
|
||||
)
|
||||
|
||||
def get_sheet_names(self, file_path: str) -> List[str]:
|
||||
"""Get all sheet names from Excel file"""
|
||||
xl = pd.ExcelFile(file_path)
|
||||
return xl.sheet_names
|
||||
|
||||
|
||||
def process_csv(file_path: str) -> str:
|
||||
async def process_csv(file_path: str) -> str:
|
||||
"""Process CSV file and return text"""
|
||||
processor = ExcelProcessor()
|
||||
return processor.extract_csv(file_path)
|
||||
return await processor.extract_csv_async(file_path)
|
||||
|
||||
|
||||
def process_excel(file_path: str) -> str:
|
||||
async def process_excel(file_path: str) -> str:
|
||||
"""Process Excel file and return text"""
|
||||
processor = ExcelProcessor()
|
||||
return processor.extract_excel(file_path)
|
||||
return await processor.extract_excel_async(file_path)
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
"""
|
||||
PDF Text Extractor
|
||||
"""
|
||||
import asyncio
|
||||
from typing import Dict, List
|
||||
import pdfplumber
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
|
||||
class PDFProcessor:
|
||||
@@ -20,6 +21,12 @@ class PDFProcessor:
|
||||
|
||||
return "\n\n".join(text_parts)
|
||||
|
||||
async def extract_text_async(self, file_path: str) -> str:
|
||||
"""Extract all text from PDF asynchronously"""
|
||||
return await asyncio.get_event_loop().run_in_executor(
|
||||
None, self.extract_text, file_path
|
||||
)
|
||||
|
||||
def extract_pages(self, file_path: str) -> List[Dict]:
|
||||
"""Extract text page by page with metadata"""
|
||||
pages = []
|
||||
@@ -36,6 +43,12 @@ class PDFProcessor:
|
||||
|
||||
return pages
|
||||
|
||||
async def extract_pages_async(self, file_path: str) -> List[Dict]:
|
||||
"""Extract pages asynchronously"""
|
||||
return await asyncio.get_event_loop().run_in_executor(
|
||||
None, self.extract_pages, file_path
|
||||
)
|
||||
|
||||
def extract_with_metadata(self, file_path: str) -> Dict:
|
||||
"""Extract text with PDF metadata"""
|
||||
result = {
|
||||
@@ -58,8 +71,14 @@ class PDFProcessor:
|
||||
|
||||
return result
|
||||
|
||||
async def extract_with_metadata_async(self, file_path: str) -> Dict:
|
||||
"""Extract with metadata asynchronously"""
|
||||
return await asyncio.get_event_loop().run_in_executor(
|
||||
None, self.extract_with_metadata, file_path
|
||||
)
|
||||
|
||||
def process_pdf(file_path: str) -> str:
|
||||
|
||||
async def process_pdf(file_path: str) -> str:
|
||||
"""Process PDF file and return text"""
|
||||
processor = PDFProcessor()
|
||||
return processor.extract_with_metadata(file_path)["text"]
|
||||
return await processor.extract_with_metadata_async(file_path)
|
||||
|
||||
Reference in New Issue
Block a user