重构了main.html的主函数
重构了大量的页面的sidebar 优化了代码结构
This commit is contained in:
118
src/api/logs.py
118
src/api/logs.py
@@ -4,9 +4,7 @@
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import yaml
|
||||
from datetime import datetime
|
||||
from logging.handlers import TimedRotatingFileHandler, RotatingFileHandler
|
||||
from flask import Blueprint, request, jsonify
|
||||
|
||||
# 获取项目根目录
|
||||
@@ -14,74 +12,40 @@ PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(_
|
||||
sys.path.insert(0, PROJECT_ROOT)
|
||||
|
||||
# 加载配置
|
||||
import yaml
|
||||
CONFIG_PATH = os.path.join(PROJECT_ROOT, 'config.yaml')
|
||||
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
|
||||
CONFIG = yaml.safe_load(f)
|
||||
|
||||
# 日志目录
|
||||
LOG_BASE_DIR = os.path.join(PROJECT_ROOT, 'logs')
|
||||
# 日志目录 - 使用与 main.py 相同的配置
|
||||
def get_log_base_dir():
|
||||
"""获取日志基础目录"""
|
||||
# 1. 检查环境变量
|
||||
if 'LOG_BASE_DIR' in os.environ:
|
||||
return os.environ['LOG_BASE_DIR']
|
||||
|
||||
# 2. 检查是否在容器环境中
|
||||
mount_base = os.environ.get('MOUNT_BASE', '/app/base')
|
||||
if os.path.exists(mount_base):
|
||||
return os.path.join(mount_base, 'logs')
|
||||
|
||||
# 3. 使用本地项目路径
|
||||
return os.path.join(PROJECT_ROOT, 'logs')
|
||||
|
||||
LOG_BASE_DIR = get_log_base_dir()
|
||||
|
||||
# 创建蓝图
|
||||
logs_bp = Blueprint('logs', __name__, url_prefix='/api')
|
||||
|
||||
|
||||
def setup_logs_logger():
|
||||
"""配置日志系统,按日期分目录存储"""
|
||||
today = datetime.now().strftime('%Y-%m-%d')
|
||||
log_dir = os.path.join(LOG_BASE_DIR, today)
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
logger = logging.getLogger('logs_api')
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logger.handlers.clear()
|
||||
|
||||
# 全部日志处理器
|
||||
all_log_path = os.path.join(log_dir, 'all.log')
|
||||
all_handler = TimedRotatingFileHandler(all_log_path, when='midnight', interval=1, backupCount=30, encoding='utf-8')
|
||||
all_handler.setLevel(logging.DEBUG)
|
||||
all_handler.setFormatter(logging.Formatter(
|
||||
'[%(asctime)s] %(levelname)s in %(module)s: %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
))
|
||||
|
||||
# 控制台处理器
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(logging.INFO)
|
||||
console_handler.setFormatter(logging.Formatter(
|
||||
'[%(asctime)s] %(levelname)s: %(message)s',
|
||||
datefmt='%H:%M:%S'
|
||||
))
|
||||
|
||||
logger.addHandler(all_handler)
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
return logger
|
||||
def get_logs_logger():
|
||||
"""从 main.py 获取日志记录器"""
|
||||
return logging.getLogger('logs_api')
|
||||
|
||||
|
||||
logs_logger = setup_logs_logger()
|
||||
|
||||
|
||||
@logs_bp.route('/web-log', methods=['POST'])
|
||||
def receive_web_log():
|
||||
"""接收前端页面发送的日志"""
|
||||
data = request.json
|
||||
level = data.get('level', 'info')
|
||||
message = data.get('message', '')
|
||||
page = data.get('page', 'unknown')
|
||||
timestamp = data.get('timestamp', datetime.now().isoformat())
|
||||
|
||||
log_message = f'[WEB-{page}] {message}'
|
||||
|
||||
if level == 'error':
|
||||
logs_logger.error(log_message)
|
||||
elif level == 'warning':
|
||||
logs_logger.warning(log_message)
|
||||
elif level == 'debug':
|
||||
logs_logger.debug(log_message)
|
||||
else:
|
||||
logs_logger.info(log_message)
|
||||
|
||||
return jsonify({'code': 0, 'message': '日志接收成功'})
|
||||
def get_request_logger():
|
||||
"""获取请求日志记录器"""
|
||||
return logging.getLogger('request')
|
||||
|
||||
|
||||
def format_file_size(size_bytes):
|
||||
@@ -94,6 +58,30 @@ def format_file_size(size_bytes):
|
||||
return f'{size_bytes / (1024 * 1024):.1f} MB'
|
||||
|
||||
|
||||
@logs_bp.route('/web-log', methods=['POST'])
|
||||
def receive_web_log():
|
||||
"""接收前端页面发送的日志"""
|
||||
data = request.json
|
||||
level = data.get('level', 'info')
|
||||
message = data.get('message', '')
|
||||
page = data.get('page', 'unknown')
|
||||
timestamp = data.get('timestamp', '')
|
||||
|
||||
log_message = f'[WEB-{page}] {message}'
|
||||
|
||||
logger = get_logs_logger()
|
||||
if level == 'error':
|
||||
logger.error(log_message)
|
||||
elif level == 'warning':
|
||||
logger.warning(log_message)
|
||||
elif level == 'debug':
|
||||
logger.debug(log_message)
|
||||
else:
|
||||
logger.info(log_message)
|
||||
|
||||
return jsonify({'code': 0, 'message': '日志接收成功'})
|
||||
|
||||
|
||||
@logs_bp.route('/log-files', methods=['GET'])
|
||||
def get_log_files():
|
||||
"""获取指定日期的日志文件列表"""
|
||||
@@ -113,7 +101,8 @@ def get_log_files():
|
||||
return jsonify({'code': 0, 'data': []})
|
||||
|
||||
log_files = []
|
||||
file_names = ['all.log', 'error.log', 'request.log']
|
||||
# 定义日志文件的优先级顺序
|
||||
file_names = ['all.log', 'api.log', 'error.log', 'request.log', 'train.log']
|
||||
|
||||
for file_name in file_names:
|
||||
file_path = os.path.join(log_dir, file_name)
|
||||
@@ -178,15 +167,13 @@ TRAINING_LOGS_BASE_DIR = '/app/base/logs'
|
||||
# 本地开发时的备用路径(Windows)
|
||||
LOCAL_TRAINING_LOGS_BASE_DIR = os.path.join(PROJECT_ROOT, 'logs')
|
||||
|
||||
# 添加调试日志
|
||||
logs_logger.info(f"[DEBUG] TRAINING_LOGS_BASE_DIR: {TRAINING_LOGS_BASE_DIR}")
|
||||
logs_logger.info(f"[DEBUG] LOCAL_TRAINING_LOGS_BASE_DIR: {LOCAL_TRAINING_LOGS_BASE_DIR}")
|
||||
|
||||
|
||||
@logs_bp.route('/training-log-files', methods=['GET'])
|
||||
def get_training_log_files():
|
||||
"""获取训练日志文件列表 - 从 logs/{日期} 目录下的 .log 文件"""
|
||||
try:
|
||||
logs_logger = get_logs_logger()
|
||||
|
||||
# 确定基础目录
|
||||
logs_base_dir = TRAINING_LOGS_BASE_DIR
|
||||
if not os.path.exists(logs_base_dir):
|
||||
@@ -280,7 +267,7 @@ def get_training_log_files():
|
||||
|
||||
return jsonify({'code': 0, 'data': log_files})
|
||||
except Exception as e:
|
||||
logs_logger.error(f"[DEBUG] 获取训练日志列表失败: {e}")
|
||||
get_logs_logger().error(f"[DEBUG] 获取训练日志列表失败: {e}")
|
||||
return jsonify({'code': 1, 'message': f'获取训练日志列表失败: {str(e)}'})
|
||||
|
||||
|
||||
@@ -291,6 +278,7 @@ def get_training_log_content():
|
||||
if not file_name:
|
||||
return jsonify({'code': 1, 'message': '缺少文件参数'})
|
||||
|
||||
logs_logger = get_logs_logger()
|
||||
logs_logger.info(f"[DEBUG] ============ get_training_log_content ============")
|
||||
logs_logger.info(f"[DEBUG] file: {file_name}")
|
||||
|
||||
|
||||
@@ -8,8 +8,12 @@ import json
|
||||
import requests
|
||||
import concurrent.futures
|
||||
import subprocess
|
||||
import logging
|
||||
from flask import Blueprint, request, jsonify
|
||||
|
||||
# 获取模块 logger(继承 main.py 的日志配置)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 获取项目根目录
|
||||
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
@@ -218,8 +222,6 @@ def preload_trained_model():
|
||||
import sys as sys_module
|
||||
import pymysql
|
||||
import yaml
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
data = request.json
|
||||
model_name = data.get('model_name') # 模型名称
|
||||
@@ -572,3 +574,502 @@ if __name__ == "__main__":
|
||||
return jsonify({'code': 1, 'message': '推理超时,请稍后重试'})
|
||||
except Exception as e:
|
||||
return jsonify({'code': 1, 'message': f'推理异常: {str(e)}'})
|
||||
|
||||
|
||||
# ==================== Transformers 本地模型接口 ====================
|
||||
|
||||
@model_chat_bp.route('/local/preload', methods=['POST'])
|
||||
def preload_local_model():
|
||||
"""预加载本地模型(使用 transformers)"""
|
||||
import yaml
|
||||
import subprocess
|
||||
import sys as sys_module
|
||||
|
||||
data = request.json
|
||||
model_path = data.get('model_path') # 模型路径
|
||||
model_name = data.get('model_name', '本地模型') # 模型名称(用于显示)
|
||||
|
||||
if not model_path:
|
||||
return jsonify({'code': 1, 'message': '缺少模型路径'})
|
||||
|
||||
logger.info(f"[PRELOAD_LOCAL] 开始预加载本地模型: {model_name}, 路径: {model_path}")
|
||||
|
||||
# 先生成唯一ID,避免并发冲突(必须在f-string之前定义)
|
||||
import uuid as uuid_module
|
||||
temp_id = uuid_module.uuid4().hex[:8]
|
||||
|
||||
# 获取项目根目录
|
||||
PROJECT_ROOT = os.path.dirname(os.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
CONFIG_PATH = os.path.join(PROJECT_ROOT, 'config.yaml')
|
||||
|
||||
try:
|
||||
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
|
||||
CONFIG = yaml.safe_load(f)
|
||||
except Exception as e:
|
||||
return jsonify({'code': 1, 'message': f'读取配置失败: {str(e)}'})
|
||||
|
||||
# 构建 transformers 预加载脚本(不使用f-string避免import语句导致的模块shadow问题)
|
||||
preload_script = '''# -*- coding: utf-8 -*-
|
||||
import sys
|
||||
import uuid
|
||||
import logging
|
||||
import os
|
||||
logging.basicConfig(level=logging.WARNING, format='%(message)s')
|
||||
|
||||
# 生成唯一的临时文件路径,避免并发冲突
|
||||
temp_id = "TEMP_ID"
|
||||
log_file = "/app/base/logs/preload_local_{}.log".format(temp_id)
|
||||
|
||||
# 设置环境变量
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
try:
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
import torch
|
||||
|
||||
model_id = "MODEL_PATH"
|
||||
with open(log_file, "w", encoding="utf-8") as f:
|
||||
f.write("开始加载模型: {}\\n".format(model_id))
|
||||
|
||||
# 加载 tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write("Tokenizer 加载完成\\n")
|
||||
|
||||
# 加载模型
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
trust_remote_code=True
|
||||
)
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write("模型加载完成\\n")
|
||||
|
||||
# 测试推理
|
||||
test_input = "你好"
|
||||
inputs = tokenizer(test_input, return_tensors="pt").to(model.device)
|
||||
outputs = model.generate(**inputs, max_new_tokens=10, do_sample=False)
|
||||
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write("测试推理成功: {}\\n".format(response))
|
||||
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write("SUCCESS\\n")
|
||||
|
||||
except Exception as e:
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write("ERROR: {}\\n".format(str(e)))
|
||||
import traceback
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write(traceback.format_exc())
|
||||
sys.exit(1)
|
||||
'''
|
||||
|
||||
# 使用 replace 替换变量(避免 f-string 中包含 import 语句)
|
||||
preload_script = preload_script.replace('TEMP_ID', temp_id).replace('MODEL_PATH', model_path)
|
||||
|
||||
# 写入临时脚本 - 使用唯一文件名避免并发冲突
|
||||
work_dir = '/app/base'
|
||||
script_path = os.path.join(work_dir, f'temp_preload_local_{temp_id}.py')
|
||||
log_path = os.path.join('/app/base/logs', f'preload_local_{temp_id}.log')
|
||||
|
||||
try:
|
||||
# 确保logs目录存在
|
||||
os.makedirs(os.path.dirname(log_path), exist_ok=True)
|
||||
|
||||
with open(script_path, 'w', encoding='utf-8') as f:
|
||||
f.write(preload_script)
|
||||
|
||||
# 查找系统 Python(不使用虚拟环境)
|
||||
def get_system_python():
|
||||
# 尝试常见的系统 Python 路径
|
||||
common_pythons = [
|
||||
'/usr/bin/python3',
|
||||
'/usr/bin/python',
|
||||
'/usr/local/bin/python3',
|
||||
'/usr/local/bin/python',
|
||||
]
|
||||
for py in common_pythons:
|
||||
if os.path.exists(py) and os.access(py, os.X_OK):
|
||||
return py
|
||||
# 如果都没找到,使用系统 PATH 中的 python3
|
||||
import shutil
|
||||
py_path = shutil.which('python3')
|
||||
if py_path:
|
||||
return py_path
|
||||
return sys_module.executable # 兜底使用当前 Python
|
||||
|
||||
python_executable = get_system_python()
|
||||
logger.info(f"[PRELOAD_LOCAL] 使用系统 Python: {python_executable}")
|
||||
|
||||
# 继承环境变量,但清除虚拟环境相关的变量
|
||||
env = {**os.environ}
|
||||
env['CUDA_VISIBLE_DEVICES'] = '0'
|
||||
env['TOKENIZERS_PARALLELISM'] = 'false'
|
||||
# 清除虚拟环境相关变量
|
||||
env.pop('VIRTUAL_ENV', None)
|
||||
env.pop('PYTHONHOME', None)
|
||||
|
||||
logger.info(f"[PRELOAD_LOCAL] 脚本路径: {script_path}")
|
||||
logger.info(f"[PRELOAD_LOCAL] 日志路径: {log_path}")
|
||||
logger.info(f"[PRELOAD_LOCAL] 工作目录: {work_dir}")
|
||||
logger.info(f"[PRELOAD_LOCAL] Python: {python_executable}")
|
||||
logger.info(f"[PRELOAD_LOCAL] 脚本是否存在: {os.path.exists(script_path)}")
|
||||
|
||||
# 执行预加载脚本
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[python_executable, script_path],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=600, # 10分钟超时
|
||||
cwd=work_dir,
|
||||
env=env
|
||||
)
|
||||
logger.info(f"[PRELOAD_LOCAL] 返回码: {result.returncode}")
|
||||
logger.info(f"[PRELOAD_LOCAL] stdout: {result.stdout[:500] if result.stdout else 'empty'}")
|
||||
logger.info(f"[PRELOAD_LOCAL] stderr: {result.stderr[:500] if result.stderr else 'empty'}")
|
||||
except Exception as sub_err:
|
||||
logger.error(f"[PRELOAD_LOCAL] subprocess执行异常: {sub_err}")
|
||||
return jsonify({'code': 1, 'message': f'执行异常: {str(sub_err)}'})
|
||||
|
||||
# 读取日志文件获取实际输出
|
||||
try:
|
||||
if os.path.exists(log_path):
|
||||
with open(log_path, 'r', encoding='utf-8') as f:
|
||||
full_output = f.read()
|
||||
logger.info(f"[PRELOAD_LOCAL] 日志文件内容: {full_output[:500]}")
|
||||
else:
|
||||
logger.warning(f"[PRELOAD_LOCAL] 日志文件不存在: {log_path}")
|
||||
full_output = result.stdout + result.stderr
|
||||
except Exception as read_err:
|
||||
logger.error(f"[PRELOAD_LOCAL] 读取日志失败: {read_err}")
|
||||
full_output = result.stdout + result.stderr
|
||||
|
||||
logger.info(f"[PRELOAD_LOCAL] 脚本输出: {full_output[:500] if len(full_output) > 500 else full_output}")
|
||||
|
||||
# 清理临时文件
|
||||
try:
|
||||
if os.path.exists(script_path):
|
||||
os.remove(script_path)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if os.path.exists(log_path):
|
||||
os.remove(log_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if result.returncode == 0 and 'SUCCESS' in full_output:
|
||||
logger.info(f"[PRELOAD_LOCAL] 模型预加载成功: {model_name}")
|
||||
return jsonify({
|
||||
'code': 0,
|
||||
'message': '模型预加载成功',
|
||||
'data': {'model_name': model_name}
|
||||
})
|
||||
else:
|
||||
error_msg = '预加载失败'
|
||||
for line in full_output.split('\n'):
|
||||
if 'ERROR:' in line:
|
||||
error_msg = line.split('ERROR:')[1].strip()
|
||||
break
|
||||
logger.error(f"[PRELOAD_LOCAL] 预加载失败: {error_msg}")
|
||||
return jsonify({'code': 1, 'message': f'预加载失败: {error_msg}'})
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
# 清理临时文件
|
||||
try:
|
||||
if os.path.exists(script_path):
|
||||
os.remove(script_path)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if os.path.exists(log_path):
|
||||
os.remove(log_path)
|
||||
except Exception:
|
||||
pass
|
||||
logger.error("[PRELOAD_LOCAL] 预加载超时")
|
||||
return jsonify({'code': 1, 'message': '预加载超时,请确保模型路径正确且有足够显存'})
|
||||
except Exception as e:
|
||||
# 清理临时文件
|
||||
try:
|
||||
if os.path.exists(script_path):
|
||||
os.remove(script_path)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if os.path.exists(log_path):
|
||||
os.remove(log_path)
|
||||
except Exception:
|
||||
pass
|
||||
logger.error(f"[PRELOAD_LOCAL] 预加载异常: {str(e)}")
|
||||
return jsonify({'code': 1, 'message': f'预加载异常: {str(e)}'})
|
||||
|
||||
|
||||
@model_chat_bp.route('/local/chat', methods=['POST'])
|
||||
def chat_local_model():
|
||||
"""使用本地模型进行对话推理(使用 transformers)"""
|
||||
import yaml
|
||||
import subprocess
|
||||
import sys as sys_module
|
||||
import json
|
||||
|
||||
data = request.json
|
||||
model_path = data.get('model_path') # 模型路径
|
||||
system_prompt = data.get('system_prompt', '')
|
||||
user_question = data.get('user_question')
|
||||
temperature = data.get('temperature', 0.7)
|
||||
max_tokens = data.get('max_tokens', 2048)
|
||||
|
||||
if not model_path:
|
||||
return jsonify({'code': 1, 'message': '缺少模型路径'})
|
||||
if not user_question:
|
||||
return jsonify({'code': 1, 'message': '缺少用户提问'})
|
||||
|
||||
logger.info(f"[CHAT_LOCAL] 开始对话推理, 模型路径: {model_path}")
|
||||
|
||||
# 构建消息
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({'role': 'system', 'content': system_prompt})
|
||||
messages.append({'role': 'user', 'content': user_question})
|
||||
|
||||
# 先生成唯一ID,避免并发冲突(必须在f-string之前定义)
|
||||
import uuid as uuid_module
|
||||
temp_id = uuid_module.uuid4().hex[:8]
|
||||
|
||||
# 构建 transformers 推理脚本(不使用f-string避免import语句导致的模块shadow问题)
|
||||
# 使用唯一占位符避免与其他内容冲突
|
||||
import json as json_module
|
||||
messages_json = json_module.dumps(messages, ensure_ascii=False)
|
||||
|
||||
inference_script = '''# -*- coding: utf-8 -*-
|
||||
import sys
|
||||
import os
|
||||
import uuid
|
||||
import logging
|
||||
import json
|
||||
logging.basicConfig(level=logging.WARNING, format='%(message)s')
|
||||
|
||||
# 生成唯一的临时文件路径
|
||||
temp_id = "__TEMP_ID__"
|
||||
log_file = "/app/base/logs/chat_local_{}.log".format(temp_id)
|
||||
|
||||
# 设置环境变量
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
try:
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
import torch
|
||||
|
||||
model_id = "__MODEL_PATH__"
|
||||
with open(log_file, "w", encoding="utf-8") as f:
|
||||
f.write("正在加载模型: {}\\n".format(model_id))
|
||||
|
||||
# 加载 tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write("Tokenizer 加载完成\\n")
|
||||
|
||||
# 加载模型
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
trust_remote_code=True
|
||||
)
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write("模型加载完成\\n")
|
||||
|
||||
# 构建消息格式
|
||||
messages = __MESSAGES_JSON__
|
||||
|
||||
# 应用 chat template
|
||||
if tokenizer.chat_template:
|
||||
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
else:
|
||||
# 手动构建
|
||||
text = ""
|
||||
for msg in messages:
|
||||
text += "<|im_start|>{}<|im_end|>\\n".format(msg['role'])
|
||||
text += "{}<|im_end|>\\n".format(msg['content'])
|
||||
text += "<|im_start|>assistant\\n"
|
||||
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write("构建输入完成\\n")
|
||||
|
||||
# 编码输入
|
||||
inputs = tokenizer(text, return_tensors="pt").to(model.device)
|
||||
|
||||
# 生成回复
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=__MAX_TOKENS__,
|
||||
temperature=__TEMPERATURE__,
|
||||
do_sample=True,
|
||||
top_p=0.9,
|
||||
pad_token_id=tokenizer.eos_token_id
|
||||
)
|
||||
|
||||
# 解码输出
|
||||
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
|
||||
# 提取 assistant 的回复
|
||||
if "assistant\\n" in response:
|
||||
response = response.split("assistant\\n")[-1]
|
||||
elif "<|im_start|>assistant" in response:
|
||||
response = response.split("<|im_start|>assistant")[-1]
|
||||
response = response.strip()
|
||||
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write("{}\\n".format(response))
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write("SUCCESS\\n")
|
||||
|
||||
except Exception as e:
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write("ERROR: {}\\n".format(str(e)))
|
||||
import traceback
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write(traceback.format_exc())
|
||||
sys.exit(1)
|
||||
'''
|
||||
|
||||
# 使用 replace 替换变量(避免 f-string 中包含 import 语句)
|
||||
inference_script = inference_script.replace('__TEMP_ID__', temp_id)
|
||||
inference_script = inference_script.replace('__MODEL_PATH__', model_path)
|
||||
inference_script = inference_script.replace('__MESSAGES_JSON__', messages_json)
|
||||
inference_script = inference_script.replace('__MAX_TOKENS__', str(max_tokens))
|
||||
inference_script = inference_script.replace('__TEMPERATURE__', str(temperature))
|
||||
|
||||
# 写入临时脚本 - 使用唯一文件名避免并发冲突
|
||||
work_dir = '/app/base'
|
||||
script_path = os.path.join(work_dir, f'temp_chat_local_{temp_id}.py')
|
||||
log_path = os.path.join('/app/base/logs', f'chat_local_{temp_id}.log')
|
||||
|
||||
try:
|
||||
# 确保logs目录存在
|
||||
os.makedirs(os.path.dirname(log_path), exist_ok=True)
|
||||
|
||||
with open(script_path, 'w', encoding='utf-8') as f:
|
||||
f.write(inference_script)
|
||||
|
||||
# 查找系统 Python(不使用虚拟环境)
|
||||
def get_system_python():
|
||||
# 尝试常见的系统 Python 路径
|
||||
common_pythons = [
|
||||
'/usr/bin/python3',
|
||||
'/usr/bin/python',
|
||||
'/usr/local/bin/python3',
|
||||
'/usr/local/bin/python',
|
||||
]
|
||||
for py in common_pythons:
|
||||
if os.path.exists(py) and os.access(py, os.X_OK):
|
||||
return py
|
||||
# 如果都没找到,使用系统 PATH 中的 python3
|
||||
import shutil
|
||||
py_path = shutil.which('python3')
|
||||
if py_path:
|
||||
return py_path
|
||||
return sys_module.executable # 兜底使用当前 Python
|
||||
|
||||
python_executable = get_system_python()
|
||||
logger.info(f"[CHAT_LOCAL] 使用系统 Python: {python_executable}")
|
||||
|
||||
# 继承环境变量,但清除虚拟环境相关的变量
|
||||
env = {**os.environ}
|
||||
env['CUDA_VISIBLE_DEVICES'] = '0'
|
||||
env['TOKENIZERS_PARALLELISM'] = 'false'
|
||||
# 清除虚拟环境相关变量
|
||||
env.pop('VIRTUAL_ENV', None)
|
||||
env.pop('PYTHONHOME', None)
|
||||
|
||||
# 执行推理脚本
|
||||
result = subprocess.run(
|
||||
[python_executable, script_path],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=600, # 10分钟超时
|
||||
cwd=work_dir,
|
||||
env=env
|
||||
)
|
||||
|
||||
# 读取日志文件获取实际输出
|
||||
try:
|
||||
if os.path.exists(log_path):
|
||||
with open(log_path, 'r', encoding='utf-8') as f:
|
||||
full_output = f.read()
|
||||
else:
|
||||
full_output = result.stdout + result.stderr
|
||||
except Exception as read_err:
|
||||
logger.error(f"[CHAT_LOCAL] 读取日志失败: {read_err}")
|
||||
full_output = result.stdout + result.stderr
|
||||
|
||||
logger.info(f"[CHAT_LOCAL] 脚本输出: {full_output[:500] if len(full_output) > 500 else full_output}")
|
||||
|
||||
# 清理临时文件
|
||||
try:
|
||||
if os.path.exists(script_path):
|
||||
os.remove(script_path)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if os.path.exists(log_path):
|
||||
os.remove(log_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if result.returncode == 0 and 'SUCCESS' in full_output:
|
||||
# 提取实际回复(去掉最后的 SUCCESS)
|
||||
lines = full_output.strip().split('\n')
|
||||
response_lines = [line for line in lines if line.strip() and line.strip() != 'SUCCESS']
|
||||
assistant_content = '\n'.join(response_lines).strip()
|
||||
|
||||
logger.info(f"[CHAT_LOCAL] 对话成功, 回复长度: {len(assistant_content)}")
|
||||
return jsonify({
|
||||
'code': 0,
|
||||
'data': {
|
||||
'model_path': model_path,
|
||||
'response': assistant_content
|
||||
}
|
||||
})
|
||||
else:
|
||||
error_msg = '推理失败'
|
||||
for line in full_output.split('\n'):
|
||||
if 'ERROR:' in line:
|
||||
error_msg = line.split('ERROR:')[1].strip()
|
||||
break
|
||||
logger.error(f"[CHAT_LOCAL] 推理失败: {error_msg}")
|
||||
return jsonify({'code': 1, 'message': f'推理失败: {error_msg}'})
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
# 清理临时文件
|
||||
try:
|
||||
if os.path.exists(script_path):
|
||||
os.remove(script_path)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if os.path.exists(log_path):
|
||||
os.remove(log_path)
|
||||
except Exception:
|
||||
pass
|
||||
logger.error("[CHAT_LOCAL] 推理超时")
|
||||
return jsonify({'code': 1, 'message': '推理超时,请减少生成token数量或调整参数'})
|
||||
except Exception as e:
|
||||
# 清理临时文件
|
||||
try:
|
||||
if os.path.exists(script_path):
|
||||
os.remove(script_path)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if os.path.exists(log_path):
|
||||
os.remove(log_path)
|
||||
except Exception:
|
||||
pass
|
||||
logger.error(f"[CHAT_LOCAL] 推理异常: {str(e)}")
|
||||
return jsonify({'code': 1, 'message': f'推理异常: {str(e)}'})
|
||||
|
||||
@@ -4,8 +4,12 @@
|
||||
import os
|
||||
import pymysql
|
||||
import yaml
|
||||
import logging
|
||||
from flask import Blueprint, request, jsonify
|
||||
|
||||
# 获取模块 logger(继承 main.py 的日志配置)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 获取项目根目录 - 优先使用环境变量,否则从文件路径计算
|
||||
MOUNT_BASE = os.environ.get('MOUNT_BASE', '/app/base')
|
||||
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
@@ -47,8 +51,6 @@ def generic_get_all(table_name, order_by='create_time DESC'):
|
||||
|
||||
def get_model_path_by_name(model_name):
|
||||
"""根据模型名称查询模型路径(用于获取基座模型路径)"""
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(f"[DEBUG get_model_path_by_name] 查询模型: {model_name}")
|
||||
|
||||
try:
|
||||
@@ -165,6 +167,23 @@ def get_model_manage_by_id(id):
|
||||
return jsonify({'code': 1, 'message': '模型不存在'})
|
||||
|
||||
|
||||
@model_manage_bp.route('/name/<model_name>', methods=['GET'])
|
||||
def get_model_manage_by_name(model_name):
|
||||
"""根据名称获取模型"""
|
||||
logger.info(f"[DEBUG] 按名称查询模型: {model_name}")
|
||||
|
||||
conn = get_db_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM model_manage WHERE name = %s LIMIT 1", (model_name,))
|
||||
model = cursor.fetchone()
|
||||
cursor.close()
|
||||
conn.close()
|
||||
|
||||
if model:
|
||||
return jsonify({'code': 0, 'data': model})
|
||||
return jsonify({'code': 1, 'message': '模型不存在'})
|
||||
|
||||
|
||||
@model_manage_bp.route('', methods=['POST'])
|
||||
def create_model_manage():
|
||||
"""创建模型"""
|
||||
@@ -575,25 +594,57 @@ def merge_model():
|
||||
|
||||
@model_manage_bp.route('/trained-models/<model_name>', methods=['DELETE'])
|
||||
def delete_trained_model(model_name):
|
||||
"""删除已训练模型(从local_trained_models目录)"""
|
||||
"""删除已训练模型
|
||||
type=merged: 删除合并模型(local_trained_models目录)
|
||||
type=lora: 删除权重(saves目录下的lora等权重文件)
|
||||
"""
|
||||
import shutil
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 获取删除类型参数
|
||||
delete_type = request.args.get('type', 'merged') # 默认删除合并模型
|
||||
|
||||
try:
|
||||
# 删除 local_trained_models 目录下的模型
|
||||
model_path = os.path.join(PROJECT_ROOT, 'local_trained_models', model_name)
|
||||
if delete_type == 'lora':
|
||||
# 删除权重:删除 saves 目录下的权重
|
||||
saves_path = os.path.join(PROJECT_ROOT, 'saves')
|
||||
train_methods = ['lora', 'full', 'qlora', 'dpo', 'cpt', 'prefix', 'adapter', 'peft']
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
return jsonify({'code': 1, 'message': f'模型不存在: {model_name}'})
|
||||
deleted = False
|
||||
for method in train_methods:
|
||||
weight_path = os.path.join(saves_path, method, model_name)
|
||||
if os.path.exists(weight_path):
|
||||
shutil.rmtree(weight_path)
|
||||
logger.info(f"[DELETE] 已删除权重: {weight_path}")
|
||||
deleted = True
|
||||
|
||||
# 删除目录
|
||||
shutil.rmtree(model_path)
|
||||
logger.info(f"[DELETE] 已删除模型: {model_path}")
|
||||
if not deleted:
|
||||
# 也可能是老结构,直接在 saves 下的 model_name 目录
|
||||
old_path = os.path.join(saves_path, model_name)
|
||||
if os.path.exists(old_path):
|
||||
shutil.rmtree(old_path)
|
||||
logger.info(f"[DELETE] 已删除老结构权重: {old_path}")
|
||||
deleted = True
|
||||
|
||||
return jsonify({'code': 0, 'message': '删除成功'})
|
||||
if deleted:
|
||||
return jsonify({'code': 0, 'message': '权重已删除'})
|
||||
else:
|
||||
return jsonify({'code': 1, 'message': f'权重不存在: {model_name}'})
|
||||
else:
|
||||
# 默认删除合并模型(local_trained_models目录)
|
||||
model_path = os.path.join(PROJECT_ROOT, 'local_trained_models', model_name)
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
return jsonify({'code': 1, 'message': f'合并模型不存在: {model_name}'})
|
||||
|
||||
# 删除目录
|
||||
shutil.rmtree(model_path)
|
||||
logger.info(f"[DELETE] 已删除合并模型: {model_path}")
|
||||
|
||||
return jsonify({'code': 0, 'message': '合并模型已删除'})
|
||||
except Exception as e:
|
||||
logger.error(f"[DELETE] 删除模型失败: {str(e)}")
|
||||
logger.error(f"[DELETE] 删除失败: {str(e)}")
|
||||
return jsonify({'code': 1, 'message': f'删除失败: {str(e)}'})
|
||||
|
||||
|
||||
|
||||
383
src/main.py
383
src/main.py
@@ -37,7 +37,25 @@ CONFIG = load_config()
|
||||
TRAINING_LOGS_DIR = CONFIG.get('training_logs_path', '/app/base/training_logs')
|
||||
|
||||
# ============ 日志系统配置 ============
|
||||
LOG_BASE_DIR = os.path.join(PROJECT_ROOT, 'logs')
|
||||
# 日志目录逻辑:
|
||||
# 1. 优先使用环境变量 LOG_BASE_DIR
|
||||
# 2. 如果是容器环境(存在 /app/base),使用 /app/base/logs
|
||||
# 3. 否则使用本地项目路径 PROJECT_ROOT/logs
|
||||
def get_log_base_dir():
|
||||
"""获取日志基础目录"""
|
||||
# 1. 检查环境变量
|
||||
if 'LOG_BASE_DIR' in os.environ:
|
||||
return os.environ['LOG_BASE_DIR']
|
||||
|
||||
# 2. 检查是否在容器环境中
|
||||
mount_base = os.environ.get('MOUNT_BASE', '/app/base')
|
||||
if os.path.exists(mount_base):
|
||||
return os.path.join(mount_base, 'logs')
|
||||
|
||||
# 3. 使用本地项目路径
|
||||
return os.path.join(PROJECT_ROOT, 'logs')
|
||||
|
||||
LOG_BASE_DIR = get_log_base_dir()
|
||||
|
||||
|
||||
def setup_logger(name='app'):
|
||||
@@ -98,6 +116,15 @@ def setup_logger(name='app'):
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
))
|
||||
|
||||
# 6. API日志处理器 - 专门记录API相关日志
|
||||
api_log_path = os.path.join(log_dir, 'api.log')
|
||||
api_handler = RotatingFileHandler(api_log_path, maxBytes=10*1024*1024, backupCount=10, encoding='utf-8')
|
||||
api_handler.setLevel(logging.DEBUG)
|
||||
api_handler.setFormatter(logging.Formatter(
|
||||
'[%(asctime)s] %(levelname)s: %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
))
|
||||
|
||||
# 添加处理器到 logger
|
||||
logger.addHandler(all_handler)
|
||||
logger.addHandler(error_handler)
|
||||
@@ -117,6 +144,13 @@ def setup_logger(name='app'):
|
||||
train_logger.addHandler(train_handler)
|
||||
train_logger.addHandler(console_handler)
|
||||
|
||||
# 为 logs_api 创建单独的 logger (供 src/api/logs.py 使用)
|
||||
logs_api_logger = logging.getLogger('logs_api')
|
||||
logs_api_logger.setLevel(logging.DEBUG)
|
||||
logs_api_logger.handlers.clear()
|
||||
logs_api_logger.addHandler(api_handler)
|
||||
logs_api_logger.addHandler(console_handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
@@ -588,6 +622,8 @@ def system_info():
|
||||
|
||||
|
||||
# ============ 通用 CRUD 操作 ============
|
||||
import json
|
||||
|
||||
def generic_get_all(table_name, order_by='create_time DESC'):
|
||||
"""通用查询所有"""
|
||||
conn = get_db_connection()
|
||||
@@ -596,6 +632,19 @@ def generic_get_all(table_name, order_by='create_time DESC'):
|
||||
result = cursor.fetchall()
|
||||
cursor.close()
|
||||
conn.close()
|
||||
# 自动解析 JSON 字段
|
||||
for row in result:
|
||||
for key, value in row.items():
|
||||
if isinstance(value, str) and value.startswith('[') and value.endswith(']'):
|
||||
try:
|
||||
row[key] = json.loads(value)
|
||||
except:
|
||||
pass
|
||||
elif isinstance(value, str) and value.startswith('{') and value.endswith('}'):
|
||||
try:
|
||||
row[key] = json.loads(value)
|
||||
except:
|
||||
pass
|
||||
return result
|
||||
|
||||
|
||||
@@ -607,6 +656,19 @@ def generic_get_by_id(table_name, id_val):
|
||||
result = cursor.fetchone()
|
||||
cursor.close()
|
||||
conn.close()
|
||||
# 自动解析 JSON 字段
|
||||
if result:
|
||||
for key, value in result.items():
|
||||
if isinstance(value, str) and value.startswith('[') and value.endswith(']'):
|
||||
try:
|
||||
result[key] = json.loads(value)
|
||||
except:
|
||||
pass
|
||||
elif isinstance(value, str) and value.startswith('{') and value.endswith('}'):
|
||||
try:
|
||||
result[key] = json.loads(value)
|
||||
except:
|
||||
pass
|
||||
return result
|
||||
|
||||
|
||||
@@ -1053,36 +1115,19 @@ def delete_model_deploy(id):
|
||||
return jsonify({'code': 0, 'message': '删除成功'})
|
||||
|
||||
|
||||
# ============ 模型管理接口 ============
|
||||
@app.route('/api/model-manage', methods=['GET'])
|
||||
def get_model_manage():
|
||||
return jsonify({'code': 0, 'data': generic_get_all('model_manage')})
|
||||
|
||||
|
||||
@app.route('/api/model-manage', methods=['POST'])
|
||||
def create_model_manage():
|
||||
data = request.json
|
||||
new_id = generic_create('model_manage', data)
|
||||
return jsonify({'code': 0, 'message': '创建成功', 'id': new_id})
|
||||
|
||||
|
||||
@app.route('/api/model-manage/<int:id>', methods=['PUT'])
|
||||
def update_model_manage(id):
|
||||
data = request.json
|
||||
generic_update('model_manage', id, data)
|
||||
return jsonify({'code': 0, 'message': '更新成功'})
|
||||
|
||||
|
||||
@app.route('/api/model-manage/<int:id>', methods=['DELETE'])
|
||||
def delete_model_manage(id):
|
||||
generic_delete('model_manage', id)
|
||||
return jsonify({'code': 0, 'message': '删除成功'})
|
||||
|
||||
|
||||
# ============ 模型对比接口 ============
|
||||
@app.route('/api/model-compare', methods=['GET'])
|
||||
def get_model_compare():
|
||||
return jsonify({'code': 0, 'data': generic_get_all('model_compare')})
|
||||
result = generic_get_all('model_compare')
|
||||
# 确保 models 字段被正确解析为 JSON
|
||||
for row in result:
|
||||
if 'models' in row and isinstance(row['models'], str):
|
||||
try:
|
||||
row['models'] = json.loads(row['models'])
|
||||
logger.debug(f"[model-compare] 解析 models 字段成功: {row['models']}")
|
||||
except Exception as e:
|
||||
logger.error(f"[model-compare] 解析 models 字段失败: {e}, 原始值: {row['models']}")
|
||||
return jsonify({'code': 0, 'data': result})
|
||||
|
||||
|
||||
@app.route('/api/model-compare/<int:id>', methods=['GET'])
|
||||
@@ -1090,13 +1135,25 @@ def get_model_compare_by_id(id):
|
||||
"""获取单个模型对比任务"""
|
||||
result = generic_get_by_id('model_compare', id)
|
||||
if result:
|
||||
# 确保 models 字段被正确解析为 JSON
|
||||
if 'models' in result and isinstance(result['models'], str):
|
||||
try:
|
||||
result['models'] = json.loads(result['models'])
|
||||
except Exception as e:
|
||||
logger.error(f"[model-compare] 解析 models 字段失败: {e}")
|
||||
return jsonify({'code': 0, 'data': result})
|
||||
return jsonify({'code': 1, 'message': '任务不存在'})
|
||||
|
||||
|
||||
@app.route('/api/model-compare', methods=['POST'])
|
||||
def create_model_compare():
|
||||
data = request.json
|
||||
data = request.json.copy()
|
||||
# 字段映射: name -> model_name
|
||||
if 'name' in data:
|
||||
data['model_name'] = data.pop('name')
|
||||
# 设置默认加载状态
|
||||
if 'status' not in data:
|
||||
data['status'] = 'pending'
|
||||
new_id = generic_create('model_compare', data)
|
||||
return jsonify({'code': 0, 'message': '创建成功', 'id': new_id})
|
||||
|
||||
@@ -1110,10 +1167,278 @@ def update_model_compare(id):
|
||||
|
||||
@app.route('/api/model-compare/<int:id>', methods=['DELETE'])
|
||||
def delete_model_compare(id):
|
||||
# 先停止加载的模型服务
|
||||
try:
|
||||
conn = get_db_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT models, load_status FROM model_compare WHERE id = %s", (id,))
|
||||
row = cursor.fetchone()
|
||||
cursor.close()
|
||||
conn.close()
|
||||
|
||||
if row and row[1] == 'loaded':
|
||||
# 停止模型服务
|
||||
from subprocess import Popen
|
||||
import signal
|
||||
load_status = json.loads(row[0]) if isinstance(row[0], str) else row[0]
|
||||
if isinstance(load_status, dict) and 'processes' in load_status:
|
||||
for proc in load_status['processes']:
|
||||
if 'pid' in proc:
|
||||
try:
|
||||
os.kill(proc['pid'], signal.SIGTERM)
|
||||
except:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"[model-compare] 停止模型服务失败: {e}")
|
||||
|
||||
generic_delete('model_compare', id)
|
||||
return jsonify({'code': 0, 'message': '删除成功'})
|
||||
|
||||
|
||||
# ============ 模型加载接口 ============
|
||||
import subprocess
|
||||
import signal
|
||||
|
||||
# 存储加载状态 (生产环境应使用数据库或Redis)
|
||||
model_compare_processes = {}
|
||||
|
||||
|
||||
@app.route('/api/model-compare/<int:id>/load', methods=['POST'])
|
||||
def load_model_compare(id):
|
||||
"""加载模型 - 启动模型服务"""
|
||||
try:
|
||||
# 获取对比任务
|
||||
conn = get_db_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT models FROM model_compare WHERE id = %s", (id,))
|
||||
row = cursor.fetchone()
|
||||
cursor.close()
|
||||
conn.close()
|
||||
|
||||
if not row:
|
||||
return jsonify({'code': 1, 'message': '任务不存在'})
|
||||
|
||||
models = json.loads(row[0]) if isinstance(row[0], str) else row[0]
|
||||
if not models or len(models) < 2:
|
||||
return jsonify({'code': 1, 'message': '模型配置无效'})
|
||||
|
||||
# 更新状态为 loading
|
||||
generic_update('model_compare', id, {'status': 'loading'})
|
||||
|
||||
# 启动模型服务
|
||||
processes = []
|
||||
loaded_models = []
|
||||
|
||||
for i, model in enumerate(models):
|
||||
model_path = model.get('model_path', '')
|
||||
gpu_id = model.get('gpu_id', 0)
|
||||
port = model.get('port', 7862 + i * 10)
|
||||
|
||||
if not model_path:
|
||||
continue
|
||||
|
||||
# 构建启动命令
|
||||
# 获取模型名称和模板
|
||||
model_name = model.get('model_name', '')
|
||||
template = detect_template(model_path)
|
||||
|
||||
cmd = [
|
||||
sys.executable,
|
||||
'-m', 'llamafactory.api',
|
||||
'--model_name_or_path', model_path,
|
||||
'--template', template or 'default',
|
||||
'--port', str(port),
|
||||
'--gpu_ids', str(gpu_id)
|
||||
]
|
||||
|
||||
logger.info(f"[model-compare] 启动模型服务: {' '.join(cmd)}")
|
||||
|
||||
# 启动进程
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
cwd=PROJECT_ROOT
|
||||
)
|
||||
|
||||
processes.append({
|
||||
'pid': proc.pid,
|
||||
'port': port,
|
||||
'model_name': model_name,
|
||||
'model_path': model_path
|
||||
})
|
||||
|
||||
loaded_models.append({
|
||||
'port': port,
|
||||
'model_name': model_name,
|
||||
'status': 'starting'
|
||||
})
|
||||
|
||||
# 保存进程信息
|
||||
model_compare_processes[id] = {
|
||||
'processes': processes,
|
||||
'loaded_models': loaded_models,
|
||||
'start_time': datetime.now().timestamp()
|
||||
}
|
||||
|
||||
# 更新数据库中的加载状态
|
||||
load_status = {
|
||||
'processes': processes,
|
||||
'loaded_models': loaded_models,
|
||||
'started_at': datetime.now().isoformat()
|
||||
}
|
||||
generic_update('model_compare', id, {
|
||||
'status': 'loading',
|
||||
'load_status': json.dumps(load_status, ensure_ascii=False)
|
||||
})
|
||||
|
||||
return jsonify({
|
||||
'code': 0,
|
||||
'message': '正在加载模型...',
|
||||
'data': {
|
||||
'models': loaded_models,
|
||||
'check_url': f'/api/model-compare/{id}/load-status'
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[model-compare] 加载模型失败: {e}")
|
||||
generic_update('model_compare', id, {'status': 'pending'})
|
||||
return jsonify({'code': 1, 'message': f'加载失败: {str(e)}'})
|
||||
|
||||
|
||||
def detect_template(model_path):
|
||||
"""根据模型路径检测模板类型"""
|
||||
model_lower = model_path.lower()
|
||||
if 'qwen' in model_lower:
|
||||
return 'qwen'
|
||||
elif 'llama' in model_lower or 'llama' in model_lower:
|
||||
return 'llama'
|
||||
elif 'chatglm' in model_lower:
|
||||
return 'chatglm'
|
||||
elif 'baichuan' in model_lower:
|
||||
return 'baichuan'
|
||||
elif 'mistral' in model_lower:
|
||||
return 'mistral'
|
||||
elif 'yi' in model_lower:
|
||||
return 'yi'
|
||||
elif 'deepseek' in model_lower:
|
||||
return 'deepseek'
|
||||
return 'default'
|
||||
|
||||
|
||||
@app.route('/api/model-compare/<int:id>/load-status', methods=['GET'])
|
||||
def get_load_status(id):
|
||||
"""获取模型加载状态"""
|
||||
try:
|
||||
conn = get_db_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT models, load_status FROM model_compare WHERE id = %s", (id,))
|
||||
row = cursor.fetchone()
|
||||
cursor.close()
|
||||
conn.close()
|
||||
|
||||
if not row:
|
||||
return jsonify({'code': 1, 'message': '任务不存在'})
|
||||
|
||||
models = json.loads(row[0]) if isinstance(row[0], str) else row[0]
|
||||
load_status = json.loads(row[1]) if row[1] and isinstance(row[1], str) else {}
|
||||
|
||||
loaded_models = load_status.get('loaded_models', [])
|
||||
processes = load_status.get('processes', [])
|
||||
|
||||
# 检查每个模型服务是否就绪
|
||||
all_ready = True
|
||||
for i, model in enumerate(loaded_models):
|
||||
port = model.get('port')
|
||||
try:
|
||||
import requests
|
||||
resp = requests.get(f'http://localhost:{port}/health', timeout=2)
|
||||
if resp.status_code == 200:
|
||||
model['status'] = 'ready'
|
||||
else:
|
||||
model['status'] = 'starting'
|
||||
all_ready = False
|
||||
except:
|
||||
# 检查进程是否还在运行
|
||||
if 'pid' in processes[i] if i < len(processes) else False:
|
||||
model['status'] = 'starting'
|
||||
all_ready = False
|
||||
else:
|
||||
model['status'] = 'failed'
|
||||
all_ready = False
|
||||
|
||||
# 更新状态
|
||||
if all_ready:
|
||||
generic_update('model_compare', id, {'status': 'loaded'})
|
||||
else:
|
||||
generic_update('model_compare', id, {'status': 'loading'})
|
||||
|
||||
# 更新加载状态
|
||||
load_status['loaded_models'] = loaded_models
|
||||
load_status['all_ready'] = all_ready
|
||||
generic_update('model_compare', id, {
|
||||
'load_status': json.dumps(load_status, ensure_ascii=False)
|
||||
})
|
||||
|
||||
return jsonify({
|
||||
'code': 0,
|
||||
'data': {
|
||||
'status': 'loaded' if all_ready else 'loading',
|
||||
'models': loaded_models,
|
||||
'all_ready': all_ready
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[model-compare] 获取加载状态失败: {e}")
|
||||
return jsonify({'code': 1, 'message': str(e)})
|
||||
|
||||
|
||||
@app.route('/api/model-compare/<int:id>/unload', methods=['POST'])
|
||||
def unload_model_compare(id):
|
||||
"""停止加载的模型服务"""
|
||||
try:
|
||||
conn = get_db_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT load_status FROM model_compare WHERE id = %s", (id,))
|
||||
row = cursor.fetchone()
|
||||
cursor.close()
|
||||
conn.close()
|
||||
|
||||
if row and row[0]:
|
||||
load_status = json.loads(row[0]) if isinstance(row[0], str) else row[0]
|
||||
processes = load_status.get('processes', [])
|
||||
|
||||
# 停止所有进程
|
||||
for proc in processes:
|
||||
pid = proc.get('pid')
|
||||
if pid:
|
||||
try:
|
||||
os.kill(pid, signal.SIGTERM)
|
||||
logger.info(f"[model-compare] 已停止进程 {pid}")
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"[model-compare] 停止进程 {pid} 失败: {e}")
|
||||
|
||||
# 清理内存
|
||||
if id in model_compare_processes:
|
||||
del model_compare_processes[id]
|
||||
|
||||
# 更新状态
|
||||
generic_update('model_compare', id, {
|
||||
'status': 'pending',
|
||||
'load_status': None
|
||||
})
|
||||
|
||||
return jsonify({'code': 0, 'message': '已停止模型服务'})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[model-compare] 停止模型服务失败: {e}")
|
||||
return jsonify({'code': 1, 'message': str(e)})
|
||||
|
||||
|
||||
# ============ 数据生成接口 ============
|
||||
@app.route('/api/data-generate', methods=['GET'])
|
||||
def get_data_generate():
|
||||
|
||||
Reference in New Issue
Block a user