1#!/usr/bin/env python3 2 3"""Updates the default value of opset_version. 4 5The current policy is that the default should be set to the 6latest released version as of 18 months ago. 7 8Usage: 9Run with no arguments. 10""" 11 12import argparse 13import datetime 14import os 15import re 16import subprocess 17import sys 18from pathlib import Path 19from subprocess import DEVNULL 20from typing import Any 21 22 23def read_sub_write(path: str, prefix_pat: str, new_default: int) -> None: 24 with open(path, encoding="utf-8") as f: 25 content_str = f.read() 26 content_str = re.sub(prefix_pat, rf"\g<1>{new_default}", content_str) 27 with open(path, "w", encoding="utf-8") as f: 28 f.write(content_str) 29 print("modified", path) 30 31 32def main(args: Any) -> None: 33 pytorch_dir = Path(__file__).parent.parent.parent.resolve() 34 onnx_dir = pytorch_dir / "third_party" / "onnx" 35 os.chdir(onnx_dir) 36 37 date = datetime.datetime.now() - datetime.timedelta(days=18 * 30) 38 onnx_commit = subprocess.check_output( 39 ("git", "log", f"--until={date}", "--max-count=1", "--format=%H"), 40 encoding="utf-8", 41 ).strip() 42 onnx_tags = subprocess.check_output( 43 ("git", "tag", "--list", f"--contains={onnx_commit}"), encoding="utf-8" 44 ) 45 tag_tups = [] 46 semver_pat = re.compile(r"v(\d+)\.(\d+)\.(\d+)") 47 for tag in onnx_tags.splitlines(): 48 match = semver_pat.match(tag) 49 if match: 50 tag_tups.append(tuple(int(x) for x in match.groups())) 51 52 # Take the release 18 months ago 53 version_str = "{}.{}.{}".format(*min(tag_tups)) 54 55 print("Using ONNX release", version_str) 56 57 head_commit = subprocess.check_output( 58 ("git", "log", "--max-count=1", "--format=%H", "HEAD"), encoding="utf-8" 59 ).strip() 60 61 new_default = None 62 63 subprocess.check_call( 64 ("git", "checkout", f"v{version_str}"), stdout=DEVNULL, stderr=DEVNULL 65 ) 66 try: 67 from onnx import helper # type: ignore[import] 68 69 for version in helper.VERSION_TABLE: 70 if version[0] == version_str: 71 new_default = version[2] 72 print("found new default opset_version", new_default) 73 break 74 if not new_default: 75 sys.exit( 76 f"failed to find version {version_str} in onnx.helper.VERSION_TABLE at commit {onnx_commit}" 77 ) 78 finally: 79 subprocess.check_call( 80 ("git", "checkout", head_commit), stdout=DEVNULL, stderr=DEVNULL 81 ) 82 83 os.chdir(pytorch_dir) 84 85 read_sub_write( 86 os.path.join("torch", "onnx", "_constants.py"), 87 r"(ONNX_DEFAULT_OPSET = )\d+", 88 new_default, 89 ) 90 read_sub_write( 91 os.path.join("torch", "onnx", "utils.py"), 92 r"(opset_version \(int, default )\d+", 93 new_default, 94 ) 95 96 if not args.skip_build: 97 print("Building PyTorch...") 98 subprocess.check_call( 99 ("python", "setup.py", "develop"), 100 ) 101 print("Updating operator .expect files") 102 subprocess.check_call( 103 ("python", os.path.join("test", "onnx", "test_operators.py"), "--accept"), 104 ) 105 106 107if __name__ == "__main__": 108 parser = argparse.ArgumentParser() 109 parser.add_argument( 110 "--skip-build", 111 "--skip_build", 112 action="store_true", 113 help="Skip building pytorch", 114 ) 115 main(parser.parse_args()) 116