xref: /aosp_15_r20/external/pigweed/pw_env_setup/py/pw_env_setup/cipd_setup/update.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1#!/usr/bin/env python
2# Copyright 2020 The Pigweed Authors
3#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may not
5# use this file except in compliance with the License. You may obtain a copy of
6# the License at
7#
8#     https://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations under
14# the License.
15"""Installs or updates prebuilt tools.
16
17Must be tested with Python 2 and Python 3.
18
19The stdout of this script is meant to be executed by the invoking shell.
20"""
21
22import collections
23import hashlib
24import json
25import os
26import platform as platform_module
27import re
28import subprocess
29import sys
30
31
32def _stderr(*args):
33    return print(*args, file=sys.stderr)
34
35
36def check_auth(cipd, package_files, cipd_service_account, spin):
37    """Check have access to CIPD pigweed directory."""
38    cmd = [cipd]
39    extra_args = []
40    if cipd_service_account:
41        extra_args.extend(['-service-account-json', cipd_service_account])
42
43    paths = []
44    for package_file in package_files:
45        with open(package_file, 'r') as ins:
46            # This is an expensive RPC, so only check the first few entries
47            # in each file.
48            for i, entry in enumerate(json.load(ins).get('packages', ())):
49                if i >= 3:
50                    break
51                parts = entry['path'].split('/')
52                while '${' in parts[-1]:
53                    parts.pop(-1)
54                paths.append('/'.join(parts))
55
56    username = None
57    try:
58        output = subprocess.check_output(
59            cmd + ['auth-info'] + extra_args, stderr=subprocess.STDOUT
60        ).decode()
61        logged_in = True
62
63        match = re.search(r'Logged in as (\S*)\.', output)
64        if match:
65            username = match.group(1)
66
67    except subprocess.CalledProcessError:
68        logged_in = False
69
70    def _check_all_paths():
71        inaccessible_paths = []
72
73        for path in paths:
74            # Not catching CalledProcessError because 'cipd ls' seems to never
75            # return an error code unless it can't reach the CIPD server.
76            output = subprocess.check_output(
77                cmd + ['ls', path] + extra_args, stderr=subprocess.STDOUT
78            ).decode()
79            if 'No matching packages' not in output:
80                continue
81
82            # 'cipd ls' only lists sub-packages but ignores any packages at the
83            # given path. 'cipd instances' will give versions of that package.
84            # 'cipd instances' does use an error code if there's no such package
85            # or that package is inaccessible.
86            try:
87                subprocess.check_output(
88                    cmd + ['instances', path] + extra_args,
89                    stderr=subprocess.STDOUT,
90                )
91            except subprocess.CalledProcessError:
92                inaccessible_paths.append(path)
93
94        return inaccessible_paths
95
96    inaccessible_paths = _check_all_paths()
97
98    if inaccessible_paths and not logged_in:
99        with spin.pause():
100            _stderr()
101            _stderr(
102                'Not logged in to CIPD and no anonymous access to the '
103                'following CIPD paths:'
104            )
105            for path in inaccessible_paths:
106                _stderr('  {}'.format(path))
107            _stderr()
108            _stderr('Attempting CIPD login')
109            try:
110                # Note that with -service-account-json, auth-login is a no-op.
111                subprocess.check_call(cmd + ['auth-login'] + extra_args)
112            except subprocess.CalledProcessError:
113                _stderr('CIPD login failed')
114                return False
115
116        inaccessible_paths = _check_all_paths()
117
118    if inaccessible_paths:
119        _stderr('=' * 60)
120        username_part = ''
121        if username:
122            username_part = '({}) '.format(username)
123        _stderr(
124            'Your account {}does not have access to the following '
125            'paths'.format(username_part)
126        )
127        _stderr('(or they do not exist)')
128        for path in inaccessible_paths:
129            _stderr('  {}'.format(path))
130        _stderr('=' * 60)
131        return False
132
133    return True
134
135
136def platform(rosetta=False):
137    """Return the CIPD platform string of the current system."""
138    osname = {
139        'darwin': 'mac',
140        'linux': 'linux',
141        'windows': 'windows',
142    }[platform_module.system().lower()]
143
144    if platform_module.machine().startswith(('aarch64', 'armv8')):
145        arch = 'arm64'
146    elif platform_module.machine() == 'x86_64':
147        arch = 'amd64'
148    elif platform_module.machine() == 'i686':
149        arch = 'i386'
150    else:
151        arch = platform_module.machine()
152
153    platform_arch = '{}-{}'.format(osname, arch).lower()
154
155    # Support `mac-arm64` through Rosetta until `mac-arm64` binaries are ready
156    if platform_arch == 'mac-arm64' and rosetta:
157        return 'mac-amd64'
158
159    return platform_arch
160
161
162def all_package_files(env_vars, package_files):
163    """Recursively retrieve all package files."""
164
165    to_process = []
166    for pkg_file in package_files:
167        args = []
168        if env_vars:
169            args.append(env_vars.get('PW_PROJECT_ROOT'))
170        args.append(pkg_file)
171
172        # The signature here is os.path.join(a, *p). Pylint doesn't like when
173        # we call os.path.join(*args), but is happy if we instead call
174        # os.path.join(args[0], *args[1:]). Disabling the option on this line
175        # seems to be a less confusing choice.
176        path = os.path.join(*args)  # pylint: disable=no-value-for-parameter
177
178        to_process.append(path)
179
180    processed_files = []
181
182    def flatten_package_files(package_files):
183        """Flatten nested package files."""
184        for package_file in package_files:
185            yield package_file
186            processed_files.append(package_file)
187
188            with open(package_file, 'r') as ins:
189                entries = json.load(ins).get('included_files', ())
190                entries = [
191                    os.path.join(os.path.dirname(package_file), entry)
192                    for entry in entries
193                ]
194                entries = [
195                    entry for entry in entries if entry not in processed_files
196                ]
197
198            if entries:
199                for entry in flatten_package_files(entries):
200                    yield entry
201
202    return list(flatten_package_files(to_process))
203
204
205def update_subdir(package, package_file):
206    """Updates subdir in package and saves original."""
207    name = package_file_name(package_file)
208    if 'subdir' in package:
209        package['original_subdir'] = package['subdir']
210        package['subdir'] = '/'.join([name, package['subdir']])
211    else:
212        package['subdir'] = name
213
214
215def all_packages(package_files):
216    packages = []
217    for package_file in package_files:
218        with open(package_file, 'r') as ins:
219            file_packages = json.load(ins).get('packages', ())
220            for package in file_packages:
221                update_subdir(package, package_file)
222            packages.extend(file_packages)
223    return packages
224
225
226def deduplicate_packages(packages):
227    deduped = collections.OrderedDict()
228    for package in packages:
229        # Use the package + the subdir as the key
230        pkg_key = package['path']
231        pkg_key += package.get('original_subdir', '')
232
233        if pkg_key in deduped:
234            # Delete the old package
235            del deduped[pkg_key]
236
237        # Insert the new package at the end
238        deduped[pkg_key] = package
239    return list(deduped.values())
240
241
242def write_ensure_file(
243    package_files, ensure_file, platform
244):  # pylint: disable=redefined-outer-name
245    logdir = os.path.dirname(ensure_file)
246    packages = all_packages(package_files)
247    with open(os.path.join(logdir, 'all-packages.json'), 'w') as outs:
248        json.dump(packages, outs, indent=4)
249    deduped_packages = deduplicate_packages(packages)
250    with open(os.path.join(logdir, 'deduped-packages.json'), 'w') as outs:
251        json.dump(deduped_packages, outs, indent=4)
252
253    with open(ensure_file, 'w') as outs:
254        outs.write(
255            '$VerifiedPlatform linux-amd64\n'
256            '$VerifiedPlatform mac-amd64\n'
257            '$ParanoidMode CheckPresence\n'
258        )
259
260        for pkg in deduped_packages:
261            # If this is a new-style package manifest platform handling must
262            # be done here instead of by the cipd executable.
263            if 'platforms' in pkg and platform not in pkg['platforms']:
264                continue
265
266            outs.write('@Subdir {}\n'.format(pkg.get('subdir', '')))
267            outs.write('{} {}\n'.format(pkg['path'], ' '.join(pkg['tags'])))
268
269
270def package_file_name(package_file):
271    return os.path.basename(os.path.splitext(package_file)[0])
272
273
274def package_installation_path(root_install_dir, package_file):
275    """Returns the package installation path.
276
277    Args:
278      root_install_dir: The CIPD installation directory.
279      package_file: The path to the .json package definition file.
280    """
281    return os.path.join(
282        root_install_dir, 'packages', package_file_name(package_file)
283    )
284
285
286def update(  # pylint: disable=too-many-locals
287    cipd,
288    package_files,
289    root_install_dir,
290    cache_dir,
291    rosetta=False,
292    env_vars=None,
293    spin=None,
294    trust_hash=False,
295):
296    """Grab the tools listed in ensure_files."""
297
298    package_files = all_package_files(env_vars, package_files)
299
300    # TODO(mohrr) use os.makedirs(..., exist_ok=True).
301    if not os.path.isdir(root_install_dir):
302        os.makedirs(root_install_dir)
303
304    # This file is read by 'pw doctor' which needs to know which package files
305    # were used in the environment.
306    package_files_file = os.path.join(
307        root_install_dir, '_all_package_files.json'
308    )
309    with open(package_files_file, 'w') as outs:
310        json.dump(package_files, outs, indent=2)
311
312    if env_vars:
313        env_vars.prepend('PATH', root_install_dir)
314        env_vars.set('PW_CIPD_INSTALL_DIR', root_install_dir)
315        if cache_dir:
316            env_vars.set('CIPD_CACHE_DIR', cache_dir)
317
318    pw_root = None
319
320    if env_vars:
321        pw_root = env_vars.get('PW_ROOT', None)
322    if not pw_root:
323        pw_root = os.environ['PW_ROOT']
324
325    plat = platform(rosetta)
326
327    ensure_file = os.path.join(root_install_dir, 'packages.ensure')
328    write_ensure_file(package_files, ensure_file, plat)
329
330    install_dir = os.path.join(root_install_dir, 'packages')
331
332    cmd = [
333        cipd,
334        'ensure',
335        '-ensure-file',
336        ensure_file,
337        '-root',
338        install_dir,
339        '-log-level',
340        'debug',
341        '-json-output',
342        os.path.join(root_install_dir, 'packages.json'),
343        '-max-threads',
344        '0',  # 0 means use CPU count.
345    ]
346
347    if cache_dir:
348        cmd.extend(('-cache-dir', cache_dir))
349
350    cipd_service_account = None
351    if env_vars:
352        cipd_service_account = env_vars.get('PW_CIPD_SERVICE_ACCOUNT_JSON')
353    if not cipd_service_account:
354        cipd_service_account = os.environ.get('PW_CIPD_SERVICE_ACCOUNT_JSON')
355    if cipd_service_account:
356        cmd.extend(['-service-account-json', cipd_service_account])
357
358    hasher = hashlib.sha256()
359    encoded = '\0'.join(cmd)
360    if hasattr(encoded, 'encode'):
361        encoded = encoded.encode()
362    hasher.update(encoded)
363    with open(ensure_file, 'rb') as ins:
364        hasher.update(ins.read())
365    digest = hasher.hexdigest()
366
367    with open(os.path.join(root_install_dir, 'hash.log'), 'w') as hashlog:
368        print('calculated digest:', digest, file=hashlog)
369
370        hash_file = os.path.join(root_install_dir, 'packages.sha256')
371        print('hash file path:', hash_file, file=hashlog)
372        print('exists:', os.path.isfile(hash_file), file=hashlog)
373        print('trust_hash:', trust_hash, file=hashlog)
374        if trust_hash and os.path.isfile(hash_file):
375            with open(hash_file, 'r') as ins:
376                digest_file = ins.read().strip()
377                print('contents:', digest_file, file=hashlog)
378                print('equal:', digest == digest_file, file=hashlog)
379                if digest == digest_file:
380                    return True
381
382    if not check_auth(cipd, package_files, cipd_service_account, spin):
383        return False
384
385    log = os.path.join(root_install_dir, 'packages.log')
386    try:
387        with open(log, 'w') as outs:
388            print(*cmd, file=outs)
389            subprocess.check_call(cmd, stdout=outs, stderr=subprocess.STDOUT)
390    except subprocess.CalledProcessError:
391        with open(log, 'r') as ins:
392            sys.stderr.write(ins.read())
393            raise
394
395    with open(hash_file, 'w') as outs:
396        print(digest, file=outs)
397
398    # Set environment variables so tools can later find things under, for
399    # example, 'share'.
400    if env_vars:
401        for package_file in reversed(package_files):
402            name = package_file_name(package_file)
403            file_install_dir = os.path.join(install_dir, name)
404
405            # The MinGW package isn't always structured correctly, and might
406            # live nested in a `mingw64` subdirectory.
407            maybe_mingw = os.path.join(file_install_dir, 'mingw64', 'bin')
408            if os.name == 'nt' and os.path.isdir(maybe_mingw):
409                env_vars.prepend('PATH', maybe_mingw)
410
411            # If this package file has no packages and just includes one other
412            # file, there won't be any contents of the folder for this package.
413            # In that case, point the variable that would point to this folder
414            # to the folder of the included file.
415            with open(package_file) as ins:
416                contents = json.load(ins)
417                entries = contents.get('included_files', ())
418                file_packages = contents.get('packages', ())
419                if not file_packages and len(entries) == 1:
420                    file_install_dir = os.path.join(
421                        install_dir,
422                        package_file_name(os.path.basename(entries[0])),
423                    )
424
425            # Some executables get installed at top-level and some get
426            # installed under 'bin'. A small number of old packages prefix the
427            # entire tree with the platform (e.g., chromium/third_party/tcl).
428            for bin_dir in (
429                file_install_dir,
430                os.path.join(file_install_dir, 'bin'),
431                os.path.join(file_install_dir, plat, 'bin'),
432            ):
433                if os.path.isdir(bin_dir):
434                    env_vars.prepend('PATH', bin_dir)
435            env_vars.set(
436                'PW_{}_CIPD_INSTALL_DIR'.format(name.upper().replace('-', '_')),
437                file_install_dir,
438            )
439
440    return True
441