xref: /aosp_15_r20/external/pytorch/tools/setup_helpers/gen_version_header.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Ideally, there would be a way in Bazel to parse version.txt
2# and use the version numbers from there as substitutions for
3# an expand_template action. Since there isn't, this silly script exists.
4
5from __future__ import annotations
6
7import argparse
8import os
9from typing import cast, Tuple
10
11
12Version = Tuple[int, int, int]
13
14
15def parse_version(version: str) -> Version:
16    """
17    Parses a version string into (major, minor, patch) version numbers.
18
19    Args:
20      version: Full version number string, possibly including revision / commit hash.
21
22    Returns:
23      An int 3-tuple of (major, minor, patch) version numbers.
24    """
25    # Extract version number part (i.e. toss any revision / hash parts).
26    version_number_str = version
27    for i in range(len(version)):
28        c = version[i]
29        if not (c.isdigit() or c == "."):
30            version_number_str = version[:i]
31            break
32
33    return cast(Version, tuple([int(n) for n in version_number_str.split(".")]))
34
35
36def apply_replacements(replacements: dict[str, str], text: str) -> str:
37    """
38    Applies the given replacements within the text.
39
40    Args:
41      replacements (dict): Mapping of str -> str replacements.
42      text (str): Text in which to make replacements.
43
44    Returns:
45      Text with replacements applied, if any.
46    """
47    for before, after in replacements.items():
48        text = text.replace(before, after)
49    return text
50
51
52def main(args: argparse.Namespace) -> None:
53    with open(args.version_path) as f:
54        version = f.read().strip()
55    (major, minor, patch) = parse_version(version)
56
57    replacements = {
58        "@TORCH_VERSION_MAJOR@": str(major),
59        "@TORCH_VERSION_MINOR@": str(minor),
60        "@TORCH_VERSION_PATCH@": str(patch),
61    }
62
63    # Create the output dir if it doesn't exist.
64    os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
65
66    with open(args.template_path) as input:
67        with open(args.output_path, "w") as output:
68            for line in input:
69                output.write(apply_replacements(replacements, line))
70
71
72if __name__ == "__main__":
73    parser = argparse.ArgumentParser(
74        description="Generate version.h from version.h.in template",
75    )
76    parser.add_argument(
77        "--template-path",
78        required=True,
79        help="Path to the template (i.e. version.h.in)",
80    )
81    parser.add_argument(
82        "--version-path",
83        required=True,
84        help="Path to the file specifying the version",
85    )
86    parser.add_argument(
87        "--output-path",
88        required=True,
89        help="Output path for expanded template (i.e. version.h)",
90    )
91    args = parser.parse_args()
92    main(args)
93