新分支,重新设计UI
This commit is contained in:
129
src/README.md
129
src/README.md
@@ -1,129 +0,0 @@
|
||||
# FastAPI 服务器
|
||||
|
||||
## 功能特性
|
||||
|
||||
这个 FastAPI 服务器为大模型微调平台提供了 RESTful API 接口。
|
||||
|
||||
## API 端点
|
||||
|
||||
### 基础信息
|
||||
- `GET /` - 根路径,返回欢迎信息
|
||||
- `GET /api/health` - 健康检查
|
||||
|
||||
### 用户认证
|
||||
- `POST /api/login` - 用户登录
|
||||
```json
|
||||
{
|
||||
"username": "admin",
|
||||
"password": "your_password"
|
||||
}
|
||||
```
|
||||
|
||||
### 数据集管理
|
||||
- `GET /api/datasets` - 获取数据集列表
|
||||
- `POST /api/datasets` - 创建新数据集
|
||||
```json
|
||||
{
|
||||
"name": "新数据集名称",
|
||||
"description": "数据集描述",
|
||||
"size": "数据集大小"
|
||||
}
|
||||
```
|
||||
- `POST /api/datasets/upload` - 上传数据集文件(支持 JSON 和 JSONL 格式)
|
||||
```bash
|
||||
curl -X POST "http://10.10.10.77:8001/api/datasets/upload" \
|
||||
-F "file=@dataset.json" \
|
||||
-F "description=数据集描述"
|
||||
```
|
||||
**支持的文件格式**: .json, .jsonl
|
||||
**文件大小限制**: 100MB
|
||||
- `GET /api/datasets/files` - 获取data目录中保存的文件列表
|
||||
- `DELETE /api/datasets/{dataset_id}` - 删除数据集
|
||||
|
||||
### 模型管理
|
||||
- `GET /api/models` - 获取模型列表
|
||||
- `POST /api/models/config` - 配置模型参数
|
||||
```json
|
||||
{
|
||||
"model_name": "GPT-4",
|
||||
"learning_rate": 0.001,
|
||||
"batch_size": 32,
|
||||
"epochs": 100
|
||||
}
|
||||
```
|
||||
|
||||
### 训练管理
|
||||
- `GET /api/training/status` - 获取训练状态
|
||||
- `POST /api/training/start` - 开始训练任务
|
||||
- `POST /api/training/stop/{task_id}` - 停止训练任务
|
||||
- `GET /api/model/{model_id}/metrics` - 获取模型指标
|
||||
|
||||
### 系统监控
|
||||
- `GET /api/system/stats` - 获取系统统计信息
|
||||
|
||||
## 启动服务器
|
||||
|
||||
### 方法 1: 使用启动脚本(推荐)
|
||||
```bash
|
||||
cd src
|
||||
./run.sh
|
||||
```
|
||||
|
||||
### 方法 2: 手动启动
|
||||
```bash
|
||||
# 安装依赖
|
||||
pip3 install -r requirements.txt
|
||||
|
||||
# 启动服务器
|
||||
uvicorn main:app --host 0.0.0.0 --port 8001 --reload
|
||||
```
|
||||
|
||||
## 访问地址
|
||||
|
||||
- **服务器**: http://10.10.10.77:8001
|
||||
- **API 文档**: http://10.10.10.77:8001/docs
|
||||
- **替代文档**: http://10.10.10.77:8001/redoc
|
||||
|
||||
## 示例请求
|
||||
|
||||
### 登录
|
||||
```bash
|
||||
curl -X POST "http://10.10.10.77:8001/api/login" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"username": "admin", "password": "123456"}'
|
||||
```
|
||||
|
||||
### 获取数据集列表
|
||||
```bash
|
||||
curl -X GET "http://10.10.10.77:8001/api/datasets"
|
||||
```
|
||||
|
||||
### 上传数据集文件
|
||||
```bash
|
||||
curl -X POST "http://10.10.10.77:8001/api/datasets/upload" \
|
||||
-F "file=@dataset.json" \
|
||||
-F "description=数据集描述"
|
||||
```
|
||||
|
||||
### 获取data目录文件列表
|
||||
```bash
|
||||
curl -X GET "http://10.10.10.77:8001/api/datasets/files"
|
||||
```
|
||||
|
||||
### 获取系统统计
|
||||
```bash
|
||||
curl -X GET "http://10.10.10.77:8001/api/system/stats"
|
||||
```
|
||||
|
||||
## 依赖
|
||||
|
||||
- Python 3.7+
|
||||
- FastAPI 0.104.1
|
||||
- Uvicorn 0.24.0
|
||||
- Pydantic 2.5.0
|
||||
|
||||
## 注意事项
|
||||
|
||||
- 服务器默认运行在端口 8001
|
||||
- 使用 `--reload` 参数启用热重载
|
||||
- 所有 API 响应都遵循统一格式
|
||||
805
src/main.py
805
src/main.py
@@ -1,381 +1,468 @@
|
||||
from fastapi import FastAPI, File, UploadFile, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
import uvicorn
|
||||
"""
|
||||
远光软件微调平台 - Flask 后端 API
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
import pymysql
|
||||
import yaml
|
||||
from flask import Flask, request, jsonify
|
||||
from flask_cors import CORS
|
||||
|
||||
app = FastAPI(title="大模型微调平台 API", version="1.0.0")
|
||||
# 获取项目根目录
|
||||
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, PROJECT_ROOT)
|
||||
|
||||
# 加载配置
|
||||
CONFIG_PATH = os.path.join(PROJECT_ROOT, 'config.yaml')
|
||||
|
||||
|
||||
# 请求模型
|
||||
class UserModel(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
def load_config():
|
||||
"""加载配置文件"""
|
||||
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
|
||||
return yaml.safe_load(f)
|
||||
|
||||
|
||||
class DatasetModel(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
size: str
|
||||
CONFIG = load_config()
|
||||
|
||||
|
||||
class ModelConfigModel(BaseModel):
|
||||
model_name: str
|
||||
learning_rate: float
|
||||
batch_size: int
|
||||
epochs: int
|
||||
|
||||
|
||||
# 响应模型
|
||||
class ResponseModel(BaseModel):
|
||||
code: int
|
||||
message: str
|
||||
data: Optional[dict] = None
|
||||
|
||||
|
||||
# 模拟数据存储
|
||||
datasets = [
|
||||
{"id": 1, "name": "中文对话数据集", "size": "1.2GB", "status": "已处理"},
|
||||
{"id": 2, "name": "英文文本分类数据集", "size": "856MB", "status": "处理中"},
|
||||
{"id": 3, "name": "图像识别数据集", "size": "2.5GB", "status": "待处理"},
|
||||
]
|
||||
|
||||
models = [
|
||||
{"id": 1, "name": "GPT-4", "status": "训练中", "accuracy": "92%"},
|
||||
{"id": 2, "name": "BERT", "status": "已完成", "accuracy": "89%"},
|
||||
{"id": 3, "name": "LLaMA", "status": "已完成", "accuracy": "95%"},
|
||||
]
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""根路径"""
|
||||
return {"message": "大模型微调平台 API 服务"}
|
||||
|
||||
|
||||
@app.get("/api/health")
|
||||
async def health_check():
|
||||
"""健康检查"""
|
||||
return ResponseModel(code=200, message="服务运行正常", data={"status": "healthy"})
|
||||
|
||||
|
||||
@app.post("/api/login", response_model=ResponseModel)
|
||||
async def login(user: UserModel):
|
||||
"""用户登录"""
|
||||
if user.username == "admin" and user.password:
|
||||
return ResponseModel(
|
||||
code=200,
|
||||
message="登录成功",
|
||||
data={"token": "mock_token_12345", "user": user.username}
|
||||
)
|
||||
else:
|
||||
return ResponseModel(code=401, message="用户名或密码错误")
|
||||
|
||||
|
||||
@app.get("/api/datasets", response_model=ResponseModel)
|
||||
async def get_datasets():
|
||||
"""获取数据集列表"""
|
||||
return ResponseModel(code=200, message="获取成功", data={"datasets": datasets})
|
||||
|
||||
|
||||
@app.post("/api/datasets", response_model=ResponseModel)
|
||||
async def create_dataset(dataset: DatasetModel):
|
||||
"""创建数据集"""
|
||||
new_dataset = {
|
||||
"id": len(datasets) + 1,
|
||||
"name": dataset.name,
|
||||
"description": dataset.description,
|
||||
"size": "0MB",
|
||||
"status": "待处理"
|
||||
}
|
||||
datasets.append(new_dataset)
|
||||
return ResponseModel(code=201, message="创建成功", data={"dataset": new_dataset})
|
||||
|
||||
|
||||
@app.post("/api/datasets/upload", response_model=ResponseModel)
|
||||
async def upload_dataset(file: UploadFile = File(...), description: Optional[str] = None):
|
||||
"""上传数据集文件(仅支持 JSON 和 JSONL 格式)"""
|
||||
# 检查文件类型
|
||||
allowed_extensions = ['.json', '.jsonl']
|
||||
file_extension = os.path.splitext(file.filename)[1].lower()
|
||||
|
||||
if file_extension not in allowed_extensions:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"不支持的文件类型。只能上传 {', '.join(allowed_extensions)} 格式的文件"
|
||||
)
|
||||
|
||||
# 检查文件大小(限制为 100MB)
|
||||
max_size = 100 * 1024 * 1024 # 100MB
|
||||
contents = await file.read()
|
||||
file_size = len(contents)
|
||||
|
||||
if file_size > max_size:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"文件大小超过限制。最大支持 100MB,当前文件大小: {file_size / (1024*1024):.2f}MB"
|
||||
)
|
||||
|
||||
try:
|
||||
# 验证文件内容
|
||||
if file_extension == '.json':
|
||||
# 验证 JSON 文件
|
||||
json.loads(contents.decode('utf-8'))
|
||||
elif file_extension == '.jsonl':
|
||||
# 验证 JSONL 文件(每行必须是有效的 JSON)
|
||||
lines = contents.decode('utf-8').strip().split('\n')
|
||||
for i, line in enumerate(lines):
|
||||
if line.strip():
|
||||
try:
|
||||
json.loads(line)
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"JSONL 文件格式错误:第 {i+1} 行不是有效的 JSON 格式"
|
||||
)
|
||||
|
||||
# 生成文件大小字符串
|
||||
if file_size < 1024:
|
||||
size_str = f"{file_size}B"
|
||||
elif file_size < 1024 * 1024:
|
||||
size_str = f"{file_size / 1024:.2f}KB"
|
||||
else:
|
||||
size_str = f"{file_size / (1024*1024):.2f}MB"
|
||||
|
||||
# 计算行数(用于统计)
|
||||
lines_count = len(contents.decode('utf-8').strip().split('\n')) if contents else 0
|
||||
|
||||
# 保存文件到 data 目录
|
||||
data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'data')
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
|
||||
# 生成唯一文件名(避免冲突)
|
||||
base_name = os.path.splitext(file.filename)[0]
|
||||
timestamp = int(time.time())
|
||||
saved_filename = f"{base_name}_{timestamp}{file_extension}"
|
||||
saved_path = os.path.join(data_dir, saved_filename)
|
||||
|
||||
# 写入文件
|
||||
with open(saved_path, 'wb') as f:
|
||||
f.write(contents)
|
||||
|
||||
# 创建新数据集记录
|
||||
new_dataset = {
|
||||
"id": len(datasets) + 1,
|
||||
"name": file.filename,
|
||||
"description": description or f"上传的数据集文件,包含 {lines_count} 行数据",
|
||||
"size": size_str,
|
||||
"status": "已处理",
|
||||
"upload_time": "刚刚",
|
||||
"file_extension": file_extension,
|
||||
"records_count": lines_count,
|
||||
"saved_path": saved_path # 添加保存路径信息
|
||||
}
|
||||
|
||||
# 添加到数据集列表
|
||||
datasets.append(new_dataset)
|
||||
|
||||
return ResponseModel(
|
||||
code=200,
|
||||
message="文件上传成功",
|
||||
data={
|
||||
"dataset": new_dataset,
|
||||
"file_info": {
|
||||
"filename": file.filename,
|
||||
"size": size_str,
|
||||
"extension": file_extension,
|
||||
"records": lines_count
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="JSON 文件格式错误:文件内容不是有效的 JSON 格式"
|
||||
)
|
||||
except UnicodeDecodeError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="文件编码错误:请确保文件使用 UTF-8 编码"
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"文件处理错误:{str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@app.get("/api/datasets/files", response_model=ResponseModel)
|
||||
async def list_dataset_files():
|
||||
"""列出data目录中所有保存的数据集文件"""
|
||||
try:
|
||||
data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'data')
|
||||
|
||||
if not os.path.exists(data_dir):
|
||||
return ResponseModel(
|
||||
code=200,
|
||||
message="获取成功",
|
||||
data={"files": [], "total": 0, "directory": data_dir}
|
||||
)
|
||||
|
||||
files = []
|
||||
for filename in os.listdir(data_dir):
|
||||
file_path = os.path.join(data_dir, filename)
|
||||
if os.path.isfile(file_path):
|
||||
stat = os.stat(file_path)
|
||||
files.append({
|
||||
"filename": filename,
|
||||
"size": stat.st_size,
|
||||
"size_human": format_size(stat.st_size),
|
||||
"modified_time": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(stat.st_mtime)),
|
||||
"path": file_path
|
||||
})
|
||||
|
||||
# 按修改时间排序(最新的在前)
|
||||
files.sort(key=lambda x: x["modified_time"], reverse=True)
|
||||
|
||||
return ResponseModel(
|
||||
code=200,
|
||||
message="获取成功",
|
||||
data={
|
||||
"files": files,
|
||||
"total": len(files),
|
||||
"directory": data_dir
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"获取文件列表失败:{str(e)}"
|
||||
)
|
||||
|
||||
|
||||
def format_size(size_bytes):
|
||||
"""格式化文件大小"""
|
||||
if size_bytes < 1024:
|
||||
return f"{size_bytes}B"
|
||||
elif size_bytes < 1024 * 1024:
|
||||
return f"{size_bytes / 1024:.2f}KB"
|
||||
else:
|
||||
return f"{size_bytes / (1024*1024):.2f}MB"
|
||||
|
||||
|
||||
@app.delete("/api/datasets/{dataset_id}", response_model=ResponseModel)
|
||||
async def delete_dataset(dataset_id: int):
|
||||
"""删除数据集"""
|
||||
global datasets
|
||||
for i, dataset in enumerate(datasets):
|
||||
if dataset["id"] == dataset_id:
|
||||
deleted_dataset = datasets.pop(i)
|
||||
return ResponseModel(
|
||||
code=200,
|
||||
message="删除成功",
|
||||
data={"deleted_dataset": deleted_dataset}
|
||||
)
|
||||
raise HTTPException(status_code=404, detail="数据集不存在")
|
||||
|
||||
|
||||
@app.get("/api/models", response_model=ResponseModel)
|
||||
async def get_models():
|
||||
"""获取模型列表"""
|
||||
return ResponseModel(code=200, message="获取成功", data={"models": models})
|
||||
|
||||
|
||||
@app.post("/api/models/config", response_model=ResponseModel)
|
||||
async def config_model(config: ModelConfigModel):
|
||||
"""配置模型参数"""
|
||||
return ResponseModel(
|
||||
code=200,
|
||||
message="配置成功",
|
||||
data={
|
||||
"model_name": config.model_name,
|
||||
"learning_rate": config.learning_rate,
|
||||
"batch_size": config.batch_size,
|
||||
"epochs": config.epochs,
|
||||
"status": "已配置"
|
||||
}
|
||||
def get_db_connection():
|
||||
"""获取数据库连接"""
|
||||
db_config = CONFIG['database']
|
||||
return pymysql.connect(
|
||||
host=db_config['host'],
|
||||
port=db_config['port'],
|
||||
user=db_config['username'],
|
||||
password=db_config['password'],
|
||||
database=db_config['name'],
|
||||
charset=db_config.get('charset', 'utf8mb4'),
|
||||
cursorclass=pymysql.cursors.DictCursor
|
||||
)
|
||||
|
||||
|
||||
@app.get("/api/training/status")
|
||||
async def get_training_status():
|
||||
"""获取训练状态"""
|
||||
return ResponseModel(
|
||||
code=200,
|
||||
message="获取成功",
|
||||
data={
|
||||
"current_task": "GPT-4微调",
|
||||
"progress": 75,
|
||||
"eta": "2小时",
|
||||
"loss": 0.23,
|
||||
"accuracy": 0.89
|
||||
}
|
||||
)
|
||||
def init_database():
|
||||
"""初始化数据库表"""
|
||||
conn = get_db_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
tables = [
|
||||
# 精调训练表
|
||||
"""CREATE TABLE IF NOT EXISTS fine_tune (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
base_model VARCHAR(255),
|
||||
train_type VARCHAR(50),
|
||||
train_method VARCHAR(50),
|
||||
dataset_id INT,
|
||||
valid_split VARCHAR(50),
|
||||
valid_ratio INT DEFAULT 10,
|
||||
output_model_name VARCHAR(255),
|
||||
status VARCHAR(50) DEFAULT 'pending',
|
||||
progress INT DEFAULT 0,
|
||||
create_time DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
update_time DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
|
||||
)""",
|
||||
|
||||
# 我的模型表
|
||||
"""CREATE TABLE IF NOT EXISTS my_models (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
type VARCHAR(100),
|
||||
version VARCHAR(50),
|
||||
description TEXT,
|
||||
create_time DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
update_time DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
|
||||
)""",
|
||||
|
||||
# 模型评测表
|
||||
"""CREATE TABLE IF NOT EXISTS model_eval (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
model_name VARCHAR(255) NOT NULL,
|
||||
dataset VARCHAR(255),
|
||||
metric VARCHAR(100),
|
||||
score DECIMAL(10, 4),
|
||||
status VARCHAR(50) DEFAULT 'completed',
|
||||
create_time DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)""",
|
||||
|
||||
# 模型部署表
|
||||
"""CREATE TABLE IF NOT EXISTS model_deploy (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
model_name VARCHAR(255) NOT NULL,
|
||||
endpoint VARCHAR(255),
|
||||
instance VARCHAR(100),
|
||||
status VARCHAR(50) DEFAULT 'running',
|
||||
create_time DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
update_time DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
|
||||
)""",
|
||||
|
||||
# 数据集管理表
|
||||
"""CREATE TABLE IF NOT EXISTS dataset_manage (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
type VARCHAR(100),
|
||||
size VARCHAR(50),
|
||||
count INT,
|
||||
description TEXT,
|
||||
create_time DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
update_time DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
|
||||
)""",
|
||||
|
||||
# 数据生成表
|
||||
"""CREATE TABLE IF NOT EXISTS data_generate (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
template VARCHAR(255),
|
||||
count INT DEFAULT 0,
|
||||
status VARCHAR(50) DEFAULT 'pending',
|
||||
create_time DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
update_time DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
|
||||
)""",
|
||||
|
||||
# 权限管理表
|
||||
"""CREATE TABLE IF NOT EXISTS permission (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
username VARCHAR(100) NOT NULL,
|
||||
role VARCHAR(50) DEFAULT 'user',
|
||||
permissions TEXT,
|
||||
create_time DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
update_time DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
|
||||
)""",
|
||||
|
||||
# 系统配置表
|
||||
"""CREATE TABLE IF NOT EXISTS sys_config (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
config_key VARCHAR(100) NOT NULL UNIQUE,
|
||||
config_value TEXT,
|
||||
description VARCHAR(255),
|
||||
update_time DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
|
||||
)""",
|
||||
|
||||
# 用户表
|
||||
"""CREATE TABLE IF NOT EXISTS users (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
username VARCHAR(100) NOT NULL UNIQUE,
|
||||
password VARCHAR(255) NOT NULL,
|
||||
role VARCHAR(50) DEFAULT 'user',
|
||||
create_time DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)"""
|
||||
]
|
||||
|
||||
for table_sql in tables:
|
||||
cursor.execute(table_sql)
|
||||
|
||||
# 插入默认管理员用户
|
||||
cursor.execute("SELECT * FROM users WHERE username = 'admin'")
|
||||
if not cursor.fetchone():
|
||||
cursor.execute("INSERT INTO users (username, password, role) VALUES ('admin', 'admin', 'admin')")
|
||||
|
||||
conn.commit()
|
||||
cursor.close()
|
||||
conn.close()
|
||||
print("数据库初始化完成")
|
||||
|
||||
|
||||
@app.get("/api/system/stats")
|
||||
async def get_system_stats():
|
||||
"""获取系统统计信息"""
|
||||
import random
|
||||
return ResponseModel(
|
||||
code=200,
|
||||
message="获取成功",
|
||||
data={
|
||||
"cpu_usage": random.randint(30, 80),
|
||||
"memory_usage": random.randint(40, 70),
|
||||
"gpu_usage": random.randint(50, 90),
|
||||
"active_tasks": 5,
|
||||
"completed_tasks": 158
|
||||
}
|
||||
)
|
||||
app = Flask(__name__)
|
||||
app.config['SECRET_KEY'] = CONFIG['secret_key']
|
||||
CORS(app, resources={r"/api/*": {"origins": "*"}})
|
||||
|
||||
|
||||
@app.post("/api/training/start")
|
||||
async def start_training(model_name: str, dataset_id: int):
|
||||
"""开始训练任务"""
|
||||
return ResponseModel(
|
||||
code=200,
|
||||
message="训练任务已启动",
|
||||
data={
|
||||
"task_id": random.randint(1000, 9999),
|
||||
"model_name": model_name,
|
||||
"dataset_id": dataset_id,
|
||||
"status": "running"
|
||||
}
|
||||
)
|
||||
# ============ 健康检查 ============
|
||||
@app.route('/api/health', methods=['GET'])
|
||||
def health_check():
|
||||
"""健康检查接口"""
|
||||
return jsonify({'status': 'ok', 'code': 0})
|
||||
|
||||
|
||||
@app.post("/api/training/stop/{task_id}")
|
||||
async def stop_training(task_id: int):
|
||||
"""停止训练任务"""
|
||||
return ResponseModel(
|
||||
code=200,
|
||||
message=f"训练任务 {task_id} 已停止",
|
||||
data={"task_id": task_id, "status": "stopped"}
|
||||
)
|
||||
# ============ 通用 CRUD 操作 ============
|
||||
def generic_get_all(table_name, order_by='create_time DESC'):
|
||||
"""通用查询所有"""
|
||||
conn = get_db_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(f"SELECT * FROM {table_name} ORDER BY {order_by}")
|
||||
result = cursor.fetchall()
|
||||
cursor.close()
|
||||
conn.close()
|
||||
return result
|
||||
|
||||
|
||||
@app.get("/api/model/{model_id}/metrics")
|
||||
async def get_model_metrics(model_id: int):
|
||||
"""获取模型指标"""
|
||||
return ResponseModel(
|
||||
code=200,
|
||||
message="获取成功",
|
||||
data={
|
||||
"model_id": model_id,
|
||||
"accuracy": round(random.uniform(0.85, 0.98), 3),
|
||||
"precision": round(random.uniform(0.80, 0.95), 3),
|
||||
"recall": round(random.uniform(0.82, 0.96), 3),
|
||||
"f1_score": round(random.uniform(0.83, 0.97), 3),
|
||||
"training_time": f"{random.randint(2, 24)}小时",
|
||||
"parameters": random.randint(1000000, 100000000)
|
||||
}
|
||||
)
|
||||
def generic_get_by_id(table_name, id_val):
|
||||
"""通用按ID查询"""
|
||||
conn = get_db_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(f"SELECT * FROM {table_name} WHERE id = %s", (id_val,))
|
||||
result = cursor.fetchone()
|
||||
cursor.close()
|
||||
conn.close()
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host="0.0.0.0", port=8001)
|
||||
def generic_create(table_name, data):
|
||||
"""通用创建"""
|
||||
conn = get_db_connection()
|
||||
cursor = conn.cursor()
|
||||
columns = ', '.join(data.keys())
|
||||
placeholders = ', '.join(['%s'] * len(data))
|
||||
sql = f"INSERT INTO {table_name} ({columns}) VALUES ({placeholders})"
|
||||
cursor.execute(sql, list(data.values()))
|
||||
conn.commit()
|
||||
new_id = cursor.lastrowid
|
||||
cursor.close()
|
||||
conn.close()
|
||||
return new_id
|
||||
|
||||
|
||||
def generic_update(table_name, id_val, data):
|
||||
"""通用更新"""
|
||||
conn = get_db_connection()
|
||||
cursor = conn.cursor()
|
||||
set_clause = ', '.join([f"{k} = %s" for k in data.keys()])
|
||||
sql = f"UPDATE {table_name} SET {set_clause} WHERE id = %s"
|
||||
values = list(data.values()) + [id_val]
|
||||
cursor.execute(sql, values)
|
||||
conn.commit()
|
||||
cursor.close()
|
||||
conn.close()
|
||||
|
||||
|
||||
def generic_delete(table_name, id_val):
|
||||
"""通用删除"""
|
||||
conn = get_db_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(f"DELETE FROM {table_name} WHERE id = %s", (id_val,))
|
||||
conn.commit()
|
||||
cursor.close()
|
||||
conn.close()
|
||||
|
||||
|
||||
# ============ 登录接口 ============
|
||||
@app.route('/api/login', methods=['POST'])
|
||||
def login():
|
||||
data = request.json
|
||||
username = data.get('username')
|
||||
password = data.get('password')
|
||||
|
||||
conn = get_db_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM users WHERE username = %s AND password = %s", (username, password))
|
||||
user = cursor.fetchone()
|
||||
cursor.close()
|
||||
conn.close()
|
||||
|
||||
if user:
|
||||
return jsonify({'code': 0, 'message': '登录成功', 'data': {'username': user['username'], 'role': user['role']}})
|
||||
return jsonify({'code': 1, 'message': '用户名或密码错误'})
|
||||
|
||||
|
||||
# ============ 精调训练接口 ============
|
||||
@app.route('/api/fine-tune', methods=['GET'])
|
||||
def get_fine_tune():
|
||||
return jsonify({'code': 0, 'data': generic_get_all('fine_tune')})
|
||||
|
||||
|
||||
@app.route('/api/fine-tune', methods=['POST'])
|
||||
def create_fine_tune():
|
||||
data = request.json
|
||||
new_id = generic_create('fine_tune', data)
|
||||
return jsonify({'code': 0, 'message': '创建成功', 'id': new_id})
|
||||
|
||||
|
||||
@app.route('/api/fine-tune/<int:id>', methods=['PUT'])
|
||||
def update_fine_tune(id):
|
||||
data = request.json
|
||||
generic_update('fine_tune', id, data)
|
||||
return jsonify({'code': 0, 'message': '更新成功'})
|
||||
|
||||
|
||||
@app.route('/api/fine-tune/<int:id>', methods=['DELETE'])
|
||||
def delete_fine_tune(id):
|
||||
generic_delete('fine_tune', id)
|
||||
return jsonify({'code': 0, 'message': '删除成功'})
|
||||
|
||||
|
||||
# ============ 我的模型接口 ============
|
||||
@app.route('/api/my-models', methods=['GET'])
|
||||
def get_my_models():
|
||||
return jsonify({'code': 0, 'data': generic_get_all('my_models')})
|
||||
|
||||
|
||||
@app.route('/api/my-models', methods=['POST'])
|
||||
def create_my_model():
|
||||
data = request.json
|
||||
new_id = generic_create('my_models', data)
|
||||
return jsonify({'code': 0, 'message': '创建成功', 'id': new_id})
|
||||
|
||||
|
||||
@app.route('/api/my-models/<int:id>', methods=['PUT'])
|
||||
def update_my_model(id):
|
||||
data = request.json
|
||||
generic_update('my_models', id, data)
|
||||
return jsonify({'code': 0, 'message': '更新成功'})
|
||||
|
||||
|
||||
@app.route('/api/my-models/<int:id>', methods=['DELETE'])
|
||||
def delete_my_model(id):
|
||||
generic_delete('my_models', id)
|
||||
return jsonify({'code': 0, 'message': '删除成功'})
|
||||
|
||||
|
||||
# ============ 模型评测接口 ============
|
||||
@app.route('/api/model-eval', methods=['GET'])
|
||||
def get_model_eval():
|
||||
return jsonify({'code': 0, 'data': generic_get_all('model_eval')})
|
||||
|
||||
|
||||
@app.route('/api/model-eval', methods=['POST'])
|
||||
def create_model_eval():
|
||||
data = request.json
|
||||
new_id = generic_create('model_eval', data)
|
||||
return jsonify({'code': 0, 'message': '创建成功', 'id': new_id})
|
||||
|
||||
|
||||
@app.route('/api/model-eval/<int:id>', methods=['PUT'])
|
||||
def update_model_eval(id):
|
||||
data = request.json
|
||||
generic_update('model_eval', id, data)
|
||||
return jsonify({'code': 0, 'message': '更新成功'})
|
||||
|
||||
|
||||
@app.route('/api/model-eval/<int:id>', methods=['DELETE'])
|
||||
def delete_model_eval(id):
|
||||
generic_delete('model_eval', id)
|
||||
return jsonify({'code': 0, 'message': '删除成功'})
|
||||
|
||||
|
||||
# ============ 模型部署接口 ============
|
||||
@app.route('/api/model-deploy', methods=['GET'])
|
||||
def get_model_deploy():
|
||||
return jsonify({'code': 0, 'data': generic_get_all('model_deploy')})
|
||||
|
||||
|
||||
@app.route('/api/model-deploy', methods=['POST'])
|
||||
def create_model_deploy():
|
||||
data = request.json
|
||||
new_id = generic_create('model_deploy', data)
|
||||
return jsonify({'code': 0, 'message': '创建成功', 'id': new_id})
|
||||
|
||||
|
||||
@app.route('/api/model-deploy/<int:id>', methods=['PUT'])
|
||||
def update_model_deploy(id):
|
||||
data = request.json
|
||||
generic_update('model_deploy', id, data)
|
||||
return jsonify({'code': 0, 'message': '更新成功'})
|
||||
|
||||
|
||||
@app.route('/api/model-deploy/<int:id>', methods=['DELETE'])
|
||||
def delete_model_deploy(id):
|
||||
generic_delete('model_deploy', id)
|
||||
return jsonify({'code': 0, 'message': '删除成功'})
|
||||
|
||||
|
||||
# ============ 数据集管理接口 ============
|
||||
@app.route('/api/dataset-manage', methods=['GET'])
|
||||
def get_dataset_manage():
|
||||
return jsonify({'code': 0, 'data': generic_get_all('dataset_manage')})
|
||||
|
||||
|
||||
@app.route('/api/dataset-manage', methods=['POST'])
|
||||
def create_dataset_manage():
|
||||
data = request.json
|
||||
new_id = generic_create('dataset_manage', data)
|
||||
return jsonify({'code': 0, 'message': '创建成功', 'id': new_id})
|
||||
|
||||
|
||||
@app.route('/api/dataset-manage/<int:id>', methods=['PUT'])
|
||||
def update_dataset_manage(id):
|
||||
data = request.json
|
||||
generic_update('dataset_manage', id, data)
|
||||
return jsonify({'code': 0, 'message': '更新成功'})
|
||||
|
||||
|
||||
@app.route('/api/dataset-manage/<int:id>', methods=['DELETE'])
|
||||
def delete_dataset_manage(id):
|
||||
generic_delete('dataset_manage', id)
|
||||
return jsonify({'code': 0, 'message': '删除成功'})
|
||||
|
||||
|
||||
# ============ 数据生成接口 ============
|
||||
@app.route('/api/data-generate', methods=['GET'])
|
||||
def get_data_generate():
|
||||
return jsonify({'code': 0, 'data': generic_get_all('data_generate')})
|
||||
|
||||
|
||||
@app.route('/api/data-generate', methods=['POST'])
|
||||
def create_data_generate():
|
||||
data = request.json
|
||||
new_id = generic_create('data_generate', data)
|
||||
return jsonify({'code': 0, 'message': '创建成功', 'id': new_id})
|
||||
|
||||
|
||||
@app.route('/api/data-generate/<int:id>', methods=['PUT'])
|
||||
def update_data_generate(id):
|
||||
data = request.json
|
||||
generic_update('data_generate', id, data)
|
||||
return jsonify({'code': 0, 'message': '更新成功'})
|
||||
|
||||
|
||||
@app.route('/api/data-generate/<int:id>', methods=['DELETE'])
|
||||
def delete_data_generate(id):
|
||||
generic_delete('data_generate', id)
|
||||
return jsonify({'code': 0, 'message': '删除成功'})
|
||||
|
||||
|
||||
# ============ 权限管理接口 ============
|
||||
@app.route('/api/permission', methods=['GET'])
|
||||
def get_permission():
|
||||
return jsonify({'code': 0, 'data': generic_get_all('permission')})
|
||||
|
||||
|
||||
@app.route('/api/permission', methods=['POST'])
|
||||
def create_permission():
|
||||
data = request.json
|
||||
new_id = generic_create('permission', data)
|
||||
return jsonify({'code': 0, 'message': '创建成功', 'id': new_id})
|
||||
|
||||
|
||||
@app.route('/api/permission/<int:id>', methods=['PUT'])
|
||||
def update_permission(id):
|
||||
data = request.json
|
||||
generic_update('permission', id, data)
|
||||
return jsonify({'code': 0, 'message': '更新成功'})
|
||||
|
||||
|
||||
@app.route('/api/permission/<int:id>', methods=['DELETE'])
|
||||
def delete_permission(id):
|
||||
generic_delete('permission', id)
|
||||
return jsonify({'code': 0, 'message': '删除成功'})
|
||||
|
||||
|
||||
# ============ 系统配置接口 ============
|
||||
@app.route('/api/sys-config', methods=['GET'])
|
||||
def get_sys_config():
|
||||
return jsonify({'code': 0, 'data': generic_get_all('sys_config')})
|
||||
|
||||
|
||||
@app.route('/api/sys-config', methods=['POST'])
|
||||
def create_sys_config():
|
||||
data = request.json
|
||||
new_id = generic_create('sys_config', data)
|
||||
return jsonify({'code': 0, 'message': '创建成功', 'id': new_id})
|
||||
|
||||
|
||||
@app.route('/api/sys-config/<int:id>', methods=['PUT'])
|
||||
def update_sys_config(id):
|
||||
data = request.json
|
||||
generic_update('sys_config', id, data)
|
||||
return jsonify({'code': 0, 'message': '更新成功'})
|
||||
|
||||
|
||||
@app.route('/api/sys-config/<int:id>', methods=['DELETE'])
|
||||
def delete_sys_config(id):
|
||||
generic_delete('sys_config', id)
|
||||
return jsonify({'code': 0, 'message': '删除成功'})
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
init_database()
|
||||
app_config = CONFIG['app']
|
||||
app.run(host=app_config['host'], port=app_config['port'], debug=app_config.get('debug', True))
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
fastapi==0.104.1
|
||||
uvicorn[standard]==0.24.0
|
||||
pydantic==2.5.0
|
||||
python-multipart==0.0.6
|
||||
47
src/run.sh
Executable file → Normal file
47
src/run.sh
Executable file → Normal file
@@ -1,43 +1,14 @@
|
||||
#!/bin/bash
|
||||
# 启动远光软件微调平台后端服务
|
||||
|
||||
echo "🚀 启动 FastAPI 服务器..."
|
||||
cd "$(dirname "$0")"
|
||||
|
||||
# 确保在正确的目录中
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
cd "$SCRIPT_DIR"
|
||||
|
||||
echo "📂 当前目录: $SCRIPT_DIR"
|
||||
|
||||
# 检查Python是否安装
|
||||
if ! command -v python3 &> /dev/null; then
|
||||
echo "❌ 错误: Python3 未安装"
|
||||
echo "请先安装 Python3"
|
||||
exit 1
|
||||
# 检查并安装依赖
|
||||
if ! python3 -c "import flask" 2>/dev/null; then
|
||||
echo "正在安装依赖..."
|
||||
pip install -r ../requirements.txt
|
||||
fi
|
||||
|
||||
# 检查pip是否安装
|
||||
if ! command -v pip3 &> /dev/null; then
|
||||
echo "❌ 错误: pip3 未安装"
|
||||
echo "请先安装 pip3"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 安装依赖
|
||||
echo "📦 安装依赖包..."
|
||||
pip3 install -r requirements.txt
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "❌ 依赖安装失败"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "🌐 服务器地址: http://localhost:8001"
|
||||
echo "📚 API 文档: http://localhost:8001/docs"
|
||||
echo "🔍 替代文档: http://localhost:8001/redoc"
|
||||
echo ""
|
||||
echo "按 Ctrl+C 停止服务器"
|
||||
echo ""
|
||||
|
||||
# 启动服务器
|
||||
python3 -m uvicorn main:app --host 0.0.0.0 --port 8001 --reload
|
||||
# 启动服务
|
||||
echo "启动后端服务..."
|
||||
python3 main.py
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
echo "🧪 测试 FastAPI 服务器"
|
||||
echo "=================================="
|
||||
echo ""
|
||||
|
||||
BASE_URL="http://localhost:8001"
|
||||
|
||||
# 测试 1: 根路径
|
||||
echo "1. 测试根路径..."
|
||||
curl -s "$BASE_URL/" | python3 -m json.tool
|
||||
echo ""
|
||||
|
||||
# 测试 2: 健康检查
|
||||
echo "2. 测试健康检查..."
|
||||
curl -s "$BASE_URL/api/health" | python3 -m json.tool
|
||||
echo ""
|
||||
|
||||
# 测试 3: 用户登录
|
||||
echo "3. 测试用户登录..."
|
||||
curl -s -X POST "$BASE_URL/api/login" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"username": "admin", "password": "123456"}' | python3 -m json.tool
|
||||
echo ""
|
||||
|
||||
# 测试 4: 获取数据集列表
|
||||
echo "4. 测试获取数据集列表..."
|
||||
curl -s "$BASE_URL/api/datasets" | python3 -m json.tool
|
||||
echo ""
|
||||
|
||||
# 测试 5: 获取模型列表
|
||||
echo "5. 测试获取模型列表..."
|
||||
curl -s "$BASE_URL/api/models" | python3 -m json.tool
|
||||
echo ""
|
||||
|
||||
# 测试 6: 系统统计
|
||||
echo "6. 测试系统统计..."
|
||||
curl -s "$BASE_URL/api/system/stats" | python3 -m json.tool
|
||||
echo ""
|
||||
|
||||
# 测试 7: 训练状态
|
||||
echo "7. 测试训练状态..."
|
||||
curl -s "$BASE_URL/api/training/status" | python3 -m json.tool
|
||||
echo ""
|
||||
|
||||
echo "=================================="
|
||||
echo "✅ 所有测试完成!"
|
||||
Reference in New Issue
Block a user