""" 数据集管理 API 模块 提供数据集上传、列表、删除等功能 """ from typing import List, Optional from fastapi import UploadFile, File, Form from src.api.internal.base import BaseAPI, post, get, delete from src.models.response import StandardResponse from src.services.file_upload import file_upload_service def format_file_size(size_bytes: int) -> str: """ 格式化文件大小显示 Args: size_bytes: 文件大小(字节) Returns: str: 格式化后的文件大小字符串 """ if size_bytes < 1024: return f"{size_bytes} B" elif size_bytes < 1024 * 1024: size_kb = round(size_bytes / 1024, 2) return f"{size_kb} KB" else: size_mb = round(size_bytes / 1024 / 1024, 2) return f"{size_mb} MB" class DatasetAPI(BaseAPI): """数据集管理 API""" def __init__(self): """初始化""" # 在调用super().__init__()之前设置module_name self._override_module_name = "datasets" super().__init__() self.logger.info("DatasetAPI 初始化完成") @post("/upload", response_model=StandardResponse) async def upload_dataset( self, file: UploadFile = File(...), description: Optional[str] = Form(None) ): """ 上传数据集文件 Args: file: 上传的文件(支持 .json, .jsonl 格式) description: 文件描述(可选) Returns: StandardResponse: 包含上传结果的标准响应 """ try: # 验证文件类型 filename = file.filename or "unknown" file_ext = filename.lower().split('.')[-1] if '.' in filename else '' if file_ext not in ['json', 'jsonl']: return StandardResponse.error("只支持 .json 和 .jsonl 格式的文件") # 读取文件内容 file_content = await file.read() # 如果未提供描述,使用默认描述 if not description: description = f"用户上传的数据集文件: {filename}" # 使用文件上传服务上传文件 uploaded_file = await file_upload_service.upload_file( file_content=file_content, original_filename=filename, content_type=file.content_type, description=description ) # 转换为前端期望的格式 # 显示真实文件名(从映射文件中获取) mapping = file_upload_service.get_filename_mapping(uploaded_file.file_id) display_name = mapping["original_filename"] if mapping else uploaded_file.original_filename # 格式化文件大小 size_mb = round(uploaded_file.file_size / 1024 / 1024, 2) size_display = format_file_size(uploaded_file.file_size) dataset_info = { "file_id": uploaded_file.file_id, "name": display_name, "size": uploaded_file.file_size, "size_mb": size_mb, "size_display": size_display, "status": "已处理", # 默认状态 "uploaded_at": uploaded_file.uploaded_at, "description": uploaded_file.description } return StandardResponse.success({ "message": "数据集上传成功", "dataset": dataset_info }) except ValueError as e: return StandardResponse.error(str(e)) except Exception as e: return StandardResponse.error(f"上传失败: {str(e)}") @get("", response_model=StandardResponse) async def list_datasets(self, list_all: bool = False): """ 获取所有数据集列表 Args: list_all: 是否列出data目录下的所有文件(物理文件),默认False(只列出API上传的文件) Returns: StandardResponse: 包含数据集列表的标准响应 """ # 添加调试日志 import logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) logger.info(f"list_datasets called with list_all={list_all}") try: if list_all: # 列出data目录下的所有文件(物理文件) import json from pathlib import Path data_dir = file_upload_service.upload_dir mapping_file = data_dir / "filename_mapping.json" # 读取文件名映射 mappings = {} if mapping_file.exists(): try: with open(mapping_file, 'r', encoding='utf-8') as f: mapping_data = json.load(f) mappings = mapping_data.get("mappings", {}) except Exception: mappings = {} # 获取data目录下的所有JSON文件 datasets = [] if data_dir.exists(): for file_path in data_dir.iterdir(): # 跳过目录和映射文件本身 if file_path.is_file() and file_path.name != "filename_mapping.json": file_id = file_path.stem # 去掉.json后缀得到file_id # 从映射文件获取真实文件名 mapping_info = mappings.get(file_id, {}) original_filename = mapping_info.get("original_filename", file_path.name) uploaded_at = mapping_info.get("uploaded_at", "") # 获取文件大小 file_size = file_path.stat().st_size # 格式化文件大小 size_mb = round(file_size / 1024 / 1024, 2) size_display = format_file_size(file_size) datasets.append({ "file_id": file_id, "name": original_filename, "size": file_size, "size_mb": size_mb, "size_display": size_display, "status": "已处理", "description": mapping_info.get("original_filename", "") if mapping_info else "", "uploaded_at": uploaded_at, "download_count": 0, "is_physical_file": True }) # 按文件名排序 datasets.sort(key=lambda x: x["name"]) return StandardResponse.success({ "datasets": datasets, "total": len(datasets), "source": "physical_files" }) else: # 获取所有文件(API上传的文件) all_files = file_upload_service.get_all_files() # 转换为前端期望的格式 datasets = [] for uploaded_file in all_files: # 只返回JSON/JSONL文件(数据集文件) file_ext = uploaded_file.original_filename.lower().split('.')[-1] if '.' in uploaded_file.original_filename else '' if file_ext in ['json', 'jsonl']: # 获取文件名映射(显示真实文件名) mapping = file_upload_service.get_filename_mapping(uploaded_file.file_id) display_name = mapping["original_filename"] if mapping else uploaded_file.original_filename # 格式化文件大小 size_mb = round(uploaded_file.file_size / 1024 / 1024, 2) size_display = format_file_size(uploaded_file.file_size) datasets.append({ "file_id": uploaded_file.file_id, "name": display_name, "size": uploaded_file.file_size, "size_mb": size_mb, "size_display": size_display, "status": "已处理", "description": uploaded_file.description or "", "uploaded_at": uploaded_file.uploaded_at, "download_count": uploaded_file.download_count, "is_physical_file": False }) return StandardResponse.success({ "datasets": datasets, "total": len(datasets), "source": "api_uploaded" }) except Exception as e: return StandardResponse.error(f"获取数据集列表失败: {str(e)}") @get("/{file_id}", response_model=StandardResponse) async def get_dataset(self, file_id: str): """ 获取特定数据集的详细信息 Args: file_id: 文件ID Returns: StandardResponse: 包含数据集详情的标准响应 """ try: file_info = file_upload_service.get_file(file_id) if not file_info: return StandardResponse.error(f"数据集 {file_id} 不存在") # 转换为前端期望的格式 # 显示真实文件名(从映射文件中获取) mapping = file_upload_service.get_filename_mapping(file_info.file_id) display_name = mapping["original_filename"] if mapping else file_info.original_filename # 格式化文件大小 size_mb = round(file_info.file_size / 1024 / 1024, 2) size_display = format_file_size(file_info.file_size) dataset_info = { "file_id": file_info.file_id, "name": display_name, "size": file_info.file_size, "size_mb": size_mb, "size_display": size_display, "status": "已处理", "description": file_info.description or "", "uploaded_at": file_info.uploaded_at, "updated_at": file_info.updated_at, "download_count": file_info.download_count, "content_type": file_info.content_type, "file_hash": file_info.file_hash } return StandardResponse.success(dataset_info) except Exception as e: return StandardResponse.error(f"获取数据集详情失败: {str(e)}") @get("/{file_id}", response_model=StandardResponse) async def get_dataset(self, file_id: str): """ 获取特定数据集的详细信息 Args: file_id: 文件ID Returns: StandardResponse: 包含数据集详情的标准响应 """ try: file_info = file_upload_service.get_file(file_id) if not file_info: return StandardResponse.error(f"数据集 {file_id} 不存在") # 转换为前端期望的格式 # 显示真实文件名(从映射文件中获取) mapping = file_upload_service.get_filename_mapping(file_info.file_id) display_name = mapping["original_filename"] if mapping else file_info.original_filename # 格式化文件大小 size_mb = round(file_info.file_size / 1024 / 1024, 2) size_display = format_file_size(file_info.file_size) dataset_info = { "file_id": file_info.file_id, "name": display_name, "size": file_info.file_size, "size_mb": size_mb, "size_display": size_display, "status": "已处理", "description": file_info.description or "", "uploaded_at": file_info.uploaded_at, "updated_at": file_info.updated_at, "download_count": file_info.download_count, "content_type": file_info.content_type, "file_hash": file_info.file_hash } return StandardResponse.success(dataset_info) except Exception as e: return StandardResponse.error(f"获取数据集详情失败: {str(e)}") @get("/{file_id}/content", response_model=StandardResponse) async def get_dataset_content(self, file_id: str, limit: int = 5): """ 获取数据集文件内容(前N条记录) Args: file_id: 文件ID limit: 返回的记录数量,默认5条 Returns: StandardResponse: 包含数据集内容的标准响应 """ try: import json import jsonlines # 获取文件信息 file_info = file_upload_service.get_file(file_id) if not file_info: return StandardResponse.error(f"数据集 {file_id} 不存在") # 获取文件路径 file_path = file_upload_service.get_file_path(file_id) if not file_path or not file_path.exists(): return StandardResponse.error(f"文件 {file_id} 不存在") # 读取文件内容 content_preview = [] filename = file_info.original_filename.lower() try: if filename.endswith('.jsonl'): # 处理JSONL格式 with jsonlines.open(file_path) as reader: count = 0 for item in reader: if count >= limit: break content_preview.append(item) count += 1 else: # 处理JSON格式 with open(file_path, 'r', encoding='utf-8') as f: data = json.load(f) if isinstance(data, list): # 如果是数组,取前N条 content_preview = data[:limit] else: # 如果是对象,直接返回 content_preview = data except json.JSONDecodeError as e: return StandardResponse.error(f"JSON文件格式错误: {str(e)}") except Exception as e: return StandardResponse.error(f"读取文件内容失败: {str(e)}") # 获取真实文件名(从映射文件中获取) mapping = file_upload_service.get_filename_mapping(file_id) display_filename = mapping["original_filename"] if mapping else file_info.original_filename return StandardResponse.success({ "file_id": file_id, "filename": display_filename, "total_records": len(content_preview), "preview": content_preview }) except Exception as e: return StandardResponse.error(f"获取数据集内容失败: {str(e)}") @delete("/{file_id}", response_model=StandardResponse) async def delete_dataset(self, file_id: str): """ 删除数据集 Args: file_id: 文件ID Returns: StandardResponse: 包含删除结果的标准响应 """ try: if not file_upload_service.file_exists(file_id): return StandardResponse.error(f"数据集 {file_id} 不存在") success = file_upload_service.delete_file(file_id) if success: return StandardResponse.success({ "message": f"数据集 {file_id} 已删除" }) else: return StandardResponse.error(f"删除数据集 {file_id} 失败") except Exception as e: return StandardResponse.error(f"删除数据集失败: {str(e)}") @get("/list-files", response_model=StandardResponse) async def list_data_files(self): """ 查询data目录下的文件列表 Returns: StandardResponse: 包含文件列表的标准响应 """ try: import json import os from pathlib import Path data_dir = file_upload_service.upload_dir mapping_file = data_dir / "filename_mapping.json" # 读取文件名映射 mappings = {} if mapping_file.exists(): try: with open(mapping_file, 'r', encoding='utf-8') as f: mapping_data = json.load(f) mappings = mapping_data.get("mappings", {}) except Exception: mappings = {} # 获取data目录下的所有JSON文件 files_info = [] if data_dir.exists(): for file_path in data_dir.iterdir(): # 跳过目录和映射文件本身 if file_path.is_file() and file_path.name != "filename_mapping.json": file_id = file_path.stem # 去掉.json后缀得到file_id # 从映射文件获取真实文件名 mapping_info = mappings.get(file_id, {}) original_filename = mapping_info.get("original_filename", file_path.name) uploaded_at = mapping_info.get("uploaded_at", "") # 获取文件大小 file_size = file_path.stat().st_size files_info.append({ "file_id": file_id, "original_filename": original_filename, "storage_filename": file_path.name, "file_path": str(file_path), "file_size": file_size, "file_size_mb": round(file_size / 1024 / 1024, 2), "uploaded_at": uploaded_at, "exists_in_mapping": file_id in mappings }) # 按文件名排序 files_info.sort(key=lambda x: x["original_filename"]) return StandardResponse.success({ "total": len(files_info), "files": files_info }) except Exception as e: return StandardResponse.error(f"查询文件列表失败: {str(e)}") # 创建实例(自动发现系统会找到这个实例) dataset_api = DatasetAPI()