2026-01-20 16:16:13 +08:00
|
|
|
|
"""
|
|
|
|
|
|
模型管理 API 路由
|
|
|
|
|
|
"""
|
|
|
|
|
|
import os
|
|
|
|
|
|
import pymysql
|
|
|
|
|
|
import yaml
|
|
|
|
|
|
from flask import Blueprint, request, jsonify
|
|
|
|
|
|
|
2026-01-28 10:31:09 +08:00
|
|
|
|
# 获取项目根目录 - 优先使用环境变量,否则从文件路径计算
|
|
|
|
|
|
MOUNT_BASE = os.environ.get('MOUNT_BASE', '/app/base')
|
2026-01-20 16:16:13 +08:00
|
|
|
|
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
2026-01-28 10:31:09 +08:00
|
|
|
|
# 如果 PROJECT_ROOT 是 /app 或 /app/src/llamafactory,则使用挂载路径
|
|
|
|
|
|
if PROJECT_ROOT in ('/app', '/app/src/llamafactory'):
|
|
|
|
|
|
PROJECT_ROOT = MOUNT_BASE
|
2026-01-20 16:16:13 +08:00
|
|
|
|
|
|
|
|
|
|
# 创建蓝图
|
|
|
|
|
|
model_manage_bp = Blueprint('model_manage', __name__, url_prefix='/api/model-manage')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_db_connection():
|
|
|
|
|
|
"""获取数据库连接"""
|
|
|
|
|
|
CONFIG_PATH = os.path.join(PROJECT_ROOT, 'config.yaml')
|
|
|
|
|
|
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
|
|
|
|
|
|
CONFIG = yaml.safe_load(f)
|
|
|
|
|
|
db_config = CONFIG['database']
|
|
|
|
|
|
return pymysql.connect(
|
|
|
|
|
|
host=db_config['host'],
|
|
|
|
|
|
port=db_config['port'],
|
|
|
|
|
|
user=db_config['username'],
|
|
|
|
|
|
password=db_config['password'],
|
|
|
|
|
|
database=db_config['name'],
|
|
|
|
|
|
charset=db_config.get('charset', 'utf8mb4'),
|
|
|
|
|
|
cursorclass=pymysql.cursors.DictCursor
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generic_get_all(table_name, order_by='create_time DESC'):
|
|
|
|
|
|
"""通用查询所有"""
|
|
|
|
|
|
conn = get_db_connection()
|
|
|
|
|
|
cursor = conn.cursor()
|
|
|
|
|
|
cursor.execute(f"SELECT * FROM {table_name} ORDER BY {order_by}")
|
|
|
|
|
|
result = cursor.fetchall()
|
|
|
|
|
|
cursor.close()
|
|
|
|
|
|
conn.close()
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generic_create(table_name, data):
|
|
|
|
|
|
"""通用创建"""
|
|
|
|
|
|
conn = get_db_connection()
|
|
|
|
|
|
cursor = conn.cursor()
|
|
|
|
|
|
columns = ', '.join(data.keys())
|
|
|
|
|
|
placeholders = ', '.join(['%s'] * len(data))
|
|
|
|
|
|
sql = f"INSERT INTO {table_name} ({columns}) VALUES ({placeholders})"
|
|
|
|
|
|
cursor.execute(sql, list(data.values()))
|
|
|
|
|
|
conn.commit()
|
|
|
|
|
|
new_id = cursor.lastrowid
|
|
|
|
|
|
cursor.close()
|
|
|
|
|
|
conn.close()
|
|
|
|
|
|
return new_id
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generic_update(table_name, id_val, data):
|
|
|
|
|
|
"""通用更新"""
|
|
|
|
|
|
conn = get_db_connection()
|
|
|
|
|
|
cursor = conn.cursor()
|
|
|
|
|
|
set_clause = ', '.join([f"{k} = %s" for k in data.keys()])
|
|
|
|
|
|
sql = f"UPDATE {table_name} SET {set_clause} WHERE id = %s"
|
|
|
|
|
|
values = list(data.values()) + [id_val]
|
|
|
|
|
|
cursor.execute(sql, values)
|
|
|
|
|
|
conn.commit()
|
|
|
|
|
|
cursor.close()
|
|
|
|
|
|
conn.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generic_delete(table_name, id_val):
|
|
|
|
|
|
"""通用删除"""
|
|
|
|
|
|
conn = get_db_connection()
|
|
|
|
|
|
cursor = conn.cursor()
|
|
|
|
|
|
cursor.execute(f"DELETE FROM {table_name} WHERE id = %s", (id_val,))
|
|
|
|
|
|
conn.commit()
|
|
|
|
|
|
cursor.close()
|
|
|
|
|
|
conn.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generic_get_by_id(table_name, id_val):
|
|
|
|
|
|
"""通用按ID查询"""
|
|
|
|
|
|
conn = get_db_connection()
|
|
|
|
|
|
cursor = conn.cursor()
|
|
|
|
|
|
cursor.execute(f"SELECT * FROM {table_name} WHERE id = %s", (id_val,))
|
|
|
|
|
|
result = cursor.fetchone()
|
|
|
|
|
|
cursor.close()
|
|
|
|
|
|
conn.close()
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ============ 模型管理 CRUD ============
|
|
|
|
|
|
|
|
|
|
|
|
@model_manage_bp.route('', methods=['GET'])
|
|
|
|
|
|
def get_model_manage():
|
|
|
|
|
|
"""获取所有模型"""
|
|
|
|
|
|
return jsonify({'code': 0, 'data': generic_get_all('model_manage')})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@model_manage_bp.route('/<int:id>', methods=['GET'])
|
|
|
|
|
|
def get_model_manage_by_id(id):
|
|
|
|
|
|
"""获取单个模型"""
|
|
|
|
|
|
model = generic_get_by_id('model_manage', id)
|
|
|
|
|
|
if model:
|
|
|
|
|
|
return jsonify({'code': 0, 'data': model})
|
|
|
|
|
|
return jsonify({'code': 1, 'message': '模型不存在'})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@model_manage_bp.route('', methods=['POST'])
|
|
|
|
|
|
def create_model_manage():
|
|
|
|
|
|
"""创建模型"""
|
|
|
|
|
|
data = request.json
|
|
|
|
|
|
# 构建插入数据
|
|
|
|
|
|
insert_data = {
|
|
|
|
|
|
'name': data.get('name'),
|
|
|
|
|
|
'type': data.get('type'),
|
|
|
|
|
|
'model_source': data.get('model_source', 'local'),
|
2026-01-22 16:46:12 +08:00
|
|
|
|
'description': data.get('description'),
|
|
|
|
|
|
'purpose': data.get('purpose', 'inference') # 默认推理用途
|
2026-01-20 16:16:13 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if data.get('model_source') == 'local':
|
|
|
|
|
|
insert_data['path'] = data.get('path', '')
|
|
|
|
|
|
else:
|
|
|
|
|
|
insert_data['api_url'] = data.get('api_url', '')
|
|
|
|
|
|
insert_data['api_key'] = data.get('api_key', '')
|
|
|
|
|
|
insert_data['model_name'] = data.get('model_name', '')
|
|
|
|
|
|
|
|
|
|
|
|
new_id = generic_create('model_manage', insert_data)
|
|
|
|
|
|
return jsonify({'code': 0, 'message': '创建成功', 'id': new_id})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@model_manage_bp.route('/<int:id>', methods=['PUT'])
|
|
|
|
|
|
def update_model_manage(id):
|
|
|
|
|
|
"""更新模型"""
|
|
|
|
|
|
data = request.json
|
|
|
|
|
|
generic_update('model_manage', id, data)
|
|
|
|
|
|
return jsonify({'code': 0, 'message': '更新成功'})
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-01-22 16:46:12 +08:00
|
|
|
|
@model_manage_bp.route('/<int:id>/purpose', methods=['PUT'])
|
|
|
|
|
|
def update_model_purpose(id):
|
|
|
|
|
|
"""更新模型用途"""
|
|
|
|
|
|
data = request.json
|
|
|
|
|
|
purpose = data.get('purpose')
|
|
|
|
|
|
if purpose not in ['training', 'inference', 'evaluation']:
|
|
|
|
|
|
return jsonify({'code': 1, 'message': '无效的用途类型'})
|
|
|
|
|
|
generic_update('model_manage', id, {'purpose': purpose})
|
|
|
|
|
|
return jsonify({'code': 0, 'message': '更新成功'})
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-01-20 16:16:13 +08:00
|
|
|
|
@model_manage_bp.route('/<int:id>', methods=['DELETE'])
|
|
|
|
|
|
def delete_model_manage(id):
|
|
|
|
|
|
"""删除模型"""
|
|
|
|
|
|
generic_delete('model_manage', id)
|
|
|
|
|
|
return jsonify({'code': 0, 'message': '删除成功'})
|
2026-01-26 17:23:34 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ============ 本地模型列表接口 ============
|
|
|
|
|
|
|
|
|
|
|
|
@model_manage_bp.route('/local-models', methods=['GET'])
|
|
|
|
|
|
def get_local_models():
|
|
|
|
|
|
"""获取本地模型列表(从YG_FT_Base/local_models目录)"""
|
|
|
|
|
|
import logging
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 使用 YG_FT_Base/local_models 目录
|
|
|
|
|
|
base_path = os.path.join(PROJECT_ROOT, 'local_models')
|
|
|
|
|
|
|
|
|
|
|
|
models = []
|
|
|
|
|
|
if os.path.exists(base_path):
|
|
|
|
|
|
for item in os.listdir(base_path):
|
|
|
|
|
|
item_path = os.path.join(base_path, item)
|
|
|
|
|
|
if os.path.isdir(item_path):
|
|
|
|
|
|
models.append({
|
|
|
|
|
|
'name': item,
|
|
|
|
|
|
'path': item_path
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
return jsonify({
|
|
|
|
|
|
'code': 0,
|
|
|
|
|
|
'data': {
|
|
|
|
|
|
'models': models,
|
|
|
|
|
|
'base_path': base_path
|
|
|
|
|
|
}
|
|
|
|
|
|
})
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"获取本地模型列表失败: {e}")
|
|
|
|
|
|
return jsonify({'code': 1, 'message': str(e)})
|