Files
YG_TDgenerator/qa_generator.py
DESKTOP-72TV0V4\caoxiaozhu f408c87564 1. 修改了表名表述的问题
2. 修改了生成别的列表述中文字段名的问题
2025-12-23 16:21:01 +08:00

1251 lines
56 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
QA生成器 - 整合版
整合所有功能于一个文件,减少冗余
"""
import json
import os
import random
from typing import List, Dict, Any
class QAConfig:
"""QA生成配置类"""
def __init__(self):
# 基础配置
self.RANDOM_SEED = 42
self.INPUT_DIR = "Data_Export_Json"
self.OUTPUT_DIR = "Data_QA_Outputs"
# 复杂程度控制 (1-5)
self.COMPLEXITY_LEVEL = 3
# 问题数量控制
self.BASIC_QUESTIONS_PER_ITEM = 1
self.MAX_QUESTIONS_PER_ITEM = 10
self.MULTI_COLUMN_RATIO = 0.3
# 输出控制
self.SHUFFLE_OUTPUT = True
self.GENERATE_REPORT = True
self.VERBOSE_LOG = True
# 数据文件配置
self.DATA_FILES = [
{"name": "元素治理模板", "file": "远光数据架构元素治理模板表.json", "output": "远光数据架构元素治理模板表.json", "enabled": True},
{"name": "物理模型", "file": "远光数据架构物理模型表.json", "output": "远光数据架构物理模型表.json", "enabled": True},
{"name": "逻辑模型", "file": "远光数据架构逻辑模型表.json", "output": "远光数据架构逻辑模型表.json", "enabled": True}
]
# 初始化修饰语和连接词
self._init_templates()
def _init_templates(self):
"""初始化模板列表"""
if self.COMPLEXITY_LEVEL <= 2:
# 简单模式
self.QUESTION_PREFIXES = ["请告诉我", "查询", "请问"]
self.ANSWER_PREFIXES = ["根据表记录,该字段的", "查询结果显示,", "经查询,该字段的"]
self.ANSWER_SUFFIXES = ["", ""]
self.CONNECTORS = ["", ""]
self.SINGLE_TEMPLATES = 6 if self.COMPLEXITY_LEVEL == 2 else 3
self.MULTI_TEMPLATES = 1 if self.COMPLEXITY_LEVEL == 2 else 0
self.MULTI_RATIO = 0.1 if self.COMPLEXITY_LEVEL == 2 else 0.0
else:
# 普通/复杂模式
self.QUESTION_PREFIXES = ["请告诉我", "查询", "请问", "", "请解释", "请输出", "请列举", "请说明", "请查找", "请确认"]
self.ANSWER_PREFIXES = ["根据表记录,该字段的", "查询结果显示,", "经查询,该字段的", "根据数据库记录,", "在表中,此字段的", "查询结果:", "经系统查询,", "根据记录显示,", "在数据中,该字段的", "查询得知,该字段的"]
self.ANSWER_SUFFIXES = ["", ",请参考。", ",详情如上。", ",以上信息为准。", ",望知悉。", ",如需更多信息请联系。", ",希望能帮到您。", ",祝您工作顺利。", ",谢谢。", ""]
self.CONNECTORS = ["", "", "", "", ",还有", "以及"]
self.SINGLE_TEMPLATES = 12
self.MULTI_TEMPLATES = 5
self.MULTI_RATIO = self.MULTI_COLUMN_RATIO
def get_random_element(self, elements: List[str]) -> str:
"""从列表中随机获取一个元素"""
return random.choice(elements) if elements else ""
def update_config(self, **kwargs):
"""更新配置参数"""
for key, value in kwargs.items():
if hasattr(self, key):
setattr(self, key, value)
if key == 'COMPLEXITY_LEVEL':
self._init_templates() # 重新初始化模板
def print_config(self):
"""打印当前配置"""
print(f"\n[INFO] 复杂程度等级: {self.COMPLEXITY_LEVEL}")
print(f"[INFO] 单列模板数: {self.SINGLE_TEMPLATES}")
print(f"[INFO] 多列模板数: {self.MULTI_TEMPLATES}")
print(f"[INFO] 多列占比: {self.MULTI_RATIO}")
print(f"[INFO] 输出目录: {self.OUTPUT_DIR}")
class QAGenerator:
"""QA生成器 - 整合版"""
def __init__(self, config: QAConfig = None):
"""初始化生成器"""
self.config = config or QAConfig()
os.makedirs(self.config.OUTPUT_DIR, exist_ok=True)
random.seed(self.config.RANDOM_SEED)
def load_json(self, file_path: str) -> List[Dict]:
"""加载JSON文件"""
with open(file_path, 'r', encoding='utf-8') as f:
return json.load(f)
def generate_single_qa(self, item: Dict, template_count: int, data_type: str) -> List[Dict]:
"""生成单列QA - 严格基于字段中文名提问"""
qa_pairs = []
answer_prefixes = self.config.ANSWER_PREFIXES
answer_suffixes = self.config.ANSWER_SUFFIXES
if data_type == "element":
# 元素治理模板 - 严格基于"数据元素中文名"
templates = []
table_name = item.get("表名", "远光数据架构元素治理模板表")
# 只保留以数据元素中文名为标识符的模板
if item.get("数据元素中文名") and item.get("业务领域名称"):
templates.append((f"{table_name}中数据元素中文名为:{item['数据元素中文名']}属于哪个业务领域?", f"业务领域:{item['业务领域名称']}"))
if item.get("数据元素中文名") and item.get("值类型"):
templates.append((f"查询{table_name}中数据元素中文名为:{item['数据元素中文名']}的值类型是什么?", f"值类型:{item['值类型']}"))
if item.get("数据元素中文名") and item.get("总长度"):
templates.append((f"{table_name}中数据元素中文名为:{item['数据元素中文名']}的总长度设置是多少?", f"总长度:{item['总长度']}"))
if item.get("数据元素中文名") and item.get("类别"):
templates.append((f"请确认在{table_name}中数据元素中文名为:{item['数据元素中文名']}属于哪个类别?", f"类别:{item['类别']}"))
if item.get("数据元素中文名") and item.get("数据元素英文名"):
templates.append((f"{table_name}中数据元素中文名为:{item['数据元素中文名']}对应的英文名是什么?", f"英文名:{item['数据元素英文名']}"))
if item.get("数据元素中文名") and item.get("是否枚举"):
templates.append((f"{table_name}中数据元素中文名为:{item['数据元素中文名']}是否枚举?", f"是否枚举:{item['是否枚举']}"))
if item.get("数据元素中文名") and item.get("枚举数量"):
templates.append((f"请问在{table_name}中数据元素中文名为:{item['数据元素中文名']}的枚举数量是多少?", f"枚举数量:{item['枚举数量']}"))
if item.get("数据元素中文名") and item.get("小数位"):
templates.append((f"{table_name}中数据元素中文名为:{item['数据元素中文名']}的小数位设置是多少?", f"小数位:{item['小数位']}"))
if item.get("数据元素中文名") and item.get("抽象元素中文名"):
templates.append((f"{table_name}中数据元素中文名为:{item['数据元素中文名']}的抽象元素中文名是什么?", f"抽象元素中文名:{item['抽象元素中文名']}"))
if item.get("数据元素中文名") and item.get("说明"):
templates.append((f"请解释在{table_name}中数据元素中文名为:{item['数据元素中文名']}的作用和含义", f"说明:{item['说明']}"))
if item.get("数据元素中文名") and item.get("是否上线"):
templates.append((f"请问在{table_name}中数据元素中文名为:{item['数据元素中文名']}是否已上线?", f"是否上线:{item['是否上线']}"))
# 生成QA
for i, (question, answer) in enumerate(templates[:template_count]):
qa_pairs.append({
"instruct": question,
"input": "",
"output": f"{self.config.get_random_element(answer_prefixes)}{answer}{self.config.get_random_element(answer_suffixes)}"
})
elif data_type == "physical":
# 物理模型 - 严格基于"字段中文名"提问
table_name = item.get("表名", "远光数据架构物理模型表")
field_name = item.get("字段中文名")
templates = []
# 以字段中文名为主要提问对象
if field_name and item.get("值类型"):
templates.append((f"请问在{table_name}中字段中文名为:{field_name}的值类型是什么?", f"值类型:{item['值类型']}"))
if field_name and item.get("长度"):
templates.append((f"请问在{table_name}中字段中文名为:{field_name}的长度是多少?", f"长度:{item['长度']}"))
if field_name and item.get("小数位") is not None:
templates.append((f"请问在{table_name}中字段中文名为:{field_name}的小数位设置是多少?", f"小数位:{item['小数位']}"))
if field_name and item.get("关联数据元素"):
templates.append((f"请问在{table_name}中字段中文名为:{field_name}关联的数据元素是什么?", f"关联数据元素:{item['关联数据元素']}"))
if field_name and item.get("物理模型中文名"):
templates.append((f"请问在{table_name}中字段中文名为:{field_name}属于哪个物理模型?", f"物理模型:{item['物理模型中文名']}"))
if field_name and item.get("说明"):
templates.append((f"请问在{table_name}中字段中文名为:{field_name}的说明是什么?", f"说明:{item['说明']}"))
if field_name and item.get("物理模型属性英文名"):
templates.append((f"请问在{table_name}中字段中文名为:{field_name}对应的英文名是什么?", f"英文名:{item['物理模型属性英文名']}"))
# 生成QA
for i, (question, answer) in enumerate(templates[:template_count]):
qa_pairs.append({
"instruct": question,
"input": "",
"output": f"{self.config.get_random_element(answer_prefixes)}{answer}{self.config.get_random_element(answer_suffixes)}"
})
elif data_type == "logical":
# 逻辑模型 - 严格基于"字段中文名"
table_name = item.get("表名", "远光数据架构逻辑模型表")
templates = []
# 只保留以字段中文名为标识符的模板
if item.get("字段中文名") and item.get("业务领域"):
templates.append((f"{table_name}中字段中文名为:{item['字段中文名']}属于哪个业务领域?", f"业务领域:{item['业务领域']}"))
if item.get("字段中文名") and item.get("逻辑模型中文名"):
templates.append((f"{table_name}中字段中文名为:{item['字段中文名']}属于哪个逻辑模型?", f"逻辑模型:{item['逻辑模型中文名']}"))
if item.get("字段中文名") and item.get("字段英文名"):
templates.append((f"{table_name}中字段中文名为:{item['字段中文名']}对应的英文名是什么?", f"英文名:{item['字段英文名']}"))
if item.get("字段中文名") and item.get("值类型"):
templates.append((f"请问在{table_name}中字段中文名为:{item['字段中文名']}的值类型是什么?", f"值类型:{item['值类型']}"))
if item.get("字段中文名") and item.get("长度"):
templates.append((f"查询{table_name}中字段中文名为:{item['字段中文名']}的长度是多少?", f"长度:{item['长度']}"))
if item.get("字段中文名") and item.get("小数位") is not None:
templates.append((f"请确认在{table_name}中字段中文名为:{item['字段中文名']}的小数位设置", f"小数位:{item['小数位']}"))
if item.get("字段中文名") and item.get("动态查询能力"):
templates.append((f"{table_name}中字段中文名为:{item['字段中文名']}的动态查询能力是什么级别?", f"动态查询能力:{item['动态查询能力']}"))
if item.get("字段中文名") and item.get("关联数据元素英文名"):
templates.append((f"{table_name}中字段中文名为:{item['字段中文名']}关联的数据元素英文名是什么?", f"关联数据元素英文名:{item['关联数据元素英文名']}"))
# 生成QA
for i, (question, answer) in enumerate(templates[:template_count]):
qa_pairs.append({
"instruct": question,
"input": "",
"output": f"{self.config.get_random_element(answer_prefixes)}{answer}{self.config.get_random_element(answer_suffixes)}"
})
return qa_pairs
def generate_multi_field_qa(self, item: Dict, field_count: int, data_type: str) -> List[Dict]:
"""生成多字段QA - 动态选择指定数量的字段进行提问"""
qa_pairs = []
answer_prefixes = self.config.ANSWER_PREFIXES
answer_suffixes = self.config.ANSWER_SUFFIXES
connectors = self.config.CONNECTORS
# 获取可用的字段
available_fields = []
if data_type == "element":
# 元素治理模板的可查询字段
field_mapping = {
"业务领域名称": item.get("业务领域名称"),
"数据元素中文名": item.get("数据元素中文名"),
"数据元素英文名": item.get("数据元素英文名"),
"值类型": item.get("值类型"),
"总长度": item.get("总长度"),
"小数位": item.get("小数位"),
"类别": item.get("类别"),
"是否枚举": item.get("是否枚举"),
"枚举数量": item.get("枚举数量"),
"抽象元素中文名": item.get("抽象元素中文名"),
"说明": item.get("说明"),
"是否上线": item.get("是否上线")
}
elif data_type == "physical":
# 物理模型的可查询字段
field_mapping = {
"物理模型中文名": item.get("物理模型中文名"),
"物理模型英文名": item.get("物理模型英文名"),
"字段中文名": item.get("字段中文名"),
"物理模型属性英文名": item.get("物理模型属性英文名"),
"值类型": item.get("值类型"),
"长度": item.get("长度"),
"小数位": item.get("小数位"),
"关联数据元素": item.get("关联数据元素"),
"说明": item.get("说明")
}
elif data_type == "logical":
# 逻辑模型的可查询字段
field_mapping = {
"业务领域": item.get("业务领域"),
"逻辑模型中文名": item.get("逻辑模型中文名"),
"逻辑模型英文名": item.get("逻辑模型英文名"),
"字段中文名": item.get("字段中文名"),
"字段英文名": item.get("字段英文名"),
"值类型": item.get("值类型"),
"长度": item.get("长度"),
"小数位": item.get("小数位"),
"动态查询能力": item.get("动态查询能力"),
"关联数据元素英文名": item.get("关联数据元素英文名")
}
# 筛选出有值的字段
for field_name, field_value in field_mapping.items():
if field_value is not None and field_value != "":
available_fields.append((field_name, field_value))
# 如果可用字段少于要求的字段数量,使用所有可用字段
if len(available_fields) < field_count:
selected_fields = available_fields
else:
# 优先确保包含标识字段的组合
identifier_fields = {
"element": "数据元素中文名",
"physical": "字段中文名",
"logical": "字段中文名"
}
identifier_field = identifier_fields.get(data_type)
identifier_value = None
# 查找标识字段的值
for field_name, field_value in available_fields:
if field_name == identifier_field:
identifier_value = (field_name, field_value)
break
if identifier_value:
# 创建包含标识字段的字段列表
other_fields = [f for f in available_fields if f[0] != identifier_field]
# 如果其他字段足够,随机选择剩余字段
if len(other_fields) >= field_count - 1:
selected_others = random.sample(other_fields, field_count - 1)
selected_fields = [identifier_value] + selected_others
else:
# 如果其他字段不够,包含所有其他字段
selected_fields = [identifier_value] + other_fields
else:
# 如果没找到标识字段,随机选择
selected_fields = random.sample(available_fields, field_count)
# 生成问题
question_parts = []
answer_parts = []
table_name = item.get("表名", "远光数据架构表")
# 当字段数量为1时直接询问这个字段
if field_count == 1:
# 直接使用第一个字段作为标识和要询问的字段
field_name, field_value = selected_fields[0]
question = f"请告诉我{table_name}{field_value}是什么"
answer = f"{field_name}{field_value}"
qa_pairs.append({
"instruct": question,
"input": "",
"output": f"{self.config.get_random_element(answer_prefixes)}{answer}{self.config.get_random_element(answer_suffixes)}"
})
return qa_pairs
# 找到标识字段(固定使用字段中文名)和要询问的字段
main_field = None
query_fields = []
# 查找标识字段
identifier_fields = {
"element": "数据元素中文名",
"physical": "字段中文名",
"logical": "字段中文名"
}
identifier_field = identifier_fields.get(data_type)
# 查找标识字段的值
for field_name, field_value in selected_fields:
if field_name == identifier_field:
main_field = field_value
break
# 如果没找到标识字段,使用第一个字段作为主标识
if main_field is None:
if selected_fields:
main_field = selected_fields[0][1] # 使用第一个字段的值作为标识
# 收集要询问的字段(使用其他字段)
for field_name, field_value in selected_fields[1:]:
query_fields.append((field_name, field_value))
question_parts.append(field_name)
else:
return qa_pairs
else:
# 收集要询问的字段(排除标识字段)
for field_name, field_value in selected_fields:
if field_name != identifier_field:
query_fields.append((field_name, field_value))
question_parts.append(field_name)
# 如果没有要询问的字段,使用主字段本身作为询问内容
if not query_fields:
query_fields = [(identifier_field, main_field)]
question_parts = [identifier_field]
# 构建问题文本
if len(question_parts) == 1:
question = f"请告诉我{table_name}中字段中文名为:{main_field}{question_parts[0]}"
elif len(question_parts) == 2:
connector = self.config.get_random_element(connectors[:3])
question = f"请列举{table_name}中字段中文名为:{main_field}{question_parts[0]}{connector}{question_parts[1]}"
else:
connector1 = self.config.get_random_element(connectors[:3])
connector2 = self.config.get_random_element(connectors[3:])
question = f"请列举{table_name}中字段中文名为:{main_field}{question_parts[0]}{connector1}{question_parts[1]}{connector2}{question_parts[2]}"
# 构建答案
for field_name, field_value in query_fields:
answer_parts.append(f"{field_name}{field_value}")
if len(answer_parts) == 1:
answer = answer_parts[0]
elif len(answer_parts) == 2:
answer = f"{answer_parts[0]}{answer_parts[1]}"
else:
answer = "".join(answer_parts)
# 生成QA对
qa_pairs.append({
"instruct": question,
"input": "",
"output": f"{self.config.get_random_element(answer_prefixes)}{answer}{self.config.get_random_element(answer_suffixes)}"
})
return qa_pairs
def generate_qa_for_data(self, data: List[Dict], data_type: str) -> List[Dict]:
"""为指定数据类型生成QA"""
all_qa = []
for item in data:
# 生成单列QA
single_qa = self.generate_single_qa(item, self.config.SINGLE_TEMPLATES, data_type)
all_qa.extend(single_qa)
# 根据概率生成多列QA
if random.random() < self.config.MULTI_RATIO:
multi_qa = self.generate_multi_field_qa(item, 2, data_type) # 默认生成2列问题
all_qa.extend(multi_qa)
return all_qa
def shuffle_qa_pairs(self, qa_pairs: List[Dict]) -> List[Dict]:
"""随机打乱问答对顺序"""
if self.config.SHUFFLE_OUTPUT:
random.shuffle(qa_pairs)
return qa_pairs
def save_qa(self, qa_pairs: List[Dict], filename: str):
"""保存QA到文件"""
output_path = os.path.join(self.config.OUTPUT_DIR, filename)
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(qa_pairs, f, ensure_ascii=False, indent=2)
size_mb = os.path.getsize(output_path) / (1024 * 1024)
if self.config.VERBOSE_LOG:
print(f"[OK] 已生成: {output_path} (共 {len(qa_pairs)} 条问答对, {size_mb:.1f}MB)")
else:
print(f"[OK] 已生成: {output_path} (共 {len(qa_pairs)} 条问答对)")
def merge_to_train(self, output_file: str = "train.json") -> Dict:
"""合并QA文件为train.json"""
all_qa_pairs = []
file_stats = {}
total_files = 0
# 遍历输出目录中的所有QA文件
for filename in os.listdir(self.config.OUTPUT_DIR):
if filename.endswith('.json') and not filename.startswith('QA生成报告') and filename != 'train_stats.json':
file_path = os.path.join(self.config.OUTPUT_DIR, filename)
try:
with open(file_path, 'r', encoding='utf-8') as f:
qa_data = json.load(f)
if isinstance(qa_data, list) and all(isinstance(item, dict) and 'instruct' in item and 'output' in item for item in qa_data):
all_qa_pairs.extend(qa_data)
file_stats[filename] = len(qa_data)
total_files += 1
if self.config.VERBOSE_LOG:
print(f"[OK] 已合并: {filename} ({len(qa_data)} 条问答对)")
except Exception as e:
print(f"[ERROR] 读取文件失败 {filename}: {str(e)}")
# 打乱顺序
if self.config.SHUFFLE_OUTPUT and all_qa_pairs:
random.shuffle(all_qa_pairs)
if self.config.VERBOSE_LOG:
print(f"\n[INFO] 已打乱 {len(all_qa_pairs)} 条问答对顺序")
# 保存train.json
train_path = os.path.join(self.config.OUTPUT_DIR, output_file)
with open(train_path, 'w', encoding='utf-8') as f:
json.dump(all_qa_pairs, f, ensure_ascii=False, indent=2)
file_size_mb = os.path.getsize(train_path) / (1024 * 1024)
# 生成统计
stats = {
"合并时间": "2025-12-18",
"处理文件数": total_files,
"各文件问答对数量": file_stats,
"总问答对数量": len(all_qa_pairs),
"输出文件大小": f"{file_size_mb:.2f} MB",
"打乱顺序": self.config.SHUFFLE_OUTPUT
}
# 保存统计信息
stats_file = os.path.join(self.config.OUTPUT_DIR, "train_stats.json")
with open(stats_file, 'w', encoding='utf-8') as f:
json.dump(stats, f, ensure_ascii=False, indent=2)
print(f"\n[SUCCESS] 合并完成!")
print(f"[OUTPUT] {train_path}")
print(f"[TOTAL] 总计: {len(all_qa_pairs):,} 条问答对")
print(f"[SIZE] 文件大小: {file_size_mb:.2f} MB")
return stats
def generate_report(self, qa_counts: Dict[str, int]):
"""生成生成报告"""
if not self.config.GENERATE_REPORT:
return
report = {
"生成时间": "2025-12-18",
"版本": "整合版",
"配置信息": {
"复杂程度等级": self.config.COMPLEXITY_LEVEL,
"随机种子": self.config.RANDOM_SEED,
"多列查询占比": self.config.MULTI_COLUMN_RATIO,
"打乱输出": self.config.SHUFFLE_OUTPUT
},
"总文件数": len(qa_counts),
"各文件问答对数量": qa_counts,
"总计问答对数量": sum(qa_counts.values()),
"说明": "所有问答对均基于原始JSON数据生成未进行任何编撰或修改"
}
report_path = os.path.join(self.config.OUTPUT_DIR, "QA生成报告.json")
with open(report_path, 'w', encoding='utf-8') as f:
json.dump(report, f, ensure_ascii=False, indent=2)
if self.config.VERBOSE_LOG:
print(f"[OK] 已生成: {report_path}")
def generate_report_custom(self, qa_counts: Dict[str, int], generate_single: bool, generate_multi: bool, single_field_count: int, multi_field_count: int):
"""生成自定义配置的报告"""
if not self.config.GENERATE_REPORT:
return
report = {
"生成时间": "2025-12-19",
"版本": "整合版-自定义",
"自定义配置": {
"生成单列问题": generate_single,
"生成多列问题": generate_multi,
"单列问题字段数量": single_field_count,
"多列问题字段数量": multi_field_count,
"随机种子": self.config.RANDOM_SEED,
"打乱输出": self.config.SHUFFLE_OUTPUT
},
"总文件数": len(qa_counts),
"各文件问答对数量": qa_counts,
"总计问答对数量": sum(qa_counts.values()),
"说明": "所有问答对均基于用户自定义配置生成,灵活控制问题类型和字段数量"
}
report_path = os.path.join(self.config.OUTPUT_DIR, "QA生成报告.json")
with open(report_path, 'w', encoding='utf-8') as f:
json.dump(report, f, ensure_ascii=False, indent=2)
print(f"[OK] 已生成: {report_path}")
def process_all_custom(self, generate_single: bool, generate_multi: bool, single_field_count: int, multi_field_count: int):
"""自定义处理所有数据文件"""
# 清理输出目录中的旧文件
print("\n[INFO] 清理输出目录中的旧文件...")
if os.path.exists(self.config.OUTPUT_DIR):
for filename in os.listdir(self.config.OUTPUT_DIR):
file_path = os.path.join(self.config.OUTPUT_DIR, filename)
try:
if os.path.isfile(file_path):
os.remove(file_path)
if self.config.VERBOSE_LOG:
print(f" [DEL] 已删除: {filename}")
except Exception as e:
print(f" [WARN] 删除文件失败 {filename}: {str(e)}")
qa_counts = {}
for file_info in self.config.DATA_FILES:
if not file_info["enabled"]:
if self.config.VERBOSE_LOG:
print(f"[SKIP] 已跳过: {file_info['name']}")
continue
print(f"\n[INFO] 正在处理: {file_info['name']}")
input_file = os.path.join(self.config.INPUT_DIR, file_info["file"])
if not os.path.exists(input_file):
print(f"[WARNING] 文件不存在: {input_file}")
continue
try:
data = self.load_json(input_file)
# 根据文件名确定数据类型
if "元素治理" in file_info["name"]:
data_type = "element"
elif "物理模型" in file_info["name"]:
data_type = "physical"
elif "逻辑模型" in file_info["name"]:
data_type = "logical"
else:
print(f"[WARNING] 未知的文件类型: {file_info['name']}")
continue
# 使用自定义参数生成QA
qa_pairs = []
for item in data:
# 生成单列QA
if generate_single and single_field_count > 0:
single_qa = self.generate_single_qa(item, single_field_count, data_type)
qa_pairs.extend(single_qa)
# 生成多字段QA
if generate_multi and multi_field_count > 0:
multi_qa = self.generate_multi_field_qa(item, multi_field_count, data_type)
qa_pairs.extend(multi_qa)
qa_pairs = self.shuffle_qa_pairs(qa_pairs)
self.save_qa(qa_pairs, file_info["output"])
qa_counts[file_info["output"]] = len(qa_pairs)
except Exception as e:
print(f"[ERROR] 处理文件 {file_info['name']} 时出错: {str(e)}")
# 生成报告
if self.config.GENERATE_REPORT:
print("\n[INFO] 正在生成: 生成报告")
self.generate_report_custom(qa_counts, generate_single, generate_multi, single_field_count, multi_field_count)
print(f"\n[DONE] 所有文件处理完成!")
print(f"[OUT] 输出目录: {self.config.OUTPUT_DIR}")
print(f"[TOTAL] 总计生成: {sum(qa_counts.values())} 条问答对")
def process_all_simplified(self, fields_per_question: int, questions_per_item: int):
"""简化版处理 - 统一的问题生成方式"""
# 清理输出目录中的旧文件
print("\n[INFO] 清理输出目录中的旧文件...")
if os.path.exists(self.config.OUTPUT_DIR):
for filename in os.listdir(self.config.OUTPUT_DIR):
file_path = os.path.join(self.config.OUTPUT_DIR, filename)
try:
if os.path.isfile(file_path):
os.remove(file_path)
if self.config.VERBOSE_LOG:
print(f" [DEL] 已删除: {filename}")
except Exception as e:
print(f" [WARN] 删除文件失败 {filename}: {str(e)}")
qa_counts = {}
for file_info in self.config.DATA_FILES:
if not file_info["enabled"]:
if self.config.VERBOSE_LOG:
print(f"[SKIP] 已跳过: {file_info['name']}")
continue
print(f"\n[INFO] 正在处理: {file_info['name']}")
input_file = os.path.join(self.config.INPUT_DIR, file_info["file"])
if not os.path.exists(input_file):
print(f"[WARNING] 文件不存在: {input_file}")
continue
try:
data = self.load_json(input_file)
# 根据文件名确定数据类型
if "元素治理" in file_info["name"]:
data_type = "element"
elif "物理模型" in file_info["name"]:
data_type = "physical"
elif "逻辑模型" in file_info["name"]:
data_type = "logical"
else:
print(f"[WARNING] 未知的文件类型: {file_info['name']}")
continue
# 使用简化参数生成QA
qa_pairs = []
for item in data:
# 每个数据项生成指定数量的问题
for _ in range(questions_per_item):
# 每个问题询问指定数量的字段
multi_qa = self.generate_multi_field_qa(item, fields_per_question, data_type)
qa_pairs.extend(multi_qa)
qa_pairs = self.shuffle_qa_pairs(qa_pairs)
self.save_qa(qa_pairs, file_info["output"])
qa_counts[file_info["output"]] = len(qa_pairs)
except Exception as e:
print(f"[ERROR] 处理文件 {file_info['name']} 时出错: {str(e)}")
# 生成报告
if self.config.GENERATE_REPORT:
print("\n[INFO] 正在生成: 生成报告")
self.generate_report_simplified(qa_counts, fields_per_question, questions_per_item)
print(f"\n[DONE] 所有文件处理完成!")
print(f"[OUT] 输出目录: {self.config.OUTPUT_DIR}")
print(f"[TOTAL] 总计生成: {sum(qa_counts.values())} 条问答对")
def generate_report_simplified(self, qa_counts: Dict[str, int], fields_per_question: int, questions_per_item: int):
"""生成简化配置的报告"""
if not self.config.GENERATE_REPORT:
return
report = {
"生成时间": "2025-12-19",
"版本": "整合版-简化",
"简化配置": {
"每个问题询问字段数": fields_per_question,
"每个JSON数据项生成问题数": questions_per_item,
"随机种子": self.config.RANDOM_SEED,
"打乱输出": self.config.SHUFFLE_OUTPUT
},
"总文件数": len(qa_counts),
"各文件问答对数量": qa_counts,
"总计问答对数量": sum(qa_counts.values()),
"说明": "所有问答对均基于简化配置生成每个问题询问指定字段数每个JSON数据项生成指定问题数"
}
report_path = os.path.join(self.config.OUTPUT_DIR, "QA生成报告.json")
with open(report_path, 'w', encoding='utf-8') as f:
json.dump(report, f, ensure_ascii=False, indent=2)
print(f"[OK] 已生成: {report_path}")
def process_all_balanced(self, max_columns: int, distribution: List[int], comprehensive_count: int = 0):
"""平衡分配版处理 - 按比例分配不同列数的问题"""
# 清理输出目录中的旧文件
print("\n[INFO] 清理输出目录中的旧文件...")
if os.path.exists(self.config.OUTPUT_DIR):
for filename in os.listdir(self.config.OUTPUT_DIR):
file_path = os.path.join(self.config.OUTPUT_DIR, filename)
try:
if os.path.isfile(file_path):
os.remove(file_path)
if self.config.VERBOSE_LOG:
print(f" [DEL] 已删除: {filename}")
except Exception as e:
print(f" [WARN] 删除文件失败 {filename}: {str(e)}")
qa_counts = {}
for file_info in self.config.DATA_FILES:
if not file_info["enabled"]:
if self.config.VERBOSE_LOG:
print(f"[SKIP] 已跳过: {file_info['name']}")
continue
print(f"\n[INFO] 正在处理: {file_info['name']}")
input_file = os.path.join(self.config.INPUT_DIR, file_info["file"])
if not os.path.exists(input_file):
print(f"[WARNING] 文件不存在: {input_file}")
continue
try:
data = self.load_json(input_file)
# 根据文件名确定数据类型
if "元素治理" in file_info["name"]:
data_type = "element"
elif "物理模型" in file_info["name"]:
data_type = "physical"
elif "逻辑模型" in file_info["name"]:
data_type = "logical"
else:
print(f"[WARNING] 未知的文件类型: {file_info['name']}")
continue
# 使用平衡分配参数生成QA
qa_pairs = []
for item in data:
# 按照分配比例生成不同列数的问题
for col_count, question_count in enumerate(distribution, 1):
for _ in range(question_count):
# 生成指定列数的问题
multi_qa = self.generate_multi_field_qa(item, col_count, data_type)
qa_pairs.extend(multi_qa)
# 生成综合性问答
if comprehensive_count > 0:
for _ in range(comprehensive_count):
# 直接从item中获取表名不再使用file_info["name"]
comprehensive_qa = self.generate_comprehensive_qa(item, data_type)
qa_pairs.extend(comprehensive_qa)
qa_pairs = self.shuffle_qa_pairs(qa_pairs)
self.save_qa(qa_pairs, file_info["output"])
qa_counts[file_info["output"]] = len(qa_pairs)
except Exception as e:
print(f"[ERROR] 处理文件 {file_info['name']} 时出错: {str(e)}")
# 生成报告
if self.config.GENERATE_REPORT:
print("\n[INFO] 正在生成: 生成报告")
self.generate_report_balanced(qa_counts, max_columns, distribution, comprehensive_count)
print(f"\n[DONE] 所有文件处理完成!")
print(f"[OUT] 输出目录: {self.config.OUTPUT_DIR}")
print(f"[TOTAL] 总计生成: {sum(qa_counts.values())} 条问答对")
def generate_comprehensive_qa(self, item: Dict, data_type: str, table_name: str = None) -> List[Dict]:
"""生成综合性问答 - 询问所有字段的定义/含义"""
qa_pairs = []
answer_prefixes = self.config.ANSWER_PREFIXES
answer_suffixes = self.config.ANSWER_SUFFIXES
# 获取所有有值的字段
all_fields = {}
# 如果没有传入表名从item中获取否则使用默认值
if table_name is None:
table_name = item.get("表名", "远光数据架构表")
if data_type == "element":
# 元素治理模板的所有字段
field_mapping = {
"业务领域名称": item.get("业务领域名称"),
"数据元素中文名": item.get("数据元素中文名"),
"数据元素英文名": item.get("数据元素英文名"),
"值类型": item.get("值类型"),
"总长度": item.get("总长度"),
"小数位": item.get("小数位"),
"类别": item.get("类别"),
"是否枚举": item.get("是否枚举"),
"枚举数量": item.get("枚举数量"),
"抽象元素中文名": item.get("抽象元素中文名"),
"说明": item.get("说明"),
"是否上线": item.get("是否上线")
}
elif data_type == "physical":
# 物理模型的所有字段
field_mapping = {
"物理模型中文名": item.get("物理模型中文名"),
"物理模型英文名": item.get("物理模型英文名"),
"字段中文名": item.get("字段中文名"),
"物理模型属性英文名": item.get("物理模型属性英文名"),
"值类型": item.get("值类型"),
"长度": item.get("长度"),
"小数位": item.get("小数位"),
"关联数据元素": item.get("关联数据元素"),
"说明": item.get("说明")
}
elif data_type == "logical":
# 逻辑模型的所有字段
field_mapping = {
"业务领域": item.get("业务领域"),
"逻辑模型中文名": item.get("逻辑模型中文名"),
"逻辑模型英文名": item.get("逻辑模型英文名"),
"字段中文名": item.get("字段中文名"),
"字段英文名": item.get("字段英文名"),
"值类型": item.get("值类型"),
"长度": item.get("长度"),
"小数位": item.get("小数位"),
"动态查询能力": item.get("动态查询能力"),
"关联数据元素英文名": item.get("关联数据元素英文名")
}
# 筛选出有值的字段
for field_name, field_value in field_mapping.items():
if field_value is not None and field_value != "":
all_fields[field_name] = field_value
# 如果字段太少,跳过
if len(all_fields) < 2:
return qa_pairs
# 获取标识字段
identifier_fields = {
"element": "数据元素中文名",
"physical": "字段中文名",
"logical": "字段中文名"
}
identifier_field = identifier_fields.get(data_type)
identifier_value = all_fields.get(identifier_field)
if not identifier_value:
# 如果没有标识字段,使用第一个字段作为标识
identifier_value = list(all_fields.values())[0]
# 从字段列表中移除第一个字段
first_key = list(all_fields.keys())[0]
all_fields.pop(first_key)
else:
# 从字段列表中移除标识字段
all_fields.pop(identifier_field)
# 如果没有其他字段可询问,跳过
if not all_fields:
return qa_pairs
# 多种问法模板(加上表名进行区分)
question_templates = [
f"{table_name}中,请详细说明字段中文名为:{identifier_value}的定义和相关信息",
f"{table_name}中,字段中文名为:{identifier_value}是什么意思?请解释其含义和特征",
f"请全面介绍{table_name}中字段中文名为:{identifier_value}的概念和属性",
f"{table_name}中,字段中文名为:{identifier_value}是什么?请详细描述其特点",
f"请解释{table_name}中字段中文名为:{identifier_value}的定义、作用和属性",
f"{table_name}里,字段中文名为:{identifier_value}具体指什么?请提供详细信息",
f"请说明{table_name}中字段中文名为:{identifier_value}的含义、类型和相关属性",
f"{table_name}中,字段中文名为:{identifier_value}的概念和特征是什么?",
f"请详细介绍{table_name}中字段中文名为:{identifier_value}的定义和相关信息",
f"{table_name}中,字段中文名为:{identifier_value}是什么概念?请解释其属性和含义"
]
# 构建答案
answer_parts = []
for field_name, field_value in all_fields.items():
answer_parts.append(f"{field_name}{field_value}")
# 连接答案(使用分号分隔)
answer = "".join(answer_parts)
# 随机选择一个问法
question = self.config.get_random_element(question_templates)
# 生成QA对
qa_pairs.append({
"instruct": question,
"input": "",
"output": f"{self.config.get_random_element(answer_prefixes)}{answer}{self.config.get_random_element(answer_suffixes)}"
})
return qa_pairs
def generate_report_balanced(self, qa_counts: Dict[str, int], max_columns: int, distribution: List[int], comprehensive_count: int = 0):
"""生成平衡分配配置的报告"""
if not self.config.GENERATE_REPORT:
return
# 计算分配详情
allocation_detail = {}
for i, count in enumerate(distribution, 1):
if count > 0:
allocation_detail[f"{i}列问题"] = f"每个JSON数据项生成{count}"
# 准备配置信息
config_info = {
"最大列数": max_columns,
"问题分配详情": allocation_detail,
"随机种子": self.config.RANDOM_SEED,
"打乱输出": self.config.SHUFFLE_OUTPUT
}
# 如果有综合性问答,添加到配置中
if comprehensive_count > 0:
config_info["综合性问答"] = f"每个JSON数据项生成{comprehensive_count}"
report = {
"生成时间": "2025-12-19",
"版本": "整合版-平衡分配",
"平衡配置": config_info,
"总文件数": len(qa_counts),
"各文件问答对数量": qa_counts,
"总计问答对数量": sum(qa_counts.values()),
"说明": "所有问答对均基于平衡分配配置生成:按比例平均分配不同列数的问题,优先多列"
}
report_path = os.path.join(self.config.OUTPUT_DIR, "QA生成报告.json")
with open(report_path, 'w', encoding='utf-8') as f:
json.dump(report, f, ensure_ascii=False, indent=2)
print(f"[OK] 已生成: {report_path}")
def split_train_validation(self, train_ratio: float = 0.9, val_ratio: float = 0.1):
"""将数据分割为训练集和验证集"""
all_qa_pairs = []
# 读取所有QA文件
for filename in os.listdir(self.config.OUTPUT_DIR):
if filename.endswith('.json') and not filename.startswith('QA生成报告') and filename not in ['train.json', 'val.json', 'validation.json', 'train_stats.json']:
file_path = os.path.join(self.config.OUTPUT_DIR, filename)
try:
with open(file_path, 'r', encoding='utf-8') as f:
qa_data = json.load(f)
if isinstance(qa_data, list) and all(isinstance(item, dict) and 'instruct' in item and 'output' in item for item in qa_data):
all_qa_pairs.extend(qa_data)
except Exception as e:
print(f"[ERROR] 读取文件失败 {filename}: {str(e)}")
if not all_qa_pairs:
print("[WARNING] 没有找到可分割的QA数据")
return
# 打乱数据顺序
random.shuffle(all_qa_pairs)
# 计算分割点
total_count = len(all_qa_pairs)
train_count = int(total_count * train_ratio)
val_count = total_count - train_count
# 分割数据
train_data = all_qa_pairs[:train_count]
val_data = all_qa_pairs[train_count:]
# 保存训练集
train_path = os.path.join(self.config.OUTPUT_DIR, "train.json")
with open(train_path, 'w', encoding='utf-8') as f:
json.dump(train_data, f, ensure_ascii=False, indent=2)
# 保存验证集
val_path = os.path.join(self.config.OUTPUT_DIR, "val.json")
with open(val_path, 'w', encoding='utf-8') as f:
json.dump(val_data, f, ensure_ascii=False, indent=2)
# 计算文件大小
train_size_mb = os.path.getsize(train_path) / (1024 * 1024)
val_size_mb = os.path.getsize(val_path) / (1024 * 1024)
print(f"\n[SUCCESS] 数据分割完成!")
print(f"[TRAIN] 训练集: {train_path}")
print(f" - 数据量: {len(train_data):,}")
print(f" - 文件大小: {train_size_mb:.2f} MB")
print(f" - 比例: {train_ratio*100:.0f}%")
print(f"[VAL] 验证集: {val_path}")
print(f" - 数据量: {len(val_data):,}")
print(f" - 文件大小: {val_size_mb:.2f} MB")
print(f" - 比例: {val_ratio*100:.0f}%")
print(f"[TOTAL] 总计: {total_count:,}")
# 生成分割报告
split_report = {
"分割时间": "2025-12-19",
"分割比例": f"{train_ratio*100:.0f}% : {val_ratio*100:.0f}%",
"训练集": {
"文件路径": "train.json",
"数据量": len(train_data),
"文件大小": f"{train_size_mb:.2f} MB",
"占比": f"{train_ratio*100:.0f}%"
},
"验证集": {
"文件路径": "val.json",
"数据量": len(val_data),
"文件大小": f"{val_size_mb:.2f} MB",
"占比": f"{val_ratio*100:.0f}%"
},
"总计": {
"数据量": total_count,
"文件大小": f"{(train_size_mb + val_size_mb):.2f} MB"
}
}
split_report_path = os.path.join(self.config.OUTPUT_DIR, "数据分割报告.json")
with open(split_report_path, 'w', encoding='utf-8') as f:
json.dump(split_report, f, ensure_ascii=False, indent=2)
print(f"[OK] 已生成: {split_report_path}")
def process_all(self):
"""处理所有数据文件 - 兼容旧版本"""
# 清理输出目录中的旧文件
print("[INFO] 清理输出目录中的旧文件...")
if os.path.exists(self.config.OUTPUT_DIR):
for filename in os.listdir(self.config.OUTPUT_DIR):
file_path = os.path.join(self.config.OUTPUT_DIR, filename)
try:
if os.path.isfile(file_path):
os.remove(file_path)
if self.config.VERBOSE_LOG:
print(f" [DEL] 已删除: {filename}")
except Exception as e:
print(f" [WARN] 删除文件失败 {filename}: {str(e)}")
qa_counts = {}
for file_info in self.config.DATA_FILES:
if not file_info["enabled"]:
if self.config.VERBOSE_LOG:
print(f"[SKIP] 已跳过: {file_info['name']}")
continue
if self.config.VERBOSE_LOG:
print(f"\n[INFO] 正在处理: {file_info['name']}.json")
else:
print(f"[INFO] 正在处理: {file_info['name']}")
input_file = os.path.join(self.config.INPUT_DIR, file_info["file"])
if not os.path.exists(input_file):
print(f"[WARNING] 文件不存在: {input_file}")
continue
try:
data = self.load_json(input_file)
# 根据文件名确定数据类型
if "元素治理" in file_info["name"]:
data_type = "element"
elif "物理模型" in file_info["name"]:
data_type = "physical"
elif "逻辑模型" in file_info["name"]:
data_type = "logical"
else:
print(f"[WARNING] 未知的文件类型: {file_info['name']}")
continue
qa_pairs = self.generate_qa_for_data(data, data_type)
qa_pairs = self.shuffle_qa_pairs(qa_pairs)
self.save_qa(qa_pairs, file_info["output"])
qa_counts[file_info["output"]] = len(qa_pairs)
except Exception as e:
print(f"[ERROR] 处理文件 {file_info['name']} 时出错: {str(e)}")
# 生成报告
if self.config.GENERATE_REPORT:
if self.config.VERBOSE_LOG:
print("\n[INFO] 正在生成: 生成报告")
else:
print("\n[INFO] 正在生成: 生成报告")
self.generate_report(qa_counts)
if self.config.VERBOSE_LOG:
print(f"\n[DONE] 所有文件处理完成!")
print(f"[OUT] 输出目录: {self.config.OUTPUT_DIR}")
print(f"[TOTAL] 总计生成: {sum(qa_counts.values())} 条问答对")
else:
print("\n[DONE] 所有文件处理完成!")
# 预设配置
SIMPLE_CONFIG = QAConfig()
SIMPLE_CONFIG.COMPLEXITY_LEVEL = 1
SIMPLE_CONFIG._init_templates()
NORMAL_CONFIG = QAConfig()
NORMAL_CONFIG.COMPLEXITY_LEVEL = 3
NORMAL_CONFIG._init_templates()
COMPLEX_CONFIG = QAConfig()
COMPLEX_CONFIG.COMPLEXITY_LEVEL = 5
COMPLEX_CONFIG._init_templates()
def main():
"""主函数 - 交互式运行"""
print("="*60)
print("QA生成器 - 整合版")
print("="*60)
# 交互式设置生成参数
print("\n" + "="*60)
print("请自定义生成参数:")
print("="*60)
# 设置最大列数
print("\n1. 最大列数设置:")
print(" 指定最多生成几列内容的问题支持1列到N列平均分配")
max_columns = 1
max_columns_input = input(f"最多生成几列内容? (默认1): ").strip()
if max_columns_input:
max_columns = max(1, int(max_columns_input))
print(f" 最多生成 {max_columns} 列内容的问题")
# 设置问题数量
print("\n2. 问题数量设置:")
print(" 指定每个JSON数据项生成多少个问题")
questions_per_item = 1
questions_input = input(f"每个JSON数据项生成几个问题? (默认1): ").strip()
if questions_input:
questions_per_item = max(1, int(questions_input))
print(f" 每个JSON数据项将生成 {questions_per_item} 个问题")
# 设置是否生成综合性问答
print("\n3. 综合性问答设置:")
print(" 综合性问答指一个问题询问所有字段的定义/含义")
print(" 例如:实际过账成本的成本中心的定义是什么?")
comprehensive_choice = input("是否添加综合性问答? (1=添加, 0=不添加, 默认0): ").strip()
comprehensive_count = 0
if comprehensive_choice == '1':
comprehensive_count = input("每个JSON数据项生成几个综合性问题? (默认1): ").strip()
comprehensive_count = max(1, int(comprehensive_count)) if comprehensive_count else 1
print(f" 每个JSON数据项将额外生成 {comprehensive_count} 个综合性问题")
else:
print(" 不生成综合性问题")
# 计算列数分配
distribution = []
base_count = questions_per_item // max_columns
remainder = questions_per_item % max_columns
# 优先分配给高列数
# 先给所有列分配基础数量
for i in range(max_columns):
distribution.append(base_count)
# 剩余的优先分配给高列数
for i in range(remainder):
distribution[-(1 + i)] += 1
# 显示分配信息
print("\n" + "="*60)
print("问题类型分配:")
print("="*60)
for i, count in enumerate(distribution, 1):
if count > 0:
print(f"[OK] {i}列问题: 每个JSON数据项生成 {count}")
# 创建自定义配置
config = QAConfig()
config.SINGLE_TEMPLATES = 1
config.MULTI_TEMPLATES = 1
config.MULTI_RATIO = 1.0 # 总是生成问题
# 显示配置信息
print("\n" + "="*60)
print("当前配置:")
print("="*60)
print(f"[OK] 最大列数: {max_columns}")
print(f"[OK] 每个JSON数据项生成问题数: {questions_per_item}")
if comprehensive_count > 0:
print(f"[OK] 综合性问题: {comprehensive_count}")
print(f"[OK] 输出目录: {config.OUTPUT_DIR}")
# 创建生成器并处理
generator = QAGenerator(config)
generator.process_all_balanced(max_columns, distribution, comprehensive_count)
# 询问是否合并为train.json
merge_choice = input("\n是否合并为train.json? (y/n): ").strip().lower()
if merge_choice == 'y':
generator.merge_to_train()
# 询问是否分割训练集和验证集
split_choice = input("\n是否将数据分割为训练集和验证集 (9:1)? (y/n): ").strip().lower()
if split_choice == 'y':
generator.split_train_validation(train_ratio=0.9, val_ratio=0.1)
if __name__ == "__main__":
main()