xref: /aosp_15_r20/external/pigweed/pw_presubmit/py/pw_presubmit/keep_sorted.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1# Copyright 2022 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Keep specified lists sorted."""
15
16import argparse
17import dataclasses
18import difflib
19import logging
20import os
21from pathlib import Path
22import re
23import sys
24from typing import (
25    Callable,
26    Collection,
27    Pattern,
28    Sequence,
29)
30
31import pw_cli
32from pw_cli.diff import colorize_diff
33from pw_cli.plural import plural
34from . import cli, git_repo, presubmit, presubmit_context, tools
35
36DEFAULT_PATH = Path('out', 'presubmit', 'keep_sorted')
37
38_LOG: logging.Logger = logging.getLogger(__name__)
39
40# Ignore a whole section. Please do not change the order of these lines.
41_START = re.compile(r'keep-sorted: (begin|start)', re.IGNORECASE)
42_END = re.compile(r'keep-sorted: (stop|end)', re.IGNORECASE)
43_IGNORE_CASE = re.compile(r'ignore-case', re.IGNORECASE)
44_ALLOW_DUPES = re.compile(r'allow-dupes', re.IGNORECASE)
45_IGNORE_PREFIX = re.compile(r'ignore-prefix=(\S+)', re.IGNORECASE)
46_STICKY_COMMENTS = re.compile(r'sticky-comments=(\S+)', re.IGNORECASE)
47
48# Only include these literals here so keep_sorted doesn't try to reorder later
49# test lines.
50(
51    START,
52    END,
53) = """
54keep-sorted: start
55keep-sorted: end
56""".strip().splitlines()
57
58
59@dataclasses.dataclass
60class KeepSortedContext:
61    paths: list[Path]
62    fix: bool
63    output_dir: Path
64    failure_summary_log: Path
65    failed: bool = False
66
67    def fail(
68        self,
69        description: str = '',
70        path: Path | None = None,
71        line: int | None = None,
72    ) -> None:
73        if not self.fix:
74            self.failed = True
75
76        line_part: str = ''
77        if line is not None:
78            line_part = f'{line}:'
79
80        log = _LOG.error
81        if self.fix:
82            log = _LOG.warning
83
84        if path:
85            log('%s:%s %s', path, line_part, description)
86        else:
87            log('%s', description)
88
89
90class KeepSortedParsingError(presubmit.PresubmitFailure):
91    pass
92
93
94@dataclasses.dataclass
95class _Line:
96    value: str = ''
97    sticky_comments: Sequence[str] = ()
98    continuations: Sequence[str] = ()
99
100    @property
101    def full(self):
102        return ''.join((*self.sticky_comments, self.value, *self.continuations))
103
104    def __lt__(self, other):
105        if not isinstance(other, _Line):
106            return NotImplemented
107        left = (self.value, self.continuations, self.sticky_comments)
108        right = (other.value, other.continuations, other.sticky_comments)
109        return left < right
110
111
112@dataclasses.dataclass
113class _Block:
114    ignore_case: bool = False
115    allow_dupes: bool = False
116    ignored_prefixes: Sequence[str] = dataclasses.field(default_factory=list)
117    sticky_comments: tuple[str, ...] = ()
118    start_line_number: int = -1
119    start_line: str = ''
120    end_line: str = ''
121    lines: list[str] = dataclasses.field(default_factory=list)
122
123
124class _FileSorter:
125    def __init__(
126        self,
127        ctx: presubmit.PresubmitContext | KeepSortedContext,
128        path: Path,
129        errors: dict[Path, Sequence[str]] | None = None,
130    ):
131        self.ctx = ctx
132        self.path: Path = path
133        self.all_lines: list[str] = []
134        self.changed: bool = False
135        self._errors: dict[Path, Sequence[str]] = {}
136        if errors is not None:
137            self._errors = errors
138
139    def _process_block(self, block: _Block) -> Sequence[str]:
140        raw_lines: list[str] = block.lines
141        lines: list[_Line] = []
142
143        def prefix(x):
144            return len(x) - len(x.lstrip())
145
146        prev_prefix: int | None = None
147        comments: list[str] = []
148        for raw_line in raw_lines:
149            curr_prefix: int = prefix(raw_line)
150            _LOG.debug('prev_prefix %r', prev_prefix)
151            _LOG.debug('curr_prefix %r', curr_prefix)
152            # A "sticky" comment is a comment in the middle of a list of
153            # non-comments. The keep-sorted check keeps this comment with the
154            # following item in the list. For more details see
155            # https://pigweed.dev/pw_presubmit/#sorted-blocks.
156            if block.sticky_comments and raw_line.lstrip().startswith(
157                block.sticky_comments
158            ):
159                _LOG.debug('found sticky %r', raw_line)
160                comments.append(raw_line)
161            elif prev_prefix is not None and curr_prefix > prev_prefix:
162                _LOG.debug('found continuation %r', raw_line)
163                lines[-1].continuations = (*lines[-1].continuations, raw_line)
164                _LOG.debug('modified line %s', lines[-1])
165            else:
166                _LOG.debug('non-sticky %r', raw_line)
167                line = _Line(raw_line, tuple(comments))
168                _LOG.debug('line %s', line)
169                lines.append(line)
170                comments = []
171                prev_prefix = curr_prefix
172        if comments:
173            self.ctx.fail(
174                f'sticky comment at end of block: {comments[0].strip()}',
175                self.path,
176                block.start_line_number,
177            )
178
179        if not block.allow_dupes:
180            lines = list({x.full: x for x in lines}.values())
181
182        StrLinePair = tuple[str, _Line]  # pylint: disable=invalid-name
183        sort_key_funcs: list[Callable[[StrLinePair], StrLinePair]] = []
184
185        if block.ignored_prefixes:
186
187            def strip_ignored_prefixes(val):
188                """Remove one ignored prefix from val, if present."""
189                wo_white = val[0].lstrip()
190                white = val[0][0 : -len(wo_white)]
191                for prefix in block.ignored_prefixes:
192                    if wo_white.startswith(prefix):
193                        return (f'{white}{wo_white[len(prefix):]}', val[1])
194                return (val[0], val[1])
195
196            sort_key_funcs.append(strip_ignored_prefixes)
197
198        if block.ignore_case:
199            sort_key_funcs.append(lambda val: (val[0].lower(), val[1]))
200
201        def sort_key(line):
202            vals = (line.value, line)
203            for sort_key_func in sort_key_funcs:
204                vals = sort_key_func(vals)
205            return vals
206
207        for val in lines:
208            _LOG.debug('For sorting: %r => %r', val, sort_key(val))
209
210        sorted_lines = sorted(lines, key=sort_key)
211        raw_sorted_lines: list[str] = []
212        for line in sorted_lines:
213            raw_sorted_lines.extend(line.sticky_comments)
214            raw_sorted_lines.append(line.value)
215            raw_sorted_lines.extend(line.continuations)
216
217        if block.lines != raw_sorted_lines:
218            self.changed = True
219            diff = difflib.Differ()
220            diff_lines = ''.join(diff.compare(block.lines, raw_sorted_lines))
221
222            self._errors.setdefault(self.path, [])
223            self._errors[self.path] = (
224                f'@@ {block.start_line_number},{len(block.lines)+2} '
225                f'{block.start_line_number},{len(raw_sorted_lines)+2} @@\n'
226                f'  {block.start_line}{diff_lines}  {block.end_line}'
227            )
228
229        return raw_sorted_lines
230
231    def _parse_file(self, ins):
232        block: _Block | None = None
233
234        for i, line in enumerate(ins, start=1):
235            if block:
236                if _START.search(line):
237                    raise KeepSortedParsingError(
238                        f'found {line.strip()!r} inside keep-sorted block',
239                        self.path,
240                        i,
241                    )
242
243                if _END.search(line):
244                    _LOG.debug('Found end line %d %r', i, line)
245                    block.end_line = line
246                    self.all_lines.extend(self._process_block(block))
247                    block = None
248                    self.all_lines.append(line)
249
250                else:
251                    _LOG.debug('Adding to block line %d %r', i, line)
252                    block.lines.append(line)
253
254            elif start_match := _START.search(line):
255                _LOG.debug('Found start line %d %r', i, line)
256
257                block = _Block()
258
259                block.ignore_case = bool(_IGNORE_CASE.search(line))
260                _LOG.debug('ignore_case: %s', block.ignore_case)
261
262                block.allow_dupes = bool(_ALLOW_DUPES.search(line))
263                _LOG.debug('allow_dupes: %s', block.allow_dupes)
264
265                match = _IGNORE_PREFIX.search(line)
266                if match:
267                    block.ignored_prefixes = match.group(1).split(',')
268
269                    # We want to check the longest prefixes first, in case one
270                    # prefix is a prefix of another prefix.
271                    block.ignored_prefixes.sort(key=lambda x: (-len(x), x))
272                _LOG.debug('ignored_prefixes: %r', block.ignored_prefixes)
273
274                match = _STICKY_COMMENTS.search(line)
275                if match:
276                    if match.group(1) == 'no':
277                        block.sticky_comments = ()
278                    else:
279                        block.sticky_comments = tuple(match.group(1).split(','))
280                else:
281                    prefix = line[: start_match.start()].strip()
282                    if prefix and len(prefix) <= 3:
283                        block.sticky_comments = (prefix,)
284                _LOG.debug('sticky_comments: %s', block.sticky_comments)
285
286                block.start_line = line
287                block.start_line_number = i
288                self.all_lines.append(line)
289
290                remaining = line[start_match.end() :].strip()
291                remaining = _IGNORE_CASE.sub('', remaining, count=1).strip()
292                remaining = _ALLOW_DUPES.sub('', remaining, count=1).strip()
293                remaining = _IGNORE_PREFIX.sub('', remaining, count=1).strip()
294                remaining = _STICKY_COMMENTS.sub('', remaining, count=1).strip()
295                if remaining.strip():
296                    raise KeepSortedParsingError(
297                        f'unrecognized directive on keep-sorted line: '
298                        f'{remaining}',
299                        self.path,
300                        i,
301                    )
302
303            elif _END.search(line):
304                raise KeepSortedParsingError(
305                    f'found {line.strip()!r} outside keep-sorted block',
306                    self.path,
307                    i,
308                )
309
310            else:
311                self.all_lines.append(line)
312
313        if block:
314            raise KeepSortedParsingError(
315                f'found EOF while looking for "{END}"', self.path
316            )
317
318    def sort(self) -> None:
319        """Check for unsorted keep-sorted blocks."""
320        _LOG.debug('Evaluating path %s', self.path)
321        try:
322            with self.path.open() as ins:
323                _LOG.debug('Processing %s', self.path)
324                self._parse_file(ins)
325
326        except UnicodeDecodeError:
327            # File is not text, like a gif.
328            _LOG.debug('File %s is not a text file', self.path)
329
330    def write(self, path: Path | None = None) -> None:
331        if not self.changed:
332            return
333        if not path:
334            path = self.path
335        with path.open('w') as outs:
336            outs.writelines(self.all_lines)
337            _LOG.info('Applied keep-sorted changes to %s', path)
338
339
340def _print_howto_fix(paths: Sequence[Path]) -> None:
341    def path_relative_to_cwd(path):
342        try:
343            return Path(path).resolve().relative_to(Path.cwd().resolve())
344        except ValueError:
345            return Path(path).resolve()
346
347    message = (
348        f'  pw keep-sorted --fix {path_relative_to_cwd(path)}' for path in paths
349    )
350    _LOG.warning('To sort these blocks, run:\n\n%s\n', '\n'.join(message))
351
352
353def _process_files(
354    ctx: presubmit.PresubmitContext | KeepSortedContext,
355) -> dict[Path, Sequence[str]]:
356    fix = getattr(ctx, 'fix', False)
357    errors: dict[Path, Sequence[str]] = {}
358
359    for path in ctx.paths:
360        if path.is_symlink() or path.is_dir():
361            continue
362
363        try:
364            sorter = _FileSorter(ctx, path, errors)
365
366            sorter.sort()
367            if sorter.changed:
368                if fix:
369                    sorter.write()
370
371        except KeepSortedParsingError as exc:
372            ctx.fail(str(exc))
373
374    if not errors:
375        return errors
376
377    ctx.fail(f'Found {plural(errors, "file")} with keep-sorted errors:')
378
379    with ctx.failure_summary_log.open('w') as outs:
380        for path, diffs in errors.items():
381            diff = ''.join(
382                [
383                    f'--- {path} (original)\n',
384                    f'+++ {path} (sorted)\n',
385                    *diffs,
386                ]
387            )
388
389            outs.write(diff)
390            print(colorize_diff(diff))
391
392    return errors
393
394
395@presubmit.check(name='keep_sorted')
396def presubmit_check(ctx: presubmit.PresubmitContext) -> None:
397    """Presubmit check that ensures specified lists remain sorted."""
398
399    ctx.paths = presubmit_context.apply_exclusions(ctx)
400    errors = _process_files(ctx)
401
402    if errors:
403        _print_howto_fix(list(errors.keys()))
404
405
406def parse_args() -> argparse.Namespace:
407    """Creates an argument parser and parses arguments."""
408
409    parser = argparse.ArgumentParser(description=__doc__)
410    cli.add_path_arguments(parser)
411    parser.add_argument(
412        '--fix', action='store_true', help='Apply fixes in place.'
413    )
414
415    parser.add_argument(
416        '--output-directory',
417        type=Path,
418        help=f'Output directory (default: {"<repo root>" / DEFAULT_PATH})',
419    )
420
421    return parser.parse_args()
422
423
424def keep_sorted_in_repo(
425    paths: Collection[Path | str],
426    fix: bool,
427    exclude: Collection[Pattern[str]],
428    base: str,
429    output_directory: Path | None,
430) -> int:
431    """Checks or fixes keep-sorted blocks for files in a Git repo."""
432
433    files = [Path(path).resolve() for path in paths if os.path.isfile(path)]
434    repo = git_repo.root() if git_repo.is_repo() else None
435
436    # Implement a graceful fallback in case the tracking branch isn't available.
437    if base == git_repo.TRACKING_BRANCH_ALIAS and not git_repo.tracking_branch(
438        repo
439    ):
440        _LOG.warning(
441            'Failed to determine the tracking branch, using --base HEAD~1 '
442            'instead of listing all files'
443        )
444        base = 'HEAD~1'
445
446    # If this is a Git repo, list the original paths with git ls-files or diff.
447    project_root = pw_cli.env.pigweed_environment().PW_PROJECT_ROOT
448    if repo:
449        _LOG.info(
450            'Sorting %s',
451            git_repo.describe_files(
452                repo, Path.cwd(), base, paths, exclude, project_root
453            ),
454        )
455
456        # Add files from Git and remove duplicates.
457        files = sorted(
458            set(tools.exclude_paths(exclude, git_repo.list_files(base, paths)))
459            | set(files)
460        )
461    elif base:
462        _LOG.critical(
463            'A base commit may only be provided if running from a Git repo'
464        )
465        return 1
466
467    outdir: Path
468    if output_directory:
469        outdir = output_directory
470    elif repo:
471        outdir = repo / DEFAULT_PATH
472    else:
473        outdir = project_root / DEFAULT_PATH
474
475    ctx = KeepSortedContext(
476        paths=files,
477        fix=fix,
478        output_dir=outdir,
479        failure_summary_log=outdir / 'failure-summary.log',
480    )
481    errors = _process_files(ctx)
482
483    if not fix and errors:
484        _print_howto_fix(list(errors.keys()))
485
486    return int(ctx.failed)
487
488
489def main() -> int:
490    return keep_sorted_in_repo(**vars(parse_args()))
491
492
493if __name__ == '__main__':
494    pw_cli.log.install(logging.INFO)
495    sys.exit(main())
496