1# Copyright 2024 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Helpful commands for working with a Git repository.""" 15 16from datetime import datetime 17import logging 18from pathlib import Path 19import re 20import shlex 21import subprocess 22from typing import Collection, Iterable, Pattern 23 24from pw_cli.plural import plural 25from pw_cli.tool_runner import ToolRunner 26 27_LOG = logging.getLogger(__name__) 28 29TRACKING_BRANCH_ALIAS = '@{upstream}' 30_TRACKING_BRANCH_ALIASES = TRACKING_BRANCH_ALIAS, '@{u}' 31_NON_TRACKING_FALLBACK = 'HEAD~10' 32 33 34class GitError(Exception): 35 """A Git-raised exception.""" 36 37 def __init__( 38 self, args: Iterable[str], message: str, returncode: int 39 ) -> None: 40 super().__init__(f'`git {shlex.join(args)}` failed: {message}') 41 self.returncode = returncode 42 43 44class _GitTool: 45 def __init__(self, tool_runner: ToolRunner, working_dir: Path) -> None: 46 self._run_tool = tool_runner 47 self._working_dir = working_dir 48 49 def __call__(self, *args, **kwargs) -> str: 50 cmd = ('-C', str(self._working_dir), *args) 51 proc = self._run_tool(tool='git', args=cmd, **kwargs) 52 53 if proc.returncode != 0: 54 if not proc.stderr: 55 err = '(no output)' 56 else: 57 err = proc.stderr.decode().strip() 58 raise GitError((str(s) for s in cmd), err, proc.returncode) 59 60 return '' if not proc.stdout else proc.stdout.decode().strip() 61 62 63class GitRepo: 64 """Represents a checked out Git repository that may be queried for info.""" 65 66 def __init__(self, root: Path, tool_runner: ToolRunner): 67 self._root = root.resolve() 68 self._git = _GitTool(tool_runner, self._root) 69 70 def tracking_branch( 71 self, 72 fallback: str | None = None, 73 ) -> str | None: 74 """Returns the tracking branch of the current branch. 75 76 Since most callers of this function can safely handle a return value of 77 None, suppress exceptions and return None if there is no tracking 78 branch. 79 80 Returns: 81 the remote tracking branch name or None if there is none 82 """ 83 84 # This command should only error out if there's no upstream branch set. 85 try: 86 return self._git( 87 'rev-parse', 88 '--abbrev-ref', 89 '--symbolic-full-name', 90 TRACKING_BRANCH_ALIAS, 91 ) 92 93 except GitError: 94 return fallback 95 96 def current_branch(self) -> str | None: 97 """Returns the current branch, or None if it cannot be determined.""" 98 try: 99 return self._git('rev-parse', '--abbrev-ref', 'HEAD') 100 except GitError: 101 return None 102 103 def _ls_files(self, pathspecs: Collection[Path | str]) -> Iterable[Path]: 104 """Returns results of git ls-files as absolute paths.""" 105 for file in self._git('ls-files', '--', *pathspecs).splitlines(): 106 full_path = self._root / file 107 # Modified submodules will show up as directories and should be 108 # ignored. 109 if full_path.is_file(): 110 yield full_path 111 112 def _diff_names( 113 self, commit: str, pathspecs: Collection[Path | str] 114 ) -> Iterable[Path]: 115 """Returns paths of files changed since the specified commit. 116 117 All returned paths are absolute file paths. 118 """ 119 for file in self._git( 120 'diff', 121 '--name-only', 122 '--diff-filter=d', 123 commit, 124 '--', 125 *pathspecs, 126 ).splitlines(): 127 full_path = self._root / file 128 # Modified submodules will show up as directories and should be 129 # ignored. 130 if full_path.is_file(): 131 yield full_path 132 133 def list_files( 134 self, 135 commit: str | None = None, 136 pathspecs: Collection[Path | str] = (), 137 ) -> list[Path]: 138 """Lists files modified since the specified commit. 139 140 If ``commit`` is not found in the current repo, all files in the 141 repository are listed. 142 143 Arugments: 144 commit: The Git hash to start from when listing modified files 145 pathspecs: Git pathspecs use when filtering results 146 147 Returns: 148 A sorted list of absolute paths. 149 """ 150 151 if commit in _TRACKING_BRANCH_ALIASES: 152 commit = self.tracking_branch(fallback=_NON_TRACKING_FALLBACK) 153 154 if commit: 155 try: 156 return sorted(self._diff_names(commit, pathspecs)) 157 except GitError: 158 _LOG.warning( 159 'Error comparing with base revision %s of %s, listing all ' 160 'files instead of just changed files', 161 commit, 162 self._root, 163 ) 164 165 return sorted(self._ls_files(pathspecs)) 166 167 def has_uncommitted_changes(self) -> bool: 168 """Returns True if this Git repo has uncommitted changes in it. 169 170 Note: This does not check for untracked files. 171 172 Returns: 173 True if the Git repo has uncommitted changes in it. 174 """ 175 176 # Refresh the Git index so that the diff-index command will be accurate. 177 # The `git update-index` command isn't reliable when run in parallel 178 # with other processes that may touch files in the repo directory, so 179 # retry a few times before giving up. The hallmark of this failure mode 180 # is the lack of an error message on stderr, so if we see something 181 # there we can assume it's some other issue and raise. 182 retries = 6 183 for i in range(retries): 184 try: 185 self._git( 186 'update-index', 187 '-q', 188 '--refresh', 189 pw_presubmit_ignore_dry_run=True, 190 ) 191 except subprocess.CalledProcessError as err: 192 if err.stderr or i == retries - 1: 193 raise 194 continue 195 196 try: 197 self._git( 198 'diff-index', 199 '--quiet', 200 'HEAD', 201 '--', 202 pw_presubmit_ignore_dry_run=True, 203 ) 204 except GitError as err: 205 # diff-index exits with 1 if there are uncommitted changes. 206 if err.returncode == 1: 207 return True 208 209 # Unexpected error. 210 raise 211 212 return False 213 214 def root(self) -> Path: 215 """The root file path of this Git repository. 216 217 Returns: 218 The repository root as an absolute path. 219 """ 220 return self._root 221 222 def list_submodules( 223 self, excluded_paths: Collection[Pattern | str] = () 224 ) -> list[Path]: 225 """Query Git and return a list of submodules in the current project. 226 227 Arguments: 228 excluded_paths: Pattern or string that match submodules that should 229 not be returned. All matches are done on posix-style paths 230 relative to the project root. 231 232 Returns: 233 List of "Path"s which were found but not excluded. All paths are 234 absolute. 235 """ 236 discovery_report = self._git( 237 'submodule', 238 'foreach', 239 '--quiet', 240 '--recursive', 241 'echo $toplevel/$sm_path', 242 ) 243 module_dirs = [Path(line) for line in discovery_report.split()] 244 245 for exclude in excluded_paths: 246 if isinstance(exclude, Pattern): 247 for module_dir in reversed(module_dirs): 248 if exclude.fullmatch( 249 module_dir.relative_to(self._root).as_posix() 250 ): 251 module_dirs.remove(module_dir) 252 else: 253 for module_dir in reversed(module_dirs): 254 print(f'not regex: {exclude}') 255 if exclude == module_dir.relative_to(self._root).as_posix(): 256 module_dirs.remove(module_dir) 257 258 return module_dirs 259 260 def commit_message(self, commit: str = 'HEAD') -> str: 261 """Returns the commit message of the specified commit. 262 263 Defaults to ``HEAD`` if no commit specified. 264 265 Returns: 266 Commit message contents as a string. 267 """ 268 return self._git('log', '--format=%B', '-n1', commit) 269 270 def commit_author(self, commit: str = 'HEAD') -> str: 271 """Returns the author of the specified commit. 272 273 Defaults to ``HEAD`` if no commit specified. 274 275 Returns: 276 Commit author as a string. 277 """ 278 return self._git('log', '--format=%ae', '-n1', commit) 279 280 def commit_date(self, commit: str = 'HEAD') -> datetime: 281 """Returns the datetime of the specified commit. 282 283 Defaults to ``HEAD`` if no commit specified. 284 285 Returns: 286 Commit datetime as a datetime object. 287 """ 288 return datetime.fromisoformat( 289 self._git('log', '--format=%aI', '-n1', commit) 290 ) 291 292 def commit_hash( 293 self, 294 commit: str = 'HEAD', 295 short: bool = True, 296 ) -> str: 297 """Returns the hash associated with the specified commit. 298 299 Defaults to ``HEAD`` if no commit specified. 300 301 Returns: 302 Commit hash as a string. 303 """ 304 args = ['rev-parse'] 305 if short: 306 args += ['--short'] 307 args += [commit] 308 return self._git(*args) 309 310 def commit_change_id(self, commit: str = 'HEAD') -> str | None: 311 """Returns the Gerrit Change-Id of the specified commit. 312 313 Defaults to ``HEAD`` if no commit specified. 314 315 Returns: 316 Change-Id as a string, or ``None`` if it does not exist. 317 """ 318 message = self.commit_message(commit) 319 regex = re.compile( 320 'Change-Id: (I[a-fA-F0-9]+)', 321 re.MULTILINE, 322 ) 323 match = regex.search(message) 324 return match.group(1) if match else None 325 326 def commit_parents(self, commit: str = 'HEAD') -> list[str]: 327 args = ['log', '--pretty=%P', '-n', '1', commit] 328 return self._git(*args).split() 329 330 331def find_git_repo(path_in_repo: Path, tool_runner: ToolRunner) -> GitRepo: 332 """Tries to find the root of the Git repo that owns ``path_in_repo``. 333 334 Raises: 335 GitError: The specified path does not live in a Git repository. 336 337 Returns: 338 A GitRepo representing the the enclosing repository that tracks the 339 specified file or folder. 340 """ 341 git_tool = _GitTool( 342 tool_runner, 343 path_in_repo if path_in_repo.is_dir() else path_in_repo.parent, 344 ) 345 root = Path( 346 git_tool( 347 'rev-parse', 348 '--show-toplevel', 349 ) 350 ) 351 352 return GitRepo(root, tool_runner) 353 354 355def is_in_git_repo(p: Path, tool_runner: ToolRunner) -> bool: 356 """Returns true if the specified path is tracked by a Git repository. 357 358 Returns: 359 True if the specified file or folder is tracked by a Git repository. 360 """ 361 try: 362 find_git_repo(p, tool_runner) 363 except GitError: 364 return False 365 366 return True 367 368 369def _describe_constraints( 370 repo: GitRepo, 371 working_dir: Path, 372 commit: str | None, 373 pathspecs: Collection[Path | str], 374 exclude: Collection[Pattern[str]], 375) -> Iterable[str]: 376 if not repo.root().samefile(working_dir): 377 yield ( 378 'under the ' 379 f'{working_dir.resolve().relative_to(repo.root().resolve())}' 380 ' subdirectory' 381 ) 382 383 if commit in _TRACKING_BRANCH_ALIASES: 384 commit = repo.tracking_branch() 385 if commit is None: 386 _LOG.warning( 387 'Attempted to list files changed since the remote tracking ' 388 'branch, but the repo is not tracking a branch' 389 ) 390 391 if commit: 392 yield f'that have changed since {commit}' 393 394 if pathspecs: 395 paths_str = ', '.join(str(p) for p in pathspecs) 396 yield f'that match {plural(pathspecs, "pathspec")} ({paths_str})' 397 398 if exclude: 399 yield ( 400 f'that do not match {plural(exclude, "pattern")} (' 401 + ', '.join(p.pattern for p in exclude) 402 + ')' 403 ) 404 405 406def describe_git_pattern( 407 working_dir: Path, 408 commit: str | None, 409 pathspecs: Collection[Path | str], 410 exclude: Collection[Pattern], 411 tool_runner: ToolRunner, 412 project_root: Path | None = None, 413) -> str: 414 """Provides a description for a set of files in a Git repo. 415 416 Example: 417 418 files in the pigweed repo 419 - that have changed since origin/main..HEAD 420 - that do not match 7 patterns (...) 421 422 The unit tests for this function are the source of truth for the expected 423 output. 424 425 Returns: 426 A multi-line string with descriptive information about the provided 427 Git pathspecs. 428 """ 429 repo = find_git_repo(working_dir, tool_runner) 430 constraints = list( 431 _describe_constraints(repo, working_dir, commit, pathspecs, exclude) 432 ) 433 434 name = repo.root().name 435 if project_root and project_root != repo.root(): 436 name = str(repo.root().relative_to(project_root)) 437 438 if not constraints: 439 return f'all files in the {name} repo' 440 441 msg = f'files in the {name} repo' 442 if len(constraints) == 1: 443 return f'{msg} {constraints[0]}' 444 445 return msg + ''.join(f'\n - {line}' for line in constraints) 446