1. 增加了请求框架
2. 增加了删除虚拟环境的脚本
This commit is contained in:
501
request/src/api/modules/dataset.py
Normal file
501
request/src/api/modules/dataset.py
Normal file
@@ -0,0 +1,501 @@
|
||||
"""
|
||||
数据集管理 API 模块
|
||||
提供数据集上传、列表、删除等功能
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from fastapi import UploadFile, File, Form
|
||||
|
||||
from src.api.internal.base import BaseAPI, post, get, delete
|
||||
from src.models.response import StandardResponse
|
||||
from src.services.file_upload import file_upload_service
|
||||
|
||||
|
||||
def format_file_size(size_bytes: int) -> str:
|
||||
"""
|
||||
格式化文件大小显示
|
||||
|
||||
Args:
|
||||
size_bytes: 文件大小(字节)
|
||||
|
||||
Returns:
|
||||
str: 格式化后的文件大小字符串
|
||||
"""
|
||||
if size_bytes < 1024:
|
||||
return f"{size_bytes} B"
|
||||
elif size_bytes < 1024 * 1024:
|
||||
size_kb = round(size_bytes / 1024, 2)
|
||||
return f"{size_kb} KB"
|
||||
else:
|
||||
size_mb = round(size_bytes / 1024 / 1024, 2)
|
||||
return f"{size_mb} MB"
|
||||
|
||||
|
||||
class DatasetAPI(BaseAPI):
|
||||
"""数据集管理 API - 自动注册到 /api/datasets 路径"""
|
||||
|
||||
def __init__(self):
|
||||
# 重写初始化逻辑以设置正确的路由前缀
|
||||
# 1. 手动设置 module_name
|
||||
self.module_name = "api.datasets"
|
||||
|
||||
# 2. 创建路由器(使用期望的前缀)
|
||||
from fastapi import APIRouter
|
||||
self.router_prefix = "/api/datasets"
|
||||
self.router = APIRouter(
|
||||
prefix=self.router_prefix,
|
||||
tags=["Datasets"]
|
||||
)
|
||||
|
||||
# 3. 获取日志器
|
||||
import sys
|
||||
self.logger = __import__('src.utils.logger', fromlist=['get_logger']).get_logger(self.__class__.__module__)
|
||||
|
||||
# 4. 调用基类的自动注册(此时router已被覆盖)
|
||||
# 注意:我们不调用父类__init__,而是手动调用_auto_register_routes
|
||||
self._auto_register_routes()
|
||||
|
||||
# 5. 记录初始化
|
||||
self.logger.info(
|
||||
f"API模块初始化完成",
|
||||
module=self.module_name,
|
||||
prefix=self.router_prefix,
|
||||
routes=len(self.router.routes)
|
||||
)
|
||||
|
||||
@post("/upload", response_model=StandardResponse)
|
||||
async def upload_dataset(
|
||||
self,
|
||||
file: UploadFile = File(...),
|
||||
description: Optional[str] = Form(None)
|
||||
):
|
||||
"""
|
||||
上传数据集文件
|
||||
|
||||
Args:
|
||||
file: 上传的文件(支持 .json, .jsonl 格式)
|
||||
description: 文件描述(可选)
|
||||
|
||||
Returns:
|
||||
StandardResponse: 包含上传结果的标准响应
|
||||
"""
|
||||
try:
|
||||
# 验证文件类型
|
||||
filename = file.filename or "unknown"
|
||||
file_ext = filename.lower().split('.')[-1] if '.' in filename else ''
|
||||
|
||||
if file_ext not in ['json', 'jsonl']:
|
||||
return StandardResponse.error("只支持 .json 和 .jsonl 格式的文件")
|
||||
|
||||
# 读取文件内容
|
||||
file_content = await file.read()
|
||||
|
||||
# 如果未提供描述,使用默认描述
|
||||
if not description:
|
||||
description = f"用户上传的数据集文件: {filename}"
|
||||
|
||||
# 使用文件上传服务上传文件
|
||||
uploaded_file = await file_upload_service.upload_file(
|
||||
file_content=file_content,
|
||||
original_filename=filename,
|
||||
content_type=file.content_type,
|
||||
description=description
|
||||
)
|
||||
|
||||
# 转换为前端期望的格式
|
||||
# 显示真实文件名(从映射文件中获取)
|
||||
mapping = file_upload_service.get_filename_mapping(uploaded_file.file_id)
|
||||
display_name = mapping["original_filename"] if mapping else uploaded_file.original_filename
|
||||
|
||||
# 格式化文件大小
|
||||
size_mb = round(uploaded_file.file_size / 1024 / 1024, 2)
|
||||
size_display = format_file_size(uploaded_file.file_size)
|
||||
|
||||
dataset_info = {
|
||||
"file_id": uploaded_file.file_id,
|
||||
"name": display_name,
|
||||
"size": uploaded_file.file_size,
|
||||
"size_mb": size_mb,
|
||||
"size_display": size_display,
|
||||
"status": "已处理", # 默认状态
|
||||
"uploaded_at": uploaded_file.uploaded_at,
|
||||
"description": uploaded_file.description
|
||||
}
|
||||
|
||||
return StandardResponse.success({
|
||||
"message": "数据集上传成功",
|
||||
"dataset": dataset_info
|
||||
})
|
||||
|
||||
except ValueError as e:
|
||||
return StandardResponse.error(str(e))
|
||||
except Exception as e:
|
||||
return StandardResponse.error(f"上传失败: {str(e)}")
|
||||
|
||||
@get("", response_model=StandardResponse)
|
||||
async def list_datasets(self, list_all: bool = False):
|
||||
"""
|
||||
获取所有数据集列表
|
||||
|
||||
Args:
|
||||
list_all: 是否列出data目录下的所有文件(物理文件),默认False(只列出API上传的文件)
|
||||
|
||||
Returns:
|
||||
StandardResponse: 包含数据集列表的标准响应
|
||||
"""
|
||||
try:
|
||||
if list_all:
|
||||
# 列出data目录下的所有文件(物理文件)
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
data_dir = file_upload_service.upload_dir
|
||||
mapping_file = data_dir / "filename_mapping.json"
|
||||
|
||||
# 读取文件名映射
|
||||
mappings = {}
|
||||
if mapping_file.exists():
|
||||
try:
|
||||
with open(mapping_file, 'r', encoding='utf-8') as f:
|
||||
mapping_data = json.load(f)
|
||||
mappings = mapping_data.get("mappings", {})
|
||||
except Exception:
|
||||
mappings = {}
|
||||
|
||||
# 获取data目录下的所有JSON文件
|
||||
datasets = []
|
||||
if data_dir.exists():
|
||||
for file_path in data_dir.iterdir():
|
||||
# 跳过目录和映射文件本身
|
||||
if file_path.is_file() and file_path.name != "filename_mapping.json":
|
||||
file_id = file_path.stem # 去掉.json后缀得到file_id
|
||||
|
||||
# 从映射文件获取真实文件名
|
||||
mapping_info = mappings.get(file_id, {})
|
||||
original_filename = mapping_info.get("original_filename", file_path.name)
|
||||
uploaded_at = mapping_info.get("uploaded_at", "")
|
||||
|
||||
# 获取文件大小
|
||||
file_size = file_path.stat().st_size
|
||||
|
||||
# 格式化文件大小
|
||||
size_mb = round(file_size / 1024 / 1024, 2)
|
||||
size_display = format_file_size(file_size)
|
||||
|
||||
datasets.append({
|
||||
"file_id": file_id,
|
||||
"name": original_filename,
|
||||
"size": file_size,
|
||||
"size_mb": size_mb,
|
||||
"size_display": size_display,
|
||||
"status": "已处理",
|
||||
"description": mapping_info.get("original_filename", "") if mapping_info else "",
|
||||
"uploaded_at": uploaded_at,
|
||||
"download_count": 0,
|
||||
"is_physical_file": True
|
||||
})
|
||||
|
||||
# 按文件名排序
|
||||
datasets.sort(key=lambda x: x["name"])
|
||||
|
||||
return StandardResponse.success({
|
||||
"datasets": datasets,
|
||||
"total": len(datasets),
|
||||
"source": "physical_files"
|
||||
})
|
||||
else:
|
||||
# 获取所有文件(API上传的文件)
|
||||
all_files = file_upload_service.get_all_files()
|
||||
|
||||
# 转换为前端期望的格式
|
||||
datasets = []
|
||||
for uploaded_file in all_files:
|
||||
# 只返回JSON/JSONL文件(数据集文件)
|
||||
file_ext = uploaded_file.original_filename.lower().split('.')[-1] if '.' in uploaded_file.original_filename else ''
|
||||
if file_ext in ['json', 'jsonl']:
|
||||
# 获取文件名映射(显示真实文件名)
|
||||
mapping = file_upload_service.get_filename_mapping(uploaded_file.file_id)
|
||||
display_name = mapping["original_filename"] if mapping else uploaded_file.original_filename
|
||||
|
||||
# 格式化文件大小
|
||||
size_mb = round(uploaded_file.file_size / 1024 / 1024, 2)
|
||||
size_display = format_file_size(uploaded_file.file_size)
|
||||
|
||||
datasets.append({
|
||||
"file_id": uploaded_file.file_id,
|
||||
"name": display_name,
|
||||
"size": uploaded_file.file_size,
|
||||
"size_mb": size_mb,
|
||||
"size_display": size_display,
|
||||
"status": "已处理",
|
||||
"description": uploaded_file.description or "",
|
||||
"uploaded_at": uploaded_file.uploaded_at,
|
||||
"download_count": uploaded_file.download_count,
|
||||
"is_physical_file": False
|
||||
})
|
||||
|
||||
return StandardResponse.success({
|
||||
"datasets": datasets,
|
||||
"total": len(datasets),
|
||||
"source": "api_uploaded"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
return StandardResponse.error(f"获取数据集列表失败: {str(e)}")
|
||||
|
||||
@get("/{file_id}", response_model=StandardResponse)
|
||||
async def get_dataset(self, file_id: str):
|
||||
"""
|
||||
获取特定数据集的详细信息
|
||||
|
||||
Args:
|
||||
file_id: 文件ID
|
||||
|
||||
Returns:
|
||||
StandardResponse: 包含数据集详情的标准响应
|
||||
"""
|
||||
try:
|
||||
file_info = file_upload_service.get_file(file_id)
|
||||
|
||||
if not file_info:
|
||||
return StandardResponse.error(f"数据集 {file_id} 不存在")
|
||||
|
||||
# 转换为前端期望的格式
|
||||
# 显示真实文件名(从映射文件中获取)
|
||||
mapping = file_upload_service.get_filename_mapping(file_info.file_id)
|
||||
display_name = mapping["original_filename"] if mapping else file_info.original_filename
|
||||
|
||||
# 格式化文件大小
|
||||
size_mb = round(file_info.file_size / 1024 / 1024, 2)
|
||||
size_display = format_file_size(file_info.file_size)
|
||||
|
||||
dataset_info = {
|
||||
"file_id": file_info.file_id,
|
||||
"name": display_name,
|
||||
"size": file_info.file_size,
|
||||
"size_mb": size_mb,
|
||||
"size_display": size_display,
|
||||
"status": "已处理",
|
||||
"description": file_info.description or "",
|
||||
"uploaded_at": file_info.uploaded_at,
|
||||
"updated_at": file_info.updated_at,
|
||||
"download_count": file_info.download_count,
|
||||
"content_type": file_info.content_type,
|
||||
"file_hash": file_info.file_hash
|
||||
}
|
||||
|
||||
return StandardResponse.success(dataset_info)
|
||||
|
||||
except Exception as e:
|
||||
return StandardResponse.error(f"获取数据集详情失败: {str(e)}")
|
||||
|
||||
@get("/{file_id}", response_model=StandardResponse)
|
||||
async def get_dataset(self, file_id: str):
|
||||
"""
|
||||
获取特定数据集的详细信息
|
||||
|
||||
Args:
|
||||
file_id: 文件ID
|
||||
|
||||
Returns:
|
||||
StandardResponse: 包含数据集详情的标准响应
|
||||
"""
|
||||
try:
|
||||
file_info = file_upload_service.get_file(file_id)
|
||||
|
||||
if not file_info:
|
||||
return StandardResponse.error(f"数据集 {file_id} 不存在")
|
||||
|
||||
# 转换为前端期望的格式
|
||||
# 显示真实文件名(从映射文件中获取)
|
||||
mapping = file_upload_service.get_filename_mapping(file_info.file_id)
|
||||
display_name = mapping["original_filename"] if mapping else file_info.original_filename
|
||||
|
||||
# 格式化文件大小
|
||||
size_mb = round(file_info.file_size / 1024 / 1024, 2)
|
||||
size_display = format_file_size(file_info.file_size)
|
||||
|
||||
dataset_info = {
|
||||
"file_id": file_info.file_id,
|
||||
"name": display_name,
|
||||
"size": file_info.file_size,
|
||||
"size_mb": size_mb,
|
||||
"size_display": size_display,
|
||||
"status": "已处理",
|
||||
"description": file_info.description or "",
|
||||
"uploaded_at": file_info.uploaded_at,
|
||||
"updated_at": file_info.updated_at,
|
||||
"download_count": file_info.download_count,
|
||||
"content_type": file_info.content_type,
|
||||
"file_hash": file_info.file_hash
|
||||
}
|
||||
|
||||
return StandardResponse.success(dataset_info)
|
||||
|
||||
except Exception as e:
|
||||
return StandardResponse.error(f"获取数据集详情失败: {str(e)}")
|
||||
|
||||
@get("/{file_id}/content", response_model=StandardResponse)
|
||||
async def get_dataset_content(self, file_id: str, limit: int = 5):
|
||||
"""
|
||||
获取数据集文件内容(前N条记录)
|
||||
|
||||
Args:
|
||||
file_id: 文件ID
|
||||
limit: 返回的记录数量,默认5条
|
||||
|
||||
Returns:
|
||||
StandardResponse: 包含数据集内容的标准响应
|
||||
"""
|
||||
try:
|
||||
import json
|
||||
import jsonlines
|
||||
|
||||
# 获取文件信息
|
||||
file_info = file_upload_service.get_file(file_id)
|
||||
if not file_info:
|
||||
return StandardResponse.error(f"数据集 {file_id} 不存在")
|
||||
|
||||
# 获取文件路径
|
||||
file_path = file_upload_service.get_file_path(file_id)
|
||||
if not file_path or not file_path.exists():
|
||||
return StandardResponse.error(f"文件 {file_id} 不存在")
|
||||
|
||||
# 读取文件内容
|
||||
content_preview = []
|
||||
filename = file_info.original_filename.lower()
|
||||
|
||||
try:
|
||||
if filename.endswith('.jsonl'):
|
||||
# 处理JSONL格式
|
||||
with jsonlines.open(file_path) as reader:
|
||||
count = 0
|
||||
for item in reader:
|
||||
if count >= limit:
|
||||
break
|
||||
content_preview.append(item)
|
||||
count += 1
|
||||
else:
|
||||
# 处理JSON格式
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
if isinstance(data, list):
|
||||
# 如果是数组,取前N条
|
||||
content_preview = data[:limit]
|
||||
else:
|
||||
# 如果是对象,直接返回
|
||||
content_preview = data
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
return StandardResponse.error(f"JSON文件格式错误: {str(e)}")
|
||||
except Exception as e:
|
||||
return StandardResponse.error(f"读取文件内容失败: {str(e)}")
|
||||
|
||||
# 获取真实文件名(从映射文件中获取)
|
||||
mapping = file_upload_service.get_filename_mapping(file_id)
|
||||
display_filename = mapping["original_filename"] if mapping else file_info.original_filename
|
||||
|
||||
return StandardResponse.success({
|
||||
"file_id": file_id,
|
||||
"filename": display_filename,
|
||||
"total_records": len(content_preview),
|
||||
"preview": content_preview
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
return StandardResponse.error(f"获取数据集内容失败: {str(e)}")
|
||||
|
||||
@delete("/{file_id}", response_model=StandardResponse)
|
||||
async def delete_dataset(self, file_id: str):
|
||||
"""
|
||||
删除数据集
|
||||
|
||||
Args:
|
||||
file_id: 文件ID
|
||||
|
||||
Returns:
|
||||
StandardResponse: 包含删除结果的标准响应
|
||||
"""
|
||||
try:
|
||||
if not file_upload_service.file_exists(file_id):
|
||||
return StandardResponse.error(f"数据集 {file_id} 不存在")
|
||||
|
||||
success = file_upload_service.delete_file(file_id)
|
||||
|
||||
if success:
|
||||
return StandardResponse.success({
|
||||
"message": f"数据集 {file_id} 已删除"
|
||||
})
|
||||
else:
|
||||
return StandardResponse.error(f"删除数据集 {file_id} 失败")
|
||||
|
||||
except Exception as e:
|
||||
return StandardResponse.error(f"删除数据集失败: {str(e)}")
|
||||
|
||||
@get("/list-files", response_model=StandardResponse)
|
||||
async def list_data_files(self):
|
||||
"""
|
||||
查询data目录下的文件列表
|
||||
|
||||
Returns:
|
||||
StandardResponse: 包含文件列表的标准响应
|
||||
"""
|
||||
try:
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
data_dir = file_upload_service.upload_dir
|
||||
mapping_file = data_dir / "filename_mapping.json"
|
||||
|
||||
# 读取文件名映射
|
||||
mappings = {}
|
||||
if mapping_file.exists():
|
||||
try:
|
||||
with open(mapping_file, 'r', encoding='utf-8') as f:
|
||||
mapping_data = json.load(f)
|
||||
mappings = mapping_data.get("mappings", {})
|
||||
except Exception:
|
||||
mappings = {}
|
||||
|
||||
# 获取data目录下的所有JSON文件
|
||||
files_info = []
|
||||
if data_dir.exists():
|
||||
for file_path in data_dir.iterdir():
|
||||
# 跳过目录和映射文件本身
|
||||
if file_path.is_file() and file_path.name != "filename_mapping.json":
|
||||
file_id = file_path.stem # 去掉.json后缀得到file_id
|
||||
|
||||
# 从映射文件获取真实文件名
|
||||
mapping_info = mappings.get(file_id, {})
|
||||
original_filename = mapping_info.get("original_filename", file_path.name)
|
||||
uploaded_at = mapping_info.get("uploaded_at", "")
|
||||
|
||||
# 获取文件大小
|
||||
file_size = file_path.stat().st_size
|
||||
|
||||
files_info.append({
|
||||
"file_id": file_id,
|
||||
"original_filename": original_filename,
|
||||
"storage_filename": file_path.name,
|
||||
"file_path": str(file_path),
|
||||
"file_size": file_size,
|
||||
"file_size_mb": round(file_size / 1024 / 1024, 2),
|
||||
"uploaded_at": uploaded_at,
|
||||
"exists_in_mapping": file_id in mappings
|
||||
})
|
||||
|
||||
# 按文件名排序
|
||||
files_info.sort(key=lambda x: x["original_filename"])
|
||||
|
||||
return StandardResponse.success({
|
||||
"total": len(files_info),
|
||||
"files": files_info
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
return StandardResponse.error(f"查询文件列表失败: {str(e)}")
|
||||
|
||||
|
||||
# 创建实例(自动发现系统会找到这个实例)
|
||||
dataset_api = DatasetAPI()
|
||||
Reference in New Issue
Block a user