1from __future__ import annotations 2 3import argparse 4import os 5import re 6import subprocess 7from pathlib import Path 8 9from setuptools import distutils # type: ignore[import] 10 11 12UNKNOWN = "Unknown" 13RELEASE_PATTERN = re.compile(r"/v[0-9]+(\.[0-9]+)*(-rc[0-9]+)?/") 14 15 16def get_sha(pytorch_root: str | Path) -> str: 17 try: 18 rev = None 19 if os.path.exists(os.path.join(pytorch_root, ".git")): 20 rev = subprocess.check_output( 21 ["git", "rev-parse", "HEAD"], cwd=pytorch_root 22 ) 23 elif os.path.exists(os.path.join(pytorch_root, ".hg")): 24 rev = subprocess.check_output( 25 ["hg", "identify", "-r", "."], cwd=pytorch_root 26 ) 27 if rev: 28 return rev.decode("ascii").strip() 29 except Exception: 30 pass 31 return UNKNOWN 32 33 34def get_tag(pytorch_root: str | Path) -> str: 35 try: 36 tag = subprocess.run( 37 ["git", "describe", "--tags", "--exact"], 38 cwd=pytorch_root, 39 encoding="ascii", 40 capture_output=True, 41 ).stdout.strip() 42 if RELEASE_PATTERN.match(tag): 43 return tag 44 else: 45 return UNKNOWN 46 except Exception: 47 return UNKNOWN 48 49 50def get_torch_version(sha: str | None = None) -> str: 51 pytorch_root = Path(__file__).absolute().parent.parent 52 version = open(pytorch_root / "version.txt").read().strip() 53 54 if os.getenv("PYTORCH_BUILD_VERSION"): 55 assert os.getenv("PYTORCH_BUILD_NUMBER") is not None 56 build_number = int(os.getenv("PYTORCH_BUILD_NUMBER", "")) 57 version = os.getenv("PYTORCH_BUILD_VERSION", "") 58 if build_number > 1: 59 version += ".post" + str(build_number) 60 elif sha != UNKNOWN: 61 if sha is None: 62 sha = get_sha(pytorch_root) 63 version += "+git" + sha[:7] 64 return version 65 66 67if __name__ == "__main__": 68 parser = argparse.ArgumentParser( 69 description="Generate torch/version.py from build and environment metadata." 70 ) 71 parser.add_argument( 72 "--is-debug", 73 "--is_debug", 74 type=distutils.util.strtobool, 75 help="Whether this build is debug mode or not.", 76 ) 77 parser.add_argument("--cuda-version", "--cuda_version", type=str) 78 parser.add_argument("--hip-version", "--hip_version", type=str) 79 80 args = parser.parse_args() 81 82 assert args.is_debug is not None 83 args.cuda_version = None if args.cuda_version == "" else args.cuda_version 84 args.hip_version = None if args.hip_version == "" else args.hip_version 85 86 pytorch_root = Path(__file__).parent.parent 87 version_path = pytorch_root / "torch" / "version.py" 88 # Attempt to get tag first, fall back to sha if a tag was not found 89 tagged_version = get_tag(pytorch_root) 90 sha = get_sha(pytorch_root) 91 if tagged_version == UNKNOWN: 92 version = get_torch_version(sha) 93 else: 94 version = tagged_version 95 96 with open(version_path, "w") as f: 97 f.write("from typing import Optional\n\n") 98 f.write("__all__ = ['__version__', 'debug', 'cuda', 'git_version', 'hip']\n") 99 f.write(f"__version__ = '{version}'\n") 100 # NB: This is not 100% accurate, because you could have built the 101 # library code with DEBUG, but csrc without DEBUG (in which case 102 # this would claim to be a release build when it's not.) 103 f.write(f"debug = {repr(bool(args.is_debug))}\n") 104 f.write(f"cuda: Optional[str] = {repr(args.cuda_version)}\n") 105 f.write(f"git_version = {repr(sha)}\n") 106 f.write(f"hip: Optional[str] = {repr(args.hip_version)}\n") 107