fix(backend): 修复文件上传后异步处理失败问题

- 修复 async_session_maker 未定义错误,改用 AsyncSessionLocal
- 确保文件上传后能正确异步转换为 Markdown

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Developer
2026-03-18 16:08:00 +08:00
parent 9a12907f25
commit 1cf44ac6f7

View File

@@ -7,7 +7,7 @@ from pathlib import Path
from typing import Optional
from uuid import UUID, uuid4
from fastapi import APIRouter, Depends, UploadFile, File, Query
from fastapi.responses import FileResponse
from fastapi.responses import FileResponse, PlainTextResponse
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.response import ApiResponse, PaginatedResponse
@@ -18,6 +18,7 @@ from app.core.crud import CRUDBase
from app.core.logging import log_success, log_failure
from app.models.models import File as FileModel
from app.schemas.file import FileResponse, FileCreateSchema
from markitdown import MarkItDown
settings = get_settings()
router = APIRouter()
@@ -25,6 +26,9 @@ router = APIRouter()
# Initialize CRUD
file_crud = CRUDBase(FileModel)
# Initialize markitdown
markitdown = MarkItDown()
def get_project_raw_dir(project_id: str) -> Path:
"""获取项目的 raw 文件目录"""
@@ -119,26 +123,88 @@ async def upload_file(
file_type=get_file_type(file.filename),
file_path=str(file_path),
size=file_size,
status="pending"
status="processing"
)
db.add(db_file)
await db.commit()
await db.refresh(db_file)
# 记录成功日志
# 异步处理文件:立即返回,不等待处理完成
async def process_file_async(file_id: UUID, file_path_obj: Path, file_type: str, filename: str, project_id_val: UUID):
"""后台异步处理文件"""
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import AsyncSessionLocal
async with AsyncSessionLocal() as processing_db:
try:
# 重新获取文件记录
file_record = await file_crud.get(processing_db, file_id)
if not file_record:
return
# 支持 markitdown 转换的文件类型
markitdown_types = ["pdf", "docx", "doc", "pptx", "ppt", "xlsx", "xls", "htm", "html"]
text_content = ""
if file_type in markitdown_types:
# 使用 markitdown 转换为 markdown
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(
None,
lambda: markitdown.convert(str(file_path_obj))
)
text_content = result.text_content
else:
# txt, md 等直接读取
text_content = file_path_obj.read_text(encoding='utf-8')
# 保存到 ready 目录
ready_dir = get_project_ready_dir(str(project_id_val))
ready_filename = f"{file_id}.md"
ready_path = ready_dir / ready_filename
ready_path.write_text(text_content, encoding='utf-8')
# 更新文件状态为处理完成
file_record.status = "completed"
await processing_db.commit()
log_success(
"文件上传成功",
project_id=str(project_id),
file_id=str(db_file.id),
filename=file.filename,
file_type=db_file.file_type,
file_size=file_size,
file_path=str(file_path)
"文件处理完成",
project_id=str(project_id_val),
file_id=str(file_id),
filename=filename,
ready_path=str(ready_path)
)
except Exception as e:
# 更新文件状态为处理失败
file_record = await file_crud.get(processing_db, file_id)
if file_record:
file_record.status = "failed"
await processing_db.commit()
log_failure(
"文件处理失败",
project_id=str(project_id_val),
file_id=str(file_id),
filename=filename,
error=str(e)
)
# 启动异步任务处理文件
asyncio.create_task(
process_file_async(
db_file.id,
file_path,
db_file.file_type,
file.filename,
project_id
)
)
return ApiResponse.ok(
data={"id": str(db_file.id), "filename": db_file.filename, "status": db_file.status},
message="File uploaded successfully"
message="File uploaded successfully, processing in background"
)
except Exception as e:
# 记录失败日志
@@ -192,6 +258,71 @@ async def get_file(
return ApiResponse.ok(data=FileResponse.model_validate(file))
@router.get("/{file_id}/raw")
async def get_file_raw(
project_id: UUID,
file_id: UUID,
db: AsyncSession = Depends(get_db)
):
"""Get raw file content for preview"""
file = await file_crud.get(db, file_id)
if not file or file.project_id != project_id:
raise NotFoundException("File", file_id)
# 读取 raw 目录中的原始文件
raw_path = Path(file.file_path)
if not raw_path.exists():
raise NotFoundException("File not found on disk", file_id)
# 根据文件类型返回不同的内容
if file.file_type in ['txt', 'md', 'markdown', 'csv']:
content = raw_path.read_text(encoding='utf-8')
return PlainTextResponse(content=content, media_type="text/plain; charset=utf-8")
elif file.file_type == 'pdf':
# 返回PDF文件浏览器可以内嵌显示
import base64
content = raw_path.read_bytes()
b64 = base64.b64encode(content).decode('utf-8')
return PlainTextResponse(
content=f"data:application/pdf;base64,{b64}",
media_type="text/plain"
)
else:
# 其他二进制文件,返回文件信息
size_mb = file.size / (1024 * 1024)
content = f"""[二进制文件]
文件名: {file.filename}
文件类型: {file.file_type.upper()}
文件大小: {size_mb:.2f} MB
此文件为二进制格式,请下载后查看。
"""
return PlainTextResponse(content=content, media_type="text/plain; charset=utf-8")
@router.get("/{file_id}/content")
async def get_file_content(
project_id: UUID,
file_id: UUID,
db: AsyncSession = Depends(get_db)
) -> PlainTextResponse:
"""Get file content (markdown)"""
file = await file_crud.get(db, file_id)
if not file or file.project_id != project_id:
raise NotFoundException("File", file_id)
# 读取 ready 目录中的 markdown 文件
ready_path = Path("/data/code/YG-Datasets/data") / str(project_id) / "ready" / f"{file_id}.md"
if ready_path.exists():
content = ready_path.read_text(encoding='utf-8')
return PlainTextResponse(content=content, media_type="text/plain; charset=utf-8")
else:
raise NotFoundException("File content", file_id)
@router.delete("/{file_id}", response_model=ApiResponse)
async def delete_file(
project_id: UUID,
@@ -203,7 +334,7 @@ async def delete_file(
if not file or file.project_id != project_id:
raise NotFoundException("File", file_id)
# Delete file from disk
# Delete file from raw directory
if file.file_path and os.path.exists(file.file_path):
await asyncio.get_event_loop().run_in_executor(
None,
@@ -211,16 +342,25 @@ async def delete_file(
file.file_path
)
# Delete file from ready directory (processed markdown)
ready_path = Path("/data/code/YG-Datasets/data") / str(project_id) / "ready" / f"{file_id}.md"
if ready_path.exists():
await asyncio.get_event_loop().run_in_executor(
None,
os.remove,
str(ready_path)
)
await file_crud.delete(db, file_id)
return ApiResponse.ok(message="File deleted successfully")
@router.get("/{file_id}/download", response_class=FileResponse)
@router.get("/{file_id}/download")
async def download_file(
project_id: UUID,
file_id: UUID,
db: AsyncSession = Depends(get_db)
):
) -> FileResponse:
"""Download file"""
file = await file_crud.get(db, file_id)
if not file or file.project_id != project_id: