1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15r"""Converts checkpoint variables into Const ops in a standalone GraphDef file. 16 17This script is designed to take a GraphDef proto, a SaverDef proto, and a set of 18variable values stored in a checkpoint file, and output a GraphDef with all of 19the variable ops converted into const ops containing the values of the 20variables. 21 22It's useful to do this when we need to load a single file in C++, especially in 23environments like mobile or embedded where we may not have access to the 24RestoreTensor ops and file loading calls that they rely on. 25 26An example of command-line usage is: 27bazel build tensorflow/python/tools:freeze_graph && \ 28bazel-bin/tensorflow/python/tools/freeze_graph \ 29--input_graph=some_graph_def.pb \ 30--input_checkpoint=model.ckpt-8361242 \ 31--output_graph=/tmp/frozen_graph.pb --output_node_names=softmax 32 33You can also look at freeze_graph_test.py for an example of how to use it. 34 35""" 36import argparse 37import re 38import sys 39 40from absl import app 41 42from google.protobuf import text_format 43from tensorflow.core.framework import graph_pb2 44from tensorflow.core.protobuf import saver_pb2 45from tensorflow.core.protobuf.meta_graph_pb2 import MetaGraphDef 46from tensorflow.python.checkpoint import checkpoint_management 47from tensorflow.python.client import session 48from tensorflow.python.framework import graph_util 49from tensorflow.python.framework import importer 50from tensorflow.python.platform import gfile 51from tensorflow.python.saved_model import loader 52from tensorflow.python.saved_model import tag_constants 53from tensorflow.python.tools import saved_model_utils 54from tensorflow.python.training import py_checkpoint_reader 55from tensorflow.python.training import saver as saver_lib 56 57 58def _has_no_variables(sess): 59 """Determines if the graph has any variables. 60 61 Args: 62 sess: TensorFlow Session. 63 64 Returns: 65 Bool. 66 """ 67 for op in sess.graph.get_operations(): 68 if op.type.startswith("Variable") or op.type.endswith("VariableOp"): 69 return False 70 return True 71 72 73def freeze_graph_with_def_protos(input_graph_def, 74 input_saver_def, 75 input_checkpoint, 76 output_node_names, 77 restore_op_name, 78 filename_tensor_name, 79 output_graph, 80 clear_devices, 81 initializer_nodes, 82 variable_names_whitelist="", 83 variable_names_denylist="", 84 input_meta_graph_def=None, 85 input_saved_model_dir=None, 86 saved_model_tags=None, 87 checkpoint_version=saver_pb2.SaverDef.V2): 88 """Converts all variables in a graph and checkpoint into constants. 89 90 Args: 91 input_graph_def: A `GraphDef`. 92 input_saver_def: A `SaverDef` (optional). 93 input_checkpoint: The prefix of a V1 or V2 checkpoint, with V2 taking 94 priority. Typically the result of `Saver.save()` or that of 95 `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or 96 V1/V2. 97 output_node_names: The name(s) of the output nodes, comma separated. 98 restore_op_name: Unused. 99 filename_tensor_name: Unused. 100 output_graph: String where to write the frozen `GraphDef`. 101 clear_devices: A Bool whether to remove device specifications. 102 initializer_nodes: Comma separated string of initializer nodes to run before 103 freezing. 104 variable_names_whitelist: The set of variable names to convert (optional, by 105 default, all variables are converted). 106 variable_names_denylist: The set of variable names to omit converting 107 to constants (optional). 108 input_meta_graph_def: A `MetaGraphDef` (optional), 109 input_saved_model_dir: Path to the dir with TensorFlow 'SavedModel' file 110 and variables (optional). 111 saved_model_tags: Group of comma separated tag(s) of the MetaGraphDef to 112 load, in string format (optional). 113 checkpoint_version: Tensorflow variable file format (saver_pb2.SaverDef.V1 114 or saver_pb2.SaverDef.V2) 115 116 Returns: 117 Location of the output_graph_def. 118 """ 119 del restore_op_name, filename_tensor_name # Unused by updated loading code. 120 121 # 'input_checkpoint' may be a prefix if we're using Saver V2 format 122 if (not input_saved_model_dir and 123 not checkpoint_management.checkpoint_exists(input_checkpoint)): 124 raise ValueError("Input checkpoint '" + input_checkpoint + 125 "' doesn't exist!") 126 127 if not output_node_names: 128 raise ValueError( 129 "You need to supply the name of a node to --output_node_names.") 130 131 # Remove all the explicit device specifications for this node. This helps to 132 # make the graph more portable. 133 if clear_devices: 134 if input_meta_graph_def: 135 for node in input_meta_graph_def.graph_def.node: 136 node.device = "" 137 elif input_graph_def: 138 for node in input_graph_def.node: 139 node.device = "" 140 141 if input_graph_def: 142 _ = importer.import_graph_def(input_graph_def, name="") 143 with session.Session() as sess: 144 if input_saver_def: 145 saver = saver_lib.Saver( 146 saver_def=input_saver_def, write_version=checkpoint_version) 147 saver.restore(sess, input_checkpoint) 148 elif input_meta_graph_def: 149 restorer = saver_lib.import_meta_graph( 150 input_meta_graph_def, clear_devices=True) 151 restorer.restore(sess, input_checkpoint) 152 if initializer_nodes: 153 sess.run(initializer_nodes.replace(" ", "").split(",")) 154 elif input_saved_model_dir: 155 if saved_model_tags is None: 156 saved_model_tags = [] 157 loader.load(sess, saved_model_tags, input_saved_model_dir) 158 else: 159 var_list = {} 160 reader = py_checkpoint_reader.NewCheckpointReader(input_checkpoint) 161 var_to_shape_map = reader.get_variable_to_shape_map() 162 163 # List of all partition variables. Because the condition is heuristic 164 # based, the list could include false positives. 165 all_partition_variable_names = [ 166 tensor.name.split(":")[0] 167 for op in sess.graph.get_operations() 168 for tensor in op.values() 169 if re.search(r"/part_\d+/", tensor.name) 170 ] 171 has_partition_var = False 172 173 for key in var_to_shape_map: 174 try: 175 tensor = sess.graph.get_tensor_by_name(key + ":0") 176 if any(key in name for name in all_partition_variable_names): 177 has_partition_var = True 178 except KeyError: 179 # This tensor doesn't exist in the graph (for example it's 180 # 'global_step' or a similar housekeeping element) so skip it. 181 continue 182 var_list[key] = tensor 183 184 try: 185 saver = saver_lib.Saver( 186 var_list=var_list, write_version=checkpoint_version) 187 except TypeError as e: 188 # `var_list` is required to be a map of variable names to Variable 189 # tensors. Partition variables are Identity tensors that cannot be 190 # handled by Saver. 191 if has_partition_var: 192 raise ValueError( 193 "Models containing partition variables cannot be converted " 194 "from checkpoint files. Please pass in a SavedModel using " 195 "the flag --input_saved_model_dir.") 196 # Models that have been frozen previously do not contain Variables. 197 elif _has_no_variables(sess): 198 raise ValueError( 199 "No variables were found in this model. It is likely the model " 200 "was frozen previously. You cannot freeze a graph twice.") 201 return 0 202 else: 203 raise e 204 205 saver.restore(sess, input_checkpoint) 206 if initializer_nodes: 207 sess.run(initializer_nodes.replace(" ", "").split(",")) 208 209 variable_names_whitelist = ( 210 variable_names_whitelist.replace(" ", "").split(",") 211 if variable_names_whitelist else None) 212 variable_names_denylist = ( 213 variable_names_denylist.replace(" ", "").split(",") 214 if variable_names_denylist else None) 215 216 if input_meta_graph_def: 217 output_graph_def = graph_util.convert_variables_to_constants( 218 sess, 219 input_meta_graph_def.graph_def, 220 output_node_names.replace(" ", "").split(","), 221 variable_names_whitelist=variable_names_whitelist, 222 variable_names_blacklist=variable_names_denylist) 223 else: 224 output_graph_def = graph_util.convert_variables_to_constants( 225 sess, 226 input_graph_def, 227 output_node_names.replace(" ", "").split(","), 228 variable_names_whitelist=variable_names_whitelist, 229 variable_names_blacklist=variable_names_denylist) 230 231 # Write GraphDef to file if output path has been given. 232 if output_graph: 233 with gfile.GFile(output_graph, "wb") as f: 234 f.write(output_graph_def.SerializeToString(deterministic=True)) 235 236 return output_graph_def 237 238 239def _parse_input_graph_proto(input_graph, input_binary): 240 """Parses input tensorflow graph into GraphDef proto.""" 241 if not gfile.Exists(input_graph): 242 raise IOError("Input graph file '" + input_graph + "' does not exist!") 243 input_graph_def = graph_pb2.GraphDef() 244 mode = "rb" if input_binary else "r" 245 with gfile.GFile(input_graph, mode) as f: 246 if input_binary: 247 input_graph_def.ParseFromString(f.read()) 248 else: 249 text_format.Merge(f.read(), input_graph_def) 250 return input_graph_def 251 252 253def _parse_input_meta_graph_proto(input_graph, input_binary): 254 """Parses input tensorflow graph into MetaGraphDef proto.""" 255 if not gfile.Exists(input_graph): 256 raise IOError("Input meta graph file '" + input_graph + "' does not exist!") 257 input_meta_graph_def = MetaGraphDef() 258 mode = "rb" if input_binary else "r" 259 with gfile.GFile(input_graph, mode) as f: 260 if input_binary: 261 input_meta_graph_def.ParseFromString(f.read()) 262 else: 263 text_format.Merge(f.read(), input_meta_graph_def) 264 print("Loaded meta graph file '" + input_graph) 265 return input_meta_graph_def 266 267 268def _parse_input_saver_proto(input_saver, input_binary): 269 """Parses input tensorflow Saver into SaverDef proto.""" 270 if not gfile.Exists(input_saver): 271 raise IOError("Input saver file '" + input_saver + "' does not exist!") 272 mode = "rb" if input_binary else "r" 273 with gfile.GFile(input_saver, mode) as f: 274 saver_def = saver_pb2.SaverDef() 275 if input_binary: 276 saver_def.ParseFromString(f.read()) 277 else: 278 text_format.Merge(f.read(), saver_def) 279 return saver_def 280 281 282def freeze_graph(input_graph, 283 input_saver, 284 input_binary, 285 input_checkpoint, 286 output_node_names, 287 restore_op_name, 288 filename_tensor_name, 289 output_graph, 290 clear_devices, 291 initializer_nodes, 292 variable_names_whitelist="", 293 variable_names_denylist="", 294 input_meta_graph=None, 295 input_saved_model_dir=None, 296 saved_model_tags=tag_constants.SERVING, 297 checkpoint_version=saver_pb2.SaverDef.V2): 298 """Converts all variables in a graph and checkpoint into constants. 299 300 Args: 301 input_graph: A `GraphDef` file to load. 302 input_saver: A TensorFlow Saver file. 303 input_binary: A Bool. True means input_graph is .pb, False indicates .pbtxt. 304 input_checkpoint: The prefix of a V1 or V2 checkpoint, with V2 taking 305 priority. Typically the result of `Saver.save()` or that of 306 `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or 307 V1/V2. 308 output_node_names: The name(s) of the output nodes, comma separated. 309 restore_op_name: Unused. 310 filename_tensor_name: Unused. 311 output_graph: String where to write the frozen `GraphDef`. 312 clear_devices: A Bool whether to remove device specifications. 313 initializer_nodes: Comma separated list of initializer nodes to run before 314 freezing. 315 variable_names_whitelist: The set of variable names to convert (optional, by 316 default, all variables are converted), 317 variable_names_denylist: The set of variable names to omit converting 318 to constants (optional). 319 input_meta_graph: A `MetaGraphDef` file to load (optional). 320 input_saved_model_dir: Path to the dir with TensorFlow 'SavedModel' file and 321 variables (optional). 322 saved_model_tags: Group of comma separated tag(s) of the MetaGraphDef to 323 load, in string format. 324 checkpoint_version: Tensorflow variable file format (saver_pb2.SaverDef.V1 325 or saver_pb2.SaverDef.V2). 326 Returns: 327 String that is the location of frozen GraphDef. 328 """ 329 input_graph_def = None 330 if input_saved_model_dir: 331 input_graph_def = saved_model_utils.get_meta_graph_def( 332 input_saved_model_dir, saved_model_tags).graph_def 333 elif input_graph: 334 input_graph_def = _parse_input_graph_proto(input_graph, input_binary) 335 input_meta_graph_def = None 336 if input_meta_graph: 337 input_meta_graph_def = _parse_input_meta_graph_proto( 338 input_meta_graph, input_binary) 339 input_saver_def = None 340 if input_saver: 341 input_saver_def = _parse_input_saver_proto(input_saver, input_binary) 342 return freeze_graph_with_def_protos( 343 input_graph_def, 344 input_saver_def, 345 input_checkpoint, 346 output_node_names, 347 restore_op_name, 348 filename_tensor_name, 349 output_graph, 350 clear_devices, 351 initializer_nodes, 352 variable_names_whitelist, 353 variable_names_denylist, 354 input_meta_graph_def, 355 input_saved_model_dir, 356 [tag for tag in saved_model_tags.replace(" ", "").split(",") if tag], 357 checkpoint_version=checkpoint_version) 358 359 360def main(unused_args, flags): 361 if flags.checkpoint_version == 1: 362 checkpoint_version = saver_pb2.SaverDef.V1 363 elif flags.checkpoint_version == 2: 364 checkpoint_version = saver_pb2.SaverDef.V2 365 else: 366 raise ValueError("Invalid checkpoint version (must be '1' or '2'): %d" % 367 flags.checkpoint_version) 368 freeze_graph(flags.input_graph, flags.input_saver, flags.input_binary, 369 flags.input_checkpoint, flags.output_node_names, 370 flags.restore_op_name, flags.filename_tensor_name, 371 flags.output_graph, flags.clear_devices, flags.initializer_nodes, 372 flags.variable_names_whitelist, flags.variable_names_denylist, 373 flags.input_meta_graph, flags.input_saved_model_dir, 374 flags.saved_model_tags, checkpoint_version) 375 376 377def run_main(): 378 """Main function of freeze_graph.""" 379 parser = argparse.ArgumentParser() 380 parser.register("type", "bool", lambda v: v.lower() == "true") 381 parser.add_argument( 382 "--input_graph", 383 type=str, 384 default="", 385 help="TensorFlow \'GraphDef\' file to load.") 386 parser.add_argument( 387 "--input_saver", 388 type=str, 389 default="", 390 help="TensorFlow saver file to load.") 391 parser.add_argument( 392 "--input_checkpoint", 393 type=str, 394 default="", 395 help="TensorFlow variables file to load.") 396 parser.add_argument( 397 "--checkpoint_version", 398 type=int, 399 default=2, 400 help="Tensorflow variable file format") 401 parser.add_argument( 402 "--output_graph", 403 type=str, 404 default="", 405 help="Output \'GraphDef\' file name.") 406 parser.add_argument( 407 "--input_binary", 408 nargs="?", 409 const=True, 410 type="bool", 411 default=False, 412 help="Whether the input files are in binary format.") 413 parser.add_argument( 414 "--output_node_names", 415 type=str, 416 default="", 417 help="The name of the output nodes, comma separated.") 418 parser.add_argument( 419 "--restore_op_name", 420 type=str, 421 default="save/restore_all", 422 help="""\ 423 The name of the master restore operator. Deprecated, unused by updated \ 424 loading code. 425 """) 426 parser.add_argument( 427 "--filename_tensor_name", 428 type=str, 429 default="save/Const:0", 430 help="""\ 431 The name of the tensor holding the save path. Deprecated, unused by \ 432 updated loading code. 433 """) 434 parser.add_argument( 435 "--clear_devices", 436 nargs="?", 437 const=True, 438 type="bool", 439 default=True, 440 help="Whether to remove device specifications.") 441 parser.add_argument( 442 "--initializer_nodes", 443 type=str, 444 default="", 445 help="Comma separated list of initializer nodes to run before freezing.") 446 parser.add_argument( 447 "--variable_names_whitelist", 448 type=str, 449 default="", 450 help="""\ 451 Comma separated list of variables to convert to constants. If specified, \ 452 only those variables will be converted to constants.\ 453 """) 454 parser.add_argument( 455 "--variable_names_denylist", 456 type=str, 457 default="", 458 help="""\ 459 Comma separated list of variables to skip converting to constants.\ 460 """) 461 parser.add_argument( 462 "--input_meta_graph", 463 type=str, 464 default="", 465 help="TensorFlow \'MetaGraphDef\' file to load.") 466 parser.add_argument( 467 "--input_saved_model_dir", 468 type=str, 469 default="", 470 help="Path to the dir with TensorFlow \'SavedModel\' file and variables.") 471 parser.add_argument( 472 "--saved_model_tags", 473 type=str, 474 default="serve", 475 help="""\ 476 Group of tag(s) of the MetaGraphDef to load, in string format,\ 477 separated by \',\'. For tag-set contains multiple tags, all tags \ 478 must be passed in.\ 479 """) 480 flags, unparsed = parser.parse_known_args() 481 482 my_main = lambda unused_args: main(unused_args, flags) 483 app.run(main=my_main, argv=[sys.argv[0]] + unparsed) 484 485 486if __name__ == "__main__": 487 run_main() 488