xref: /aosp_15_r20/external/pytorch/tools/linter/adapters/bazel_linter.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""
2This linter ensures that users don't set a SHA hash checksum in Bazel for the http_archive.
3Although the security practice of setting the checksum is good, it doesn't work when the
4archive is downloaded from some sites like GitHub because it can change. Specifically,
5GitHub gives no guarantee to keep the same value forever. Check for more details at
6https://github.com/community/community/discussions/46034.
7"""
8
9from __future__ import annotations
10
11import argparse
12import json
13import re
14import shlex
15import subprocess
16import sys
17import xml.etree.ElementTree as ET
18from enum import Enum
19from typing import NamedTuple
20from urllib.parse import urlparse
21
22
23LINTER_CODE = "BAZEL_LINTER"
24SHA256_REGEX = re.compile(r"\s*sha256\s*=\s*['\"](?P<sha256>[a-zA-Z0-9]{64})['\"]\s*,")
25DOMAINS_WITH_UNSTABLE_CHECKSUM = {"github.com"}
26
27
28class LintSeverity(str, Enum):
29    ERROR = "error"
30    WARNING = "warning"
31    ADVICE = "advice"
32    DISABLED = "disabled"
33
34
35class LintMessage(NamedTuple):
36    path: str | None
37    line: int | None
38    char: int | None
39    code: str
40    severity: LintSeverity
41    name: str
42    original: str | None
43    replacement: str | None
44    description: str | None
45
46
47def is_required_checksum(urls: list[str | None]) -> bool:
48    if not urls:
49        return False
50
51    for url in urls:
52        if not url:
53            continue
54
55        parsed_url = urlparse(url)
56        if parsed_url.hostname in DOMAINS_WITH_UNSTABLE_CHECKSUM:
57            return False
58
59    return True
60
61
62def get_disallowed_checksums(
63    binary: str,
64) -> set[str]:
65    """
66    Return the set of disallowed checksums from all http_archive rules
67    """
68    # Use bazel to get the list of external dependencies in XML format
69    proc = subprocess.run(
70        [binary, "query", "kind(http_archive, //external:*)", "--output=xml"],
71        capture_output=True,
72        check=True,
73        text=True,
74    )
75
76    root = ET.fromstring(proc.stdout)
77
78    disallowed_checksums = set()
79    # Parse all the http_archive rules in the XML output
80    for rule in root.findall('.//rule[@class="http_archive"]'):
81        urls_node = rule.find('.//list[@name="urls"]')
82        if urls_node is None:
83            continue
84        urls = [n.get("value") for n in urls_node.findall(".//string")]
85
86        checksum_node = rule.find('.//string[@name="sha256"]')
87        if checksum_node is None:
88            continue
89        checksum = checksum_node.get("value")
90
91        if not checksum:
92            continue
93
94        if not is_required_checksum(urls):
95            disallowed_checksums.add(checksum)
96
97    return disallowed_checksums
98
99
100def check_bazel(
101    filename: str,
102    disallowed_checksums: set[str],
103) -> list[LintMessage]:
104    original = ""
105    replacement = ""
106
107    with open(filename) as f:
108        for line in f:
109            original += f"{line}"
110
111            m = SHA256_REGEX.match(line)
112            if m:
113                sha256 = m.group("sha256")
114
115                if sha256 in disallowed_checksums:
116                    continue
117
118            replacement += f"{line}"
119
120        if original == replacement:
121            return []
122
123        return [
124            LintMessage(
125                path=filename,
126                line=None,
127                char=None,
128                code=LINTER_CODE,
129                severity=LintSeverity.ADVICE,
130                name="format",
131                original=original,
132                replacement=replacement,
133                description="Found redundant SHA checksums. Run `lintrunner -a` to apply this patch.",
134            )
135        ]
136
137
138def main() -> None:
139    parser = argparse.ArgumentParser(
140        description="A custom linter to detect redundant SHA checksums in Bazel",
141        fromfile_prefix_chars="@",
142    )
143    parser.add_argument(
144        "--binary",
145        required=True,
146        help="bazel binary path",
147    )
148    parser.add_argument(
149        "filenames",
150        nargs="+",
151        help="paths to lint",
152    )
153    args = parser.parse_args()
154
155    try:
156        disallowed_checksums = get_disallowed_checksums(args.binary)
157    except subprocess.CalledProcessError as err:
158        err_msg = LintMessage(
159            path=None,
160            line=None,
161            char=None,
162            code=__file__,
163            severity=LintSeverity.ADVICE,
164            name="command-failed",
165            original=None,
166            replacement=None,
167            description=(
168                f"COMMAND (exit code {err.returncode})\n"
169                f"{shlex.join(err.cmd)}\n\n"
170                f"STDERR\n{err.stderr or '(empty)'}\n\n"
171                f"STDOUT\n{err.stdout or '(empty)'}"
172            ),
173        )
174        print(json.dumps(err_msg._asdict()))
175        return
176    except Exception as e:
177        err_msg = LintMessage(
178            path=None,
179            line=None,
180            char=None,
181            code=LINTER_CODE,
182            severity=LintSeverity.ERROR,
183            name="command-failed",
184            original=None,
185            replacement=None,
186            description=(f"Failed due to {e.__class__.__name__}:\n{e}"),
187        )
188        print(json.dumps(err_msg._asdict()), flush=True)
189        sys.exit(0)
190
191    for filename in args.filenames:
192        for lint_message in check_bazel(filename, disallowed_checksums):
193            print(json.dumps(lint_message._asdict()), flush=True)
194
195
196if __name__ == "__main__":
197    main()
198