Files
YG-Rules/app/utils/schema_storage.py

145 lines
6.0 KiB
Python
Raw Normal View History

2026-06-10 19:15:24 +08:00
"""Schema 文件存储与解析。"""
from __future__ import annotations
import io
import json
import os
import re
from datetime import datetime
from typing import BinaryIO
from openpyxl import load_workbook
from app.utils.llm import strip_thinking
class SchemaStorage:
def __init__(self, file_path: str | None = None) -> None:
self.file_path = file_path or os.path.join(os.getcwd(), "data", "schema.json")
os.makedirs(os.path.dirname(self.file_path), exist_ok=True)
if not os.path.exists(self.file_path):
self._write({"processing_status": "empty", "modules": []})
def _read(self) -> dict:
if not os.path.exists(self.file_path):
return {"processing_status": "empty", "modules": []}
with open(self.file_path, "r", encoding="utf-8") as file:
return json.load(file)
def _write(self, data: dict) -> None:
os.makedirs(os.path.dirname(self.file_path), exist_ok=True)
with open(self.file_path, "w", encoding="utf-8") as file:
json.dump(data, file, ensure_ascii=False, indent=2)
def save(self, source_type: str, content: bytes | BinaryIO, filename: str = "", source_url: str = "") -> dict:
raw = content.read() if hasattr(content, "read") else content
modules = self._parse_excel(io.BytesIO(raw))
data = {
"version": "2.0",
"uploaded_at": datetime.now().isoformat(),
"source_type": source_type,
"source_url": source_url,
"filename": filename,
"processing_status": "done",
"modules": modules,
}
self._write(data)
return data
def get(self) -> dict:
return self._clean_descriptions_for_response(self._read())
def status(self) -> dict:
data = self._read()
modules = data.get("modules", [])
return {
"processing_status": data.get("processing_status", "empty"),
"module_count": len(modules),
"uploaded_at": self._format_time(data.get("uploaded_at", "")),
"source_type": data.get("source_type", ""),
"source_url": data.get("source_url", ""),
"filename": data.get("filename", ""),
}
def delete_file(self) -> bool:
if os.path.exists(self.file_path):
os.remove(self.file_path)
return True
return False
def _parse_excel(self, stream: BinaryIO) -> list[dict]:
workbook = load_workbook(stream, data_only=True)
modules: list[dict] = []
try:
sheet = workbook.active
current: dict | None = None
headers: list[str] = []
for row in sheet.iter_rows(values_only=True):
values = ["" if value is None else str(value).strip() for value in row]
if not any(values):
continue
if "数据名称" in values and "数据标记" in values:
headers = values
continue
if len(values) >= 1 and values[0] and not values[0].isdigit() and all(not value for value in values[1:]):
current = {"module_name": values[0], "table_name": "", "description": "", "fields": []}
modules.append(current)
continue
if not headers:
continue
if current is None:
current = {"module_name": "默认模块", "table_name": "", "description": "", "fields": []}
modules.append(current)
row_map = {headers[index]: values[index] for index in range(min(len(headers), len(values)))}
marker = row_map.get("数据标记", "")
field_name = row_map.get("数据名称", "")
table_name = row_map.get("表名", "")
if table_name and not current.get("table_name"):
current["table_name"] = table_name
if marker or field_name:
current["fields"].append({
"seq": row_map.get("序号", ""),
"name": field_name,
"marker": marker,
"type": row_map.get("数据类型", ""),
"length": row_map.get("数据长度", ""),
"rule": row_map.get("数据填写规则", ""),
"required": row_map.get("数据填写要求", ""),
"strong_check": row_map.get("强弱校验", ""),
})
for module in modules:
module["description"] = module.get("description") or self._fallback_description(module)
return modules
finally:
workbook.close()
def _sanitize_description(self, response: str, module: dict) -> str:
cleaned = strip_thinking(response or "").strip()
quoted = re.findall(r'"([^"]{8,})"', cleaned)
if quoted:
cleaned = quoted[-1].strip()
cleaned = re.sub(r"^(好的|我来|以下).*?[。:]\s*", "", cleaned).strip()
if "<think>" in cleaned or not cleaned:
return self._fallback_description(module)
return cleaned
def _clean_descriptions_for_response(self, data: dict) -> dict:
cloned = json.loads(json.dumps(data, ensure_ascii=False))
for module in cloned.get("modules", []):
module["description"] = self._sanitize_description(module.get("description", ""), module)
return cloned
def _fallback_description(self, module: dict) -> str:
name = module.get("module_name", "数据表")
fields = "".join(field.get("name", "") for field in module.get("fields", [])[:6] if field.get("name"))
return f"{name}表用于记录和管理{fields or '关键业务'}等信息,支撑查询、校验和监管规则生成。"
def _format_time(self, value: str) -> str:
if not value:
return ""
try:
return datetime.fromisoformat(value).strftime("%Y-%m-%d %H:%M:%S")
except ValueError:
return value