1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests the graph freezing tool.""" 16 17import os 18import re 19 20from absl.testing import parameterized 21 22from tensorflow.core.example import example_pb2 23from tensorflow.core.framework import graph_pb2 24from tensorflow.core.protobuf import saver_pb2 25from tensorflow.python.client import session 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import graph_io 28from tensorflow.python.framework import importer 29from tensorflow.python.framework import ops 30from tensorflow.python.framework import test_util 31from tensorflow.python.ops import array_ops 32from tensorflow.python.ops import math_ops 33from tensorflow.python.ops import nn 34from tensorflow.python.ops import parsing_ops 35from tensorflow.python.ops import partitioned_variables 36from tensorflow.python.ops import variable_scope 37from tensorflow.python.ops import variables 38from tensorflow.python.platform import test 39from tensorflow.python.saved_model import builder as saved_model_builder 40from tensorflow.python.saved_model import signature_constants 41from tensorflow.python.saved_model import signature_def_utils 42from tensorflow.python.saved_model import tag_constants 43from tensorflow.python.tools import freeze_graph 44from tensorflow.python.training import saver as saver_lib 45 46 47class FreezeGraphTest(test_util.TensorFlowTestCase, parameterized.TestCase): 48 49 def _testFreezeGraph(self, saver_write_version): 50 51 checkpoint_prefix = os.path.join(self.get_temp_dir(), "saved_checkpoint") 52 checkpoint_state_name = "checkpoint_state" 53 input_graph_name = "input_graph.pb" 54 output_graph_name = "output_graph.pb" 55 56 # We'll create an input graph that has a single variable containing 1.0, 57 # and that then multiplies it by 2. 58 with ops.Graph().as_default(): 59 variable_node = variables.VariableV1(1.0, name="variable_node") 60 output_node = math_ops.multiply(variable_node, 2.0, name="output_node") 61 sess = session.Session() 62 init = variables.global_variables_initializer() 63 sess.run(init) 64 output = sess.run(output_node) 65 self.assertNear(2.0, output, 0.00001) 66 saver = saver_lib.Saver(write_version=saver_write_version) 67 checkpoint_path = saver.save( 68 sess, 69 checkpoint_prefix, 70 global_step=0, 71 latest_filename=checkpoint_state_name) 72 graph_io.write_graph(sess.graph, self.get_temp_dir(), input_graph_name) 73 74 # We save out the graph to disk, and then call the const conversion 75 # routine. 76 input_graph_path = os.path.join(self.get_temp_dir(), input_graph_name) 77 input_saver_def_path = "" 78 input_binary = False 79 output_node_names = "output_node" 80 restore_op_name = "save/restore_all" 81 filename_tensor_name = "save/Const:0" 82 output_graph_path = os.path.join(self.get_temp_dir(), output_graph_name) 83 clear_devices = False 84 85 freeze_graph.freeze_graph( 86 input_graph_path, 87 input_saver_def_path, 88 input_binary, 89 checkpoint_path, 90 output_node_names, 91 restore_op_name, 92 filename_tensor_name, 93 output_graph_path, 94 clear_devices, 95 "", 96 "", 97 "", 98 checkpoint_version=saver_write_version) 99 100 # Now we make sure the variable is now a constant, and that the graph still 101 # produces the expected result. 102 with ops.Graph().as_default(): 103 output_graph_def = graph_pb2.GraphDef() 104 with open(output_graph_path, "rb") as f: 105 output_graph_def.ParseFromString(f.read()) 106 _ = importer.import_graph_def(output_graph_def, name="") 107 108 self.assertEqual(4, len(output_graph_def.node)) 109 for node in output_graph_def.node: 110 self.assertNotEqual("VariableV2", node.op) 111 self.assertNotEqual("Variable", node.op) 112 113 with session.Session() as sess: 114 output_node = sess.graph.get_tensor_by_name("output_node:0") 115 output = sess.run(output_node) 116 self.assertNear(2.0, output, 0.00001) 117 118 def _createTFExampleString(self, feature_name, feature_value): 119 """Create a serialized tensorflow example.""" 120 example = example_pb2.Example() 121 example.features.feature[feature_name].float_list.value.extend([ 122 feature_value]) 123 return example.SerializeToString() 124 125 def _writeDummySavedModel(self, path, feature_name, tags): 126 """Writes a classifier with two input features to the given path.""" 127 with ops.Graph().as_default(): 128 examples = array_ops.placeholder(dtypes.string, name="input_node") 129 feature_configs = { 130 feature_name: parsing_ops.FixedLenFeature(shape=[], 131 dtype=dtypes.float32), 132 } 133 features = parsing_ops.parse_example(examples, feature_configs) 134 feature = features[feature_name] 135 136 variable_node = variables.VariableV1(1.0, name="variable_node") 137 scores = math_ops.multiply(variable_node, feature, name="output_node") 138 class_feature = array_ops.fill(array_ops.shape(feature), 139 "class_%s" % feature_name) 140 classes = array_ops.transpose(class_feature) 141 142 with session.Session() as sess: 143 sess.run(variables.global_variables_initializer()) 144 signature = ( 145 signature_def_utils.classification_signature_def( 146 examples=examples, 147 classes=classes, 148 scores=scores,)) 149 builder = saved_model_builder.SavedModelBuilder(path) 150 builder.add_meta_graph_and_variables( 151 sess, 152 tags, 153 signature_def_map={ 154 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: 155 signature, 156 }, 157 ) 158 builder.save(as_text=True) 159 160 @test_util.run_v1_only("b/120545219") 161 def testFreezeGraphV1(self): 162 self._testFreezeGraph(saver_pb2.SaverDef.V1) 163 164 @test_util.run_v1_only("b/120545219") 165 def testFreezeGraphV2(self): 166 self._testFreezeGraph(saver_pb2.SaverDef.V2) 167 168 def testFreezeMetaGraph(self): 169 tmp_dir = self.get_temp_dir() 170 checkpoint_prefix = os.path.join(tmp_dir, "meta_graph_checkpoint") 171 checkpoint_state_name = "checkpoint_state" 172 output_graph_filename = os.path.join(tmp_dir, "output_graph.pb") 173 174 with ops.Graph().as_default(): 175 variable_node = variables.VariableV1(1.0, name="variable_node") 176 output_node = math_ops.multiply(variable_node, 2.0, name="output_node") 177 sess = session.Session() 178 init = variables.global_variables_initializer() 179 sess.run(init) 180 output = sess.run(output_node) 181 self.assertNear(2.0, output, 0.00001) 182 saver = saver_lib.Saver() 183 checkpoint_path = saver.save( 184 sess, 185 checkpoint_prefix, 186 global_step=0, 187 latest_filename=checkpoint_state_name) 188 189 input_saver_def_path = "" 190 input_binary = True 191 output_node_names = "output_node" 192 restore_op_name = "save/restore_all" 193 filename_tensor_name = "save/Const:0" 194 clear_devices = False 195 input_meta_graph = checkpoint_path + ".meta" 196 197 freeze_graph.freeze_graph( 198 "", input_saver_def_path, input_binary, checkpoint_path, 199 output_node_names, restore_op_name, filename_tensor_name, 200 output_graph_filename, clear_devices, "", "", "", input_meta_graph) 201 202 # Now we make sure the variable is now a constant, and that the graph still 203 # produces the expected result. 204 with ops.Graph().as_default(): 205 output_graph_def = graph_pb2.GraphDef() 206 with open(output_graph_filename, "rb") as f: 207 output_graph_def.ParseFromString(f.read()) 208 _ = importer.import_graph_def(output_graph_def, name="") 209 210 self.assertEqual(4, len(output_graph_def.node)) 211 for node in output_graph_def.node: 212 self.assertNotEqual("VariableV2", node.op) 213 self.assertNotEqual("Variable", node.op) 214 215 with session.Session() as sess: 216 output_node = sess.graph.get_tensor_by_name("output_node:0") 217 output = sess.run(output_node) 218 self.assertNear(2.0, output, 0.00001) 219 220 @parameterized.named_parameters( 221 ("empty_tags_set", "", []), 222 ("default_tags_set", tag_constants.SERVING, [tag_constants.SERVING])) 223 def testFreezeSavedModel(self, tags_string, tags_list): 224 tmp_dir = self.get_temp_dir() 225 saved_model_dir = os.path.join(tmp_dir, "saved_model_dir") 226 feature_name = "feature" 227 self._writeDummySavedModel(saved_model_dir, feature_name, tags_list) 228 output_graph_filename = os.path.join(tmp_dir, "output_graph.pb") 229 230 input_saved_model_dir = saved_model_dir 231 output_node_names = "output_node" 232 input_binary = False 233 input_saver_def_path = False 234 restore_op_name = None 235 filename_tensor_name = None 236 clear_devices = False 237 input_meta_graph = False 238 checkpoint_path = None 239 input_graph_filename = None 240 saved_model_tags = tags_string 241 242 freeze_graph.freeze_graph(input_graph_filename, input_saver_def_path, 243 input_binary, checkpoint_path, output_node_names, 244 restore_op_name, filename_tensor_name, 245 output_graph_filename, clear_devices, "", "", "", 246 input_meta_graph, input_saved_model_dir, 247 saved_model_tags) 248 249 # Now we make sure the variable is now a constant, and that the graph still 250 # produces the expected result. 251 with ops.Graph().as_default(): 252 output_graph_def = graph_pb2.GraphDef() 253 with open(output_graph_filename, "rb") as f: 254 output_graph_def.ParseFromString(f.read()) 255 _ = importer.import_graph_def(output_graph_def, name="") 256 257 if any(u"ParseExampleV2" in node.name for node in output_graph_def.node): 258 expected_node_count = 10 259 else: 260 expected_node_count = 8 261 self.assertEqual(expected_node_count, len(output_graph_def.node)) 262 for node in output_graph_def.node: 263 self.assertNotEqual("VariableV2", node.op) 264 self.assertNotEqual("Variable", node.op) 265 266 feature_value = 2.0 267 example = self._createTFExampleString(feature_name, feature_value) 268 with session.Session() as sess: 269 input_node = sess.graph.get_tensor_by_name("input_node:0") 270 output_node = sess.graph.get_tensor_by_name("output_node:0") 271 output = sess.run(output_node, feed_dict={input_node: [example]}) 272 self.assertNear(feature_value, output, 0.00001) 273 274 def testSinglePartitionedVariable(self): 275 """Ensures partitioned variables fail cleanly with freeze graph.""" 276 checkpoint_prefix = os.path.join(self.get_temp_dir(), "saved_checkpoint") 277 checkpoint_state_name = "checkpoint_state" 278 input_graph_name = "input_graph.pb" 279 output_graph_name = "output_graph.pb" 280 281 # Create a graph with partition variables. When weights are partitioned into 282 # a single partition, the weights variable is followed by a identity -> 283 # identity (an additional identity node). 284 partitioner = partitioned_variables.fixed_size_partitioner(1) 285 with ops.Graph().as_default(): 286 with variable_scope.variable_scope("part", partitioner=partitioner): 287 batch_size, height, width, depth = 5, 128, 128, 3 288 input1 = array_ops.zeros( 289 (batch_size, height, width, depth), name="input1") 290 input2 = array_ops.zeros( 291 (batch_size, height, width, depth), name="input2") 292 293 num_nodes = depth 294 filter1 = variable_scope.get_variable("filter", [num_nodes, num_nodes]) 295 filter2 = array_ops.reshape(filter1, [1, 1, num_nodes, num_nodes]) 296 conv = nn.conv2d( 297 input=input1, filter=filter2, strides=[1, 1, 1, 1], padding="SAME") 298 node = math_ops.add(conv, input2, name="test/add") 299 node = nn.relu6(node, name="test/relu6") 300 301 # Save graph and checkpoints. 302 sess = session.Session() 303 sess.run(variables.global_variables_initializer()) 304 305 saver = saver_lib.Saver() 306 checkpoint_path = saver.save( 307 sess, 308 checkpoint_prefix, 309 global_step=0, 310 latest_filename=checkpoint_state_name) 311 graph_io.write_graph(sess.graph, self.get_temp_dir(), input_graph_name) 312 313 # Ensure this graph has partition variables. 314 self.assertTrue([ 315 tensor.name.split(":")[0] 316 for op in sess.graph.get_operations() 317 for tensor in op.values() 318 if re.search(r"/part_\d+/", tensor.name) 319 ]) 320 321 # Test freezing graph doesn't make it crash. 322 output_node_names = "save/restore_all" 323 output_graph_path = os.path.join(self.get_temp_dir(), output_graph_name) 324 325 with self.assertRaises(ValueError): 326 freeze_graph.freeze_graph_with_def_protos( 327 input_graph_def=sess.graph_def, 328 input_saver_def=None, 329 input_checkpoint=checkpoint_path, 330 output_node_names=output_node_names, 331 restore_op_name="save/restore_all", # default value 332 filename_tensor_name="save/Const:0", # default value 333 output_graph=output_graph_path, 334 clear_devices=False, 335 initializer_nodes="") 336 337 338if __name__ == "__main__": 339 test.main() 340