xref: /aosp_15_r20/external/bazelbuild-rules_python/python/private/whl_filegroup/extract_wheel_files.py (revision 60517a1edbc8ecf509223e9af94a7adec7d736b8)
1"""Extract files from a wheel's RECORD."""
2
3import re
4import sys
5import zipfile
6from collections.abc import Iterable
7from pathlib import Path
8
9WhlRecord = dict[str, tuple[str, int]]
10
11
12def get_record(whl_path: Path) -> WhlRecord:
13    try:
14        zipf = zipfile.ZipFile(whl_path)
15    except zipfile.BadZipFile as ex:
16        raise RuntimeError(f"{whl_path} is not a valid zip file") from ex
17    files = zipf.namelist()
18    try:
19        (record_file,) = [name for name in files if name.endswith(".dist-info/RECORD")]
20    except ValueError:
21        raise RuntimeError(f"{whl_path} doesn't contain exactly one .dist-info/RECORD")
22    record_lines = zipf.read(record_file).decode().splitlines()
23    return {
24        file: (filehash, int(filelen))
25        for line in record_lines
26        for file, filehash, filelen in [line.split(",")]
27        if filehash  # Skip RECORD itself, which has no hash or length
28    }
29
30
31def get_files(whl_record: WhlRecord, regex_pattern: str) -> list[str]:
32    """Get files in a wheel that match a regex pattern."""
33    p = re.compile(regex_pattern)
34    return [filepath for filepath in whl_record.keys() if re.match(p, filepath)]
35
36
37def extract_files(whl_path: Path, files: Iterable[str], outdir: Path) -> None:
38    """Extract files from whl_path to outdir."""
39    zipf = zipfile.ZipFile(whl_path)
40    for file in files:
41        zipf.extract(file, outdir)
42
43
44def main() -> None:
45    if len(sys.argv) not in {3, 4}:
46        print(
47            f"Usage: {sys.argv[0]} <wheel> <out_dir> [regex_pattern]",
48            file=sys.stderr,
49        )
50        sys.exit(1)
51
52    whl_path = Path(sys.argv[1]).resolve()
53    outdir = Path(sys.argv[2])
54    regex_pattern = sys.argv[3] if len(sys.argv) == 4 else ""
55
56    whl_record = get_record(whl_path)
57    files = get_files(whl_record, regex_pattern)
58    extract_files(whl_path, files, outdir)
59
60
61if __name__ == "__main__":
62    main()
63