Files
YG_TDgenerator/qa_generator.py

585 lines
29 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
QA生成器 - 整合版
整合所有功能于一个文件,减少冗余
"""
import json
import os
import random
from typing import List, Dict, Any
class QAConfig:
"""QA生成配置类"""
def __init__(self):
# 基础配置
self.RANDOM_SEED = 42
self.INPUT_DIR = "Data_Export_Json"
self.OUTPUT_DIR = "Data_QA_Outputs"
# 复杂程度控制 (1-5)
self.COMPLEXITY_LEVEL = 3
# 问题数量控制
self.BASIC_QUESTIONS_PER_ITEM = 1
self.MAX_QUESTIONS_PER_ITEM = 10
self.MULTI_COLUMN_RATIO = 0.3
# 输出控制
self.SHUFFLE_OUTPUT = True
self.GENERATE_REPORT = True
self.VERBOSE_LOG = True
# 数据文件配置
self.DATA_FILES = [
{"name": "元素治理模板", "file": "元素治理模板.json", "output": "元素治理模板_QA.json", "enabled": True},
{"name": "物理模型", "file": "物理模型.json", "output": "物理模型_QA.json", "enabled": True},
{"name": "逻辑模型", "file": "逻辑模型.json", "output": "逻辑模型_QA.json", "enabled": True}
]
# 初始化修饰语和连接词
self._init_templates()
def _init_templates(self):
"""初始化模板列表"""
if self.COMPLEXITY_LEVEL <= 2:
# 简单模式
self.QUESTION_PREFIXES = ["请告诉我", "查询", "请问"]
self.ANSWER_PREFIXES = ["根据表记录,该字段的", "查询结果显示,", "经查询,该字段的"]
self.ANSWER_SUFFIXES = ["", ""]
self.CONNECTORS = ["", ""]
self.SINGLE_TEMPLATES = 6 if self.COMPLEXITY_LEVEL == 2 else 3
self.MULTI_TEMPLATES = 1 if self.COMPLEXITY_LEVEL == 2 else 0
self.MULTI_RATIO = 0.1 if self.COMPLEXITY_LEVEL == 2 else 0.0
else:
# 普通/复杂模式
self.QUESTION_PREFIXES = ["请告诉我", "查询", "请问", "", "请解释", "请输出", "请列举", "请说明", "请查找", "请确认"]
self.ANSWER_PREFIXES = ["根据表记录,该字段的", "查询结果显示,", "经查询,该字段的", "根据数据库记录,", "在表中,此字段的", "查询结果:", "经系统查询,", "根据记录显示,", "在数据中,该字段的", "查询得知,该字段的"]
self.ANSWER_SUFFIXES = ["", ",请参考。", ",详情如上。", ",以上信息为准。", ",望知悉。", ",如需更多信息请联系。", ",希望能帮到您。", ",祝您工作顺利。", ",谢谢。", ""]
self.CONNECTORS = ["", "", "", "", ",还有", "以及"]
self.SINGLE_TEMPLATES = 12
self.MULTI_TEMPLATES = 5
self.MULTI_RATIO = self.MULTI_COLUMN_RATIO
def get_random_element(self, elements: List[str]) -> str:
"""从列表中随机获取一个元素"""
return random.choice(elements) if elements else ""
def update_config(self, **kwargs):
"""更新配置参数"""
for key, value in kwargs.items():
if hasattr(self, key):
setattr(self, key, value)
if key == 'COMPLEXITY_LEVEL':
self._init_templates() # 重新初始化模板
def print_config(self):
"""打印当前配置"""
print(f"\n复杂程度等级: {self.COMPLEXITY_LEVEL}")
print(f"单列模板数: {self.SINGLE_TEMPLATES}")
print(f"多列模板数: {self.MULTI_TEMPLATES}")
print(f"多列占比: {self.MULTI_RATIO}")
print(f"输出目录: {self.OUTPUT_DIR}")
class QAGenerator:
"""QA生成器 - 整合版"""
def __init__(self, config: QAConfig = None):
"""初始化生成器"""
self.config = config or QAConfig()
os.makedirs(self.config.OUTPUT_DIR, exist_ok=True)
random.seed(self.config.RANDOM_SEED)
def load_json(self, file_path: str) -> List[Dict]:
"""加载JSON文件"""
with open(file_path, 'r', encoding='utf-8') as f:
return json.load(f)
def generate_single_qa(self, item: Dict, template_count: int, data_type: str) -> List[Dict]:
"""生成单列QA - 整合了三种数据类型的模板"""
qa_pairs = []
answer_prefixes = self.config.ANSWER_PREFIXES
answer_suffixes = self.config.ANSWER_SUFFIXES
prefixes = self.config.QUESTION_PREFIXES
if data_type == "element":
# 元素治理模板模板
templates = []
table_name = item.get("表名", "元素治理模板")
if item.get("业务领域名称") and item.get("数据元素中文名"):
templates.append((f"{table_name}中,业务领域「{item['业务领域名称']}」对应的数据元素中文名是什么?", f"该数据元素为「{item['数据元素中文名']}"))
if item.get("业务领域名称") and item.get("数据元素中文名"):
templates.append((f"{item['数据元素中文名']}」这个数据元素在{table_name}中属于哪个业务领域?", f"该数据元素属于「{item['业务领域名称']}"))
if item.get("数据元素中文名") and item.get("值类型"):
templates.append((f"查询{table_name}中,数据元素「{item['数据元素中文名']}」的值类型是什么?", f"值类型为「{item['值类型']}"))
if item.get("数据元素中文名") and item.get("总长度"):
templates.append((f"{table_name}里,「{item['数据元素中文名']}」的总长度设置是多少?", f"总长度为{item['总长度']}"))
if item.get("数据元素中文名") and item.get("类别"):
templates.append((f"请确认「{item['数据元素中文名']}」在{table_name}中属于哪个类别?", f"该数据元素属于「{item['类别']}」类别"))
if item.get("数据元素中文名") and item.get("数据元素英文名"):
templates.append((f"请查找{table_name}中「{item['数据元素中文名']}」对应的英文名是什么?", f"英文名为「{item['数据元素英文名']}"))
if item.get("数据元素中文名") and item.get("是否枚举"):
templates.append((f"{item['数据元素中文名']}」这个数据元素在{table_name}中是否枚举?", f"是否枚举为「{item['是否枚举']}"))
if item.get("数据元素中文名") and item.get("枚举数量"):
templates.append((f"请问在{table_name}中,「{item['数据元素中文名']}」的枚举数量是多少?", f"枚举数量为{item['枚举数量']}"))
if item.get("数据元素中文名") and item.get("小数位"):
templates.append((f"请说明{table_name}中「{item['数据元素中文名']}」的小数位设置?", f"小数位为{item['小数位']}"))
if item.get("数据元素中文名") and item.get("抽象元素中文名"):
templates.append((f"{table_name}里,「{item['数据元素中文名']}」的抽象元素中文名是什么?", f"抽象元素中文名为「{item['抽象元素中文名']}"))
if item.get("数据元素中文名") and item.get("说明"):
templates.append((f"请解释「{item['数据元素中文名']}」在{table_name}中的作用和含义", f"该数据元素的说明为「{item['说明']}"))
if item.get("数据元素中文名") and item.get("是否上线"):
templates.append((f"请问「{item['数据元素中文名']}」在{table_name}中是否已上线?", f"是否上线为「{item['是否上线']}"))
# 生成QA
for i, (question, answer) in enumerate(templates[:template_count]):
qa_pairs.append({
"instruct": question,
"input": "",
"output": f"{self.config.get_random_element(answer_prefixes)}{answer}{self.config.get_random_element(answer_suffixes)}"
})
elif data_type == "physical":
# 物理模型模板
table_name = item.get("表名", "物理模型")
templates = []
if item.get("物理模型属性中文名") and item.get("值类型"):
templates.append((f"{table_name}中,「{item['物理模型属性中文名']}」的值类型是什么?", f"该属性的值类型为「{item['值类型']}"))
if item.get("物理模型属性中文名") and item.get("长度"):
templates.append((f"请问「{item['物理模型属性中文名']}」在{table_name}中的长度是多少?", f"该属性长度为{item['长度']}"))
if item.get("物理模型属性中文名") and item.get("小数位") is not None:
templates.append((f"请确认{table_name}中「{item['物理模型属性中文名']}」的小数位设置", f"该属性小数位为{item['小数位']}"))
if item.get("物理模型属性中文名") and item.get("关联数据元素"):
templates.append((f"{item['物理模型属性中文名']}」这个属性在{table_name}中关联哪个数据元素?", f"该属性关联的数据元素为「{item['关联数据元素']}"))
if item.get("物理模型属性中文名") and item.get("物理模型中文名"):
templates.append((f"请查找「{item['物理模型属性中文名']}」属于哪个物理模型?", f"该属性属于「{item['物理模型中文名']}"))
if item.get("物理模型属性中文名") and item.get("说明"):
templates.append((f"请说明「{item['物理模型属性中文名']}」在{table_name}中的作用和用途", f"该属性的说明为「{item['说明']}"))
if item.get("物理模型属性英文名") and item.get("物理模型属性中文名"):
templates.append((f"{table_name}中,属性「{item['物理模型属性英文名']}」的中文名是什么?", f"该属性的中文名为「{item['物理模型属性中文名']}"))
if item.get("物理模型中文名") and item.get("物理模型英文名"):
templates.append((f"{item['物理模型中文名']}」对应的物理模型英文名是什么?", f"该物理模型的英文名为「{item['物理模型英文名']}"))
if item.get("物理模型英文名") and item.get("物理模型中文名"):
templates.append((f"{table_name}中,物理模型「{item['物理模型英文名']}」的中文名是什么?", f"该物理模型的中文名为「{item['物理模型中文名']}"))
# 生成QA
for i, (question, answer) in enumerate(templates[:template_count]):
qa_pairs.append({
"instruct": question,
"input": "",
"output": f"{self.config.get_random_element(answer_prefixes)}{answer}{self.config.get_random_element(answer_suffixes)}"
})
elif data_type == "logical":
# 逻辑模型模板
table_name = item.get("表名", "逻辑模型")
templates = []
if item.get("业务领域") and item.get("逻辑模型中文名"):
templates.append((f"{table_name}中,业务领域为「{item['业务领域']}」的逻辑模型中文名是什么?", f"该业务领域对应的逻辑模型中文名为「{item['逻辑模型中文名']}"))
if item.get("业务领域") and item.get("逻辑模型中文名"):
templates.append((f"{item['逻辑模型中文名']}」这个逻辑模型在{table_name}中属于哪个业务领域?", f"该逻辑模型属于「{item['业务领域']}"))
if item.get("字段英文名") and item.get("字段中文名"):
templates.append((f"请告诉我「{item['字段英文名']}」在{table_name}中的中文名是什么?", f"该字段的中文名为「{item['字段中文名']}"))
if item.get("字段中文名") and item.get("字段英文名"):
templates.append((f"{table_name}中,「{item['字段中文名']}」对应的英文名是什么?", f"该字段的英文名为「{item['字段英文名']}"))
if item.get("字段中文名") and item.get("值类型"):
templates.append((f"请问「{item['字段中文名']}」在{table_name}中的值类型是什么?", f"该字段的值类型为「{item['值类型']}"))
if item.get("字段中文名") and item.get("长度"):
templates.append((f"查询{table_name}中,「{item['字段中文名']}」的长度是多少?", f"该字段长度为{item['长度']}"))
if item.get("字段中文名") and item.get("小数位") is not None:
templates.append((f"请确认「{item['字段中文名']}」在{table_name}中的小数位设置", f"该字段小数位为{item['小数位']}"))
if item.get("逻辑模型中文名") and item.get("动态查询能力"):
templates.append((f"{item['逻辑模型中文名']}」的动态查询能力是什么级别?", f"该逻辑模型的动态查询能力为「{item['动态查询能力']}"))
if item.get("字段中文名") and item.get("关联数据元素英文名"):
templates.append((f"{table_name}中,「{item['字段中文名']}」关联的数据元素英文名是什么?", f"该字段关联的数据元素英文名为「{item['关联数据元素英文名']}"))
if item.get("逻辑模型中文名") and item.get("逻辑模型英文名"):
templates.append((f"请查找「{item['逻辑模型中文名']}」对应的逻辑模型英文名", f"该逻辑模型的英文名为「{item['逻辑模型英文名']}"))
# 生成QA
for i, (question, answer) in enumerate(templates[:template_count]):
qa_pairs.append({
"instruct": question,
"input": "",
"output": f"{self.config.get_random_element(answer_prefixes)}{answer}{self.config.get_random_element(answer_suffixes)}"
})
return qa_pairs
def generate_multi_qa(self, item: Dict, template_count: int, data_type: str) -> List[Dict]:
"""生成多列QA"""
qa_pairs = []
answer_prefixes = self.config.ANSWER_PREFIXES
answer_suffixes = self.config.ANSWER_SUFFIXES
connectors = self.config.CONNECTORS
if data_type == "element":
table_name = item.get("表名", "元素治理模板")
templates = []
if item.get("数据元素中文名") and item.get("值类型") and item.get("总长度"):
connector = self.config.get_random_element(connectors)
templates.append((
f"请列举{table_name}中「{item['数据元素中文名']}」的值类型和总长度",
f"该数据元素的{connector}值类型为「{item['值类型']}」,总长度为{item['总长度']}"
))
if item.get("数据元素中文名") and item.get("类别") and item.get("业务领域名称") and item.get("是否枚举"):
connector1 = self.config.get_random_element(connectors[:3])
connector2 = self.config.get_random_element(connectors[3:])
templates.append((
f"请输出「{item['数据元素中文名']}」在{table_name}中的类别、业务领域和是否枚举信息",
f"该数据元素的类别为「{item['类别']}」,{connector1}业务领域为「{item['业务领域名称']}」,{connector2}是否枚举为「{item['是否枚举']}"
))
if item.get("数据元素中文名") and item.get("值类型") and item.get("总长度") and item.get("小数位"):
connector = self.config.get_random_element(connectors)
templates.append((
f"{table_name}中,「{item['数据元素中文名']}」的值类型、长度和小数位分别是多少?",
f"该数据元素的值类型为「{item['值类型']}」,{connector}长度为{item['总长度']},小数位为{item['小数位']}"
))
if item.get("数据元素中文名") and item.get("数据元素英文名") and item.get("说明"):
connector = self.config.get_random_element(connectors)
templates.append((
f"请查找「{item['数据元素中文名']}」在{table_name}中的英文名和说明信息",
f"该数据元素的英文名为「{item['数据元素英文名']}」,{connector}说明为「{item['说明']}"
))
if item.get("数据元素中文名") and item.get("枚举数量") and item.get("元素取值范围\n(枚举类型名称)"):
connector = self.config.get_random_element(connectors)
range_value = item['元素取值范围\n(枚举类型名称)'] or ""
templates.append((
f"{table_name}中,「{item['数据元素中文名']}」的枚举数量和取值范围分别是什么?",
f"该数据元素的枚举数量为{item['枚举数量']}{connector}取值范围为「{range_value}"
))
# 生成QA
for i, (question, answer) in enumerate(templates[:template_count]):
qa_pairs.append({
"instruct": question,
"input": "",
"output": f"{self.config.get_random_element(answer_prefixes)}{answer}{self.config.get_random_element(answer_suffixes)}"
})
# 为物理模型和逻辑模型也添加类似的多列模板...
elif data_type == "physical":
table_name = item.get("表名", "物理模型")
if item.get("物理模型属性中文名") and item.get("值类型") and item.get("长度"):
connector = self.config.get_random_element(connectors)
qa_pairs.append({
"instruct": f"请列举「{item['物理模型属性中文名']}」在{table_name}中的值类型和长度",
"input": "",
"output": f"{self.config.get_random_element(answer_prefixes)}该属性的{connector}值类型为「{item['值类型']}」,长度为{item['长度']}{self.config.get_random_element(answer_suffixes)}"
})
elif data_type == "logical":
table_name = item.get("表名", "逻辑模型")
if item.get("字段中文名") and item.get("值类型") and item.get("长度"):
connector = self.config.get_random_element(connectors)
qa_pairs.append({
"instruct": f"请列举「{item['字段中文名']}」在{table_name}中的值类型和长度",
"input": "",
"output": f"{self.config.get_random_element(answer_prefixes)}该字段的{connector}值类型为「{item['值类型']}」,长度为{item['长度']}{self.config.get_random_element(answer_suffixes)}"
})
return qa_pairs
def generate_qa_for_data(self, data: List[Dict], data_type: str) -> List[Dict]:
"""为指定数据类型生成QA"""
all_qa = []
for item in data:
# 生成单列QA
single_qa = self.generate_single_qa(item, self.config.SINGLE_TEMPLATES, data_type)
all_qa.extend(single_qa)
# 根据概率生成多列QA
if random.random() < self.config.MULTI_RATIO:
multi_qa = self.generate_multi_qa(item, self.config.MULTI_TEMPLATES, data_type)
all_qa.extend(multi_qa)
return all_qa
def shuffle_qa_pairs(self, qa_pairs: List[Dict]) -> List[Dict]:
"""随机打乱问答对顺序"""
if self.config.SHUFFLE_OUTPUT:
random.shuffle(qa_pairs)
return qa_pairs
def save_qa(self, qa_pairs: List[Dict], filename: str):
"""保存QA到文件"""
output_path = os.path.join(self.config.OUTPUT_DIR, filename)
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(qa_pairs, f, ensure_ascii=False, indent=2)
size_mb = os.path.getsize(output_path) / (1024 * 1024)
if self.config.VERBOSE_LOG:
print(f"[OK] 已生成: {output_path} (共 {len(qa_pairs)} 条问答对, {size_mb:.1f}MB)")
else:
print(f"[OK] 已生成: {output_path} (共 {len(qa_pairs)} 条问答对)")
def merge_to_train(self, output_file: str = "train.json") -> Dict:
"""合并QA文件为train.json"""
all_qa_pairs = []
file_stats = {}
total_files = 0
# 遍历输出目录中的所有QA文件
for filename in os.listdir(self.config.OUTPUT_DIR):
if filename.endswith('.json') and not filename.startswith('QA生成报告') and filename != 'train_stats.json':
file_path = os.path.join(self.config.OUTPUT_DIR, filename)
try:
with open(file_path, 'r', encoding='utf-8') as f:
qa_data = json.load(f)
if isinstance(qa_data, list) and all(isinstance(item, dict) and 'instruct' in item and 'output' in item for item in qa_data):
all_qa_pairs.extend(qa_data)
file_stats[filename] = len(qa_data)
total_files += 1
if self.config.VERBOSE_LOG:
print(f"[OK] 已合并: {filename} ({len(qa_data)} 条问答对)")
except Exception as e:
print(f"[ERROR] 读取文件失败 {filename}: {str(e)}")
# 打乱顺序
if self.config.SHUFFLE_OUTPUT and all_qa_pairs:
random.shuffle(all_qa_pairs)
if self.config.VERBOSE_LOG:
print(f"\n[INFO] 已打乱 {len(all_qa_pairs)} 条问答对顺序")
# 保存train.json
train_path = os.path.join(self.config.OUTPUT_DIR, output_file)
with open(train_path, 'w', encoding='utf-8') as f:
json.dump(all_qa_pairs, f, ensure_ascii=False, indent=2)
file_size_mb = os.path.getsize(train_path) / (1024 * 1024)
# 生成统计
stats = {
"合并时间": "2025-12-18",
"处理文件数": total_files,
"各文件问答对数量": file_stats,
"总问答对数量": len(all_qa_pairs),
"输出文件大小": f"{file_size_mb:.2f} MB",
"打乱顺序": self.config.SHUFFLE_OUTPUT
}
# 保存统计信息
stats_file = os.path.join(self.config.OUTPUT_DIR, "train_stats.json")
with open(stats_file, 'w', encoding='utf-8') as f:
json.dump(stats, f, ensure_ascii=False, indent=2)
print(f"\n[SUCCESS] 合并完成!")
print(f"[OUTPUT] {train_path}")
print(f"[TOTAL] 总计: {len(all_qa_pairs):,} 条问答对")
print(f"[SIZE] 文件大小: {file_size_mb:.2f} MB")
return stats
def generate_report(self, qa_counts: Dict[str, int]):
"""生成生成报告"""
if not self.config.GENERATE_REPORT:
return
report = {
"生成时间": "2025-12-18",
"版本": "整合版",
"配置信息": {
"复杂程度等级": self.config.COMPLEXITY_LEVEL,
"随机种子": self.config.RANDOM_SEED,
"多列查询占比": self.config.MULTI_COLUMN_RATIO,
"打乱输出": self.config.SHUFFLE_OUTPUT
},
"总文件数": len(qa_counts),
"各文件问答对数量": qa_counts,
"总计问答对数量": sum(qa_counts.values()),
"说明": "所有问答对均基于原始JSON数据生成未进行任何编撰或修改"
}
report_path = os.path.join(self.config.OUTPUT_DIR, "QA生成报告.json")
with open(report_path, 'w', encoding='utf-8') as f:
json.dump(report, f, ensure_ascii=False, indent=2)
if self.config.VERBOSE_LOG:
print(f"[OK] 已生成: {report_path}")
def process_all(self):
# 清理输出目录中的旧文件
print("[INFO] 清理输出目录中的旧文件...")
if os.path.exists(self.config.OUTPUT_DIR):
for filename in os.listdir(self.config.OUTPUT_DIR):
file_path = os.path.join(self.config.OUTPUT_DIR, filename)
try:
if os.path.isfile(file_path):
os.remove(file_path)
if self.config.VERBOSE_LOG:
print(f" [DEL] 已删除: {filename}")
except Exception as e:
print(f" [WARN] 删除文件失败 {filename}: {str(e)}")
"""处理所有数据文件"""
qa_counts = {}
for file_info in self.config.DATA_FILES:
if not file_info["enabled"]:
if self.config.VERBOSE_LOG:
print(f"[SKIP] 已跳过: {file_info['name']}")
continue
if self.config.VERBOSE_LOG:
print(f"\n[INFO] 正在处理: {file_info['name']}.json")
else:
print(f"[INFO] 正在处理: {file_info['name']}")
input_file = os.path.join(self.config.INPUT_DIR, file_info["file"])
if not os.path.exists(input_file):
print(f"[WARNING] 文件不存在: {input_file}")
continue
try:
data = self.load_json(input_file)
# 根据文件名确定数据类型
if "元素治理" in file_info["name"]:
data_type = "element"
elif "物理模型" in file_info["name"]:
data_type = "physical"
elif "逻辑模型" in file_info["name"]:
data_type = "logical"
else:
print(f"[WARNING] 未知的文件类型: {file_info['name']}")
continue
qa_pairs = self.generate_qa_for_data(data, data_type)
qa_pairs = self.shuffle_qa_pairs(qa_pairs)
self.save_qa(qa_pairs, file_info["output"])
qa_counts[file_info["output"]] = len(qa_pairs)
except Exception as e:
print(f"[ERROR] 处理文件 {file_info['name']} 时出错: {str(e)}")
# 生成报告
if self.config.GENERATE_REPORT:
if self.config.VERBOSE_LOG:
print("\n[INFO] 正在生成: 生成报告")
else:
print("\n[INFO] 正在生成: 生成报告")
self.generate_report(qa_counts)
if self.config.VERBOSE_LOG:
print(f"\n[DONE] 所有文件处理完成!")
print(f"[OUT] 输出目录: {self.config.OUTPUT_DIR}")
print(f"[TOTAL] 总计生成: {sum(qa_counts.values())} 条问答对")
else:
print("\n[DONE] 所有文件处理完成!")
# 预设配置
SIMPLE_CONFIG = QAConfig()
SIMPLE_CONFIG.COMPLEXITY_LEVEL = 1
SIMPLE_CONFIG._init_templates()
NORMAL_CONFIG = QAConfig()
NORMAL_CONFIG.COMPLEXITY_LEVEL = 3
NORMAL_CONFIG._init_templates()
COMPLEX_CONFIG = QAConfig()
COMPLEX_CONFIG.COMPLEXITY_LEVEL = 5
COMPLEX_CONFIG._init_templates()
def main():
"""主函数 - 交互式运行"""
print("="*60)
print("QA生成器 - 整合版")
print("="*60)
print("\n可用的配置预设:")
print("1. SIMPLE_CONFIG - 简单模式 (复杂程度=1)")
print("2. NORMAL_CONFIG - 普通模式 (复杂程度=3)")
print("3. COMPLEX_CONFIG - 复杂模式 (复杂程度=5)")
print("4. 自定义配置")
choice = input("\n请选择配置 (1-4): ").strip()
if choice == "1":
config = SIMPLE_CONFIG
print("\n✓ 已选择: 简单模式")
elif choice == "2":
config = NORMAL_CONFIG
print("\n✓ 已选择: 普通模式")
elif choice == "3":
config = COMPLEX_CONFIG
print("\n✓ 已选择: 复杂模式")
elif choice == "4":
print("\n自定义配置:")
config = QAConfig()
complexity = input(f"复杂程度等级 (1-5, 当前:{config.COMPLEXITY_LEVEL}): ").strip()
if complexity:
config.COMPLEXITY_LEVEL = int(complexity)
config._init_templates()
multi_ratio = input(f"多列查询占比 0.0-1.0 (当前:{config.MULTI_COLUMN_RATIO}): ").strip()
if multi_ratio:
config.MULTI_COLUMN_RATIO = float(multi_ratio)
config._init_templates()
print("\n✓ 已应用自定义配置")
else:
print("\n无效选择,使用默认配置")
config = NORMAL_CONFIG
# 显示配置信息
config.print_config()
# 创建生成器并处理
generator = QAGenerator(config)
generator.process_all()
# 询问是否合并为train.json
merge_choice = input("\n是否合并为train.json? (y/n): ").strip().lower()
if merge_choice == 'y':
generator.merge_to_train()
if __name__ == "__main__":
main()