486 lines
19 KiB
Python
486 lines
19 KiB
Python
"""
|
||
数据集管理 API 模块
|
||
提供数据集上传、列表、删除等功能
|
||
"""
|
||
|
||
from typing import List, Optional
|
||
from fastapi import UploadFile, File, Form
|
||
|
||
from src.api.internal.base import BaseAPI, post, get, delete
|
||
from src.models.response import StandardResponse
|
||
from src.services.file_upload import file_upload_service
|
||
|
||
|
||
def format_file_size(size_bytes: int) -> str:
|
||
"""
|
||
格式化文件大小显示
|
||
|
||
Args:
|
||
size_bytes: 文件大小(字节)
|
||
|
||
Returns:
|
||
str: 格式化后的文件大小字符串
|
||
"""
|
||
if size_bytes < 1024:
|
||
return f"{size_bytes} B"
|
||
elif size_bytes < 1024 * 1024:
|
||
size_kb = round(size_bytes / 1024, 2)
|
||
return f"{size_kb} KB"
|
||
else:
|
||
size_mb = round(size_bytes / 1024 / 1024, 2)
|
||
return f"{size_mb} MB"
|
||
|
||
|
||
class DatasetAPI(BaseAPI):
|
||
"""数据集管理 API"""
|
||
|
||
def __init__(self):
|
||
"""初始化"""
|
||
# 在调用super().__init__()之前设置module_name
|
||
self._override_module_name = "datasets"
|
||
super().__init__()
|
||
self.logger.info("DatasetAPI 初始化完成")
|
||
|
||
@post("/upload", response_model=StandardResponse)
|
||
async def upload_dataset(
|
||
self,
|
||
file: UploadFile = File(...),
|
||
description: Optional[str] = Form(None)
|
||
):
|
||
"""
|
||
上传数据集文件
|
||
|
||
Args:
|
||
file: 上传的文件(支持 .json, .jsonl 格式)
|
||
description: 文件描述(可选)
|
||
|
||
Returns:
|
||
StandardResponse: 包含上传结果的标准响应
|
||
"""
|
||
try:
|
||
# 验证文件类型
|
||
filename = file.filename or "unknown"
|
||
file_ext = filename.lower().split('.')[-1] if '.' in filename else ''
|
||
|
||
if file_ext not in ['json', 'jsonl']:
|
||
return StandardResponse.error("只支持 .json 和 .jsonl 格式的文件")
|
||
|
||
# 读取文件内容
|
||
file_content = await file.read()
|
||
|
||
# 如果未提供描述,使用默认描述
|
||
if not description:
|
||
description = f"用户上传的数据集文件: {filename}"
|
||
|
||
# 使用文件上传服务上传文件
|
||
uploaded_file = await file_upload_service.upload_file(
|
||
file_content=file_content,
|
||
original_filename=filename,
|
||
content_type=file.content_type,
|
||
description=description
|
||
)
|
||
|
||
# 转换为前端期望的格式
|
||
# 显示真实文件名(从映射文件中获取)
|
||
mapping = file_upload_service.get_filename_mapping(uploaded_file.file_id)
|
||
display_name = mapping["original_filename"] if mapping else uploaded_file.original_filename
|
||
|
||
# 格式化文件大小
|
||
size_mb = round(uploaded_file.file_size / 1024 / 1024, 2)
|
||
size_display = format_file_size(uploaded_file.file_size)
|
||
|
||
dataset_info = {
|
||
"file_id": uploaded_file.file_id,
|
||
"name": display_name,
|
||
"size": uploaded_file.file_size,
|
||
"size_mb": size_mb,
|
||
"size_display": size_display,
|
||
"status": "已处理", # 默认状态
|
||
"uploaded_at": uploaded_file.uploaded_at,
|
||
"description": uploaded_file.description
|
||
}
|
||
|
||
return StandardResponse.success({
|
||
"message": "数据集上传成功",
|
||
"dataset": dataset_info
|
||
})
|
||
|
||
except ValueError as e:
|
||
return StandardResponse.error(str(e))
|
||
except Exception as e:
|
||
return StandardResponse.error(f"上传失败: {str(e)}")
|
||
|
||
@get("", response_model=StandardResponse)
|
||
async def list_datasets(self, list_all: bool = False):
|
||
"""
|
||
获取所有数据集列表
|
||
|
||
Args:
|
||
list_all: 是否列出data目录下的所有文件(物理文件),默认False(只列出API上传的文件)
|
||
|
||
Returns:
|
||
StandardResponse: 包含数据集列表的标准响应
|
||
"""
|
||
# 添加调试日志
|
||
import logging
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger(__name__)
|
||
logger.info(f"list_datasets called with list_all={list_all}")
|
||
|
||
try:
|
||
if list_all:
|
||
# 列出data目录下的所有文件(物理文件)
|
||
import json
|
||
from pathlib import Path
|
||
|
||
data_dir = file_upload_service.upload_dir
|
||
mapping_file = data_dir / "filename_mapping.json"
|
||
|
||
# 读取文件名映射
|
||
mappings = {}
|
||
if mapping_file.exists():
|
||
try:
|
||
with open(mapping_file, 'r', encoding='utf-8') as f:
|
||
mapping_data = json.load(f)
|
||
mappings = mapping_data.get("mappings", {})
|
||
except Exception:
|
||
mappings = {}
|
||
|
||
# 获取data目录下的所有JSON文件
|
||
datasets = []
|
||
if data_dir.exists():
|
||
for file_path in data_dir.iterdir():
|
||
# 跳过目录和映射文件本身
|
||
if file_path.is_file() and file_path.name != "filename_mapping.json":
|
||
file_id = file_path.stem # 去掉.json后缀得到file_id
|
||
|
||
# 从映射文件获取真实文件名
|
||
mapping_info = mappings.get(file_id, {})
|
||
original_filename = mapping_info.get("original_filename", file_path.name)
|
||
uploaded_at = mapping_info.get("uploaded_at", "")
|
||
|
||
# 获取文件大小
|
||
file_size = file_path.stat().st_size
|
||
|
||
# 格式化文件大小
|
||
size_mb = round(file_size / 1024 / 1024, 2)
|
||
size_display = format_file_size(file_size)
|
||
|
||
datasets.append({
|
||
"file_id": file_id,
|
||
"name": original_filename,
|
||
"size": file_size,
|
||
"size_mb": size_mb,
|
||
"size_display": size_display,
|
||
"status": "已处理",
|
||
"description": mapping_info.get("original_filename", "") if mapping_info else "",
|
||
"uploaded_at": uploaded_at,
|
||
"download_count": 0,
|
||
"is_physical_file": True
|
||
})
|
||
|
||
# 按文件名排序
|
||
datasets.sort(key=lambda x: x["name"])
|
||
|
||
return StandardResponse.success({
|
||
"datasets": datasets,
|
||
"total": len(datasets),
|
||
"source": "physical_files"
|
||
})
|
||
else:
|
||
# 获取所有文件(API上传的文件)
|
||
all_files = file_upload_service.get_all_files()
|
||
|
||
# 转换为前端期望的格式
|
||
datasets = []
|
||
for uploaded_file in all_files:
|
||
# 只返回JSON/JSONL文件(数据集文件)
|
||
file_ext = uploaded_file.original_filename.lower().split('.')[-1] if '.' in uploaded_file.original_filename else ''
|
||
if file_ext in ['json', 'jsonl']:
|
||
# 获取文件名映射(显示真实文件名)
|
||
mapping = file_upload_service.get_filename_mapping(uploaded_file.file_id)
|
||
display_name = mapping["original_filename"] if mapping else uploaded_file.original_filename
|
||
|
||
# 格式化文件大小
|
||
size_mb = round(uploaded_file.file_size / 1024 / 1024, 2)
|
||
size_display = format_file_size(uploaded_file.file_size)
|
||
|
||
datasets.append({
|
||
"file_id": uploaded_file.file_id,
|
||
"name": display_name,
|
||
"size": uploaded_file.file_size,
|
||
"size_mb": size_mb,
|
||
"size_display": size_display,
|
||
"status": "已处理",
|
||
"description": uploaded_file.description or "",
|
||
"uploaded_at": uploaded_file.uploaded_at,
|
||
"download_count": uploaded_file.download_count,
|
||
"is_physical_file": False
|
||
})
|
||
|
||
return StandardResponse.success({
|
||
"datasets": datasets,
|
||
"total": len(datasets),
|
||
"source": "api_uploaded"
|
||
})
|
||
|
||
except Exception as e:
|
||
return StandardResponse.error(f"获取数据集列表失败: {str(e)}")
|
||
|
||
@get("/{file_id}", response_model=StandardResponse)
|
||
async def get_dataset(self, file_id: str):
|
||
"""
|
||
获取特定数据集的详细信息
|
||
|
||
Args:
|
||
file_id: 文件ID
|
||
|
||
Returns:
|
||
StandardResponse: 包含数据集详情的标准响应
|
||
"""
|
||
try:
|
||
file_info = file_upload_service.get_file(file_id)
|
||
|
||
if not file_info:
|
||
return StandardResponse.error(f"数据集 {file_id} 不存在")
|
||
|
||
# 转换为前端期望的格式
|
||
# 显示真实文件名(从映射文件中获取)
|
||
mapping = file_upload_service.get_filename_mapping(file_info.file_id)
|
||
display_name = mapping["original_filename"] if mapping else file_info.original_filename
|
||
|
||
# 格式化文件大小
|
||
size_mb = round(file_info.file_size / 1024 / 1024, 2)
|
||
size_display = format_file_size(file_info.file_size)
|
||
|
||
dataset_info = {
|
||
"file_id": file_info.file_id,
|
||
"name": display_name,
|
||
"size": file_info.file_size,
|
||
"size_mb": size_mb,
|
||
"size_display": size_display,
|
||
"status": "已处理",
|
||
"description": file_info.description or "",
|
||
"uploaded_at": file_info.uploaded_at,
|
||
"updated_at": file_info.updated_at,
|
||
"download_count": file_info.download_count,
|
||
"content_type": file_info.content_type,
|
||
"file_hash": file_info.file_hash
|
||
}
|
||
|
||
return StandardResponse.success(dataset_info)
|
||
|
||
except Exception as e:
|
||
return StandardResponse.error(f"获取数据集详情失败: {str(e)}")
|
||
|
||
@get("/{file_id}", response_model=StandardResponse)
|
||
async def get_dataset(self, file_id: str):
|
||
"""
|
||
获取特定数据集的详细信息
|
||
|
||
Args:
|
||
file_id: 文件ID
|
||
|
||
Returns:
|
||
StandardResponse: 包含数据集详情的标准响应
|
||
"""
|
||
try:
|
||
file_info = file_upload_service.get_file(file_id)
|
||
|
||
if not file_info:
|
||
return StandardResponse.error(f"数据集 {file_id} 不存在")
|
||
|
||
# 转换为前端期望的格式
|
||
# 显示真实文件名(从映射文件中获取)
|
||
mapping = file_upload_service.get_filename_mapping(file_info.file_id)
|
||
display_name = mapping["original_filename"] if mapping else file_info.original_filename
|
||
|
||
# 格式化文件大小
|
||
size_mb = round(file_info.file_size / 1024 / 1024, 2)
|
||
size_display = format_file_size(file_info.file_size)
|
||
|
||
dataset_info = {
|
||
"file_id": file_info.file_id,
|
||
"name": display_name,
|
||
"size": file_info.file_size,
|
||
"size_mb": size_mb,
|
||
"size_display": size_display,
|
||
"status": "已处理",
|
||
"description": file_info.description or "",
|
||
"uploaded_at": file_info.uploaded_at,
|
||
"updated_at": file_info.updated_at,
|
||
"download_count": file_info.download_count,
|
||
"content_type": file_info.content_type,
|
||
"file_hash": file_info.file_hash
|
||
}
|
||
|
||
return StandardResponse.success(dataset_info)
|
||
|
||
except Exception as e:
|
||
return StandardResponse.error(f"获取数据集详情失败: {str(e)}")
|
||
|
||
@get("/{file_id}/content", response_model=StandardResponse)
|
||
async def get_dataset_content(self, file_id: str, limit: int = 5):
|
||
"""
|
||
获取数据集文件内容(前N条记录)
|
||
|
||
Args:
|
||
file_id: 文件ID
|
||
limit: 返回的记录数量,默认5条
|
||
|
||
Returns:
|
||
StandardResponse: 包含数据集内容的标准响应
|
||
"""
|
||
try:
|
||
import json
|
||
import jsonlines
|
||
|
||
# 获取文件信息
|
||
file_info = file_upload_service.get_file(file_id)
|
||
if not file_info:
|
||
return StandardResponse.error(f"数据集 {file_id} 不存在")
|
||
|
||
# 获取文件路径
|
||
file_path = file_upload_service.get_file_path(file_id)
|
||
if not file_path or not file_path.exists():
|
||
return StandardResponse.error(f"文件 {file_id} 不存在")
|
||
|
||
# 读取文件内容
|
||
content_preview = []
|
||
filename = file_info.original_filename.lower()
|
||
|
||
try:
|
||
if filename.endswith('.jsonl'):
|
||
# 处理JSONL格式
|
||
with jsonlines.open(file_path) as reader:
|
||
count = 0
|
||
for item in reader:
|
||
if count >= limit:
|
||
break
|
||
content_preview.append(item)
|
||
count += 1
|
||
else:
|
||
# 处理JSON格式
|
||
with open(file_path, 'r', encoding='utf-8') as f:
|
||
data = json.load(f)
|
||
if isinstance(data, list):
|
||
# 如果是数组,取前N条
|
||
content_preview = data[:limit]
|
||
else:
|
||
# 如果是对象,直接返回
|
||
content_preview = data
|
||
|
||
except json.JSONDecodeError as e:
|
||
return StandardResponse.error(f"JSON文件格式错误: {str(e)}")
|
||
except Exception as e:
|
||
return StandardResponse.error(f"读取文件内容失败: {str(e)}")
|
||
|
||
# 获取真实文件名(从映射文件中获取)
|
||
mapping = file_upload_service.get_filename_mapping(file_id)
|
||
display_filename = mapping["original_filename"] if mapping else file_info.original_filename
|
||
|
||
return StandardResponse.success({
|
||
"file_id": file_id,
|
||
"filename": display_filename,
|
||
"total_records": len(content_preview),
|
||
"preview": content_preview
|
||
})
|
||
|
||
except Exception as e:
|
||
return StandardResponse.error(f"获取数据集内容失败: {str(e)}")
|
||
|
||
@delete("/{file_id}", response_model=StandardResponse)
|
||
async def delete_dataset(self, file_id: str):
|
||
"""
|
||
删除数据集
|
||
|
||
Args:
|
||
file_id: 文件ID
|
||
|
||
Returns:
|
||
StandardResponse: 包含删除结果的标准响应
|
||
"""
|
||
try:
|
||
if not file_upload_service.file_exists(file_id):
|
||
return StandardResponse.error(f"数据集 {file_id} 不存在")
|
||
|
||
success = file_upload_service.delete_file(file_id)
|
||
|
||
if success:
|
||
return StandardResponse.success({
|
||
"message": f"数据集 {file_id} 已删除"
|
||
})
|
||
else:
|
||
return StandardResponse.error(f"删除数据集 {file_id} 失败")
|
||
|
||
except Exception as e:
|
||
return StandardResponse.error(f"删除数据集失败: {str(e)}")
|
||
|
||
@get("/list-files", response_model=StandardResponse)
|
||
async def list_data_files(self):
|
||
"""
|
||
查询data目录下的文件列表
|
||
|
||
Returns:
|
||
StandardResponse: 包含文件列表的标准响应
|
||
"""
|
||
try:
|
||
import json
|
||
import os
|
||
from pathlib import Path
|
||
|
||
data_dir = file_upload_service.upload_dir
|
||
mapping_file = data_dir / "filename_mapping.json"
|
||
|
||
# 读取文件名映射
|
||
mappings = {}
|
||
if mapping_file.exists():
|
||
try:
|
||
with open(mapping_file, 'r', encoding='utf-8') as f:
|
||
mapping_data = json.load(f)
|
||
mappings = mapping_data.get("mappings", {})
|
||
except Exception:
|
||
mappings = {}
|
||
|
||
# 获取data目录下的所有JSON文件
|
||
files_info = []
|
||
if data_dir.exists():
|
||
for file_path in data_dir.iterdir():
|
||
# 跳过目录和映射文件本身
|
||
if file_path.is_file() and file_path.name != "filename_mapping.json":
|
||
file_id = file_path.stem # 去掉.json后缀得到file_id
|
||
|
||
# 从映射文件获取真实文件名
|
||
mapping_info = mappings.get(file_id, {})
|
||
original_filename = mapping_info.get("original_filename", file_path.name)
|
||
uploaded_at = mapping_info.get("uploaded_at", "")
|
||
|
||
# 获取文件大小
|
||
file_size = file_path.stat().st_size
|
||
|
||
files_info.append({
|
||
"file_id": file_id,
|
||
"original_filename": original_filename,
|
||
"storage_filename": file_path.name,
|
||
"file_path": str(file_path),
|
||
"file_size": file_size,
|
||
"file_size_mb": round(file_size / 1024 / 1024, 2),
|
||
"uploaded_at": uploaded_at,
|
||
"exists_in_mapping": file_id in mappings
|
||
})
|
||
|
||
# 按文件名排序
|
||
files_info.sort(key=lambda x: x["original_filename"])
|
||
|
||
return StandardResponse.success({
|
||
"total": len(files_info),
|
||
"files": files_info
|
||
})
|
||
|
||
except Exception as e:
|
||
return StandardResponse.error(f"查询文件列表失败: {str(e)}")
|
||
|
||
|
||
# 创建实例(自动发现系统会找到这个实例)
|
||
dataset_api = DatasetAPI()
|