xref: /aosp_15_r20/external/pytorch/tools/linter/adapters/grep_linter.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""
2Generic linter that greps for a pattern and optionally suggests replacements.
3"""
4
5from __future__ import annotations
6
7import argparse
8import json
9import logging
10import os
11import subprocess
12import sys
13import time
14from enum import Enum
15from typing import Any, NamedTuple
16
17
18IS_WINDOWS: bool = os.name == "nt"
19
20
21def eprint(*args: Any, **kwargs: Any) -> None:
22    print(*args, file=sys.stderr, flush=True, **kwargs)
23
24
25class LintSeverity(str, Enum):
26    ERROR = "error"
27    WARNING = "warning"
28    ADVICE = "advice"
29    DISABLED = "disabled"
30
31
32class LintMessage(NamedTuple):
33    path: str | None
34    line: int | None
35    char: int | None
36    code: str
37    severity: LintSeverity
38    name: str
39    original: str | None
40    replacement: str | None
41    description: str | None
42
43
44def as_posix(name: str) -> str:
45    return name.replace("\\", "/") if IS_WINDOWS else name
46
47
48def run_command(
49    args: list[str],
50) -> subprocess.CompletedProcess[bytes]:
51    logging.debug("$ %s", " ".join(args))
52    start_time = time.monotonic()
53    try:
54        return subprocess.run(
55            args,
56            capture_output=True,
57        )
58    finally:
59        end_time = time.monotonic()
60        logging.debug("took %dms", (end_time - start_time) * 1000)
61
62
63def lint_file(
64    matching_line: str,
65    allowlist_pattern: str,
66    replace_pattern: str,
67    linter_name: str,
68    error_name: str,
69    error_description: str,
70) -> LintMessage | None:
71    # matching_line looks like:
72    #   tools/linter/clangtidy_linter.py:13:import foo.bar.baz
73    split = matching_line.split(":")
74    filename = split[0]
75
76    if allowlist_pattern:
77        try:
78            proc = run_command(["grep", "-nEHI", allowlist_pattern, filename])
79        except Exception as err:
80            return LintMessage(
81                path=None,
82                line=None,
83                char=None,
84                code=linter_name,
85                severity=LintSeverity.ERROR,
86                name="command-failed",
87                original=None,
88                replacement=None,
89                description=(
90                    f"Failed due to {err.__class__.__name__}:\n{err}"
91                    if not isinstance(err, subprocess.CalledProcessError)
92                    else (
93                        "COMMAND (exit code {returncode})\n"
94                        "{command}\n\n"
95                        "STDERR\n{stderr}\n\n"
96                        "STDOUT\n{stdout}"
97                    ).format(
98                        returncode=err.returncode,
99                        command=" ".join(as_posix(x) for x in err.cmd),
100                        stderr=err.stderr.decode("utf-8").strip() or "(empty)",
101                        stdout=err.stdout.decode("utf-8").strip() or "(empty)",
102                    )
103                ),
104            )
105
106        # allowlist pattern was found, abort lint
107        if proc.returncode == 0:
108            return None
109
110    original = None
111    replacement = None
112    if replace_pattern:
113        with open(filename) as f:
114            original = f.read()
115
116        try:
117            proc = run_command(["sed", "-r", replace_pattern, filename])
118            replacement = proc.stdout.decode("utf-8")
119        except Exception as err:
120            return LintMessage(
121                path=None,
122                line=None,
123                char=None,
124                code=linter_name,
125                severity=LintSeverity.ERROR,
126                name="command-failed",
127                original=None,
128                replacement=None,
129                description=(
130                    f"Failed due to {err.__class__.__name__}:\n{err}"
131                    if not isinstance(err, subprocess.CalledProcessError)
132                    else (
133                        "COMMAND (exit code {returncode})\n"
134                        "{command}\n\n"
135                        "STDERR\n{stderr}\n\n"
136                        "STDOUT\n{stdout}"
137                    ).format(
138                        returncode=err.returncode,
139                        command=" ".join(as_posix(x) for x in err.cmd),
140                        stderr=err.stderr.decode("utf-8").strip() or "(empty)",
141                        stdout=err.stdout.decode("utf-8").strip() or "(empty)",
142                    )
143                ),
144            )
145
146    return LintMessage(
147        path=split[0],
148        line=int(split[1]) if len(split) > 1 else None,
149        char=None,
150        code=linter_name,
151        severity=LintSeverity.ERROR,
152        name=error_name,
153        original=original,
154        replacement=replacement,
155        description=error_description,
156    )
157
158
159def main() -> None:
160    parser = argparse.ArgumentParser(
161        description="grep wrapper linter.",
162        fromfile_prefix_chars="@",
163    )
164    parser.add_argument(
165        "--pattern",
166        required=True,
167        help="pattern to grep for",
168    )
169    parser.add_argument(
170        "--allowlist-pattern",
171        help="if this pattern is true in the file, we don't grep for pattern",
172    )
173    parser.add_argument(
174        "--linter-name",
175        required=True,
176        help="name of the linter",
177    )
178    parser.add_argument(
179        "--match-first-only",
180        action="store_true",
181        help="only match the first hit in the file",
182    )
183    parser.add_argument(
184        "--error-name",
185        required=True,
186        help="human-readable description of what the error is",
187    )
188    parser.add_argument(
189        "--error-description",
190        required=True,
191        help="message to display when the pattern is found",
192    )
193    parser.add_argument(
194        "--replace-pattern",
195        help=(
196            "the form of a pattern passed to `sed -r`. "
197            "If specified, this will become proposed replacement text."
198        ),
199    )
200    parser.add_argument(
201        "--verbose",
202        action="store_true",
203        help="verbose logging",
204    )
205    parser.add_argument(
206        "filenames",
207        nargs="+",
208        help="paths to lint",
209    )
210    args = parser.parse_args()
211
212    logging.basicConfig(
213        format="<%(threadName)s:%(levelname)s> %(message)s",
214        level=logging.NOTSET
215        if args.verbose
216        else logging.DEBUG
217        if len(args.filenames) < 1000
218        else logging.INFO,
219        stream=sys.stderr,
220    )
221
222    files_with_matches = []
223    if args.match_first_only:
224        files_with_matches = ["--files-with-matches"]
225
226    try:
227        proc = run_command(
228            ["grep", "-nEHI", *files_with_matches, args.pattern, *args.filenames]
229        )
230    except Exception as err:
231        err_msg = LintMessage(
232            path=None,
233            line=None,
234            char=None,
235            code=args.linter_name,
236            severity=LintSeverity.ERROR,
237            name="command-failed",
238            original=None,
239            replacement=None,
240            description=(
241                f"Failed due to {err.__class__.__name__}:\n{err}"
242                if not isinstance(err, subprocess.CalledProcessError)
243                else (
244                    "COMMAND (exit code {returncode})\n"
245                    "{command}\n\n"
246                    "STDERR\n{stderr}\n\n"
247                    "STDOUT\n{stdout}"
248                ).format(
249                    returncode=err.returncode,
250                    command=" ".join(as_posix(x) for x in err.cmd),
251                    stderr=err.stderr.decode("utf-8").strip() or "(empty)",
252                    stdout=err.stdout.decode("utf-8").strip() or "(empty)",
253                )
254            ),
255        )
256        print(json.dumps(err_msg._asdict()), flush=True)
257        sys.exit(0)
258
259    lines = proc.stdout.decode().splitlines()
260    for line in lines:
261        lint_message = lint_file(
262            line,
263            args.allowlist_pattern,
264            args.replace_pattern,
265            args.linter_name,
266            args.error_name,
267            args.error_description,
268        )
269        if lint_message is not None:
270            print(json.dumps(lint_message._asdict()), flush=True)
271
272
273if __name__ == "__main__":
274    main()
275