xref: /aosp_15_r20/external/tensorflow/tensorflow/python/tools/freeze_graph.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15r"""Converts checkpoint variables into Const ops in a standalone GraphDef file.
16
17This script is designed to take a GraphDef proto, a SaverDef proto, and a set of
18variable values stored in a checkpoint file, and output a GraphDef with all of
19the variable ops converted into const ops containing the values of the
20variables.
21
22It's useful to do this when we need to load a single file in C++, especially in
23environments like mobile or embedded where we may not have access to the
24RestoreTensor ops and file loading calls that they rely on.
25
26An example of command-line usage is:
27bazel build tensorflow/python/tools:freeze_graph && \
28bazel-bin/tensorflow/python/tools/freeze_graph \
29--input_graph=some_graph_def.pb \
30--input_checkpoint=model.ckpt-8361242 \
31--output_graph=/tmp/frozen_graph.pb --output_node_names=softmax
32
33You can also look at freeze_graph_test.py for an example of how to use it.
34
35"""
36import argparse
37import re
38import sys
39
40from absl import app
41
42from google.protobuf import text_format
43from tensorflow.core.framework import graph_pb2
44from tensorflow.core.protobuf import saver_pb2
45from tensorflow.core.protobuf.meta_graph_pb2 import MetaGraphDef
46from tensorflow.python.checkpoint import checkpoint_management
47from tensorflow.python.client import session
48from tensorflow.python.framework import graph_util
49from tensorflow.python.framework import importer
50from tensorflow.python.platform import gfile
51from tensorflow.python.saved_model import loader
52from tensorflow.python.saved_model import tag_constants
53from tensorflow.python.tools import saved_model_utils
54from tensorflow.python.training import py_checkpoint_reader
55from tensorflow.python.training import saver as saver_lib
56
57
58def _has_no_variables(sess):
59  """Determines if the graph has any variables.
60
61  Args:
62    sess: TensorFlow Session.
63
64  Returns:
65    Bool.
66  """
67  for op in sess.graph.get_operations():
68    if op.type.startswith("Variable") or op.type.endswith("VariableOp"):
69      return False
70  return True
71
72
73def freeze_graph_with_def_protos(input_graph_def,
74                                 input_saver_def,
75                                 input_checkpoint,
76                                 output_node_names,
77                                 restore_op_name,
78                                 filename_tensor_name,
79                                 output_graph,
80                                 clear_devices,
81                                 initializer_nodes,
82                                 variable_names_whitelist="",
83                                 variable_names_denylist="",
84                                 input_meta_graph_def=None,
85                                 input_saved_model_dir=None,
86                                 saved_model_tags=None,
87                                 checkpoint_version=saver_pb2.SaverDef.V2):
88  """Converts all variables in a graph and checkpoint into constants.
89
90  Args:
91    input_graph_def: A `GraphDef`.
92    input_saver_def: A `SaverDef` (optional).
93    input_checkpoint: The prefix of a V1 or V2 checkpoint, with V2 taking
94      priority.  Typically the result of `Saver.save()` or that of
95      `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
96      V1/V2.
97    output_node_names: The name(s) of the output nodes, comma separated.
98    restore_op_name: Unused.
99    filename_tensor_name: Unused.
100    output_graph: String where to write the frozen `GraphDef`.
101    clear_devices: A Bool whether to remove device specifications.
102    initializer_nodes: Comma separated string of initializer nodes to run before
103                       freezing.
104    variable_names_whitelist: The set of variable names to convert (optional, by
105                              default, all variables are converted).
106    variable_names_denylist: The set of variable names to omit converting
107                              to constants (optional).
108    input_meta_graph_def: A `MetaGraphDef` (optional),
109    input_saved_model_dir: Path to the dir with TensorFlow 'SavedModel' file
110                           and variables (optional).
111    saved_model_tags: Group of comma separated tag(s) of the MetaGraphDef to
112                      load, in string format (optional).
113    checkpoint_version: Tensorflow variable file format (saver_pb2.SaverDef.V1
114                        or saver_pb2.SaverDef.V2)
115
116  Returns:
117    Location of the output_graph_def.
118  """
119  del restore_op_name, filename_tensor_name  # Unused by updated loading code.
120
121  # 'input_checkpoint' may be a prefix if we're using Saver V2 format
122  if (not input_saved_model_dir and
123      not checkpoint_management.checkpoint_exists(input_checkpoint)):
124    raise ValueError("Input checkpoint '" + input_checkpoint +
125                     "' doesn't exist!")
126
127  if not output_node_names:
128    raise ValueError(
129        "You need to supply the name of a node to --output_node_names.")
130
131  # Remove all the explicit device specifications for this node. This helps to
132  # make the graph more portable.
133  if clear_devices:
134    if input_meta_graph_def:
135      for node in input_meta_graph_def.graph_def.node:
136        node.device = ""
137    elif input_graph_def:
138      for node in input_graph_def.node:
139        node.device = ""
140
141  if input_graph_def:
142    _ = importer.import_graph_def(input_graph_def, name="")
143  with session.Session() as sess:
144    if input_saver_def:
145      saver = saver_lib.Saver(
146          saver_def=input_saver_def, write_version=checkpoint_version)
147      saver.restore(sess, input_checkpoint)
148    elif input_meta_graph_def:
149      restorer = saver_lib.import_meta_graph(
150          input_meta_graph_def, clear_devices=True)
151      restorer.restore(sess, input_checkpoint)
152      if initializer_nodes:
153        sess.run(initializer_nodes.replace(" ", "").split(","))
154    elif input_saved_model_dir:
155      if saved_model_tags is None:
156        saved_model_tags = []
157      loader.load(sess, saved_model_tags, input_saved_model_dir)
158    else:
159      var_list = {}
160      reader = py_checkpoint_reader.NewCheckpointReader(input_checkpoint)
161      var_to_shape_map = reader.get_variable_to_shape_map()
162
163      # List of all partition variables. Because the condition is heuristic
164      # based, the list could include false positives.
165      all_partition_variable_names = [
166          tensor.name.split(":")[0]
167          for op in sess.graph.get_operations()
168          for tensor in op.values()
169          if re.search(r"/part_\d+/", tensor.name)
170      ]
171      has_partition_var = False
172
173      for key in var_to_shape_map:
174        try:
175          tensor = sess.graph.get_tensor_by_name(key + ":0")
176          if any(key in name for name in all_partition_variable_names):
177            has_partition_var = True
178        except KeyError:
179          # This tensor doesn't exist in the graph (for example it's
180          # 'global_step' or a similar housekeeping element) so skip it.
181          continue
182        var_list[key] = tensor
183
184      try:
185        saver = saver_lib.Saver(
186            var_list=var_list, write_version=checkpoint_version)
187      except TypeError as e:
188        # `var_list` is required to be a map of variable names to Variable
189        # tensors. Partition variables are Identity tensors that cannot be
190        # handled by Saver.
191        if has_partition_var:
192          raise ValueError(
193              "Models containing partition variables cannot be converted "
194              "from checkpoint files. Please pass in a SavedModel using "
195              "the flag --input_saved_model_dir.")
196        # Models that have been frozen previously do not contain Variables.
197        elif _has_no_variables(sess):
198          raise ValueError(
199              "No variables were found in this model. It is likely the model "
200              "was frozen previously. You cannot freeze a graph twice.")
201          return 0
202        else:
203          raise e
204
205      saver.restore(sess, input_checkpoint)
206      if initializer_nodes:
207        sess.run(initializer_nodes.replace(" ", "").split(","))
208
209    variable_names_whitelist = (
210        variable_names_whitelist.replace(" ", "").split(",")
211        if variable_names_whitelist else None)
212    variable_names_denylist = (
213        variable_names_denylist.replace(" ", "").split(",")
214        if variable_names_denylist else None)
215
216    if input_meta_graph_def:
217      output_graph_def = graph_util.convert_variables_to_constants(
218          sess,
219          input_meta_graph_def.graph_def,
220          output_node_names.replace(" ", "").split(","),
221          variable_names_whitelist=variable_names_whitelist,
222          variable_names_blacklist=variable_names_denylist)
223    else:
224      output_graph_def = graph_util.convert_variables_to_constants(
225          sess,
226          input_graph_def,
227          output_node_names.replace(" ", "").split(","),
228          variable_names_whitelist=variable_names_whitelist,
229          variable_names_blacklist=variable_names_denylist)
230
231  # Write GraphDef to file if output path has been given.
232  if output_graph:
233    with gfile.GFile(output_graph, "wb") as f:
234      f.write(output_graph_def.SerializeToString(deterministic=True))
235
236  return output_graph_def
237
238
239def _parse_input_graph_proto(input_graph, input_binary):
240  """Parses input tensorflow graph into GraphDef proto."""
241  if not gfile.Exists(input_graph):
242    raise IOError("Input graph file '" + input_graph + "' does not exist!")
243  input_graph_def = graph_pb2.GraphDef()
244  mode = "rb" if input_binary else "r"
245  with gfile.GFile(input_graph, mode) as f:
246    if input_binary:
247      input_graph_def.ParseFromString(f.read())
248    else:
249      text_format.Merge(f.read(), input_graph_def)
250  return input_graph_def
251
252
253def _parse_input_meta_graph_proto(input_graph, input_binary):
254  """Parses input tensorflow graph into MetaGraphDef proto."""
255  if not gfile.Exists(input_graph):
256    raise IOError("Input meta graph file '" + input_graph + "' does not exist!")
257  input_meta_graph_def = MetaGraphDef()
258  mode = "rb" if input_binary else "r"
259  with gfile.GFile(input_graph, mode) as f:
260    if input_binary:
261      input_meta_graph_def.ParseFromString(f.read())
262    else:
263      text_format.Merge(f.read(), input_meta_graph_def)
264  print("Loaded meta graph file '" + input_graph)
265  return input_meta_graph_def
266
267
268def _parse_input_saver_proto(input_saver, input_binary):
269  """Parses input tensorflow Saver into SaverDef proto."""
270  if not gfile.Exists(input_saver):
271    raise IOError("Input saver file '" + input_saver + "' does not exist!")
272  mode = "rb" if input_binary else "r"
273  with gfile.GFile(input_saver, mode) as f:
274    saver_def = saver_pb2.SaverDef()
275    if input_binary:
276      saver_def.ParseFromString(f.read())
277    else:
278      text_format.Merge(f.read(), saver_def)
279  return saver_def
280
281
282def freeze_graph(input_graph,
283                 input_saver,
284                 input_binary,
285                 input_checkpoint,
286                 output_node_names,
287                 restore_op_name,
288                 filename_tensor_name,
289                 output_graph,
290                 clear_devices,
291                 initializer_nodes,
292                 variable_names_whitelist="",
293                 variable_names_denylist="",
294                 input_meta_graph=None,
295                 input_saved_model_dir=None,
296                 saved_model_tags=tag_constants.SERVING,
297                 checkpoint_version=saver_pb2.SaverDef.V2):
298  """Converts all variables in a graph and checkpoint into constants.
299
300  Args:
301    input_graph: A `GraphDef` file to load.
302    input_saver: A TensorFlow Saver file.
303    input_binary: A Bool. True means input_graph is .pb, False indicates .pbtxt.
304    input_checkpoint: The prefix of a V1 or V2 checkpoint, with V2 taking
305      priority.  Typically the result of `Saver.save()` or that of
306      `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
307      V1/V2.
308    output_node_names: The name(s) of the output nodes, comma separated.
309    restore_op_name: Unused.
310    filename_tensor_name: Unused.
311    output_graph: String where to write the frozen `GraphDef`.
312    clear_devices: A Bool whether to remove device specifications.
313    initializer_nodes: Comma separated list of initializer nodes to run before
314                       freezing.
315    variable_names_whitelist: The set of variable names to convert (optional, by
316                              default, all variables are converted),
317    variable_names_denylist: The set of variable names to omit converting
318                              to constants (optional).
319    input_meta_graph: A `MetaGraphDef` file to load (optional).
320    input_saved_model_dir: Path to the dir with TensorFlow 'SavedModel' file and
321                           variables (optional).
322    saved_model_tags: Group of comma separated tag(s) of the MetaGraphDef to
323                      load, in string format.
324    checkpoint_version: Tensorflow variable file format (saver_pb2.SaverDef.V1
325                        or saver_pb2.SaverDef.V2).
326  Returns:
327    String that is the location of frozen GraphDef.
328  """
329  input_graph_def = None
330  if input_saved_model_dir:
331    input_graph_def = saved_model_utils.get_meta_graph_def(
332        input_saved_model_dir, saved_model_tags).graph_def
333  elif input_graph:
334    input_graph_def = _parse_input_graph_proto(input_graph, input_binary)
335  input_meta_graph_def = None
336  if input_meta_graph:
337    input_meta_graph_def = _parse_input_meta_graph_proto(
338        input_meta_graph, input_binary)
339  input_saver_def = None
340  if input_saver:
341    input_saver_def = _parse_input_saver_proto(input_saver, input_binary)
342  return freeze_graph_with_def_protos(
343      input_graph_def,
344      input_saver_def,
345      input_checkpoint,
346      output_node_names,
347      restore_op_name,
348      filename_tensor_name,
349      output_graph,
350      clear_devices,
351      initializer_nodes,
352      variable_names_whitelist,
353      variable_names_denylist,
354      input_meta_graph_def,
355      input_saved_model_dir,
356      [tag for tag in saved_model_tags.replace(" ", "").split(",") if tag],
357      checkpoint_version=checkpoint_version)
358
359
360def main(unused_args, flags):
361  if flags.checkpoint_version == 1:
362    checkpoint_version = saver_pb2.SaverDef.V1
363  elif flags.checkpoint_version == 2:
364    checkpoint_version = saver_pb2.SaverDef.V2
365  else:
366    raise ValueError("Invalid checkpoint version (must be '1' or '2'): %d" %
367                     flags.checkpoint_version)
368  freeze_graph(flags.input_graph, flags.input_saver, flags.input_binary,
369               flags.input_checkpoint, flags.output_node_names,
370               flags.restore_op_name, flags.filename_tensor_name,
371               flags.output_graph, flags.clear_devices, flags.initializer_nodes,
372               flags.variable_names_whitelist, flags.variable_names_denylist,
373               flags.input_meta_graph, flags.input_saved_model_dir,
374               flags.saved_model_tags, checkpoint_version)
375
376
377def run_main():
378  """Main function of freeze_graph."""
379  parser = argparse.ArgumentParser()
380  parser.register("type", "bool", lambda v: v.lower() == "true")
381  parser.add_argument(
382      "--input_graph",
383      type=str,
384      default="",
385      help="TensorFlow \'GraphDef\' file to load.")
386  parser.add_argument(
387      "--input_saver",
388      type=str,
389      default="",
390      help="TensorFlow saver file to load.")
391  parser.add_argument(
392      "--input_checkpoint",
393      type=str,
394      default="",
395      help="TensorFlow variables file to load.")
396  parser.add_argument(
397      "--checkpoint_version",
398      type=int,
399      default=2,
400      help="Tensorflow variable file format")
401  parser.add_argument(
402      "--output_graph",
403      type=str,
404      default="",
405      help="Output \'GraphDef\' file name.")
406  parser.add_argument(
407      "--input_binary",
408      nargs="?",
409      const=True,
410      type="bool",
411      default=False,
412      help="Whether the input files are in binary format.")
413  parser.add_argument(
414      "--output_node_names",
415      type=str,
416      default="",
417      help="The name of the output nodes, comma separated.")
418  parser.add_argument(
419      "--restore_op_name",
420      type=str,
421      default="save/restore_all",
422      help="""\
423      The name of the master restore operator. Deprecated, unused by updated \
424      loading code.
425      """)
426  parser.add_argument(
427      "--filename_tensor_name",
428      type=str,
429      default="save/Const:0",
430      help="""\
431      The name of the tensor holding the save path. Deprecated, unused by \
432      updated loading code.
433      """)
434  parser.add_argument(
435      "--clear_devices",
436      nargs="?",
437      const=True,
438      type="bool",
439      default=True,
440      help="Whether to remove device specifications.")
441  parser.add_argument(
442      "--initializer_nodes",
443      type=str,
444      default="",
445      help="Comma separated list of initializer nodes to run before freezing.")
446  parser.add_argument(
447      "--variable_names_whitelist",
448      type=str,
449      default="",
450      help="""\
451      Comma separated list of variables to convert to constants. If specified, \
452      only those variables will be converted to constants.\
453      """)
454  parser.add_argument(
455      "--variable_names_denylist",
456      type=str,
457      default="",
458      help="""\
459      Comma separated list of variables to skip converting to constants.\
460      """)
461  parser.add_argument(
462      "--input_meta_graph",
463      type=str,
464      default="",
465      help="TensorFlow \'MetaGraphDef\' file to load.")
466  parser.add_argument(
467      "--input_saved_model_dir",
468      type=str,
469      default="",
470      help="Path to the dir with TensorFlow \'SavedModel\' file and variables.")
471  parser.add_argument(
472      "--saved_model_tags",
473      type=str,
474      default="serve",
475      help="""\
476      Group of tag(s) of the MetaGraphDef to load, in string format,\
477      separated by \',\'. For tag-set contains multiple tags, all tags \
478      must be passed in.\
479      """)
480  flags, unparsed = parser.parse_known_args()
481
482  my_main = lambda unused_args: main(unused_args, flags)
483  app.run(main=my_main, argv=[sys.argv[0]] + unparsed)
484
485
486if __name__ == "__main__":
487  run_main()
488