1#!/usr/bin/env python3 2# Much of the logging code here was forked from https://github.com/ezyang/ghstack 3# Copyright (c) Edward Z. Yang <[email protected]> 4"""Checks out the nightly development version of PyTorch and installs pre-built 5binaries into the repo. 6 7You can use this script to check out a new nightly branch with the following:: 8 9 $ ./tools/nightly.py checkout -b my-nightly-branch 10 $ conda activate pytorch-deps 11 12Or if you would like to re-use an existing conda environment, you can pass in 13the regular environment parameters (--name or --prefix):: 14 15 $ ./tools/nightly.py checkout -b my-nightly-branch -n my-env 16 $ conda activate my-env 17 18To install the nightly binaries built with CUDA, you can pass in the flag --cuda:: 19 20 $ ./tools/nightly.py checkout -b my-nightly-branch --cuda 21 $ conda activate pytorch-deps 22 23You can also use this tool to pull the nightly commits into the current branch as 24well. This can be done with:: 25 26 $ ./tools/nightly.py pull -n my-env 27 $ conda activate my-env 28 29Pulling will reinstall the conda dependencies as well as the nightly binaries into 30the repo directory. 31""" 32 33from __future__ import annotations 34 35import argparse 36import contextlib 37import functools 38import glob 39import itertools 40import json 41import logging 42import os 43import re 44import shutil 45import subprocess 46import sys 47import tempfile 48import time 49import uuid 50from ast import literal_eval 51from datetime import datetime 52from pathlib import Path 53from platform import system as platform_system 54from typing import Any, Callable, cast, Generator, Iterable, Iterator, Sequence, TypeVar 55 56 57REPO_ROOT = Path(__file__).absolute().parent.parent 58GITHUB_REMOTE_URL = "https://github.com/pytorch/pytorch.git" 59SPECS_TO_INSTALL = ("pytorch", "mypy", "pytest", "hypothesis", "ipython", "sphinx") 60 61LOGGER: logging.Logger | None = None 62URL_FORMAT = "{base_url}/{platform}/{dist_name}.tar.bz2" 63DATETIME_FORMAT = "%Y-%m-%d_%Hh%Mm%Ss" 64SHA1_RE = re.compile(r"(?P<sha1>[0-9a-fA-F]{40})") 65USERNAME_PASSWORD_RE = re.compile(r":\/\/(.*?)\@") 66LOG_DIRNAME_RE = re.compile( 67 r"(?P<datetime>\d{4}-\d\d-\d\d_\d\dh\d\dm\d\ds)_" 68 r"(?P<uuid>[0-9a-f]{8}-(?:[0-9a-f]{4}-){3}[0-9a-f]{12})", 69) 70 71 72class Formatter(logging.Formatter): 73 redactions: dict[str, str] 74 75 def __init__(self, fmt: str | None = None, datefmt: str | None = None) -> None: 76 super().__init__(fmt, datefmt) 77 self.redactions = {} 78 79 # Remove sensitive information from URLs 80 def _filter(self, s: str) -> str: 81 s = USERNAME_PASSWORD_RE.sub(r"://<USERNAME>:<PASSWORD>@", s) 82 for needle, replace in self.redactions.items(): 83 s = s.replace(needle, replace) 84 return s 85 86 def formatMessage(self, record: logging.LogRecord) -> str: 87 if record.levelno == logging.INFO or record.levelno == logging.DEBUG: 88 # Log INFO/DEBUG without any adornment 89 return record.getMessage() 90 else: 91 # I'm not sure why, but formatMessage doesn't show up 92 # even though it's in the typeshed for Python >3 93 return super().formatMessage(record) 94 95 def format(self, record: logging.LogRecord) -> str: 96 return self._filter(super().format(record)) 97 98 def redact(self, needle: str, replace: str = "<REDACTED>") -> None: 99 """Redact specific strings; e.g., authorization tokens. This won't 100 retroactively redact stuff you've already leaked, so make sure 101 you redact things as soon as possible. 102 """ 103 # Don't redact empty strings; this will lead to something 104 # that looks like s<REDACTED>t<REDACTED>r<REDACTED>... 105 if needle == "": 106 return 107 self.redactions[needle] = replace 108 109 110def git(*args: str) -> list[str]: 111 return ["git", "-C", str(REPO_ROOT), *args] 112 113 114@functools.lru_cache 115def logging_base_dir() -> Path: 116 base_dir = REPO_ROOT / "nightly" / "log" 117 base_dir.mkdir(parents=True, exist_ok=True) 118 return base_dir 119 120 121@functools.lru_cache 122def logging_run_dir() -> Path: 123 base_dir = logging_base_dir() 124 cur_dir = base_dir / f"{datetime.now().strftime(DATETIME_FORMAT)}_{uuid.uuid1()}" 125 cur_dir.mkdir(parents=True, exist_ok=True) 126 return cur_dir 127 128 129@functools.lru_cache 130def logging_record_argv() -> None: 131 s = subprocess.list2cmdline(sys.argv) 132 (logging_run_dir() / "argv").write_text(s, encoding="utf-8") 133 134 135def logging_record_exception(e: BaseException) -> None: 136 (logging_run_dir() / "exception").write_text(type(e).__name__, encoding="utf-8") 137 138 139def logging_rotate() -> None: 140 log_base = logging_base_dir() 141 old_logs = sorted(log_base.iterdir(), reverse=True) 142 for stale_log in old_logs[1000:]: 143 # Sanity check that it looks like a log 144 if LOG_DIRNAME_RE.fullmatch(stale_log.name) is not None: 145 shutil.rmtree(stale_log) 146 147 148@contextlib.contextmanager 149def logging_manager(*, debug: bool = False) -> Generator[logging.Logger, None, None]: 150 """Setup logging. If a failure starts here we won't 151 be able to save the user in a reasonable way. 152 153 Logging structure: there is one logger (the root logger) 154 and in processes all events. There are two handlers: 155 stderr (INFO) and file handler (DEBUG). 156 """ 157 formatter = Formatter(fmt="%(levelname)s: %(message)s", datefmt="") 158 root_logger = logging.getLogger("conda-pytorch") 159 root_logger.setLevel(logging.DEBUG) 160 161 console_handler = logging.StreamHandler() 162 if debug: 163 console_handler.setLevel(logging.DEBUG) 164 else: 165 console_handler.setLevel(logging.INFO) 166 console_handler.setFormatter(formatter) 167 root_logger.addHandler(console_handler) 168 169 log_file = logging_run_dir() / "nightly.log" 170 171 file_handler = logging.FileHandler(log_file) 172 file_handler.setFormatter(formatter) 173 root_logger.addHandler(file_handler) 174 logging_record_argv() 175 176 try: 177 logging_rotate() 178 print(f"log file: {log_file}") 179 yield root_logger 180 except Exception as e: 181 logging.exception("Fatal exception") 182 logging_record_exception(e) 183 print(f"log file: {log_file}") 184 sys.exit(1) 185 except BaseException as e: 186 # You could logging.debug here to suppress the backtrace 187 # entirely, but there is no reason to hide it from technically 188 # savvy users. 189 logging.info("", exc_info=True) 190 logging_record_exception(e) 191 print(f"log file: {log_file}") 192 sys.exit(1) 193 194 195def check_branch(subcommand: str, branch: str | None) -> str | None: 196 """Checks that the branch name can be checked out.""" 197 if subcommand != "checkout": 198 return None 199 # first make sure actual branch name was given 200 if branch is None: 201 return "Branch name to checkout must be supplied with '-b' option" 202 # next check that the local repo is clean 203 cmd = git("status", "--untracked-files=no", "--porcelain") 204 stdout = subprocess.check_output(cmd, text=True, encoding="utf-8") 205 if stdout.strip(): 206 return "Need to have clean working tree to checkout!\n\n" + stdout 207 # next check that the branch name doesn't already exist 208 cmd = git("show-ref", "--verify", "--quiet", f"refs/heads/{branch}") 209 p = subprocess.run(cmd, capture_output=True, check=False) # type: ignore[assignment] 210 if not p.returncode: 211 return f"Branch {branch!r} already exists" 212 return None 213 214 215@contextlib.contextmanager 216def timer(logger: logging.Logger, prefix: str) -> Iterator[None]: 217 """Timed context manager""" 218 start_time = time.perf_counter() 219 yield 220 logger.info("%s took %.3f [s]", prefix, time.perf_counter() - start_time) 221 222 223F = TypeVar("F", bound=Callable[..., Any]) 224 225 226def timed(prefix: str) -> Callable[[F], F]: 227 """Decorator for timing functions""" 228 229 def dec(f: F) -> F: 230 @functools.wraps(f) 231 def wrapper(*args: Any, **kwargs: Any) -> Any: 232 logger = cast(logging.Logger, LOGGER) 233 logger.info(prefix) 234 with timer(logger, prefix): 235 return f(*args, **kwargs) 236 237 return cast(F, wrapper) 238 239 return dec 240 241 242def _make_channel_args( 243 channels: Iterable[str] = ("pytorch-nightly",), 244 override_channels: bool = False, 245) -> list[str]: 246 args = [] 247 for channel in channels: 248 args.extend(["--channel", channel]) 249 if override_channels: 250 args.append("--override-channels") 251 return args 252 253 254@timed("Solving conda environment") 255def conda_solve( 256 specs: Iterable[str], 257 *, 258 name: str | None = None, 259 prefix: str | None = None, 260 channels: Iterable[str] = ("pytorch-nightly",), 261 override_channels: bool = False, 262) -> tuple[list[str], str, str, bool, list[str]]: 263 """Performs the conda solve and splits the deps from the package.""" 264 # compute what environment to use 265 if prefix is not None: 266 existing_env = True 267 env_opts = ["--prefix", prefix] 268 elif name is not None: 269 existing_env = True 270 env_opts = ["--name", name] 271 else: 272 # create new environment 273 existing_env = False 274 env_opts = ["--name", "pytorch-deps"] 275 # run solve 276 if existing_env: 277 cmd = [ 278 "conda", 279 "install", 280 "--yes", 281 "--dry-run", 282 "--json", 283 ] 284 cmd.extend(env_opts) 285 else: 286 cmd = [ 287 "conda", 288 "create", 289 "--yes", 290 "--dry-run", 291 "--json", 292 "--name", 293 "__pytorch__", 294 ] 295 channel_args = _make_channel_args( 296 channels=channels, 297 override_channels=override_channels, 298 ) 299 cmd.extend(channel_args) 300 cmd.extend(specs) 301 stdout = subprocess.check_output(cmd, text=True, encoding="utf-8") 302 # parse solution 303 solve = json.loads(stdout) 304 link = solve["actions"]["LINK"] 305 deps = [] 306 pytorch, platform = "", "" 307 for pkg in link: 308 url = URL_FORMAT.format(**pkg) 309 if pkg["name"] == "pytorch": 310 pytorch = url 311 platform = pkg["platform"] 312 else: 313 deps.append(url) 314 assert pytorch, "PyTorch package not found in solve" 315 assert platform, "Platform not found in solve" 316 return deps, pytorch, platform, existing_env, env_opts 317 318 319@timed("Installing dependencies") 320def deps_install(deps: list[str], existing_env: bool, env_opts: list[str]) -> None: 321 """Install dependencies to deps environment""" 322 if not existing_env: 323 # first remove previous pytorch-deps env 324 cmd = ["conda", "env", "remove", "--yes", *env_opts] 325 subprocess.check_call(cmd) 326 # install new deps 327 install_command = "install" if existing_env else "create" 328 cmd = ["conda", install_command, "--yes", "--no-deps", *env_opts, *deps] 329 subprocess.check_call(cmd) 330 331 332@timed("Installing pytorch nightly binaries") 333def pytorch_install(url: str) -> tempfile.TemporaryDirectory[str]: 334 """Install pytorch into a temporary directory""" 335 pytorch_dir = tempfile.TemporaryDirectory(prefix="conda-pytorch-") 336 cmd = ["conda", "create", "--yes", "--no-deps", f"--prefix={pytorch_dir.name}", url] 337 subprocess.check_call(cmd) 338 return pytorch_dir 339 340 341def _site_packages(dirname: str, platform: str) -> Path: 342 if platform.startswith("win"): 343 template = os.path.join(dirname, "Lib", "site-packages") 344 else: 345 template = os.path.join(dirname, "lib", "python*.*", "site-packages") 346 return Path(next(glob.iglob(template))).absolute() 347 348 349def _ensure_commit(git_sha1: str) -> None: 350 """Make sure that we actually have the commit locally""" 351 cmd = git("cat-file", "-e", git_sha1 + r"^{commit}") 352 p = subprocess.run(cmd, capture_output=True, check=False) 353 if p.returncode == 0: 354 # we have the commit locally 355 return 356 # we don't have the commit, must fetch 357 cmd = git("fetch", GITHUB_REMOTE_URL, git_sha1) 358 subprocess.check_call(cmd) 359 360 361def _nightly_version(site_dir: Path) -> str: 362 # first get the git version from the installed module 363 version_file = site_dir / "torch" / "version.py" 364 with version_file.open(encoding="utf-8") as f: 365 for line in f: 366 if not line.startswith("git_version"): 367 continue 368 git_version = literal_eval(line.partition("=")[2].strip()) 369 break 370 else: 371 raise RuntimeError(f"Could not find git_version in {version_file}") 372 373 print(f"Found released git version {git_version}") 374 # now cross reference with nightly version 375 _ensure_commit(git_version) 376 cmd = git("show", "--no-patch", "--format=%s", git_version) 377 stdout = subprocess.check_output(cmd, text=True, encoding="utf-8") 378 m = SHA1_RE.search(stdout) 379 if m is None: 380 raise RuntimeError( 381 f"Could not find nightly release in git history:\n {stdout}" 382 ) 383 nightly_version = m.group("sha1") 384 print(f"Found nightly release version {nightly_version}") 385 # now checkout nightly version 386 _ensure_commit(nightly_version) 387 return nightly_version 388 389 390@timed("Checking out nightly PyTorch") 391def checkout_nightly_version(branch: str, site_dir: Path) -> None: 392 """Get's the nightly version and then checks it out.""" 393 nightly_version = _nightly_version(site_dir) 394 cmd = git("checkout", "-b", branch, nightly_version) 395 subprocess.check_call(cmd) 396 397 398@timed("Pulling nightly PyTorch") 399def pull_nightly_version(site_dir: Path) -> None: 400 """Fetches the nightly version and then merges it .""" 401 nightly_version = _nightly_version(site_dir) 402 cmd = git("merge", nightly_version) 403 subprocess.check_call(cmd) 404 405 406def _get_listing_linux(source_dir: Path) -> list[Path]: 407 return list( 408 itertools.chain( 409 source_dir.glob("*.so"), 410 (source_dir / "lib").glob("*.so"), 411 (source_dir / "lib").glob("*.so.*"), 412 ) 413 ) 414 415 416def _get_listing_osx(source_dir: Path) -> list[Path]: 417 # oddly, these are .so files even on Mac 418 return list( 419 itertools.chain( 420 source_dir.glob("*.so"), 421 (source_dir / "lib").glob("*.dylib"), 422 ) 423 ) 424 425 426def _get_listing_win(source_dir: Path) -> list[Path]: 427 return list( 428 itertools.chain( 429 source_dir.glob("*.pyd"), 430 (source_dir / "lib").glob("*.lib"), 431 (source_dir / "lib").glob(".dll"), 432 ) 433 ) 434 435 436def _glob_pyis(d: Path) -> set[str]: 437 return {p.relative_to(d).as_posix() for p in d.rglob("*.pyi")} 438 439 440def _find_missing_pyi(source_dir: Path, target_dir: Path) -> list[Path]: 441 source_pyis = _glob_pyis(source_dir) 442 target_pyis = _glob_pyis(target_dir) 443 missing_pyis = sorted(source_dir / p for p in (source_pyis - target_pyis)) 444 return missing_pyis 445 446 447def _get_listing(source_dir: Path, target_dir: Path, platform: str) -> list[Path]: 448 if platform.startswith("linux"): 449 listing = _get_listing_linux(source_dir) 450 elif platform.startswith("osx"): 451 listing = _get_listing_osx(source_dir) 452 elif platform.startswith("win"): 453 listing = _get_listing_win(source_dir) 454 else: 455 raise RuntimeError(f"Platform {platform!r} not recognized") 456 listing.extend(_find_missing_pyi(source_dir, target_dir)) 457 listing.append(source_dir / "version.py") 458 listing.append(source_dir / "testing" / "_internal" / "generated") 459 listing.append(source_dir / "bin") 460 listing.append(source_dir / "include") 461 return listing 462 463 464def _remove_existing(path: Path) -> None: 465 if path.exists(): 466 if path.is_dir(): 467 shutil.rmtree(path) 468 else: 469 path.unlink() 470 471 472def _move_single( 473 src: Path, 474 source_dir: Path, 475 target_dir: Path, 476 mover: Callable[[Path, Path], None], 477 verb: str, 478) -> None: 479 relpath = src.relative_to(source_dir) 480 trg = target_dir / relpath 481 _remove_existing(trg) 482 # move over new files 483 if src.is_dir(): 484 trg.mkdir(parents=True, exist_ok=True) 485 for root, dirs, files in os.walk(src): 486 relroot = Path(root).relative_to(src) 487 for name in files: 488 relname = relroot / name 489 s = src / relname 490 t = trg / relname 491 print(f"{verb} {s} -> {t}") 492 mover(s, t) 493 for name in dirs: 494 (trg / relroot / name).mkdir(parents=True, exist_ok=True) 495 else: 496 print(f"{verb} {src} -> {trg}") 497 mover(src, trg) 498 499 500def _copy_files(listing: list[Path], source_dir: Path, target_dir: Path) -> None: 501 for src in listing: 502 _move_single(src, source_dir, target_dir, shutil.copy2, "Copying") 503 504 505def _link_files(listing: list[Path], source_dir: Path, target_dir: Path) -> None: 506 for src in listing: 507 _move_single(src, source_dir, target_dir, os.link, "Linking") 508 509 510@timed("Moving nightly files into repo") 511def move_nightly_files(site_dir: Path, platform: str) -> None: 512 """Moves PyTorch files from temporary installed location to repo.""" 513 # get file listing 514 source_dir = site_dir / "torch" 515 target_dir = REPO_ROOT / "torch" 516 listing = _get_listing(source_dir, target_dir, platform) 517 # copy / link files 518 if platform.startswith("win"): 519 _copy_files(listing, source_dir, target_dir) 520 else: 521 try: 522 _link_files(listing, source_dir, target_dir) 523 except Exception: 524 _copy_files(listing, source_dir, target_dir) 525 526 527def _available_envs() -> dict[str, str]: 528 cmd = ["conda", "env", "list"] 529 stdout = subprocess.check_output(cmd, text=True, encoding="utf-8") 530 envs = {} 531 for line in map(str.strip, stdout.splitlines()): 532 if not line or line.startswith("#"): 533 continue 534 parts = line.split() 535 if len(parts) == 1: 536 # unnamed env 537 continue 538 envs[parts[0]] = parts[-1] 539 return envs 540 541 542@timed("Writing pytorch-nightly.pth") 543def write_pth(env_opts: list[str], platform: str) -> None: 544 """Writes Python path file for this dir.""" 545 env_type, env_dir = env_opts 546 if env_type == "--name": 547 # have to find directory 548 envs = _available_envs() 549 env_dir = envs[env_dir] 550 site_dir = _site_packages(env_dir, platform) 551 (site_dir / "pytorch-nightly.pth").write_text( 552 "# This file was autogenerated by PyTorch's tools/nightly.py\n" 553 "# Please delete this file if you no longer need the following development\n" 554 "# version of PyTorch to be importable\n" 555 f"{REPO_ROOT}\n", 556 encoding="utf-8", 557 ) 558 559 560def install( 561 specs: Iterable[str], 562 *, 563 logger: logging.Logger, 564 subcommand: str = "checkout", 565 branch: str | None = None, 566 name: str | None = None, 567 prefix: str | None = None, 568 channels: Iterable[str] = ("pytorch-nightly",), 569 override_channels: bool = False, 570) -> None: 571 """Development install of PyTorch""" 572 specs = list(specs) 573 deps, pytorch, platform, existing_env, env_opts = conda_solve( 574 specs=specs, 575 name=name, 576 prefix=prefix, 577 channels=channels, 578 override_channels=override_channels, 579 ) 580 if deps: 581 deps_install(deps, existing_env, env_opts) 582 583 with pytorch_install(pytorch) as pytorch_dir: 584 site_dir = _site_packages(pytorch_dir, platform) 585 if subcommand == "checkout": 586 checkout_nightly_version(cast(str, branch), site_dir) 587 elif subcommand == "pull": 588 pull_nightly_version(site_dir) 589 else: 590 raise ValueError(f"Subcommand {subcommand} must be one of: checkout, pull.") 591 move_nightly_files(site_dir, platform) 592 593 write_pth(env_opts, platform) 594 logger.info( 595 "-------\nPyTorch Development Environment set up!\nPlease activate to " 596 "enable this environment:\n $ conda activate %s", 597 env_opts[1], 598 ) 599 600 601def make_parser() -> argparse.ArgumentParser: 602 p = argparse.ArgumentParser() 603 # subcommands 604 subcmd = p.add_subparsers(dest="subcmd", help="subcommand to execute") 605 checkout = subcmd.add_parser("checkout", help="checkout a new branch") 606 checkout.add_argument( 607 "-b", 608 "--branch", 609 help="Branch name to checkout", 610 dest="branch", 611 default=None, 612 metavar="NAME", 613 ) 614 pull = subcmd.add_parser( 615 "pull", help="pulls the nightly commits into the current branch" 616 ) 617 # general arguments 618 subparsers = [checkout, pull] 619 for subparser in subparsers: 620 subparser.add_argument( 621 "-n", 622 "--name", 623 help="Name of environment", 624 dest="name", 625 default=None, 626 metavar="ENVIRONMENT", 627 ) 628 subparser.add_argument( 629 "-p", 630 "--prefix", 631 help="Full path to environment location (i.e. prefix)", 632 dest="prefix", 633 default=None, 634 metavar="PATH", 635 ) 636 subparser.add_argument( 637 "-v", 638 "--verbose", 639 help="Provide debugging info", 640 dest="verbose", 641 default=False, 642 action="store_true", 643 ) 644 subparser.add_argument( 645 "--override-channels", 646 help="Do not search default or .condarc channels.", 647 dest="override_channels", 648 default=False, 649 action="store_true", 650 ) 651 subparser.add_argument( 652 "-c", 653 "--channel", 654 help=( 655 "Additional channel to search for packages. " 656 "'pytorch-nightly' will always be prepended to this list." 657 ), 658 dest="channels", 659 action="append", 660 metavar="CHANNEL", 661 ) 662 if platform_system() in {"Linux", "Windows"}: 663 subparser.add_argument( 664 "--cuda", 665 help=( 666 "CUDA version to install " 667 "(defaults to the latest version available on the platform)" 668 ), 669 dest="cuda", 670 nargs="?", 671 default=argparse.SUPPRESS, 672 metavar="VERSION", 673 ) 674 return p 675 676 677def main(args: Sequence[str] | None = None) -> None: 678 """Main entry point""" 679 global LOGGER 680 p = make_parser() 681 ns = p.parse_args(args) 682 ns.branch = getattr(ns, "branch", None) 683 status = check_branch(ns.subcmd, ns.branch) 684 if status: 685 sys.exit(status) 686 specs = list(SPECS_TO_INSTALL) 687 channels = ["pytorch-nightly"] 688 if hasattr(ns, "cuda"): 689 if ns.cuda is not None: 690 specs.append(f"pytorch-cuda={ns.cuda}") 691 else: 692 specs.append("pytorch-cuda") 693 specs.append("pytorch-mutex=*=*cuda*") 694 channels.append("nvidia") 695 else: 696 specs.append("pytorch-mutex=*=*cpu*") 697 if ns.channels: 698 channels.extend(ns.channels) 699 with logging_manager(debug=ns.verbose) as logger: 700 LOGGER = logger 701 install( 702 specs=specs, 703 subcommand=ns.subcmd, 704 branch=ns.branch, 705 name=ns.name, 706 prefix=ns.prefix, 707 logger=logger, 708 channels=channels, 709 override_channels=ns.override_channels, 710 ) 711 712 713if __name__ == "__main__": 714 main() 715