118 lines
4.6 KiB
Python
118 lines
4.6 KiB
Python
"""领域和指引文件存储。"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
import uuid
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import BinaryIO
|
|
|
|
from app.utils.guidance_analysis import GuidanceAnalyzer
|
|
|
|
|
|
class DomainStorage:
|
|
def __init__(self, file_path: str | None = None) -> None:
|
|
self.file_path = file_path or os.path.join(os.getcwd(), "data", "domains.json")
|
|
os.makedirs(os.path.dirname(self.file_path), exist_ok=True)
|
|
if not os.path.exists(self.file_path):
|
|
self._write({"domains": []})
|
|
|
|
def _read(self) -> dict:
|
|
if not os.path.exists(self.file_path):
|
|
return {"domains": []}
|
|
with open(self.file_path, "r", encoding="utf-8") as file:
|
|
return json.load(file)
|
|
|
|
def _write(self, data: dict) -> None:
|
|
os.makedirs(os.path.dirname(self.file_path), exist_ok=True)
|
|
with open(self.file_path, "w", encoding="utf-8") as file:
|
|
json.dump(data, file, ensure_ascii=False, indent=2)
|
|
|
|
def list_domains(self) -> list[dict]:
|
|
return self._read().get("domains", [])
|
|
|
|
def save_domains(self, rows: list[dict]) -> list[dict]:
|
|
domains = []
|
|
for row in rows:
|
|
domain = row.get("domain") or row.get("领域") or row.get("风险领域") or row.get("name")
|
|
if not domain:
|
|
continue
|
|
domains.append({
|
|
"token": row.get("token") or str(uuid.uuid4()),
|
|
"domain": str(domain).strip(),
|
|
"note": str(row.get("note") or row.get("备注") or ""),
|
|
"created_at": datetime.now().isoformat(),
|
|
})
|
|
self._write({"domains": domains})
|
|
return domains
|
|
|
|
def clear_all(self) -> int:
|
|
count = len(self.list_domains())
|
|
self._write({"domains": []})
|
|
return count
|
|
|
|
def save_guidance_file(self, token_id: str, file_obj: BinaryIO, filename: str) -> dict:
|
|
data = self._read()
|
|
domain = self._find_domain(data, token_id)
|
|
content = file_obj.read()
|
|
file_id = str(uuid.uuid4())
|
|
extension = Path(filename).suffix or ".txt"
|
|
folder = Path(os.getcwd()) / "data" / "guidance" / token_id
|
|
folder.mkdir(parents=True, exist_ok=True)
|
|
stored_path = folder / f"{file_id}{extension}"
|
|
stored_path.write_bytes(content)
|
|
|
|
record = {
|
|
"file_id": file_id,
|
|
"filename": filename,
|
|
"file_size": len(content),
|
|
"stored_path": str(stored_path),
|
|
"uploaded_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
|
}
|
|
domain.setdefault("guidance_files", []).append(record)
|
|
self._write(data)
|
|
return record
|
|
|
|
@staticmethod
|
|
def public_guidance_file(record: dict) -> dict:
|
|
"""返回给前端的文件信息,不暴露服务器内部路径。"""
|
|
return {key: value for key, value in record.items() if key != "stored_path"}
|
|
|
|
def analyze_guidance(self, analysis_options: dict | None = None, token_id: str = "") -> list[dict]:
|
|
options = analysis_options or {}
|
|
analyzer = GuidanceAnalyzer()
|
|
granularity = analyzer.normalize_granularity(options.get("granularity", "high"))
|
|
if not analyzer.is_supported_granularity(options.get("granularity", granularity)):
|
|
raise ValueError("不支持的 granularity")
|
|
|
|
data = self._read()
|
|
domains = data.get("domains", [])
|
|
if token_id and not any(item.get("token") == token_id for item in domains):
|
|
raise ValueError(f"找不到 token_id={token_id}")
|
|
|
|
results = []
|
|
for domain in domains:
|
|
if token_id and domain.get("token") != token_id:
|
|
continue
|
|
for record in domain.get("guidance_files", []):
|
|
analysis = self._analyze_guidance_file(domain, record, {"granularity": granularity})
|
|
record["guidance_analysis"] = analysis
|
|
results.append(record)
|
|
self._write(data)
|
|
return results
|
|
|
|
def _analyze_guidance_file(self, domain: dict, record: dict, options: dict) -> dict:
|
|
path = Path(record["stored_path"])
|
|
if not path.exists():
|
|
path = Path(str(record["stored_path"]).replace("\\", os.sep))
|
|
content = path.read_bytes()
|
|
return GuidanceAnalyzer().analyze(domain, content, record.get("filename", path.name), options.get("granularity", "high"))
|
|
|
|
def _find_domain(self, data: dict, token_id: str) -> dict:
|
|
for domain in data.get("domains", []):
|
|
if domain.get("token") == token_id:
|
|
return domain
|
|
raise ValueError(f"找不到 token={token_id},请先保存风险领域后再上传指引文件")
|