1"""Extract files from a wheel's RECORD.""" 2 3import re 4import sys 5import zipfile 6from collections.abc import Iterable 7from pathlib import Path 8 9WhlRecord = dict[str, tuple[str, int]] 10 11 12def get_record(whl_path: Path) -> WhlRecord: 13 try: 14 zipf = zipfile.ZipFile(whl_path) 15 except zipfile.BadZipFile as ex: 16 raise RuntimeError(f"{whl_path} is not a valid zip file") from ex 17 files = zipf.namelist() 18 try: 19 (record_file,) = [name for name in files if name.endswith(".dist-info/RECORD")] 20 except ValueError: 21 raise RuntimeError(f"{whl_path} doesn't contain exactly one .dist-info/RECORD") 22 record_lines = zipf.read(record_file).decode().splitlines() 23 return { 24 file: (filehash, int(filelen)) 25 for line in record_lines 26 for file, filehash, filelen in [line.split(",")] 27 if filehash # Skip RECORD itself, which has no hash or length 28 } 29 30 31def get_files(whl_record: WhlRecord, regex_pattern: str) -> list[str]: 32 """Get files in a wheel that match a regex pattern.""" 33 p = re.compile(regex_pattern) 34 return [filepath for filepath in whl_record.keys() if re.match(p, filepath)] 35 36 37def extract_files(whl_path: Path, files: Iterable[str], outdir: Path) -> None: 38 """Extract files from whl_path to outdir.""" 39 zipf = zipfile.ZipFile(whl_path) 40 for file in files: 41 zipf.extract(file, outdir) 42 43 44def main() -> None: 45 if len(sys.argv) not in {3, 4}: 46 print( 47 f"Usage: {sys.argv[0]} <wheel> <out_dir> [regex_pattern]", 48 file=sys.stderr, 49 ) 50 sys.exit(1) 51 52 whl_path = Path(sys.argv[1]).resolve() 53 outdir = Path(sys.argv[2]) 54 regex_pattern = sys.argv[3] if len(sys.argv) == 4 else "" 55 56 whl_record = get_record(whl_path) 57 files = get_files(whl_record, regex_pattern) 58 extract_files(whl_path, files, outdir) 59 60 61if __name__ == "__main__": 62 main() 63