Files
YG-Rules/skill/yg-rules-pipeline/scripts/run_pipeline.py

182 lines
6.7 KiB
Python

import argparse
import io
import json
import sys
import time
from pathlib import Path
from typing import Any
GUIDANCE_EXTENSIONS = {".txt", ".pdf", ".doc", ".docx", ".md"}
DOMAIN_EXTENSIONS = (".xlsx", ".xls", ".csv", ".json")
SCHEMA_EXTENSIONS = (".xlsx", ".xls")
def repo_root() -> Path:
return Path(__file__).resolve().parents[3]
def import_project() -> None:
root = repo_root()
if str(root) not in sys.path:
sys.path.insert(0, str(root))
def find_named_file(input_dir: Path, stem: str, extensions: tuple[str, ...]) -> Path | None:
for extension in extensions:
candidate = input_dir / f"{stem}{extension}"
if candidate.exists():
return candidate
return None
def guidance_files_for_domain(guidance_dir: Path, domain_name: str) -> list[Path]:
files: list[Path] = []
for folder_name in ("_all", domain_name):
folder = guidance_dir / folder_name
if folder.exists() and folder.is_dir():
files.extend(
path
for path in sorted(folder.rglob("*"))
if path.is_file() and path.suffix.lower() in GUIDANCE_EXTENSIONS
)
return files
def wait_for_rule_task(service: Any, task_id: str, timeout: int, interval: float) -> dict[str, Any]:
deadline = time.time() + timeout
last_state: dict[str, Any] | None = None
while time.time() < deadline:
state = service.get_status(task_id)
if state:
last_state = state
if state.get("status") in {"done", "failed"}:
return state
time.sleep(interval)
raise TimeoutError(f"Timed out waiting for rule task {task_id}. Last state: {last_state}")
def validate_output_dir(output_dir: str) -> list[str]:
scripts_dir = Path(__file__).resolve().parents[1] / "scripts"
output_skill_script = repo_root() / "skill" / "yg-rules-output" / "scripts"
if str(output_skill_script) not in sys.path:
sys.path.insert(0, str(output_skill_script))
try:
from validate_task_output import validate_task_output
except Exception as exc:
return [f"Could not import output validator: {exc}"]
_ = scripts_dir
return validate_task_output(Path(output_dir))
def run_pipeline(args: argparse.Namespace) -> dict[str, Any]:
import_project()
from app.utils.parser import parse_upload_file
from app.utils.rule_generation import RuleGenerationService
from app.utils.schema_storage import SchemaStorage
from app.utils.storage import DomainStorage
input_dir = args.input.resolve()
if not input_dir.exists() or not input_dir.is_dir():
raise FileNotFoundError(f"Input directory not found: {input_dir}")
domains_file = find_named_file(input_dir, "domains", DOMAIN_EXTENSIONS)
if not domains_file:
raise FileNotFoundError(f"Missing domains file in {input_dir}")
with domains_file.open("rb") as file:
domain_rows = parse_upload_file(file, domains_file.name)
domain_storage = DomainStorage()
domains = domain_storage.save_domains(domain_rows)
schema_result = None
if not args.skip_schema:
schema_file = find_named_file(input_dir, "schema", SCHEMA_EXTENSIONS)
if not schema_file:
raise FileNotFoundError(f"Missing schema file in {input_dir}; pass --skip-schema to reuse existing schema")
schema_result = SchemaStorage().save(
source_type="file",
content=schema_file.read_bytes(),
filename=schema_file.name,
)
guidance_dir = input_dir / "guidance"
uploaded_guidance: list[dict[str, Any]] = []
if guidance_dir.exists():
for domain in domains:
for path in guidance_files_for_domain(guidance_dir, domain["domain"]):
with path.open("rb") as file:
uploaded = domain_storage.save_guidance_file(
domain["token"],
io.BytesIO(file.read()),
path.name,
)
uploaded_guidance.append({
"domain": domain["domain"],
"token": domain["token"],
"filename": uploaded["filename"],
})
if uploaded_guidance:
domain_storage.analyze_guidance(
analysis_options={"granularity": args.granularity},
token_id="",
)
elif not args.allow_no_guidance:
raise ValueError(f"No guidance files found under {guidance_dir}")
service = RuleGenerationService(create_sql=args.create_sql)
initial_state = service.start(limit=args.limit)
final_state = wait_for_rule_task(service, initial_state["task_id"], args.timeout, args.interval)
validation_errors = []
if final_state.get("output_dir"):
validation_errors = validate_output_dir(final_state["output_dir"])
return {
"input_dir": str(input_dir),
"domains_file": str(domains_file),
"domain_count": len(domains),
"schema_file": "" if args.skip_schema else str(find_named_file(input_dir, "schema", SCHEMA_EXTENSIONS)),
"schema_module_count": len((schema_result or {}).get("modules", [])) if schema_result else None,
"uploaded_guidance_count": len(uploaded_guidance),
"task": final_state,
"validation_errors": validation_errors,
}
def main() -> int:
parser = argparse.ArgumentParser(description="Run YG-Rules directly from an input folder.")
parser.add_argument("--input", type=Path, default=Path("input"), help="Input folder path.")
parser.add_argument("--limit", type=int, default=2, help="Rules per policy point, 1-30.")
parser.add_argument(
"--granularity",
choices=["low", "high", "coarse", "medium", "fine"],
default="low",
help="Guidance extraction granularity. Use low or high; old coarse/medium/fine values are accepted.",
)
parser.add_argument("--create-sql", action="store_true", help="Include SQL output.")
parser.add_argument("--skip-schema", action="store_true", help="Reuse existing data/schema.json.")
parser.add_argument("--allow-no-guidance", action="store_true", help="Allow running without guidance files.")
parser.add_argument("--timeout", type=int, default=900, help="Rule generation timeout in seconds.")
parser.add_argument("--interval", type=float, default=2.0, help="Polling interval in seconds.")
args = parser.parse_args()
try:
result = run_pipeline(args)
except Exception as exc:
print(f"ERROR: {exc}", file=sys.stderr)
return 1
print(json.dumps(result, ensure_ascii=False, indent=2))
task = result.get("task", {})
if result.get("validation_errors"):
return 1
return 0 if task.get("status") == "done" else 1
if __name__ == "__main__":
raise SystemExit(main())