xref: /aosp_15_r20/external/pigweed/pw_presubmit/py/pw_presubmit/presubmit.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1# Copyright 2020 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Tools for running presubmit checks in a Git repository.
15
16Presubmit checks are defined as a function or other callable. The function may
17take either no arguments or a list of the paths on which to run. Presubmit
18checks communicate failure by raising any exception.
19
20For example, either of these functions may be used as presubmit checks:
21
22  @pw_presubmit.filter_paths(endswith='.py')
23  def file_contains_ni(ctx: PresubmitContext):
24      for path in ctx.paths:
25          with open(path) as file:
26              contents = file.read()
27              if 'ni' not in contents and 'nee' not in contents:
28                  raise PresumitFailure('Files must say "ni"!', path=path)
29
30  def run_the_build():
31      subprocess.run(['make', 'release'], check=True)
32
33Presubmit checks that accept a list of paths may use the filter_paths decorator
34to automatically filter the paths list for file types they care about. See the
35pragma_once function for an example.
36
37See pigweed_presbumit.py for an example of how to define presubmit checks.
38"""
39
40from __future__ import annotations
41
42import collections
43import contextlib
44import copy
45import dataclasses
46import enum
47from inspect import Parameter, signature
48import itertools
49import json
50import logging
51import os
52from pathlib import Path
53import re
54import signal
55import subprocess
56import sys
57import tempfile
58import time
59import types
60from typing import (
61    Any,
62    Callable,
63    Collection,
64    Iterable,
65    Iterator,
66    Pattern,
67    Sequence,
68    Set,
69)
70
71import pw_cli.color
72import pw_cli.env
73from pw_cli.plural import plural
74from pw_cli.file_filter import FileFilter
75from pw_package import package_manager
76from pw_presubmit import git_repo, tools
77from pw_presubmit.presubmit_context import (
78    FormatContext,
79    FormatOptions,
80    LuciContext,
81    PRESUBMIT_CONTEXT,
82    PresubmitContext,
83    PresubmitFailure,
84    log_check_traces,
85)
86
87_LOG: logging.Logger = logging.getLogger(__name__)
88
89_COLOR = pw_cli.color.colors()
90
91_SUMMARY_BOX = '══╦╗ ║║══╩╝'
92_CHECK_UPPER = '━━━┓       '
93_CHECK_LOWER = '       ━━━┛'
94
95WIDTH = 80
96
97_LEFT = 7
98_RIGHT = 11
99
100
101def _title(msg, style=_SUMMARY_BOX) -> str:
102    msg = f' {msg} '.center(WIDTH - 2)
103    return tools.make_box('^').format(*style, section1=msg, width1=len(msg))
104
105
106def _format_time(time_s: float) -> str:
107    minutes, seconds = divmod(time_s, 60)
108    if minutes < 60:
109        return f' {int(minutes)}:{seconds:04.1f}'
110    hours, minutes = divmod(minutes, 60)
111    return f'{int(hours):d}:{int(minutes):02}:{int(seconds):02}'
112
113
114def _box(style, left, middle, right, box=tools.make_box('><>')) -> str:
115    return box.format(
116        *style,
117        section1=left + ('' if left.endswith(' ') else ' '),
118        width1=_LEFT,
119        section2=' ' + middle,
120        width2=WIDTH - _LEFT - _RIGHT - 4,
121        section3=right + ' ',
122        width3=_RIGHT,
123    )
124
125
126class PresubmitResult(enum.Enum):
127    PASS = 'PASSED'  # Check completed successfully.
128    FAIL = 'FAILED'  # Check failed.
129    CANCEL = 'CANCEL'  # Check didn't complete.
130
131    def colorized(self, width: int, invert: bool = False) -> str:
132        if self is PresubmitResult.PASS:
133            color = _COLOR.black_on_green if invert else _COLOR.green
134        elif self is PresubmitResult.FAIL:
135            color = _COLOR.black_on_red if invert else _COLOR.red
136        elif self is PresubmitResult.CANCEL:
137            color = _COLOR.yellow
138        else:
139
140            def color(value):
141                return value
142
143        padding = (width - len(self.value)) // 2 * ' '
144        return padding + color(self.value) + padding
145
146
147class Program(collections.abc.Sequence):
148    """A sequence of presubmit checks; basically a tuple with a name."""
149
150    def __init__(self, name: str, steps: Iterable[Callable]):
151        self.name = name
152
153        def ensure_check(step):
154            if isinstance(step, Check):
155                return step
156            return Check(step)
157
158        self._steps: tuple[Check, ...] = tuple(
159            {ensure_check(s): None for s in tools.flatten(steps)}
160        )
161
162    def __getitem__(self, i):
163        return self._steps[i]
164
165    def __len__(self):
166        return len(self._steps)
167
168    def __str__(self):
169        return self.name
170
171    def title(self):
172        return f'{self.name if self.name else ""} presubmit checks'.strip()
173
174
175class Programs(collections.abc.Mapping):
176    """A mapping of presubmit check programs.
177
178    Use is optional. Helpful when managing multiple presubmit check programs.
179    """
180
181    def __init__(self, **programs: Sequence):
182        """Initializes a name: program mapping from the provided keyword args.
183
184        A program is a sequence of presubmit check functions. The sequence may
185        contain nested sequences, which are flattened.
186        """
187        self._programs: dict[str, Program] = {
188            name: Program(name, checks) for name, checks in programs.items()
189        }
190
191    def all_steps(self) -> dict[str, Check]:
192        return {c.name: c for c in itertools.chain(*self.values())}
193
194    def __getitem__(self, item: str) -> Program:
195        return self._programs[item]
196
197    def __iter__(self) -> Iterator[str]:
198        return iter(self._programs)
199
200    def __len__(self) -> int:
201        return len(self._programs)
202
203
204def download_cas_artifact(
205    ctx: PresubmitContext, digest: str, output_dir: str
206) -> None:
207    """Downloads the given digest to the given outputdirectory
208
209    Args:
210        ctx: the presubmit context
211        digest:
212        a string digest in the form "<digest hash>/<size bytes>"
213        i.e 693a04e41374150d9d4b645fccb49d6f96e10b527c7a24b1e17b331f508aa73b/86
214        output_dir: the directory we want to download the artifacts to
215    """
216    if ctx.luci is None:
217        raise PresubmitFailure('Lucicontext is None')
218    cmd = [
219        'cas',
220        'download',
221        '-cas-instance',
222        ctx.luci.cas_instance,
223        '-digest',
224        digest,
225        '-dir',
226        output_dir,
227    ]
228    try:
229        subprocess.check_call(cmd)
230    except subprocess.CalledProcessError as failure:
231        raise PresubmitFailure('cas download failed') from failure
232
233
234def archive_cas_artifact(
235    ctx: PresubmitContext, root: str, upload_paths: list[str]
236) -> str:
237    """Uploads the given artifacts into cas
238
239    Args:
240        ctx: the presubmit context
241        root: root directory of archived tree, should be absolutepath.
242        paths: path to archived files/dirs, should be absolute path.
243            If empty, [root] will be used.
244
245    Returns:
246        A string digest in the form "<digest hash>/<size bytes>"
247        i.e 693a04e41374150d9d4b645fccb49d6f96e10b527c7a24b1e17b331f508aa73b/86
248    """
249    if ctx.luci is None:
250        raise PresubmitFailure('Lucicontext is None')
251    assert os.path.abspath(root)
252    if not upload_paths:
253        upload_paths = [root]
254    for path in upload_paths:
255        assert os.path.abspath(path)
256
257    with tempfile.NamedTemporaryFile(mode='w+t') as tmp_digest_file:
258        with tempfile.NamedTemporaryFile(mode='w+t') as tmp_paths_file:
259            json_paths = json.dumps(
260                [
261                    [str(root), str(os.path.relpath(path, root))]
262                    for path in upload_paths
263                ]
264            )
265            tmp_paths_file.write(json_paths)
266            tmp_paths_file.seek(0)
267            cmd = [
268                'cas',
269                'archive',
270                '-cas-instance',
271                ctx.luci.cas_instance,
272                '-paths-json',
273                tmp_paths_file.name,
274                '-dump-digest',
275                tmp_digest_file.name,
276            ]
277            try:
278                subprocess.check_call(cmd)
279            except subprocess.CalledProcessError as failure:
280                raise PresubmitFailure('cas archive failed') from failure
281
282            tmp_digest_file.seek(0)
283            uploaded_digest = tmp_digest_file.read()
284            return uploaded_digest
285
286
287def _print_ui(*args) -> None:
288    """Prints to stdout and flushes to stay in sync with logs on stderr."""
289    print(*args, flush=True)
290
291
292@dataclasses.dataclass
293class FilteredCheck:
294    check: Check
295    paths: Sequence[Path]
296    substep: str | None = None
297
298    @property
299    def name(self) -> str:
300        return self.check.name
301
302    def run(self, ctx: PresubmitContext, count: int, total: int):
303        return self.check.run(ctx, count, total, self.substep)
304
305
306class Presubmit:
307    """Runs a series of presubmit checks on a list of files."""
308
309    def __init__(  # pylint: disable=too-many-arguments
310        self,
311        root: Path,
312        repos: Sequence[Path],
313        output_directory: Path,
314        paths: Sequence[Path],
315        all_paths: Sequence[Path],
316        package_root: Path,
317        override_gn_args: dict[str, str],
318        continue_after_build_error: bool,
319        rng_seed: int,
320        full: bool,
321    ):
322        self._root = root.resolve()
323        self._repos = tuple(repos)
324        self._output_directory = output_directory.resolve()
325        self._paths = tuple(paths)
326        self._all_paths = tuple(all_paths)
327        self._relative_paths = tuple(
328            tools.relative_paths(self._paths, self._root)
329        )
330        self._package_root = package_root.resolve()
331        self._override_gn_args = override_gn_args
332        self._continue_after_build_error = continue_after_build_error
333        self._rng_seed = rng_seed
334        self._full = full
335
336    def run(
337        self,
338        program: Program,
339        keep_going: bool = False,
340        substep: str | None = None,
341        dry_run: bool = False,
342    ) -> bool:
343        """Executes a series of presubmit checks on the paths."""
344        checks = self.apply_filters(program)
345        if substep:
346            assert (
347                len(checks) == 1
348            ), 'substeps not supported with multiple steps'
349            checks[0].substep = substep
350
351        _LOG.debug('Running %s for %s', program.title(), self._root.name)
352        _print_ui(_title(f'{self._root.name}: {program.title()}'))
353
354        _LOG.info(
355            '%d of %d checks apply to %s in %s',
356            len(checks),
357            len(program),
358            plural(self._paths, 'file'),
359            self._root,
360        )
361
362        _print_ui()
363        for line in tools.file_summary(self._relative_paths):
364            _print_ui(line)
365        _print_ui()
366
367        if not self._paths:
368            _print_ui(_COLOR.yellow('No files are being checked!'))
369
370        _LOG.debug('Checks:\n%s', '\n'.join(c.name for c in checks))
371
372        start_time: float = time.time()
373        passed, failed, skipped = self._execute_checks(
374            checks, keep_going, dry_run
375        )
376        self._log_summary(time.time() - start_time, passed, failed, skipped)
377
378        return not failed and not skipped
379
380    def apply_filters(self, program: Sequence[Callable]) -> list[FilteredCheck]:
381        """Returns list of FilteredCheck for checks that should run."""
382        checks = [c if isinstance(c, Check) else Check(c) for c in program]
383        filter_to_checks: dict[
384            FileFilter, list[Check]
385        ] = collections.defaultdict(list)
386
387        for chk in checks:
388            filter_to_checks[chk.filter].append(chk)
389
390        check_to_paths = self._map_checks_to_paths(filter_to_checks)
391        return [
392            FilteredCheck(c, check_to_paths[c])
393            for c in checks
394            if c in check_to_paths
395        ]
396
397    def _map_checks_to_paths(
398        self, filter_to_checks: dict[FileFilter, list[Check]]
399    ) -> dict[Check, Sequence[Path]]:
400        checks_to_paths: dict[Check, Sequence[Path]] = {}
401
402        posix_paths = tuple(p.as_posix() for p in self._relative_paths)
403
404        for filt, checks in filter_to_checks.items():
405            filtered_paths = tuple(
406                path
407                for path, filter_path in zip(self._paths, posix_paths)
408                if filt.matches(filter_path)
409            )
410
411            for chk in checks:
412                if filtered_paths or chk.always_run:
413                    checks_to_paths[chk] = filtered_paths
414                else:
415                    _LOG.debug('Skipping "%s": no relevant files', chk.name)
416
417        return checks_to_paths
418
419    def _log_summary(
420        self, time_s: float, passed: int, failed: int, skipped: int
421    ) -> None:
422        summary_items = []
423        if passed:
424            summary_items.append(f'{passed} passed')
425        if failed:
426            summary_items.append(f'{failed} failed')
427        if skipped:
428            summary_items.append(f'{skipped} not run')
429        summary = ', '.join(summary_items) or 'nothing was done'
430
431        if failed or skipped:
432            result = PresubmitResult.FAIL
433        else:
434            result = PresubmitResult.PASS
435        total = passed + failed + skipped
436
437        _LOG.debug(
438            'Finished running %d checks on %s in %.1f s',
439            total,
440            plural(self._paths, 'file'),
441            time_s,
442        )
443        _LOG.debug('Presubmit checks %s: %s', result.value, summary)
444
445        _print_ui(
446            _box(
447                _SUMMARY_BOX,
448                result.colorized(_LEFT, invert=True),
449                f'{total} checks on {plural(self._paths, "file")}: {summary}',
450                _format_time(time_s),
451            )
452        )
453
454    def _create_presubmit_context(  # pylint: disable=no-self-use
455        self, **kwargs
456    ):
457        """Create a PresubmitContext. Override if needed in subclasses."""
458        return PresubmitContext(**kwargs)
459
460    @contextlib.contextmanager
461    def _context(self, filtered_check: FilteredCheck, dry_run: bool = False):
462        # There are many characters banned from filenames on Windows. To
463        # simplify things, just strip everything that's not a letter, digit,
464        # or underscore.
465        sanitized_name = re.sub(r'[\W_]+', '_', filtered_check.name).lower()
466        output_directory = self._output_directory.joinpath(sanitized_name)
467        os.makedirs(output_directory, exist_ok=True)
468
469        failure_summary_log = output_directory / 'failure-summary.log'
470        failure_summary_log.unlink(missing_ok=True)
471
472        handler = logging.FileHandler(
473            output_directory.joinpath('step.log'), mode='w'
474        )
475        handler.setLevel(logging.DEBUG)
476
477        try:
478            _LOG.addHandler(handler)
479
480            yield self._create_presubmit_context(
481                root=self._root,
482                repos=self._repos,
483                output_dir=output_directory,
484                failure_summary_log=failure_summary_log,
485                paths=filtered_check.paths,
486                all_paths=self._all_paths,
487                package_root=self._package_root,
488                override_gn_args=self._override_gn_args,
489                continue_after_build_error=self._continue_after_build_error,
490                rng_seed=self._rng_seed,
491                full=self._full,
492                luci=LuciContext.create_from_environment(),
493                format_options=FormatOptions.load(),
494                dry_run=dry_run,
495            )
496
497        finally:
498            _LOG.removeHandler(handler)
499
500    def _execute_checks(
501        self,
502        program: list[FilteredCheck],
503        keep_going: bool,
504        dry_run: bool = False,
505    ) -> tuple[int, int, int]:
506        """Runs presubmit checks; returns (passed, failed, skipped) lists."""
507        passed = failed = 0
508
509        for i, filtered_check in enumerate(program, 1):
510            with self._context(filtered_check, dry_run) as ctx:
511                result = filtered_check.run(ctx, i, len(program))
512
513            if result is PresubmitResult.PASS:
514                passed += 1
515            elif result is PresubmitResult.CANCEL:
516                break
517            else:
518                failed += 1
519                if not keep_going:
520                    break
521
522        return passed, failed, len(program) - passed - failed
523
524
525def _process_pathspecs(
526    repos: Iterable[Path], pathspecs: Iterable[str]
527) -> dict[Path, list[str]]:
528    pathspecs_by_repo: dict[Path, list[str]] = {repo: [] for repo in repos}
529    repos_with_paths: Set[Path] = set()
530
531    for pathspec in pathspecs:
532        # If the pathspec is a path to an existing file, only use it for the
533        # repo it is in.
534        if os.path.exists(pathspec):
535            # Raise an exception if the path exists but is not in a known repo.
536            repo = git_repo.within_repo(pathspec)
537            if repo not in pathspecs_by_repo:
538                raise ValueError(
539                    f'{pathspec} is not in a Git repository in this presubmit'
540                )
541
542            # Make the path relative to the repo's root.
543            pathspecs_by_repo[repo].append(os.path.relpath(pathspec, repo))
544            repos_with_paths.add(repo)
545        else:
546            # Pathspecs that are not paths (e.g. '*.h') are used for all repos.
547            for patterns in pathspecs_by_repo.values():
548                patterns.append(pathspec)
549
550    # If any paths were specified, only search for paths in those repos.
551    if repos_with_paths:
552        for repo in set(pathspecs_by_repo) - repos_with_paths:
553            del pathspecs_by_repo[repo]
554
555    return pathspecs_by_repo
556
557
558def fetch_file_lists(
559    root: Path,
560    repo: Path,
561    pathspecs: list[str],
562    exclude: Sequence[Pattern] = (),
563    base: str | None = None,
564) -> tuple[list[Path], list[Path]]:
565    """Returns lists of all files and modified files for the given repo.
566
567    Args:
568        root: root path of the project
569        repo: path to the roots of Git repository to check
570        base: optional base Git commit to list files against
571        pathspecs: optional list of Git pathspecs to run the checks against
572        exclude: regular expressions for Posix-style paths to exclude
573    """
574
575    all_files: list[Path] = []
576    modified_files: list[Path] = []
577
578    all_files_repo = tuple(
579        tools.exclude_paths(
580            exclude, git_repo.list_files(None, pathspecs, repo), root
581        )
582    )
583    all_files += all_files_repo
584
585    if base is None:
586        modified_files += all_files_repo
587    else:
588        modified_files += tools.exclude_paths(
589            exclude, git_repo.list_files(base, pathspecs, repo), root
590        )
591
592    _LOG.info(
593        'Checking %s',
594        git_repo.describe_files(repo, repo, base, pathspecs, exclude, root),
595    )
596
597    return all_files, modified_files
598
599
600def run(  # pylint: disable=too-many-arguments,too-many-locals
601    program: Sequence[Check],
602    root: Path,
603    repos: Collection[Path] = (),
604    base: str | None = None,
605    paths: Sequence[str] = (),
606    exclude: Sequence[Pattern] = (),
607    output_directory: Path | None = None,
608    package_root: Path | None = None,
609    only_list_steps: bool = False,
610    override_gn_args: Sequence[tuple[str, str]] = (),
611    keep_going: bool = False,
612    continue_after_build_error: bool = False,
613    rng_seed: int = 1,
614    presubmit_class: type = Presubmit,
615    list_steps_file: Path | None = None,
616    substep: str | None = None,
617    dry_run: bool = False,
618) -> bool:
619    """Lists files in the current Git repo and runs a Presubmit with them.
620
621    This changes the directory to the root of the Git repository after listing
622    paths, so all presubmit checks can assume they run from there.
623
624    The paths argument contains Git pathspecs. If no pathspecs are provided, all
625    paths in all repos are included. If paths to files or directories are
626    provided, only files within those repositories are searched. Patterns are
627    searched across all repositories. For example, if the pathspecs "my_module/"
628    and "*.h", paths under "my_module/" in the containing repo and paths in all
629    repos matching "*.h" will be included in the presubmit.
630
631    Args:
632        program: list of presubmit check functions to run
633        root: root path of the project
634        repos: paths to the roots of Git repositories to check
635        name: name to use to refer to this presubmit check run
636        base: optional base Git commit to list files against
637        paths: optional list of Git pathspecs to run the checks against
638        exclude: regular expressions for Posix-style paths to exclude
639        output_directory: where to place output files
640        package_root: where to place package files
641        only_list_steps: print step names instead of running them
642        override_gn_args: additional GN args to set on steps
643        keep_going: continue running presubmit steps after a step fails
644        continue_after_build_error: continue building if a build step fails
645        rng_seed: seed for a random number generator, for the few steps that
646            need one
647        presubmit_class: class to use to run Presubmits, should inherit from
648            Presubmit class above
649        list_steps_file: File created by --only-list-steps, used to keep from
650            recalculating affected files.
651        substep: run only part of a single check
652
653    Returns:
654        True if all presubmit checks succeeded
655    """
656    repos = [repo.resolve() for repo in repos]
657
658    non_empty_repos = []
659    for repo in repos:
660        if list(repo.iterdir()):
661            non_empty_repos.append(repo)
662            if git_repo.root(repo) != repo:
663                raise ValueError(
664                    f'{repo} is not the root of a Git repo; '
665                    'presubmit checks must be run from a Git repo'
666                )
667    repos = non_empty_repos
668
669    pathspecs_by_repo = _process_pathspecs(repos, paths)
670
671    all_files: list[Path] = []
672    modified_files: list[Path] = []
673    list_steps_data: dict[str, Any] = {}
674
675    if list_steps_file:
676        with list_steps_file.open() as ins:
677            list_steps_data = json.load(ins)
678        all_files.extend(list_steps_data['all_files'])
679        for step in list_steps_data['steps']:
680            modified_files.extend(Path(x) for x in step.get("paths", ()))
681        modified_files = sorted(set(modified_files))
682        _LOG.info(
683            'Loaded %d paths from file %s',
684            len(modified_files),
685            list_steps_file,
686        )
687
688    else:
689        for repo, pathspecs in pathspecs_by_repo.items():
690            new_all_files_items, new_modified_file_items = fetch_file_lists(
691                root, repo, pathspecs, exclude, base
692            )
693            all_files.extend(new_all_files_items)
694            modified_files.extend(new_modified_file_items)
695
696    if output_directory is None:
697        output_directory = root / '.presubmit'
698
699    if package_root is None:
700        package_root = output_directory / 'packages'
701
702    presubmit = presubmit_class(
703        root=root,
704        repos=repos,
705        output_directory=output_directory,
706        paths=modified_files,
707        all_paths=all_files,
708        package_root=package_root,
709        override_gn_args=dict(override_gn_args or {}),
710        continue_after_build_error=continue_after_build_error,
711        rng_seed=rng_seed,
712        full=bool(base is None),
713    )
714
715    if only_list_steps:
716        steps: list[dict] = []
717        for filtered_check in presubmit.apply_filters(program):
718            step = {
719                'name': filtered_check.name,
720                'paths': [str(x) for x in filtered_check.paths],
721            }
722            substeps = filtered_check.check.substeps()
723            if len(substeps) > 1:
724                step['substeps'] = [x.name for x in substeps]
725            steps.append(step)
726
727        list_steps_data = {
728            'steps': steps,
729            'all_files': [str(x) for x in all_files],
730        }
731        json.dump(list_steps_data, sys.stdout, indent=2)
732        sys.stdout.write('\n')
733        return True
734
735    if not isinstance(program, Program):
736        program = Program('', program)
737
738    return presubmit.run(program, keep_going, substep=substep, dry_run=dry_run)
739
740
741def _make_str_tuple(value: Iterable[str] | str) -> tuple[str, ...]:
742    return tuple([value] if isinstance(value, str) else value)
743
744
745def check(*args, **kwargs):
746    """Turn a function into a presubmit check.
747
748    Args:
749        *args: Passed through to function.
750        *kwargs: Passed through to function.
751
752    If only one argument is provided and it's a function, this function acts
753    as a decorator and creates a Check from the function. Example of this kind
754    of usage:
755
756    @check
757    def pragma_once(ctx: PresubmitContext):
758        pass
759
760    Otherwise, save the arguments, and return a decorator that turns a function
761    into a Check, but with the arguments added onto the Check constructor.
762    Example of this kind of usage:
763
764    @check(name='pragma_twice')
765    def pragma_once(ctx: PresubmitContext):
766        pass
767    """
768    if (
769        len(args) == 1
770        and isinstance(args[0], types.FunctionType)
771        and not kwargs
772    ):
773        # Called as a regular decorator.
774        return Check(args[0])
775
776    def decorator(check_function):
777        return Check(check_function, *args, **kwargs)
778
779    return decorator
780
781
782@dataclasses.dataclass
783class SubStep:
784    name: str | None
785    _func: Callable[..., PresubmitResult]
786    args: Sequence[Any] = ()
787    kwargs: dict[str, Any] = dataclasses.field(default_factory=lambda: {})
788
789    def __call__(self, ctx: PresubmitContext) -> PresubmitResult:
790        if self.name:
791            _LOG.info('%s', self.name)
792        return self._func(ctx, *self.args, **self.kwargs)
793
794
795class Check:
796    """Wraps a presubmit check function.
797
798    This class consolidates the logic for running and logging a presubmit check.
799    It also supports filtering the paths passed to the presubmit check.
800    """
801
802    def __init__(
803        self,
804        check: (  # pylint: disable=redefined-outer-name
805            Callable | Iterable[SubStep]
806        ),
807        path_filter: FileFilter = FileFilter(),
808        always_run: bool = True,
809        name: str | None = None,
810        doc: str | None = None,
811    ) -> None:
812        # Since Check wraps a presubmit function, adopt that function's name.
813        self.name: str = ''
814        self.doc: str = ''
815        if isinstance(check, Check):
816            self.name = check.name
817            self.doc = check.doc
818        elif callable(check):
819            self.name = check.__name__
820            self.doc = check.__doc__ or ''
821
822        if name:
823            self.name = name
824        if doc:
825            self.doc = doc
826
827        if not self.name:
828            raise ValueError('no name for step')
829
830        self._substeps_raw: Iterable[SubStep]
831        if isinstance(check, collections.abc.Iterator):
832            self._substeps_raw = check
833        else:
834            assert callable(check)
835            _ensure_is_valid_presubmit_check_function(check)
836            self._substeps_raw = iter((SubStep(None, check),))
837        self._substeps_saved: Sequence[SubStep] = ()
838
839        self.filter = path_filter
840        self.always_run: bool = always_run
841
842        self._is_presubmit_check_object = True
843
844    def substeps(self) -> Sequence[SubStep]:
845        """Return the SubSteps of the current step.
846
847        This is where the list of SubSteps is actually evaluated. It can't be
848        evaluated in the constructor because the Iterable passed into the
849        constructor might not be ready yet.
850        """
851        if not self._substeps_saved:
852            self._substeps_saved = tuple(self._substeps_raw)
853        return self._substeps_saved
854
855    def __repr__(self):
856        # This returns just the name so it's easy to show the entire list of
857        # steps with '--help'.
858        return self.name
859
860    def unfiltered(self) -> Check:
861        """Create a new check identical to this one, but without the filter."""
862        clone = copy.copy(self)
863        clone.filter = FileFilter()
864        return clone
865
866    def with_filter(
867        self,
868        *,
869        endswith: Iterable[str] = (),
870        exclude: Iterable[Pattern[str] | str] = (),
871    ) -> Check:
872        """Create a new check identical to this one, but with extra filters.
873
874        Add to the existing filter, perhaps to exclude an additional directory.
875
876        Args:
877            endswith: Passed through to FileFilter.
878            exclude: Passed through to FileFilter.
879
880        Returns a new check.
881        """
882        return self.with_file_filter(
883            FileFilter(endswith=_make_str_tuple(endswith), exclude=exclude)
884        )
885
886    def with_file_filter(self, file_filter: FileFilter) -> Check:
887        """Create a new check identical to this one, but with extra filters.
888
889        Add to the existing filter, perhaps to exclude an additional directory.
890
891        Args:
892            file_filter: Additional filter rules.
893
894        Returns a new check.
895        """
896        clone = copy.copy(self)
897        if clone.filter:
898            clone.filter.exclude = clone.filter.exclude + file_filter.exclude
899            clone.filter.endswith = clone.filter.endswith + file_filter.endswith
900            clone.filter.name = file_filter.name or clone.filter.name
901            clone.filter.suffix = clone.filter.suffix + file_filter.suffix
902        else:
903            clone.filter = file_filter
904        return clone
905
906    def run(
907        self,
908        ctx: PresubmitContext,
909        count: int,
910        total: int,
911        substep: str | None = None,
912    ) -> PresubmitResult:
913        """Runs the presubmit check on the provided paths."""
914
915        _print_ui(
916            _box(
917                _CHECK_UPPER,
918                f'{count}/{total}',
919                self.name,
920                plural(ctx.paths, "file"),
921            )
922        )
923
924        substep_part = f'.{substep}' if substep else ''
925        _LOG.debug(
926            '[%d/%d] Running %s%s on %s',
927            count,
928            total,
929            self.name,
930            substep_part,
931            plural(ctx.paths, "file"),
932        )
933
934        start_time_s = time.time()
935        result: PresubmitResult
936        if substep:
937            result = self.run_substep(ctx, substep)
938        else:
939            result = self(ctx)
940        time_str = _format_time(time.time() - start_time_s)
941        _LOG.debug('%s %s', self.name, result.value)
942
943        if ctx.dry_run:
944            log_check_traces(ctx)
945
946        _print_ui(
947            _box(_CHECK_LOWER, result.colorized(_LEFT), self.name, time_str)
948        )
949        _LOG.debug('%s duration:%s', self.name, time_str)
950
951        return result
952
953    def _try_call(
954        self,
955        func: Callable,
956        ctx,
957        *args,
958        **kwargs,
959    ) -> PresubmitResult:
960        try:
961            result = func(ctx, *args, **kwargs)
962            if ctx.failed:
963                return PresubmitResult.FAIL
964            if isinstance(result, PresubmitResult):
965                return result
966            return PresubmitResult.PASS
967
968        except PresubmitFailure as failure:
969            if str(failure):
970                _LOG.warning('%s', failure)
971            return PresubmitResult.FAIL
972
973        except Exception as _failure:  # pylint: disable=broad-except
974            _LOG.exception('Presubmit check %s failed!', self.name)
975            return PresubmitResult.FAIL
976
977        except KeyboardInterrupt:
978            _print_ui()
979            return PresubmitResult.CANCEL
980
981    def run_substep(
982        self, ctx: PresubmitContext, name: str | None
983    ) -> PresubmitResult:
984        for substep in self.substeps():
985            if substep.name == name:
986                return substep(ctx)
987
988        expected = ', '.join(repr(s.name) for s in self.substeps())
989        raise LookupError(f'bad substep name: {name!r} (expected: {expected})')
990
991    def __call__(self, ctx: PresubmitContext) -> PresubmitResult:
992        """Calling a Check calls its underlying substeps directly.
993
994        This makes it possible to call functions wrapped by @filter_paths. The
995        prior filters are ignored, so new filters may be applied.
996        """
997        result: PresubmitResult
998        for substep in self.substeps():
999            result = self._try_call(substep, ctx)
1000            if result and result != PresubmitResult.PASS:
1001                return result
1002        return PresubmitResult.PASS
1003
1004
1005def _required_args(function: Callable) -> Iterable[Parameter]:
1006    """Returns the required arguments for a function."""
1007    optional_types = Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD
1008
1009    for param in signature(function).parameters.values():
1010        if param.default is param.empty and param.kind not in optional_types:
1011            yield param
1012
1013
1014def _ensure_is_valid_presubmit_check_function(chk: Callable) -> None:
1015    """Checks if a Callable can be used as a presubmit check."""
1016    try:
1017        required_args = tuple(_required_args(chk))
1018    except (TypeError, ValueError):
1019        raise TypeError(
1020            'Presubmit checks must be callable, but '
1021            f'{chk!r} is a {type(chk).__name__}'
1022        )
1023
1024    if len(required_args) != 1:
1025        raise TypeError(
1026            f'Presubmit check functions must have exactly one required '
1027            f'positional argument (the PresubmitContext), but '
1028            f'{chk.__name__} has {len(required_args)} required arguments'
1029            + (
1030                f' ({", ".join(a.name for a in required_args)})'
1031                if required_args
1032                else ''
1033            )
1034        )
1035
1036
1037def filter_paths(
1038    *,
1039    endswith: Iterable[str] = (),
1040    exclude: Iterable[Pattern[str] | str] = (),
1041    file_filter: FileFilter | None = None,
1042    always_run: bool = False,
1043) -> Callable[[Callable], Check]:
1044    """Decorator for filtering the paths list for a presubmit check function.
1045
1046    Path filters only apply when the function is used as a presubmit check.
1047    Filters are ignored when the functions are called directly. This makes it
1048    possible to reuse functions wrapped in @filter_paths in other presubmit
1049    checks, potentially with different path filtering rules.
1050
1051    Args:
1052        endswith: str or iterable of path endings to include
1053        exclude: regular expressions of paths to exclude
1054        file_filter: FileFilter used to select files
1055        always_run: Run check even when no files match
1056    Returns:
1057        a wrapped version of the presubmit function
1058    """
1059
1060    if file_filter:
1061        real_file_filter = file_filter
1062        if endswith or exclude:
1063            raise ValueError(
1064                'Must specify either file_filter or '
1065                'endswith/exclude args, not both'
1066            )
1067    else:
1068        # TODO: b/238426363 - Remove these arguments and use FileFilter only.
1069        real_file_filter = FileFilter(
1070            endswith=_make_str_tuple(endswith), exclude=exclude
1071        )
1072
1073    def filter_paths_for_function(function: Callable):
1074        return Check(function, real_file_filter, always_run=always_run)
1075
1076    return filter_paths_for_function
1077
1078
1079def call(
1080    *args, call_annotation: dict[Any, Any] | None = None, **kwargs
1081) -> None:
1082    """Optional subprocess wrapper that causes a PresubmitFailure on errors."""
1083    ctx = PRESUBMIT_CONTEXT.get()
1084    if ctx:
1085        # Save the subprocess command args for pw build presubmit runner.
1086        call_annotation = call_annotation if call_annotation else {}
1087        ctx.append_check_command(
1088            *args, call_annotation=call_annotation, **kwargs
1089        )
1090        # Return without running if dry-run mode is on.
1091        if ctx.dry_run:
1092            return
1093
1094    attributes, command = tools.format_command(args, kwargs)
1095    _LOG.debug('[RUN] %s\n%s', attributes, command)
1096
1097    tee = kwargs.pop('tee', None)
1098    propagate_sigterm = kwargs.pop('propagate_sigterm', False)
1099
1100    env = pw_cli.env.pigweed_environment()
1101    kwargs.setdefault('stdout', subprocess.PIPE)
1102    kwargs.setdefault('stderr', subprocess.STDOUT)
1103
1104    process = subprocess.Popen(args, **kwargs)
1105    assert process.stdout
1106
1107    # Set up signal handler if requested.
1108    signaled = False
1109    if propagate_sigterm:
1110
1111        def signal_handler(_signal_number: int, _stack_frame: Any) -> None:
1112            nonlocal signaled
1113            signaled = True
1114            process.terminate()
1115
1116        previous_signal_handler = signal.signal(signal.SIGTERM, signal_handler)
1117
1118    if env.PW_PRESUBMIT_DISABLE_SUBPROCESS_CAPTURE:
1119        while True:
1120            line = process.stdout.readline().decode(errors='backslashreplace')
1121            if not line:
1122                break
1123            _LOG.info(line.rstrip())
1124            if tee:
1125                tee.write(line)
1126
1127    stdout, _ = process.communicate()
1128    if tee:
1129        tee.write(stdout.decode(errors='backslashreplace'))
1130
1131    logfunc = _LOG.warning if process.returncode else _LOG.debug
1132    logfunc('[FINISHED]\n%s', command)
1133    logfunc(
1134        '[RESULT] %s with return code %d',
1135        'Failed' if process.returncode else 'Passed',
1136        process.returncode,
1137    )
1138    if stdout:
1139        logfunc('[OUTPUT]\n%s', stdout.decode(errors='backslashreplace'))
1140
1141    if propagate_sigterm:
1142        signal.signal(signal.SIGTERM, previous_signal_handler)
1143        if signaled:
1144            _LOG.warning('Exiting due to SIGTERM.')
1145            sys.exit(1)
1146
1147    if process.returncode:
1148        raise PresubmitFailure
1149
1150
1151def install_package(
1152    ctx: FormatContext | PresubmitContext,
1153    name: str,
1154    force: bool = False,
1155) -> None:
1156    """Install package with given name in given path."""
1157    root = ctx.package_root
1158    mgr = package_manager.PackageManager(root)
1159
1160    ctx.append_check_command(
1161        'pw',
1162        'package',
1163        'install',
1164        name,
1165        call_annotation={'pw_package_install': name},
1166    )
1167    if ctx.dry_run:
1168        return
1169
1170    if not mgr.list():
1171        raise PresubmitFailure(
1172            'no packages configured, please import your pw_package '
1173            'configuration module'
1174        )
1175
1176    if not mgr.status(name) or force:
1177        mgr.install(name, force=force)
1178