#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ QA生成器 - 整合版 整合所有功能于一个文件,减少冗余 """ import json import os import random from typing import List, Dict, Any class QAConfig: """QA生成配置类""" def __init__(self): # 基础配置 self.RANDOM_SEED = 42 self.INPUT_DIR = "Data_Export_Json" self.OUTPUT_DIR = "Data_QA_Outputs" # 复杂程度控制 (1-5) self.COMPLEXITY_LEVEL = 3 # 问题数量控制 self.BASIC_QUESTIONS_PER_ITEM = 1 self.MAX_QUESTIONS_PER_ITEM = 10 self.MULTI_COLUMN_RATIO = 0.3 # 输出控制 self.SHUFFLE_OUTPUT = True self.GENERATE_REPORT = True self.VERBOSE_LOG = True # 数据文件配置 self.DATA_FILES = [ {"name": "元素治理模板", "file": "元素治理模板.json", "output": "元素治理模板_QA.json", "enabled": True}, {"name": "物理模型", "file": "物理模型.json", "output": "物理模型_QA.json", "enabled": True}, {"name": "逻辑模型", "file": "逻辑模型.json", "output": "逻辑模型_QA.json", "enabled": True} ] # 初始化修饰语和连接词 self._init_templates() def _init_templates(self): """初始化模板列表""" if self.COMPLEXITY_LEVEL <= 2: # 简单模式 self.QUESTION_PREFIXES = ["请告诉我", "查询", "请问"] self.ANSWER_PREFIXES = ["根据表记录,该字段的", "查询结果显示,", "经查询,该字段的"] self.ANSWER_SUFFIXES = ["。", "。"] self.CONNECTORS = ["和", "与"] self.SINGLE_TEMPLATES = 6 if self.COMPLEXITY_LEVEL == 2 else 3 self.MULTI_TEMPLATES = 1 if self.COMPLEXITY_LEVEL == 2 else 0 self.MULTI_RATIO = 0.1 if self.COMPLEXITY_LEVEL == 2 else 0.0 else: # 普通/复杂模式 self.QUESTION_PREFIXES = ["请告诉我", "查询", "请问", "在", "请解释", "请输出", "请列举", "请说明", "请查找", "请确认"] self.ANSWER_PREFIXES = ["根据表记录,该字段的", "查询结果显示,", "经查询,该字段的", "根据数据库记录,", "在表中,此字段的", "查询结果:", "经系统查询,", "根据记录显示,", "在数据中,该字段的", "查询得知,该字段的"] self.ANSWER_SUFFIXES = ["。", ",请参考。", ",详情如上。", ",以上信息为准。", ",望知悉。", ",如需更多信息请联系。", ",希望能帮到您。", ",祝您工作顺利。", ",谢谢。", "。"] self.CONNECTORS = ["和", "与", "及", "、", ",还有", "以及"] self.SINGLE_TEMPLATES = 12 self.MULTI_TEMPLATES = 5 self.MULTI_RATIO = self.MULTI_COLUMN_RATIO def get_random_element(self, elements: List[str]) -> str: """从列表中随机获取一个元素""" return random.choice(elements) if elements else "" def update_config(self, **kwargs): """更新配置参数""" for key, value in kwargs.items(): if hasattr(self, key): setattr(self, key, value) if key == 'COMPLEXITY_LEVEL': self._init_templates() # 重新初始化模板 def print_config(self): """打印当前配置""" print(f"\n复杂程度等级: {self.COMPLEXITY_LEVEL}") print(f"单列模板数: {self.SINGLE_TEMPLATES}") print(f"多列模板数: {self.MULTI_TEMPLATES}") print(f"多列占比: {self.MULTI_RATIO}") print(f"输出目录: {self.OUTPUT_DIR}") class QAGenerator: """QA生成器 - 整合版""" def __init__(self, config: QAConfig = None): """初始化生成器""" self.config = config or QAConfig() os.makedirs(self.config.OUTPUT_DIR, exist_ok=True) random.seed(self.config.RANDOM_SEED) def load_json(self, file_path: str) -> List[Dict]: """加载JSON文件""" with open(file_path, 'r', encoding='utf-8') as f: return json.load(f) def generate_single_qa(self, item: Dict, template_count: int, data_type: str) -> List[Dict]: """生成单列QA - 整合了三种数据类型的模板""" qa_pairs = [] answer_prefixes = self.config.ANSWER_PREFIXES answer_suffixes = self.config.ANSWER_SUFFIXES prefixes = self.config.QUESTION_PREFIXES if data_type == "element": # 元素治理模板模板 templates = [] table_name = item.get("表名", "元素治理模板") if item.get("业务领域名称") and item.get("数据元素中文名"): templates.append((f"在{table_name}中,业务领域「{item['业务领域名称']}」对应的数据元素中文名是什么?", f"该数据元素为「{item['数据元素中文名']}」")) if item.get("业务领域名称") and item.get("数据元素中文名"): templates.append((f"「{item['数据元素中文名']}」这个数据元素在{table_name}中属于哪个业务领域?", f"该数据元素属于「{item['业务领域名称']}」")) if item.get("数据元素中文名") and item.get("值类型"): templates.append((f"查询{table_name}中,数据元素「{item['数据元素中文名']}」的值类型是什么?", f"值类型为「{item['值类型']}」")) if item.get("数据元素中文名") and item.get("总长度"): templates.append((f"在{table_name}里,「{item['数据元素中文名']}」的总长度设置是多少?", f"总长度为{item['总长度']}")) if item.get("数据元素中文名") and item.get("类别"): templates.append((f"请确认「{item['数据元素中文名']}」在{table_name}中属于哪个类别?", f"该数据元素属于「{item['类别']}」类别")) if item.get("数据元素中文名") and item.get("数据元素英文名"): templates.append((f"请查找{table_name}中「{item['数据元素中文名']}」对应的英文名是什么?", f"英文名为「{item['数据元素英文名']}」")) if item.get("数据元素中文名") and item.get("是否枚举"): templates.append((f"「{item['数据元素中文名']}」这个数据元素在{table_name}中是否枚举?", f"是否枚举为「{item['是否枚举']}」")) if item.get("数据元素中文名") and item.get("枚举数量"): templates.append((f"请问在{table_name}中,「{item['数据元素中文名']}」的枚举数量是多少?", f"枚举数量为{item['枚举数量']}")) if item.get("数据元素中文名") and item.get("小数位"): templates.append((f"请说明{table_name}中「{item['数据元素中文名']}」的小数位设置?", f"小数位为{item['小数位']}")) if item.get("数据元素中文名") and item.get("抽象元素中文名"): templates.append((f"在{table_name}里,「{item['数据元素中文名']}」的抽象元素中文名是什么?", f"抽象元素中文名为「{item['抽象元素中文名']}」")) if item.get("数据元素中文名") and item.get("说明"): templates.append((f"请解释「{item['数据元素中文名']}」在{table_name}中的作用和含义", f"该数据元素的说明为「{item['说明']}」")) if item.get("数据元素中文名") and item.get("是否上线"): templates.append((f"请问「{item['数据元素中文名']}」在{table_name}中是否已上线?", f"是否上线为「{item['是否上线']}」")) # 生成QA for i, (question, answer) in enumerate(templates[:template_count]): qa_pairs.append({ "instruct": question, "input": "", "output": f"{self.config.get_random_element(answer_prefixes)}{answer}{self.config.get_random_element(answer_suffixes)}" }) elif data_type == "physical": # 物理模型模板 table_name = item.get("表名", "物理模型") templates = [] if item.get("物理模型属性中文名") and item.get("值类型"): templates.append((f"在{table_name}中,「{item['物理模型属性中文名']}」的值类型是什么?", f"该属性的值类型为「{item['值类型']}」")) if item.get("物理模型属性中文名") and item.get("长度"): templates.append((f"请问「{item['物理模型属性中文名']}」在{table_name}中的长度是多少?", f"该属性长度为{item['长度']}")) if item.get("物理模型属性中文名") and item.get("小数位") is not None: templates.append((f"请确认{table_name}中「{item['物理模型属性中文名']}」的小数位设置", f"该属性小数位为{item['小数位']}")) if item.get("物理模型属性中文名") and item.get("关联数据元素"): templates.append((f"「{item['物理模型属性中文名']}」这个属性在{table_name}中关联哪个数据元素?", f"该属性关联的数据元素为「{item['关联数据元素']}」")) if item.get("物理模型属性中文名") and item.get("物理模型中文名"): templates.append((f"请查找「{item['物理模型属性中文名']}」属于哪个物理模型?", f"该属性属于「{item['物理模型中文名']}」")) if item.get("物理模型属性中文名") and item.get("说明"): templates.append((f"请说明「{item['物理模型属性中文名']}」在{table_name}中的作用和用途", f"该属性的说明为「{item['说明']}」")) if item.get("物理模型属性英文名") and item.get("物理模型属性中文名"): templates.append((f"在{table_name}中,属性「{item['物理模型属性英文名']}」的中文名是什么?", f"该属性的中文名为「{item['物理模型属性中文名']}」")) if item.get("物理模型中文名") and item.get("物理模型英文名"): templates.append((f"「{item['物理模型中文名']}」对应的物理模型英文名是什么?", f"该物理模型的英文名为「{item['物理模型英文名']}」")) if item.get("物理模型英文名") and item.get("物理模型中文名"): templates.append((f"在{table_name}中,物理模型「{item['物理模型英文名']}」的中文名是什么?", f"该物理模型的中文名为「{item['物理模型中文名']}」")) # 生成QA for i, (question, answer) in enumerate(templates[:template_count]): qa_pairs.append({ "instruct": question, "input": "", "output": f"{self.config.get_random_element(answer_prefixes)}{answer}{self.config.get_random_element(answer_suffixes)}" }) elif data_type == "logical": # 逻辑模型模板 table_name = item.get("表名", "逻辑模型") templates = [] if item.get("业务领域") and item.get("逻辑模型中文名"): templates.append((f"在{table_name}中,业务领域为「{item['业务领域']}」的逻辑模型中文名是什么?", f"该业务领域对应的逻辑模型中文名为「{item['逻辑模型中文名']}」")) if item.get("业务领域") and item.get("逻辑模型中文名"): templates.append((f"「{item['逻辑模型中文名']}」这个逻辑模型在{table_name}中属于哪个业务领域?", f"该逻辑模型属于「{item['业务领域']}」")) if item.get("字段英文名") and item.get("字段中文名"): templates.append((f"请告诉我「{item['字段英文名']}」在{table_name}中的中文名是什么?", f"该字段的中文名为「{item['字段中文名']}」")) if item.get("字段中文名") and item.get("字段英文名"): templates.append((f"在{table_name}中,「{item['字段中文名']}」对应的英文名是什么?", f"该字段的英文名为「{item['字段英文名']}」")) if item.get("字段中文名") and item.get("值类型"): templates.append((f"请问「{item['字段中文名']}」在{table_name}中的值类型是什么?", f"该字段的值类型为「{item['值类型']}」")) if item.get("字段中文名") and item.get("长度"): templates.append((f"查询{table_name}中,「{item['字段中文名']}」的长度是多少?", f"该字段长度为{item['长度']}")) if item.get("字段中文名") and item.get("小数位") is not None: templates.append((f"请确认「{item['字段中文名']}」在{table_name}中的小数位设置", f"该字段小数位为{item['小数位']}")) if item.get("逻辑模型中文名") and item.get("动态查询能力"): templates.append((f"「{item['逻辑模型中文名']}」的动态查询能力是什么级别?", f"该逻辑模型的动态查询能力为「{item['动态查询能力']}」")) if item.get("字段中文名") and item.get("关联数据元素英文名"): templates.append((f"在{table_name}中,「{item['字段中文名']}」关联的数据元素英文名是什么?", f"该字段关联的数据元素英文名为「{item['关联数据元素英文名']}」")) if item.get("逻辑模型中文名") and item.get("逻辑模型英文名"): templates.append((f"请查找「{item['逻辑模型中文名']}」对应的逻辑模型英文名", f"该逻辑模型的英文名为「{item['逻辑模型英文名']}」")) # 生成QA for i, (question, answer) in enumerate(templates[:template_count]): qa_pairs.append({ "instruct": question, "input": "", "output": f"{self.config.get_random_element(answer_prefixes)}{answer}{self.config.get_random_element(answer_suffixes)}" }) return qa_pairs def generate_multi_qa(self, item: Dict, template_count: int, data_type: str) -> List[Dict]: """生成多列QA""" qa_pairs = [] answer_prefixes = self.config.ANSWER_PREFIXES answer_suffixes = self.config.ANSWER_SUFFIXES connectors = self.config.CONNECTORS if data_type == "element": table_name = item.get("表名", "元素治理模板") templates = [] if item.get("数据元素中文名") and item.get("值类型") and item.get("总长度"): connector = self.config.get_random_element(connectors) templates.append(( f"请列举{table_name}中「{item['数据元素中文名']}」的值类型和总长度", f"该数据元素的{connector}值类型为「{item['值类型']}」,总长度为{item['总长度']}" )) if item.get("数据元素中文名") and item.get("类别") and item.get("业务领域名称") and item.get("是否枚举"): connector1 = self.config.get_random_element(connectors[:3]) connector2 = self.config.get_random_element(connectors[3:]) templates.append(( f"请输出「{item['数据元素中文名']}」在{table_name}中的类别、业务领域和是否枚举信息", f"该数据元素的类别为「{item['类别']}」,{connector1}业务领域为「{item['业务领域名称']}」,{connector2}是否枚举为「{item['是否枚举']}」" )) if item.get("数据元素中文名") and item.get("值类型") and item.get("总长度") and item.get("小数位"): connector = self.config.get_random_element(connectors) templates.append(( f"在{table_name}中,「{item['数据元素中文名']}」的值类型、长度和小数位分别是多少?", f"该数据元素的值类型为「{item['值类型']}」,{connector}长度为{item['总长度']},小数位为{item['小数位']}" )) if item.get("数据元素中文名") and item.get("数据元素英文名") and item.get("说明"): connector = self.config.get_random_element(connectors) templates.append(( f"请查找「{item['数据元素中文名']}」在{table_name}中的英文名和说明信息", f"该数据元素的英文名为「{item['数据元素英文名']}」,{connector}说明为「{item['说明']}」" )) if item.get("数据元素中文名") and item.get("枚举数量") and item.get("元素取值范围\n(枚举类型名称)"): connector = self.config.get_random_element(connectors) range_value = item['元素取值范围\n(枚举类型名称)'] or "无" templates.append(( f"在{table_name}中,「{item['数据元素中文名']}」的枚举数量和取值范围分别是什么?", f"该数据元素的枚举数量为{item['枚举数量']},{connector}取值范围为「{range_value}」" )) # 生成QA for i, (question, answer) in enumerate(templates[:template_count]): qa_pairs.append({ "instruct": question, "input": "", "output": f"{self.config.get_random_element(answer_prefixes)}{answer}{self.config.get_random_element(answer_suffixes)}" }) # 为物理模型和逻辑模型也添加类似的多列模板... elif data_type == "physical": table_name = item.get("表名", "物理模型") if item.get("物理模型属性中文名") and item.get("值类型") and item.get("长度"): connector = self.config.get_random_element(connectors) qa_pairs.append({ "instruct": f"请列举「{item['物理模型属性中文名']}」在{table_name}中的值类型和长度", "input": "", "output": f"{self.config.get_random_element(answer_prefixes)}该属性的{connector}值类型为「{item['值类型']}」,长度为{item['长度']}{self.config.get_random_element(answer_suffixes)}" }) elif data_type == "logical": table_name = item.get("表名", "逻辑模型") if item.get("字段中文名") and item.get("值类型") and item.get("长度"): connector = self.config.get_random_element(connectors) qa_pairs.append({ "instruct": f"请列举「{item['字段中文名']}」在{table_name}中的值类型和长度", "input": "", "output": f"{self.config.get_random_element(answer_prefixes)}该字段的{connector}值类型为「{item['值类型']}」,长度为{item['长度']}{self.config.get_random_element(answer_suffixes)}" }) return qa_pairs def generate_qa_for_data(self, data: List[Dict], data_type: str) -> List[Dict]: """为指定数据类型生成QA""" all_qa = [] for item in data: # 生成单列QA single_qa = self.generate_single_qa(item, self.config.SINGLE_TEMPLATES, data_type) all_qa.extend(single_qa) # 根据概率生成多列QA if random.random() < self.config.MULTI_RATIO: multi_qa = self.generate_multi_qa(item, self.config.MULTI_TEMPLATES, data_type) all_qa.extend(multi_qa) return all_qa def shuffle_qa_pairs(self, qa_pairs: List[Dict]) -> List[Dict]: """随机打乱问答对顺序""" if self.config.SHUFFLE_OUTPUT: random.shuffle(qa_pairs) return qa_pairs def save_qa(self, qa_pairs: List[Dict], filename: str): """保存QA到文件""" output_path = os.path.join(self.config.OUTPUT_DIR, filename) with open(output_path, 'w', encoding='utf-8') as f: json.dump(qa_pairs, f, ensure_ascii=False, indent=2) size_mb = os.path.getsize(output_path) / (1024 * 1024) if self.config.VERBOSE_LOG: print(f"[OK] 已生成: {output_path} (共 {len(qa_pairs)} 条问答对, {size_mb:.1f}MB)") else: print(f"[OK] 已生成: {output_path} (共 {len(qa_pairs)} 条问答对)") def merge_to_train(self, output_file: str = "train.json") -> Dict: """合并QA文件为train.json""" all_qa_pairs = [] file_stats = {} total_files = 0 # 遍历输出目录中的所有QA文件 for filename in os.listdir(self.config.OUTPUT_DIR): if filename.endswith('.json') and not filename.startswith('QA生成报告') and filename != 'train_stats.json': file_path = os.path.join(self.config.OUTPUT_DIR, filename) try: with open(file_path, 'r', encoding='utf-8') as f: qa_data = json.load(f) if isinstance(qa_data, list) and all(isinstance(item, dict) and 'instruct' in item and 'output' in item for item in qa_data): all_qa_pairs.extend(qa_data) file_stats[filename] = len(qa_data) total_files += 1 if self.config.VERBOSE_LOG: print(f"[OK] 已合并: {filename} ({len(qa_data)} 条问答对)") except Exception as e: print(f"[ERROR] 读取文件失败 {filename}: {str(e)}") # 打乱顺序 if self.config.SHUFFLE_OUTPUT and all_qa_pairs: random.shuffle(all_qa_pairs) if self.config.VERBOSE_LOG: print(f"\n[INFO] 已打乱 {len(all_qa_pairs)} 条问答对顺序") # 保存train.json train_path = os.path.join(self.config.OUTPUT_DIR, output_file) with open(train_path, 'w', encoding='utf-8') as f: json.dump(all_qa_pairs, f, ensure_ascii=False, indent=2) file_size_mb = os.path.getsize(train_path) / (1024 * 1024) # 生成统计 stats = { "合并时间": "2025-12-18", "处理文件数": total_files, "各文件问答对数量": file_stats, "总问答对数量": len(all_qa_pairs), "输出文件大小": f"{file_size_mb:.2f} MB", "打乱顺序": self.config.SHUFFLE_OUTPUT } # 保存统计信息 stats_file = os.path.join(self.config.OUTPUT_DIR, "train_stats.json") with open(stats_file, 'w', encoding='utf-8') as f: json.dump(stats, f, ensure_ascii=False, indent=2) print(f"\n[SUCCESS] 合并完成!") print(f"[OUTPUT] {train_path}") print(f"[TOTAL] 总计: {len(all_qa_pairs):,} 条问答对") print(f"[SIZE] 文件大小: {file_size_mb:.2f} MB") return stats def generate_report(self, qa_counts: Dict[str, int]): """生成生成报告""" if not self.config.GENERATE_REPORT: return report = { "生成时间": "2025-12-18", "版本": "整合版", "配置信息": { "复杂程度等级": self.config.COMPLEXITY_LEVEL, "随机种子": self.config.RANDOM_SEED, "多列查询占比": self.config.MULTI_COLUMN_RATIO, "打乱输出": self.config.SHUFFLE_OUTPUT }, "总文件数": len(qa_counts), "各文件问答对数量": qa_counts, "总计问答对数量": sum(qa_counts.values()), "说明": "所有问答对均基于原始JSON数据生成,未进行任何编撰或修改" } report_path = os.path.join(self.config.OUTPUT_DIR, "QA生成报告.json") with open(report_path, 'w', encoding='utf-8') as f: json.dump(report, f, ensure_ascii=False, indent=2) if self.config.VERBOSE_LOG: print(f"[OK] 已生成: {report_path}") def process_all(self): # 清理输出目录中的旧文件 print("[INFO] 清理输出目录中的旧文件...") if os.path.exists(self.config.OUTPUT_DIR): for filename in os.listdir(self.config.OUTPUT_DIR): file_path = os.path.join(self.config.OUTPUT_DIR, filename) try: if os.path.isfile(file_path): os.remove(file_path) if self.config.VERBOSE_LOG: print(f" [DEL] 已删除: {filename}") except Exception as e: print(f" [WARN] 删除文件失败 {filename}: {str(e)}") """处理所有数据文件""" qa_counts = {} for file_info in self.config.DATA_FILES: if not file_info["enabled"]: if self.config.VERBOSE_LOG: print(f"[SKIP] 已跳过: {file_info['name']}") continue if self.config.VERBOSE_LOG: print(f"\n[INFO] 正在处理: {file_info['name']}.json") else: print(f"[INFO] 正在处理: {file_info['name']}") input_file = os.path.join(self.config.INPUT_DIR, file_info["file"]) if not os.path.exists(input_file): print(f"[WARNING] 文件不存在: {input_file}") continue try: data = self.load_json(input_file) # 根据文件名确定数据类型 if "元素治理" in file_info["name"]: data_type = "element" elif "物理模型" in file_info["name"]: data_type = "physical" elif "逻辑模型" in file_info["name"]: data_type = "logical" else: print(f"[WARNING] 未知的文件类型: {file_info['name']}") continue qa_pairs = self.generate_qa_for_data(data, data_type) qa_pairs = self.shuffle_qa_pairs(qa_pairs) self.save_qa(qa_pairs, file_info["output"]) qa_counts[file_info["output"]] = len(qa_pairs) except Exception as e: print(f"[ERROR] 处理文件 {file_info['name']} 时出错: {str(e)}") # 生成报告 if self.config.GENERATE_REPORT: if self.config.VERBOSE_LOG: print("\n[INFO] 正在生成: 生成报告") else: print("\n[INFO] 正在生成: 生成报告") self.generate_report(qa_counts) if self.config.VERBOSE_LOG: print(f"\n[DONE] 所有文件处理完成!") print(f"[OUT] 输出目录: {self.config.OUTPUT_DIR}") print(f"[TOTAL] 总计生成: {sum(qa_counts.values())} 条问答对") else: print("\n[DONE] 所有文件处理完成!") # 预设配置 SIMPLE_CONFIG = QAConfig() SIMPLE_CONFIG.COMPLEXITY_LEVEL = 1 SIMPLE_CONFIG._init_templates() NORMAL_CONFIG = QAConfig() NORMAL_CONFIG.COMPLEXITY_LEVEL = 3 NORMAL_CONFIG._init_templates() COMPLEX_CONFIG = QAConfig() COMPLEX_CONFIG.COMPLEXITY_LEVEL = 5 COMPLEX_CONFIG._init_templates() def main(): """主函数 - 交互式运行""" print("="*60) print("QA生成器 - 整合版") print("="*60) print("\n可用的配置预设:") print("1. SIMPLE_CONFIG - 简单模式 (复杂程度=1)") print("2. NORMAL_CONFIG - 普通模式 (复杂程度=3)") print("3. COMPLEX_CONFIG - 复杂模式 (复杂程度=5)") print("4. 自定义配置") choice = input("\n请选择配置 (1-4): ").strip() if choice == "1": config = SIMPLE_CONFIG print("\n✓ 已选择: 简单模式") elif choice == "2": config = NORMAL_CONFIG print("\n✓ 已选择: 普通模式") elif choice == "3": config = COMPLEX_CONFIG print("\n✓ 已选择: 复杂模式") elif choice == "4": print("\n自定义配置:") config = QAConfig() complexity = input(f"复杂程度等级 (1-5, 当前:{config.COMPLEXITY_LEVEL}): ").strip() if complexity: config.COMPLEXITY_LEVEL = int(complexity) config._init_templates() multi_ratio = input(f"多列查询占比 0.0-1.0 (当前:{config.MULTI_COLUMN_RATIO}): ").strip() if multi_ratio: config.MULTI_COLUMN_RATIO = float(multi_ratio) config._init_templates() print("\n✓ 已应用自定义配置") else: print("\n无效选择,使用默认配置") config = NORMAL_CONFIG # 显示配置信息 config.print_config() # 创建生成器并处理 generator = QAGenerator(config) generator.process_all() # 询问是否合并为train.json merge_choice = input("\n是否合并为train.json? (y/n): ").strip().lower() if merge_choice == 'y': generator.merge_to_train() if __name__ == "__main__": main()