""" Files API Router """ import os import asyncio from pathlib import Path from typing import Optional from uuid import UUID, uuid4 from fastapi import APIRouter, Depends, UploadFile, File, Query from fastapi.responses import FileResponse from sqlalchemy.ext.asyncio import AsyncSession from app.api.response import ApiResponse, PaginatedResponse from app.core.config import get_settings from app.core.database import get_db from app.core.exceptions import ValidationException, NotFoundException from app.core.crud import CRUDBase from app.core.logging import log_success, log_failure from app.models.models import File as FileModel from app.schemas.file import FileResponse, FileCreateSchema settings = get_settings() router = APIRouter() # Initialize CRUD file_crud = CRUDBase(FileModel) def get_project_raw_dir(project_id: str) -> Path: """获取项目的 raw 文件目录""" base_dir = Path("/data/code/YG-Datasets/data") / project_id / "raw" base_dir.mkdir(parents=True, exist_ok=True) return base_dir def get_project_ready_dir(project_id: str) -> Path: """获取项目的 ready 文件目录(处理后的文件)""" base_dir = Path("/data/code/YG-Datasets/data") / project_id / "ready" base_dir.mkdir(parents=True, exist_ok=True) return base_dir def get_file_type(filename: str) -> str: """Get file type from extension""" ext = filename.rsplit('.', 1)[-1].lower() if '.' in filename else '' type_map = { 'pdf': 'pdf', 'docx': 'docx', 'doc': 'docx', 'xlsx': 'xlsx', 'xls': 'xlsx', 'csv': 'csv', 'epub': 'epub', 'md': 'md', 'markdown': 'md', 'txt': 'txt' } return type_map.get(ext, 'txt') # Allowed file extensions ALLOWED_EXTENSIONS = {'pdf', 'docx', 'doc', 'xlsx', 'xls', 'csv', 'epub', 'md', 'txt'} def validate_file(filename: str, file_size: int) -> None: """Validate file extension and size""" ext = filename.rsplit('.', 1)[-1].lower() if '.' in filename else '' if ext not in ALLOWED_EXTENSIONS: raise ValidationException( f"File type '{ext}' not allowed", field="file" ) if file_size > settings.MAX_FILE_SIZE: raise ValidationException( f"File size exceeds maximum allowed size of {settings.MAX_FILE_SIZE // (1024*1024)}MB", field="file" ) async def save_file_async(file: UploadFile, destination: Path) -> None: """Save uploaded file asynchronously""" content = await file.read() loop = asyncio.get_event_loop() await loop.run_in_executor(None, lambda: destination.write_bytes(content)) @router.post("/upload", response_model=ApiResponse) async def upload_file( project_id: UUID, file: UploadFile = File(...), db: AsyncSession = Depends(get_db) ): """Upload a file""" try: # Read file content for validation content = await file.read() file_size = len(content) # Validate file validate_file(file.filename, file_size) # Save file to disk - 使用项目 raw 目录 safe_filename = f"{uuid4().hex[:8]}_{file.filename}" project_dir = get_project_raw_dir(str(project_id)) file_path = project_dir / safe_filename # Write file asynchronously await asyncio.get_event_loop().run_in_executor( None, lambda: file_path.write_bytes(content) ) # Create file record db_file = FileModel( project_id=project_id, filename=file.filename, file_type=get_file_type(file.filename), file_path=str(file_path), size=file_size, status="pending" ) db.add(db_file) await db.commit() await db.refresh(db_file) # 记录成功日志 log_success( "文件上传成功", project_id=str(project_id), file_id=str(db_file.id), filename=file.filename, file_type=db_file.file_type, file_size=file_size, file_path=str(file_path) ) return ApiResponse.ok( data={"id": str(db_file.id), "filename": db_file.filename, "status": db_file.status}, message="File uploaded successfully" ) except Exception as e: # 记录失败日志 log_failure( "文件上传失败", project_id=str(project_id), filename=file.filename if 'file' in locals() else "unknown", error=str(e) ) raise @router.get("", response_model=ApiResponse) async def list_files( project_id: UUID, page: int = Query(1, ge=1), page_size: int = Query(20, ge=1, le=100), db: AsyncSession = Depends(get_db) ): """List files for a project""" skip = (page - 1) * page_size files, total = await file_crud.get_multi( db, skip=skip, limit=page_size, filters={"project_id": project_id}, order_by="created_at", descending=True ) file_responses = [FileResponse.model_validate(f) for f in files] return PaginatedResponse.ok( items=file_responses, page=page, page_size=page_size, total=total ) @router.get("/{file_id}", response_model=ApiResponse) async def get_file( project_id: UUID, file_id: UUID, db: AsyncSession = Depends(get_db) ): """Get file by ID""" file = await file_crud.get(db, file_id) if not file or file.project_id != project_id: raise NotFoundException("File", file_id) return ApiResponse.ok(data=FileResponse.model_validate(file)) @router.delete("/{file_id}", response_model=ApiResponse) async def delete_file( project_id: UUID, file_id: UUID, db: AsyncSession = Depends(get_db) ): """Delete file""" file = await file_crud.get(db, file_id) if not file or file.project_id != project_id: raise NotFoundException("File", file_id) # Delete file from disk if file.file_path and os.path.exists(file.file_path): await asyncio.get_event_loop().run_in_executor( None, os.remove, file.file_path ) await file_crud.delete(db, file_id) return ApiResponse.ok(message="File deleted successfully") @router.get("/{file_id}/download", response_class=FileResponse) async def download_file( project_id: UUID, file_id: UUID, db: AsyncSession = Depends(get_db) ): """Download file""" file = await file_crud.get(db, file_id) if not file or file.project_id != project_id: raise NotFoundException("File", file_id) if not file.file_path or not os.path.exists(file.file_path): raise ValidationException("File not found on disk", field="file") return FileResponse( path=file.file_path, filename=file.filename, media_type=f"application/{file.file_type}" )