xref: /aosp_15_r20/external/pigweed/pw_arduino_build/py/pw_arduino_build/file_operations.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1#!/usr/bin/env python3
2# Copyright 2020 The Pigweed Authors
3#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may not
5# use this file except in compliance with the License. You may obtain a copy of
6# the License at
7#
8#     https://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations under
14# the License.
15"""File Helper Functions."""
16
17import glob
18import hashlib
19import json
20import logging
21import os
22import shutil
23import sys
24import subprocess
25import tarfile
26import urllib.request
27import zipfile
28from pathlib import Path
29
30_LOG = logging.getLogger(__name__)
31
32
33class InvalidChecksumError(Exception):
34    pass
35
36
37def find_files(
38    starting_dir: str, patterns: list[str], directories_only=False
39) -> list[str]:
40    original_working_dir = os.getcwd()
41    if not (os.path.exists(starting_dir) and os.path.isdir(starting_dir)):
42        raise FileNotFoundError(
43            "Directory '{}' does not exist.".format(starting_dir)
44        )
45
46    os.chdir(starting_dir)
47    files = []
48    for pattern in patterns:
49        for file_path in glob.glob(pattern, recursive=True):
50            if not directories_only or (
51                directories_only and os.path.isdir(file_path)
52            ):
53                files.append(file_path)
54    os.chdir(original_working_dir)
55    return sorted(files)
56
57
58def sha256_sum(file_name):
59    hash_sha256 = hashlib.sha256()
60    with open(file_name, "rb") as file_handle:
61        for chunk in iter(lambda: file_handle.read(4096), b""):
62            hash_sha256.update(chunk)
63    return hash_sha256.hexdigest()
64
65
66def md5_sum(file_name):
67    hash_md5 = hashlib.md5()
68    with open(file_name, "rb") as file_handle:
69        for chunk in iter(lambda: file_handle.read(4096), b""):
70            hash_md5.update(chunk)
71    return hash_md5.hexdigest()
72
73
74def verify_file_checksum(file_path, expected_checksum, sum_function=sha256_sum):
75    downloaded_checksum = sum_function(file_path)
76    if downloaded_checksum != expected_checksum:
77        raise InvalidChecksumError(
78            f"Invalid {sum_function.__name__}\n"
79            f"{downloaded_checksum} {os.path.basename(file_path)}\n"
80            f"{expected_checksum} (expected)\n\n"
81            "Please delete this file and try again:\n"
82            f"{file_path}"
83        )
84
85    _LOG.debug("  %s:", sum_function.__name__)
86    _LOG.debug("  %s %s", downloaded_checksum, os.path.basename(file_path))
87    return True
88
89
90def relative_or_absolute_path(file_string: str):
91    """Return a Path relative to os.getcwd(), else an absolute path."""
92    file_path = Path(file_string)
93    try:
94        return file_path.relative_to(os.getcwd())
95    except ValueError:
96        return file_path.resolve()
97
98
99def download_to_cache(
100    url: str,
101    expected_md5sum=None,
102    expected_sha256sum=None,
103    cache_directory=".cache",
104    downloaded_file_name=None,
105) -> str:
106    """TODO(tonymd) Add docstring."""
107
108    cache_dir = os.path.realpath(
109        os.path.expanduser(os.path.expandvars(cache_directory))
110    )
111    if not downloaded_file_name:
112        # Use the last part of the URL as the file name.
113        downloaded_file_name = url.split("/")[-1]
114    downloaded_file = os.path.join(cache_dir, downloaded_file_name)
115
116    if not os.path.exists(downloaded_file):
117        _LOG.info("Downloading: %s", url)
118        _LOG.info("Please wait...")
119        urllib.request.urlretrieve(url, filename=downloaded_file)
120
121    if os.path.exists(downloaded_file):
122        _LOG.info("Downloaded: %s", relative_or_absolute_path(downloaded_file))
123        if expected_sha256sum:
124            verify_file_checksum(
125                downloaded_file, expected_sha256sum, sum_function=sha256_sum
126            )
127        elif expected_md5sum:
128            verify_file_checksum(
129                downloaded_file, expected_md5sum, sum_function=md5_sum
130            )
131
132    return downloaded_file
133
134
135def extract_zipfile(archive_file: str, dest_dir: str):
136    """Extract a zipfile preseving permissions."""
137    destination_path = Path(dest_dir)
138    with zipfile.ZipFile(archive_file) as archive:
139        for info in archive.infolist():
140            archive.extract(info.filename, path=dest_dir)
141            permissions = info.external_attr >> 16
142            out_path = destination_path / info.filename
143            out_path.chmod(permissions)
144
145
146def extract_tarfile(archive_file: str, dest_dir: str):
147    with tarfile.open(archive_file, 'r') as archive:
148        archive.extractall(path=dest_dir)
149
150
151def extract_archive(
152    archive_file: str,
153    dest_dir: str,
154    cache_dir: str,
155    remove_single_toplevel_folder=True,
156):
157    """Extract a tar or zip file.
158
159    Args:
160        archive_file (str): Absolute path to the archive file.
161        dest_dir (str): Extraction destination directory.
162        cache_dir (str): Directory where temp files can be created.
163        remove_single_toplevel_folder (bool): If the archive contains only a
164            single folder move the contents of that into the destination
165            directory.
166    """
167    # Make a temporary directory to extract files into
168    temp_extract_dir = os.path.join(
169        cache_dir, "." + os.path.basename(archive_file)
170    )
171    os.makedirs(temp_extract_dir, exist_ok=True)
172
173    _LOG.info("Extracting: %s", relative_or_absolute_path(archive_file))
174    if zipfile.is_zipfile(archive_file):
175        extract_zipfile(archive_file, temp_extract_dir)
176    elif tarfile.is_tarfile(archive_file):
177        extract_tarfile(archive_file, temp_extract_dir)
178    else:
179        _LOG.error("Unknown archive format: %s", archive_file)
180        return sys.exit(1)
181
182    _LOG.info("Installing into: %s", relative_or_absolute_path(dest_dir))
183    path_to_extracted_files = temp_extract_dir
184
185    extracted_top_level_files = os.listdir(temp_extract_dir)
186    # Check if tarfile has only one folder
187    # If yes, make that the new path_to_extracted_files
188    if remove_single_toplevel_folder and len(extracted_top_level_files) == 1:
189        path_to_extracted_files = os.path.join(
190            temp_extract_dir, extracted_top_level_files[0]
191        )
192
193    # Move extracted files to dest_dir
194    extracted_files = os.listdir(path_to_extracted_files)
195    for file_name in extracted_files:
196        source_file = os.path.join(path_to_extracted_files, file_name)
197        dest_file = os.path.join(dest_dir, file_name)
198        shutil.move(source_file, dest_file)
199
200    # rm -rf temp_extract_dir
201    shutil.rmtree(temp_extract_dir, ignore_errors=True)
202
203    # Return List of extracted files
204    return list(Path(dest_dir).rglob("*"))
205
206
207def remove_empty_directories(directory):
208    """Recursively remove empty directories."""
209
210    for path in sorted(Path(directory).rglob("*"), reverse=True):
211        # If broken symlink
212        if path.is_symlink() and not path.exists():
213            path.unlink()
214        # if empty directory
215        elif path.is_dir() and len(os.listdir(path)) == 0:
216            path.rmdir()
217
218
219def decode_file_json(file_name):
220    """Decode JSON values from a file.
221
222    Does not raise an error if the file cannot be decoded."""
223
224    # Get absolute path to the file.
225    file_path = os.path.realpath(
226        os.path.expanduser(os.path.expandvars(file_name))
227    )
228
229    json_file_options = {}
230    try:
231        with open(file_path, "r") as jfile:
232            json_file_options = json.loads(jfile.read())
233    except (FileNotFoundError, json.JSONDecodeError):
234        _LOG.warning("Unable to read file '%s'", file_path)
235
236    return json_file_options, file_path
237
238
239def git_apply_patch(
240    root_directory, patch_file, ignore_whitespace=True, unsafe_paths=False
241):
242    """Use `git apply` to apply a diff file."""
243
244    _LOG.info("Applying Patch: %s", patch_file)
245    git_apply_command = ["git", "apply"]
246    if ignore_whitespace:
247        git_apply_command.append("--ignore-whitespace")
248    if unsafe_paths:
249        git_apply_command.append("--unsafe-paths")
250    git_apply_command += ["--directory", root_directory, patch_file]
251    subprocess.run(git_apply_command)
252