xref: /aosp_15_r20/external/pytorch/tools/extract_scripts.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2
3from __future__ import annotations
4
5import argparse
6import re
7import sys
8from pathlib import Path
9from typing import Any, Dict
10from typing_extensions import TypedDict  # Python 3.11+
11
12import yaml
13
14
15Step = Dict[str, Any]
16
17
18class Script(TypedDict):
19    extension: str
20    script: str
21
22
23def extract(step: Step) -> Script | None:
24    run = step.get("run")
25
26    # https://docs.github.com/en/actions/reference/workflow-syntax-for-github-actions#using-a-specific-shell
27    shell = step.get("shell", "bash")
28    extension = {
29        "bash": ".sh",
30        "pwsh": ".ps1",
31        "python": ".py",
32        "sh": ".sh",
33        "cmd": ".cmd",
34        "powershell": ".ps1",
35    }.get(shell)
36
37    is_gh_script = step.get("uses", "").startswith("actions/github-script@")
38    gh_script = step.get("with", {}).get("script")
39
40    if run is not None and extension is not None:
41        script = {
42            "bash": f"#!/usr/bin/env bash\nset -eo pipefail\n{run}",
43            "sh": f"#!/usr/bin/env sh\nset -e\n{run}",
44        }.get(shell, run)
45        return {"extension": extension, "script": script}
46    elif is_gh_script and gh_script is not None:
47        return {"extension": ".js", "script": gh_script}
48    else:
49        return None
50
51
52def main() -> None:
53    parser = argparse.ArgumentParser()
54    parser.add_argument("--out", required=True)
55    args = parser.parse_args()
56
57    out = Path(args.out)
58    if out.exists():
59        sys.exit(f"{out} already exists; aborting to avoid overwriting")
60
61    gha_expressions_found = False
62
63    for p in Path(".github/workflows").iterdir():
64        with open(p, "rb") as f:
65            workflow = yaml.safe_load(f)
66
67        for job_name, job in workflow["jobs"].items():
68            job_dir = out / p / job_name
69            if "steps" not in job:
70                continue
71            steps = job["steps"]
72            index_chars = len(str(len(steps) - 1))
73            for i, step in enumerate(steps, start=1):
74                extracted = extract(step)
75                if extracted:
76                    script = extracted["script"]
77                    step_name = step.get("name", "")
78                    if "${{" in script:
79                        gha_expressions_found = True
80                        print(
81                            f"{p} job `{job_name}` step {i}: {step_name}",
82                            file=sys.stderr,
83                        )
84
85                    job_dir.mkdir(parents=True, exist_ok=True)
86
87                    sanitized = re.sub(
88                        "[^a-zA-Z_]+",
89                        "_",
90                        f"_{step_name}",
91                    ).rstrip("_")
92                    extension = extracted["extension"]
93                    filename = f"{i:0{index_chars}}{sanitized}{extension}"
94                    (job_dir / filename).write_text(script)
95
96    if gha_expressions_found:
97        sys.exit(
98            "Each of the above scripts contains a GitHub Actions "
99            "${{ <expression> }} which must be replaced with an `env` variable"
100            " for security reasons."
101        )
102
103
104if __name__ == "__main__":
105    main()
106