xref: /aosp_15_r20/external/pytorch/tools/onnx/update_default_opset_version.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2
3"""Updates the default value of opset_version.
4
5The current policy is that the default should be set to the
6latest released version as of 18 months ago.
7
8Usage:
9Run with no arguments.
10"""
11
12import argparse
13import datetime
14import os
15import re
16import subprocess
17import sys
18from pathlib import Path
19from subprocess import DEVNULL
20from typing import Any
21
22
23def read_sub_write(path: str, prefix_pat: str, new_default: int) -> None:
24    with open(path, encoding="utf-8") as f:
25        content_str = f.read()
26    content_str = re.sub(prefix_pat, rf"\g<1>{new_default}", content_str)
27    with open(path, "w", encoding="utf-8") as f:
28        f.write(content_str)
29    print("modified", path)
30
31
32def main(args: Any) -> None:
33    pytorch_dir = Path(__file__).parent.parent.parent.resolve()
34    onnx_dir = pytorch_dir / "third_party" / "onnx"
35    os.chdir(onnx_dir)
36
37    date = datetime.datetime.now() - datetime.timedelta(days=18 * 30)
38    onnx_commit = subprocess.check_output(
39        ("git", "log", f"--until={date}", "--max-count=1", "--format=%H"),
40        encoding="utf-8",
41    ).strip()
42    onnx_tags = subprocess.check_output(
43        ("git", "tag", "--list", f"--contains={onnx_commit}"), encoding="utf-8"
44    )
45    tag_tups = []
46    semver_pat = re.compile(r"v(\d+)\.(\d+)\.(\d+)")
47    for tag in onnx_tags.splitlines():
48        match = semver_pat.match(tag)
49        if match:
50            tag_tups.append(tuple(int(x) for x in match.groups()))
51
52    # Take the release 18 months ago
53    version_str = "{}.{}.{}".format(*min(tag_tups))
54
55    print("Using ONNX release", version_str)
56
57    head_commit = subprocess.check_output(
58        ("git", "log", "--max-count=1", "--format=%H", "HEAD"), encoding="utf-8"
59    ).strip()
60
61    new_default = None
62
63    subprocess.check_call(
64        ("git", "checkout", f"v{version_str}"), stdout=DEVNULL, stderr=DEVNULL
65    )
66    try:
67        from onnx import helper  # type: ignore[import]
68
69        for version in helper.VERSION_TABLE:
70            if version[0] == version_str:
71                new_default = version[2]
72                print("found new default opset_version", new_default)
73                break
74        if not new_default:
75            sys.exit(
76                f"failed to find version {version_str} in onnx.helper.VERSION_TABLE at commit {onnx_commit}"
77            )
78    finally:
79        subprocess.check_call(
80            ("git", "checkout", head_commit), stdout=DEVNULL, stderr=DEVNULL
81        )
82
83    os.chdir(pytorch_dir)
84
85    read_sub_write(
86        os.path.join("torch", "onnx", "_constants.py"),
87        r"(ONNX_DEFAULT_OPSET = )\d+",
88        new_default,
89    )
90    read_sub_write(
91        os.path.join("torch", "onnx", "utils.py"),
92        r"(opset_version \(int, default )\d+",
93        new_default,
94    )
95
96    if not args.skip_build:
97        print("Building PyTorch...")
98        subprocess.check_call(
99            ("python", "setup.py", "develop"),
100        )
101    print("Updating operator .expect files")
102    subprocess.check_call(
103        ("python", os.path.join("test", "onnx", "test_operators.py"), "--accept"),
104    )
105
106
107if __name__ == "__main__":
108    parser = argparse.ArgumentParser()
109    parser.add_argument(
110        "--skip-build",
111        "--skip_build",
112        action="store_true",
113        help="Skip building pytorch",
114    )
115    main(parser.parse_args())
116