Files
YG-Rules/app/utils/schema_storage.py

145 lines
6.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Schema 文件存储与解析。"""
from __future__ import annotations
import io
import json
import os
import re
from datetime import datetime
from typing import BinaryIO
from openpyxl import load_workbook
from app.utils.llm import strip_thinking
class SchemaStorage:
def __init__(self, file_path: str | None = None) -> None:
self.file_path = file_path or os.path.join(os.getcwd(), "data", "schema.json")
os.makedirs(os.path.dirname(self.file_path), exist_ok=True)
if not os.path.exists(self.file_path):
self._write({"processing_status": "empty", "modules": []})
def _read(self) -> dict:
if not os.path.exists(self.file_path):
return {"processing_status": "empty", "modules": []}
with open(self.file_path, "r", encoding="utf-8") as file:
return json.load(file)
def _write(self, data: dict) -> None:
os.makedirs(os.path.dirname(self.file_path), exist_ok=True)
with open(self.file_path, "w", encoding="utf-8") as file:
json.dump(data, file, ensure_ascii=False, indent=2)
def save(self, source_type: str, content: bytes | BinaryIO, filename: str = "", source_url: str = "") -> dict:
raw = content.read() if hasattr(content, "read") else content
modules = self._parse_excel(io.BytesIO(raw))
data = {
"version": "2.0",
"uploaded_at": datetime.now().isoformat(),
"source_type": source_type,
"source_url": source_url,
"filename": filename,
"processing_status": "done",
"modules": modules,
}
self._write(data)
return data
def get(self) -> dict:
return self._clean_descriptions_for_response(self._read())
def status(self) -> dict:
data = self._read()
modules = data.get("modules", [])
return {
"processing_status": data.get("processing_status", "empty"),
"module_count": len(modules),
"uploaded_at": self._format_time(data.get("uploaded_at", "")),
"source_type": data.get("source_type", ""),
"source_url": data.get("source_url", ""),
"filename": data.get("filename", ""),
}
def delete_file(self) -> bool:
if os.path.exists(self.file_path):
os.remove(self.file_path)
return True
return False
def _parse_excel(self, stream: BinaryIO) -> list[dict]:
workbook = load_workbook(stream, data_only=True)
modules: list[dict] = []
try:
sheet = workbook.active
current: dict | None = None
headers: list[str] = []
for row in sheet.iter_rows(values_only=True):
values = ["" if value is None else str(value).strip() for value in row]
if not any(values):
continue
if "数据名称" in values and "数据标记" in values:
headers = values
continue
if len(values) >= 1 and values[0] and not values[0].isdigit() and all(not value for value in values[1:]):
current = {"module_name": values[0], "table_name": "", "description": "", "fields": []}
modules.append(current)
continue
if not headers:
continue
if current is None:
current = {"module_name": "默认模块", "table_name": "", "description": "", "fields": []}
modules.append(current)
row_map = {headers[index]: values[index] for index in range(min(len(headers), len(values)))}
marker = row_map.get("数据标记", "")
field_name = row_map.get("数据名称", "")
table_name = row_map.get("表名", "")
if table_name and not current.get("table_name"):
current["table_name"] = table_name
if marker or field_name:
current["fields"].append({
"seq": row_map.get("序号", ""),
"name": field_name,
"marker": marker,
"type": row_map.get("数据类型", ""),
"length": row_map.get("数据长度", ""),
"rule": row_map.get("数据填写规则", ""),
"required": row_map.get("数据填写要求", ""),
"strong_check": row_map.get("强弱校验", ""),
})
for module in modules:
module["description"] = module.get("description") or self._fallback_description(module)
return modules
finally:
workbook.close()
def _sanitize_description(self, response: str, module: dict) -> str:
cleaned = strip_thinking(response or "").strip()
quoted = re.findall(r'"([^"]{8,})"', cleaned)
if quoted:
cleaned = quoted[-1].strip()
cleaned = re.sub(r"^(好的|我来|以下).*?[。:]\s*", "", cleaned).strip()
if "<think>" in cleaned or not cleaned:
return self._fallback_description(module)
return cleaned
def _clean_descriptions_for_response(self, data: dict) -> dict:
cloned = json.loads(json.dumps(data, ensure_ascii=False))
for module in cloned.get("modules", []):
module["description"] = self._sanitize_description(module.get("description", ""), module)
return cloned
def _fallback_description(self, module: dict) -> str:
name = module.get("module_name", "数据表")
fields = "".join(field.get("name", "") for field in module.get("fields", [])[:6] if field.get("name"))
return f"{name}表用于记录和管理{fields or '关键业务'}等信息,支撑查询、校验和监管规则生成。"
def _format_time(self, value: str) -> str:
if not value:
return ""
try:
return datetime.fromisoformat(value).strftime("%Y-%m-%d %H:%M:%S")
except ValueError:
return value