xref: /aosp_15_r20/external/pytorch/tools/linter/adapters/mypy_linter.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3import argparse
4import json
5import logging
6import os
7import re
8import subprocess
9import sys
10import time
11from enum import Enum
12from pathlib import Path
13from typing import Any, NamedTuple
14
15
16IS_WINDOWS: bool = os.name == "nt"
17
18
19def eprint(*args: Any, **kwargs: Any) -> None:
20    print(*args, file=sys.stderr, flush=True, **kwargs)
21
22
23class LintSeverity(str, Enum):
24    ERROR = "error"
25    WARNING = "warning"
26    ADVICE = "advice"
27    DISABLED = "disabled"
28
29
30class LintMessage(NamedTuple):
31    path: str | None
32    line: int | None
33    char: int | None
34    code: str
35    severity: LintSeverity
36    name: str
37    original: str | None
38    replacement: str | None
39    description: str | None
40
41
42def as_posix(name: str) -> str:
43    return name.replace("\\", "/") if IS_WINDOWS else name
44
45
46# tools/linter/flake8_linter.py:15:13: error: Incompatibl...int")  [assignment]
47RESULTS_RE: re.Pattern[str] = re.compile(
48    r"""(?mx)
49    ^
50    (?P<file>.*?):
51    (?P<line>\d+):
52    (?:(?P<column>-?\d+):)?
53    \s(?P<severity>\S+?):?
54    \s(?P<message>.*)
55    \s(?P<code>\[.*\])
56    $
57    """
58)
59
60# torch/_dynamo/variables/tensor.py:363: error: INTERNAL ERROR
61INTERNAL_ERROR_RE: re.Pattern[str] = re.compile(
62    r"""(?mx)
63    ^
64    (?P<file>.*?):
65    (?P<line>\d+):
66    \s(?P<severity>\S+?):?
67    \s(?P<message>INTERNAL\sERROR.*)
68    $
69    """
70)
71
72
73def run_command(
74    args: list[str],
75    *,
76    extra_env: dict[str, str] | None,
77    retries: int,
78) -> subprocess.CompletedProcess[bytes]:
79    logging.debug("$ %s", " ".join(args))
80    start_time = time.monotonic()
81    try:
82        return subprocess.run(
83            args,
84            capture_output=True,
85        )
86    finally:
87        end_time = time.monotonic()
88        logging.debug("took %dms", (end_time - start_time) * 1000)
89
90
91# Severity is either "error" or "note":
92# https://github.com/python/mypy/blob/8b47a032e1317fb8e3f9a818005a6b63e9bf0311/mypy/errors.py#L46-L47
93severities = {
94    "error": LintSeverity.ERROR,
95    "note": LintSeverity.ADVICE,
96}
97
98
99def check_mypy_installed(code: str) -> list[LintMessage]:
100    cmd = [sys.executable, "-mmypy", "-V"]
101    try:
102        subprocess.run(cmd, check=True, capture_output=True)
103        return []
104    except subprocess.CalledProcessError as e:
105        msg = e.stderr.decode(errors="replace")
106        return [
107            LintMessage(
108                path=None,
109                line=None,
110                char=None,
111                code=code,
112                severity=LintSeverity.ERROR,
113                name="command-failed",
114                original=None,
115                replacement=None,
116                description=f"Could not run '{' '.join(cmd)}': {msg}",
117            )
118        ]
119
120
121def check_files(
122    filenames: list[str],
123    config: str,
124    retries: int,
125    code: str,
126) -> list[LintMessage]:
127    # dmypy has a bug where it won't pick up changes if you pass it absolute
128    # file names, see https://github.com/python/mypy/issues/16768
129    filenames = [os.path.relpath(f) for f in filenames]
130    try:
131        proc = run_command(
132            ["dmypy", "run", "--", f"--config={config}"] + filenames,
133            extra_env={},
134            retries=retries,
135        )
136    except OSError as err:
137        return [
138            LintMessage(
139                path=None,
140                line=None,
141                char=None,
142                code=code,
143                severity=LintSeverity.ERROR,
144                name="command-failed",
145                original=None,
146                replacement=None,
147                description=(f"Failed due to {err.__class__.__name__}:\n{err}"),
148            )
149        ]
150    stdout = str(proc.stdout, "utf-8").strip()
151    stderr = str(proc.stderr, "utf-8").strip()
152    rc = [
153        LintMessage(
154            path=match["file"],
155            name=match["code"],
156            description=match["message"],
157            line=int(match["line"]),
158            char=int(match["column"])
159            if match["column"] is not None and not match["column"].startswith("-")
160            else None,
161            code=code,
162            severity=severities.get(match["severity"], LintSeverity.ERROR),
163            original=None,
164            replacement=None,
165        )
166        for match in RESULTS_RE.finditer(stdout)
167    ] + [
168        LintMessage(
169            path=match["file"],
170            name="INTERNAL ERROR",
171            description=match["message"],
172            line=int(match["line"]),
173            char=None,
174            code=code,
175            severity=severities.get(match["severity"], LintSeverity.ERROR),
176            original=None,
177            replacement=None,
178        )
179        for match in INTERNAL_ERROR_RE.finditer(stderr)
180    ]
181    return rc
182
183
184def main() -> None:
185    parser = argparse.ArgumentParser(
186        description="mypy wrapper linter.",
187        fromfile_prefix_chars="@",
188    )
189    parser.add_argument(
190        "--retries",
191        default=3,
192        type=int,
193        help="times to retry timed out mypy",
194    )
195    parser.add_argument(
196        "--config",
197        required=True,
198        help="path to an mypy .ini config file",
199    )
200    parser.add_argument(
201        "--code",
202        default="MYPY",
203        help="the code this lint should report as",
204    )
205    parser.add_argument(
206        "--verbose",
207        action="store_true",
208        help="verbose logging",
209    )
210    parser.add_argument(
211        "filenames",
212        nargs="+",
213        help="paths to lint",
214    )
215    args = parser.parse_args()
216
217    logging.basicConfig(
218        format="<%(threadName)s:%(levelname)s> %(message)s",
219        level=logging.NOTSET
220        if args.verbose
221        else logging.DEBUG
222        if len(args.filenames) < 1000
223        else logging.INFO,
224        stream=sys.stderr,
225    )
226
227    # Use a dictionary here to preserve order. mypy cares about order,
228    # tragically, e.g. https://github.com/python/mypy/issues/2015
229    filenames: dict[str, bool] = {}
230
231    # If a stub file exists, have mypy check it instead of the original file, in
232    # accordance with PEP-484 (see https://www.python.org/dev/peps/pep-0484/#stub-files)
233    for filename in args.filenames:
234        if filename.endswith(".pyi"):
235            filenames[filename] = True
236            continue
237
238        stub_filename = filename.replace(".py", ".pyi")
239        if Path(stub_filename).exists():
240            filenames[stub_filename] = True
241        else:
242            filenames[filename] = True
243
244    lint_messages = check_mypy_installed(args.code) + check_files(
245        list(filenames), args.config, args.retries, args.code
246    )
247    for lint_message in lint_messages:
248        print(json.dumps(lint_message._asdict()), flush=True)
249
250
251if __name__ == "__main__":
252    main()
253