Files
YG_FT_Platform/request/src/api/modules/dataset.py

502 lines
19 KiB
Python
Raw Normal View History

"""
数据集管理 API 模块
提供数据集上传列表删除等功能
"""
from typing import List, Optional
from fastapi import UploadFile, File, Form
from src.api.internal.base import BaseAPI, post, get, delete
from src.models.response import StandardResponse
from src.services.file_upload import file_upload_service
def format_file_size(size_bytes: int) -> str:
"""
格式化文件大小显示
Args:
size_bytes: 文件大小字节
Returns:
str: 格式化后的文件大小字符串
"""
if size_bytes < 1024:
return f"{size_bytes} B"
elif size_bytes < 1024 * 1024:
size_kb = round(size_bytes / 1024, 2)
return f"{size_kb} KB"
else:
size_mb = round(size_bytes / 1024 / 1024, 2)
return f"{size_mb} MB"
class DatasetAPI(BaseAPI):
"""数据集管理 API - 自动注册到 /api/datasets 路径"""
def __init__(self):
# 重写初始化逻辑以设置正确的路由前缀
# 1. 手动设置 module_name
self.module_name = "api.datasets"
# 2. 创建路由器(使用期望的前缀)
from fastapi import APIRouter
self.router_prefix = "/api/datasets"
self.router = APIRouter(
prefix=self.router_prefix,
tags=["Datasets"]
)
# 3. 获取日志器
import sys
self.logger = __import__('src.utils.logger', fromlist=['get_logger']).get_logger(self.__class__.__module__)
# 4. 调用基类的自动注册此时router已被覆盖
# 注意我们不调用父类__init__而是手动调用_auto_register_routes
self._auto_register_routes()
# 5. 记录初始化
self.logger.info(
f"API模块初始化完成",
module=self.module_name,
prefix=self.router_prefix,
routes=len(self.router.routes)
)
@post("/upload", response_model=StandardResponse)
async def upload_dataset(
self,
file: UploadFile = File(...),
description: Optional[str] = Form(None)
):
"""
上传数据集文件
Args:
file: 上传的文件支持 .json, .jsonl 格式
description: 文件描述可选
Returns:
StandardResponse: 包含上传结果的标准响应
"""
try:
# 验证文件类型
filename = file.filename or "unknown"
file_ext = filename.lower().split('.')[-1] if '.' in filename else ''
if file_ext not in ['json', 'jsonl']:
return StandardResponse.error("只支持 .json 和 .jsonl 格式的文件")
# 读取文件内容
file_content = await file.read()
# 如果未提供描述,使用默认描述
if not description:
description = f"用户上传的数据集文件: {filename}"
# 使用文件上传服务上传文件
uploaded_file = await file_upload_service.upload_file(
file_content=file_content,
original_filename=filename,
content_type=file.content_type,
description=description
)
# 转换为前端期望的格式
# 显示真实文件名(从映射文件中获取)
mapping = file_upload_service.get_filename_mapping(uploaded_file.file_id)
display_name = mapping["original_filename"] if mapping else uploaded_file.original_filename
# 格式化文件大小
size_mb = round(uploaded_file.file_size / 1024 / 1024, 2)
size_display = format_file_size(uploaded_file.file_size)
dataset_info = {
"file_id": uploaded_file.file_id,
"name": display_name,
"size": uploaded_file.file_size,
"size_mb": size_mb,
"size_display": size_display,
"status": "已处理", # 默认状态
"uploaded_at": uploaded_file.uploaded_at,
"description": uploaded_file.description
}
return StandardResponse.success({
"message": "数据集上传成功",
"dataset": dataset_info
})
except ValueError as e:
return StandardResponse.error(str(e))
except Exception as e:
return StandardResponse.error(f"上传失败: {str(e)}")
@get("", response_model=StandardResponse)
async def list_datasets(self, list_all: bool = False):
"""
获取所有数据集列表
Args:
list_all: 是否列出data目录下的所有文件物理文件默认False只列出API上传的文件
Returns:
StandardResponse: 包含数据集列表的标准响应
"""
try:
if list_all:
# 列出data目录下的所有文件物理文件
import json
from pathlib import Path
data_dir = file_upload_service.upload_dir
mapping_file = data_dir / "filename_mapping.json"
# 读取文件名映射
mappings = {}
if mapping_file.exists():
try:
with open(mapping_file, 'r', encoding='utf-8') as f:
mapping_data = json.load(f)
mappings = mapping_data.get("mappings", {})
except Exception:
mappings = {}
# 获取data目录下的所有JSON文件
datasets = []
if data_dir.exists():
for file_path in data_dir.iterdir():
# 跳过目录和映射文件本身
if file_path.is_file() and file_path.name != "filename_mapping.json":
file_id = file_path.stem # 去掉.json后缀得到file_id
# 从映射文件获取真实文件名
mapping_info = mappings.get(file_id, {})
original_filename = mapping_info.get("original_filename", file_path.name)
uploaded_at = mapping_info.get("uploaded_at", "")
# 获取文件大小
file_size = file_path.stat().st_size
# 格式化文件大小
size_mb = round(file_size / 1024 / 1024, 2)
size_display = format_file_size(file_size)
datasets.append({
"file_id": file_id,
"name": original_filename,
"size": file_size,
"size_mb": size_mb,
"size_display": size_display,
"status": "已处理",
"description": mapping_info.get("original_filename", "") if mapping_info else "",
"uploaded_at": uploaded_at,
"download_count": 0,
"is_physical_file": True
})
# 按文件名排序
datasets.sort(key=lambda x: x["name"])
return StandardResponse.success({
"datasets": datasets,
"total": len(datasets),
"source": "physical_files"
})
else:
# 获取所有文件API上传的文件
all_files = file_upload_service.get_all_files()
# 转换为前端期望的格式
datasets = []
for uploaded_file in all_files:
# 只返回JSON/JSONL文件数据集文件
file_ext = uploaded_file.original_filename.lower().split('.')[-1] if '.' in uploaded_file.original_filename else ''
if file_ext in ['json', 'jsonl']:
# 获取文件名映射(显示真实文件名)
mapping = file_upload_service.get_filename_mapping(uploaded_file.file_id)
display_name = mapping["original_filename"] if mapping else uploaded_file.original_filename
# 格式化文件大小
size_mb = round(uploaded_file.file_size / 1024 / 1024, 2)
size_display = format_file_size(uploaded_file.file_size)
datasets.append({
"file_id": uploaded_file.file_id,
"name": display_name,
"size": uploaded_file.file_size,
"size_mb": size_mb,
"size_display": size_display,
"status": "已处理",
"description": uploaded_file.description or "",
"uploaded_at": uploaded_file.uploaded_at,
"download_count": uploaded_file.download_count,
"is_physical_file": False
})
return StandardResponse.success({
"datasets": datasets,
"total": len(datasets),
"source": "api_uploaded"
})
except Exception as e:
return StandardResponse.error(f"获取数据集列表失败: {str(e)}")
@get("/{file_id}", response_model=StandardResponse)
async def get_dataset(self, file_id: str):
"""
获取特定数据集的详细信息
Args:
file_id: 文件ID
Returns:
StandardResponse: 包含数据集详情的标准响应
"""
try:
file_info = file_upload_service.get_file(file_id)
if not file_info:
return StandardResponse.error(f"数据集 {file_id} 不存在")
# 转换为前端期望的格式
# 显示真实文件名(从映射文件中获取)
mapping = file_upload_service.get_filename_mapping(file_info.file_id)
display_name = mapping["original_filename"] if mapping else file_info.original_filename
# 格式化文件大小
size_mb = round(file_info.file_size / 1024 / 1024, 2)
size_display = format_file_size(file_info.file_size)
dataset_info = {
"file_id": file_info.file_id,
"name": display_name,
"size": file_info.file_size,
"size_mb": size_mb,
"size_display": size_display,
"status": "已处理",
"description": file_info.description or "",
"uploaded_at": file_info.uploaded_at,
"updated_at": file_info.updated_at,
"download_count": file_info.download_count,
"content_type": file_info.content_type,
"file_hash": file_info.file_hash
}
return StandardResponse.success(dataset_info)
except Exception as e:
return StandardResponse.error(f"获取数据集详情失败: {str(e)}")
@get("/{file_id}", response_model=StandardResponse)
async def get_dataset(self, file_id: str):
"""
获取特定数据集的详细信息
Args:
file_id: 文件ID
Returns:
StandardResponse: 包含数据集详情的标准响应
"""
try:
file_info = file_upload_service.get_file(file_id)
if not file_info:
return StandardResponse.error(f"数据集 {file_id} 不存在")
# 转换为前端期望的格式
# 显示真实文件名(从映射文件中获取)
mapping = file_upload_service.get_filename_mapping(file_info.file_id)
display_name = mapping["original_filename"] if mapping else file_info.original_filename
# 格式化文件大小
size_mb = round(file_info.file_size / 1024 / 1024, 2)
size_display = format_file_size(file_info.file_size)
dataset_info = {
"file_id": file_info.file_id,
"name": display_name,
"size": file_info.file_size,
"size_mb": size_mb,
"size_display": size_display,
"status": "已处理",
"description": file_info.description or "",
"uploaded_at": file_info.uploaded_at,
"updated_at": file_info.updated_at,
"download_count": file_info.download_count,
"content_type": file_info.content_type,
"file_hash": file_info.file_hash
}
return StandardResponse.success(dataset_info)
except Exception as e:
return StandardResponse.error(f"获取数据集详情失败: {str(e)}")
@get("/{file_id}/content", response_model=StandardResponse)
async def get_dataset_content(self, file_id: str, limit: int = 5):
"""
获取数据集文件内容前N条记录
Args:
file_id: 文件ID
limit: 返回的记录数量默认5条
Returns:
StandardResponse: 包含数据集内容的标准响应
"""
try:
import json
import jsonlines
# 获取文件信息
file_info = file_upload_service.get_file(file_id)
if not file_info:
return StandardResponse.error(f"数据集 {file_id} 不存在")
# 获取文件路径
file_path = file_upload_service.get_file_path(file_id)
if not file_path or not file_path.exists():
return StandardResponse.error(f"文件 {file_id} 不存在")
# 读取文件内容
content_preview = []
filename = file_info.original_filename.lower()
try:
if filename.endswith('.jsonl'):
# 处理JSONL格式
with jsonlines.open(file_path) as reader:
count = 0
for item in reader:
if count >= limit:
break
content_preview.append(item)
count += 1
else:
# 处理JSON格式
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
if isinstance(data, list):
# 如果是数组取前N条
content_preview = data[:limit]
else:
# 如果是对象,直接返回
content_preview = data
except json.JSONDecodeError as e:
return StandardResponse.error(f"JSON文件格式错误: {str(e)}")
except Exception as e:
return StandardResponse.error(f"读取文件内容失败: {str(e)}")
# 获取真实文件名(从映射文件中获取)
mapping = file_upload_service.get_filename_mapping(file_id)
display_filename = mapping["original_filename"] if mapping else file_info.original_filename
return StandardResponse.success({
"file_id": file_id,
"filename": display_filename,
"total_records": len(content_preview),
"preview": content_preview
})
except Exception as e:
return StandardResponse.error(f"获取数据集内容失败: {str(e)}")
@delete("/{file_id}", response_model=StandardResponse)
async def delete_dataset(self, file_id: str):
"""
删除数据集
Args:
file_id: 文件ID
Returns:
StandardResponse: 包含删除结果的标准响应
"""
try:
if not file_upload_service.file_exists(file_id):
return StandardResponse.error(f"数据集 {file_id} 不存在")
success = file_upload_service.delete_file(file_id)
if success:
return StandardResponse.success({
"message": f"数据集 {file_id} 已删除"
})
else:
return StandardResponse.error(f"删除数据集 {file_id} 失败")
except Exception as e:
return StandardResponse.error(f"删除数据集失败: {str(e)}")
@get("/list-files", response_model=StandardResponse)
async def list_data_files(self):
"""
查询data目录下的文件列表
Returns:
StandardResponse: 包含文件列表的标准响应
"""
try:
import json
import os
from pathlib import Path
data_dir = file_upload_service.upload_dir
mapping_file = data_dir / "filename_mapping.json"
# 读取文件名映射
mappings = {}
if mapping_file.exists():
try:
with open(mapping_file, 'r', encoding='utf-8') as f:
mapping_data = json.load(f)
mappings = mapping_data.get("mappings", {})
except Exception:
mappings = {}
# 获取data目录下的所有JSON文件
files_info = []
if data_dir.exists():
for file_path in data_dir.iterdir():
# 跳过目录和映射文件本身
if file_path.is_file() and file_path.name != "filename_mapping.json":
file_id = file_path.stem # 去掉.json后缀得到file_id
# 从映射文件获取真实文件名
mapping_info = mappings.get(file_id, {})
original_filename = mapping_info.get("original_filename", file_path.name)
uploaded_at = mapping_info.get("uploaded_at", "")
# 获取文件大小
file_size = file_path.stat().st_size
files_info.append({
"file_id": file_id,
"original_filename": original_filename,
"storage_filename": file_path.name,
"file_path": str(file_path),
"file_size": file_size,
"file_size_mb": round(file_size / 1024 / 1024, 2),
"uploaded_at": uploaded_at,
"exists_in_mapping": file_id in mappings
})
# 按文件名排序
files_info.sort(key=lambda x: x["original_filename"])
return StandardResponse.success({
"total": len(files_info),
"files": files_info
})
except Exception as e:
return StandardResponse.error(f"查询文件列表失败: {str(e)}")
# 创建实例(自动发现系统会找到这个实例)
dataset_api = DatasetAPI()