xref: /aosp_15_r20/external/pigweed/pw_presubmit/py/pw_presubmit/install_hook.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1#!/usr/bin/env python3
2# Copyright 2020 The Pigweed Authors
3#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may not
5# use this file except in compliance with the License. You may obtain a copy of
6# the License at
7#
8#     https://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations under
14# the License.
15"""Creates a Git hook that calls a script with certain arguments."""
16
17import argparse
18import logging
19import os
20from pathlib import Path
21import re
22import shlex
23import subprocess
24from typing import Sequence
25
26_LOG: logging.Logger = logging.getLogger(__name__)
27
28
29def git_repo_root(path: Path | str) -> Path:
30    return Path(
31        subprocess.run(
32            ['git', '-C', path, 'rev-parse', '--show-toplevel'],
33            check=True,
34            stdout=subprocess.PIPE,
35        )
36        .stdout.strip()
37        .decode()
38    )
39
40
41def _stdin_args_for_hook(hook) -> Sequence[str]:
42    """Gives stdin arguments for each hook.
43
44    See https://git-scm.com/docs/githooks for more information.
45    """
46    if hook == 'pre-push':
47        return (
48            'local_ref',
49            'local_object_name',
50            'remote_ref',
51            'remote_object_name',
52        )
53    if hook in ('pre-receive', 'post-receive', 'reference-transaction'):
54        return ('old_value', 'new_value', 'ref_name')
55    if hook == 'post-rewrite':
56        return ('old_object_name', 'new_object_name')
57    return ()
58
59
60def _replace_arg_in_hook(arg: str, unquoted_args: Sequence[str]) -> str:
61    if arg in unquoted_args:
62        return arg
63    return shlex.quote(arg)
64
65
66def install_git_hook(
67    hook: str,
68    command: Sequence[Path | str],
69    repository: Path | str = '.',
70) -> None:
71    """Installs a simple Git hook that executes the provided command.
72
73    Args:
74      hook: Git hook to install, e.g. 'pre-push'.
75      command: Command to execute as the hook. The command is executed from the
76          root of the repo. Arguments are sanitised with `shlex.quote`, except
77          for any arguments are equal to f'${stdin_arg}' for some `stdin_arg`
78          that matches a standard-input argument to the git hook.
79      repository: Repository to install the hook in.
80    """
81    if not command:
82        raise ValueError('The command cannot be empty!')
83
84    root = git_repo_root(repository).resolve()
85
86    if root.joinpath('.git').is_dir():
87        hook_path = root.joinpath('.git', 'hooks', hook)
88    else:  # This repo is probably a submodule with a .git file instead
89        match = re.match('^gitdir: (.*)$', root.joinpath('.git').read_text())
90        if not match:
91            raise ValueError('Unexpected format for .git file')
92
93        hook_path = root.joinpath(match.group(1), 'hooks', hook).resolve()
94
95    try:
96        hook_path.parent.mkdir(exist_ok=True)
97    except FileExistsError as exc:
98        _LOG.warning('Failed to install %s hook: %s', hook, exc)
99        return
100
101    hook_stdin_args = _stdin_args_for_hook(hook)
102    read_stdin_command = 'read ' + ' '.join(hook_stdin_args)
103
104    unquoted_args = [f'${arg}' for arg in hook_stdin_args]
105    args = (_replace_arg_in_hook(str(a), unquoted_args) for a in command[1:])
106
107    command_str = ' '.join([shlex.quote(str(command[0])), *args])
108
109    with hook_path.open('w') as file:
110
111        def line(*args):
112            return print(*args, file=file)
113
114        line('#!/bin/sh')
115        line(f'# {hook} hook generated by {__file__}')
116        line()
117        line('# Unset Git environment variables, which are set when this is ')
118        line('# run as a Git hook. These environment variables cause issues ')
119        line('# when trying to run Git commands on other repos from a ')
120        line('# submodule hook.')
121        line('unset $(git rev-parse --local-env-vars)')
122        line()
123        line('# Read the stdin args for the hook, made available by git.')
124        line(read_stdin_command)
125        line()
126        line(command_str)
127
128    hook_path.chmod(0o755)
129    _LOG.info('Installed %s hook for `%s` at %s', hook, command_str, hook_path)
130
131
132def argument_parser(
133    parser: argparse.ArgumentParser | None = None,
134) -> argparse.ArgumentParser:
135    if parser is None:
136        parser = argparse.ArgumentParser(description=__doc__)
137
138    def path(arg: str) -> Path:
139        if not os.path.exists(arg):
140            raise argparse.ArgumentTypeError(f'"{arg}" is not a valid path')
141
142        return Path(arg)
143
144    parser.add_argument(
145        '-r',
146        '--repository',
147        default='.',
148        type=path,
149        help='Path to the repository in which to install the hook',
150    )
151    parser.add_argument(
152        '--hook', required=True, help='Which type of Git hook to create'
153    )
154    parser.add_argument(
155        'command', nargs='*', help='Command to run in the commit hook'
156    )
157
158    return parser
159
160
161if __name__ == '__main__':
162    logging.basicConfig(format='%(message)s', level=logging.INFO)
163    install_git_hook(**vars(argument_parser().parse_args()))
164