xref: /aosp_15_r20/external/tensorflow/tensorflow/tools/ci_build/update_version.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1#!/usr/bin/python
2# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16#
17# Automatically update TensorFlow version in source files
18#
19# Usage:
20#           ./tensorflow/tools/ci_build/update_version.py --version 1.4.0-rc1
21#           ./tensorflow/tools/ci_build/update_version.py --nightly
22#
23"""Update version of TensorFlow script."""
24
25# pylint: disable=superfluous-parens
26
27import argparse
28import os
29import re
30import subprocess
31import time
32
33# File parameters.
34TF_SRC_DIR = "tensorflow"
35VERSION_H = "%s/core/public/version.h" % TF_SRC_DIR
36SETUP_PY = "%s/tools/pip_package/setup.py" % TF_SRC_DIR
37SETUP_PARTNER_BUILD_PY = "%s/tools/pip_package/setup_partner_builds.py" % TF_SRC_DIR
38README_MD = "./README.md"
39TENSORFLOW_BZL = "%s/tensorflow.bzl" % TF_SRC_DIR
40RELEVANT_FILES = [
41    TF_SRC_DIR, VERSION_H, SETUP_PY, SETUP_PARTNER_BUILD_PY, README_MD
42]
43
44# Version type parameters.
45NIGHTLY_VERSION = 1
46REGULAR_VERSION = 0
47
48
49def check_existence(filename):
50  """Check the existence of file or dir."""
51  if not os.path.exists(filename):
52    raise RuntimeError("%s not found. Are you under the TensorFlow source root"
53                       " directory?" % filename)
54
55
56def check_all_files():
57  """Check all relevant files necessary for upgrade."""
58  for file_name in RELEVANT_FILES:
59    check_existence(file_name)
60
61
62def replace_string_in_line(search, replace, filename):
63  """Replace with sed when regex is required."""
64  with open(filename, "r") as source:
65    content = source.read()
66  with open(filename, "w") as source:
67    source.write(re.sub(search, replace, content))
68
69
70class Version(object):
71  """Version class object that stores SemVer version information."""
72
73  def __init__(self, major, minor, patch, identifier_string, version_type):
74    """Constructor.
75
76    Args:
77      major: major string eg. (1)
78      minor: minor string eg. (3)
79      patch: patch string eg. (1)
80      identifier_string: extension string eg. (-rc0)
81      version_type: version parameter ((REGULAR|NIGHTLY)_VERSION)
82    """
83    self.major = major
84    self.minor = minor
85    self.patch = patch
86    self.identifier_string = identifier_string
87    self.version_type = version_type
88    self._update_string()
89
90  def _update_string(self):
91    self.string = "%s.%s.%s%s" % (self.major,
92                                  self.minor,
93                                  self.patch,
94                                  self.identifier_string)
95
96  def __str__(self):
97    return self.string
98
99  def set_identifier_string(self, identifier_string):
100    self.identifier_string = identifier_string
101    self._update_string()
102
103  @property
104  def pep_440_str(self):
105    if self.version_type == REGULAR_VERSION:
106      return_string = "%s.%s.%s%s" % (self.major,
107                                      self.minor,
108                                      self.patch,
109                                      self.identifier_string)
110      return return_string.replace("-", "")
111    else:
112      return_string = "%s.%s.%s" % (self.major,
113                                    self.minor,
114                                    self.identifier_string)
115      return return_string.replace("-", "")
116
117  @staticmethod
118  def parse_from_string(string, version_type):
119    """Returns version object from Semver string.
120
121    Args:
122      string: version string
123      version_type: version parameter
124
125    Raises:
126      RuntimeError: If the version string is not valid.
127    """
128    # Check validity of new version string.
129    if not re.search(r"[0-9]+\.[0-9]+\.[a-zA-Z0-9]+", string):
130      raise RuntimeError("Invalid version string: %s" % string)
131
132    major, minor, extension = string.split(".", 2)
133
134    # Isolate patch and identifier string if identifier string exists.
135    extension_split = extension.split("-", 1)
136    patch = extension_split[0]
137    if len(extension_split) == 2:
138      identifier_string = "-" + extension_split[1]
139    else:
140      identifier_string = ""
141
142    return Version(major,
143                   minor,
144                   patch,
145                   identifier_string,
146                   version_type)
147
148
149def get_current_semver_version():
150  """Returns a Version object of current version.
151
152  Returns:
153    version: Version object of current SemVer string based on information from
154    core/public/version.h
155  """
156
157  # Get current version information.
158  version_file = open(VERSION_H, "r")
159  for line in version_file:
160    major_match = re.search("^#define TF_MAJOR_VERSION ([0-9]+)", line)
161    minor_match = re.search("^#define TF_MINOR_VERSION ([0-9]+)", line)
162    patch_match = re.search("^#define TF_PATCH_VERSION ([0-9]+)", line)
163    extension_match = re.search("^#define TF_VERSION_SUFFIX \"(.*)\"", line)
164    if major_match:
165      old_major = major_match.group(1)
166    if minor_match:
167      old_minor = minor_match.group(1)
168    if patch_match:
169      old_patch_num = patch_match.group(1)
170    if extension_match:
171      old_extension = extension_match.group(1)
172      break
173
174  if "dev" in old_extension:
175    version_type = NIGHTLY_VERSION
176  else:
177    version_type = REGULAR_VERSION
178
179  return Version(old_major,
180                 old_minor,
181                 old_patch_num,
182                 old_extension,
183                 version_type)
184
185
186def update_version_h(old_version, new_version):
187  """Update tensorflow/core/public/version.h."""
188  replace_string_in_line("#define TF_MAJOR_VERSION %s" % old_version.major,
189                         "#define TF_MAJOR_VERSION %s" % new_version.major,
190                         VERSION_H)
191  replace_string_in_line("#define TF_MINOR_VERSION %s" % old_version.minor,
192                         "#define TF_MINOR_VERSION %s" % new_version.minor,
193                         VERSION_H)
194  replace_string_in_line("#define TF_PATCH_VERSION %s" % old_version.patch,
195                         "#define TF_PATCH_VERSION %s" % new_version.patch,
196                         VERSION_H)
197  replace_string_in_line(
198      "#define TF_VERSION_SUFFIX \"%s\"" % old_version.identifier_string,
199      "#define TF_VERSION_SUFFIX \"%s\"" % new_version.identifier_string,
200      VERSION_H)
201
202
203def update_setup_dot_py(old_version, new_version):
204  """Update setup.py."""
205  replace_string_in_line("_VERSION = '%s'" % old_version.string,
206                         "_VERSION = '%s'" % new_version.string, SETUP_PY)
207
208
209def update_setup_partner_builds_dot_py(old_version, new_version):
210  """Update setup_partner_builds.py."""
211  replace_string_in_line("_VERSION = '%s'" % old_version.string,
212                         "_VERSION = '%s'" % new_version.string,
213                         SETUP_PARTNER_BUILD_PY)
214
215
216def update_readme(old_version, new_version):
217  """Update README."""
218  pep_440_str = new_version.pep_440_str
219  replace_string_in_line(r"%s\.%s\.([[:alnum:]]+)-" % (old_version.major,
220                                                       old_version.minor),
221                         "%s-" % pep_440_str, README_MD)
222
223
224def update_tensorflow_bzl(old_version, new_version):
225  """Update tensorflow.bzl."""
226  old_mmp = "%s.%s.%s" % (old_version.major, old_version.minor,
227                          old_version.patch)
228  new_mmp = "%s.%s.%s" % (new_version.major, new_version.minor,
229                          new_version.patch)
230  replace_string_in_line('VERSION = "%s"' % old_mmp,
231                         'VERSION = "%s"' % new_mmp, TENSORFLOW_BZL)
232
233
234def major_minor_change(old_version, new_version):
235  """Check if a major or minor change occurred."""
236  major_mismatch = old_version.major != new_version.major
237  minor_mismatch = old_version.minor != new_version.minor
238  if major_mismatch or minor_mismatch:
239    return True
240  return False
241
242
243def check_for_lingering_string(lingering_string):
244  """Check for given lingering strings."""
245  formatted_string = lingering_string.replace(".", r"\.")
246  try:
247    linger_str_output = subprocess.check_output(
248        ["grep", "-rnoH", formatted_string, TF_SRC_DIR])
249    linger_strs = linger_str_output.decode("utf8").split("\n")
250  except subprocess.CalledProcessError:
251    linger_strs = []
252
253  if linger_strs:
254    print("WARNING: Below are potentially instances of lingering old version "
255          "string \"%s\" in source directory \"%s/\" that are not "
256          "updated by this script. Please check them manually!"
257          % (lingering_string, TF_SRC_DIR))
258    for linger_str in linger_strs:
259      print(linger_str)
260  else:
261    print("No lingering old version strings \"%s\" found in source directory"
262          " \"%s/\". Good." % (lingering_string, TF_SRC_DIR))
263
264
265def check_for_old_version(old_version, new_version):
266  """Check for old version references."""
267  for old_ver in [old_version.string, old_version.pep_440_str]:
268    check_for_lingering_string(old_ver)
269
270  if major_minor_change(old_version, new_version):
271    old_r_major_minor = "r%s.%s" % (old_version.major, old_version.minor)
272    check_for_lingering_string(old_r_major_minor)
273
274
275def main():
276  """This script updates all instances of version in the tensorflow directory.
277
278  Requirements:
279    version: The version tag
280    OR
281    nightly: Create a nightly tag with current date
282
283  Raises:
284    RuntimeError: If the script is not being run from tf source dir
285  """
286
287  parser = argparse.ArgumentParser(description="Cherry picking automation.")
288
289  # Arg information
290  parser.add_argument("--version",
291                      help="<new_major_ver>.<new_minor_ver>.<new_patch_ver>",
292                      default="")
293  parser.add_argument("--nightly",
294                      help="disable the service provisioning step",
295                      action="store_true")
296
297  args = parser.parse_args()
298
299  check_all_files()
300  old_version = get_current_semver_version()
301
302  if args.nightly:
303    if args.version:
304      new_version = Version.parse_from_string(args.version, NIGHTLY_VERSION)
305      new_version.set_identifier_string("-dev" + time.strftime("%Y%m%d"))
306    else:
307      new_version = Version(old_version.major,
308                            str(old_version.minor),
309                            old_version.patch,
310                            "-dev" + time.strftime("%Y%m%d"),
311                            NIGHTLY_VERSION)
312  else:
313    new_version = Version.parse_from_string(args.version, REGULAR_VERSION)
314
315  update_version_h(old_version, new_version)
316  update_setup_dot_py(old_version, new_version)
317  update_setup_partner_builds_dot_py(old_version, new_version)
318  update_readme(old_version, new_version)
319  update_tensorflow_bzl(old_version, new_version)
320
321  # Print transition details.
322  print("Major: %s -> %s" % (old_version.major, new_version.major))
323  print("Minor: %s -> %s" % (old_version.minor, new_version.minor))
324  print("Patch: %s -> %s\n" % (old_version.patch, new_version.patch))
325
326  check_for_old_version(old_version, new_version)
327
328
329if __name__ == "__main__":
330  main()
331