1#!/usr/bin/env python3 2# Copyright 2020 The Pigweed Authors 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); you may not 5# use this file except in compliance with the License. You may obtain a copy of 6# the License at 7# 8# https://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 13# License for the specific language governing permissions and limitations under 14# the License. 15"""File Helper Functions.""" 16 17import glob 18import hashlib 19import json 20import logging 21import os 22import shutil 23import sys 24import subprocess 25import tarfile 26import urllib.request 27import zipfile 28from pathlib import Path 29 30_LOG = logging.getLogger(__name__) 31 32 33class InvalidChecksumError(Exception): 34 pass 35 36 37def find_files( 38 starting_dir: str, patterns: list[str], directories_only=False 39) -> list[str]: 40 original_working_dir = os.getcwd() 41 if not (os.path.exists(starting_dir) and os.path.isdir(starting_dir)): 42 raise FileNotFoundError( 43 "Directory '{}' does not exist.".format(starting_dir) 44 ) 45 46 os.chdir(starting_dir) 47 files = [] 48 for pattern in patterns: 49 for file_path in glob.glob(pattern, recursive=True): 50 if not directories_only or ( 51 directories_only and os.path.isdir(file_path) 52 ): 53 files.append(file_path) 54 os.chdir(original_working_dir) 55 return sorted(files) 56 57 58def sha256_sum(file_name): 59 hash_sha256 = hashlib.sha256() 60 with open(file_name, "rb") as file_handle: 61 for chunk in iter(lambda: file_handle.read(4096), b""): 62 hash_sha256.update(chunk) 63 return hash_sha256.hexdigest() 64 65 66def md5_sum(file_name): 67 hash_md5 = hashlib.md5() 68 with open(file_name, "rb") as file_handle: 69 for chunk in iter(lambda: file_handle.read(4096), b""): 70 hash_md5.update(chunk) 71 return hash_md5.hexdigest() 72 73 74def verify_file_checksum(file_path, expected_checksum, sum_function=sha256_sum): 75 downloaded_checksum = sum_function(file_path) 76 if downloaded_checksum != expected_checksum: 77 raise InvalidChecksumError( 78 f"Invalid {sum_function.__name__}\n" 79 f"{downloaded_checksum} {os.path.basename(file_path)}\n" 80 f"{expected_checksum} (expected)\n\n" 81 "Please delete this file and try again:\n" 82 f"{file_path}" 83 ) 84 85 _LOG.debug(" %s:", sum_function.__name__) 86 _LOG.debug(" %s %s", downloaded_checksum, os.path.basename(file_path)) 87 return True 88 89 90def relative_or_absolute_path(file_string: str): 91 """Return a Path relative to os.getcwd(), else an absolute path.""" 92 file_path = Path(file_string) 93 try: 94 return file_path.relative_to(os.getcwd()) 95 except ValueError: 96 return file_path.resolve() 97 98 99def download_to_cache( 100 url: str, 101 expected_md5sum=None, 102 expected_sha256sum=None, 103 cache_directory=".cache", 104 downloaded_file_name=None, 105) -> str: 106 """TODO(tonymd) Add docstring.""" 107 108 cache_dir = os.path.realpath( 109 os.path.expanduser(os.path.expandvars(cache_directory)) 110 ) 111 if not downloaded_file_name: 112 # Use the last part of the URL as the file name. 113 downloaded_file_name = url.split("/")[-1] 114 downloaded_file = os.path.join(cache_dir, downloaded_file_name) 115 116 if not os.path.exists(downloaded_file): 117 _LOG.info("Downloading: %s", url) 118 _LOG.info("Please wait...") 119 urllib.request.urlretrieve(url, filename=downloaded_file) 120 121 if os.path.exists(downloaded_file): 122 _LOG.info("Downloaded: %s", relative_or_absolute_path(downloaded_file)) 123 if expected_sha256sum: 124 verify_file_checksum( 125 downloaded_file, expected_sha256sum, sum_function=sha256_sum 126 ) 127 elif expected_md5sum: 128 verify_file_checksum( 129 downloaded_file, expected_md5sum, sum_function=md5_sum 130 ) 131 132 return downloaded_file 133 134 135def extract_zipfile(archive_file: str, dest_dir: str): 136 """Extract a zipfile preseving permissions.""" 137 destination_path = Path(dest_dir) 138 with zipfile.ZipFile(archive_file) as archive: 139 for info in archive.infolist(): 140 archive.extract(info.filename, path=dest_dir) 141 permissions = info.external_attr >> 16 142 out_path = destination_path / info.filename 143 out_path.chmod(permissions) 144 145 146def extract_tarfile(archive_file: str, dest_dir: str): 147 with tarfile.open(archive_file, 'r') as archive: 148 archive.extractall(path=dest_dir) 149 150 151def extract_archive( 152 archive_file: str, 153 dest_dir: str, 154 cache_dir: str, 155 remove_single_toplevel_folder=True, 156): 157 """Extract a tar or zip file. 158 159 Args: 160 archive_file (str): Absolute path to the archive file. 161 dest_dir (str): Extraction destination directory. 162 cache_dir (str): Directory where temp files can be created. 163 remove_single_toplevel_folder (bool): If the archive contains only a 164 single folder move the contents of that into the destination 165 directory. 166 """ 167 # Make a temporary directory to extract files into 168 temp_extract_dir = os.path.join( 169 cache_dir, "." + os.path.basename(archive_file) 170 ) 171 os.makedirs(temp_extract_dir, exist_ok=True) 172 173 _LOG.info("Extracting: %s", relative_or_absolute_path(archive_file)) 174 if zipfile.is_zipfile(archive_file): 175 extract_zipfile(archive_file, temp_extract_dir) 176 elif tarfile.is_tarfile(archive_file): 177 extract_tarfile(archive_file, temp_extract_dir) 178 else: 179 _LOG.error("Unknown archive format: %s", archive_file) 180 return sys.exit(1) 181 182 _LOG.info("Installing into: %s", relative_or_absolute_path(dest_dir)) 183 path_to_extracted_files = temp_extract_dir 184 185 extracted_top_level_files = os.listdir(temp_extract_dir) 186 # Check if tarfile has only one folder 187 # If yes, make that the new path_to_extracted_files 188 if remove_single_toplevel_folder and len(extracted_top_level_files) == 1: 189 path_to_extracted_files = os.path.join( 190 temp_extract_dir, extracted_top_level_files[0] 191 ) 192 193 # Move extracted files to dest_dir 194 extracted_files = os.listdir(path_to_extracted_files) 195 for file_name in extracted_files: 196 source_file = os.path.join(path_to_extracted_files, file_name) 197 dest_file = os.path.join(dest_dir, file_name) 198 shutil.move(source_file, dest_file) 199 200 # rm -rf temp_extract_dir 201 shutil.rmtree(temp_extract_dir, ignore_errors=True) 202 203 # Return List of extracted files 204 return list(Path(dest_dir).rglob("*")) 205 206 207def remove_empty_directories(directory): 208 """Recursively remove empty directories.""" 209 210 for path in sorted(Path(directory).rglob("*"), reverse=True): 211 # If broken symlink 212 if path.is_symlink() and not path.exists(): 213 path.unlink() 214 # if empty directory 215 elif path.is_dir() and len(os.listdir(path)) == 0: 216 path.rmdir() 217 218 219def decode_file_json(file_name): 220 """Decode JSON values from a file. 221 222 Does not raise an error if the file cannot be decoded.""" 223 224 # Get absolute path to the file. 225 file_path = os.path.realpath( 226 os.path.expanduser(os.path.expandvars(file_name)) 227 ) 228 229 json_file_options = {} 230 try: 231 with open(file_path, "r") as jfile: 232 json_file_options = json.loads(jfile.read()) 233 except (FileNotFoundError, json.JSONDecodeError): 234 _LOG.warning("Unable to read file '%s'", file_path) 235 236 return json_file_options, file_path 237 238 239def git_apply_patch( 240 root_directory, patch_file, ignore_whitespace=True, unsafe_paths=False 241): 242 """Use `git apply` to apply a diff file.""" 243 244 _LOG.info("Applying Patch: %s", patch_file) 245 git_apply_command = ["git", "apply"] 246 if ignore_whitespace: 247 git_apply_command.append("--ignore-whitespace") 248 if unsafe_paths: 249 git_apply_command.append("--unsafe-paths") 250 git_apply_command += ["--directory", root_directory, patch_file] 251 subprocess.run(git_apply_command) 252