重构了main.html的主函数

重构了大量的页面的sidebar
优化了代码结构
This commit is contained in:
2026-02-02 09:22:52 +08:00
33 changed files with 5566 additions and 2383 deletions

View File

@@ -4,9 +4,7 @@
import os
import sys
import logging
import yaml
from datetime import datetime
from logging.handlers import TimedRotatingFileHandler, RotatingFileHandler
from flask import Blueprint, request, jsonify
# 获取项目根目录
@@ -14,74 +12,40 @@ PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(_
sys.path.insert(0, PROJECT_ROOT)
# 加载配置
import yaml
CONFIG_PATH = os.path.join(PROJECT_ROOT, 'config.yaml')
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
CONFIG = yaml.safe_load(f)
# 日志目录
LOG_BASE_DIR = os.path.join(PROJECT_ROOT, 'logs')
# 日志目录 - 使用与 main.py 相同的配置
def get_log_base_dir():
"""获取日志基础目录"""
# 1. 检查环境变量
if 'LOG_BASE_DIR' in os.environ:
return os.environ['LOG_BASE_DIR']
# 2. 检查是否在容器环境中
mount_base = os.environ.get('MOUNT_BASE', '/app/base')
if os.path.exists(mount_base):
return os.path.join(mount_base, 'logs')
# 3. 使用本地项目路径
return os.path.join(PROJECT_ROOT, 'logs')
LOG_BASE_DIR = get_log_base_dir()
# 创建蓝图
logs_bp = Blueprint('logs', __name__, url_prefix='/api')
def setup_logs_logger():
"""配置日志系统,按日期分目录存储"""
today = datetime.now().strftime('%Y-%m-%d')
log_dir = os.path.join(LOG_BASE_DIR, today)
os.makedirs(log_dir, exist_ok=True)
logger = logging.getLogger('logs_api')
logger.setLevel(logging.DEBUG)
logger.handlers.clear()
# 全部日志处理器
all_log_path = os.path.join(log_dir, 'all.log')
all_handler = TimedRotatingFileHandler(all_log_path, when='midnight', interval=1, backupCount=30, encoding='utf-8')
all_handler.setLevel(logging.DEBUG)
all_handler.setFormatter(logging.Formatter(
'[%(asctime)s] %(levelname)s in %(module)s: %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
))
# 控制台处理器
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(logging.Formatter(
'[%(asctime)s] %(levelname)s: %(message)s',
datefmt='%H:%M:%S'
))
logger.addHandler(all_handler)
logger.addHandler(console_handler)
return logger
def get_logs_logger():
"""从 main.py 获取日志记录器"""
return logging.getLogger('logs_api')
logs_logger = setup_logs_logger()
@logs_bp.route('/web-log', methods=['POST'])
def receive_web_log():
"""接收前端页面发送的日志"""
data = request.json
level = data.get('level', 'info')
message = data.get('message', '')
page = data.get('page', 'unknown')
timestamp = data.get('timestamp', datetime.now().isoformat())
log_message = f'[WEB-{page}] {message}'
if level == 'error':
logs_logger.error(log_message)
elif level == 'warning':
logs_logger.warning(log_message)
elif level == 'debug':
logs_logger.debug(log_message)
else:
logs_logger.info(log_message)
return jsonify({'code': 0, 'message': '日志接收成功'})
def get_request_logger():
"""获取请求日志记录器"""
return logging.getLogger('request')
def format_file_size(size_bytes):
@@ -94,6 +58,30 @@ def format_file_size(size_bytes):
return f'{size_bytes / (1024 * 1024):.1f} MB'
@logs_bp.route('/web-log', methods=['POST'])
def receive_web_log():
"""接收前端页面发送的日志"""
data = request.json
level = data.get('level', 'info')
message = data.get('message', '')
page = data.get('page', 'unknown')
timestamp = data.get('timestamp', '')
log_message = f'[WEB-{page}] {message}'
logger = get_logs_logger()
if level == 'error':
logger.error(log_message)
elif level == 'warning':
logger.warning(log_message)
elif level == 'debug':
logger.debug(log_message)
else:
logger.info(log_message)
return jsonify({'code': 0, 'message': '日志接收成功'})
@logs_bp.route('/log-files', methods=['GET'])
def get_log_files():
"""获取指定日期的日志文件列表"""
@@ -113,7 +101,8 @@ def get_log_files():
return jsonify({'code': 0, 'data': []})
log_files = []
file_names = ['all.log', 'error.log', 'request.log']
# 定义日志文件的优先级顺序
file_names = ['all.log', 'api.log', 'error.log', 'request.log', 'train.log']
for file_name in file_names:
file_path = os.path.join(log_dir, file_name)
@@ -178,15 +167,13 @@ TRAINING_LOGS_BASE_DIR = '/app/base/logs'
# 本地开发时的备用路径Windows
LOCAL_TRAINING_LOGS_BASE_DIR = os.path.join(PROJECT_ROOT, 'logs')
# 添加调试日志
logs_logger.info(f"[DEBUG] TRAINING_LOGS_BASE_DIR: {TRAINING_LOGS_BASE_DIR}")
logs_logger.info(f"[DEBUG] LOCAL_TRAINING_LOGS_BASE_DIR: {LOCAL_TRAINING_LOGS_BASE_DIR}")
@logs_bp.route('/training-log-files', methods=['GET'])
def get_training_log_files():
"""获取训练日志文件列表 - 从 logs/{日期} 目录下的 .log 文件"""
try:
logs_logger = get_logs_logger()
# 确定基础目录
logs_base_dir = TRAINING_LOGS_BASE_DIR
if not os.path.exists(logs_base_dir):
@@ -280,7 +267,7 @@ def get_training_log_files():
return jsonify({'code': 0, 'data': log_files})
except Exception as e:
logs_logger.error(f"[DEBUG] 获取训练日志列表失败: {e}")
get_logs_logger().error(f"[DEBUG] 获取训练日志列表失败: {e}")
return jsonify({'code': 1, 'message': f'获取训练日志列表失败: {str(e)}'})
@@ -291,6 +278,7 @@ def get_training_log_content():
if not file_name:
return jsonify({'code': 1, 'message': '缺少文件参数'})
logs_logger = get_logs_logger()
logs_logger.info(f"[DEBUG] ============ get_training_log_content ============")
logs_logger.info(f"[DEBUG] file: {file_name}")

View File

@@ -8,8 +8,12 @@ import json
import requests
import concurrent.futures
import subprocess
import logging
from flask import Blueprint, request, jsonify
# 获取模块 logger继承 main.py 的日志配置)
logger = logging.getLogger(__name__)
# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
@@ -218,8 +222,6 @@ def preload_trained_model():
import sys as sys_module
import pymysql
import yaml
import logging
logger = logging.getLogger(__name__)
data = request.json
model_name = data.get('model_name') # 模型名称
@@ -572,3 +574,502 @@ if __name__ == "__main__":
return jsonify({'code': 1, 'message': '推理超时,请稍后重试'})
except Exception as e:
return jsonify({'code': 1, 'message': f'推理异常: {str(e)}'})
# ==================== Transformers 本地模型接口 ====================
@model_chat_bp.route('/local/preload', methods=['POST'])
def preload_local_model():
"""预加载本地模型(使用 transformers"""
import yaml
import subprocess
import sys as sys_module
data = request.json
model_path = data.get('model_path') # 模型路径
model_name = data.get('model_name', '本地模型') # 模型名称(用于显示)
if not model_path:
return jsonify({'code': 1, 'message': '缺少模型路径'})
logger.info(f"[PRELOAD_LOCAL] 开始预加载本地模型: {model_name}, 路径: {model_path}")
# 先生成唯一ID避免并发冲突必须在f-string之前定义
import uuid as uuid_module
temp_id = uuid_module.uuid4().hex[:8]
# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.dirname(os.path.dirname(os.path.abspath(__file__))))
CONFIG_PATH = os.path.join(PROJECT_ROOT, 'config.yaml')
try:
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
CONFIG = yaml.safe_load(f)
except Exception as e:
return jsonify({'code': 1, 'message': f'读取配置失败: {str(e)}'})
# 构建 transformers 预加载脚本不使用f-string避免import语句导致的模块shadow问题
preload_script = '''# -*- coding: utf-8 -*-
import sys
import uuid
import logging
import os
logging.basicConfig(level=logging.WARNING, format='%(message)s')
# 生成唯一的临时文件路径,避免并发冲突
temp_id = "TEMP_ID"
log_file = "/app/base/logs/preload_local_{}.log".format(temp_id)
# 设置环境变量
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
try:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_id = "MODEL_PATH"
with open(log_file, "w", encoding="utf-8") as f:
f.write("开始加载模型: {}\\n".format(model_id))
# 加载 tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
with open(log_file, "a", encoding="utf-8") as f:
f.write("Tokenizer 加载完成\\n")
# 加载模型
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
with open(log_file, "a", encoding="utf-8") as f:
f.write("模型加载完成\\n")
# 测试推理
test_input = "你好"
inputs = tokenizer(test_input, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=10, do_sample=False)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
with open(log_file, "a", encoding="utf-8") as f:
f.write("测试推理成功: {}\\n".format(response))
with open(log_file, "a", encoding="utf-8") as f:
f.write("SUCCESS\\n")
except Exception as e:
with open(log_file, "a", encoding="utf-8") as f:
f.write("ERROR: {}\\n".format(str(e)))
import traceback
with open(log_file, "a", encoding="utf-8") as f:
f.write(traceback.format_exc())
sys.exit(1)
'''
# 使用 replace 替换变量(避免 f-string 中包含 import 语句)
preload_script = preload_script.replace('TEMP_ID', temp_id).replace('MODEL_PATH', model_path)
# 写入临时脚本 - 使用唯一文件名避免并发冲突
work_dir = '/app/base'
script_path = os.path.join(work_dir, f'temp_preload_local_{temp_id}.py')
log_path = os.path.join('/app/base/logs', f'preload_local_{temp_id}.log')
try:
# 确保logs目录存在
os.makedirs(os.path.dirname(log_path), exist_ok=True)
with open(script_path, 'w', encoding='utf-8') as f:
f.write(preload_script)
# 查找系统 Python不使用虚拟环境
def get_system_python():
# 尝试常见的系统 Python 路径
common_pythons = [
'/usr/bin/python3',
'/usr/bin/python',
'/usr/local/bin/python3',
'/usr/local/bin/python',
]
for py in common_pythons:
if os.path.exists(py) and os.access(py, os.X_OK):
return py
# 如果都没找到,使用系统 PATH 中的 python3
import shutil
py_path = shutil.which('python3')
if py_path:
return py_path
return sys_module.executable # 兜底使用当前 Python
python_executable = get_system_python()
logger.info(f"[PRELOAD_LOCAL] 使用系统 Python: {python_executable}")
# 继承环境变量,但清除虚拟环境相关的变量
env = {**os.environ}
env['CUDA_VISIBLE_DEVICES'] = '0'
env['TOKENIZERS_PARALLELISM'] = 'false'
# 清除虚拟环境相关变量
env.pop('VIRTUAL_ENV', None)
env.pop('PYTHONHOME', None)
logger.info(f"[PRELOAD_LOCAL] 脚本路径: {script_path}")
logger.info(f"[PRELOAD_LOCAL] 日志路径: {log_path}")
logger.info(f"[PRELOAD_LOCAL] 工作目录: {work_dir}")
logger.info(f"[PRELOAD_LOCAL] Python: {python_executable}")
logger.info(f"[PRELOAD_LOCAL] 脚本是否存在: {os.path.exists(script_path)}")
# 执行预加载脚本
try:
result = subprocess.run(
[python_executable, script_path],
capture_output=True,
text=True,
timeout=600, # 10分钟超时
cwd=work_dir,
env=env
)
logger.info(f"[PRELOAD_LOCAL] 返回码: {result.returncode}")
logger.info(f"[PRELOAD_LOCAL] stdout: {result.stdout[:500] if result.stdout else 'empty'}")
logger.info(f"[PRELOAD_LOCAL] stderr: {result.stderr[:500] if result.stderr else 'empty'}")
except Exception as sub_err:
logger.error(f"[PRELOAD_LOCAL] subprocess执行异常: {sub_err}")
return jsonify({'code': 1, 'message': f'执行异常: {str(sub_err)}'})
# 读取日志文件获取实际输出
try:
if os.path.exists(log_path):
with open(log_path, 'r', encoding='utf-8') as f:
full_output = f.read()
logger.info(f"[PRELOAD_LOCAL] 日志文件内容: {full_output[:500]}")
else:
logger.warning(f"[PRELOAD_LOCAL] 日志文件不存在: {log_path}")
full_output = result.stdout + result.stderr
except Exception as read_err:
logger.error(f"[PRELOAD_LOCAL] 读取日志失败: {read_err}")
full_output = result.stdout + result.stderr
logger.info(f"[PRELOAD_LOCAL] 脚本输出: {full_output[:500] if len(full_output) > 500 else full_output}")
# 清理临时文件
try:
if os.path.exists(script_path):
os.remove(script_path)
except Exception:
pass
try:
if os.path.exists(log_path):
os.remove(log_path)
except Exception:
pass
if result.returncode == 0 and 'SUCCESS' in full_output:
logger.info(f"[PRELOAD_LOCAL] 模型预加载成功: {model_name}")
return jsonify({
'code': 0,
'message': '模型预加载成功',
'data': {'model_name': model_name}
})
else:
error_msg = '预加载失败'
for line in full_output.split('\n'):
if 'ERROR:' in line:
error_msg = line.split('ERROR:')[1].strip()
break
logger.error(f"[PRELOAD_LOCAL] 预加载失败: {error_msg}")
return jsonify({'code': 1, 'message': f'预加载失败: {error_msg}'})
except subprocess.TimeoutExpired:
# 清理临时文件
try:
if os.path.exists(script_path):
os.remove(script_path)
except Exception:
pass
try:
if os.path.exists(log_path):
os.remove(log_path)
except Exception:
pass
logger.error("[PRELOAD_LOCAL] 预加载超时")
return jsonify({'code': 1, 'message': '预加载超时,请确保模型路径正确且有足够显存'})
except Exception as e:
# 清理临时文件
try:
if os.path.exists(script_path):
os.remove(script_path)
except Exception:
pass
try:
if os.path.exists(log_path):
os.remove(log_path)
except Exception:
pass
logger.error(f"[PRELOAD_LOCAL] 预加载异常: {str(e)}")
return jsonify({'code': 1, 'message': f'预加载异常: {str(e)}'})
@model_chat_bp.route('/local/chat', methods=['POST'])
def chat_local_model():
"""使用本地模型进行对话推理(使用 transformers"""
import yaml
import subprocess
import sys as sys_module
import json
data = request.json
model_path = data.get('model_path') # 模型路径
system_prompt = data.get('system_prompt', '')
user_question = data.get('user_question')
temperature = data.get('temperature', 0.7)
max_tokens = data.get('max_tokens', 2048)
if not model_path:
return jsonify({'code': 1, 'message': '缺少模型路径'})
if not user_question:
return jsonify({'code': 1, 'message': '缺少用户提问'})
logger.info(f"[CHAT_LOCAL] 开始对话推理, 模型路径: {model_path}")
# 构建消息
messages = []
if system_prompt:
messages.append({'role': 'system', 'content': system_prompt})
messages.append({'role': 'user', 'content': user_question})
# 先生成唯一ID避免并发冲突必须在f-string之前定义
import uuid as uuid_module
temp_id = uuid_module.uuid4().hex[:8]
# 构建 transformers 推理脚本不使用f-string避免import语句导致的模块shadow问题
# 使用唯一占位符避免与其他内容冲突
import json as json_module
messages_json = json_module.dumps(messages, ensure_ascii=False)
inference_script = '''# -*- coding: utf-8 -*-
import sys
import os
import uuid
import logging
import json
logging.basicConfig(level=logging.WARNING, format='%(message)s')
# 生成唯一的临时文件路径
temp_id = "__TEMP_ID__"
log_file = "/app/base/logs/chat_local_{}.log".format(temp_id)
# 设置环境变量
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
try:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_id = "__MODEL_PATH__"
with open(log_file, "w", encoding="utf-8") as f:
f.write("正在加载模型: {}\\n".format(model_id))
# 加载 tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
with open(log_file, "a", encoding="utf-8") as f:
f.write("Tokenizer 加载完成\\n")
# 加载模型
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
with open(log_file, "a", encoding="utf-8") as f:
f.write("模型加载完成\\n")
# 构建消息格式
messages = __MESSAGES_JSON__
# 应用 chat template
if tokenizer.chat_template:
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
else:
# 手动构建
text = ""
for msg in messages:
text += "<|im_start|>{}<|im_end|>\\n".format(msg['role'])
text += "{}<|im_end|>\\n".format(msg['content'])
text += "<|im_start|>assistant\\n"
with open(log_file, "a", encoding="utf-8") as f:
f.write("构建输入完成\\n")
# 编码输入
inputs = tokenizer(text, return_tensors="pt").to(model.device)
# 生成回复
outputs = model.generate(
**inputs,
max_new_tokens=__MAX_TOKENS__,
temperature=__TEMPERATURE__,
do_sample=True,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id
)
# 解码输出
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# 提取 assistant 的回复
if "assistant\\n" in response:
response = response.split("assistant\\n")[-1]
elif "<|im_start|>assistant" in response:
response = response.split("<|im_start|>assistant")[-1]
response = response.strip()
with open(log_file, "a", encoding="utf-8") as f:
f.write("{}\\n".format(response))
with open(log_file, "a", encoding="utf-8") as f:
f.write("SUCCESS\\n")
except Exception as e:
with open(log_file, "a", encoding="utf-8") as f:
f.write("ERROR: {}\\n".format(str(e)))
import traceback
with open(log_file, "a", encoding="utf-8") as f:
f.write(traceback.format_exc())
sys.exit(1)
'''
# 使用 replace 替换变量(避免 f-string 中包含 import 语句)
inference_script = inference_script.replace('__TEMP_ID__', temp_id)
inference_script = inference_script.replace('__MODEL_PATH__', model_path)
inference_script = inference_script.replace('__MESSAGES_JSON__', messages_json)
inference_script = inference_script.replace('__MAX_TOKENS__', str(max_tokens))
inference_script = inference_script.replace('__TEMPERATURE__', str(temperature))
# 写入临时脚本 - 使用唯一文件名避免并发冲突
work_dir = '/app/base'
script_path = os.path.join(work_dir, f'temp_chat_local_{temp_id}.py')
log_path = os.path.join('/app/base/logs', f'chat_local_{temp_id}.log')
try:
# 确保logs目录存在
os.makedirs(os.path.dirname(log_path), exist_ok=True)
with open(script_path, 'w', encoding='utf-8') as f:
f.write(inference_script)
# 查找系统 Python不使用虚拟环境
def get_system_python():
# 尝试常见的系统 Python 路径
common_pythons = [
'/usr/bin/python3',
'/usr/bin/python',
'/usr/local/bin/python3',
'/usr/local/bin/python',
]
for py in common_pythons:
if os.path.exists(py) and os.access(py, os.X_OK):
return py
# 如果都没找到,使用系统 PATH 中的 python3
import shutil
py_path = shutil.which('python3')
if py_path:
return py_path
return sys_module.executable # 兜底使用当前 Python
python_executable = get_system_python()
logger.info(f"[CHAT_LOCAL] 使用系统 Python: {python_executable}")
# 继承环境变量,但清除虚拟环境相关的变量
env = {**os.environ}
env['CUDA_VISIBLE_DEVICES'] = '0'
env['TOKENIZERS_PARALLELISM'] = 'false'
# 清除虚拟环境相关变量
env.pop('VIRTUAL_ENV', None)
env.pop('PYTHONHOME', None)
# 执行推理脚本
result = subprocess.run(
[python_executable, script_path],
capture_output=True,
text=True,
timeout=600, # 10分钟超时
cwd=work_dir,
env=env
)
# 读取日志文件获取实际输出
try:
if os.path.exists(log_path):
with open(log_path, 'r', encoding='utf-8') as f:
full_output = f.read()
else:
full_output = result.stdout + result.stderr
except Exception as read_err:
logger.error(f"[CHAT_LOCAL] 读取日志失败: {read_err}")
full_output = result.stdout + result.stderr
logger.info(f"[CHAT_LOCAL] 脚本输出: {full_output[:500] if len(full_output) > 500 else full_output}")
# 清理临时文件
try:
if os.path.exists(script_path):
os.remove(script_path)
except Exception:
pass
try:
if os.path.exists(log_path):
os.remove(log_path)
except Exception:
pass
if result.returncode == 0 and 'SUCCESS' in full_output:
# 提取实际回复(去掉最后的 SUCCESS
lines = full_output.strip().split('\n')
response_lines = [line for line in lines if line.strip() and line.strip() != 'SUCCESS']
assistant_content = '\n'.join(response_lines).strip()
logger.info(f"[CHAT_LOCAL] 对话成功, 回复长度: {len(assistant_content)}")
return jsonify({
'code': 0,
'data': {
'model_path': model_path,
'response': assistant_content
}
})
else:
error_msg = '推理失败'
for line in full_output.split('\n'):
if 'ERROR:' in line:
error_msg = line.split('ERROR:')[1].strip()
break
logger.error(f"[CHAT_LOCAL] 推理失败: {error_msg}")
return jsonify({'code': 1, 'message': f'推理失败: {error_msg}'})
except subprocess.TimeoutExpired:
# 清理临时文件
try:
if os.path.exists(script_path):
os.remove(script_path)
except Exception:
pass
try:
if os.path.exists(log_path):
os.remove(log_path)
except Exception:
pass
logger.error("[CHAT_LOCAL] 推理超时")
return jsonify({'code': 1, 'message': '推理超时请减少生成token数量或调整参数'})
except Exception as e:
# 清理临时文件
try:
if os.path.exists(script_path):
os.remove(script_path)
except Exception:
pass
try:
if os.path.exists(log_path):
os.remove(log_path)
except Exception:
pass
logger.error(f"[CHAT_LOCAL] 推理异常: {str(e)}")
return jsonify({'code': 1, 'message': f'推理异常: {str(e)}'})

View File

@@ -4,8 +4,12 @@
import os
import pymysql
import yaml
import logging
from flask import Blueprint, request, jsonify
# 获取模块 logger继承 main.py 的日志配置)
logger = logging.getLogger(__name__)
# 获取项目根目录 - 优先使用环境变量,否则从文件路径计算
MOUNT_BASE = os.environ.get('MOUNT_BASE', '/app/base')
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
@@ -47,8 +51,6 @@ def generic_get_all(table_name, order_by='create_time DESC'):
def get_model_path_by_name(model_name):
"""根据模型名称查询模型路径(用于获取基座模型路径)"""
import logging
logger = logging.getLogger(__name__)
logger.info(f"[DEBUG get_model_path_by_name] 查询模型: {model_name}")
try:
@@ -165,6 +167,23 @@ def get_model_manage_by_id(id):
return jsonify({'code': 1, 'message': '模型不存在'})
@model_manage_bp.route('/name/<model_name>', methods=['GET'])
def get_model_manage_by_name(model_name):
"""根据名称获取模型"""
logger.info(f"[DEBUG] 按名称查询模型: {model_name}")
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute("SELECT * FROM model_manage WHERE name = %s LIMIT 1", (model_name,))
model = cursor.fetchone()
cursor.close()
conn.close()
if model:
return jsonify({'code': 0, 'data': model})
return jsonify({'code': 1, 'message': '模型不存在'})
@model_manage_bp.route('', methods=['POST'])
def create_model_manage():
"""创建模型"""
@@ -575,25 +594,57 @@ def merge_model():
@model_manage_bp.route('/trained-models/<model_name>', methods=['DELETE'])
def delete_trained_model(model_name):
"""删除已训练模型从local_trained_models目录"""
"""删除已训练模型
type=merged: 删除合并模型local_trained_models目录
type=lora: 删除权重saves目录下的lora等权重文件
"""
import shutil
import logging
logger = logging.getLogger(__name__)
# 获取删除类型参数
delete_type = request.args.get('type', 'merged') # 默认删除合并模型
try:
# 删除 local_trained_models 目录下的模型
model_path = os.path.join(PROJECT_ROOT, 'local_trained_models', model_name)
if delete_type == 'lora':
# 删除权重:删除 saves 目录下的权重
saves_path = os.path.join(PROJECT_ROOT, 'saves')
train_methods = ['lora', 'full', 'qlora', 'dpo', 'cpt', 'prefix', 'adapter', 'peft']
if not os.path.exists(model_path):
return jsonify({'code': 1, 'message': f'模型不存在: {model_name}'})
deleted = False
for method in train_methods:
weight_path = os.path.join(saves_path, method, model_name)
if os.path.exists(weight_path):
shutil.rmtree(weight_path)
logger.info(f"[DELETE] 已删除权重: {weight_path}")
deleted = True
# 删除目录
shutil.rmtree(model_path)
logger.info(f"[DELETE] 已删除模型: {model_path}")
if not deleted:
# 也可能是老结构,直接在 saves 下的 model_name 目录
old_path = os.path.join(saves_path, model_name)
if os.path.exists(old_path):
shutil.rmtree(old_path)
logger.info(f"[DELETE] 已删除老结构权重: {old_path}")
deleted = True
return jsonify({'code': 0, 'message': '删除成功'})
if deleted:
return jsonify({'code': 0, 'message': '权重已删除'})
else:
return jsonify({'code': 1, 'message': f'权重不存在: {model_name}'})
else:
# 默认删除合并模型local_trained_models目录
model_path = os.path.join(PROJECT_ROOT, 'local_trained_models', model_name)
if not os.path.exists(model_path):
return jsonify({'code': 1, 'message': f'合并模型不存在: {model_name}'})
# 删除目录
shutil.rmtree(model_path)
logger.info(f"[DELETE] 已删除合并模型: {model_path}")
return jsonify({'code': 0, 'message': '合并模型已删除'})
except Exception as e:
logger.error(f"[DELETE] 删除模型失败: {str(e)}")
logger.error(f"[DELETE] 删除失败: {str(e)}")
return jsonify({'code': 1, 'message': f'删除失败: {str(e)}'})

View File

@@ -37,7 +37,25 @@ CONFIG = load_config()
TRAINING_LOGS_DIR = CONFIG.get('training_logs_path', '/app/base/training_logs')
# ============ 日志系统配置 ============
LOG_BASE_DIR = os.path.join(PROJECT_ROOT, 'logs')
# 日志目录逻辑:
# 1. 优先使用环境变量 LOG_BASE_DIR
# 2. 如果是容器环境(存在 /app/base使用 /app/base/logs
# 3. 否则使用本地项目路径 PROJECT_ROOT/logs
def get_log_base_dir():
"""获取日志基础目录"""
# 1. 检查环境变量
if 'LOG_BASE_DIR' in os.environ:
return os.environ['LOG_BASE_DIR']
# 2. 检查是否在容器环境中
mount_base = os.environ.get('MOUNT_BASE', '/app/base')
if os.path.exists(mount_base):
return os.path.join(mount_base, 'logs')
# 3. 使用本地项目路径
return os.path.join(PROJECT_ROOT, 'logs')
LOG_BASE_DIR = get_log_base_dir()
def setup_logger(name='app'):
@@ -98,6 +116,15 @@ def setup_logger(name='app'):
datefmt='%Y-%m-%d %H:%M:%S'
))
# 6. API日志处理器 - 专门记录API相关日志
api_log_path = os.path.join(log_dir, 'api.log')
api_handler = RotatingFileHandler(api_log_path, maxBytes=10*1024*1024, backupCount=10, encoding='utf-8')
api_handler.setLevel(logging.DEBUG)
api_handler.setFormatter(logging.Formatter(
'[%(asctime)s] %(levelname)s: %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
))
# 添加处理器到 logger
logger.addHandler(all_handler)
logger.addHandler(error_handler)
@@ -117,6 +144,13 @@ def setup_logger(name='app'):
train_logger.addHandler(train_handler)
train_logger.addHandler(console_handler)
# 为 logs_api 创建单独的 logger (供 src/api/logs.py 使用)
logs_api_logger = logging.getLogger('logs_api')
logs_api_logger.setLevel(logging.DEBUG)
logs_api_logger.handlers.clear()
logs_api_logger.addHandler(api_handler)
logs_api_logger.addHandler(console_handler)
return logger
@@ -588,6 +622,8 @@ def system_info():
# ============ 通用 CRUD 操作 ============
import json
def generic_get_all(table_name, order_by='create_time DESC'):
"""通用查询所有"""
conn = get_db_connection()
@@ -596,6 +632,19 @@ def generic_get_all(table_name, order_by='create_time DESC'):
result = cursor.fetchall()
cursor.close()
conn.close()
# 自动解析 JSON 字段
for row in result:
for key, value in row.items():
if isinstance(value, str) and value.startswith('[') and value.endswith(']'):
try:
row[key] = json.loads(value)
except:
pass
elif isinstance(value, str) and value.startswith('{') and value.endswith('}'):
try:
row[key] = json.loads(value)
except:
pass
return result
@@ -607,6 +656,19 @@ def generic_get_by_id(table_name, id_val):
result = cursor.fetchone()
cursor.close()
conn.close()
# 自动解析 JSON 字段
if result:
for key, value in result.items():
if isinstance(value, str) and value.startswith('[') and value.endswith(']'):
try:
result[key] = json.loads(value)
except:
pass
elif isinstance(value, str) and value.startswith('{') and value.endswith('}'):
try:
result[key] = json.loads(value)
except:
pass
return result
@@ -1053,36 +1115,19 @@ def delete_model_deploy(id):
return jsonify({'code': 0, 'message': '删除成功'})
# ============ 模型管理接口 ============
@app.route('/api/model-manage', methods=['GET'])
def get_model_manage():
return jsonify({'code': 0, 'data': generic_get_all('model_manage')})
@app.route('/api/model-manage', methods=['POST'])
def create_model_manage():
data = request.json
new_id = generic_create('model_manage', data)
return jsonify({'code': 0, 'message': '创建成功', 'id': new_id})
@app.route('/api/model-manage/<int:id>', methods=['PUT'])
def update_model_manage(id):
data = request.json
generic_update('model_manage', id, data)
return jsonify({'code': 0, 'message': '更新成功'})
@app.route('/api/model-manage/<int:id>', methods=['DELETE'])
def delete_model_manage(id):
generic_delete('model_manage', id)
return jsonify({'code': 0, 'message': '删除成功'})
# ============ 模型对比接口 ============
@app.route('/api/model-compare', methods=['GET'])
def get_model_compare():
return jsonify({'code': 0, 'data': generic_get_all('model_compare')})
result = generic_get_all('model_compare')
# 确保 models 字段被正确解析为 JSON
for row in result:
if 'models' in row and isinstance(row['models'], str):
try:
row['models'] = json.loads(row['models'])
logger.debug(f"[model-compare] 解析 models 字段成功: {row['models']}")
except Exception as e:
logger.error(f"[model-compare] 解析 models 字段失败: {e}, 原始值: {row['models']}")
return jsonify({'code': 0, 'data': result})
@app.route('/api/model-compare/<int:id>', methods=['GET'])
@@ -1090,13 +1135,25 @@ def get_model_compare_by_id(id):
"""获取单个模型对比任务"""
result = generic_get_by_id('model_compare', id)
if result:
# 确保 models 字段被正确解析为 JSON
if 'models' in result and isinstance(result['models'], str):
try:
result['models'] = json.loads(result['models'])
except Exception as e:
logger.error(f"[model-compare] 解析 models 字段失败: {e}")
return jsonify({'code': 0, 'data': result})
return jsonify({'code': 1, 'message': '任务不存在'})
@app.route('/api/model-compare', methods=['POST'])
def create_model_compare():
data = request.json
data = request.json.copy()
# 字段映射: name -> model_name
if 'name' in data:
data['model_name'] = data.pop('name')
# 设置默认加载状态
if 'status' not in data:
data['status'] = 'pending'
new_id = generic_create('model_compare', data)
return jsonify({'code': 0, 'message': '创建成功', 'id': new_id})
@@ -1110,10 +1167,278 @@ def update_model_compare(id):
@app.route('/api/model-compare/<int:id>', methods=['DELETE'])
def delete_model_compare(id):
# 先停止加载的模型服务
try:
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute("SELECT models, load_status FROM model_compare WHERE id = %s", (id,))
row = cursor.fetchone()
cursor.close()
conn.close()
if row and row[1] == 'loaded':
# 停止模型服务
from subprocess import Popen
import signal
load_status = json.loads(row[0]) if isinstance(row[0], str) else row[0]
if isinstance(load_status, dict) and 'processes' in load_status:
for proc in load_status['processes']:
if 'pid' in proc:
try:
os.kill(proc['pid'], signal.SIGTERM)
except:
pass
except Exception as e:
logger.error(f"[model-compare] 停止模型服务失败: {e}")
generic_delete('model_compare', id)
return jsonify({'code': 0, 'message': '删除成功'})
# ============ 模型加载接口 ============
import subprocess
import signal
# 存储加载状态 (生产环境应使用数据库或Redis)
model_compare_processes = {}
@app.route('/api/model-compare/<int:id>/load', methods=['POST'])
def load_model_compare(id):
"""加载模型 - 启动模型服务"""
try:
# 获取对比任务
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute("SELECT models FROM model_compare WHERE id = %s", (id,))
row = cursor.fetchone()
cursor.close()
conn.close()
if not row:
return jsonify({'code': 1, 'message': '任务不存在'})
models = json.loads(row[0]) if isinstance(row[0], str) else row[0]
if not models or len(models) < 2:
return jsonify({'code': 1, 'message': '模型配置无效'})
# 更新状态为 loading
generic_update('model_compare', id, {'status': 'loading'})
# 启动模型服务
processes = []
loaded_models = []
for i, model in enumerate(models):
model_path = model.get('model_path', '')
gpu_id = model.get('gpu_id', 0)
port = model.get('port', 7862 + i * 10)
if not model_path:
continue
# 构建启动命令
# 获取模型名称和模板
model_name = model.get('model_name', '')
template = detect_template(model_path)
cmd = [
sys.executable,
'-m', 'llamafactory.api',
'--model_name_or_path', model_path,
'--template', template or 'default',
'--port', str(port),
'--gpu_ids', str(gpu_id)
]
logger.info(f"[model-compare] 启动模型服务: {' '.join(cmd)}")
# 启动进程
proc = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
cwd=PROJECT_ROOT
)
processes.append({
'pid': proc.pid,
'port': port,
'model_name': model_name,
'model_path': model_path
})
loaded_models.append({
'port': port,
'model_name': model_name,
'status': 'starting'
})
# 保存进程信息
model_compare_processes[id] = {
'processes': processes,
'loaded_models': loaded_models,
'start_time': datetime.now().timestamp()
}
# 更新数据库中的加载状态
load_status = {
'processes': processes,
'loaded_models': loaded_models,
'started_at': datetime.now().isoformat()
}
generic_update('model_compare', id, {
'status': 'loading',
'load_status': json.dumps(load_status, ensure_ascii=False)
})
return jsonify({
'code': 0,
'message': '正在加载模型...',
'data': {
'models': loaded_models,
'check_url': f'/api/model-compare/{id}/load-status'
}
})
except Exception as e:
logger.error(f"[model-compare] 加载模型失败: {e}")
generic_update('model_compare', id, {'status': 'pending'})
return jsonify({'code': 1, 'message': f'加载失败: {str(e)}'})
def detect_template(model_path):
"""根据模型路径检测模板类型"""
model_lower = model_path.lower()
if 'qwen' in model_lower:
return 'qwen'
elif 'llama' in model_lower or 'llama' in model_lower:
return 'llama'
elif 'chatglm' in model_lower:
return 'chatglm'
elif 'baichuan' in model_lower:
return 'baichuan'
elif 'mistral' in model_lower:
return 'mistral'
elif 'yi' in model_lower:
return 'yi'
elif 'deepseek' in model_lower:
return 'deepseek'
return 'default'
@app.route('/api/model-compare/<int:id>/load-status', methods=['GET'])
def get_load_status(id):
"""获取模型加载状态"""
try:
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute("SELECT models, load_status FROM model_compare WHERE id = %s", (id,))
row = cursor.fetchone()
cursor.close()
conn.close()
if not row:
return jsonify({'code': 1, 'message': '任务不存在'})
models = json.loads(row[0]) if isinstance(row[0], str) else row[0]
load_status = json.loads(row[1]) if row[1] and isinstance(row[1], str) else {}
loaded_models = load_status.get('loaded_models', [])
processes = load_status.get('processes', [])
# 检查每个模型服务是否就绪
all_ready = True
for i, model in enumerate(loaded_models):
port = model.get('port')
try:
import requests
resp = requests.get(f'http://localhost:{port}/health', timeout=2)
if resp.status_code == 200:
model['status'] = 'ready'
else:
model['status'] = 'starting'
all_ready = False
except:
# 检查进程是否还在运行
if 'pid' in processes[i] if i < len(processes) else False:
model['status'] = 'starting'
all_ready = False
else:
model['status'] = 'failed'
all_ready = False
# 更新状态
if all_ready:
generic_update('model_compare', id, {'status': 'loaded'})
else:
generic_update('model_compare', id, {'status': 'loading'})
# 更新加载状态
load_status['loaded_models'] = loaded_models
load_status['all_ready'] = all_ready
generic_update('model_compare', id, {
'load_status': json.dumps(load_status, ensure_ascii=False)
})
return jsonify({
'code': 0,
'data': {
'status': 'loaded' if all_ready else 'loading',
'models': loaded_models,
'all_ready': all_ready
}
})
except Exception as e:
logger.error(f"[model-compare] 获取加载状态失败: {e}")
return jsonify({'code': 1, 'message': str(e)})
@app.route('/api/model-compare/<int:id>/unload', methods=['POST'])
def unload_model_compare(id):
"""停止加载的模型服务"""
try:
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute("SELECT load_status FROM model_compare WHERE id = %s", (id,))
row = cursor.fetchone()
cursor.close()
conn.close()
if row and row[0]:
load_status = json.loads(row[0]) if isinstance(row[0], str) else row[0]
processes = load_status.get('processes', [])
# 停止所有进程
for proc in processes:
pid = proc.get('pid')
if pid:
try:
os.kill(pid, signal.SIGTERM)
logger.info(f"[model-compare] 已停止进程 {pid}")
except ProcessLookupError:
pass
except Exception as e:
logger.warning(f"[model-compare] 停止进程 {pid} 失败: {e}")
# 清理内存
if id in model_compare_processes:
del model_compare_processes[id]
# 更新状态
generic_update('model_compare', id, {
'status': 'pending',
'load_status': None
})
return jsonify({'code': 0, 'message': '已停止模型服务'})
except Exception as e:
logger.error(f"[model-compare] 停止模型服务失败: {e}")
return jsonify({'code': 1, 'message': str(e)})
# ============ 数据生成接口 ============
@app.route('/api/data-generate', methods=['GET'])
def get_data_generate():