xref: /aosp_15_r20/external/pytorch/tools/linter/adapters/ruff_linter.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""Adapter for https://github.com/charliermarsh/ruff."""
2
3from __future__ import annotations
4
5import argparse
6import concurrent.futures
7import dataclasses
8import enum
9import json
10import logging
11import os
12import subprocess
13import sys
14import time
15from typing import Any, BinaryIO
16
17
18LINTER_CODE = "RUFF"
19IS_WINDOWS: bool = os.name == "nt"
20
21
22def eprint(*args: Any, **kwargs: Any) -> None:
23    """Print to stderr."""
24    print(*args, file=sys.stderr, flush=True, **kwargs)
25
26
27class LintSeverity(str, enum.Enum):
28    """Severity of a lint message."""
29
30    ERROR = "error"
31    WARNING = "warning"
32    ADVICE = "advice"
33    DISABLED = "disabled"
34
35
36@dataclasses.dataclass(frozen=True)
37class LintMessage:
38    """A lint message defined by https://docs.rs/lintrunner/latest/lintrunner/lint_message/struct.LintMessage.html."""
39
40    path: str | None
41    line: int | None
42    char: int | None
43    code: str
44    severity: LintSeverity
45    name: str
46    original: str | None
47    replacement: str | None
48    description: str | None
49
50    def asdict(self) -> dict[str, Any]:
51        return dataclasses.asdict(self)
52
53    def display(self) -> None:
54        """Print to stdout for lintrunner to consume."""
55        print(json.dumps(self.asdict()), flush=True)
56
57
58def as_posix(name: str) -> str:
59    return name.replace("\\", "/") if IS_WINDOWS else name
60
61
62def _run_command(
63    args: list[str],
64    *,
65    timeout: int | None,
66    stdin: BinaryIO | None,
67    input: bytes | None,
68    check: bool,
69    cwd: os.PathLike[Any] | None,
70) -> subprocess.CompletedProcess[bytes]:
71    logging.debug("$ %s", " ".join(args))
72    start_time = time.monotonic()
73    try:
74        if input is not None:
75            return subprocess.run(
76                args,
77                capture_output=True,
78                shell=False,
79                input=input,
80                timeout=timeout,
81                check=check,
82                cwd=cwd,
83            )
84
85        return subprocess.run(
86            args,
87            stdin=stdin,
88            capture_output=True,
89            shell=False,
90            timeout=timeout,
91            check=check,
92            cwd=cwd,
93        )
94    finally:
95        end_time = time.monotonic()
96        logging.debug("took %dms", (end_time - start_time) * 1000)
97
98
99def run_command(
100    args: list[str],
101    *,
102    retries: int = 0,
103    timeout: int | None = None,
104    stdin: BinaryIO | None = None,
105    input: bytes | None = None,
106    check: bool = False,
107    cwd: os.PathLike[Any] | None = None,
108) -> subprocess.CompletedProcess[bytes]:
109    remaining_retries = retries
110    while True:
111        try:
112            return _run_command(
113                args, timeout=timeout, stdin=stdin, input=input, check=check, cwd=cwd
114            )
115        except subprocess.TimeoutExpired as err:
116            if remaining_retries == 0:
117                raise err
118            remaining_retries -= 1
119            logging.warning(
120                "(%s/%s) Retrying because command failed with: %r",
121                retries - remaining_retries,
122                retries,
123                err,
124            )
125            time.sleep(1)
126
127
128def add_default_options(parser: argparse.ArgumentParser) -> None:
129    """Add default options to a parser.
130
131    This should be called the last in the chain of add_argument calls.
132    """
133    parser.add_argument(
134        "--retries",
135        type=int,
136        default=3,
137        help="number of times to retry if the linter times out.",
138    )
139    parser.add_argument(
140        "--verbose",
141        action="store_true",
142        help="verbose logging",
143    )
144    parser.add_argument(
145        "filenames",
146        nargs="+",
147        help="paths to lint",
148    )
149
150
151def explain_rule(code: str) -> str:
152    proc = run_command(
153        ["ruff", "rule", "--output-format=json", code],
154        check=True,
155    )
156    rule = json.loads(str(proc.stdout, "utf-8").strip())
157    return f"\n{rule['linter']}: {rule['summary']}"
158
159
160def get_issue_severity(code: str) -> LintSeverity:
161    # "B901": `return x` inside a generator
162    # "B902": Invalid first argument to a method
163    # "B903": __slots__ efficiency
164    # "B950": Line too long
165    # "C4": Flake8 Comprehensions
166    # "C9": Cyclomatic complexity
167    # "E2": PEP8 horizontal whitespace "errors"
168    # "E3": PEP8 blank line "errors"
169    # "E5": PEP8 line length "errors"
170    # "T400": type checking Notes
171    # "T49": internal type checker errors or unmatched messages
172    if any(
173        code.startswith(x)
174        for x in (
175            "B9",
176            "C4",
177            "C9",
178            "E2",
179            "E3",
180            "E5",
181            "T400",
182            "T49",
183            "PLC",
184            "PLR",
185        )
186    ):
187        return LintSeverity.ADVICE
188
189    # "F821": Undefined name
190    # "E999": syntax error
191    if any(code.startswith(x) for x in ("F821", "E999", "PLE")):
192        return LintSeverity.ERROR
193
194    # "F": PyFlakes Error
195    # "B": flake8-bugbear Error
196    # "E": PEP8 "Error"
197    # "W": PEP8 Warning
198    # possibly other plugins...
199    return LintSeverity.WARNING
200
201
202def format_lint_message(
203    message: str, code: str, rules: dict[str, str], show_disable: bool
204) -> str:
205    if rules:
206        message += f".\n{rules.get(code) or ''}"
207    message += ".\nSee https://beta.ruff.rs/docs/rules/"
208    if show_disable:
209        message += f".\n\nTo disable, use `  # noqa: {code}`"
210    return message
211
212
213def check_files(
214    filenames: list[str],
215    severities: dict[str, LintSeverity],
216    *,
217    config: str | None,
218    retries: int,
219    timeout: int,
220    explain: bool,
221    show_disable: bool,
222) -> list[LintMessage]:
223    try:
224        proc = run_command(
225            [
226                sys.executable,
227                "-m",
228                "ruff",
229                "check",
230                "--exit-zero",
231                "--quiet",
232                "--output-format=json",
233                *([f"--config={config}"] if config else []),
234                *filenames,
235            ],
236            retries=retries,
237            timeout=timeout,
238            check=True,
239        )
240    except (OSError, subprocess.CalledProcessError) as err:
241        return [
242            LintMessage(
243                path=None,
244                line=None,
245                char=None,
246                code=LINTER_CODE,
247                severity=LintSeverity.ERROR,
248                name="command-failed",
249                original=None,
250                replacement=None,
251                description=(
252                    f"Failed due to {err.__class__.__name__}:\n{err}"
253                    if not isinstance(err, subprocess.CalledProcessError)
254                    else (
255                        f"COMMAND (exit code {err.returncode})\n"
256                        f"{' '.join(as_posix(x) for x in err.cmd)}\n\n"
257                        f"STDERR\n{err.stderr.decode('utf-8').strip() or '(empty)'}\n\n"
258                        f"STDOUT\n{err.stdout.decode('utf-8').strip() or '(empty)'}"
259                    )
260                ),
261            )
262        ]
263
264    stdout = str(proc.stdout, "utf-8").strip()
265    vulnerabilities = json.loads(stdout)
266
267    if explain:
268        all_codes = {v["code"] for v in vulnerabilities}
269        rules = {code: explain_rule(code) for code in all_codes}
270    else:
271        rules = {}
272
273    return [
274        LintMessage(
275            path=vuln["filename"],
276            name=vuln["code"],
277            description=(
278                format_lint_message(
279                    vuln["message"],
280                    vuln["code"],
281                    rules,
282                    show_disable,
283                )
284            ),
285            line=int(vuln["location"]["row"]),
286            char=int(vuln["location"]["column"]),
287            code=LINTER_CODE,
288            severity=severities.get(vuln["code"], get_issue_severity(vuln["code"])),
289            original=None,
290            replacement=None,
291        )
292        for vuln in vulnerabilities
293    ]
294
295
296def check_file_for_fixes(
297    filename: str,
298    *,
299    config: str | None,
300    retries: int,
301    timeout: int,
302) -> list[LintMessage]:
303    try:
304        with open(filename, "rb") as f:
305            original = f.read()
306        with open(filename, "rb") as f:
307            proc_fix = run_command(
308                [
309                    sys.executable,
310                    "-m",
311                    "ruff",
312                    "check",
313                    "--fix-only",
314                    "--exit-zero",
315                    *([f"--config={config}"] if config else []),
316                    "--stdin-filename",
317                    filename,
318                    "-",
319                ],
320                stdin=f,
321                retries=retries,
322                timeout=timeout,
323                check=True,
324            )
325    except (OSError, subprocess.CalledProcessError) as err:
326        return [
327            LintMessage(
328                path=None,
329                line=None,
330                char=None,
331                code=LINTER_CODE,
332                severity=LintSeverity.ERROR,
333                name="command-failed",
334                original=None,
335                replacement=None,
336                description=(
337                    f"Failed due to {err.__class__.__name__}:\n{err}"
338                    if not isinstance(err, subprocess.CalledProcessError)
339                    else (
340                        f"COMMAND (exit code {err.returncode})\n"
341                        f"{' '.join(as_posix(x) for x in err.cmd)}\n\n"
342                        f"STDERR\n{err.stderr.decode('utf-8').strip() or '(empty)'}\n\n"
343                        f"STDOUT\n{err.stdout.decode('utf-8').strip() or '(empty)'}"
344                    )
345                ),
346            )
347        ]
348
349    replacement = proc_fix.stdout
350    if original == replacement:
351        return []
352
353    return [
354        LintMessage(
355            path=filename,
356            name="format",
357            description="Run `lintrunner -a` to apply this patch.",
358            line=None,
359            char=None,
360            code=LINTER_CODE,
361            severity=LintSeverity.WARNING,
362            original=original.decode("utf-8"),
363            replacement=replacement.decode("utf-8"),
364        )
365    ]
366
367
368def main() -> None:
369    parser = argparse.ArgumentParser(
370        description=f"Ruff linter. Linter code: {LINTER_CODE}. Use with RUFF-FIX to auto-fix issues.",
371        fromfile_prefix_chars="@",
372    )
373    parser.add_argument(
374        "--config",
375        default=None,
376        help="Path to the `pyproject.toml` or `ruff.toml` file to use for configuration",
377    )
378    parser.add_argument(
379        "--explain",
380        action="store_true",
381        help="Explain a rule",
382    )
383    parser.add_argument(
384        "--show-disable",
385        action="store_true",
386        help="Show how to disable a lint message",
387    )
388    parser.add_argument(
389        "--timeout",
390        default=90,
391        type=int,
392        help="Seconds to wait for ruff",
393    )
394    parser.add_argument(
395        "--severity",
396        action="append",
397        help="map code to severity (e.g. `F401:advice`). This option can be used multiple times.",
398    )
399    parser.add_argument(
400        "--no-fix",
401        action="store_true",
402        help="Do not suggest fixes",
403    )
404    add_default_options(parser)
405    args = parser.parse_args()
406
407    logging.basicConfig(
408        format="<%(threadName)s:%(levelname)s> %(message)s",
409        level=logging.NOTSET
410        if args.verbose
411        else logging.DEBUG
412        if len(args.filenames) < 1000
413        else logging.INFO,
414        stream=sys.stderr,
415    )
416
417    severities: dict[str, LintSeverity] = {}
418    if args.severity:
419        for severity in args.severity:
420            parts = severity.split(":", 1)
421            assert len(parts) == 2, f"invalid severity `{severity}`"
422            severities[parts[0]] = LintSeverity(parts[1])
423
424    lint_messages = check_files(
425        args.filenames,
426        severities=severities,
427        config=args.config,
428        retries=args.retries,
429        timeout=args.timeout,
430        explain=args.explain,
431        show_disable=args.show_disable,
432    )
433    for lint_message in lint_messages:
434        lint_message.display()
435
436    if args.no_fix or not lint_messages:
437        # If we're not fixing, we can exit early
438        return
439
440    files_with_lints = {lint.path for lint in lint_messages if lint.path is not None}
441    with concurrent.futures.ThreadPoolExecutor(
442        max_workers=os.cpu_count(),
443        thread_name_prefix="Thread",
444    ) as executor:
445        futures = {
446            executor.submit(
447                check_file_for_fixes,
448                path,
449                config=args.config,
450                retries=args.retries,
451                timeout=args.timeout,
452            ): path
453            for path in files_with_lints
454        }
455        for future in concurrent.futures.as_completed(futures):
456            try:
457                for lint_message in future.result():
458                    lint_message.display()
459            except Exception:  # Catch all exceptions for lintrunner
460                logging.critical('Failed at "%s".', futures[future])
461                raise
462
463
464if __name__ == "__main__":
465    main()
466