1#!/usr/bin/env python3 2 3# Copyright 2020 The Pigweed Authors 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); you may not 6# use this file except in compliance with the License. You may obtain a copy of 7# the License at 8# 9# https://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 14# License for the specific language governing permissions and limitations under 15# the License. 16"""Checks and fixes formatting for source files. 17 18This uses clang-format, gn format, gofmt, and python -m yapf to format source 19code. These tools must be available on the path when this script is invoked! 20""" 21 22import argparse 23import collections 24import difflib 25import json 26import logging 27import os 28from pathlib import Path 29import re 30import shutil 31import subprocess 32import sys 33import tempfile 34from typing import ( 35 Callable, 36 Collection, 37 Iterable, 38 NamedTuple, 39 Optional, 40 Pattern, 41 TextIO, 42) 43 44import pw_cli.color 45from pw_cli.diff import colorize_diff 46import pw_cli.env 47from pw_cli.file_filter import FileFilter 48from pw_cli.plural import plural 49import pw_env_setup.config_file 50from pw_presubmit.presubmit import filter_paths 51from pw_presubmit.presubmit_context import ( 52 FormatContext, 53 FormatOptions, 54 PresubmitContext, 55 PresubmitFailure, 56) 57from pw_presubmit import ( 58 cli, 59 git_repo, 60 owners_checks, 61 presubmit_context, 62) 63from pw_presubmit.format.core import FormattedDiff, FormatFixStatus 64from pw_presubmit.format.cpp import ClangFormatFormatter 65from pw_presubmit.format.bazel import BuildifierFormatter 66from pw_presubmit.format.gn import GnFormatter 67from pw_presubmit.format.python import BlackFormatter 68from pw_presubmit.tools import ( 69 exclude_paths, 70 file_summary, 71 log_run, 72 PresubmitToolRunner, 73) 74from pw_presubmit.rst_format import reformat_rst 75 76_LOG: logging.Logger = logging.getLogger(__name__) 77_COLOR = pw_cli.color.colors() 78_DEFAULT_PATH = Path('out', 'format') 79 80_Context = PresubmitContext | FormatContext 81 82 83def _ensure_newline(orig: bytes) -> bytes: 84 if orig.endswith(b'\n'): 85 return orig 86 return orig + b'\nNo newline at end of file\n' 87 88 89def _diff(path, original: bytes, formatted: bytes) -> str: 90 original = _ensure_newline(original) 91 formatted = _ensure_newline(formatted) 92 return ''.join( 93 difflib.unified_diff( 94 original.decode(errors='replace').splitlines(True), 95 formatted.decode(errors='replace').splitlines(True), 96 f'{path} (original)', 97 f'{path} (reformatted)', 98 ) 99 ) 100 101 102FormatterT = Callable[[str, bytes], bytes] 103 104 105def _diff_formatted( 106 path, formatter: FormatterT, dry_run: bool = False 107) -> str | None: 108 """Returns a diff comparing a file to its formatted version.""" 109 with open(path, 'rb') as fd: 110 original = fd.read() 111 112 formatted = formatter(path, original) 113 114 if dry_run: 115 return None 116 117 return None if formatted == original else _diff(path, original, formatted) 118 119 120def _check_files( 121 files, formatter: FormatterT, dry_run: bool = False 122) -> dict[Path, str]: 123 errors = {} 124 125 for path in files: 126 difference = _diff_formatted(path, formatter, dry_run) 127 if difference: 128 errors[path] = difference 129 130 return errors 131 132 133def _make_formatting_diff_dict( 134 diffs: Iterable[FormattedDiff], 135) -> dict[Path, str]: 136 """Adapts the formatting check API to work with this presubmit tooling.""" 137 return { 138 result.file_path: ( 139 result.diff if result.ok else str(result.error_message) 140 ) 141 for result in diffs 142 } 143 144 145def _make_format_fix_error_output_dict( 146 statuses: Iterable[tuple[Path, FormatFixStatus]], 147) -> dict[Path, str]: 148 """Adapts the formatter API to work with this presubmit tooling.""" 149 return { 150 file_path: str(status.error_message) for file_path, status in statuses 151 } 152 153 154def clang_format_check(ctx: _Context) -> dict[Path, str]: 155 """Checks formatting; returns {path: diff} for files with bad formatting.""" 156 formatter = ClangFormatFormatter(tool_runner=PresubmitToolRunner()) 157 return _make_formatting_diff_dict( 158 formatter.get_formatting_diffs(ctx.paths, ctx.dry_run) 159 ) 160 161 162def clang_format_fix(ctx: _Context) -> dict[Path, str]: 163 """Fixes formatting for the provided files in place.""" 164 formatter = ClangFormatFormatter(tool_runner=PresubmitToolRunner()) 165 return _make_format_fix_error_output_dict(formatter.format_files(ctx.paths)) 166 167 168def _typescript_format(*args: Path | str, **kwargs) -> bytes: 169 # TODO: b/323378974 - Better integrate NPM actions with pw_env_setup so 170 # we don't have to manually set `npm_config_cache` every time we run npm. 171 # Force npm cache to live inside the environment directory. 172 npm_env = os.environ.copy() 173 npm_env['npm_config_cache'] = str( 174 Path(npm_env['_PW_ACTUAL_ENVIRONMENT_ROOT']) / 'npm-cache' 175 ) 176 177 npm = shutil.which('npm.cmd' if os.name == 'nt' else 'npm') 178 return log_run( 179 [npm, 'exec', 'prettier', *args], 180 stdout=subprocess.PIPE, 181 stdin=subprocess.DEVNULL, 182 check=True, 183 env=npm_env, 184 **kwargs, 185 ).stdout 186 187 188def typescript_format_check(ctx: _Context) -> dict[Path, str]: 189 """Checks formatting; returns {path: diff} for files with bad formatting.""" 190 return _check_files( 191 ctx.paths, 192 lambda path, _: _typescript_format(path), 193 ctx.dry_run, 194 ) 195 196 197def typescript_format_fix(ctx: _Context) -> dict[Path, str]: 198 """Fixes formatting for the provided files in place.""" 199 print_format_fix(_typescript_format(*ctx.paths, '--', '--write')) 200 return {} 201 202 203def check_gn_format(ctx: _Context) -> dict[Path, str]: 204 """Checks formatting; returns {path: diff} for files with bad formatting.""" 205 formatter = GnFormatter(tool_runner=PresubmitToolRunner()) 206 return _make_formatting_diff_dict( 207 formatter.get_formatting_diffs( 208 ctx.paths, 209 ctx.dry_run, 210 ) 211 ) 212 213 214def fix_gn_format(ctx: _Context) -> dict[Path, str]: 215 """Fixes formatting for the provided files in place.""" 216 formatter = GnFormatter(tool_runner=PresubmitToolRunner()) 217 return _make_format_fix_error_output_dict(formatter.format_files(ctx.paths)) 218 219 220def check_bazel_format(ctx: _Context) -> dict[Path, str]: 221 """Checks formatting; returns {path: diff} for files with bad formatting.""" 222 formatter = BuildifierFormatter(tool_runner=PresubmitToolRunner()) 223 return _make_formatting_diff_dict( 224 formatter.get_formatting_diffs( 225 ctx.paths, 226 ctx.dry_run, 227 ) 228 ) 229 230 231def fix_bazel_format(ctx: _Context) -> dict[Path, str]: 232 """Fixes formatting for the provided files in place.""" 233 formatter = BuildifierFormatter(tool_runner=PresubmitToolRunner()) 234 return _make_format_fix_error_output_dict(formatter.format_files(ctx.paths)) 235 236 237def check_owners_format(ctx: _Context) -> dict[Path, str]: 238 return owners_checks.run_owners_checks(ctx.paths) 239 240 241def fix_owners_format(ctx: _Context) -> dict[Path, str]: 242 return owners_checks.format_owners_file(ctx.paths) 243 244 245def check_go_format(ctx: _Context) -> dict[Path, str]: 246 """Checks formatting; returns {path: diff} for files with bad formatting.""" 247 return _check_files( 248 ctx.paths, 249 lambda path, _: log_run( 250 ['gofmt', path], stdout=subprocess.PIPE, check=True 251 ).stdout, 252 ctx.dry_run, 253 ) 254 255 256def fix_go_format(ctx: _Context) -> dict[Path, str]: 257 """Fixes formatting for the provided files in place.""" 258 log_run(['gofmt', '-w', *ctx.paths], check=True) 259 return {} 260 261 262# TODO: b/259595799 - Remove yapf support. 263def _yapf(*args, **kwargs) -> subprocess.CompletedProcess: 264 return log_run( 265 ['python', '-m', 'yapf', '--parallel', *args], 266 capture_output=True, 267 **kwargs, 268 ) 269 270 271_DIFF_START = re.compile(r'^--- (.*)\s+\(original\)$', flags=re.MULTILINE) 272 273 274def check_py_format_yapf(ctx: _Context) -> dict[Path, str]: 275 """Checks formatting; returns {path: diff} for files with bad formatting.""" 276 process = _yapf('--diff', *ctx.paths) 277 278 errors: dict[Path, str] = {} 279 280 if process.stdout: 281 raw_diff = process.stdout.decode(errors='replace') 282 283 matches = tuple(_DIFF_START.finditer(raw_diff)) 284 for start, end in zip(matches, (*matches[1:], None)): 285 errors[Path(start.group(1))] = raw_diff[ 286 start.start() : end.start() if end else None 287 ] 288 289 if process.stderr: 290 _LOG.error( 291 'yapf encountered an error:\n%s', 292 process.stderr.decode(errors='replace').rstrip(), 293 ) 294 errors.update({file: '' for file in ctx.paths if file not in errors}) 295 296 return errors 297 298 299def fix_py_format_yapf(ctx: _Context) -> dict[Path, str]: 300 """Fixes formatting for the provided files in place.""" 301 print_format_fix(_yapf('--in-place', *ctx.paths, check=True).stdout) 302 return {} 303 304 305def _enumerate_black_configs() -> Iterable[Path]: 306 config = pw_env_setup.config_file.load() 307 black_config_file = ( 308 config.get('pw', {}) 309 .get('pw_presubmit', {}) 310 .get('format', {}) 311 .get('black_config_file', {}) 312 ) 313 if black_config_file: 314 explicit_path = Path(black_config_file) 315 if not explicit_path.is_file(): 316 raise ValueError(f'Black config file not found: {explicit_path}') 317 yield explicit_path 318 return # If an explicit path is provided, don't try implicit paths. 319 320 if directory := os.environ.get('PW_PROJECT_ROOT'): 321 yield Path(directory, '.black.toml') 322 yield Path(directory, 'pyproject.toml') 323 324 if directory := os.environ.get('PW_ROOT'): 325 yield Path(directory, '.black.toml') 326 yield Path(directory, 'pyproject.toml') 327 328 329def _select_black_config_file() -> Optional[Path]: 330 config = None 331 for config_location in _enumerate_black_configs(): 332 if config_location.is_file(): 333 config = config_location 334 break 335 return config 336 337 338def check_py_format_black(ctx: _Context) -> dict[Path, str]: 339 """Checks formatting; returns {path: diff} for files with bad formatting.""" 340 formatter = BlackFormatter( 341 _select_black_config_file(), tool_runner=PresubmitToolRunner() 342 ) 343 return _make_formatting_diff_dict( 344 formatter.get_formatting_diffs( 345 ctx.paths, 346 ctx.dry_run, 347 ) 348 ) 349 350 351def fix_py_format_black(ctx: _Context) -> dict[Path, str]: 352 """Fixes formatting for the provided files in place.""" 353 formatter = BlackFormatter( 354 _select_black_config_file(), tool_runner=PresubmitToolRunner() 355 ) 356 return _make_format_fix_error_output_dict(formatter.format_files(ctx.paths)) 357 358 359def check_py_format(ctx: _Context) -> dict[Path, str]: 360 if ctx.format_options.python_formatter == 'black': 361 return check_py_format_black(ctx) 362 if ctx.format_options.python_formatter == 'yapf': 363 return check_py_format_yapf(ctx) 364 raise ValueError(ctx.format_options.python_formatter) 365 366 367def fix_py_format(ctx: _Context) -> dict[Path, str]: 368 if ctx.format_options.python_formatter == 'black': 369 return fix_py_format_black(ctx) 370 if ctx.format_options.python_formatter == 'yapf': 371 return fix_py_format_yapf(ctx) 372 raise ValueError(ctx.format_options.python_formatter) 373 374 375_TRAILING_SPACE = re.compile(rb'[ \t]+$', flags=re.MULTILINE) 376 377 378def _check_trailing_space(paths: Iterable[Path], fix: bool) -> dict[Path, str]: 379 """Checks for and optionally removes trailing whitespace.""" 380 errors = {} 381 382 for path in paths: 383 with path.open('rb') as fd: 384 contents = fd.read() 385 386 corrected = _TRAILING_SPACE.sub(b'', contents) 387 if corrected != contents: 388 errors[path] = _diff(path, contents, corrected) 389 390 if fix: 391 with path.open('wb') as fd: 392 fd.write(corrected) 393 394 return errors 395 396 397def _format_json(contents: bytes) -> bytes: 398 return json.dumps(json.loads(contents), indent=2).encode() + b'\n' 399 400 401def _json_error(exc: json.JSONDecodeError, path: Path) -> str: 402 return f'{path}: {exc.msg} {exc.lineno}:{exc.colno}\n' 403 404 405def check_json_format(ctx: _Context) -> dict[Path, str]: 406 errors = {} 407 408 for path in ctx.paths: 409 orig = path.read_bytes() 410 try: 411 formatted = _format_json(orig) 412 except json.JSONDecodeError as exc: 413 errors[path] = _json_error(exc, path) 414 continue 415 416 if orig != formatted: 417 errors[path] = _diff(path, orig, formatted) 418 419 return errors 420 421 422def fix_json_format(ctx: _Context) -> dict[Path, str]: 423 errors = {} 424 for path in ctx.paths: 425 orig = path.read_bytes() 426 try: 427 formatted = _format_json(orig) 428 except json.JSONDecodeError as exc: 429 errors[path] = _json_error(exc, path) 430 continue 431 432 if orig != formatted: 433 path.write_bytes(formatted) 434 435 return errors 436 437 438def check_trailing_space(ctx: _Context) -> dict[Path, str]: 439 return _check_trailing_space(ctx.paths, fix=False) 440 441 442def fix_trailing_space(ctx: _Context) -> dict[Path, str]: 443 _check_trailing_space(ctx.paths, fix=True) 444 return {} 445 446 447def rst_format_check(ctx: _Context) -> dict[Path, str]: 448 errors: dict[Path, str] = {} 449 for path in ctx.paths: 450 result = reformat_rst( 451 path, diff=True, in_place=False, suppress_stdout=True 452 ) 453 if result: 454 errors[path] = ''.join(result) 455 return errors 456 457 458def rst_format_fix(ctx: _Context) -> dict[Path, str]: 459 errors: dict[Path, str] = {} 460 for path in ctx.paths: 461 reformat_rst(path, diff=True, in_place=True, suppress_stdout=True) 462 return errors 463 464 465def print_format_check( 466 errors: dict[Path, str], 467 show_fix_commands: bool, 468 show_summary: bool = True, 469 colors: bool | None = None, 470 file: TextIO = sys.stdout, 471) -> None: 472 """Prints and returns the result of a check_*_format function.""" 473 if not errors: 474 # Don't print anything in the all-good case. 475 return 476 477 if colors is None: 478 colors = file == sys.stdout 479 480 # Show the format fixing diff suggested by the tooling (with colors). 481 if show_summary: 482 _LOG.warning( 483 'Found %d files with formatting errors. Format changes:', 484 len(errors), 485 ) 486 for diff in errors.values(): 487 if colors: 488 diff = colorize_diff(diff) 489 print(diff, end='', file=file) 490 491 # Show a copy-and-pastable command to fix the issues. 492 if show_fix_commands: 493 494 def path_relative_to_cwd(path: Path): 495 try: 496 return Path(path).resolve().relative_to(Path.cwd().resolve()) 497 except ValueError: 498 return Path(path).resolve() 499 500 message = ( 501 f' pw format --fix {path_relative_to_cwd(path)}' for path in errors 502 ) 503 _LOG.warning('To fix formatting, run:\n\n%s\n', '\n'.join(message)) 504 505 506def print_format_fix(stdout: bytes): 507 """Prints the output of a format --fix call.""" 508 for line in stdout.splitlines(): 509 _LOG.info('Fix cmd stdout: %r', line.decode('utf-8')) 510 511 512class CodeFormat(NamedTuple): 513 language: str 514 filter: FileFilter 515 check: Callable[[_Context], dict[Path, str]] 516 fix: Callable[[_Context], dict[Path, str]] 517 518 @property 519 def extensions(self): 520 # TODO: b/23842636 - Switch calls of this to using 'filter' and remove. 521 return self.filter.endswith 522 523 524CPP_HEADER_EXTS = frozenset(('.h', '.hpp', '.hxx', '.h++', '.hh', '.H')) 525CPP_SOURCE_EXTS = frozenset( 526 ('.c', '.cpp', '.cxx', '.c++', '.cc', '.C', '.inc', '.inl') 527) 528CPP_EXTS = CPP_HEADER_EXTS.union(CPP_SOURCE_EXTS) 529CPP_FILE_FILTER = FileFilter( 530 endswith=CPP_EXTS, exclude=[r'\.pb\.h$', r'\.pb\.c$'] 531) 532 533C_FORMAT = CodeFormat( 534 'C and C++', CPP_FILE_FILTER, clang_format_check, clang_format_fix 535) 536 537PROTO_FORMAT: CodeFormat = CodeFormat( 538 'Protocol buffer', 539 FileFilter(endswith=['.proto']), 540 clang_format_check, 541 clang_format_fix, 542) 543 544JAVA_FORMAT: CodeFormat = CodeFormat( 545 'Java', 546 FileFilter(endswith=['.java']), 547 clang_format_check, 548 clang_format_fix, 549) 550 551JAVASCRIPT_FORMAT: CodeFormat = CodeFormat( 552 'JavaScript', 553 FileFilter(endswith=['.js']), 554 typescript_format_check, 555 typescript_format_fix, 556) 557 558TYPESCRIPT_FORMAT: CodeFormat = CodeFormat( 559 'TypeScript', 560 FileFilter(endswith=['.ts']), 561 typescript_format_check, 562 typescript_format_fix, 563) 564 565# TODO: b/308948504 - Add real code formatting support for CSS 566CSS_FORMAT: CodeFormat = CodeFormat( 567 'css', 568 FileFilter(endswith=['.css']), 569 check_trailing_space, 570 fix_trailing_space, 571) 572 573GO_FORMAT: CodeFormat = CodeFormat( 574 'Go', FileFilter(endswith=['.go']), check_go_format, fix_go_format 575) 576 577PYTHON_FORMAT: CodeFormat = CodeFormat( 578 'Python', 579 FileFilter(endswith=['.py']), 580 check_py_format, 581 fix_py_format, 582) 583 584GN_FORMAT: CodeFormat = CodeFormat( 585 'GN', FileFilter(endswith=['.gn', '.gni']), check_gn_format, fix_gn_format 586) 587 588BAZEL_FORMAT: CodeFormat = CodeFormat( 589 'Bazel', 590 FileFilter(endswith=['.bazel', '.bzl'], name=['^BUILD$', '^WORKSPACE$']), 591 check_bazel_format, 592 fix_bazel_format, 593) 594 595COPYBARA_FORMAT: CodeFormat = CodeFormat( 596 'Copybara', 597 FileFilter(endswith=['.bara.sky']), 598 check_bazel_format, 599 fix_bazel_format, 600) 601 602# TODO: b/234881054 - Add real code formatting support for CMake 603CMAKE_FORMAT: CodeFormat = CodeFormat( 604 'CMake', 605 FileFilter(endswith=['.cmake'], name=['^CMakeLists.txt$']), 606 check_trailing_space, 607 fix_trailing_space, 608) 609 610RST_FORMAT: CodeFormat = CodeFormat( 611 'reStructuredText', 612 FileFilter(endswith=['.rst']), 613 rst_format_check, 614 rst_format_fix, 615) 616 617MARKDOWN_FORMAT: CodeFormat = CodeFormat( 618 'Markdown', 619 FileFilter(endswith=['.md']), 620 check_trailing_space, 621 fix_trailing_space, 622) 623 624OWNERS_CODE_FORMAT = CodeFormat( 625 'OWNERS', 626 filter=FileFilter(name=['^OWNERS$']), 627 check=check_owners_format, 628 fix=fix_owners_format, 629) 630 631JSON_FORMAT: CodeFormat = CodeFormat( 632 'JSON', 633 FileFilter(endswith=['.json']), 634 check=check_json_format, 635 fix=fix_json_format, 636) 637 638CODE_FORMATS: tuple[CodeFormat, ...] = tuple( 639 filter( 640 None, 641 ( 642 # keep-sorted: start 643 BAZEL_FORMAT, 644 CMAKE_FORMAT, 645 COPYBARA_FORMAT, 646 CSS_FORMAT, 647 C_FORMAT, 648 GN_FORMAT, 649 GO_FORMAT, 650 JAVASCRIPT_FORMAT if shutil.which('npm') else None, 651 JAVA_FORMAT, 652 JSON_FORMAT, 653 MARKDOWN_FORMAT, 654 OWNERS_CODE_FORMAT, 655 PROTO_FORMAT, 656 PYTHON_FORMAT, 657 RST_FORMAT, 658 TYPESCRIPT_FORMAT if shutil.which('npm') else None, 659 # keep-sorted: end 660 ), 661 ) 662) 663 664 665# TODO: b/264578594 - Remove these lines when these globals aren't referenced. 666CODE_FORMATS_WITH_BLACK: tuple[CodeFormat, ...] = CODE_FORMATS 667CODE_FORMATS_WITH_YAPF: tuple[CodeFormat, ...] = CODE_FORMATS 668 669 670def presubmit_check( 671 code_format: CodeFormat, 672 *, 673 exclude: Collection[str | Pattern[str]] = (), 674) -> Callable: 675 """Creates a presubmit check function from a CodeFormat object. 676 677 Args: 678 exclude: Additional exclusion regexes to apply. 679 """ 680 681 # Make a copy of the FileFilter and add in any additional excludes. 682 file_filter = FileFilter(**vars(code_format.filter)) 683 file_filter.exclude += tuple(re.compile(e) for e in exclude) 684 685 @filter_paths(file_filter=file_filter) 686 def check_code_format(ctx: PresubmitContext): 687 ctx.paths = presubmit_context.apply_exclusions(ctx) 688 errors = code_format.check(ctx) 689 print_format_check( 690 errors, 691 # When running as part of presubmit, show the fix command help. 692 show_fix_commands=True, 693 ) 694 if not errors: 695 return 696 697 with ctx.failure_summary_log.open('w') as outs: 698 print_format_check( 699 errors, 700 show_summary=False, 701 show_fix_commands=False, 702 file=outs, 703 ) 704 705 raise PresubmitFailure 706 707 language = code_format.language.lower().replace('+', 'p').replace(' ', '_') 708 check_code_format.name = f'{language}_format' 709 check_code_format.doc = f'Check the format of {code_format.language} files.' 710 711 return check_code_format 712 713 714def presubmit_checks( 715 *, 716 exclude: Collection[str | Pattern[str]] = (), 717 code_formats: Collection[CodeFormat] = CODE_FORMATS, 718) -> tuple[Callable, ...]: 719 """Returns a tuple with all supported code format presubmit checks. 720 721 Args: 722 exclude: Additional exclusion regexes to apply. 723 code_formats: A list of CodeFormat objects to run checks with. 724 """ 725 726 return tuple(presubmit_check(fmt, exclude=exclude) for fmt in code_formats) 727 728 729class CodeFormatter: 730 """Checks or fixes the formatting of a set of files.""" 731 732 def __init__( 733 self, 734 root: Path | None, 735 files: Iterable[Path], 736 output_dir: Path, 737 code_formats: Collection[CodeFormat] = CODE_FORMATS_WITH_YAPF, 738 package_root: Path | None = None, 739 ): 740 self.root = root 741 self._formats: dict[CodeFormat, list] = collections.defaultdict(list) 742 self.root_output_dir = output_dir 743 self.package_root = package_root or output_dir / 'packages' 744 self._format_options = FormatOptions.load() 745 raw_paths = files 746 self.paths: tuple[Path, ...] = self._format_options.filter_paths(files) 747 748 filtered_paths = set(raw_paths) - set(self.paths) 749 for path in sorted(filtered_paths): 750 _LOG.debug('filtered out %s', path) 751 752 for path in self.paths: 753 for code_format in code_formats: 754 if code_format.filter.matches(path): 755 _LOG.debug( 756 'Formatting %s as %s', path, code_format.language 757 ) 758 self._formats[code_format].append(path) 759 break 760 else: 761 _LOG.debug('No formatter found for %s', path) 762 763 def _context(self, code_format: CodeFormat): 764 outdir = self.root_output_dir / code_format.language.replace(' ', '_') 765 os.makedirs(outdir, exist_ok=True) 766 767 return FormatContext( 768 root=self.root, 769 output_dir=outdir, 770 paths=tuple(self._formats[code_format]), 771 package_root=self.package_root, 772 format_options=self._format_options, 773 ) 774 775 def check(self) -> dict[Path, str]: 776 """Returns {path: diff} for files with incorrect formatting.""" 777 errors: dict[Path, str] = {} 778 779 for code_format, files in self._formats.items(): 780 _LOG.debug('Checking %s', ', '.join(str(f) for f in files)) 781 errors.update(code_format.check(self._context(code_format))) 782 783 return collections.OrderedDict(sorted(errors.items())) 784 785 def fix(self) -> dict[Path, str]: 786 """Fixes format errors for supported files in place.""" 787 all_errors: dict[Path, str] = {} 788 for code_format, files in self._formats.items(): 789 errors = code_format.fix(self._context(code_format)) 790 if errors: 791 for path, error in errors.items(): 792 _LOG.error('Failed to format %s', path) 793 for line in error.splitlines(): 794 _LOG.error('%s', line) 795 all_errors.update(errors) 796 continue 797 798 _LOG.info( 799 'Formatted %s', plural(files, code_format.language + ' file') 800 ) 801 return all_errors 802 803 804def _file_summary(files: Iterable[Path | str], base: Path) -> list[str]: 805 try: 806 return file_summary( 807 Path(f).resolve().relative_to(base.resolve()) for f in files 808 ) 809 except ValueError: 810 return [] 811 812 813def format_paths_in_repo( 814 paths: Collection[Path | str], 815 exclude: Collection[Pattern[str]], 816 fix: bool, 817 base: str, 818 code_formats: Collection[CodeFormat] = CODE_FORMATS, 819 output_directory: Path | None = None, 820 package_root: Path | None = None, 821) -> int: 822 """Checks or fixes formatting for files in a Git repo.""" 823 824 files = [Path(path).resolve() for path in paths if os.path.isfile(path)] 825 repo = git_repo.root() if git_repo.is_repo() else None 826 827 # Implement a graceful fallback in case the tracking branch isn't available. 828 if base == git_repo.TRACKING_BRANCH_ALIAS and not git_repo.tracking_branch( 829 repo 830 ): 831 _LOG.warning( 832 'Failed to determine the tracking branch, using --base HEAD~1 ' 833 'instead of listing all files' 834 ) 835 base = 'HEAD~1' 836 837 # If this is a Git repo, list the original paths with git ls-files or diff. 838 if repo: 839 project_root = pw_cli.env.pigweed_environment().PW_PROJECT_ROOT 840 _LOG.info( 841 'Formatting %s', 842 git_repo.describe_files( 843 repo, Path.cwd(), base, paths, exclude, project_root 844 ), 845 ) 846 847 # Add files from Git and remove duplicates. 848 files = sorted( 849 set(exclude_paths(exclude, git_repo.list_files(base, paths))) 850 | set(files) 851 ) 852 elif base: 853 _LOG.critical( 854 'A base commit may only be provided if running from a Git repo' 855 ) 856 return 1 857 858 return format_files( 859 files, 860 fix, 861 repo=repo, 862 code_formats=code_formats, 863 output_directory=output_directory, 864 package_root=package_root, 865 ) 866 867 868def format_files( 869 paths: Collection[Path | str], 870 fix: bool, 871 repo: Path | None = None, 872 code_formats: Collection[CodeFormat] = CODE_FORMATS, 873 output_directory: Path | None = None, 874 package_root: Path | None = None, 875) -> int: 876 """Checks or fixes formatting for the specified files.""" 877 878 root: Path | None = None 879 880 if git_repo.is_repo(): 881 root = git_repo.root() 882 elif paths: 883 parent = Path(next(iter(paths))).parent 884 if git_repo.is_repo(parent): 885 root = git_repo.root(parent) 886 887 output_dir: Path 888 if output_directory: 889 output_dir = output_directory 890 elif root: 891 output_dir = root / _DEFAULT_PATH 892 else: 893 tempdir = tempfile.TemporaryDirectory() 894 output_dir = Path(tempdir.name) 895 896 formatter = CodeFormatter( 897 files=(Path(p) for p in paths), 898 code_formats=code_formats, 899 root=root, 900 output_dir=output_dir, 901 package_root=package_root, 902 ) 903 904 _LOG.info('Checking formatting for %s', plural(formatter.paths, 'file')) 905 906 for line in _file_summary(paths, repo if repo else Path.cwd()): 907 print(line, file=sys.stderr) 908 909 check_errors = formatter.check() 910 print_format_check(check_errors, show_fix_commands=(not fix)) 911 912 if check_errors: 913 if fix: 914 _LOG.info( 915 'Applying formatting fixes to %d files', len(check_errors) 916 ) 917 fix_errors = formatter.fix() 918 if fix_errors: 919 _LOG.info('Failed to apply formatting fixes') 920 print_format_check(fix_errors, show_fix_commands=False) 921 return 1 922 923 _LOG.info('Formatting fixes applied successfully') 924 return 0 925 926 _LOG.error('Formatting errors found') 927 return 1 928 929 _LOG.info('Congratulations! No formatting changes needed') 930 return 0 931 932 933def arguments(git_paths: bool) -> argparse.ArgumentParser: 934 """Creates an argument parser for format_files or format_paths_in_repo.""" 935 936 parser = argparse.ArgumentParser(description=__doc__) 937 938 if git_paths: 939 cli.add_path_arguments(parser) 940 else: 941 942 def existing_path(arg: str) -> Path: 943 path = Path(arg) 944 if not path.is_file(): 945 raise argparse.ArgumentTypeError( 946 f'{arg} is not a path to a file' 947 ) 948 949 return path 950 951 parser.add_argument( 952 'paths', 953 metavar='path', 954 nargs='+', 955 type=existing_path, 956 help='File paths to check', 957 ) 958 959 parser.add_argument( 960 '--fix', action='store_true', help='Apply formatting fixes in place.' 961 ) 962 963 parser.add_argument( 964 '--output-directory', 965 type=Path, 966 help=f"Output directory (default: {'<repo root>' / _DEFAULT_PATH})", 967 ) 968 parser.add_argument( 969 '--package-root', 970 type=Path, 971 default=Path(os.environ['PW_PACKAGE_ROOT']), 972 help='Package root directory', 973 ) 974 975 return parser 976 977 978def main() -> int: 979 """Check and fix formatting for source files.""" 980 return format_paths_in_repo(**vars(arguments(git_paths=True).parse_args())) 981 982 983if __name__ == '__main__': 984 try: 985 # If pw_cli is available, use it to initialize logs. 986 from pw_cli import log # pylint: disable=ungrouped-imports 987 988 log.install(logging.INFO) 989 except ImportError: 990 # If pw_cli isn't available, display log messages like a simple print. 991 logging.basicConfig(format='%(message)s', level=logging.INFO) 992 993 sys.exit(main()) 994