1# Ideally, there would be a way in Bazel to parse version.txt 2# and use the version numbers from there as substitutions for 3# an expand_template action. Since there isn't, this silly script exists. 4 5from __future__ import annotations 6 7import argparse 8import os 9from typing import cast, Tuple 10 11 12Version = Tuple[int, int, int] 13 14 15def parse_version(version: str) -> Version: 16 """ 17 Parses a version string into (major, minor, patch) version numbers. 18 19 Args: 20 version: Full version number string, possibly including revision / commit hash. 21 22 Returns: 23 An int 3-tuple of (major, minor, patch) version numbers. 24 """ 25 # Extract version number part (i.e. toss any revision / hash parts). 26 version_number_str = version 27 for i in range(len(version)): 28 c = version[i] 29 if not (c.isdigit() or c == "."): 30 version_number_str = version[:i] 31 break 32 33 return cast(Version, tuple([int(n) for n in version_number_str.split(".")])) 34 35 36def apply_replacements(replacements: dict[str, str], text: str) -> str: 37 """ 38 Applies the given replacements within the text. 39 40 Args: 41 replacements (dict): Mapping of str -> str replacements. 42 text (str): Text in which to make replacements. 43 44 Returns: 45 Text with replacements applied, if any. 46 """ 47 for before, after in replacements.items(): 48 text = text.replace(before, after) 49 return text 50 51 52def main(args: argparse.Namespace) -> None: 53 with open(args.version_path) as f: 54 version = f.read().strip() 55 (major, minor, patch) = parse_version(version) 56 57 replacements = { 58 "@TORCH_VERSION_MAJOR@": str(major), 59 "@TORCH_VERSION_MINOR@": str(minor), 60 "@TORCH_VERSION_PATCH@": str(patch), 61 } 62 63 # Create the output dir if it doesn't exist. 64 os.makedirs(os.path.dirname(args.output_path), exist_ok=True) 65 66 with open(args.template_path) as input: 67 with open(args.output_path, "w") as output: 68 for line in input: 69 output.write(apply_replacements(replacements, line)) 70 71 72if __name__ == "__main__": 73 parser = argparse.ArgumentParser( 74 description="Generate version.h from version.h.in template", 75 ) 76 parser.add_argument( 77 "--template-path", 78 required=True, 79 help="Path to the template (i.e. version.h.in)", 80 ) 81 parser.add_argument( 82 "--version-path", 83 required=True, 84 help="Path to the file specifying the version", 85 ) 86 parser.add_argument( 87 "--output-path", 88 required=True, 89 help="Output path for expanded template (i.e. version.h)", 90 ) 91 args = parser.parse_args() 92 main(args) 93