xref: /aosp_15_r20/external/pytorch/tools/linter/adapters/pyfmt_linter.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3import argparse
4import concurrent.futures
5import fnmatch
6import json
7import logging
8import os
9import re
10import subprocess
11import sys
12from enum import Enum
13from pathlib import Path
14from typing import Any, NamedTuple
15
16import black
17import isort
18import usort
19
20
21IS_WINDOWS: bool = os.name == "nt"
22REPO_ROOT = Path(__file__).absolute().parents[3]
23
24# TODO: remove this when it gets empty and remove `black` in PYFMT
25USE_BLACK_FILELIST = re.compile(
26    "|".join(
27        (
28            r"\A\Z",  # empty string
29            *map(
30                fnmatch.translate,
31                [
32                    # **
33                    # .ci/**
34                    ".ci/**",
35                    # .github/**
36                    ".github/**",
37                    # benchmarks/**
38                    "benchmarks/**",
39                    # functorch/**
40                    "functorch/**",
41                    # tools/**
42                    "tools/**",
43                    # torchgen/**
44                    "torchgen/**",
45                    # test/**
46                    # test/[a-h]*/**
47                    "test/[a-h]*/**",
48                    # test/[i-j]*/**
49                    "test/[i-j]*/**",
50                    # test/[k-n]*/**
51                    "test/[k-n]*/**",
52                    # test/optim/**
53                    "test/optim/**",
54                    # "test/[p-z]*/**",
55                    "test/[p-z]*/**",
56                    # torch/**
57                    # torch/_[a-h]*/**
58                    "torch/_[a-h]*/**",
59                    # torch/_i*/**
60                    "torch/_i*/**",
61                    # torch/_[j-z]*/**
62                    "torch/_[j-z]*/**",
63                    # torch/[a-c]*/**
64                    "torch/[a-c]*/**",
65                    # torch/d*/**
66                    "torch/d*/**",
67                    # torch/[e-n]*/**
68                    "torch/[e-n]*/**",
69                    # torch/optim/**
70                    "torch/optim/**",
71                    # torch/[p-z]*/**
72                    "torch/[p-z]*/**",
73                ],
74            ),
75        )
76    )
77)
78
79
80def eprint(*args: Any, **kwargs: Any) -> None:
81    print(*args, file=sys.stderr, flush=True, **kwargs)
82
83
84class LintSeverity(str, Enum):
85    ERROR = "error"
86    WARNING = "warning"
87    ADVICE = "advice"
88    DISABLED = "disabled"
89
90
91class LintMessage(NamedTuple):
92    path: str | None
93    line: int | None
94    char: int | None
95    code: str
96    severity: LintSeverity
97    name: str
98    original: str | None
99    replacement: str | None
100    description: str | None
101
102
103def as_posix(name: str) -> str:
104    return name.replace("\\", "/") if IS_WINDOWS else name
105
106
107def format_error_message(filename: str, err: Exception) -> LintMessage:
108    return LintMessage(
109        path=filename,
110        line=None,
111        char=None,
112        code="PYFMT",
113        severity=LintSeverity.ADVICE,
114        name="command-failed",
115        original=None,
116        replacement=None,
117        description=(f"Failed due to {err.__class__.__name__}:\n{err}"),
118    )
119
120
121def run_isort(content: str, path: Path) -> str:
122    isort_config = isort.Config(settings_path=str(REPO_ROOT))
123
124    is_this_file = path.samefile(__file__)
125    if not is_this_file:
126        content = re.sub(r"(#.*\b)usort:\s*skip\b", r"\g<1>isort: split", content)
127
128    content = isort.code(content, config=isort_config, file_path=path)
129
130    if not is_this_file:
131        content = re.sub(r"(#.*\b)isort: split\b", r"\g<1>usort: skip", content)
132
133    return content
134
135
136def run_usort(content: str, path: Path) -> str:
137    usort_config = usort.Config.find(path)
138
139    return usort.usort_string(content, path=path, config=usort_config)
140
141
142def run_black(content: str, path: Path) -> str:
143    black_config = black.parse_pyproject_toml(black.find_pyproject_toml((str(path),)))  # type: ignore[attr-defined,arg-type]
144    # manually patch options that do not have a 1-to-1 match in Mode arguments
145    black_config["target_versions"] = {
146        black.TargetVersion[ver.upper()]  # type: ignore[attr-defined]
147        for ver in black_config.pop("target_version", [])
148    }
149    black_config["string_normalization"] = not black_config.pop(
150        "skip_string_normalization", False
151    )
152    black_mode = black.Mode(**black_config)
153    black_mode.is_pyi = path.suffix.lower() == ".pyi"
154    black_mode.is_ipynb = path.suffix.lower() == ".ipynb"
155
156    return black.format_str(content, mode=black_mode)
157
158
159def run_ruff_format(content: str, path: Path) -> str:
160    try:
161        return subprocess.check_output(
162            [
163                sys.executable,
164                "-m",
165                "ruff",
166                "format",
167                "--config",
168                str(REPO_ROOT / "pyproject.toml"),
169                "--stdin-filename",
170                str(path),
171                "-",
172            ],
173            input=content,
174            stderr=subprocess.STDOUT,
175            text=True,
176            encoding="utf-8",
177        )
178    except subprocess.CalledProcessError as exc:
179        raise ValueError(exc.output) from exc
180
181
182def check_file(filename: str) -> list[LintMessage]:
183    path = Path(filename).absolute()
184    original = replacement = path.read_text(encoding="utf-8")
185
186    try:
187        # NB: run isort first to enforce style for blank lines
188        replacement = run_isort(replacement, path=path)
189        replacement = run_usort(replacement, path=path)
190        if USE_BLACK_FILELIST.match(path.absolute().relative_to(REPO_ROOT).as_posix()):
191            replacement = run_black(replacement, path=path)
192        else:
193            replacement = run_ruff_format(replacement, path=path)
194
195        if original == replacement:
196            return []
197
198        return [
199            LintMessage(
200                path=filename,
201                line=None,
202                char=None,
203                code="PYFMT",
204                severity=LintSeverity.WARNING,
205                name="format",
206                original=original,
207                replacement=replacement,
208                description="Run `lintrunner -a` to apply this patch.",
209            )
210        ]
211    except Exception as err:
212        return [format_error_message(filename, err)]
213
214
215def main() -> None:
216    parser = argparse.ArgumentParser(
217        description="Format files with usort + ruff-format.",
218        fromfile_prefix_chars="@",
219    )
220    parser.add_argument(
221        "--verbose",
222        action="store_true",
223        help="verbose logging",
224    )
225    parser.add_argument(
226        "filenames",
227        nargs="+",
228        help="paths to lint",
229    )
230    args = parser.parse_args()
231
232    logging.basicConfig(
233        format="<%(processName)s:%(levelname)s> %(message)s",
234        level=logging.NOTSET
235        if args.verbose
236        else logging.DEBUG
237        if len(args.filenames) < 1000
238        else logging.INFO,
239        stream=sys.stderr,
240    )
241
242    with concurrent.futures.ProcessPoolExecutor(
243        max_workers=os.cpu_count(),
244    ) as executor:
245        futures = {executor.submit(check_file, x): x for x in args.filenames}
246        for future in concurrent.futures.as_completed(futures):
247            try:
248                for lint_message in future.result():
249                    print(json.dumps(lint_message._asdict()), flush=True)
250            except Exception:
251                logging.critical('Failed at "%s".', futures[future])
252                raise
253
254
255if __name__ == "__main__":
256    main()
257