xref: /aosp_15_r20/external/pigweed/pw_env_setup/py/pw_env_setup/python_packages.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1#!/usr/bin/env python
2
3# Copyright 2021 The Pigweed Authors
4#
5# Licensed under the Apache License, Version 2.0 (the "License"); you may not
6# use this file except in compliance with the License. You may obtain a copy of
7# the License at
8#
9#     https://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
14# License for the specific language governing permissions and limitations under
15# the License.
16"""Save list of installed packages and versions."""
17
18import argparse
19import itertools
20import sys
21from pathlib import Path
22from typing import Iterator
23
24import pkg_resources
25
26
27def _installed_packages() -> Iterator[str]:
28    """Run pip python_packages and write to out."""
29    installed_packages = list(
30        pkg.as_requirement()
31        for pkg in pkg_resources.working_set  # pylint: disable=not-an-iterable
32        # Non-editable packages only
33        if isinstance(pkg, pkg_resources.DistInfoDistribution)  # type: ignore
34        # This will skip packages with local versions.
35        #   For example text after a plus sign: 1.2.3+dev456
36        and not pkg.parsed_version.local  # type: ignore
37        # These are always installed by default in:
38        #   pw_env_setup/py/pw_env_setup/virtualenv_setup/install.py
39        and pkg.key not in ['pip', 'setuptools', 'wheel']
40    )
41    for req in sorted(
42        installed_packages, key=lambda pkg: pkg.name.lower()  # type: ignore
43    ):
44        yield str(req)
45
46
47def ls(output_file: Path | None) -> int:  # pylint: disable=invalid-name
48    """Run pip python_packages and write to output_file."""
49    actual_requirements = frozenset(
50        pkg_resources.Requirement.parse(line) for line in _installed_packages()
51    )
52    missing_requirements = set()
53
54    # If updating an existing file, load the existing requirements to find lines
55    # that are missing in the active environment.
56    if output_file:
57        existing_lines = output_file.read_text().splitlines()
58        expected_requirements = set(
59            pkg_resources.Requirement.parse(line) for line in existing_lines
60        )
61        missing_requirements = expected_requirements - actual_requirements
62
63    new_requirements: list[pkg_resources.Requirement] = list(
64        actual_requirements
65    )
66
67    for requirement in missing_requirements:
68        # Preserve this requirement if it has a marker that doesn't apply to
69        # the current environment. For example a line with ;python_version <
70        # "3.9" if running on Python 3.9 or higher will be saved.
71        if requirement.marker and not requirement.marker.evaluate():
72            new_requirements.append(requirement)
73            continue
74
75        # If this package is in the active environment then it has a marker
76        # in the existing file that should be preserved.
77        try:
78            found_package = pkg_resources.working_set.find(requirement)
79            # If the package version doesn't match, save the new version.
80        except pkg_resources.VersionConflict:
81            found_package = None
82        if found_package:
83            # Delete the old line with no marker.
84            new_requirements.remove(found_package.as_requirement())
85            # Add the existing requirement line that includes the marker.
86            new_requirements.append(requirement)
87
88    out = output_file.open('w') if output_file else sys.stdout
89
90    for package in sorted(
91        new_requirements, key=lambda pkg: pkg.name.lower()  # type: ignore
92    ):
93        print(package, file=out)
94
95    if output_file:
96        out.close()
97    return 0
98
99
100class UpdateRequiredError(Exception):
101    pass
102
103
104def _stderr(*args, **kwargs):
105    return print(*args, file=sys.stderr, **kwargs)
106
107
108def _load_requirements_lines(*req_files: Path) -> Iterator[str]:
109    for req_file in req_files:
110        for line in req_file.read_text().splitlines():
111            # Ignore constraints, comments and blank lines
112            if line.startswith('-c') or line.startswith('#') or line == '':
113                continue
114            yield line
115
116
117def diff(
118    expected: Path, ignore_requirements_file: list[Path] | None = None
119) -> int:
120    """Report on differences between installed and expected versions."""
121    actual_lines = set(_installed_packages())
122    expected_lines = set(_load_requirements_lines(expected))
123    ignored_lines = set()
124    if ignore_requirements_file:
125        ignored_lines = set(_load_requirements_lines(*ignore_requirements_file))
126
127    if actual_lines == expected_lines:
128        _stderr('package versions are identical')
129        return 0
130
131    actual_requirements = frozenset(
132        pkg_resources.Requirement.parse(line) for line in actual_lines
133    )
134    expected_requirements = frozenset(
135        pkg_resources.Requirement.parse(line) for line in expected_lines
136    )
137    ignored_requirements = frozenset(
138        pkg_resources.Requirement.parse(line) for line in ignored_lines
139    )
140
141    removed_requirements = expected_requirements - actual_requirements
142    added_requirements = actual_requirements - expected_requirements
143
144    removed_packages: dict[pkg_resources.Requirement, str] = {}
145    updated_packages: dict[pkg_resources.Requirement, str] = {}
146    new_packages: dict[pkg_resources.Requirement, str] = {}
147    reformatted_packages: dict[pkg_resources.Requirement, str] = {}
148
149    for line in expected_lines:
150        requirement = pkg_resources.Requirement.parse(line)
151
152        # Check for lines that need reformatting
153        # This will catch lines that use underscores instead of dashes in the
154        # name of missing spaces after specifiers.
155
156        # Match this requirement with the original one found in
157        # actual_requirements.
158        #
159        # Note the requirement variable may equal its counterpart in
160        # actual_requirements due to the .hashCpm on the Requirement class
161        # checking their normalized name. Since we are looking for formatting
162        # mismatches here we need to retrieve the requirement instance from the
163        # actual_requirements set.
164        matching_found_requirement = {
165            req: req for req in actual_requirements
166        }.get(requirement)
167
168        # If the actual requirement line doesn't match.
169        if (
170            matching_found_requirement
171            and str(matching_found_requirement) != line
172        ):
173            reformatted_packages[matching_found_requirement] = line
174
175        # If this requirement isn't in the active enviroment and the line
176        # doesn't match the repr: flag for reformatting.
177        if not matching_found_requirement and str(requirement) != line:
178            reformatted_packages[requirement] = line
179
180        # If a requirement specifier is used and it doesn't apply, skip this
181        # line. See details for requirement specifiers at:
182        # https://pip.pypa.io/en/stable/reference/requirement-specifiers/#requirement-specifiers
183        if requirement.marker and not requirement.marker.evaluate():
184            continue
185
186        # Try to find this requirement in the current environment.
187        try:
188            found_package = pkg_resources.working_set.find(requirement)
189        # If the package version doesn't match, save the new version.
190        except pkg_resources.VersionConflict as err:
191            found_package = None
192            if err.dist:
193                found_package = err.dist
194
195        # If this requirement isn't in the environment, it was removed.
196        if not found_package:
197            removed_packages[requirement] = line
198            continue
199
200        # If found_package is set, the version doesn't match so it was updated.
201        if requirement.specs != found_package.as_requirement().specs:
202            updated_packages[found_package.as_requirement()] = line
203
204    ignored_distributions = list(
205        distribution
206        for distribution in pkg_resources.working_set  # pylint: disable=not-an-iterable
207        if distribution.as_requirement() in ignored_requirements
208    )
209    expected_distributions = list(
210        distribution
211        for distribution in pkg_resources.working_set  # pylint: disable=not-an-iterable
212        if distribution.as_requirement() in expected_requirements
213    )
214
215    def get_requirements(
216        dist_info: pkg_resources.Distribution,
217    ) -> Iterator[pkg_resources.Distribution]:
218        """Return requirement that are not in expected_distributions."""
219        for req in dist_info.requires():
220            req_dist_info = pkg_resources.working_set.find(req)
221            if not req_dist_info:
222                continue
223            if req_dist_info in expected_distributions:
224                continue
225            yield req_dist_info
226
227    def expand_requirements(
228        reqs: list[pkg_resources.Distribution],
229    ) -> Iterator[list[pkg_resources.Distribution]]:
230        """Recursively expand requirements."""
231        for dist_info in reqs:
232            deps = list(get_requirements(dist_info))
233            if deps:
234                yield deps
235            yield from expand_requirements(deps)
236
237    ignored_transitive_deps = set(
238        itertools.chain.from_iterable(
239            expand_requirements(ignored_distributions)
240        )
241    )
242
243    # Check for new packages
244    for requirement in added_requirements - removed_requirements:
245        if requirement in updated_packages:
246            continue
247        if requirement in ignored_requirements or ignored_transitive_deps:
248            continue
249
250        new_packages[requirement] = str(requirement)
251
252    # Print status messages to stderr
253
254    if reformatted_packages:
255        _stderr('Requirements that need reformatting:')
256        for requirement, line in reformatted_packages.items():
257            _stderr(f'  {line}')
258            _stderr('  should be:')
259            _stderr(f'  {str(requirement)}')
260
261    if updated_packages:
262        _stderr('Updated packages')
263        for requirement, line in updated_packages.items():
264            _stderr(f'  {str(requirement)} (from {line})')
265
266    if removed_packages:
267        _stderr('Removed packages')
268        for requirement in removed_packages:
269            _stderr(f'  {requirement}')
270
271    if new_packages:
272        _stderr('New packages')
273        for requirement in new_packages:
274            _stderr(f'  {requirement}')
275
276    if updated_packages or new_packages:
277        _stderr("Package versions don't match!")
278        _stderr(
279            f"""
280Please do the following:
281
282* purge your environment directory
283  * Linux/Mac: 'rm -rf "$_PW_ACTUAL_ENVIRONMENT_ROOT"'
284  * Windows: 'rmdir /S %_PW_ACTUAL_ENVIRONMENT_ROOT%'
285* bootstrap
286  * Linux/Mac: '. ./bootstrap.sh'
287  * Windows: 'bootstrap.bat'
288* update the constraint file
289  * 'pw python-packages list {expected.name}'
290"""
291        )
292        return -1
293
294    return 0
295
296
297def parse(argv: list[str] | None = None) -> argparse.Namespace:
298    """Parse command-line arguments."""
299    parser = argparse.ArgumentParser(
300        prog="python -m pw_env_setup.python_packages"
301    )
302    subparsers = parser.add_subparsers(dest='cmd')
303
304    list_parser = subparsers.add_parser(
305        'list', aliases=('ls',), help='List installed package versions.'
306    )
307    list_parser.add_argument('output_file', type=Path, nargs='?')
308
309    diff_parser = subparsers.add_parser(
310        'diff',
311        help='Show differences between expected and actual package versions.',
312    )
313    diff_parser.add_argument('expected', type=Path)
314    diff_parser.add_argument(
315        '--ignore-requirements-file', type=Path, action='append'
316    )
317
318    return parser.parse_args(argv)
319
320
321def main() -> int:
322    args = vars(parse())
323    cmd = args.pop('cmd')
324    if cmd == 'diff':
325        return diff(**args)
326    if cmd == 'list':
327        return ls(**args)
328    return -1
329
330
331if __name__ == '__main__':
332    sys.exit(main())
333