502 lines
19 KiB
Python
502 lines
19 KiB
Python
|
|
"""
|
|||
|
|
数据集管理 API 模块
|
|||
|
|
提供数据集上传、列表、删除等功能
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from typing import List, Optional
|
|||
|
|
from fastapi import UploadFile, File, Form
|
|||
|
|
|
|||
|
|
from src.api.internal.base import BaseAPI, post, get, delete
|
|||
|
|
from src.models.response import StandardResponse
|
|||
|
|
from src.services.file_upload import file_upload_service
|
|||
|
|
|
|||
|
|
|
|||
|
|
def format_file_size(size_bytes: int) -> str:
|
|||
|
|
"""
|
|||
|
|
格式化文件大小显示
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
size_bytes: 文件大小(字节)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
str: 格式化后的文件大小字符串
|
|||
|
|
"""
|
|||
|
|
if size_bytes < 1024:
|
|||
|
|
return f"{size_bytes} B"
|
|||
|
|
elif size_bytes < 1024 * 1024:
|
|||
|
|
size_kb = round(size_bytes / 1024, 2)
|
|||
|
|
return f"{size_kb} KB"
|
|||
|
|
else:
|
|||
|
|
size_mb = round(size_bytes / 1024 / 1024, 2)
|
|||
|
|
return f"{size_mb} MB"
|
|||
|
|
|
|||
|
|
|
|||
|
|
class DatasetAPI(BaseAPI):
|
|||
|
|
"""数据集管理 API - 自动注册到 /api/datasets 路径"""
|
|||
|
|
|
|||
|
|
def __init__(self):
|
|||
|
|
# 重写初始化逻辑以设置正确的路由前缀
|
|||
|
|
# 1. 手动设置 module_name
|
|||
|
|
self.module_name = "api.datasets"
|
|||
|
|
|
|||
|
|
# 2. 创建路由器(使用期望的前缀)
|
|||
|
|
from fastapi import APIRouter
|
|||
|
|
self.router_prefix = "/api/datasets"
|
|||
|
|
self.router = APIRouter(
|
|||
|
|
prefix=self.router_prefix,
|
|||
|
|
tags=["Datasets"]
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 3. 获取日志器
|
|||
|
|
import sys
|
|||
|
|
self.logger = __import__('src.utils.logger', fromlist=['get_logger']).get_logger(self.__class__.__module__)
|
|||
|
|
|
|||
|
|
# 4. 调用基类的自动注册(此时router已被覆盖)
|
|||
|
|
# 注意:我们不调用父类__init__,而是手动调用_auto_register_routes
|
|||
|
|
self._auto_register_routes()
|
|||
|
|
|
|||
|
|
# 5. 记录初始化
|
|||
|
|
self.logger.info(
|
|||
|
|
f"API模块初始化完成",
|
|||
|
|
module=self.module_name,
|
|||
|
|
prefix=self.router_prefix,
|
|||
|
|
routes=len(self.router.routes)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
@post("/upload", response_model=StandardResponse)
|
|||
|
|
async def upload_dataset(
|
|||
|
|
self,
|
|||
|
|
file: UploadFile = File(...),
|
|||
|
|
description: Optional[str] = Form(None)
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
上传数据集文件
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
file: 上传的文件(支持 .json, .jsonl 格式)
|
|||
|
|
description: 文件描述(可选)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
StandardResponse: 包含上传结果的标准响应
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
# 验证文件类型
|
|||
|
|
filename = file.filename or "unknown"
|
|||
|
|
file_ext = filename.lower().split('.')[-1] if '.' in filename else ''
|
|||
|
|
|
|||
|
|
if file_ext not in ['json', 'jsonl']:
|
|||
|
|
return StandardResponse.error("只支持 .json 和 .jsonl 格式的文件")
|
|||
|
|
|
|||
|
|
# 读取文件内容
|
|||
|
|
file_content = await file.read()
|
|||
|
|
|
|||
|
|
# 如果未提供描述,使用默认描述
|
|||
|
|
if not description:
|
|||
|
|
description = f"用户上传的数据集文件: {filename}"
|
|||
|
|
|
|||
|
|
# 使用文件上传服务上传文件
|
|||
|
|
uploaded_file = await file_upload_service.upload_file(
|
|||
|
|
file_content=file_content,
|
|||
|
|
original_filename=filename,
|
|||
|
|
content_type=file.content_type,
|
|||
|
|
description=description
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 转换为前端期望的格式
|
|||
|
|
# 显示真实文件名(从映射文件中获取)
|
|||
|
|
mapping = file_upload_service.get_filename_mapping(uploaded_file.file_id)
|
|||
|
|
display_name = mapping["original_filename"] if mapping else uploaded_file.original_filename
|
|||
|
|
|
|||
|
|
# 格式化文件大小
|
|||
|
|
size_mb = round(uploaded_file.file_size / 1024 / 1024, 2)
|
|||
|
|
size_display = format_file_size(uploaded_file.file_size)
|
|||
|
|
|
|||
|
|
dataset_info = {
|
|||
|
|
"file_id": uploaded_file.file_id,
|
|||
|
|
"name": display_name,
|
|||
|
|
"size": uploaded_file.file_size,
|
|||
|
|
"size_mb": size_mb,
|
|||
|
|
"size_display": size_display,
|
|||
|
|
"status": "已处理", # 默认状态
|
|||
|
|
"uploaded_at": uploaded_file.uploaded_at,
|
|||
|
|
"description": uploaded_file.description
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return StandardResponse.success({
|
|||
|
|
"message": "数据集上传成功",
|
|||
|
|
"dataset": dataset_info
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
except ValueError as e:
|
|||
|
|
return StandardResponse.error(str(e))
|
|||
|
|
except Exception as e:
|
|||
|
|
return StandardResponse.error(f"上传失败: {str(e)}")
|
|||
|
|
|
|||
|
|
@get("", response_model=StandardResponse)
|
|||
|
|
async def list_datasets(self, list_all: bool = False):
|
|||
|
|
"""
|
|||
|
|
获取所有数据集列表
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
list_all: 是否列出data目录下的所有文件(物理文件),默认False(只列出API上传的文件)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
StandardResponse: 包含数据集列表的标准响应
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
if list_all:
|
|||
|
|
# 列出data目录下的所有文件(物理文件)
|
|||
|
|
import json
|
|||
|
|
from pathlib import Path
|
|||
|
|
|
|||
|
|
data_dir = file_upload_service.upload_dir
|
|||
|
|
mapping_file = data_dir / "filename_mapping.json"
|
|||
|
|
|
|||
|
|
# 读取文件名映射
|
|||
|
|
mappings = {}
|
|||
|
|
if mapping_file.exists():
|
|||
|
|
try:
|
|||
|
|
with open(mapping_file, 'r', encoding='utf-8') as f:
|
|||
|
|
mapping_data = json.load(f)
|
|||
|
|
mappings = mapping_data.get("mappings", {})
|
|||
|
|
except Exception:
|
|||
|
|
mappings = {}
|
|||
|
|
|
|||
|
|
# 获取data目录下的所有JSON文件
|
|||
|
|
datasets = []
|
|||
|
|
if data_dir.exists():
|
|||
|
|
for file_path in data_dir.iterdir():
|
|||
|
|
# 跳过目录和映射文件本身
|
|||
|
|
if file_path.is_file() and file_path.name != "filename_mapping.json":
|
|||
|
|
file_id = file_path.stem # 去掉.json后缀得到file_id
|
|||
|
|
|
|||
|
|
# 从映射文件获取真实文件名
|
|||
|
|
mapping_info = mappings.get(file_id, {})
|
|||
|
|
original_filename = mapping_info.get("original_filename", file_path.name)
|
|||
|
|
uploaded_at = mapping_info.get("uploaded_at", "")
|
|||
|
|
|
|||
|
|
# 获取文件大小
|
|||
|
|
file_size = file_path.stat().st_size
|
|||
|
|
|
|||
|
|
# 格式化文件大小
|
|||
|
|
size_mb = round(file_size / 1024 / 1024, 2)
|
|||
|
|
size_display = format_file_size(file_size)
|
|||
|
|
|
|||
|
|
datasets.append({
|
|||
|
|
"file_id": file_id,
|
|||
|
|
"name": original_filename,
|
|||
|
|
"size": file_size,
|
|||
|
|
"size_mb": size_mb,
|
|||
|
|
"size_display": size_display,
|
|||
|
|
"status": "已处理",
|
|||
|
|
"description": mapping_info.get("original_filename", "") if mapping_info else "",
|
|||
|
|
"uploaded_at": uploaded_at,
|
|||
|
|
"download_count": 0,
|
|||
|
|
"is_physical_file": True
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
# 按文件名排序
|
|||
|
|
datasets.sort(key=lambda x: x["name"])
|
|||
|
|
|
|||
|
|
return StandardResponse.success({
|
|||
|
|
"datasets": datasets,
|
|||
|
|
"total": len(datasets),
|
|||
|
|
"source": "physical_files"
|
|||
|
|
})
|
|||
|
|
else:
|
|||
|
|
# 获取所有文件(API上传的文件)
|
|||
|
|
all_files = file_upload_service.get_all_files()
|
|||
|
|
|
|||
|
|
# 转换为前端期望的格式
|
|||
|
|
datasets = []
|
|||
|
|
for uploaded_file in all_files:
|
|||
|
|
# 只返回JSON/JSONL文件(数据集文件)
|
|||
|
|
file_ext = uploaded_file.original_filename.lower().split('.')[-1] if '.' in uploaded_file.original_filename else ''
|
|||
|
|
if file_ext in ['json', 'jsonl']:
|
|||
|
|
# 获取文件名映射(显示真实文件名)
|
|||
|
|
mapping = file_upload_service.get_filename_mapping(uploaded_file.file_id)
|
|||
|
|
display_name = mapping["original_filename"] if mapping else uploaded_file.original_filename
|
|||
|
|
|
|||
|
|
# 格式化文件大小
|
|||
|
|
size_mb = round(uploaded_file.file_size / 1024 / 1024, 2)
|
|||
|
|
size_display = format_file_size(uploaded_file.file_size)
|
|||
|
|
|
|||
|
|
datasets.append({
|
|||
|
|
"file_id": uploaded_file.file_id,
|
|||
|
|
"name": display_name,
|
|||
|
|
"size": uploaded_file.file_size,
|
|||
|
|
"size_mb": size_mb,
|
|||
|
|
"size_display": size_display,
|
|||
|
|
"status": "已处理",
|
|||
|
|
"description": uploaded_file.description or "",
|
|||
|
|
"uploaded_at": uploaded_file.uploaded_at,
|
|||
|
|
"download_count": uploaded_file.download_count,
|
|||
|
|
"is_physical_file": False
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
return StandardResponse.success({
|
|||
|
|
"datasets": datasets,
|
|||
|
|
"total": len(datasets),
|
|||
|
|
"source": "api_uploaded"
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
return StandardResponse.error(f"获取数据集列表失败: {str(e)}")
|
|||
|
|
|
|||
|
|
@get("/{file_id}", response_model=StandardResponse)
|
|||
|
|
async def get_dataset(self, file_id: str):
|
|||
|
|
"""
|
|||
|
|
获取特定数据集的详细信息
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
file_id: 文件ID
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
StandardResponse: 包含数据集详情的标准响应
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
file_info = file_upload_service.get_file(file_id)
|
|||
|
|
|
|||
|
|
if not file_info:
|
|||
|
|
return StandardResponse.error(f"数据集 {file_id} 不存在")
|
|||
|
|
|
|||
|
|
# 转换为前端期望的格式
|
|||
|
|
# 显示真实文件名(从映射文件中获取)
|
|||
|
|
mapping = file_upload_service.get_filename_mapping(file_info.file_id)
|
|||
|
|
display_name = mapping["original_filename"] if mapping else file_info.original_filename
|
|||
|
|
|
|||
|
|
# 格式化文件大小
|
|||
|
|
size_mb = round(file_info.file_size / 1024 / 1024, 2)
|
|||
|
|
size_display = format_file_size(file_info.file_size)
|
|||
|
|
|
|||
|
|
dataset_info = {
|
|||
|
|
"file_id": file_info.file_id,
|
|||
|
|
"name": display_name,
|
|||
|
|
"size": file_info.file_size,
|
|||
|
|
"size_mb": size_mb,
|
|||
|
|
"size_display": size_display,
|
|||
|
|
"status": "已处理",
|
|||
|
|
"description": file_info.description or "",
|
|||
|
|
"uploaded_at": file_info.uploaded_at,
|
|||
|
|
"updated_at": file_info.updated_at,
|
|||
|
|
"download_count": file_info.download_count,
|
|||
|
|
"content_type": file_info.content_type,
|
|||
|
|
"file_hash": file_info.file_hash
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return StandardResponse.success(dataset_info)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
return StandardResponse.error(f"获取数据集详情失败: {str(e)}")
|
|||
|
|
|
|||
|
|
@get("/{file_id}", response_model=StandardResponse)
|
|||
|
|
async def get_dataset(self, file_id: str):
|
|||
|
|
"""
|
|||
|
|
获取特定数据集的详细信息
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
file_id: 文件ID
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
StandardResponse: 包含数据集详情的标准响应
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
file_info = file_upload_service.get_file(file_id)
|
|||
|
|
|
|||
|
|
if not file_info:
|
|||
|
|
return StandardResponse.error(f"数据集 {file_id} 不存在")
|
|||
|
|
|
|||
|
|
# 转换为前端期望的格式
|
|||
|
|
# 显示真实文件名(从映射文件中获取)
|
|||
|
|
mapping = file_upload_service.get_filename_mapping(file_info.file_id)
|
|||
|
|
display_name = mapping["original_filename"] if mapping else file_info.original_filename
|
|||
|
|
|
|||
|
|
# 格式化文件大小
|
|||
|
|
size_mb = round(file_info.file_size / 1024 / 1024, 2)
|
|||
|
|
size_display = format_file_size(file_info.file_size)
|
|||
|
|
|
|||
|
|
dataset_info = {
|
|||
|
|
"file_id": file_info.file_id,
|
|||
|
|
"name": display_name,
|
|||
|
|
"size": file_info.file_size,
|
|||
|
|
"size_mb": size_mb,
|
|||
|
|
"size_display": size_display,
|
|||
|
|
"status": "已处理",
|
|||
|
|
"description": file_info.description or "",
|
|||
|
|
"uploaded_at": file_info.uploaded_at,
|
|||
|
|
"updated_at": file_info.updated_at,
|
|||
|
|
"download_count": file_info.download_count,
|
|||
|
|
"content_type": file_info.content_type,
|
|||
|
|
"file_hash": file_info.file_hash
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return StandardResponse.success(dataset_info)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
return StandardResponse.error(f"获取数据集详情失败: {str(e)}")
|
|||
|
|
|
|||
|
|
@get("/{file_id}/content", response_model=StandardResponse)
|
|||
|
|
async def get_dataset_content(self, file_id: str, limit: int = 5):
|
|||
|
|
"""
|
|||
|
|
获取数据集文件内容(前N条记录)
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
file_id: 文件ID
|
|||
|
|
limit: 返回的记录数量,默认5条
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
StandardResponse: 包含数据集内容的标准响应
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
import json
|
|||
|
|
import jsonlines
|
|||
|
|
|
|||
|
|
# 获取文件信息
|
|||
|
|
file_info = file_upload_service.get_file(file_id)
|
|||
|
|
if not file_info:
|
|||
|
|
return StandardResponse.error(f"数据集 {file_id} 不存在")
|
|||
|
|
|
|||
|
|
# 获取文件路径
|
|||
|
|
file_path = file_upload_service.get_file_path(file_id)
|
|||
|
|
if not file_path or not file_path.exists():
|
|||
|
|
return StandardResponse.error(f"文件 {file_id} 不存在")
|
|||
|
|
|
|||
|
|
# 读取文件内容
|
|||
|
|
content_preview = []
|
|||
|
|
filename = file_info.original_filename.lower()
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
if filename.endswith('.jsonl'):
|
|||
|
|
# 处理JSONL格式
|
|||
|
|
with jsonlines.open(file_path) as reader:
|
|||
|
|
count = 0
|
|||
|
|
for item in reader:
|
|||
|
|
if count >= limit:
|
|||
|
|
break
|
|||
|
|
content_preview.append(item)
|
|||
|
|
count += 1
|
|||
|
|
else:
|
|||
|
|
# 处理JSON格式
|
|||
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|||
|
|
data = json.load(f)
|
|||
|
|
if isinstance(data, list):
|
|||
|
|
# 如果是数组,取前N条
|
|||
|
|
content_preview = data[:limit]
|
|||
|
|
else:
|
|||
|
|
# 如果是对象,直接返回
|
|||
|
|
content_preview = data
|
|||
|
|
|
|||
|
|
except json.JSONDecodeError as e:
|
|||
|
|
return StandardResponse.error(f"JSON文件格式错误: {str(e)}")
|
|||
|
|
except Exception as e:
|
|||
|
|
return StandardResponse.error(f"读取文件内容失败: {str(e)}")
|
|||
|
|
|
|||
|
|
# 获取真实文件名(从映射文件中获取)
|
|||
|
|
mapping = file_upload_service.get_filename_mapping(file_id)
|
|||
|
|
display_filename = mapping["original_filename"] if mapping else file_info.original_filename
|
|||
|
|
|
|||
|
|
return StandardResponse.success({
|
|||
|
|
"file_id": file_id,
|
|||
|
|
"filename": display_filename,
|
|||
|
|
"total_records": len(content_preview),
|
|||
|
|
"preview": content_preview
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
return StandardResponse.error(f"获取数据集内容失败: {str(e)}")
|
|||
|
|
|
|||
|
|
@delete("/{file_id}", response_model=StandardResponse)
|
|||
|
|
async def delete_dataset(self, file_id: str):
|
|||
|
|
"""
|
|||
|
|
删除数据集
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
file_id: 文件ID
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
StandardResponse: 包含删除结果的标准响应
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
if not file_upload_service.file_exists(file_id):
|
|||
|
|
return StandardResponse.error(f"数据集 {file_id} 不存在")
|
|||
|
|
|
|||
|
|
success = file_upload_service.delete_file(file_id)
|
|||
|
|
|
|||
|
|
if success:
|
|||
|
|
return StandardResponse.success({
|
|||
|
|
"message": f"数据集 {file_id} 已删除"
|
|||
|
|
})
|
|||
|
|
else:
|
|||
|
|
return StandardResponse.error(f"删除数据集 {file_id} 失败")
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
return StandardResponse.error(f"删除数据集失败: {str(e)}")
|
|||
|
|
|
|||
|
|
@get("/list-files", response_model=StandardResponse)
|
|||
|
|
async def list_data_files(self):
|
|||
|
|
"""
|
|||
|
|
查询data目录下的文件列表
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
StandardResponse: 包含文件列表的标准响应
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
import json
|
|||
|
|
import os
|
|||
|
|
from pathlib import Path
|
|||
|
|
|
|||
|
|
data_dir = file_upload_service.upload_dir
|
|||
|
|
mapping_file = data_dir / "filename_mapping.json"
|
|||
|
|
|
|||
|
|
# 读取文件名映射
|
|||
|
|
mappings = {}
|
|||
|
|
if mapping_file.exists():
|
|||
|
|
try:
|
|||
|
|
with open(mapping_file, 'r', encoding='utf-8') as f:
|
|||
|
|
mapping_data = json.load(f)
|
|||
|
|
mappings = mapping_data.get("mappings", {})
|
|||
|
|
except Exception:
|
|||
|
|
mappings = {}
|
|||
|
|
|
|||
|
|
# 获取data目录下的所有JSON文件
|
|||
|
|
files_info = []
|
|||
|
|
if data_dir.exists():
|
|||
|
|
for file_path in data_dir.iterdir():
|
|||
|
|
# 跳过目录和映射文件本身
|
|||
|
|
if file_path.is_file() and file_path.name != "filename_mapping.json":
|
|||
|
|
file_id = file_path.stem # 去掉.json后缀得到file_id
|
|||
|
|
|
|||
|
|
# 从映射文件获取真实文件名
|
|||
|
|
mapping_info = mappings.get(file_id, {})
|
|||
|
|
original_filename = mapping_info.get("original_filename", file_path.name)
|
|||
|
|
uploaded_at = mapping_info.get("uploaded_at", "")
|
|||
|
|
|
|||
|
|
# 获取文件大小
|
|||
|
|
file_size = file_path.stat().st_size
|
|||
|
|
|
|||
|
|
files_info.append({
|
|||
|
|
"file_id": file_id,
|
|||
|
|
"original_filename": original_filename,
|
|||
|
|
"storage_filename": file_path.name,
|
|||
|
|
"file_path": str(file_path),
|
|||
|
|
"file_size": file_size,
|
|||
|
|
"file_size_mb": round(file_size / 1024 / 1024, 2),
|
|||
|
|
"uploaded_at": uploaded_at,
|
|||
|
|
"exists_in_mapping": file_id in mappings
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
# 按文件名排序
|
|||
|
|
files_info.sort(key=lambda x: x["original_filename"])
|
|||
|
|
|
|||
|
|
return StandardResponse.success({
|
|||
|
|
"total": len(files_info),
|
|||
|
|
"files": files_info
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
return StandardResponse.error(f"查询文件列表失败: {str(e)}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 创建实例(自动发现系统会找到这个实例)
|
|||
|
|
dataset_api = DatasetAPI()
|