from fastapi import FastAPI, File, UploadFile, HTTPException from pydantic import BaseModel from typing import List, Optional import uvicorn import os import json import re import time app = FastAPI(title="大模型微调平台 API", version="1.0.0") # 请求模型 class UserModel(BaseModel): username: str password: str class DatasetModel(BaseModel): name: str description: Optional[str] = None size: str class ModelConfigModel(BaseModel): model_name: str learning_rate: float batch_size: int epochs: int # 响应模型 class ResponseModel(BaseModel): code: int message: str data: Optional[dict] = None # 模拟数据存储 datasets = [ {"id": 1, "name": "中文对话数据集", "size": "1.2GB", "status": "已处理"}, {"id": 2, "name": "英文文本分类数据集", "size": "856MB", "status": "处理中"}, {"id": 3, "name": "图像识别数据集", "size": "2.5GB", "status": "待处理"}, ] models = [ {"id": 1, "name": "GPT-4", "status": "训练中", "accuracy": "92%"}, {"id": 2, "name": "BERT", "status": "已完成", "accuracy": "89%"}, {"id": 3, "name": "LLaMA", "status": "已完成", "accuracy": "95%"}, ] @app.get("/") async def root(): """根路径""" return {"message": "大模型微调平台 API 服务"} @app.get("/api/health") async def health_check(): """健康检查""" return ResponseModel(code=200, message="服务运行正常", data={"status": "healthy"}) @app.post("/api/login", response_model=ResponseModel) async def login(user: UserModel): """用户登录""" if user.username == "admin" and user.password: return ResponseModel( code=200, message="登录成功", data={"token": "mock_token_12345", "user": user.username} ) else: return ResponseModel(code=401, message="用户名或密码错误") @app.get("/api/datasets", response_model=ResponseModel) async def get_datasets(): """获取数据集列表""" return ResponseModel(code=200, message="获取成功", data={"datasets": datasets}) @app.post("/api/datasets", response_model=ResponseModel) async def create_dataset(dataset: DatasetModel): """创建数据集""" new_dataset = { "id": len(datasets) + 1, "name": dataset.name, "description": dataset.description, "size": "0MB", "status": "待处理" } datasets.append(new_dataset) return ResponseModel(code=201, message="创建成功", data={"dataset": new_dataset}) @app.post("/api/datasets/upload", response_model=ResponseModel) async def upload_dataset(file: UploadFile = File(...), description: Optional[str] = None): """上传数据集文件(仅支持 JSON 和 JSONL 格式)""" # 检查文件类型 allowed_extensions = ['.json', '.jsonl'] file_extension = os.path.splitext(file.filename)[1].lower() if file_extension not in allowed_extensions: raise HTTPException( status_code=400, detail=f"不支持的文件类型。只能上传 {', '.join(allowed_extensions)} 格式的文件" ) # 检查文件大小(限制为 100MB) max_size = 100 * 1024 * 1024 # 100MB contents = await file.read() file_size = len(contents) if file_size > max_size: raise HTTPException( status_code=400, detail=f"文件大小超过限制。最大支持 100MB,当前文件大小: {file_size / (1024*1024):.2f}MB" ) try: # 验证文件内容 if file_extension == '.json': # 验证 JSON 文件 json.loads(contents.decode('utf-8')) elif file_extension == '.jsonl': # 验证 JSONL 文件(每行必须是有效的 JSON) lines = contents.decode('utf-8').strip().split('\n') for i, line in enumerate(lines): if line.strip(): try: json.loads(line) except json.JSONDecodeError as e: raise HTTPException( status_code=400, detail=f"JSONL 文件格式错误:第 {i+1} 行不是有效的 JSON 格式" ) # 生成文件大小字符串 if file_size < 1024: size_str = f"{file_size}B" elif file_size < 1024 * 1024: size_str = f"{file_size / 1024:.2f}KB" else: size_str = f"{file_size / (1024*1024):.2f}MB" # 计算行数(用于统计) lines_count = len(contents.decode('utf-8').strip().split('\n')) if contents else 0 # 保存文件到 data 目录 data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'data') os.makedirs(data_dir, exist_ok=True) # 生成唯一文件名(避免冲突) base_name = os.path.splitext(file.filename)[0] timestamp = int(time.time()) saved_filename = f"{base_name}_{timestamp}{file_extension}" saved_path = os.path.join(data_dir, saved_filename) # 写入文件 with open(saved_path, 'wb') as f: f.write(contents) # 创建新数据集记录 new_dataset = { "id": len(datasets) + 1, "name": file.filename, "description": description or f"上传的数据集文件,包含 {lines_count} 行数据", "size": size_str, "status": "已处理", "upload_time": "刚刚", "file_extension": file_extension, "records_count": lines_count, "saved_path": saved_path # 添加保存路径信息 } # 添加到数据集列表 datasets.append(new_dataset) return ResponseModel( code=200, message="文件上传成功", data={ "dataset": new_dataset, "file_info": { "filename": file.filename, "size": size_str, "extension": file_extension, "records": lines_count } } ) except json.JSONDecodeError: raise HTTPException( status_code=400, detail="JSON 文件格式错误:文件内容不是有效的 JSON 格式" ) except UnicodeDecodeError: raise HTTPException( status_code=400, detail="文件编码错误:请确保文件使用 UTF-8 编码" ) except Exception as e: raise HTTPException( status_code=500, detail=f"文件处理错误:{str(e)}" ) @app.get("/api/datasets/files", response_model=ResponseModel) async def list_dataset_files(): """列出data目录中所有保存的数据集文件""" try: data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'data') if not os.path.exists(data_dir): return ResponseModel( code=200, message="获取成功", data={"files": [], "total": 0, "directory": data_dir} ) files = [] for filename in os.listdir(data_dir): file_path = os.path.join(data_dir, filename) if os.path.isfile(file_path): stat = os.stat(file_path) files.append({ "filename": filename, "size": stat.st_size, "size_human": format_size(stat.st_size), "modified_time": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(stat.st_mtime)), "path": file_path }) # 按修改时间排序(最新的在前) files.sort(key=lambda x: x["modified_time"], reverse=True) return ResponseModel( code=200, message="获取成功", data={ "files": files, "total": len(files), "directory": data_dir } ) except Exception as e: raise HTTPException( status_code=500, detail=f"获取文件列表失败:{str(e)}" ) def format_size(size_bytes): """格式化文件大小""" if size_bytes < 1024: return f"{size_bytes}B" elif size_bytes < 1024 * 1024: return f"{size_bytes / 1024:.2f}KB" else: return f"{size_bytes / (1024*1024):.2f}MB" @app.delete("/api/datasets/{dataset_id}", response_model=ResponseModel) async def delete_dataset(dataset_id: int): """删除数据集""" global datasets for i, dataset in enumerate(datasets): if dataset["id"] == dataset_id: deleted_dataset = datasets.pop(i) return ResponseModel( code=200, message="删除成功", data={"deleted_dataset": deleted_dataset} ) raise HTTPException(status_code=404, detail="数据集不存在") @app.get("/api/models", response_model=ResponseModel) async def get_models(): """获取模型列表""" return ResponseModel(code=200, message="获取成功", data={"models": models}) @app.post("/api/models/config", response_model=ResponseModel) async def config_model(config: ModelConfigModel): """配置模型参数""" return ResponseModel( code=200, message="配置成功", data={ "model_name": config.model_name, "learning_rate": config.learning_rate, "batch_size": config.batch_size, "epochs": config.epochs, "status": "已配置" } ) @app.get("/api/training/status") async def get_training_status(): """获取训练状态""" return ResponseModel( code=200, message="获取成功", data={ "current_task": "GPT-4微调", "progress": 75, "eta": "2小时", "loss": 0.23, "accuracy": 0.89 } ) @app.get("/api/system/stats") async def get_system_stats(): """获取系统统计信息""" import random return ResponseModel( code=200, message="获取成功", data={ "cpu_usage": random.randint(30, 80), "memory_usage": random.randint(40, 70), "gpu_usage": random.randint(50, 90), "active_tasks": 5, "completed_tasks": 158 } ) @app.post("/api/training/start") async def start_training(model_name: str, dataset_id: int): """开始训练任务""" return ResponseModel( code=200, message="训练任务已启动", data={ "task_id": random.randint(1000, 9999), "model_name": model_name, "dataset_id": dataset_id, "status": "running" } ) @app.post("/api/training/stop/{task_id}") async def stop_training(task_id: int): """停止训练任务""" return ResponseModel( code=200, message=f"训练任务 {task_id} 已停止", data={"task_id": task_id, "status": "stopped"} ) @app.get("/api/model/{model_id}/metrics") async def get_model_metrics(model_id: int): """获取模型指标""" return ResponseModel( code=200, message="获取成功", data={ "model_id": model_id, "accuracy": round(random.uniform(0.85, 0.98), 3), "precision": round(random.uniform(0.80, 0.95), 3), "recall": round(random.uniform(0.82, 0.96), 3), "f1_score": round(random.uniform(0.83, 0.97), 3), "training_time": f"{random.randint(2, 24)}小时", "parameters": random.randint(1000000, 100000000) } ) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8001)