xref: /aosp_15_r20/external/pytorch/tools/nightly.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2# Much of the logging code here was forked from https://github.com/ezyang/ghstack
3# Copyright (c) Edward Z. Yang <[email protected]>
4"""Checks out the nightly development version of PyTorch and installs pre-built
5binaries into the repo.
6
7You can use this script to check out a new nightly branch with the following::
8
9    $ ./tools/nightly.py checkout -b my-nightly-branch
10    $ conda activate pytorch-deps
11
12Or if you would like to re-use an existing conda environment, you can pass in
13the regular environment parameters (--name or --prefix)::
14
15    $ ./tools/nightly.py checkout -b my-nightly-branch -n my-env
16    $ conda activate my-env
17
18To install the nightly binaries built with CUDA, you can pass in the flag --cuda::
19
20    $ ./tools/nightly.py checkout -b my-nightly-branch --cuda
21    $ conda activate pytorch-deps
22
23You can also use this tool to pull the nightly commits into the current branch as
24well. This can be done with::
25
26    $ ./tools/nightly.py pull -n my-env
27    $ conda activate my-env
28
29Pulling will reinstall the conda dependencies as well as the nightly binaries into
30the repo directory.
31"""
32
33from __future__ import annotations
34
35import argparse
36import contextlib
37import functools
38import glob
39import itertools
40import json
41import logging
42import os
43import re
44import shutil
45import subprocess
46import sys
47import tempfile
48import time
49import uuid
50from ast import literal_eval
51from datetime import datetime
52from pathlib import Path
53from platform import system as platform_system
54from typing import Any, Callable, cast, Generator, Iterable, Iterator, Sequence, TypeVar
55
56
57REPO_ROOT = Path(__file__).absolute().parent.parent
58GITHUB_REMOTE_URL = "https://github.com/pytorch/pytorch.git"
59SPECS_TO_INSTALL = ("pytorch", "mypy", "pytest", "hypothesis", "ipython", "sphinx")
60
61LOGGER: logging.Logger | None = None
62URL_FORMAT = "{base_url}/{platform}/{dist_name}.tar.bz2"
63DATETIME_FORMAT = "%Y-%m-%d_%Hh%Mm%Ss"
64SHA1_RE = re.compile(r"(?P<sha1>[0-9a-fA-F]{40})")
65USERNAME_PASSWORD_RE = re.compile(r":\/\/(.*?)\@")
66LOG_DIRNAME_RE = re.compile(
67    r"(?P<datetime>\d{4}-\d\d-\d\d_\d\dh\d\dm\d\ds)_"
68    r"(?P<uuid>[0-9a-f]{8}-(?:[0-9a-f]{4}-){3}[0-9a-f]{12})",
69)
70
71
72class Formatter(logging.Formatter):
73    redactions: dict[str, str]
74
75    def __init__(self, fmt: str | None = None, datefmt: str | None = None) -> None:
76        super().__init__(fmt, datefmt)
77        self.redactions = {}
78
79    # Remove sensitive information from URLs
80    def _filter(self, s: str) -> str:
81        s = USERNAME_PASSWORD_RE.sub(r"://<USERNAME>:<PASSWORD>@", s)
82        for needle, replace in self.redactions.items():
83            s = s.replace(needle, replace)
84        return s
85
86    def formatMessage(self, record: logging.LogRecord) -> str:
87        if record.levelno == logging.INFO or record.levelno == logging.DEBUG:
88            # Log INFO/DEBUG without any adornment
89            return record.getMessage()
90        else:
91            # I'm not sure why, but formatMessage doesn't show up
92            # even though it's in the typeshed for Python >3
93            return super().formatMessage(record)
94
95    def format(self, record: logging.LogRecord) -> str:
96        return self._filter(super().format(record))
97
98    def redact(self, needle: str, replace: str = "<REDACTED>") -> None:
99        """Redact specific strings; e.g., authorization tokens.  This won't
100        retroactively redact stuff you've already leaked, so make sure
101        you redact things as soon as possible.
102        """
103        # Don't redact empty strings; this will lead to something
104        # that looks like s<REDACTED>t<REDACTED>r<REDACTED>...
105        if needle == "":
106            return
107        self.redactions[needle] = replace
108
109
110def git(*args: str) -> list[str]:
111    return ["git", "-C", str(REPO_ROOT), *args]
112
113
114@functools.lru_cache
115def logging_base_dir() -> Path:
116    base_dir = REPO_ROOT / "nightly" / "log"
117    base_dir.mkdir(parents=True, exist_ok=True)
118    return base_dir
119
120
121@functools.lru_cache
122def logging_run_dir() -> Path:
123    base_dir = logging_base_dir()
124    cur_dir = base_dir / f"{datetime.now().strftime(DATETIME_FORMAT)}_{uuid.uuid1()}"
125    cur_dir.mkdir(parents=True, exist_ok=True)
126    return cur_dir
127
128
129@functools.lru_cache
130def logging_record_argv() -> None:
131    s = subprocess.list2cmdline(sys.argv)
132    (logging_run_dir() / "argv").write_text(s, encoding="utf-8")
133
134
135def logging_record_exception(e: BaseException) -> None:
136    (logging_run_dir() / "exception").write_text(type(e).__name__, encoding="utf-8")
137
138
139def logging_rotate() -> None:
140    log_base = logging_base_dir()
141    old_logs = sorted(log_base.iterdir(), reverse=True)
142    for stale_log in old_logs[1000:]:
143        # Sanity check that it looks like a log
144        if LOG_DIRNAME_RE.fullmatch(stale_log.name) is not None:
145            shutil.rmtree(stale_log)
146
147
148@contextlib.contextmanager
149def logging_manager(*, debug: bool = False) -> Generator[logging.Logger, None, None]:
150    """Setup logging. If a failure starts here we won't
151    be able to save the user in a reasonable way.
152
153    Logging structure: there is one logger (the root logger)
154    and in processes all events.  There are two handlers:
155    stderr (INFO) and file handler (DEBUG).
156    """
157    formatter = Formatter(fmt="%(levelname)s: %(message)s", datefmt="")
158    root_logger = logging.getLogger("conda-pytorch")
159    root_logger.setLevel(logging.DEBUG)
160
161    console_handler = logging.StreamHandler()
162    if debug:
163        console_handler.setLevel(logging.DEBUG)
164    else:
165        console_handler.setLevel(logging.INFO)
166    console_handler.setFormatter(formatter)
167    root_logger.addHandler(console_handler)
168
169    log_file = logging_run_dir() / "nightly.log"
170
171    file_handler = logging.FileHandler(log_file)
172    file_handler.setFormatter(formatter)
173    root_logger.addHandler(file_handler)
174    logging_record_argv()
175
176    try:
177        logging_rotate()
178        print(f"log file: {log_file}")
179        yield root_logger
180    except Exception as e:
181        logging.exception("Fatal exception")
182        logging_record_exception(e)
183        print(f"log file: {log_file}")
184        sys.exit(1)
185    except BaseException as e:
186        # You could logging.debug here to suppress the backtrace
187        # entirely, but there is no reason to hide it from technically
188        # savvy users.
189        logging.info("", exc_info=True)
190        logging_record_exception(e)
191        print(f"log file: {log_file}")
192        sys.exit(1)
193
194
195def check_branch(subcommand: str, branch: str | None) -> str | None:
196    """Checks that the branch name can be checked out."""
197    if subcommand != "checkout":
198        return None
199    # first make sure actual branch name was given
200    if branch is None:
201        return "Branch name to checkout must be supplied with '-b' option"
202    # next check that the local repo is clean
203    cmd = git("status", "--untracked-files=no", "--porcelain")
204    stdout = subprocess.check_output(cmd, text=True, encoding="utf-8")
205    if stdout.strip():
206        return "Need to have clean working tree to checkout!\n\n" + stdout
207    # next check that the branch name doesn't already exist
208    cmd = git("show-ref", "--verify", "--quiet", f"refs/heads/{branch}")
209    p = subprocess.run(cmd, capture_output=True, check=False)  # type: ignore[assignment]
210    if not p.returncode:
211        return f"Branch {branch!r} already exists"
212    return None
213
214
215@contextlib.contextmanager
216def timer(logger: logging.Logger, prefix: str) -> Iterator[None]:
217    """Timed context manager"""
218    start_time = time.perf_counter()
219    yield
220    logger.info("%s took %.3f [s]", prefix, time.perf_counter() - start_time)
221
222
223F = TypeVar("F", bound=Callable[..., Any])
224
225
226def timed(prefix: str) -> Callable[[F], F]:
227    """Decorator for timing functions"""
228
229    def dec(f: F) -> F:
230        @functools.wraps(f)
231        def wrapper(*args: Any, **kwargs: Any) -> Any:
232            logger = cast(logging.Logger, LOGGER)
233            logger.info(prefix)
234            with timer(logger, prefix):
235                return f(*args, **kwargs)
236
237        return cast(F, wrapper)
238
239    return dec
240
241
242def _make_channel_args(
243    channels: Iterable[str] = ("pytorch-nightly",),
244    override_channels: bool = False,
245) -> list[str]:
246    args = []
247    for channel in channels:
248        args.extend(["--channel", channel])
249    if override_channels:
250        args.append("--override-channels")
251    return args
252
253
254@timed("Solving conda environment")
255def conda_solve(
256    specs: Iterable[str],
257    *,
258    name: str | None = None,
259    prefix: str | None = None,
260    channels: Iterable[str] = ("pytorch-nightly",),
261    override_channels: bool = False,
262) -> tuple[list[str], str, str, bool, list[str]]:
263    """Performs the conda solve and splits the deps from the package."""
264    # compute what environment to use
265    if prefix is not None:
266        existing_env = True
267        env_opts = ["--prefix", prefix]
268    elif name is not None:
269        existing_env = True
270        env_opts = ["--name", name]
271    else:
272        # create new environment
273        existing_env = False
274        env_opts = ["--name", "pytorch-deps"]
275    # run solve
276    if existing_env:
277        cmd = [
278            "conda",
279            "install",
280            "--yes",
281            "--dry-run",
282            "--json",
283        ]
284        cmd.extend(env_opts)
285    else:
286        cmd = [
287            "conda",
288            "create",
289            "--yes",
290            "--dry-run",
291            "--json",
292            "--name",
293            "__pytorch__",
294        ]
295    channel_args = _make_channel_args(
296        channels=channels,
297        override_channels=override_channels,
298    )
299    cmd.extend(channel_args)
300    cmd.extend(specs)
301    stdout = subprocess.check_output(cmd, text=True, encoding="utf-8")
302    # parse solution
303    solve = json.loads(stdout)
304    link = solve["actions"]["LINK"]
305    deps = []
306    pytorch, platform = "", ""
307    for pkg in link:
308        url = URL_FORMAT.format(**pkg)
309        if pkg["name"] == "pytorch":
310            pytorch = url
311            platform = pkg["platform"]
312        else:
313            deps.append(url)
314    assert pytorch, "PyTorch package not found in solve"
315    assert platform, "Platform not found in solve"
316    return deps, pytorch, platform, existing_env, env_opts
317
318
319@timed("Installing dependencies")
320def deps_install(deps: list[str], existing_env: bool, env_opts: list[str]) -> None:
321    """Install dependencies to deps environment"""
322    if not existing_env:
323        # first remove previous pytorch-deps env
324        cmd = ["conda", "env", "remove", "--yes", *env_opts]
325        subprocess.check_call(cmd)
326    # install new deps
327    install_command = "install" if existing_env else "create"
328    cmd = ["conda", install_command, "--yes", "--no-deps", *env_opts, *deps]
329    subprocess.check_call(cmd)
330
331
332@timed("Installing pytorch nightly binaries")
333def pytorch_install(url: str) -> tempfile.TemporaryDirectory[str]:
334    """Install pytorch into a temporary directory"""
335    pytorch_dir = tempfile.TemporaryDirectory(prefix="conda-pytorch-")
336    cmd = ["conda", "create", "--yes", "--no-deps", f"--prefix={pytorch_dir.name}", url]
337    subprocess.check_call(cmd)
338    return pytorch_dir
339
340
341def _site_packages(dirname: str, platform: str) -> Path:
342    if platform.startswith("win"):
343        template = os.path.join(dirname, "Lib", "site-packages")
344    else:
345        template = os.path.join(dirname, "lib", "python*.*", "site-packages")
346    return Path(next(glob.iglob(template))).absolute()
347
348
349def _ensure_commit(git_sha1: str) -> None:
350    """Make sure that we actually have the commit locally"""
351    cmd = git("cat-file", "-e", git_sha1 + r"^{commit}")
352    p = subprocess.run(cmd, capture_output=True, check=False)
353    if p.returncode == 0:
354        # we have the commit locally
355        return
356    # we don't have the commit, must fetch
357    cmd = git("fetch", GITHUB_REMOTE_URL, git_sha1)
358    subprocess.check_call(cmd)
359
360
361def _nightly_version(site_dir: Path) -> str:
362    # first get the git version from the installed module
363    version_file = site_dir / "torch" / "version.py"
364    with version_file.open(encoding="utf-8") as f:
365        for line in f:
366            if not line.startswith("git_version"):
367                continue
368            git_version = literal_eval(line.partition("=")[2].strip())
369            break
370        else:
371            raise RuntimeError(f"Could not find git_version in {version_file}")
372
373    print(f"Found released git version {git_version}")
374    # now cross reference with nightly version
375    _ensure_commit(git_version)
376    cmd = git("show", "--no-patch", "--format=%s", git_version)
377    stdout = subprocess.check_output(cmd, text=True, encoding="utf-8")
378    m = SHA1_RE.search(stdout)
379    if m is None:
380        raise RuntimeError(
381            f"Could not find nightly release in git history:\n  {stdout}"
382        )
383    nightly_version = m.group("sha1")
384    print(f"Found nightly release version {nightly_version}")
385    # now checkout nightly version
386    _ensure_commit(nightly_version)
387    return nightly_version
388
389
390@timed("Checking out nightly PyTorch")
391def checkout_nightly_version(branch: str, site_dir: Path) -> None:
392    """Get's the nightly version and then checks it out."""
393    nightly_version = _nightly_version(site_dir)
394    cmd = git("checkout", "-b", branch, nightly_version)
395    subprocess.check_call(cmd)
396
397
398@timed("Pulling nightly PyTorch")
399def pull_nightly_version(site_dir: Path) -> None:
400    """Fetches the nightly version and then merges it ."""
401    nightly_version = _nightly_version(site_dir)
402    cmd = git("merge", nightly_version)
403    subprocess.check_call(cmd)
404
405
406def _get_listing_linux(source_dir: Path) -> list[Path]:
407    return list(
408        itertools.chain(
409            source_dir.glob("*.so"),
410            (source_dir / "lib").glob("*.so"),
411            (source_dir / "lib").glob("*.so.*"),
412        )
413    )
414
415
416def _get_listing_osx(source_dir: Path) -> list[Path]:
417    # oddly, these are .so files even on Mac
418    return list(
419        itertools.chain(
420            source_dir.glob("*.so"),
421            (source_dir / "lib").glob("*.dylib"),
422        )
423    )
424
425
426def _get_listing_win(source_dir: Path) -> list[Path]:
427    return list(
428        itertools.chain(
429            source_dir.glob("*.pyd"),
430            (source_dir / "lib").glob("*.lib"),
431            (source_dir / "lib").glob(".dll"),
432        )
433    )
434
435
436def _glob_pyis(d: Path) -> set[str]:
437    return {p.relative_to(d).as_posix() for p in d.rglob("*.pyi")}
438
439
440def _find_missing_pyi(source_dir: Path, target_dir: Path) -> list[Path]:
441    source_pyis = _glob_pyis(source_dir)
442    target_pyis = _glob_pyis(target_dir)
443    missing_pyis = sorted(source_dir / p for p in (source_pyis - target_pyis))
444    return missing_pyis
445
446
447def _get_listing(source_dir: Path, target_dir: Path, platform: str) -> list[Path]:
448    if platform.startswith("linux"):
449        listing = _get_listing_linux(source_dir)
450    elif platform.startswith("osx"):
451        listing = _get_listing_osx(source_dir)
452    elif platform.startswith("win"):
453        listing = _get_listing_win(source_dir)
454    else:
455        raise RuntimeError(f"Platform {platform!r} not recognized")
456    listing.extend(_find_missing_pyi(source_dir, target_dir))
457    listing.append(source_dir / "version.py")
458    listing.append(source_dir / "testing" / "_internal" / "generated")
459    listing.append(source_dir / "bin")
460    listing.append(source_dir / "include")
461    return listing
462
463
464def _remove_existing(path: Path) -> None:
465    if path.exists():
466        if path.is_dir():
467            shutil.rmtree(path)
468        else:
469            path.unlink()
470
471
472def _move_single(
473    src: Path,
474    source_dir: Path,
475    target_dir: Path,
476    mover: Callable[[Path, Path], None],
477    verb: str,
478) -> None:
479    relpath = src.relative_to(source_dir)
480    trg = target_dir / relpath
481    _remove_existing(trg)
482    # move over new files
483    if src.is_dir():
484        trg.mkdir(parents=True, exist_ok=True)
485        for root, dirs, files in os.walk(src):
486            relroot = Path(root).relative_to(src)
487            for name in files:
488                relname = relroot / name
489                s = src / relname
490                t = trg / relname
491                print(f"{verb} {s} -> {t}")
492                mover(s, t)
493            for name in dirs:
494                (trg / relroot / name).mkdir(parents=True, exist_ok=True)
495    else:
496        print(f"{verb} {src} -> {trg}")
497        mover(src, trg)
498
499
500def _copy_files(listing: list[Path], source_dir: Path, target_dir: Path) -> None:
501    for src in listing:
502        _move_single(src, source_dir, target_dir, shutil.copy2, "Copying")
503
504
505def _link_files(listing: list[Path], source_dir: Path, target_dir: Path) -> None:
506    for src in listing:
507        _move_single(src, source_dir, target_dir, os.link, "Linking")
508
509
510@timed("Moving nightly files into repo")
511def move_nightly_files(site_dir: Path, platform: str) -> None:
512    """Moves PyTorch files from temporary installed location to repo."""
513    # get file listing
514    source_dir = site_dir / "torch"
515    target_dir = REPO_ROOT / "torch"
516    listing = _get_listing(source_dir, target_dir, platform)
517    # copy / link files
518    if platform.startswith("win"):
519        _copy_files(listing, source_dir, target_dir)
520    else:
521        try:
522            _link_files(listing, source_dir, target_dir)
523        except Exception:
524            _copy_files(listing, source_dir, target_dir)
525
526
527def _available_envs() -> dict[str, str]:
528    cmd = ["conda", "env", "list"]
529    stdout = subprocess.check_output(cmd, text=True, encoding="utf-8")
530    envs = {}
531    for line in map(str.strip, stdout.splitlines()):
532        if not line or line.startswith("#"):
533            continue
534        parts = line.split()
535        if len(parts) == 1:
536            # unnamed env
537            continue
538        envs[parts[0]] = parts[-1]
539    return envs
540
541
542@timed("Writing pytorch-nightly.pth")
543def write_pth(env_opts: list[str], platform: str) -> None:
544    """Writes Python path file for this dir."""
545    env_type, env_dir = env_opts
546    if env_type == "--name":
547        # have to find directory
548        envs = _available_envs()
549        env_dir = envs[env_dir]
550    site_dir = _site_packages(env_dir, platform)
551    (site_dir / "pytorch-nightly.pth").write_text(
552        "# This file was autogenerated by PyTorch's tools/nightly.py\n"
553        "# Please delete this file if you no longer need the following development\n"
554        "# version of PyTorch to be importable\n"
555        f"{REPO_ROOT}\n",
556        encoding="utf-8",
557    )
558
559
560def install(
561    specs: Iterable[str],
562    *,
563    logger: logging.Logger,
564    subcommand: str = "checkout",
565    branch: str | None = None,
566    name: str | None = None,
567    prefix: str | None = None,
568    channels: Iterable[str] = ("pytorch-nightly",),
569    override_channels: bool = False,
570) -> None:
571    """Development install of PyTorch"""
572    specs = list(specs)
573    deps, pytorch, platform, existing_env, env_opts = conda_solve(
574        specs=specs,
575        name=name,
576        prefix=prefix,
577        channels=channels,
578        override_channels=override_channels,
579    )
580    if deps:
581        deps_install(deps, existing_env, env_opts)
582
583    with pytorch_install(pytorch) as pytorch_dir:
584        site_dir = _site_packages(pytorch_dir, platform)
585        if subcommand == "checkout":
586            checkout_nightly_version(cast(str, branch), site_dir)
587        elif subcommand == "pull":
588            pull_nightly_version(site_dir)
589        else:
590            raise ValueError(f"Subcommand {subcommand} must be one of: checkout, pull.")
591        move_nightly_files(site_dir, platform)
592
593    write_pth(env_opts, platform)
594    logger.info(
595        "-------\nPyTorch Development Environment set up!\nPlease activate to "
596        "enable this environment:\n  $ conda activate %s",
597        env_opts[1],
598    )
599
600
601def make_parser() -> argparse.ArgumentParser:
602    p = argparse.ArgumentParser()
603    # subcommands
604    subcmd = p.add_subparsers(dest="subcmd", help="subcommand to execute")
605    checkout = subcmd.add_parser("checkout", help="checkout a new branch")
606    checkout.add_argument(
607        "-b",
608        "--branch",
609        help="Branch name to checkout",
610        dest="branch",
611        default=None,
612        metavar="NAME",
613    )
614    pull = subcmd.add_parser(
615        "pull", help="pulls the nightly commits into the current branch"
616    )
617    # general arguments
618    subparsers = [checkout, pull]
619    for subparser in subparsers:
620        subparser.add_argument(
621            "-n",
622            "--name",
623            help="Name of environment",
624            dest="name",
625            default=None,
626            metavar="ENVIRONMENT",
627        )
628        subparser.add_argument(
629            "-p",
630            "--prefix",
631            help="Full path to environment location (i.e. prefix)",
632            dest="prefix",
633            default=None,
634            metavar="PATH",
635        )
636        subparser.add_argument(
637            "-v",
638            "--verbose",
639            help="Provide debugging info",
640            dest="verbose",
641            default=False,
642            action="store_true",
643        )
644        subparser.add_argument(
645            "--override-channels",
646            help="Do not search default or .condarc channels.",
647            dest="override_channels",
648            default=False,
649            action="store_true",
650        )
651        subparser.add_argument(
652            "-c",
653            "--channel",
654            help=(
655                "Additional channel to search for packages. "
656                "'pytorch-nightly' will always be prepended to this list."
657            ),
658            dest="channels",
659            action="append",
660            metavar="CHANNEL",
661        )
662        if platform_system() in {"Linux", "Windows"}:
663            subparser.add_argument(
664                "--cuda",
665                help=(
666                    "CUDA version to install "
667                    "(defaults to the latest version available on the platform)"
668                ),
669                dest="cuda",
670                nargs="?",
671                default=argparse.SUPPRESS,
672                metavar="VERSION",
673            )
674    return p
675
676
677def main(args: Sequence[str] | None = None) -> None:
678    """Main entry point"""
679    global LOGGER
680    p = make_parser()
681    ns = p.parse_args(args)
682    ns.branch = getattr(ns, "branch", None)
683    status = check_branch(ns.subcmd, ns.branch)
684    if status:
685        sys.exit(status)
686    specs = list(SPECS_TO_INSTALL)
687    channels = ["pytorch-nightly"]
688    if hasattr(ns, "cuda"):
689        if ns.cuda is not None:
690            specs.append(f"pytorch-cuda={ns.cuda}")
691        else:
692            specs.append("pytorch-cuda")
693        specs.append("pytorch-mutex=*=*cuda*")
694        channels.append("nvidia")
695    else:
696        specs.append("pytorch-mutex=*=*cpu*")
697    if ns.channels:
698        channels.extend(ns.channels)
699    with logging_manager(debug=ns.verbose) as logger:
700        LOGGER = logger
701        install(
702            specs=specs,
703            subcommand=ns.subcmd,
704            branch=ns.branch,
705            name=ns.name,
706            prefix=ns.prefix,
707            logger=logger,
708            channels=channels,
709            override_channels=ns.override_channels,
710        )
711
712
713if __name__ == "__main__":
714    main()
715