1from __future__ import annotations 2 3import argparse 4import concurrent.futures 5import fnmatch 6import json 7import logging 8import os 9import re 10import subprocess 11import sys 12from enum import Enum 13from pathlib import Path 14from typing import Any, NamedTuple 15 16import black 17import isort 18import usort 19 20 21IS_WINDOWS: bool = os.name == "nt" 22REPO_ROOT = Path(__file__).absolute().parents[3] 23 24# TODO: remove this when it gets empty and remove `black` in PYFMT 25USE_BLACK_FILELIST = re.compile( 26 "|".join( 27 ( 28 r"\A\Z", # empty string 29 *map( 30 fnmatch.translate, 31 [ 32 # ** 33 # .ci/** 34 ".ci/**", 35 # .github/** 36 ".github/**", 37 # benchmarks/** 38 "benchmarks/**", 39 # functorch/** 40 "functorch/**", 41 # tools/** 42 "tools/**", 43 # torchgen/** 44 "torchgen/**", 45 # test/** 46 # test/[a-h]*/** 47 "test/[a-h]*/**", 48 # test/[i-j]*/** 49 "test/[i-j]*/**", 50 # test/[k-n]*/** 51 "test/[k-n]*/**", 52 # test/optim/** 53 "test/optim/**", 54 # "test/[p-z]*/**", 55 "test/[p-z]*/**", 56 # torch/** 57 # torch/_[a-h]*/** 58 "torch/_[a-h]*/**", 59 # torch/_i*/** 60 "torch/_i*/**", 61 # torch/_[j-z]*/** 62 "torch/_[j-z]*/**", 63 # torch/[a-c]*/** 64 "torch/[a-c]*/**", 65 # torch/d*/** 66 "torch/d*/**", 67 # torch/[e-n]*/** 68 "torch/[e-n]*/**", 69 # torch/optim/** 70 "torch/optim/**", 71 # torch/[p-z]*/** 72 "torch/[p-z]*/**", 73 ], 74 ), 75 ) 76 ) 77) 78 79 80def eprint(*args: Any, **kwargs: Any) -> None: 81 print(*args, file=sys.stderr, flush=True, **kwargs) 82 83 84class LintSeverity(str, Enum): 85 ERROR = "error" 86 WARNING = "warning" 87 ADVICE = "advice" 88 DISABLED = "disabled" 89 90 91class LintMessage(NamedTuple): 92 path: str | None 93 line: int | None 94 char: int | None 95 code: str 96 severity: LintSeverity 97 name: str 98 original: str | None 99 replacement: str | None 100 description: str | None 101 102 103def as_posix(name: str) -> str: 104 return name.replace("\\", "/") if IS_WINDOWS else name 105 106 107def format_error_message(filename: str, err: Exception) -> LintMessage: 108 return LintMessage( 109 path=filename, 110 line=None, 111 char=None, 112 code="PYFMT", 113 severity=LintSeverity.ADVICE, 114 name="command-failed", 115 original=None, 116 replacement=None, 117 description=(f"Failed due to {err.__class__.__name__}:\n{err}"), 118 ) 119 120 121def run_isort(content: str, path: Path) -> str: 122 isort_config = isort.Config(settings_path=str(REPO_ROOT)) 123 124 is_this_file = path.samefile(__file__) 125 if not is_this_file: 126 content = re.sub(r"(#.*\b)usort:\s*skip\b", r"\g<1>isort: split", content) 127 128 content = isort.code(content, config=isort_config, file_path=path) 129 130 if not is_this_file: 131 content = re.sub(r"(#.*\b)isort: split\b", r"\g<1>usort: skip", content) 132 133 return content 134 135 136def run_usort(content: str, path: Path) -> str: 137 usort_config = usort.Config.find(path) 138 139 return usort.usort_string(content, path=path, config=usort_config) 140 141 142def run_black(content: str, path: Path) -> str: 143 black_config = black.parse_pyproject_toml(black.find_pyproject_toml((str(path),))) # type: ignore[attr-defined,arg-type] 144 # manually patch options that do not have a 1-to-1 match in Mode arguments 145 black_config["target_versions"] = { 146 black.TargetVersion[ver.upper()] # type: ignore[attr-defined] 147 for ver in black_config.pop("target_version", []) 148 } 149 black_config["string_normalization"] = not black_config.pop( 150 "skip_string_normalization", False 151 ) 152 black_mode = black.Mode(**black_config) 153 black_mode.is_pyi = path.suffix.lower() == ".pyi" 154 black_mode.is_ipynb = path.suffix.lower() == ".ipynb" 155 156 return black.format_str(content, mode=black_mode) 157 158 159def run_ruff_format(content: str, path: Path) -> str: 160 try: 161 return subprocess.check_output( 162 [ 163 sys.executable, 164 "-m", 165 "ruff", 166 "format", 167 "--config", 168 str(REPO_ROOT / "pyproject.toml"), 169 "--stdin-filename", 170 str(path), 171 "-", 172 ], 173 input=content, 174 stderr=subprocess.STDOUT, 175 text=True, 176 encoding="utf-8", 177 ) 178 except subprocess.CalledProcessError as exc: 179 raise ValueError(exc.output) from exc 180 181 182def check_file(filename: str) -> list[LintMessage]: 183 path = Path(filename).absolute() 184 original = replacement = path.read_text(encoding="utf-8") 185 186 try: 187 # NB: run isort first to enforce style for blank lines 188 replacement = run_isort(replacement, path=path) 189 replacement = run_usort(replacement, path=path) 190 if USE_BLACK_FILELIST.match(path.absolute().relative_to(REPO_ROOT).as_posix()): 191 replacement = run_black(replacement, path=path) 192 else: 193 replacement = run_ruff_format(replacement, path=path) 194 195 if original == replacement: 196 return [] 197 198 return [ 199 LintMessage( 200 path=filename, 201 line=None, 202 char=None, 203 code="PYFMT", 204 severity=LintSeverity.WARNING, 205 name="format", 206 original=original, 207 replacement=replacement, 208 description="Run `lintrunner -a` to apply this patch.", 209 ) 210 ] 211 except Exception as err: 212 return [format_error_message(filename, err)] 213 214 215def main() -> None: 216 parser = argparse.ArgumentParser( 217 description="Format files with usort + ruff-format.", 218 fromfile_prefix_chars="@", 219 ) 220 parser.add_argument( 221 "--verbose", 222 action="store_true", 223 help="verbose logging", 224 ) 225 parser.add_argument( 226 "filenames", 227 nargs="+", 228 help="paths to lint", 229 ) 230 args = parser.parse_args() 231 232 logging.basicConfig( 233 format="<%(processName)s:%(levelname)s> %(message)s", 234 level=logging.NOTSET 235 if args.verbose 236 else logging.DEBUG 237 if len(args.filenames) < 1000 238 else logging.INFO, 239 stream=sys.stderr, 240 ) 241 242 with concurrent.futures.ProcessPoolExecutor( 243 max_workers=os.cpu_count(), 244 ) as executor: 245 futures = {executor.submit(check_file, x): x for x in args.filenames} 246 for future in concurrent.futures.as_completed(futures): 247 try: 248 for lint_message in future.result(): 249 print(json.dumps(lint_message._asdict()), flush=True) 250 except Exception: 251 logging.critical('Failed at "%s".', futures[future]) 252 raise 253 254 255if __name__ == "__main__": 256 main() 257