xref: /aosp_15_r20/external/bazelbuild-rules_python/python/private/pypi/repack_whl.py (revision 60517a1edbc8ecf509223e9af94a7adec7d736b8)
1# Copyright 2023 The Bazel Authors. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#    http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""
16Regenerate a whl file after patching and cleanup the patched contents.
17
18This script will take contents of the current directory and create a new wheel
19out of it and will remove all files that were written to the wheel.
20"""
21
22from __future__ import annotations
23
24import argparse
25import difflib
26import logging
27import pathlib
28import sys
29import tempfile
30
31from tools.wheelmaker import _WhlFile
32
33# NOTE: Implement the following matching of what goes into the RECORD
34# https://peps.python.org/pep-0491/#the-dist-info-directory
35_EXCLUDES = [
36    "RECORD",
37    "INSTALLER",
38    "RECORD.jws",
39    "RECORD.p7s",
40    "REQUESTED",
41]
42
43_DISTINFO = "dist-info"
44
45
46def _unidiff_output(expected, actual, record):
47    """
48    Helper function. Returns a string containing the unified diff of two
49    multiline strings.
50    """
51
52    expected = expected.splitlines(1)
53    actual = actual.splitlines(1)
54
55    diff = difflib.unified_diff(
56        expected, actual, fromfile=f"a/{record}", tofile=f"b/{record}"
57    )
58
59    return "".join(diff)
60
61
62def _files_to_pack(dir: pathlib.Path, want_record: str) -> list[pathlib.Path]:
63    """Check that the RECORD file entries are correct and print a unified diff on failure."""
64
65    # First get existing files by using the RECORD file
66    got_files = []
67    got_distinfos = []
68    for line in want_record.splitlines():
69        rec, _, _ = line.partition(",")
70        path = dir / rec
71
72        if not path.exists():
73            # skip files that do not exist as they won't be present in the final
74            # RECORD file.
75            continue
76
77        if not path.parent.name.endswith(_DISTINFO):
78            got_files.append(path)
79        elif path.name not in _EXCLUDES:
80            got_distinfos.append(path)
81
82    # Then get extra files present in the directory but not in the RECORD file
83    extra_files = []
84    extra_distinfos = []
85    for path in dir.rglob("*"):
86        if path.is_dir():
87            continue
88
89        elif path.parent.name.endswith(_DISTINFO):
90            if path.name in _EXCLUDES:
91                # NOTE: we implement the following matching of what goes into the RECORD
92                # https://peps.python.org/pep-0491/#the-dist-info-directory
93                continue
94            elif path not in got_distinfos:
95                extra_distinfos.append(path)
96
97        elif path not in got_files:
98            extra_files.append(path)
99
100    # sort the extra files for reproducibility
101    extra_files.sort()
102    extra_distinfos.sort()
103
104    # This order ensures that the structure of the RECORD file is always the
105    # same and ensures smaller patchsets to the RECORD file in general
106    return got_files + extra_files + got_distinfos + extra_distinfos
107
108
109def main(sys_argv):
110    parser = argparse.ArgumentParser(description=__doc__)
111    parser.add_argument(
112        "whl_path",
113        type=pathlib.Path,
114        help="The original wheel file that we have patched.",
115    )
116    parser.add_argument(
117        "--record-patch",
118        type=pathlib.Path,
119        help="The output path that we are going to write the RECORD file patch to.",
120    )
121    parser.add_argument(
122        "output",
123        type=pathlib.Path,
124        help="The output path that we are going to write a new file to.",
125    )
126    args = parser.parse_args(sys_argv)
127
128    cwd = pathlib.Path.cwd()
129    logging.debug("=" * 80)
130    logging.debug("Repackaging the wheel")
131    logging.debug("=" * 80)
132
133    with tempfile.TemporaryDirectory(dir=cwd) as tmpdir:
134        patched_wheel_dir = cwd / tmpdir
135        logging.debug(f"Created a tmpdir: {patched_wheel_dir}")
136
137        excludes = [args.whl_path, patched_wheel_dir]
138
139        logging.debug("Moving whl contents to the newly created tmpdir")
140        for p in cwd.glob("*"):
141            if p in excludes:
142                logging.debug(f"Ignoring: {p}")
143                continue
144
145            rel_path = p.relative_to(cwd)
146            dst = p.rename(patched_wheel_dir / rel_path)
147            logging.debug(f"mv {p} -> {dst}")
148
149        distinfo_dir = next(iter(patched_wheel_dir.glob("*dist-info")))
150        logging.debug(f"Found dist-info dir: {distinfo_dir}")
151        record_path = distinfo_dir / "RECORD"
152        record_contents = record_path.read_text() if record_path.exists() else ""
153        distribution_prefix = distinfo_dir.with_suffix("").name
154
155        with _WhlFile(
156            args.output, mode="w", distribution_prefix=distribution_prefix
157        ) as out:
158            for p in _files_to_pack(patched_wheel_dir, record_contents):
159                rel_path = p.relative_to(patched_wheel_dir)
160                out.add_file(str(rel_path), p)
161
162            logging.debug(f"Writing RECORD file")
163            got_record = out.add_recordfile().decode("utf-8", "surrogateescape")
164
165    if got_record == record_contents:
166        logging.info(f"Created a whl file: {args.output}")
167        return
168
169    record_diff = _unidiff_output(
170        record_contents,
171        got_record,
172        out.distinfo_path("RECORD"),
173    )
174    args.record_patch.write_text(record_diff)
175    logging.warning(
176        f"Please apply patch to the RECORD file ({args.record_patch}):\n{record_diff}"
177    )
178
179
180if __name__ == "__main__":
181    logging.basicConfig(
182        format="%(module)s: %(levelname)s: %(message)s", level=logging.DEBUG
183    )
184
185    sys.exit(main(sys.argv[1:]))
186