xref: /aosp_15_r20/external/pigweed/pw_presubmit/py/pw_presubmit/format_code.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1#!/usr/bin/env python3
2
3# Copyright 2020 The Pigweed Authors
4#
5# Licensed under the Apache License, Version 2.0 (the "License"); you may not
6# use this file except in compliance with the License. You may obtain a copy of
7# the License at
8#
9#     https://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
14# License for the specific language governing permissions and limitations under
15# the License.
16"""Checks and fixes formatting for source files.
17
18This uses clang-format, gn format, gofmt, and python -m yapf to format source
19code. These tools must be available on the path when this script is invoked!
20"""
21
22import argparse
23import collections
24import difflib
25import json
26import logging
27import os
28from pathlib import Path
29import re
30import shutil
31import subprocess
32import sys
33import tempfile
34from typing import (
35    Callable,
36    Collection,
37    Iterable,
38    NamedTuple,
39    Optional,
40    Pattern,
41    TextIO,
42)
43
44import pw_cli.color
45from pw_cli.diff import colorize_diff
46import pw_cli.env
47from pw_cli.file_filter import FileFilter
48from pw_cli.plural import plural
49import pw_env_setup.config_file
50from pw_presubmit.presubmit import filter_paths
51from pw_presubmit.presubmit_context import (
52    FormatContext,
53    FormatOptions,
54    PresubmitContext,
55    PresubmitFailure,
56)
57from pw_presubmit import (
58    cli,
59    git_repo,
60    owners_checks,
61    presubmit_context,
62)
63from pw_presubmit.format.core import FormattedDiff, FormatFixStatus
64from pw_presubmit.format.cpp import ClangFormatFormatter
65from pw_presubmit.format.bazel import BuildifierFormatter
66from pw_presubmit.format.gn import GnFormatter
67from pw_presubmit.format.python import BlackFormatter
68from pw_presubmit.tools import (
69    exclude_paths,
70    file_summary,
71    log_run,
72    PresubmitToolRunner,
73)
74from pw_presubmit.rst_format import reformat_rst
75
76_LOG: logging.Logger = logging.getLogger(__name__)
77_COLOR = pw_cli.color.colors()
78_DEFAULT_PATH = Path('out', 'format')
79
80_Context = PresubmitContext | FormatContext
81
82
83def _ensure_newline(orig: bytes) -> bytes:
84    if orig.endswith(b'\n'):
85        return orig
86    return orig + b'\nNo newline at end of file\n'
87
88
89def _diff(path, original: bytes, formatted: bytes) -> str:
90    original = _ensure_newline(original)
91    formatted = _ensure_newline(formatted)
92    return ''.join(
93        difflib.unified_diff(
94            original.decode(errors='replace').splitlines(True),
95            formatted.decode(errors='replace').splitlines(True),
96            f'{path}  (original)',
97            f'{path}  (reformatted)',
98        )
99    )
100
101
102FormatterT = Callable[[str, bytes], bytes]
103
104
105def _diff_formatted(
106    path, formatter: FormatterT, dry_run: bool = False
107) -> str | None:
108    """Returns a diff comparing a file to its formatted version."""
109    with open(path, 'rb') as fd:
110        original = fd.read()
111
112    formatted = formatter(path, original)
113
114    if dry_run:
115        return None
116
117    return None if formatted == original else _diff(path, original, formatted)
118
119
120def _check_files(
121    files, formatter: FormatterT, dry_run: bool = False
122) -> dict[Path, str]:
123    errors = {}
124
125    for path in files:
126        difference = _diff_formatted(path, formatter, dry_run)
127        if difference:
128            errors[path] = difference
129
130    return errors
131
132
133def _make_formatting_diff_dict(
134    diffs: Iterable[FormattedDiff],
135) -> dict[Path, str]:
136    """Adapts the formatting check API to work with this presubmit tooling."""
137    return {
138        result.file_path: (
139            result.diff if result.ok else str(result.error_message)
140        )
141        for result in diffs
142    }
143
144
145def _make_format_fix_error_output_dict(
146    statuses: Iterable[tuple[Path, FormatFixStatus]],
147) -> dict[Path, str]:
148    """Adapts the formatter API to work with this presubmit tooling."""
149    return {
150        file_path: str(status.error_message) for file_path, status in statuses
151    }
152
153
154def clang_format_check(ctx: _Context) -> dict[Path, str]:
155    """Checks formatting; returns {path: diff} for files with bad formatting."""
156    formatter = ClangFormatFormatter(tool_runner=PresubmitToolRunner())
157    return _make_formatting_diff_dict(
158        formatter.get_formatting_diffs(ctx.paths, ctx.dry_run)
159    )
160
161
162def clang_format_fix(ctx: _Context) -> dict[Path, str]:
163    """Fixes formatting for the provided files in place."""
164    formatter = ClangFormatFormatter(tool_runner=PresubmitToolRunner())
165    return _make_format_fix_error_output_dict(formatter.format_files(ctx.paths))
166
167
168def _typescript_format(*args: Path | str, **kwargs) -> bytes:
169    # TODO: b/323378974 - Better integrate NPM actions with pw_env_setup so
170    # we don't have to manually set `npm_config_cache` every time we run npm.
171    # Force npm cache to live inside the environment directory.
172    npm_env = os.environ.copy()
173    npm_env['npm_config_cache'] = str(
174        Path(npm_env['_PW_ACTUAL_ENVIRONMENT_ROOT']) / 'npm-cache'
175    )
176
177    npm = shutil.which('npm.cmd' if os.name == 'nt' else 'npm')
178    return log_run(
179        [npm, 'exec', 'prettier', *args],
180        stdout=subprocess.PIPE,
181        stdin=subprocess.DEVNULL,
182        check=True,
183        env=npm_env,
184        **kwargs,
185    ).stdout
186
187
188def typescript_format_check(ctx: _Context) -> dict[Path, str]:
189    """Checks formatting; returns {path: diff} for files with bad formatting."""
190    return _check_files(
191        ctx.paths,
192        lambda path, _: _typescript_format(path),
193        ctx.dry_run,
194    )
195
196
197def typescript_format_fix(ctx: _Context) -> dict[Path, str]:
198    """Fixes formatting for the provided files in place."""
199    print_format_fix(_typescript_format(*ctx.paths, '--', '--write'))
200    return {}
201
202
203def check_gn_format(ctx: _Context) -> dict[Path, str]:
204    """Checks formatting; returns {path: diff} for files with bad formatting."""
205    formatter = GnFormatter(tool_runner=PresubmitToolRunner())
206    return _make_formatting_diff_dict(
207        formatter.get_formatting_diffs(
208            ctx.paths,
209            ctx.dry_run,
210        )
211    )
212
213
214def fix_gn_format(ctx: _Context) -> dict[Path, str]:
215    """Fixes formatting for the provided files in place."""
216    formatter = GnFormatter(tool_runner=PresubmitToolRunner())
217    return _make_format_fix_error_output_dict(formatter.format_files(ctx.paths))
218
219
220def check_bazel_format(ctx: _Context) -> dict[Path, str]:
221    """Checks formatting; returns {path: diff} for files with bad formatting."""
222    formatter = BuildifierFormatter(tool_runner=PresubmitToolRunner())
223    return _make_formatting_diff_dict(
224        formatter.get_formatting_diffs(
225            ctx.paths,
226            ctx.dry_run,
227        )
228    )
229
230
231def fix_bazel_format(ctx: _Context) -> dict[Path, str]:
232    """Fixes formatting for the provided files in place."""
233    formatter = BuildifierFormatter(tool_runner=PresubmitToolRunner())
234    return _make_format_fix_error_output_dict(formatter.format_files(ctx.paths))
235
236
237def check_owners_format(ctx: _Context) -> dict[Path, str]:
238    return owners_checks.run_owners_checks(ctx.paths)
239
240
241def fix_owners_format(ctx: _Context) -> dict[Path, str]:
242    return owners_checks.format_owners_file(ctx.paths)
243
244
245def check_go_format(ctx: _Context) -> dict[Path, str]:
246    """Checks formatting; returns {path: diff} for files with bad formatting."""
247    return _check_files(
248        ctx.paths,
249        lambda path, _: log_run(
250            ['gofmt', path], stdout=subprocess.PIPE, check=True
251        ).stdout,
252        ctx.dry_run,
253    )
254
255
256def fix_go_format(ctx: _Context) -> dict[Path, str]:
257    """Fixes formatting for the provided files in place."""
258    log_run(['gofmt', '-w', *ctx.paths], check=True)
259    return {}
260
261
262# TODO: b/259595799 - Remove yapf support.
263def _yapf(*args, **kwargs) -> subprocess.CompletedProcess:
264    return log_run(
265        ['python', '-m', 'yapf', '--parallel', *args],
266        capture_output=True,
267        **kwargs,
268    )
269
270
271_DIFF_START = re.compile(r'^--- (.*)\s+\(original\)$', flags=re.MULTILINE)
272
273
274def check_py_format_yapf(ctx: _Context) -> dict[Path, str]:
275    """Checks formatting; returns {path: diff} for files with bad formatting."""
276    process = _yapf('--diff', *ctx.paths)
277
278    errors: dict[Path, str] = {}
279
280    if process.stdout:
281        raw_diff = process.stdout.decode(errors='replace')
282
283        matches = tuple(_DIFF_START.finditer(raw_diff))
284        for start, end in zip(matches, (*matches[1:], None)):
285            errors[Path(start.group(1))] = raw_diff[
286                start.start() : end.start() if end else None
287            ]
288
289    if process.stderr:
290        _LOG.error(
291            'yapf encountered an error:\n%s',
292            process.stderr.decode(errors='replace').rstrip(),
293        )
294        errors.update({file: '' for file in ctx.paths if file not in errors})
295
296    return errors
297
298
299def fix_py_format_yapf(ctx: _Context) -> dict[Path, str]:
300    """Fixes formatting for the provided files in place."""
301    print_format_fix(_yapf('--in-place', *ctx.paths, check=True).stdout)
302    return {}
303
304
305def _enumerate_black_configs() -> Iterable[Path]:
306    config = pw_env_setup.config_file.load()
307    black_config_file = (
308        config.get('pw', {})
309        .get('pw_presubmit', {})
310        .get('format', {})
311        .get('black_config_file', {})
312    )
313    if black_config_file:
314        explicit_path = Path(black_config_file)
315        if not explicit_path.is_file():
316            raise ValueError(f'Black config file not found: {explicit_path}')
317        yield explicit_path
318        return  # If an explicit path is provided, don't try implicit paths.
319
320    if directory := os.environ.get('PW_PROJECT_ROOT'):
321        yield Path(directory, '.black.toml')
322        yield Path(directory, 'pyproject.toml')
323
324    if directory := os.environ.get('PW_ROOT'):
325        yield Path(directory, '.black.toml')
326        yield Path(directory, 'pyproject.toml')
327
328
329def _select_black_config_file() -> Optional[Path]:
330    config = None
331    for config_location in _enumerate_black_configs():
332        if config_location.is_file():
333            config = config_location
334            break
335    return config
336
337
338def check_py_format_black(ctx: _Context) -> dict[Path, str]:
339    """Checks formatting; returns {path: diff} for files with bad formatting."""
340    formatter = BlackFormatter(
341        _select_black_config_file(), tool_runner=PresubmitToolRunner()
342    )
343    return _make_formatting_diff_dict(
344        formatter.get_formatting_diffs(
345            ctx.paths,
346            ctx.dry_run,
347        )
348    )
349
350
351def fix_py_format_black(ctx: _Context) -> dict[Path, str]:
352    """Fixes formatting for the provided files in place."""
353    formatter = BlackFormatter(
354        _select_black_config_file(), tool_runner=PresubmitToolRunner()
355    )
356    return _make_format_fix_error_output_dict(formatter.format_files(ctx.paths))
357
358
359def check_py_format(ctx: _Context) -> dict[Path, str]:
360    if ctx.format_options.python_formatter == 'black':
361        return check_py_format_black(ctx)
362    if ctx.format_options.python_formatter == 'yapf':
363        return check_py_format_yapf(ctx)
364    raise ValueError(ctx.format_options.python_formatter)
365
366
367def fix_py_format(ctx: _Context) -> dict[Path, str]:
368    if ctx.format_options.python_formatter == 'black':
369        return fix_py_format_black(ctx)
370    if ctx.format_options.python_formatter == 'yapf':
371        return fix_py_format_yapf(ctx)
372    raise ValueError(ctx.format_options.python_formatter)
373
374
375_TRAILING_SPACE = re.compile(rb'[ \t]+$', flags=re.MULTILINE)
376
377
378def _check_trailing_space(paths: Iterable[Path], fix: bool) -> dict[Path, str]:
379    """Checks for and optionally removes trailing whitespace."""
380    errors = {}
381
382    for path in paths:
383        with path.open('rb') as fd:
384            contents = fd.read()
385
386        corrected = _TRAILING_SPACE.sub(b'', contents)
387        if corrected != contents:
388            errors[path] = _diff(path, contents, corrected)
389
390            if fix:
391                with path.open('wb') as fd:
392                    fd.write(corrected)
393
394    return errors
395
396
397def _format_json(contents: bytes) -> bytes:
398    return json.dumps(json.loads(contents), indent=2).encode() + b'\n'
399
400
401def _json_error(exc: json.JSONDecodeError, path: Path) -> str:
402    return f'{path}: {exc.msg} {exc.lineno}:{exc.colno}\n'
403
404
405def check_json_format(ctx: _Context) -> dict[Path, str]:
406    errors = {}
407
408    for path in ctx.paths:
409        orig = path.read_bytes()
410        try:
411            formatted = _format_json(orig)
412        except json.JSONDecodeError as exc:
413            errors[path] = _json_error(exc, path)
414            continue
415
416        if orig != formatted:
417            errors[path] = _diff(path, orig, formatted)
418
419    return errors
420
421
422def fix_json_format(ctx: _Context) -> dict[Path, str]:
423    errors = {}
424    for path in ctx.paths:
425        orig = path.read_bytes()
426        try:
427            formatted = _format_json(orig)
428        except json.JSONDecodeError as exc:
429            errors[path] = _json_error(exc, path)
430            continue
431
432        if orig != formatted:
433            path.write_bytes(formatted)
434
435    return errors
436
437
438def check_trailing_space(ctx: _Context) -> dict[Path, str]:
439    return _check_trailing_space(ctx.paths, fix=False)
440
441
442def fix_trailing_space(ctx: _Context) -> dict[Path, str]:
443    _check_trailing_space(ctx.paths, fix=True)
444    return {}
445
446
447def rst_format_check(ctx: _Context) -> dict[Path, str]:
448    errors: dict[Path, str] = {}
449    for path in ctx.paths:
450        result = reformat_rst(
451            path, diff=True, in_place=False, suppress_stdout=True
452        )
453        if result:
454            errors[path] = ''.join(result)
455    return errors
456
457
458def rst_format_fix(ctx: _Context) -> dict[Path, str]:
459    errors: dict[Path, str] = {}
460    for path in ctx.paths:
461        reformat_rst(path, diff=True, in_place=True, suppress_stdout=True)
462    return errors
463
464
465def print_format_check(
466    errors: dict[Path, str],
467    show_fix_commands: bool,
468    show_summary: bool = True,
469    colors: bool | None = None,
470    file: TextIO = sys.stdout,
471) -> None:
472    """Prints and returns the result of a check_*_format function."""
473    if not errors:
474        # Don't print anything in the all-good case.
475        return
476
477    if colors is None:
478        colors = file == sys.stdout
479
480    # Show the format fixing diff suggested by the tooling (with colors).
481    if show_summary:
482        _LOG.warning(
483            'Found %d files with formatting errors. Format changes:',
484            len(errors),
485        )
486    for diff in errors.values():
487        if colors:
488            diff = colorize_diff(diff)
489        print(diff, end='', file=file)
490
491    # Show a copy-and-pastable command to fix the issues.
492    if show_fix_commands:
493
494        def path_relative_to_cwd(path: Path):
495            try:
496                return Path(path).resolve().relative_to(Path.cwd().resolve())
497            except ValueError:
498                return Path(path).resolve()
499
500        message = (
501            f'  pw format --fix {path_relative_to_cwd(path)}' for path in errors
502        )
503        _LOG.warning('To fix formatting, run:\n\n%s\n', '\n'.join(message))
504
505
506def print_format_fix(stdout: bytes):
507    """Prints the output of a format --fix call."""
508    for line in stdout.splitlines():
509        _LOG.info('Fix cmd stdout: %r', line.decode('utf-8'))
510
511
512class CodeFormat(NamedTuple):
513    language: str
514    filter: FileFilter
515    check: Callable[[_Context], dict[Path, str]]
516    fix: Callable[[_Context], dict[Path, str]]
517
518    @property
519    def extensions(self):
520        # TODO: b/23842636 - Switch calls of this to using 'filter' and remove.
521        return self.filter.endswith
522
523
524CPP_HEADER_EXTS = frozenset(('.h', '.hpp', '.hxx', '.h++', '.hh', '.H'))
525CPP_SOURCE_EXTS = frozenset(
526    ('.c', '.cpp', '.cxx', '.c++', '.cc', '.C', '.inc', '.inl')
527)
528CPP_EXTS = CPP_HEADER_EXTS.union(CPP_SOURCE_EXTS)
529CPP_FILE_FILTER = FileFilter(
530    endswith=CPP_EXTS, exclude=[r'\.pb\.h$', r'\.pb\.c$']
531)
532
533C_FORMAT = CodeFormat(
534    'C and C++', CPP_FILE_FILTER, clang_format_check, clang_format_fix
535)
536
537PROTO_FORMAT: CodeFormat = CodeFormat(
538    'Protocol buffer',
539    FileFilter(endswith=['.proto']),
540    clang_format_check,
541    clang_format_fix,
542)
543
544JAVA_FORMAT: CodeFormat = CodeFormat(
545    'Java',
546    FileFilter(endswith=['.java']),
547    clang_format_check,
548    clang_format_fix,
549)
550
551JAVASCRIPT_FORMAT: CodeFormat = CodeFormat(
552    'JavaScript',
553    FileFilter(endswith=['.js']),
554    typescript_format_check,
555    typescript_format_fix,
556)
557
558TYPESCRIPT_FORMAT: CodeFormat = CodeFormat(
559    'TypeScript',
560    FileFilter(endswith=['.ts']),
561    typescript_format_check,
562    typescript_format_fix,
563)
564
565# TODO: b/308948504 - Add real code formatting support for CSS
566CSS_FORMAT: CodeFormat = CodeFormat(
567    'css',
568    FileFilter(endswith=['.css']),
569    check_trailing_space,
570    fix_trailing_space,
571)
572
573GO_FORMAT: CodeFormat = CodeFormat(
574    'Go', FileFilter(endswith=['.go']), check_go_format, fix_go_format
575)
576
577PYTHON_FORMAT: CodeFormat = CodeFormat(
578    'Python',
579    FileFilter(endswith=['.py']),
580    check_py_format,
581    fix_py_format,
582)
583
584GN_FORMAT: CodeFormat = CodeFormat(
585    'GN', FileFilter(endswith=['.gn', '.gni']), check_gn_format, fix_gn_format
586)
587
588BAZEL_FORMAT: CodeFormat = CodeFormat(
589    'Bazel',
590    FileFilter(endswith=['.bazel', '.bzl'], name=['^BUILD$', '^WORKSPACE$']),
591    check_bazel_format,
592    fix_bazel_format,
593)
594
595COPYBARA_FORMAT: CodeFormat = CodeFormat(
596    'Copybara',
597    FileFilter(endswith=['.bara.sky']),
598    check_bazel_format,
599    fix_bazel_format,
600)
601
602# TODO: b/234881054 - Add real code formatting support for CMake
603CMAKE_FORMAT: CodeFormat = CodeFormat(
604    'CMake',
605    FileFilter(endswith=['.cmake'], name=['^CMakeLists.txt$']),
606    check_trailing_space,
607    fix_trailing_space,
608)
609
610RST_FORMAT: CodeFormat = CodeFormat(
611    'reStructuredText',
612    FileFilter(endswith=['.rst']),
613    rst_format_check,
614    rst_format_fix,
615)
616
617MARKDOWN_FORMAT: CodeFormat = CodeFormat(
618    'Markdown',
619    FileFilter(endswith=['.md']),
620    check_trailing_space,
621    fix_trailing_space,
622)
623
624OWNERS_CODE_FORMAT = CodeFormat(
625    'OWNERS',
626    filter=FileFilter(name=['^OWNERS$']),
627    check=check_owners_format,
628    fix=fix_owners_format,
629)
630
631JSON_FORMAT: CodeFormat = CodeFormat(
632    'JSON',
633    FileFilter(endswith=['.json']),
634    check=check_json_format,
635    fix=fix_json_format,
636)
637
638CODE_FORMATS: tuple[CodeFormat, ...] = tuple(
639    filter(
640        None,
641        (
642            # keep-sorted: start
643            BAZEL_FORMAT,
644            CMAKE_FORMAT,
645            COPYBARA_FORMAT,
646            CSS_FORMAT,
647            C_FORMAT,
648            GN_FORMAT,
649            GO_FORMAT,
650            JAVASCRIPT_FORMAT if shutil.which('npm') else None,
651            JAVA_FORMAT,
652            JSON_FORMAT,
653            MARKDOWN_FORMAT,
654            OWNERS_CODE_FORMAT,
655            PROTO_FORMAT,
656            PYTHON_FORMAT,
657            RST_FORMAT,
658            TYPESCRIPT_FORMAT if shutil.which('npm') else None,
659            # keep-sorted: end
660        ),
661    )
662)
663
664
665# TODO: b/264578594 - Remove these lines when these globals aren't referenced.
666CODE_FORMATS_WITH_BLACK: tuple[CodeFormat, ...] = CODE_FORMATS
667CODE_FORMATS_WITH_YAPF: tuple[CodeFormat, ...] = CODE_FORMATS
668
669
670def presubmit_check(
671    code_format: CodeFormat,
672    *,
673    exclude: Collection[str | Pattern[str]] = (),
674) -> Callable:
675    """Creates a presubmit check function from a CodeFormat object.
676
677    Args:
678      exclude: Additional exclusion regexes to apply.
679    """
680
681    # Make a copy of the FileFilter and add in any additional excludes.
682    file_filter = FileFilter(**vars(code_format.filter))
683    file_filter.exclude += tuple(re.compile(e) for e in exclude)
684
685    @filter_paths(file_filter=file_filter)
686    def check_code_format(ctx: PresubmitContext):
687        ctx.paths = presubmit_context.apply_exclusions(ctx)
688        errors = code_format.check(ctx)
689        print_format_check(
690            errors,
691            # When running as part of presubmit, show the fix command help.
692            show_fix_commands=True,
693        )
694        if not errors:
695            return
696
697        with ctx.failure_summary_log.open('w') as outs:
698            print_format_check(
699                errors,
700                show_summary=False,
701                show_fix_commands=False,
702                file=outs,
703            )
704
705        raise PresubmitFailure
706
707    language = code_format.language.lower().replace('+', 'p').replace(' ', '_')
708    check_code_format.name = f'{language}_format'
709    check_code_format.doc = f'Check the format of {code_format.language} files.'
710
711    return check_code_format
712
713
714def presubmit_checks(
715    *,
716    exclude: Collection[str | Pattern[str]] = (),
717    code_formats: Collection[CodeFormat] = CODE_FORMATS,
718) -> tuple[Callable, ...]:
719    """Returns a tuple with all supported code format presubmit checks.
720
721    Args:
722      exclude: Additional exclusion regexes to apply.
723      code_formats: A list of CodeFormat objects to run checks with.
724    """
725
726    return tuple(presubmit_check(fmt, exclude=exclude) for fmt in code_formats)
727
728
729class CodeFormatter:
730    """Checks or fixes the formatting of a set of files."""
731
732    def __init__(
733        self,
734        root: Path | None,
735        files: Iterable[Path],
736        output_dir: Path,
737        code_formats: Collection[CodeFormat] = CODE_FORMATS_WITH_YAPF,
738        package_root: Path | None = None,
739    ):
740        self.root = root
741        self._formats: dict[CodeFormat, list] = collections.defaultdict(list)
742        self.root_output_dir = output_dir
743        self.package_root = package_root or output_dir / 'packages'
744        self._format_options = FormatOptions.load()
745        raw_paths = files
746        self.paths: tuple[Path, ...] = self._format_options.filter_paths(files)
747
748        filtered_paths = set(raw_paths) - set(self.paths)
749        for path in sorted(filtered_paths):
750            _LOG.debug('filtered out %s', path)
751
752        for path in self.paths:
753            for code_format in code_formats:
754                if code_format.filter.matches(path):
755                    _LOG.debug(
756                        'Formatting %s as %s', path, code_format.language
757                    )
758                    self._formats[code_format].append(path)
759                    break
760            else:
761                _LOG.debug('No formatter found for %s', path)
762
763    def _context(self, code_format: CodeFormat):
764        outdir = self.root_output_dir / code_format.language.replace(' ', '_')
765        os.makedirs(outdir, exist_ok=True)
766
767        return FormatContext(
768            root=self.root,
769            output_dir=outdir,
770            paths=tuple(self._formats[code_format]),
771            package_root=self.package_root,
772            format_options=self._format_options,
773        )
774
775    def check(self) -> dict[Path, str]:
776        """Returns {path: diff} for files with incorrect formatting."""
777        errors: dict[Path, str] = {}
778
779        for code_format, files in self._formats.items():
780            _LOG.debug('Checking %s', ', '.join(str(f) for f in files))
781            errors.update(code_format.check(self._context(code_format)))
782
783        return collections.OrderedDict(sorted(errors.items()))
784
785    def fix(self) -> dict[Path, str]:
786        """Fixes format errors for supported files in place."""
787        all_errors: dict[Path, str] = {}
788        for code_format, files in self._formats.items():
789            errors = code_format.fix(self._context(code_format))
790            if errors:
791                for path, error in errors.items():
792                    _LOG.error('Failed to format %s', path)
793                    for line in error.splitlines():
794                        _LOG.error('%s', line)
795                all_errors.update(errors)
796                continue
797
798            _LOG.info(
799                'Formatted %s', plural(files, code_format.language + ' file')
800            )
801        return all_errors
802
803
804def _file_summary(files: Iterable[Path | str], base: Path) -> list[str]:
805    try:
806        return file_summary(
807            Path(f).resolve().relative_to(base.resolve()) for f in files
808        )
809    except ValueError:
810        return []
811
812
813def format_paths_in_repo(
814    paths: Collection[Path | str],
815    exclude: Collection[Pattern[str]],
816    fix: bool,
817    base: str,
818    code_formats: Collection[CodeFormat] = CODE_FORMATS,
819    output_directory: Path | None = None,
820    package_root: Path | None = None,
821) -> int:
822    """Checks or fixes formatting for files in a Git repo."""
823
824    files = [Path(path).resolve() for path in paths if os.path.isfile(path)]
825    repo = git_repo.root() if git_repo.is_repo() else None
826
827    # Implement a graceful fallback in case the tracking branch isn't available.
828    if base == git_repo.TRACKING_BRANCH_ALIAS and not git_repo.tracking_branch(
829        repo
830    ):
831        _LOG.warning(
832            'Failed to determine the tracking branch, using --base HEAD~1 '
833            'instead of listing all files'
834        )
835        base = 'HEAD~1'
836
837    # If this is a Git repo, list the original paths with git ls-files or diff.
838    if repo:
839        project_root = pw_cli.env.pigweed_environment().PW_PROJECT_ROOT
840        _LOG.info(
841            'Formatting %s',
842            git_repo.describe_files(
843                repo, Path.cwd(), base, paths, exclude, project_root
844            ),
845        )
846
847        # Add files from Git and remove duplicates.
848        files = sorted(
849            set(exclude_paths(exclude, git_repo.list_files(base, paths)))
850            | set(files)
851        )
852    elif base:
853        _LOG.critical(
854            'A base commit may only be provided if running from a Git repo'
855        )
856        return 1
857
858    return format_files(
859        files,
860        fix,
861        repo=repo,
862        code_formats=code_formats,
863        output_directory=output_directory,
864        package_root=package_root,
865    )
866
867
868def format_files(
869    paths: Collection[Path | str],
870    fix: bool,
871    repo: Path | None = None,
872    code_formats: Collection[CodeFormat] = CODE_FORMATS,
873    output_directory: Path | None = None,
874    package_root: Path | None = None,
875) -> int:
876    """Checks or fixes formatting for the specified files."""
877
878    root: Path | None = None
879
880    if git_repo.is_repo():
881        root = git_repo.root()
882    elif paths:
883        parent = Path(next(iter(paths))).parent
884        if git_repo.is_repo(parent):
885            root = git_repo.root(parent)
886
887    output_dir: Path
888    if output_directory:
889        output_dir = output_directory
890    elif root:
891        output_dir = root / _DEFAULT_PATH
892    else:
893        tempdir = tempfile.TemporaryDirectory()
894        output_dir = Path(tempdir.name)
895
896    formatter = CodeFormatter(
897        files=(Path(p) for p in paths),
898        code_formats=code_formats,
899        root=root,
900        output_dir=output_dir,
901        package_root=package_root,
902    )
903
904    _LOG.info('Checking formatting for %s', plural(formatter.paths, 'file'))
905
906    for line in _file_summary(paths, repo if repo else Path.cwd()):
907        print(line, file=sys.stderr)
908
909    check_errors = formatter.check()
910    print_format_check(check_errors, show_fix_commands=(not fix))
911
912    if check_errors:
913        if fix:
914            _LOG.info(
915                'Applying formatting fixes to %d files', len(check_errors)
916            )
917            fix_errors = formatter.fix()
918            if fix_errors:
919                _LOG.info('Failed to apply formatting fixes')
920                print_format_check(fix_errors, show_fix_commands=False)
921                return 1
922
923            _LOG.info('Formatting fixes applied successfully')
924            return 0
925
926        _LOG.error('Formatting errors found')
927        return 1
928
929    _LOG.info('Congratulations! No formatting changes needed')
930    return 0
931
932
933def arguments(git_paths: bool) -> argparse.ArgumentParser:
934    """Creates an argument parser for format_files or format_paths_in_repo."""
935
936    parser = argparse.ArgumentParser(description=__doc__)
937
938    if git_paths:
939        cli.add_path_arguments(parser)
940    else:
941
942        def existing_path(arg: str) -> Path:
943            path = Path(arg)
944            if not path.is_file():
945                raise argparse.ArgumentTypeError(
946                    f'{arg} is not a path to a file'
947                )
948
949            return path
950
951        parser.add_argument(
952            'paths',
953            metavar='path',
954            nargs='+',
955            type=existing_path,
956            help='File paths to check',
957        )
958
959    parser.add_argument(
960        '--fix', action='store_true', help='Apply formatting fixes in place.'
961    )
962
963    parser.add_argument(
964        '--output-directory',
965        type=Path,
966        help=f"Output directory (default: {'<repo root>' / _DEFAULT_PATH})",
967    )
968    parser.add_argument(
969        '--package-root',
970        type=Path,
971        default=Path(os.environ['PW_PACKAGE_ROOT']),
972        help='Package root directory',
973    )
974
975    return parser
976
977
978def main() -> int:
979    """Check and fix formatting for source files."""
980    return format_paths_in_repo(**vars(arguments(git_paths=True).parse_args()))
981
982
983if __name__ == '__main__':
984    try:
985        # If pw_cli is available, use it to initialize logs.
986        from pw_cli import log  # pylint: disable=ungrouped-imports
987
988        log.install(logging.INFO)
989    except ImportError:
990        # If pw_cli isn't available, display log messages like a simple print.
991        logging.basicConfig(format='%(message)s', level=logging.INFO)
992
993    sys.exit(main())
994