xref: /aosp_15_r20/external/tensorflow/tensorflow/python/tools/freeze_graph_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests the graph freezing tool."""
16
17import os
18import re
19
20from absl.testing import parameterized
21
22from tensorflow.core.example import example_pb2
23from tensorflow.core.framework import graph_pb2
24from tensorflow.core.protobuf import saver_pb2
25from tensorflow.python.client import session
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import graph_io
28from tensorflow.python.framework import importer
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import test_util
31from tensorflow.python.ops import array_ops
32from tensorflow.python.ops import math_ops
33from tensorflow.python.ops import nn
34from tensorflow.python.ops import parsing_ops
35from tensorflow.python.ops import partitioned_variables
36from tensorflow.python.ops import variable_scope
37from tensorflow.python.ops import variables
38from tensorflow.python.platform import test
39from tensorflow.python.saved_model import builder as saved_model_builder
40from tensorflow.python.saved_model import signature_constants
41from tensorflow.python.saved_model import signature_def_utils
42from tensorflow.python.saved_model import tag_constants
43from tensorflow.python.tools import freeze_graph
44from tensorflow.python.training import saver as saver_lib
45
46
47class FreezeGraphTest(test_util.TensorFlowTestCase, parameterized.TestCase):
48
49  def _testFreezeGraph(self, saver_write_version):
50
51    checkpoint_prefix = os.path.join(self.get_temp_dir(), "saved_checkpoint")
52    checkpoint_state_name = "checkpoint_state"
53    input_graph_name = "input_graph.pb"
54    output_graph_name = "output_graph.pb"
55
56    # We'll create an input graph that has a single variable containing 1.0,
57    # and that then multiplies it by 2.
58    with ops.Graph().as_default():
59      variable_node = variables.VariableV1(1.0, name="variable_node")
60      output_node = math_ops.multiply(variable_node, 2.0, name="output_node")
61      sess = session.Session()
62      init = variables.global_variables_initializer()
63      sess.run(init)
64      output = sess.run(output_node)
65      self.assertNear(2.0, output, 0.00001)
66      saver = saver_lib.Saver(write_version=saver_write_version)
67      checkpoint_path = saver.save(
68          sess,
69          checkpoint_prefix,
70          global_step=0,
71          latest_filename=checkpoint_state_name)
72      graph_io.write_graph(sess.graph, self.get_temp_dir(), input_graph_name)
73
74    # We save out the graph to disk, and then call the const conversion
75    # routine.
76    input_graph_path = os.path.join(self.get_temp_dir(), input_graph_name)
77    input_saver_def_path = ""
78    input_binary = False
79    output_node_names = "output_node"
80    restore_op_name = "save/restore_all"
81    filename_tensor_name = "save/Const:0"
82    output_graph_path = os.path.join(self.get_temp_dir(), output_graph_name)
83    clear_devices = False
84
85    freeze_graph.freeze_graph(
86        input_graph_path,
87        input_saver_def_path,
88        input_binary,
89        checkpoint_path,
90        output_node_names,
91        restore_op_name,
92        filename_tensor_name,
93        output_graph_path,
94        clear_devices,
95        "",
96        "",
97        "",
98        checkpoint_version=saver_write_version)
99
100    # Now we make sure the variable is now a constant, and that the graph still
101    # produces the expected result.
102    with ops.Graph().as_default():
103      output_graph_def = graph_pb2.GraphDef()
104      with open(output_graph_path, "rb") as f:
105        output_graph_def.ParseFromString(f.read())
106        _ = importer.import_graph_def(output_graph_def, name="")
107
108      self.assertEqual(4, len(output_graph_def.node))
109      for node in output_graph_def.node:
110        self.assertNotEqual("VariableV2", node.op)
111        self.assertNotEqual("Variable", node.op)
112
113      with session.Session() as sess:
114        output_node = sess.graph.get_tensor_by_name("output_node:0")
115        output = sess.run(output_node)
116        self.assertNear(2.0, output, 0.00001)
117
118  def _createTFExampleString(self, feature_name, feature_value):
119    """Create a serialized tensorflow example."""
120    example = example_pb2.Example()
121    example.features.feature[feature_name].float_list.value.extend([
122        feature_value])
123    return example.SerializeToString()
124
125  def _writeDummySavedModel(self, path, feature_name, tags):
126    """Writes a classifier with two input features to the given path."""
127    with ops.Graph().as_default():
128      examples = array_ops.placeholder(dtypes.string, name="input_node")
129      feature_configs = {
130          feature_name: parsing_ops.FixedLenFeature(shape=[],
131                                                    dtype=dtypes.float32),
132      }
133      features = parsing_ops.parse_example(examples, feature_configs)
134      feature = features[feature_name]
135
136      variable_node = variables.VariableV1(1.0, name="variable_node")
137      scores = math_ops.multiply(variable_node, feature, name="output_node")
138      class_feature = array_ops.fill(array_ops.shape(feature),
139                                     "class_%s" % feature_name)
140      classes = array_ops.transpose(class_feature)
141
142      with session.Session() as sess:
143        sess.run(variables.global_variables_initializer())
144        signature = (
145            signature_def_utils.classification_signature_def(
146                examples=examples,
147                classes=classes,
148                scores=scores,))
149        builder = saved_model_builder.SavedModelBuilder(path)
150        builder.add_meta_graph_and_variables(
151            sess,
152            tags,
153            signature_def_map={
154                signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
155                    signature,
156            },
157        )
158        builder.save(as_text=True)
159
160  @test_util.run_v1_only("b/120545219")
161  def testFreezeGraphV1(self):
162    self._testFreezeGraph(saver_pb2.SaverDef.V1)
163
164  @test_util.run_v1_only("b/120545219")
165  def testFreezeGraphV2(self):
166    self._testFreezeGraph(saver_pb2.SaverDef.V2)
167
168  def testFreezeMetaGraph(self):
169    tmp_dir = self.get_temp_dir()
170    checkpoint_prefix = os.path.join(tmp_dir, "meta_graph_checkpoint")
171    checkpoint_state_name = "checkpoint_state"
172    output_graph_filename = os.path.join(tmp_dir, "output_graph.pb")
173
174    with ops.Graph().as_default():
175      variable_node = variables.VariableV1(1.0, name="variable_node")
176      output_node = math_ops.multiply(variable_node, 2.0, name="output_node")
177      sess = session.Session()
178      init = variables.global_variables_initializer()
179      sess.run(init)
180      output = sess.run(output_node)
181      self.assertNear(2.0, output, 0.00001)
182      saver = saver_lib.Saver()
183      checkpoint_path = saver.save(
184          sess,
185          checkpoint_prefix,
186          global_step=0,
187          latest_filename=checkpoint_state_name)
188
189    input_saver_def_path = ""
190    input_binary = True
191    output_node_names = "output_node"
192    restore_op_name = "save/restore_all"
193    filename_tensor_name = "save/Const:0"
194    clear_devices = False
195    input_meta_graph = checkpoint_path + ".meta"
196
197    freeze_graph.freeze_graph(
198        "", input_saver_def_path, input_binary, checkpoint_path,
199        output_node_names, restore_op_name, filename_tensor_name,
200        output_graph_filename, clear_devices, "", "", "", input_meta_graph)
201
202    # Now we make sure the variable is now a constant, and that the graph still
203    # produces the expected result.
204    with ops.Graph().as_default():
205      output_graph_def = graph_pb2.GraphDef()
206      with open(output_graph_filename, "rb") as f:
207        output_graph_def.ParseFromString(f.read())
208        _ = importer.import_graph_def(output_graph_def, name="")
209
210      self.assertEqual(4, len(output_graph_def.node))
211      for node in output_graph_def.node:
212        self.assertNotEqual("VariableV2", node.op)
213        self.assertNotEqual("Variable", node.op)
214
215      with session.Session() as sess:
216        output_node = sess.graph.get_tensor_by_name("output_node:0")
217        output = sess.run(output_node)
218        self.assertNear(2.0, output, 0.00001)
219
220  @parameterized.named_parameters(
221      ("empty_tags_set", "", []),
222      ("default_tags_set", tag_constants.SERVING, [tag_constants.SERVING]))
223  def testFreezeSavedModel(self, tags_string, tags_list):
224    tmp_dir = self.get_temp_dir()
225    saved_model_dir = os.path.join(tmp_dir, "saved_model_dir")
226    feature_name = "feature"
227    self._writeDummySavedModel(saved_model_dir, feature_name, tags_list)
228    output_graph_filename = os.path.join(tmp_dir, "output_graph.pb")
229
230    input_saved_model_dir = saved_model_dir
231    output_node_names = "output_node"
232    input_binary = False
233    input_saver_def_path = False
234    restore_op_name = None
235    filename_tensor_name = None
236    clear_devices = False
237    input_meta_graph = False
238    checkpoint_path = None
239    input_graph_filename = None
240    saved_model_tags = tags_string
241
242    freeze_graph.freeze_graph(input_graph_filename, input_saver_def_path,
243                              input_binary, checkpoint_path, output_node_names,
244                              restore_op_name, filename_tensor_name,
245                              output_graph_filename, clear_devices, "", "", "",
246                              input_meta_graph, input_saved_model_dir,
247                              saved_model_tags)
248
249    # Now we make sure the variable is now a constant, and that the graph still
250    # produces the expected result.
251    with ops.Graph().as_default():
252      output_graph_def = graph_pb2.GraphDef()
253      with open(output_graph_filename, "rb") as f:
254        output_graph_def.ParseFromString(f.read())
255        _ = importer.import_graph_def(output_graph_def, name="")
256
257      if any(u"ParseExampleV2" in node.name for node in output_graph_def.node):
258        expected_node_count = 10
259      else:
260        expected_node_count = 8
261      self.assertEqual(expected_node_count, len(output_graph_def.node))
262      for node in output_graph_def.node:
263        self.assertNotEqual("VariableV2", node.op)
264        self.assertNotEqual("Variable", node.op)
265
266      feature_value = 2.0
267      example = self._createTFExampleString(feature_name, feature_value)
268      with session.Session() as sess:
269        input_node = sess.graph.get_tensor_by_name("input_node:0")
270        output_node = sess.graph.get_tensor_by_name("output_node:0")
271        output = sess.run(output_node, feed_dict={input_node: [example]})
272        self.assertNear(feature_value, output, 0.00001)
273
274  def testSinglePartitionedVariable(self):
275    """Ensures partitioned variables fail cleanly with freeze graph."""
276    checkpoint_prefix = os.path.join(self.get_temp_dir(), "saved_checkpoint")
277    checkpoint_state_name = "checkpoint_state"
278    input_graph_name = "input_graph.pb"
279    output_graph_name = "output_graph.pb"
280
281    # Create a graph with partition variables. When weights are partitioned into
282    # a single partition, the weights variable is followed by a identity ->
283    # identity (an additional identity node).
284    partitioner = partitioned_variables.fixed_size_partitioner(1)
285    with ops.Graph().as_default():
286      with variable_scope.variable_scope("part", partitioner=partitioner):
287        batch_size, height, width, depth = 5, 128, 128, 3
288        input1 = array_ops.zeros(
289            (batch_size, height, width, depth), name="input1")
290        input2 = array_ops.zeros(
291            (batch_size, height, width, depth), name="input2")
292
293        num_nodes = depth
294        filter1 = variable_scope.get_variable("filter", [num_nodes, num_nodes])
295        filter2 = array_ops.reshape(filter1, [1, 1, num_nodes, num_nodes])
296        conv = nn.conv2d(
297            input=input1, filter=filter2, strides=[1, 1, 1, 1], padding="SAME")
298        node = math_ops.add(conv, input2, name="test/add")
299        node = nn.relu6(node, name="test/relu6")
300
301      # Save graph and checkpoints.
302      sess = session.Session()
303      sess.run(variables.global_variables_initializer())
304
305      saver = saver_lib.Saver()
306      checkpoint_path = saver.save(
307          sess,
308          checkpoint_prefix,
309          global_step=0,
310          latest_filename=checkpoint_state_name)
311      graph_io.write_graph(sess.graph, self.get_temp_dir(), input_graph_name)
312
313      # Ensure this graph has partition variables.
314      self.assertTrue([
315          tensor.name.split(":")[0]
316          for op in sess.graph.get_operations()
317          for tensor in op.values()
318          if re.search(r"/part_\d+/", tensor.name)
319      ])
320
321    # Test freezing graph doesn't make it crash.
322    output_node_names = "save/restore_all"
323    output_graph_path = os.path.join(self.get_temp_dir(), output_graph_name)
324
325    with self.assertRaises(ValueError):
326      freeze_graph.freeze_graph_with_def_protos(
327          input_graph_def=sess.graph_def,
328          input_saver_def=None,
329          input_checkpoint=checkpoint_path,
330          output_node_names=output_node_names,
331          restore_op_name="save/restore_all",  # default value
332          filename_tensor_name="save/Const:0",  # default value
333          output_graph=output_graph_path,
334          clear_devices=False,
335          initializer_nodes="")
336
337
338if __name__ == "__main__":
339  test.main()
340