- 更新 Chunks API 端点 - 更新 Datasets API 端点 - 更新 Evaluation API 端点 - 更新 Files API 端点 - 更新 Projects API 端点 - 更新 Questions API 端点 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
204 lines
5.6 KiB
Python
204 lines
5.6 KiB
Python
"""
|
|
Files API Router
|
|
"""
|
|
import os
|
|
import asyncio
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
from uuid import UUID
|
|
from fastapi import APIRouter, Depends, UploadFile, File, Query
|
|
from fastapi.responses import FileResponse
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.api.response import ApiResponse, PaginatedResponse
|
|
from app.core.config import get_settings
|
|
from app.core.database import get_db
|
|
from app.core.exceptions import ValidationException, NotFoundException
|
|
from app.core.crud import CRUDBase
|
|
from app.models.models import File as FileModel
|
|
from app.schemas.file import FileResponse, FileCreateSchema
|
|
|
|
settings = get_settings()
|
|
router = APIRouter()
|
|
|
|
# Ensure upload directory exists
|
|
UPLOAD_DIR = Path(settings.UPLOAD_DIR)
|
|
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Initialize CRUD
|
|
file_crud = CRUDBase(FileModel)
|
|
|
|
|
|
def get_file_type(filename: str) -> str:
|
|
"""Get file type from extension"""
|
|
ext = filename.rsplit('.', 1)[-1].lower() if '.' in filename else ''
|
|
type_map = {
|
|
'pdf': 'pdf',
|
|
'docx': 'docx',
|
|
'doc': 'docx',
|
|
'xlsx': 'xlsx',
|
|
'xls': 'xlsx',
|
|
'csv': 'csv',
|
|
'epub': 'epub',
|
|
'md': 'md',
|
|
'markdown': 'md',
|
|
'txt': 'txt'
|
|
}
|
|
return type_map.get(ext, 'txt')
|
|
|
|
|
|
# Allowed file extensions
|
|
ALLOWED_EXTENSIONS = {'pdf', 'docx', 'doc', 'xlsx', 'xls', 'csv', 'epub', 'md', 'txt'}
|
|
|
|
|
|
def validate_file(filename: str, file_size: int) -> None:
|
|
"""Validate file extension and size"""
|
|
ext = filename.rsplit('.', 1)[-1].lower() if '.' in filename else ''
|
|
|
|
if ext not in ALLOWED_EXTENSIONS:
|
|
raise ValidationException(
|
|
f"File type '{ext}' not allowed",
|
|
field="file"
|
|
)
|
|
|
|
if file_size > settings.MAX_FILE_SIZE:
|
|
raise ValidationException(
|
|
f"File size exceeds maximum allowed size of {settings.MAX_FILE_SIZE // (1024*1024)}MB",
|
|
field="file"
|
|
)
|
|
|
|
|
|
async def save_file_async(file: UploadFile, destination: Path) -> None:
|
|
"""Save uploaded file asynchronously"""
|
|
content = await file.read()
|
|
loop = asyncio.get_event_loop()
|
|
await loop.run_in_executor(None, lambda: destination.write_bytes(content))
|
|
|
|
|
|
@router.post("/upload", response_model=ApiResponse)
|
|
async def upload_file(
|
|
project_id: UUID,
|
|
file: UploadFile = File(...),
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""Upload a file"""
|
|
# Read file content for validation
|
|
content = await file.read()
|
|
file_size = len(content)
|
|
|
|
# Validate file
|
|
validate_file(file.filename, file_size)
|
|
|
|
# Save file to disk
|
|
safe_filename = f"{project_id}_{UUID.uuid4().hex[:8]}_{file.filename}"
|
|
file_path = UPLOAD_DIR / safe_filename
|
|
|
|
# Write file asynchronously
|
|
await asyncio.get_event_loop().run_in_executor(
|
|
None,
|
|
lambda: file_path.write_bytes(content)
|
|
)
|
|
|
|
# Create file record
|
|
db_file = FileModel(
|
|
project_id=project_id,
|
|
filename=file.filename,
|
|
file_type=get_file_type(file.filename),
|
|
file_path=str(file_path),
|
|
size=file_size,
|
|
status="pending"
|
|
)
|
|
db.add(db_file)
|
|
await db.commit()
|
|
await db.refresh(db_file)
|
|
|
|
return ApiResponse.ok(
|
|
data={"id": str(db_file.id), "filename": db_file.filename, "status": db_file.status},
|
|
message="File uploaded successfully"
|
|
)
|
|
|
|
|
|
@router.get("", response_model=ApiResponse)
|
|
async def list_files(
|
|
project_id: UUID,
|
|
page: int = Query(1, ge=1),
|
|
page_size: int = Query(20, ge=1, le=100),
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""List files for a project"""
|
|
skip = (page - 1) * page_size
|
|
files, total = await file_crud.get_multi(
|
|
db,
|
|
skip=skip,
|
|
limit=page_size,
|
|
filters={"project_id": project_id},
|
|
order_by="created_at",
|
|
descending=True
|
|
)
|
|
|
|
file_responses = [FileResponse.model_validate(f) for f in files]
|
|
return PaginatedResponse.ok(
|
|
items=file_responses,
|
|
page=page,
|
|
page_size=page_size,
|
|
total=total
|
|
)
|
|
|
|
|
|
@router.get("/{file_id}", response_model=ApiResponse)
|
|
async def get_file(
|
|
project_id: UUID,
|
|
file_id: UUID,
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""Get file by ID"""
|
|
file = await file_crud.get(db, file_id)
|
|
if not file or file.project_id != project_id:
|
|
raise NotFoundException("File", file_id)
|
|
|
|
return ApiResponse.ok(data=FileResponse.model_validate(file))
|
|
|
|
|
|
@router.delete("/{file_id}", response_model=ApiResponse)
|
|
async def delete_file(
|
|
project_id: UUID,
|
|
file_id: UUID,
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""Delete file"""
|
|
file = await file_crud.get(db, file_id)
|
|
if not file or file.project_id != project_id:
|
|
raise NotFoundException("File", file_id)
|
|
|
|
# Delete file from disk
|
|
if file.file_path and os.path.exists(file.file_path):
|
|
await asyncio.get_event_loop().run_in_executor(
|
|
None,
|
|
os.remove,
|
|
file.file_path
|
|
)
|
|
|
|
await file_crud.delete(db, file_id)
|
|
return ApiResponse.ok(message="File deleted successfully")
|
|
|
|
|
|
@router.get("/{file_id}/download", response_class=FileResponse)
|
|
async def download_file(
|
|
project_id: UUID,
|
|
file_id: UUID,
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""Download file"""
|
|
file = await file_crud.get(db, file_id)
|
|
if not file or file.project_id != project_id:
|
|
raise NotFoundException("File", file_id)
|
|
|
|
if not file.file_path or not os.path.exists(file.file_path):
|
|
raise ValidationException("File not found on disk", field="file")
|
|
|
|
return FileResponse(
|
|
path=file.file_path,
|
|
filename=file.filename,
|
|
media_type=f"application/{file.file_type}"
|
|
)
|