2025-12-18 16:16:12 +08:00
|
|
|
|
#!/usr/bin/env python3
|
|
|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
|
|
"""
|
|
|
|
|
|
QA生成器 - 整合版
|
|
|
|
|
|
整合所有功能于一个文件,减少冗余
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
import json
|
|
|
|
|
|
import os
|
|
|
|
|
|
import random
|
|
|
|
|
|
from typing import List, Dict, Any
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class QAConfig:
|
|
|
|
|
|
"""QA生成配置类"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
|
# 基础配置
|
|
|
|
|
|
self.RANDOM_SEED = 42
|
|
|
|
|
|
self.INPUT_DIR = "Data_Export_Json"
|
|
|
|
|
|
self.OUTPUT_DIR = "Data_QA_Outputs"
|
|
|
|
|
|
|
|
|
|
|
|
# 复杂程度控制 (1-5)
|
|
|
|
|
|
self.COMPLEXITY_LEVEL = 3
|
|
|
|
|
|
|
|
|
|
|
|
# 问题数量控制
|
|
|
|
|
|
self.BASIC_QUESTIONS_PER_ITEM = 1
|
|
|
|
|
|
self.MAX_QUESTIONS_PER_ITEM = 10
|
|
|
|
|
|
self.MULTI_COLUMN_RATIO = 0.3
|
|
|
|
|
|
|
|
|
|
|
|
# 输出控制
|
|
|
|
|
|
self.SHUFFLE_OUTPUT = True
|
|
|
|
|
|
self.GENERATE_REPORT = True
|
|
|
|
|
|
self.VERBOSE_LOG = True
|
|
|
|
|
|
|
|
|
|
|
|
# 数据文件配置
|
|
|
|
|
|
self.DATA_FILES = [
|
2025-12-19 09:52:06 +08:00
|
|
|
|
{"name": "元素治理模板", "file": "远光数据架构元素治理模板表.json", "output": "远光数据架构元素治理模板表.json", "enabled": True},
|
|
|
|
|
|
{"name": "物理模型", "file": "远光数据架构物理模型表.json", "output": "远光数据架构物理模型表.json", "enabled": True},
|
|
|
|
|
|
{"name": "逻辑模型", "file": "远光数据架构逻辑模型表.json", "output": "远光数据架构逻辑模型表.json", "enabled": True}
|
2025-12-18 16:16:12 +08:00
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
# 初始化修饰语和连接词
|
|
|
|
|
|
self._init_templates()
|
|
|
|
|
|
|
|
|
|
|
|
def _init_templates(self):
|
|
|
|
|
|
"""初始化模板列表"""
|
|
|
|
|
|
if self.COMPLEXITY_LEVEL <= 2:
|
|
|
|
|
|
# 简单模式
|
|
|
|
|
|
self.QUESTION_PREFIXES = ["请告诉我", "查询", "请问"]
|
|
|
|
|
|
self.ANSWER_PREFIXES = ["根据表记录,该字段的", "查询结果显示,", "经查询,该字段的"]
|
|
|
|
|
|
self.ANSWER_SUFFIXES = ["。", "。"]
|
|
|
|
|
|
self.CONNECTORS = ["和", "与"]
|
|
|
|
|
|
self.SINGLE_TEMPLATES = 6 if self.COMPLEXITY_LEVEL == 2 else 3
|
|
|
|
|
|
self.MULTI_TEMPLATES = 1 if self.COMPLEXITY_LEVEL == 2 else 0
|
|
|
|
|
|
self.MULTI_RATIO = 0.1 if self.COMPLEXITY_LEVEL == 2 else 0.0
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 普通/复杂模式
|
|
|
|
|
|
self.QUESTION_PREFIXES = ["请告诉我", "查询", "请问", "在", "请解释", "请输出", "请列举", "请说明", "请查找", "请确认"]
|
|
|
|
|
|
self.ANSWER_PREFIXES = ["根据表记录,该字段的", "查询结果显示,", "经查询,该字段的", "根据数据库记录,", "在表中,此字段的", "查询结果:", "经系统查询,", "根据记录显示,", "在数据中,该字段的", "查询得知,该字段的"]
|
|
|
|
|
|
self.ANSWER_SUFFIXES = ["。", ",请参考。", ",详情如上。", ",以上信息为准。", ",望知悉。", ",如需更多信息请联系。", ",希望能帮到您。", ",祝您工作顺利。", ",谢谢。", "。"]
|
|
|
|
|
|
self.CONNECTORS = ["和", "与", "及", "、", ",还有", "以及"]
|
|
|
|
|
|
self.SINGLE_TEMPLATES = 12
|
|
|
|
|
|
self.MULTI_TEMPLATES = 5
|
|
|
|
|
|
self.MULTI_RATIO = self.MULTI_COLUMN_RATIO
|
|
|
|
|
|
|
|
|
|
|
|
def get_random_element(self, elements: List[str]) -> str:
|
|
|
|
|
|
"""从列表中随机获取一个元素"""
|
|
|
|
|
|
return random.choice(elements) if elements else ""
|
|
|
|
|
|
|
|
|
|
|
|
def update_config(self, **kwargs):
|
|
|
|
|
|
"""更新配置参数"""
|
|
|
|
|
|
for key, value in kwargs.items():
|
|
|
|
|
|
if hasattr(self, key):
|
|
|
|
|
|
setattr(self, key, value)
|
|
|
|
|
|
if key == 'COMPLEXITY_LEVEL':
|
|
|
|
|
|
self._init_templates() # 重新初始化模板
|
|
|
|
|
|
|
|
|
|
|
|
def print_config(self):
|
|
|
|
|
|
"""打印当前配置"""
|
2025-12-19 15:14:46 +08:00
|
|
|
|
print(f"\n[INFO] 复杂程度等级: {self.COMPLEXITY_LEVEL}")
|
|
|
|
|
|
print(f"[INFO] 单列模板数: {self.SINGLE_TEMPLATES}")
|
|
|
|
|
|
print(f"[INFO] 多列模板数: {self.MULTI_TEMPLATES}")
|
|
|
|
|
|
print(f"[INFO] 多列占比: {self.MULTI_RATIO}")
|
|
|
|
|
|
print(f"[INFO] 输出目录: {self.OUTPUT_DIR}")
|
2025-12-18 16:16:12 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class QAGenerator:
|
|
|
|
|
|
"""QA生成器 - 整合版"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, config: QAConfig = None):
|
|
|
|
|
|
"""初始化生成器"""
|
|
|
|
|
|
self.config = config or QAConfig()
|
|
|
|
|
|
os.makedirs(self.config.OUTPUT_DIR, exist_ok=True)
|
|
|
|
|
|
random.seed(self.config.RANDOM_SEED)
|
|
|
|
|
|
|
|
|
|
|
|
def load_json(self, file_path: str) -> List[Dict]:
|
|
|
|
|
|
"""加载JSON文件"""
|
|
|
|
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|
|
|
|
|
return json.load(f)
|
|
|
|
|
|
|
|
|
|
|
|
def generate_single_qa(self, item: Dict, template_count: int, data_type: str) -> List[Dict]:
|
2025-12-19 15:14:46 +08:00
|
|
|
|
"""生成单列QA - 严格基于字段中文名提问"""
|
2025-12-18 16:16:12 +08:00
|
|
|
|
qa_pairs = []
|
|
|
|
|
|
answer_prefixes = self.config.ANSWER_PREFIXES
|
|
|
|
|
|
answer_suffixes = self.config.ANSWER_SUFFIXES
|
|
|
|
|
|
|
|
|
|
|
|
if data_type == "element":
|
2025-12-19 15:14:46 +08:00
|
|
|
|
# 元素治理模板 - 严格基于"数据元素中文名"
|
2025-12-18 16:16:12 +08:00
|
|
|
|
templates = []
|
2025-12-23 16:21:01 +08:00
|
|
|
|
table_name = item.get("表名", "远光数据架构元素治理模板表")
|
2025-12-18 16:16:12 +08:00
|
|
|
|
|
2025-12-19 15:14:46 +08:00
|
|
|
|
# 只保留以数据元素中文名为标识符的模板
|
|
|
|
|
|
if item.get("数据元素中文名") and item.get("业务领域名称"):
|
2025-12-23 16:21:01 +08:00
|
|
|
|
templates.append((f"在{table_name}中数据元素中文名为:{item['数据元素中文名']}属于哪个业务领域?", f"业务领域:{item['业务领域名称']}"))
|
2025-12-18 16:16:12 +08:00
|
|
|
|
|
|
|
|
|
|
if item.get("数据元素中文名") and item.get("值类型"):
|
2025-12-23 16:21:01 +08:00
|
|
|
|
templates.append((f"查询{table_name}中数据元素中文名为:{item['数据元素中文名']}的值类型是什么?", f"值类型:{item['值类型']}"))
|
2025-12-18 16:16:12 +08:00
|
|
|
|
|
|
|
|
|
|
if item.get("数据元素中文名") and item.get("总长度"):
|
2025-12-23 16:21:01 +08:00
|
|
|
|
templates.append((f"在{table_name}中数据元素中文名为:{item['数据元素中文名']}的总长度设置是多少?", f"总长度:{item['总长度']}"))
|
2025-12-18 16:16:12 +08:00
|
|
|
|
|
|
|
|
|
|
if item.get("数据元素中文名") and item.get("类别"):
|
2025-12-23 16:21:01 +08:00
|
|
|
|
templates.append((f"请确认在{table_name}中数据元素中文名为:{item['数据元素中文名']}属于哪个类别?", f"类别:{item['类别']}"))
|
2025-12-18 16:16:12 +08:00
|
|
|
|
|
|
|
|
|
|
if item.get("数据元素中文名") and item.get("数据元素英文名"):
|
2025-12-23 16:21:01 +08:00
|
|
|
|
templates.append((f"在{table_name}中数据元素中文名为:{item['数据元素中文名']}对应的英文名是什么?", f"英文名:{item['数据元素英文名']}"))
|
2025-12-18 16:16:12 +08:00
|
|
|
|
|
|
|
|
|
|
if item.get("数据元素中文名") and item.get("是否枚举"):
|
2025-12-23 16:21:01 +08:00
|
|
|
|
templates.append((f"在{table_name}中数据元素中文名为:{item['数据元素中文名']}是否枚举?", f"是否枚举:{item['是否枚举']}"))
|
2025-12-18 16:16:12 +08:00
|
|
|
|
|
|
|
|
|
|
if item.get("数据元素中文名") and item.get("枚举数量"):
|
2025-12-23 16:21:01 +08:00
|
|
|
|
templates.append((f"请问在{table_name}中数据元素中文名为:{item['数据元素中文名']}的枚举数量是多少?", f"枚举数量:{item['枚举数量']}"))
|
2025-12-18 16:16:12 +08:00
|
|
|
|
|
|
|
|
|
|
if item.get("数据元素中文名") and item.get("小数位"):
|
2025-12-23 16:21:01 +08:00
|
|
|
|
templates.append((f"在{table_name}中数据元素中文名为:{item['数据元素中文名']}的小数位设置是多少?", f"小数位:{item['小数位']}"))
|
2025-12-18 16:16:12 +08:00
|
|
|
|
|
|
|
|
|
|
if item.get("数据元素中文名") and item.get("抽象元素中文名"):
|
2025-12-23 16:21:01 +08:00
|
|
|
|
templates.append((f"在{table_name}中数据元素中文名为:{item['数据元素中文名']}的抽象元素中文名是什么?", f"抽象元素中文名:{item['抽象元素中文名']}"))
|
2025-12-18 16:16:12 +08:00
|
|
|
|
|
|
|
|
|
|
if item.get("数据元素中文名") and item.get("说明"):
|
2025-12-23 16:21:01 +08:00
|
|
|
|
templates.append((f"请解释在{table_name}中数据元素中文名为:{item['数据元素中文名']}的作用和含义", f"说明:{item['说明']}"))
|
2025-12-18 16:16:12 +08:00
|
|
|
|
|
|
|
|
|
|
if item.get("数据元素中文名") and item.get("是否上线"):
|
2025-12-23 16:21:01 +08:00
|
|
|
|
templates.append((f"请问在{table_name}中数据元素中文名为:{item['数据元素中文名']}是否已上线?", f"是否上线:{item['是否上线']}"))
|
2025-12-18 16:16:12 +08:00
|
|
|
|
|
|
|
|
|
|
# 生成QA
|
|
|
|
|
|
for i, (question, answer) in enumerate(templates[:template_count]):
|
|
|
|
|
|
qa_pairs.append({
|
|
|
|
|
|
"instruct": question,
|
|
|
|
|
|
"input": "",
|
|
|
|
|
|
"output": f"{self.config.get_random_element(answer_prefixes)}{answer}{self.config.get_random_element(answer_suffixes)}"
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
elif data_type == "physical":
|
2025-12-23 16:21:01 +08:00
|
|
|
|
# 物理模型 - 严格基于"字段中文名"提问
|
|
|
|
|
|
table_name = item.get("表名", "远光数据架构物理模型表")
|
|
|
|
|
|
field_name = item.get("字段中文名")
|
2025-12-18 16:16:12 +08:00
|
|
|
|
templates = []
|
|
|
|
|
|
|
2025-12-23 16:21:01 +08:00
|
|
|
|
# 以字段中文名为主要提问对象
|
|
|
|
|
|
if field_name and item.get("值类型"):
|
|
|
|
|
|
templates.append((f"请问在{table_name}中字段中文名为:{field_name}的值类型是什么?", f"值类型:{item['值类型']}"))
|
2025-12-18 16:16:12 +08:00
|
|
|
|
|
2025-12-23 16:21:01 +08:00
|
|
|
|
if field_name and item.get("长度"):
|
|
|
|
|
|
templates.append((f"请问在{table_name}中字段中文名为:{field_name}的长度是多少?", f"长度:{item['长度']}"))
|
2025-12-18 16:16:12 +08:00
|
|
|
|
|
2025-12-23 16:21:01 +08:00
|
|
|
|
if field_name and item.get("小数位") is not None:
|
|
|
|
|
|
templates.append((f"请问在{table_name}中字段中文名为:{field_name}的小数位设置是多少?", f"小数位:{item['小数位']}"))
|
2025-12-18 16:16:12 +08:00
|
|
|
|
|
2025-12-23 16:21:01 +08:00
|
|
|
|
if field_name and item.get("关联数据元素"):
|
|
|
|
|
|
templates.append((f"请问在{table_name}中字段中文名为:{field_name}关联的数据元素是什么?", f"关联数据元素:{item['关联数据元素']}"))
|
2025-12-18 16:16:12 +08:00
|
|
|
|
|
2025-12-23 16:21:01 +08:00
|
|
|
|
if field_name and item.get("物理模型中文名"):
|
|
|
|
|
|
templates.append((f"请问在{table_name}中字段中文名为:{field_name}属于哪个物理模型?", f"物理模型:{item['物理模型中文名']}"))
|
2025-12-18 16:16:12 +08:00
|
|
|
|
|
2025-12-23 16:21:01 +08:00
|
|
|
|
if field_name and item.get("说明"):
|
|
|
|
|
|
templates.append((f"请问在{table_name}中字段中文名为:{field_name}的说明是什么?", f"说明:{item['说明']}"))
|
|
|
|
|
|
|
|
|
|
|
|
if field_name and item.get("物理模型属性英文名"):
|
|
|
|
|
|
templates.append((f"请问在{table_name}中字段中文名为:{field_name}对应的英文名是什么?", f"英文名:{item['物理模型属性英文名']}"))
|
2025-12-18 16:16:12 +08:00
|
|
|
|
|
|
|
|
|
|
# 生成QA
|
|
|
|
|
|
for i, (question, answer) in enumerate(templates[:template_count]):
|
|
|
|
|
|
qa_pairs.append({
|
|
|
|
|
|
"instruct": question,
|
|
|
|
|
|
"input": "",
|
|
|
|
|
|
"output": f"{self.config.get_random_element(answer_prefixes)}{answer}{self.config.get_random_element(answer_suffixes)}"
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
elif data_type == "logical":
|
2025-12-19 15:14:46 +08:00
|
|
|
|
# 逻辑模型 - 严格基于"字段中文名"
|
2025-12-23 16:21:01 +08:00
|
|
|
|
table_name = item.get("表名", "远光数据架构逻辑模型表")
|
2025-12-18 16:16:12 +08:00
|
|
|
|
templates = []
|
|
|
|
|
|
|
2025-12-19 15:14:46 +08:00
|
|
|
|
# 只保留以字段中文名为标识符的模板
|
|
|
|
|
|
if item.get("字段中文名") and item.get("业务领域"):
|
2025-12-23 16:21:01 +08:00
|
|
|
|
templates.append((f"在{table_name}中字段中文名为:{item['字段中文名']}属于哪个业务领域?", f"业务领域:{item['业务领域']}"))
|
2025-12-18 16:16:12 +08:00
|
|
|
|
|
2025-12-19 15:14:46 +08:00
|
|
|
|
if item.get("字段中文名") and item.get("逻辑模型中文名"):
|
2025-12-23 16:21:01 +08:00
|
|
|
|
templates.append((f"在{table_name}中字段中文名为:{item['字段中文名']}属于哪个逻辑模型?", f"逻辑模型:{item['逻辑模型中文名']}"))
|
2025-12-18 16:16:12 +08:00
|
|
|
|
|
|
|
|
|
|
if item.get("字段中文名") and item.get("字段英文名"):
|
2025-12-23 16:21:01 +08:00
|
|
|
|
templates.append((f"在{table_name}中字段中文名为:{item['字段中文名']}对应的英文名是什么?", f"英文名:{item['字段英文名']}"))
|
2025-12-18 16:16:12 +08:00
|
|
|
|
|
|
|
|
|
|
if item.get("字段中文名") and item.get("值类型"):
|
2025-12-23 16:21:01 +08:00
|
|
|
|
templates.append((f"请问在{table_name}中字段中文名为:{item['字段中文名']}的值类型是什么?", f"值类型:{item['值类型']}"))
|
2025-12-18 16:16:12 +08:00
|
|
|
|
|
|
|
|
|
|
if item.get("字段中文名") and item.get("长度"):
|
2025-12-23 16:21:01 +08:00
|
|
|
|
templates.append((f"查询{table_name}中字段中文名为:{item['字段中文名']}的长度是多少?", f"长度:{item['长度']}"))
|
2025-12-18 16:16:12 +08:00
|
|
|
|
|
|
|
|
|
|
if item.get("字段中文名") and item.get("小数位") is not None:
|
2025-12-23 16:21:01 +08:00
|
|
|
|
templates.append((f"请确认在{table_name}中字段中文名为:{item['字段中文名']}的小数位设置", f"小数位:{item['小数位']}"))
|
2025-12-18 16:16:12 +08:00
|
|
|
|
|
2025-12-19 15:14:46 +08:00
|
|
|
|
if item.get("字段中文名") and item.get("动态查询能力"):
|
2025-12-23 16:21:01 +08:00
|
|
|
|
templates.append((f"在{table_name}中字段中文名为:{item['字段中文名']}的动态查询能力是什么级别?", f"动态查询能力:{item['动态查询能力']}"))
|
2025-12-18 16:16:12 +08:00
|
|
|
|
|
|
|
|
|
|
if item.get("字段中文名") and item.get("关联数据元素英文名"):
|
2025-12-23 16:21:01 +08:00
|
|
|
|
templates.append((f"在{table_name}中字段中文名为:{item['字段中文名']}关联的数据元素英文名是什么?", f"关联数据元素英文名:{item['关联数据元素英文名']}"))
|
2025-12-18 16:16:12 +08:00
|
|
|
|
|
|
|
|
|
|
# 生成QA
|
|
|
|
|
|
for i, (question, answer) in enumerate(templates[:template_count]):
|
|
|
|
|
|
qa_pairs.append({
|
|
|
|
|
|
"instruct": question,
|
|
|
|
|
|
"input": "",
|
|
|
|
|
|
"output": f"{self.config.get_random_element(answer_prefixes)}{answer}{self.config.get_random_element(answer_suffixes)}"
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
return qa_pairs
|
|
|
|
|
|
|
2025-12-19 15:14:46 +08:00
|
|
|
|
def generate_multi_field_qa(self, item: Dict, field_count: int, data_type: str) -> List[Dict]:
|
|
|
|
|
|
"""生成多字段QA - 动态选择指定数量的字段进行提问"""
|
2025-12-18 16:16:12 +08:00
|
|
|
|
qa_pairs = []
|
|
|
|
|
|
answer_prefixes = self.config.ANSWER_PREFIXES
|
|
|
|
|
|
answer_suffixes = self.config.ANSWER_SUFFIXES
|
|
|
|
|
|
connectors = self.config.CONNECTORS
|
|
|
|
|
|
|
2025-12-19 15:14:46 +08:00
|
|
|
|
# 获取可用的字段
|
|
|
|
|
|
available_fields = []
|
2025-12-18 16:16:12 +08:00
|
|
|
|
if data_type == "element":
|
2025-12-19 15:14:46 +08:00
|
|
|
|
# 元素治理模板的可查询字段
|
|
|
|
|
|
field_mapping = {
|
|
|
|
|
|
"业务领域名称": item.get("业务领域名称"),
|
|
|
|
|
|
"数据元素中文名": item.get("数据元素中文名"),
|
|
|
|
|
|
"数据元素英文名": item.get("数据元素英文名"),
|
|
|
|
|
|
"值类型": item.get("值类型"),
|
|
|
|
|
|
"总长度": item.get("总长度"),
|
|
|
|
|
|
"小数位": item.get("小数位"),
|
|
|
|
|
|
"类别": item.get("类别"),
|
|
|
|
|
|
"是否枚举": item.get("是否枚举"),
|
|
|
|
|
|
"枚举数量": item.get("枚举数量"),
|
|
|
|
|
|
"抽象元素中文名": item.get("抽象元素中文名"),
|
|
|
|
|
|
"说明": item.get("说明"),
|
|
|
|
|
|
"是否上线": item.get("是否上线")
|
|
|
|
|
|
}
|
2025-12-18 16:16:12 +08:00
|
|
|
|
elif data_type == "physical":
|
2025-12-19 15:14:46 +08:00
|
|
|
|
# 物理模型的可查询字段
|
|
|
|
|
|
field_mapping = {
|
|
|
|
|
|
"物理模型中文名": item.get("物理模型中文名"),
|
|
|
|
|
|
"物理模型英文名": item.get("物理模型英文名"),
|
2025-12-23 16:21:01 +08:00
|
|
|
|
"字段中文名": item.get("字段中文名"),
|
2025-12-19 15:14:46 +08:00
|
|
|
|
"物理模型属性英文名": item.get("物理模型属性英文名"),
|
|
|
|
|
|
"值类型": item.get("值类型"),
|
|
|
|
|
|
"长度": item.get("长度"),
|
|
|
|
|
|
"小数位": item.get("小数位"),
|
|
|
|
|
|
"关联数据元素": item.get("关联数据元素"),
|
|
|
|
|
|
"说明": item.get("说明")
|
|
|
|
|
|
}
|
2025-12-18 16:16:12 +08:00
|
|
|
|
elif data_type == "logical":
|
2025-12-19 15:14:46 +08:00
|
|
|
|
# 逻辑模型的可查询字段
|
|
|
|
|
|
field_mapping = {
|
|
|
|
|
|
"业务领域": item.get("业务领域"),
|
|
|
|
|
|
"逻辑模型中文名": item.get("逻辑模型中文名"),
|
|
|
|
|
|
"逻辑模型英文名": item.get("逻辑模型英文名"),
|
|
|
|
|
|
"字段中文名": item.get("字段中文名"),
|
|
|
|
|
|
"字段英文名": item.get("字段英文名"),
|
|
|
|
|
|
"值类型": item.get("值类型"),
|
|
|
|
|
|
"长度": item.get("长度"),
|
|
|
|
|
|
"小数位": item.get("小数位"),
|
|
|
|
|
|
"动态查询能力": item.get("动态查询能力"),
|
|
|
|
|
|
"关联数据元素英文名": item.get("关联数据元素英文名")
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# 筛选出有值的字段
|
|
|
|
|
|
for field_name, field_value in field_mapping.items():
|
|
|
|
|
|
if field_value is not None and field_value != "":
|
|
|
|
|
|
available_fields.append((field_name, field_value))
|
|
|
|
|
|
|
|
|
|
|
|
# 如果可用字段少于要求的字段数量,使用所有可用字段
|
|
|
|
|
|
if len(available_fields) < field_count:
|
|
|
|
|
|
selected_fields = available_fields
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 优先确保包含标识字段的组合
|
|
|
|
|
|
identifier_fields = {
|
|
|
|
|
|
"element": "数据元素中文名",
|
2025-12-23 16:21:01 +08:00
|
|
|
|
"physical": "字段中文名",
|
2025-12-19 15:14:46 +08:00
|
|
|
|
"logical": "字段中文名"
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
identifier_field = identifier_fields.get(data_type)
|
|
|
|
|
|
identifier_value = None
|
|
|
|
|
|
|
|
|
|
|
|
# 查找标识字段的值
|
|
|
|
|
|
for field_name, field_value in available_fields:
|
|
|
|
|
|
if field_name == identifier_field:
|
|
|
|
|
|
identifier_value = (field_name, field_value)
|
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
if identifier_value:
|
|
|
|
|
|
# 创建包含标识字段的字段列表
|
|
|
|
|
|
other_fields = [f for f in available_fields if f[0] != identifier_field]
|
|
|
|
|
|
|
|
|
|
|
|
# 如果其他字段足够,随机选择剩余字段
|
|
|
|
|
|
if len(other_fields) >= field_count - 1:
|
|
|
|
|
|
selected_others = random.sample(other_fields, field_count - 1)
|
|
|
|
|
|
selected_fields = [identifier_value] + selected_others
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 如果其他字段不够,包含所有其他字段
|
|
|
|
|
|
selected_fields = [identifier_value] + other_fields
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 如果没找到标识字段,随机选择
|
|
|
|
|
|
selected_fields = random.sample(available_fields, field_count)
|
|
|
|
|
|
|
|
|
|
|
|
# 生成问题
|
|
|
|
|
|
question_parts = []
|
|
|
|
|
|
answer_parts = []
|
|
|
|
|
|
|
2025-12-23 16:21:01 +08:00
|
|
|
|
table_name = item.get("表名", "远光数据架构表")
|
2025-12-19 15:14:46 +08:00
|
|
|
|
|
|
|
|
|
|
# 当字段数量为1时,直接询问这个字段
|
|
|
|
|
|
if field_count == 1:
|
|
|
|
|
|
# 直接使用第一个字段作为标识和要询问的字段
|
|
|
|
|
|
field_name, field_value = selected_fields[0]
|
2025-12-23 16:21:01 +08:00
|
|
|
|
question = f"请告诉我{table_name}中{field_value}是什么"
|
|
|
|
|
|
answer = f"{field_name}:{field_value}"
|
2025-12-19 15:14:46 +08:00
|
|
|
|
|
|
|
|
|
|
qa_pairs.append({
|
|
|
|
|
|
"instruct": question,
|
|
|
|
|
|
"input": "",
|
|
|
|
|
|
"output": f"{self.config.get_random_element(answer_prefixes)}{answer}{self.config.get_random_element(answer_suffixes)}"
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
return qa_pairs
|
|
|
|
|
|
|
|
|
|
|
|
# 找到标识字段(固定使用字段中文名)和要询问的字段
|
|
|
|
|
|
main_field = None
|
|
|
|
|
|
query_fields = []
|
|
|
|
|
|
|
|
|
|
|
|
# 查找标识字段
|
|
|
|
|
|
identifier_fields = {
|
|
|
|
|
|
"element": "数据元素中文名",
|
2025-12-23 16:21:01 +08:00
|
|
|
|
"physical": "字段中文名",
|
2025-12-19 15:14:46 +08:00
|
|
|
|
"logical": "字段中文名"
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
identifier_field = identifier_fields.get(data_type)
|
|
|
|
|
|
|
|
|
|
|
|
# 查找标识字段的值
|
|
|
|
|
|
for field_name, field_value in selected_fields:
|
|
|
|
|
|
if field_name == identifier_field:
|
|
|
|
|
|
main_field = field_value
|
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
# 如果没找到标识字段,使用第一个字段作为主标识
|
|
|
|
|
|
if main_field is None:
|
|
|
|
|
|
if selected_fields:
|
|
|
|
|
|
main_field = selected_fields[0][1] # 使用第一个字段的值作为标识
|
|
|
|
|
|
# 收集要询问的字段(使用其他字段)
|
|
|
|
|
|
for field_name, field_value in selected_fields[1:]:
|
|
|
|
|
|
query_fields.append((field_name, field_value))
|
|
|
|
|
|
question_parts.append(field_name)
|
|
|
|
|
|
else:
|
|
|
|
|
|
return qa_pairs
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 收集要询问的字段(排除标识字段)
|
|
|
|
|
|
for field_name, field_value in selected_fields:
|
|
|
|
|
|
if field_name != identifier_field:
|
|
|
|
|
|
query_fields.append((field_name, field_value))
|
|
|
|
|
|
question_parts.append(field_name)
|
|
|
|
|
|
|
|
|
|
|
|
# 如果没有要询问的字段,使用主字段本身作为询问内容
|
|
|
|
|
|
if not query_fields:
|
|
|
|
|
|
query_fields = [(identifier_field, main_field)]
|
|
|
|
|
|
question_parts = [identifier_field]
|
|
|
|
|
|
|
|
|
|
|
|
# 构建问题文本
|
|
|
|
|
|
if len(question_parts) == 1:
|
2025-12-23 16:21:01 +08:00
|
|
|
|
question = f"请告诉我{table_name}中字段中文名为:{main_field}的{question_parts[0]}"
|
2025-12-19 15:14:46 +08:00
|
|
|
|
elif len(question_parts) == 2:
|
|
|
|
|
|
connector = self.config.get_random_element(connectors[:3])
|
2025-12-23 16:21:01 +08:00
|
|
|
|
question = f"请列举{table_name}中字段中文名为:{main_field}的{question_parts[0]}{connector}{question_parts[1]}"
|
2025-12-19 15:14:46 +08:00
|
|
|
|
else:
|
|
|
|
|
|
connector1 = self.config.get_random_element(connectors[:3])
|
|
|
|
|
|
connector2 = self.config.get_random_element(connectors[3:])
|
2025-12-23 16:21:01 +08:00
|
|
|
|
question = f"请列举{table_name}中字段中文名为:{main_field}的{question_parts[0]}{connector1}{question_parts[1]}{connector2}{question_parts[2]}"
|
2025-12-19 15:14:46 +08:00
|
|
|
|
|
|
|
|
|
|
# 构建答案
|
|
|
|
|
|
for field_name, field_value in query_fields:
|
2025-12-23 16:21:01 +08:00
|
|
|
|
answer_parts.append(f"{field_name}:{field_value}")
|
2025-12-19 15:14:46 +08:00
|
|
|
|
|
|
|
|
|
|
if len(answer_parts) == 1:
|
|
|
|
|
|
answer = answer_parts[0]
|
|
|
|
|
|
elif len(answer_parts) == 2:
|
2025-12-23 16:21:01 +08:00
|
|
|
|
answer = f"{answer_parts[0]};{answer_parts[1]}"
|
2025-12-19 15:14:46 +08:00
|
|
|
|
else:
|
2025-12-23 16:21:01 +08:00
|
|
|
|
answer = ";".join(answer_parts)
|
2025-12-19 15:14:46 +08:00
|
|
|
|
|
|
|
|
|
|
# 生成QA对
|
|
|
|
|
|
qa_pairs.append({
|
|
|
|
|
|
"instruct": question,
|
|
|
|
|
|
"input": "",
|
|
|
|
|
|
"output": f"{self.config.get_random_element(answer_prefixes)}{answer}{self.config.get_random_element(answer_suffixes)}"
|
|
|
|
|
|
})
|
2025-12-18 16:16:12 +08:00
|
|
|
|
|
|
|
|
|
|
return qa_pairs
|
|
|
|
|
|
|
|
|
|
|
|
def generate_qa_for_data(self, data: List[Dict], data_type: str) -> List[Dict]:
|
|
|
|
|
|
"""为指定数据类型生成QA"""
|
|
|
|
|
|
all_qa = []
|
|
|
|
|
|
|
|
|
|
|
|
for item in data:
|
|
|
|
|
|
# 生成单列QA
|
|
|
|
|
|
single_qa = self.generate_single_qa(item, self.config.SINGLE_TEMPLATES, data_type)
|
|
|
|
|
|
all_qa.extend(single_qa)
|
|
|
|
|
|
|
|
|
|
|
|
# 根据概率生成多列QA
|
|
|
|
|
|
if random.random() < self.config.MULTI_RATIO:
|
2025-12-19 15:14:46 +08:00
|
|
|
|
multi_qa = self.generate_multi_field_qa(item, 2, data_type) # 默认生成2列问题
|
2025-12-18 16:16:12 +08:00
|
|
|
|
all_qa.extend(multi_qa)
|
|
|
|
|
|
|
|
|
|
|
|
return all_qa
|
|
|
|
|
|
|
|
|
|
|
|
def shuffle_qa_pairs(self, qa_pairs: List[Dict]) -> List[Dict]:
|
|
|
|
|
|
"""随机打乱问答对顺序"""
|
|
|
|
|
|
if self.config.SHUFFLE_OUTPUT:
|
|
|
|
|
|
random.shuffle(qa_pairs)
|
|
|
|
|
|
return qa_pairs
|
|
|
|
|
|
|
|
|
|
|
|
def save_qa(self, qa_pairs: List[Dict], filename: str):
|
|
|
|
|
|
"""保存QA到文件"""
|
|
|
|
|
|
output_path = os.path.join(self.config.OUTPUT_DIR, filename)
|
|
|
|
|
|
with open(output_path, 'w', encoding='utf-8') as f:
|
|
|
|
|
|
json.dump(qa_pairs, f, ensure_ascii=False, indent=2)
|
|
|
|
|
|
|
|
|
|
|
|
size_mb = os.path.getsize(output_path) / (1024 * 1024)
|
|
|
|
|
|
if self.config.VERBOSE_LOG:
|
|
|
|
|
|
print(f"[OK] 已生成: {output_path} (共 {len(qa_pairs)} 条问答对, {size_mb:.1f}MB)")
|
|
|
|
|
|
else:
|
|
|
|
|
|
print(f"[OK] 已生成: {output_path} (共 {len(qa_pairs)} 条问答对)")
|
|
|
|
|
|
|
|
|
|
|
|
def merge_to_train(self, output_file: str = "train.json") -> Dict:
|
|
|
|
|
|
"""合并QA文件为train.json"""
|
|
|
|
|
|
all_qa_pairs = []
|
|
|
|
|
|
file_stats = {}
|
|
|
|
|
|
total_files = 0
|
|
|
|
|
|
|
|
|
|
|
|
# 遍历输出目录中的所有QA文件
|
|
|
|
|
|
for filename in os.listdir(self.config.OUTPUT_DIR):
|
|
|
|
|
|
if filename.endswith('.json') and not filename.startswith('QA生成报告') and filename != 'train_stats.json':
|
|
|
|
|
|
file_path = os.path.join(self.config.OUTPUT_DIR, filename)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|
|
|
|
|
qa_data = json.load(f)
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(qa_data, list) and all(isinstance(item, dict) and 'instruct' in item and 'output' in item for item in qa_data):
|
|
|
|
|
|
all_qa_pairs.extend(qa_data)
|
|
|
|
|
|
file_stats[filename] = len(qa_data)
|
|
|
|
|
|
total_files += 1
|
|
|
|
|
|
if self.config.VERBOSE_LOG:
|
|
|
|
|
|
print(f"[OK] 已合并: {filename} ({len(qa_data)} 条问答对)")
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"[ERROR] 读取文件失败 {filename}: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
# 打乱顺序
|
|
|
|
|
|
if self.config.SHUFFLE_OUTPUT and all_qa_pairs:
|
|
|
|
|
|
random.shuffle(all_qa_pairs)
|
|
|
|
|
|
if self.config.VERBOSE_LOG:
|
|
|
|
|
|
print(f"\n[INFO] 已打乱 {len(all_qa_pairs)} 条问答对顺序")
|
|
|
|
|
|
|
|
|
|
|
|
# 保存train.json
|
|
|
|
|
|
train_path = os.path.join(self.config.OUTPUT_DIR, output_file)
|
|
|
|
|
|
with open(train_path, 'w', encoding='utf-8') as f:
|
|
|
|
|
|
json.dump(all_qa_pairs, f, ensure_ascii=False, indent=2)
|
|
|
|
|
|
|
|
|
|
|
|
file_size_mb = os.path.getsize(train_path) / (1024 * 1024)
|
|
|
|
|
|
|
|
|
|
|
|
# 生成统计
|
|
|
|
|
|
stats = {
|
|
|
|
|
|
"合并时间": "2025-12-18",
|
|
|
|
|
|
"处理文件数": total_files,
|
|
|
|
|
|
"各文件问答对数量": file_stats,
|
|
|
|
|
|
"总问答对数量": len(all_qa_pairs),
|
|
|
|
|
|
"输出文件大小": f"{file_size_mb:.2f} MB",
|
|
|
|
|
|
"打乱顺序": self.config.SHUFFLE_OUTPUT
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# 保存统计信息
|
|
|
|
|
|
stats_file = os.path.join(self.config.OUTPUT_DIR, "train_stats.json")
|
|
|
|
|
|
with open(stats_file, 'w', encoding='utf-8') as f:
|
|
|
|
|
|
json.dump(stats, f, ensure_ascii=False, indent=2)
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\n[SUCCESS] 合并完成!")
|
|
|
|
|
|
print(f"[OUTPUT] {train_path}")
|
|
|
|
|
|
print(f"[TOTAL] 总计: {len(all_qa_pairs):,} 条问答对")
|
|
|
|
|
|
print(f"[SIZE] 文件大小: {file_size_mb:.2f} MB")
|
|
|
|
|
|
|
|
|
|
|
|
return stats
|
|
|
|
|
|
|
|
|
|
|
|
def generate_report(self, qa_counts: Dict[str, int]):
|
|
|
|
|
|
"""生成生成报告"""
|
|
|
|
|
|
if not self.config.GENERATE_REPORT:
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
report = {
|
|
|
|
|
|
"生成时间": "2025-12-18",
|
|
|
|
|
|
"版本": "整合版",
|
|
|
|
|
|
"配置信息": {
|
|
|
|
|
|
"复杂程度等级": self.config.COMPLEXITY_LEVEL,
|
|
|
|
|
|
"随机种子": self.config.RANDOM_SEED,
|
|
|
|
|
|
"多列查询占比": self.config.MULTI_COLUMN_RATIO,
|
|
|
|
|
|
"打乱输出": self.config.SHUFFLE_OUTPUT
|
|
|
|
|
|
},
|
|
|
|
|
|
"总文件数": len(qa_counts),
|
|
|
|
|
|
"各文件问答对数量": qa_counts,
|
|
|
|
|
|
"总计问答对数量": sum(qa_counts.values()),
|
|
|
|
|
|
"说明": "所有问答对均基于原始JSON数据生成,未进行任何编撰或修改"
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
report_path = os.path.join(self.config.OUTPUT_DIR, "QA生成报告.json")
|
|
|
|
|
|
with open(report_path, 'w', encoding='utf-8') as f:
|
|
|
|
|
|
json.dump(report, f, ensure_ascii=False, indent=2)
|
|
|
|
|
|
|
|
|
|
|
|
if self.config.VERBOSE_LOG:
|
|
|
|
|
|
print(f"[OK] 已生成: {report_path}")
|
|
|
|
|
|
|
2025-12-19 15:14:46 +08:00
|
|
|
|
def generate_report_custom(self, qa_counts: Dict[str, int], generate_single: bool, generate_multi: bool, single_field_count: int, multi_field_count: int):
|
|
|
|
|
|
"""生成自定义配置的报告"""
|
|
|
|
|
|
if not self.config.GENERATE_REPORT:
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
report = {
|
|
|
|
|
|
"生成时间": "2025-12-19",
|
|
|
|
|
|
"版本": "整合版-自定义",
|
|
|
|
|
|
"自定义配置": {
|
|
|
|
|
|
"生成单列问题": generate_single,
|
|
|
|
|
|
"生成多列问题": generate_multi,
|
|
|
|
|
|
"单列问题字段数量": single_field_count,
|
|
|
|
|
|
"多列问题字段数量": multi_field_count,
|
|
|
|
|
|
"随机种子": self.config.RANDOM_SEED,
|
|
|
|
|
|
"打乱输出": self.config.SHUFFLE_OUTPUT
|
|
|
|
|
|
},
|
|
|
|
|
|
"总文件数": len(qa_counts),
|
|
|
|
|
|
"各文件问答对数量": qa_counts,
|
|
|
|
|
|
"总计问答对数量": sum(qa_counts.values()),
|
|
|
|
|
|
"说明": "所有问答对均基于用户自定义配置生成,灵活控制问题类型和字段数量"
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
report_path = os.path.join(self.config.OUTPUT_DIR, "QA生成报告.json")
|
|
|
|
|
|
with open(report_path, 'w', encoding='utf-8') as f:
|
|
|
|
|
|
json.dump(report, f, ensure_ascii=False, indent=2)
|
|
|
|
|
|
|
|
|
|
|
|
print(f"[OK] 已生成: {report_path}")
|
|
|
|
|
|
|
|
|
|
|
|
def process_all_custom(self, generate_single: bool, generate_multi: bool, single_field_count: int, multi_field_count: int):
|
|
|
|
|
|
"""自定义处理所有数据文件"""
|
|
|
|
|
|
# 清理输出目录中的旧文件
|
|
|
|
|
|
print("\n[INFO] 清理输出目录中的旧文件...")
|
|
|
|
|
|
if os.path.exists(self.config.OUTPUT_DIR):
|
|
|
|
|
|
for filename in os.listdir(self.config.OUTPUT_DIR):
|
|
|
|
|
|
file_path = os.path.join(self.config.OUTPUT_DIR, filename)
|
|
|
|
|
|
try:
|
|
|
|
|
|
if os.path.isfile(file_path):
|
|
|
|
|
|
os.remove(file_path)
|
|
|
|
|
|
if self.config.VERBOSE_LOG:
|
|
|
|
|
|
print(f" [DEL] 已删除: {filename}")
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f" [WARN] 删除文件失败 {filename}: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
qa_counts = {}
|
|
|
|
|
|
|
|
|
|
|
|
for file_info in self.config.DATA_FILES:
|
|
|
|
|
|
if not file_info["enabled"]:
|
|
|
|
|
|
if self.config.VERBOSE_LOG:
|
|
|
|
|
|
print(f"[SKIP] 已跳过: {file_info['name']}")
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\n[INFO] 正在处理: {file_info['name']}")
|
|
|
|
|
|
|
|
|
|
|
|
input_file = os.path.join(self.config.INPUT_DIR, file_info["file"])
|
|
|
|
|
|
|
|
|
|
|
|
if not os.path.exists(input_file):
|
|
|
|
|
|
print(f"[WARNING] 文件不存在: {input_file}")
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
data = self.load_json(input_file)
|
|
|
|
|
|
|
|
|
|
|
|
# 根据文件名确定数据类型
|
|
|
|
|
|
if "元素治理" in file_info["name"]:
|
|
|
|
|
|
data_type = "element"
|
|
|
|
|
|
elif "物理模型" in file_info["name"]:
|
|
|
|
|
|
data_type = "physical"
|
|
|
|
|
|
elif "逻辑模型" in file_info["name"]:
|
|
|
|
|
|
data_type = "logical"
|
|
|
|
|
|
else:
|
|
|
|
|
|
print(f"[WARNING] 未知的文件类型: {file_info['name']}")
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
# 使用自定义参数生成QA
|
|
|
|
|
|
qa_pairs = []
|
|
|
|
|
|
for item in data:
|
|
|
|
|
|
# 生成单列QA
|
|
|
|
|
|
if generate_single and single_field_count > 0:
|
|
|
|
|
|
single_qa = self.generate_single_qa(item, single_field_count, data_type)
|
|
|
|
|
|
qa_pairs.extend(single_qa)
|
|
|
|
|
|
|
|
|
|
|
|
# 生成多字段QA
|
|
|
|
|
|
if generate_multi and multi_field_count > 0:
|
|
|
|
|
|
multi_qa = self.generate_multi_field_qa(item, multi_field_count, data_type)
|
|
|
|
|
|
qa_pairs.extend(multi_qa)
|
|
|
|
|
|
|
|
|
|
|
|
qa_pairs = self.shuffle_qa_pairs(qa_pairs)
|
|
|
|
|
|
self.save_qa(qa_pairs, file_info["output"])
|
|
|
|
|
|
qa_counts[file_info["output"]] = len(qa_pairs)
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"[ERROR] 处理文件 {file_info['name']} 时出错: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
# 生成报告
|
|
|
|
|
|
if self.config.GENERATE_REPORT:
|
|
|
|
|
|
print("\n[INFO] 正在生成: 生成报告")
|
|
|
|
|
|
self.generate_report_custom(qa_counts, generate_single, generate_multi, single_field_count, multi_field_count)
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\n[DONE] 所有文件处理完成!")
|
|
|
|
|
|
print(f"[OUT] 输出目录: {self.config.OUTPUT_DIR}")
|
|
|
|
|
|
print(f"[TOTAL] 总计生成: {sum(qa_counts.values())} 条问答对")
|
|
|
|
|
|
|
|
|
|
|
|
def process_all_simplified(self, fields_per_question: int, questions_per_item: int):
|
|
|
|
|
|
"""简化版处理 - 统一的问题生成方式"""
|
|
|
|
|
|
# 清理输出目录中的旧文件
|
|
|
|
|
|
print("\n[INFO] 清理输出目录中的旧文件...")
|
|
|
|
|
|
if os.path.exists(self.config.OUTPUT_DIR):
|
|
|
|
|
|
for filename in os.listdir(self.config.OUTPUT_DIR):
|
|
|
|
|
|
file_path = os.path.join(self.config.OUTPUT_DIR, filename)
|
|
|
|
|
|
try:
|
|
|
|
|
|
if os.path.isfile(file_path):
|
|
|
|
|
|
os.remove(file_path)
|
|
|
|
|
|
if self.config.VERBOSE_LOG:
|
|
|
|
|
|
print(f" [DEL] 已删除: {filename}")
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f" [WARN] 删除文件失败 {filename}: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
qa_counts = {}
|
|
|
|
|
|
|
|
|
|
|
|
for file_info in self.config.DATA_FILES:
|
|
|
|
|
|
if not file_info["enabled"]:
|
|
|
|
|
|
if self.config.VERBOSE_LOG:
|
|
|
|
|
|
print(f"[SKIP] 已跳过: {file_info['name']}")
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\n[INFO] 正在处理: {file_info['name']}")
|
|
|
|
|
|
|
|
|
|
|
|
input_file = os.path.join(self.config.INPUT_DIR, file_info["file"])
|
|
|
|
|
|
|
|
|
|
|
|
if not os.path.exists(input_file):
|
|
|
|
|
|
print(f"[WARNING] 文件不存在: {input_file}")
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
data = self.load_json(input_file)
|
|
|
|
|
|
|
|
|
|
|
|
# 根据文件名确定数据类型
|
|
|
|
|
|
if "元素治理" in file_info["name"]:
|
|
|
|
|
|
data_type = "element"
|
|
|
|
|
|
elif "物理模型" in file_info["name"]:
|
|
|
|
|
|
data_type = "physical"
|
|
|
|
|
|
elif "逻辑模型" in file_info["name"]:
|
|
|
|
|
|
data_type = "logical"
|
|
|
|
|
|
else:
|
|
|
|
|
|
print(f"[WARNING] 未知的文件类型: {file_info['name']}")
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
# 使用简化参数生成QA
|
|
|
|
|
|
qa_pairs = []
|
|
|
|
|
|
for item in data:
|
|
|
|
|
|
# 每个数据项生成指定数量的问题
|
|
|
|
|
|
for _ in range(questions_per_item):
|
|
|
|
|
|
# 每个问题询问指定数量的字段
|
|
|
|
|
|
multi_qa = self.generate_multi_field_qa(item, fields_per_question, data_type)
|
|
|
|
|
|
qa_pairs.extend(multi_qa)
|
|
|
|
|
|
|
|
|
|
|
|
qa_pairs = self.shuffle_qa_pairs(qa_pairs)
|
|
|
|
|
|
self.save_qa(qa_pairs, file_info["output"])
|
|
|
|
|
|
qa_counts[file_info["output"]] = len(qa_pairs)
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"[ERROR] 处理文件 {file_info['name']} 时出错: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
# 生成报告
|
|
|
|
|
|
if self.config.GENERATE_REPORT:
|
|
|
|
|
|
print("\n[INFO] 正在生成: 生成报告")
|
|
|
|
|
|
self.generate_report_simplified(qa_counts, fields_per_question, questions_per_item)
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\n[DONE] 所有文件处理完成!")
|
|
|
|
|
|
print(f"[OUT] 输出目录: {self.config.OUTPUT_DIR}")
|
|
|
|
|
|
print(f"[TOTAL] 总计生成: {sum(qa_counts.values())} 条问答对")
|
|
|
|
|
|
|
|
|
|
|
|
def generate_report_simplified(self, qa_counts: Dict[str, int], fields_per_question: int, questions_per_item: int):
|
|
|
|
|
|
"""生成简化配置的报告"""
|
|
|
|
|
|
if not self.config.GENERATE_REPORT:
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
report = {
|
|
|
|
|
|
"生成时间": "2025-12-19",
|
|
|
|
|
|
"版本": "整合版-简化",
|
|
|
|
|
|
"简化配置": {
|
|
|
|
|
|
"每个问题询问字段数": fields_per_question,
|
|
|
|
|
|
"每个JSON数据项生成问题数": questions_per_item,
|
|
|
|
|
|
"随机种子": self.config.RANDOM_SEED,
|
|
|
|
|
|
"打乱输出": self.config.SHUFFLE_OUTPUT
|
|
|
|
|
|
},
|
|
|
|
|
|
"总文件数": len(qa_counts),
|
|
|
|
|
|
"各文件问答对数量": qa_counts,
|
|
|
|
|
|
"总计问答对数量": sum(qa_counts.values()),
|
|
|
|
|
|
"说明": "所有问答对均基于简化配置生成:每个问题询问指定字段数,每个JSON数据项生成指定问题数"
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
report_path = os.path.join(self.config.OUTPUT_DIR, "QA生成报告.json")
|
|
|
|
|
|
with open(report_path, 'w', encoding='utf-8') as f:
|
|
|
|
|
|
json.dump(report, f, ensure_ascii=False, indent=2)
|
|
|
|
|
|
|
|
|
|
|
|
print(f"[OK] 已生成: {report_path}")
|
|
|
|
|
|
|
|
|
|
|
|
def process_all_balanced(self, max_columns: int, distribution: List[int], comprehensive_count: int = 0):
|
|
|
|
|
|
"""平衡分配版处理 - 按比例分配不同列数的问题"""
|
|
|
|
|
|
# 清理输出目录中的旧文件
|
|
|
|
|
|
print("\n[INFO] 清理输出目录中的旧文件...")
|
|
|
|
|
|
if os.path.exists(self.config.OUTPUT_DIR):
|
|
|
|
|
|
for filename in os.listdir(self.config.OUTPUT_DIR):
|
|
|
|
|
|
file_path = os.path.join(self.config.OUTPUT_DIR, filename)
|
|
|
|
|
|
try:
|
|
|
|
|
|
if os.path.isfile(file_path):
|
|
|
|
|
|
os.remove(file_path)
|
|
|
|
|
|
if self.config.VERBOSE_LOG:
|
|
|
|
|
|
print(f" [DEL] 已删除: {filename}")
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f" [WARN] 删除文件失败 {filename}: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
qa_counts = {}
|
|
|
|
|
|
|
|
|
|
|
|
for file_info in self.config.DATA_FILES:
|
|
|
|
|
|
if not file_info["enabled"]:
|
|
|
|
|
|
if self.config.VERBOSE_LOG:
|
|
|
|
|
|
print(f"[SKIP] 已跳过: {file_info['name']}")
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\n[INFO] 正在处理: {file_info['name']}")
|
|
|
|
|
|
|
|
|
|
|
|
input_file = os.path.join(self.config.INPUT_DIR, file_info["file"])
|
|
|
|
|
|
|
|
|
|
|
|
if not os.path.exists(input_file):
|
|
|
|
|
|
print(f"[WARNING] 文件不存在: {input_file}")
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
data = self.load_json(input_file)
|
|
|
|
|
|
|
|
|
|
|
|
# 根据文件名确定数据类型
|
|
|
|
|
|
if "元素治理" in file_info["name"]:
|
|
|
|
|
|
data_type = "element"
|
|
|
|
|
|
elif "物理模型" in file_info["name"]:
|
|
|
|
|
|
data_type = "physical"
|
|
|
|
|
|
elif "逻辑模型" in file_info["name"]:
|
|
|
|
|
|
data_type = "logical"
|
|
|
|
|
|
else:
|
|
|
|
|
|
print(f"[WARNING] 未知的文件类型: {file_info['name']}")
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
# 使用平衡分配参数生成QA
|
|
|
|
|
|
qa_pairs = []
|
|
|
|
|
|
for item in data:
|
|
|
|
|
|
# 按照分配比例生成不同列数的问题
|
|
|
|
|
|
for col_count, question_count in enumerate(distribution, 1):
|
|
|
|
|
|
for _ in range(question_count):
|
|
|
|
|
|
# 生成指定列数的问题
|
|
|
|
|
|
multi_qa = self.generate_multi_field_qa(item, col_count, data_type)
|
|
|
|
|
|
qa_pairs.extend(multi_qa)
|
|
|
|
|
|
|
|
|
|
|
|
# 生成综合性问答
|
|
|
|
|
|
if comprehensive_count > 0:
|
|
|
|
|
|
for _ in range(comprehensive_count):
|
2025-12-23 16:21:01 +08:00
|
|
|
|
# 直接从item中获取表名,不再使用file_info["name"]
|
|
|
|
|
|
comprehensive_qa = self.generate_comprehensive_qa(item, data_type)
|
2025-12-19 15:14:46 +08:00
|
|
|
|
qa_pairs.extend(comprehensive_qa)
|
|
|
|
|
|
|
|
|
|
|
|
qa_pairs = self.shuffle_qa_pairs(qa_pairs)
|
|
|
|
|
|
self.save_qa(qa_pairs, file_info["output"])
|
|
|
|
|
|
qa_counts[file_info["output"]] = len(qa_pairs)
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"[ERROR] 处理文件 {file_info['name']} 时出错: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
# 生成报告
|
|
|
|
|
|
if self.config.GENERATE_REPORT:
|
|
|
|
|
|
print("\n[INFO] 正在生成: 生成报告")
|
|
|
|
|
|
self.generate_report_balanced(qa_counts, max_columns, distribution, comprehensive_count)
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\n[DONE] 所有文件处理完成!")
|
|
|
|
|
|
print(f"[OUT] 输出目录: {self.config.OUTPUT_DIR}")
|
|
|
|
|
|
print(f"[TOTAL] 总计生成: {sum(qa_counts.values())} 条问答对")
|
|
|
|
|
|
|
|
|
|
|
|
def generate_comprehensive_qa(self, item: Dict, data_type: str, table_name: str = None) -> List[Dict]:
|
|
|
|
|
|
"""生成综合性问答 - 询问所有字段的定义/含义"""
|
|
|
|
|
|
qa_pairs = []
|
|
|
|
|
|
answer_prefixes = self.config.ANSWER_PREFIXES
|
|
|
|
|
|
answer_suffixes = self.config.ANSWER_SUFFIXES
|
|
|
|
|
|
|
|
|
|
|
|
# 获取所有有值的字段
|
|
|
|
|
|
all_fields = {}
|
|
|
|
|
|
|
|
|
|
|
|
# 如果没有传入表名,从item中获取,否则使用默认值
|
|
|
|
|
|
if table_name is None:
|
2025-12-23 16:21:01 +08:00
|
|
|
|
table_name = item.get("表名", "远光数据架构表")
|
2025-12-19 15:14:46 +08:00
|
|
|
|
|
|
|
|
|
|
if data_type == "element":
|
|
|
|
|
|
# 元素治理模板的所有字段
|
|
|
|
|
|
field_mapping = {
|
|
|
|
|
|
"业务领域名称": item.get("业务领域名称"),
|
|
|
|
|
|
"数据元素中文名": item.get("数据元素中文名"),
|
|
|
|
|
|
"数据元素英文名": item.get("数据元素英文名"),
|
|
|
|
|
|
"值类型": item.get("值类型"),
|
|
|
|
|
|
"总长度": item.get("总长度"),
|
|
|
|
|
|
"小数位": item.get("小数位"),
|
|
|
|
|
|
"类别": item.get("类别"),
|
|
|
|
|
|
"是否枚举": item.get("是否枚举"),
|
|
|
|
|
|
"枚举数量": item.get("枚举数量"),
|
|
|
|
|
|
"抽象元素中文名": item.get("抽象元素中文名"),
|
|
|
|
|
|
"说明": item.get("说明"),
|
|
|
|
|
|
"是否上线": item.get("是否上线")
|
|
|
|
|
|
}
|
|
|
|
|
|
elif data_type == "physical":
|
|
|
|
|
|
# 物理模型的所有字段
|
|
|
|
|
|
field_mapping = {
|
|
|
|
|
|
"物理模型中文名": item.get("物理模型中文名"),
|
|
|
|
|
|
"物理模型英文名": item.get("物理模型英文名"),
|
2025-12-23 16:21:01 +08:00
|
|
|
|
"字段中文名": item.get("字段中文名"),
|
2025-12-19 15:14:46 +08:00
|
|
|
|
"物理模型属性英文名": item.get("物理模型属性英文名"),
|
|
|
|
|
|
"值类型": item.get("值类型"),
|
|
|
|
|
|
"长度": item.get("长度"),
|
|
|
|
|
|
"小数位": item.get("小数位"),
|
|
|
|
|
|
"关联数据元素": item.get("关联数据元素"),
|
|
|
|
|
|
"说明": item.get("说明")
|
|
|
|
|
|
}
|
|
|
|
|
|
elif data_type == "logical":
|
|
|
|
|
|
# 逻辑模型的所有字段
|
|
|
|
|
|
field_mapping = {
|
|
|
|
|
|
"业务领域": item.get("业务领域"),
|
|
|
|
|
|
"逻辑模型中文名": item.get("逻辑模型中文名"),
|
|
|
|
|
|
"逻辑模型英文名": item.get("逻辑模型英文名"),
|
|
|
|
|
|
"字段中文名": item.get("字段中文名"),
|
|
|
|
|
|
"字段英文名": item.get("字段英文名"),
|
|
|
|
|
|
"值类型": item.get("值类型"),
|
|
|
|
|
|
"长度": item.get("长度"),
|
|
|
|
|
|
"小数位": item.get("小数位"),
|
|
|
|
|
|
"动态查询能力": item.get("动态查询能力"),
|
|
|
|
|
|
"关联数据元素英文名": item.get("关联数据元素英文名")
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# 筛选出有值的字段
|
|
|
|
|
|
for field_name, field_value in field_mapping.items():
|
|
|
|
|
|
if field_value is not None and field_value != "":
|
|
|
|
|
|
all_fields[field_name] = field_value
|
|
|
|
|
|
|
|
|
|
|
|
# 如果字段太少,跳过
|
|
|
|
|
|
if len(all_fields) < 2:
|
|
|
|
|
|
return qa_pairs
|
|
|
|
|
|
|
|
|
|
|
|
# 获取标识字段
|
|
|
|
|
|
identifier_fields = {
|
|
|
|
|
|
"element": "数据元素中文名",
|
2025-12-23 16:21:01 +08:00
|
|
|
|
"physical": "字段中文名",
|
2025-12-19 15:14:46 +08:00
|
|
|
|
"logical": "字段中文名"
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
identifier_field = identifier_fields.get(data_type)
|
|
|
|
|
|
identifier_value = all_fields.get(identifier_field)
|
|
|
|
|
|
|
|
|
|
|
|
if not identifier_value:
|
|
|
|
|
|
# 如果没有标识字段,使用第一个字段作为标识
|
|
|
|
|
|
identifier_value = list(all_fields.values())[0]
|
|
|
|
|
|
# 从字段列表中移除第一个字段
|
|
|
|
|
|
first_key = list(all_fields.keys())[0]
|
|
|
|
|
|
all_fields.pop(first_key)
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 从字段列表中移除标识字段
|
|
|
|
|
|
all_fields.pop(identifier_field)
|
|
|
|
|
|
|
|
|
|
|
|
# 如果没有其他字段可询问,跳过
|
|
|
|
|
|
if not all_fields:
|
|
|
|
|
|
return qa_pairs
|
|
|
|
|
|
|
|
|
|
|
|
# 多种问法模板(加上表名进行区分)
|
|
|
|
|
|
question_templates = [
|
2025-12-23 16:21:01 +08:00
|
|
|
|
f"在{table_name}中,请详细说明字段中文名为:{identifier_value}的定义和相关信息",
|
|
|
|
|
|
f"在{table_name}中,字段中文名为:{identifier_value}是什么意思?请解释其含义和特征",
|
|
|
|
|
|
f"请全面介绍{table_name}中字段中文名为:{identifier_value}的概念和属性",
|
|
|
|
|
|
f"在{table_name}中,字段中文名为:{identifier_value}是什么?请详细描述其特点",
|
|
|
|
|
|
f"请解释{table_name}中字段中文名为:{identifier_value}的定义、作用和属性",
|
|
|
|
|
|
f"在{table_name}里,字段中文名为:{identifier_value}具体指什么?请提供详细信息",
|
|
|
|
|
|
f"请说明{table_name}中字段中文名为:{identifier_value}的含义、类型和相关属性",
|
|
|
|
|
|
f"在{table_name}中,字段中文名为:{identifier_value}的概念和特征是什么?",
|
|
|
|
|
|
f"请详细介绍{table_name}中字段中文名为:{identifier_value}的定义和相关信息",
|
|
|
|
|
|
f"在{table_name}中,字段中文名为:{identifier_value}是什么概念?请解释其属性和含义"
|
2025-12-19 15:14:46 +08:00
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
# 构建答案
|
|
|
|
|
|
answer_parts = []
|
|
|
|
|
|
for field_name, field_value in all_fields.items():
|
2025-12-23 16:21:01 +08:00
|
|
|
|
answer_parts.append(f"{field_name}:{field_value}")
|
2025-12-19 15:14:46 +08:00
|
|
|
|
|
2025-12-23 16:21:01 +08:00
|
|
|
|
# 连接答案(使用分号分隔)
|
|
|
|
|
|
answer = ";".join(answer_parts)
|
2025-12-19 15:14:46 +08:00
|
|
|
|
|
|
|
|
|
|
# 随机选择一个问法
|
|
|
|
|
|
question = self.config.get_random_element(question_templates)
|
|
|
|
|
|
|
|
|
|
|
|
# 生成QA对
|
|
|
|
|
|
qa_pairs.append({
|
|
|
|
|
|
"instruct": question,
|
|
|
|
|
|
"input": "",
|
|
|
|
|
|
"output": f"{self.config.get_random_element(answer_prefixes)}{answer}{self.config.get_random_element(answer_suffixes)}"
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
return qa_pairs
|
|
|
|
|
|
|
|
|
|
|
|
def generate_report_balanced(self, qa_counts: Dict[str, int], max_columns: int, distribution: List[int], comprehensive_count: int = 0):
|
|
|
|
|
|
"""生成平衡分配配置的报告"""
|
|
|
|
|
|
if not self.config.GENERATE_REPORT:
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
# 计算分配详情
|
|
|
|
|
|
allocation_detail = {}
|
|
|
|
|
|
for i, count in enumerate(distribution, 1):
|
|
|
|
|
|
if count > 0:
|
|
|
|
|
|
allocation_detail[f"{i}列问题"] = f"每个JSON数据项生成{count}个"
|
|
|
|
|
|
|
|
|
|
|
|
# 准备配置信息
|
|
|
|
|
|
config_info = {
|
|
|
|
|
|
"最大列数": max_columns,
|
|
|
|
|
|
"问题分配详情": allocation_detail,
|
|
|
|
|
|
"随机种子": self.config.RANDOM_SEED,
|
|
|
|
|
|
"打乱输出": self.config.SHUFFLE_OUTPUT
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# 如果有综合性问答,添加到配置中
|
|
|
|
|
|
if comprehensive_count > 0:
|
|
|
|
|
|
config_info["综合性问答"] = f"每个JSON数据项生成{comprehensive_count}个"
|
|
|
|
|
|
|
|
|
|
|
|
report = {
|
|
|
|
|
|
"生成时间": "2025-12-19",
|
|
|
|
|
|
"版本": "整合版-平衡分配",
|
|
|
|
|
|
"平衡配置": config_info,
|
|
|
|
|
|
"总文件数": len(qa_counts),
|
|
|
|
|
|
"各文件问答对数量": qa_counts,
|
|
|
|
|
|
"总计问答对数量": sum(qa_counts.values()),
|
|
|
|
|
|
"说明": "所有问答对均基于平衡分配配置生成:按比例平均分配不同列数的问题,优先多列"
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
report_path = os.path.join(self.config.OUTPUT_DIR, "QA生成报告.json")
|
|
|
|
|
|
with open(report_path, 'w', encoding='utf-8') as f:
|
|
|
|
|
|
json.dump(report, f, ensure_ascii=False, indent=2)
|
|
|
|
|
|
|
|
|
|
|
|
print(f"[OK] 已生成: {report_path}")
|
|
|
|
|
|
|
|
|
|
|
|
def split_train_validation(self, train_ratio: float = 0.9, val_ratio: float = 0.1):
|
|
|
|
|
|
"""将数据分割为训练集和验证集"""
|
|
|
|
|
|
all_qa_pairs = []
|
|
|
|
|
|
|
|
|
|
|
|
# 读取所有QA文件
|
|
|
|
|
|
for filename in os.listdir(self.config.OUTPUT_DIR):
|
|
|
|
|
|
if filename.endswith('.json') and not filename.startswith('QA生成报告') and filename not in ['train.json', 'val.json', 'validation.json', 'train_stats.json']:
|
|
|
|
|
|
file_path = os.path.join(self.config.OUTPUT_DIR, filename)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|
|
|
|
|
qa_data = json.load(f)
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(qa_data, list) and all(isinstance(item, dict) and 'instruct' in item and 'output' in item for item in qa_data):
|
|
|
|
|
|
all_qa_pairs.extend(qa_data)
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"[ERROR] 读取文件失败 {filename}: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
if not all_qa_pairs:
|
|
|
|
|
|
print("[WARNING] 没有找到可分割的QA数据")
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
# 打乱数据顺序
|
|
|
|
|
|
random.shuffle(all_qa_pairs)
|
|
|
|
|
|
|
|
|
|
|
|
# 计算分割点
|
|
|
|
|
|
total_count = len(all_qa_pairs)
|
|
|
|
|
|
train_count = int(total_count * train_ratio)
|
|
|
|
|
|
val_count = total_count - train_count
|
|
|
|
|
|
|
|
|
|
|
|
# 分割数据
|
|
|
|
|
|
train_data = all_qa_pairs[:train_count]
|
|
|
|
|
|
val_data = all_qa_pairs[train_count:]
|
|
|
|
|
|
|
|
|
|
|
|
# 保存训练集
|
|
|
|
|
|
train_path = os.path.join(self.config.OUTPUT_DIR, "train.json")
|
|
|
|
|
|
with open(train_path, 'w', encoding='utf-8') as f:
|
|
|
|
|
|
json.dump(train_data, f, ensure_ascii=False, indent=2)
|
|
|
|
|
|
|
|
|
|
|
|
# 保存验证集
|
|
|
|
|
|
val_path = os.path.join(self.config.OUTPUT_DIR, "val.json")
|
|
|
|
|
|
with open(val_path, 'w', encoding='utf-8') as f:
|
|
|
|
|
|
json.dump(val_data, f, ensure_ascii=False, indent=2)
|
|
|
|
|
|
|
|
|
|
|
|
# 计算文件大小
|
|
|
|
|
|
train_size_mb = os.path.getsize(train_path) / (1024 * 1024)
|
|
|
|
|
|
val_size_mb = os.path.getsize(val_path) / (1024 * 1024)
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\n[SUCCESS] 数据分割完成!")
|
|
|
|
|
|
print(f"[TRAIN] 训练集: {train_path}")
|
|
|
|
|
|
print(f" - 数据量: {len(train_data):,} 条")
|
|
|
|
|
|
print(f" - 文件大小: {train_size_mb:.2f} MB")
|
|
|
|
|
|
print(f" - 比例: {train_ratio*100:.0f}%")
|
|
|
|
|
|
print(f"[VAL] 验证集: {val_path}")
|
|
|
|
|
|
print(f" - 数据量: {len(val_data):,} 条")
|
|
|
|
|
|
print(f" - 文件大小: {val_size_mb:.2f} MB")
|
|
|
|
|
|
print(f" - 比例: {val_ratio*100:.0f}%")
|
|
|
|
|
|
print(f"[TOTAL] 总计: {total_count:,} 条")
|
|
|
|
|
|
|
|
|
|
|
|
# 生成分割报告
|
|
|
|
|
|
split_report = {
|
|
|
|
|
|
"分割时间": "2025-12-19",
|
|
|
|
|
|
"分割比例": f"{train_ratio*100:.0f}% : {val_ratio*100:.0f}%",
|
|
|
|
|
|
"训练集": {
|
|
|
|
|
|
"文件路径": "train.json",
|
|
|
|
|
|
"数据量": len(train_data),
|
|
|
|
|
|
"文件大小": f"{train_size_mb:.2f} MB",
|
|
|
|
|
|
"占比": f"{train_ratio*100:.0f}%"
|
|
|
|
|
|
},
|
|
|
|
|
|
"验证集": {
|
|
|
|
|
|
"文件路径": "val.json",
|
|
|
|
|
|
"数据量": len(val_data),
|
|
|
|
|
|
"文件大小": f"{val_size_mb:.2f} MB",
|
|
|
|
|
|
"占比": f"{val_ratio*100:.0f}%"
|
|
|
|
|
|
},
|
|
|
|
|
|
"总计": {
|
|
|
|
|
|
"数据量": total_count,
|
|
|
|
|
|
"文件大小": f"{(train_size_mb + val_size_mb):.2f} MB"
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
split_report_path = os.path.join(self.config.OUTPUT_DIR, "数据分割报告.json")
|
|
|
|
|
|
with open(split_report_path, 'w', encoding='utf-8') as f:
|
|
|
|
|
|
json.dump(split_report, f, ensure_ascii=False, indent=2)
|
|
|
|
|
|
|
|
|
|
|
|
print(f"[OK] 已生成: {split_report_path}")
|
|
|
|
|
|
|
2025-12-18 16:16:12 +08:00
|
|
|
|
def process_all(self):
|
2025-12-19 15:14:46 +08:00
|
|
|
|
"""处理所有数据文件 - 兼容旧版本"""
|
2025-12-18 16:16:12 +08:00
|
|
|
|
# 清理输出目录中的旧文件
|
|
|
|
|
|
print("[INFO] 清理输出目录中的旧文件...")
|
|
|
|
|
|
if os.path.exists(self.config.OUTPUT_DIR):
|
|
|
|
|
|
for filename in os.listdir(self.config.OUTPUT_DIR):
|
|
|
|
|
|
file_path = os.path.join(self.config.OUTPUT_DIR, filename)
|
|
|
|
|
|
try:
|
|
|
|
|
|
if os.path.isfile(file_path):
|
|
|
|
|
|
os.remove(file_path)
|
|
|
|
|
|
if self.config.VERBOSE_LOG:
|
|
|
|
|
|
print(f" [DEL] 已删除: {filename}")
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f" [WARN] 删除文件失败 {filename}: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
qa_counts = {}
|
|
|
|
|
|
|
|
|
|
|
|
for file_info in self.config.DATA_FILES:
|
|
|
|
|
|
if not file_info["enabled"]:
|
|
|
|
|
|
if self.config.VERBOSE_LOG:
|
|
|
|
|
|
print(f"[SKIP] 已跳过: {file_info['name']}")
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
if self.config.VERBOSE_LOG:
|
|
|
|
|
|
print(f"\n[INFO] 正在处理: {file_info['name']}.json")
|
|
|
|
|
|
else:
|
|
|
|
|
|
print(f"[INFO] 正在处理: {file_info['name']}")
|
|
|
|
|
|
|
|
|
|
|
|
input_file = os.path.join(self.config.INPUT_DIR, file_info["file"])
|
|
|
|
|
|
|
|
|
|
|
|
if not os.path.exists(input_file):
|
|
|
|
|
|
print(f"[WARNING] 文件不存在: {input_file}")
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
data = self.load_json(input_file)
|
|
|
|
|
|
|
|
|
|
|
|
# 根据文件名确定数据类型
|
|
|
|
|
|
if "元素治理" in file_info["name"]:
|
|
|
|
|
|
data_type = "element"
|
|
|
|
|
|
elif "物理模型" in file_info["name"]:
|
|
|
|
|
|
data_type = "physical"
|
|
|
|
|
|
elif "逻辑模型" in file_info["name"]:
|
|
|
|
|
|
data_type = "logical"
|
|
|
|
|
|
else:
|
|
|
|
|
|
print(f"[WARNING] 未知的文件类型: {file_info['name']}")
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
qa_pairs = self.generate_qa_for_data(data, data_type)
|
|
|
|
|
|
qa_pairs = self.shuffle_qa_pairs(qa_pairs)
|
|
|
|
|
|
self.save_qa(qa_pairs, file_info["output"])
|
|
|
|
|
|
qa_counts[file_info["output"]] = len(qa_pairs)
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"[ERROR] 处理文件 {file_info['name']} 时出错: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
# 生成报告
|
|
|
|
|
|
if self.config.GENERATE_REPORT:
|
|
|
|
|
|
if self.config.VERBOSE_LOG:
|
|
|
|
|
|
print("\n[INFO] 正在生成: 生成报告")
|
|
|
|
|
|
else:
|
|
|
|
|
|
print("\n[INFO] 正在生成: 生成报告")
|
|
|
|
|
|
self.generate_report(qa_counts)
|
|
|
|
|
|
|
|
|
|
|
|
if self.config.VERBOSE_LOG:
|
|
|
|
|
|
print(f"\n[DONE] 所有文件处理完成!")
|
|
|
|
|
|
print(f"[OUT] 输出目录: {self.config.OUTPUT_DIR}")
|
|
|
|
|
|
print(f"[TOTAL] 总计生成: {sum(qa_counts.values())} 条问答对")
|
|
|
|
|
|
else:
|
|
|
|
|
|
print("\n[DONE] 所有文件处理完成!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 预设配置
|
|
|
|
|
|
SIMPLE_CONFIG = QAConfig()
|
|
|
|
|
|
SIMPLE_CONFIG.COMPLEXITY_LEVEL = 1
|
|
|
|
|
|
SIMPLE_CONFIG._init_templates()
|
|
|
|
|
|
|
|
|
|
|
|
NORMAL_CONFIG = QAConfig()
|
|
|
|
|
|
NORMAL_CONFIG.COMPLEXITY_LEVEL = 3
|
|
|
|
|
|
NORMAL_CONFIG._init_templates()
|
|
|
|
|
|
|
|
|
|
|
|
COMPLEX_CONFIG = QAConfig()
|
|
|
|
|
|
COMPLEX_CONFIG.COMPLEXITY_LEVEL = 5
|
|
|
|
|
|
COMPLEX_CONFIG._init_templates()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
|
|
"""主函数 - 交互式运行"""
|
|
|
|
|
|
print("="*60)
|
|
|
|
|
|
print("QA生成器 - 整合版")
|
|
|
|
|
|
print("="*60)
|
|
|
|
|
|
|
2025-12-19 15:14:46 +08:00
|
|
|
|
# 交互式设置生成参数
|
|
|
|
|
|
print("\n" + "="*60)
|
|
|
|
|
|
print("请自定义生成参数:")
|
|
|
|
|
|
print("="*60)
|
|
|
|
|
|
|
|
|
|
|
|
# 设置最大列数
|
|
|
|
|
|
print("\n1. 最大列数设置:")
|
|
|
|
|
|
print(" 指定最多生成几列内容的问题(支持1列到N列平均分配)")
|
|
|
|
|
|
|
|
|
|
|
|
max_columns = 1
|
|
|
|
|
|
max_columns_input = input(f"最多生成几列内容? (默认1): ").strip()
|
|
|
|
|
|
if max_columns_input:
|
|
|
|
|
|
max_columns = max(1, int(max_columns_input))
|
|
|
|
|
|
print(f" 最多生成 {max_columns} 列内容的问题")
|
|
|
|
|
|
|
|
|
|
|
|
# 设置问题数量
|
|
|
|
|
|
print("\n2. 问题数量设置:")
|
|
|
|
|
|
print(" 指定每个JSON数据项生成多少个问题")
|
|
|
|
|
|
|
|
|
|
|
|
questions_per_item = 1
|
|
|
|
|
|
questions_input = input(f"每个JSON数据项生成几个问题? (默认1): ").strip()
|
|
|
|
|
|
if questions_input:
|
|
|
|
|
|
questions_per_item = max(1, int(questions_input))
|
|
|
|
|
|
print(f" 每个JSON数据项将生成 {questions_per_item} 个问题")
|
|
|
|
|
|
|
|
|
|
|
|
# 设置是否生成综合性问答
|
|
|
|
|
|
print("\n3. 综合性问答设置:")
|
|
|
|
|
|
print(" 综合性问答指一个问题询问所有字段的定义/含义")
|
2025-12-23 16:21:01 +08:00
|
|
|
|
print(" 例如:实际过账成本的成本中心的定义是什么?")
|
2025-12-19 15:14:46 +08:00
|
|
|
|
|
|
|
|
|
|
comprehensive_choice = input("是否添加综合性问答? (1=添加, 0=不添加, 默认0): ").strip()
|
|
|
|
|
|
comprehensive_count = 0
|
|
|
|
|
|
if comprehensive_choice == '1':
|
|
|
|
|
|
comprehensive_count = input("每个JSON数据项生成几个综合性问题? (默认1): ").strip()
|
|
|
|
|
|
comprehensive_count = max(1, int(comprehensive_count)) if comprehensive_count else 1
|
|
|
|
|
|
print(f" 每个JSON数据项将额外生成 {comprehensive_count} 个综合性问题")
|
2025-12-18 16:16:12 +08:00
|
|
|
|
else:
|
2025-12-19 15:14:46 +08:00
|
|
|
|
print(" 不生成综合性问题")
|
|
|
|
|
|
|
|
|
|
|
|
# 计算列数分配
|
|
|
|
|
|
distribution = []
|
|
|
|
|
|
base_count = questions_per_item // max_columns
|
|
|
|
|
|
remainder = questions_per_item % max_columns
|
|
|
|
|
|
|
|
|
|
|
|
# 优先分配给高列数
|
|
|
|
|
|
# 先给所有列分配基础数量
|
|
|
|
|
|
for i in range(max_columns):
|
|
|
|
|
|
distribution.append(base_count)
|
|
|
|
|
|
|
|
|
|
|
|
# 剩余的优先分配给高列数
|
|
|
|
|
|
for i in range(remainder):
|
|
|
|
|
|
distribution[-(1 + i)] += 1
|
|
|
|
|
|
|
|
|
|
|
|
# 显示分配信息
|
|
|
|
|
|
print("\n" + "="*60)
|
|
|
|
|
|
print("问题类型分配:")
|
|
|
|
|
|
print("="*60)
|
|
|
|
|
|
for i, count in enumerate(distribution, 1):
|
|
|
|
|
|
if count > 0:
|
|
|
|
|
|
print(f"[OK] {i}列问题: 每个JSON数据项生成 {count} 个")
|
|
|
|
|
|
|
|
|
|
|
|
# 创建自定义配置
|
|
|
|
|
|
config = QAConfig()
|
|
|
|
|
|
config.SINGLE_TEMPLATES = 1
|
|
|
|
|
|
config.MULTI_TEMPLATES = 1
|
|
|
|
|
|
config.MULTI_RATIO = 1.0 # 总是生成问题
|
2025-12-18 16:16:12 +08:00
|
|
|
|
|
|
|
|
|
|
# 显示配置信息
|
2025-12-19 15:14:46 +08:00
|
|
|
|
print("\n" + "="*60)
|
|
|
|
|
|
print("当前配置:")
|
|
|
|
|
|
print("="*60)
|
|
|
|
|
|
print(f"[OK] 最大列数: {max_columns} 列")
|
|
|
|
|
|
print(f"[OK] 每个JSON数据项生成问题数: {questions_per_item} 个")
|
|
|
|
|
|
if comprehensive_count > 0:
|
|
|
|
|
|
print(f"[OK] 综合性问题: {comprehensive_count} 个")
|
|
|
|
|
|
print(f"[OK] 输出目录: {config.OUTPUT_DIR}")
|
2025-12-18 16:16:12 +08:00
|
|
|
|
|
|
|
|
|
|
# 创建生成器并处理
|
|
|
|
|
|
generator = QAGenerator(config)
|
2025-12-19 15:14:46 +08:00
|
|
|
|
generator.process_all_balanced(max_columns, distribution, comprehensive_count)
|
2025-12-18 16:16:12 +08:00
|
|
|
|
|
|
|
|
|
|
# 询问是否合并为train.json
|
|
|
|
|
|
merge_choice = input("\n是否合并为train.json? (y/n): ").strip().lower()
|
|
|
|
|
|
if merge_choice == 'y':
|
|
|
|
|
|
generator.merge_to_train()
|
|
|
|
|
|
|
2025-12-19 15:14:46 +08:00
|
|
|
|
# 询问是否分割训练集和验证集
|
|
|
|
|
|
split_choice = input("\n是否将数据分割为训练集和验证集 (9:1)? (y/n): ").strip().lower()
|
|
|
|
|
|
if split_choice == 'y':
|
|
|
|
|
|
generator.split_train_validation(train_ratio=0.9, val_ratio=0.1)
|
|
|
|
|
|
|
2025-12-18 16:16:12 +08:00
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
main()
|