xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/python/convert_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TensorFlow Lite Python Interface: Sanity check."""
16import numpy as np
17
18from tensorflow.lite.python import convert
19from tensorflow.lite.python import op_hint
20from tensorflow.lite.python.interpreter import Interpreter
21from tensorflow.python.client import session
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import test_util
25from tensorflow.python.framework.graph_util_impl import _bfs_for_reachable_nodes
26from tensorflow.python.framework.graph_util_impl import _extract_graph_summary
27from tensorflow.python.framework.graph_util_impl import _node_name
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import math_ops
30from tensorflow.python.platform import test
31
32
33class ConvertTest(test_util.TensorFlowTestCase):
34
35  def testBasic(self):
36    with ops.Graph().as_default():
37      in_tensor = array_ops.placeholder(
38          shape=[1, 16, 16, 3], dtype=dtypes.float32)
39      out_tensor = in_tensor + in_tensor
40      sess = session.Session()
41
42    # Try running on valid graph
43    tflite_model = convert.convert_graphdef(
44        sess.graph_def, input_tensors=[in_tensor], output_tensors=[out_tensor])
45    self.assertTrue(tflite_model)
46
47  def testQuantization(self):
48    with ops.Graph().as_default():
49      in_tensor = array_ops.placeholder(
50          shape=[1, 16, 16, 3], dtype=dtypes.float32)
51      out_tensor = array_ops.fake_quant_with_min_max_args(
52          in_tensor + in_tensor, min=0., max=1.)
53      sess = session.Session()
54
55    tflite_model = convert.convert_graphdef(
56        sess.graph_def,
57        input_tensors=[in_tensor],
58        output_tensors=[out_tensor],
59        inference_type=dtypes.uint8,
60        quantized_input_stats=[(0., 1.)])
61    self.assertTrue(tflite_model)
62
63  def testGraphDefBasic(self):
64    with ops.Graph().as_default():
65      in_tensor = array_ops.placeholder(
66          shape=[1, 16, 16, 3], dtype=dtypes.float32, name="input")
67      _ = in_tensor + in_tensor
68      sess = session.Session()
69
70    tflite_model = convert.convert_graphdef_with_arrays(
71        sess.graph_def,
72        input_arrays_with_shape=[("input", [1, 16, 16, 3])],
73        output_arrays=["add"],
74        control_output_arrays=None,
75        inference_type=dtypes.float32,
76        enable_mlir_converter=False)
77    self.assertTrue(tflite_model)
78
79    # Check values from converted model.
80    interpreter = Interpreter(model_content=tflite_model)
81    interpreter.allocate_tensors()
82
83    input_details = interpreter.get_input_details()
84    self.assertEqual(1, len(input_details))
85    self.assertEqual("input", input_details[0]["name"])
86    self.assertEqual(np.float32, input_details[0]["dtype"])
87    self.assertTrue(([1, 16, 16, 3] == input_details[0]["shape"]).all())
88    self.assertEqual((0., 0.), input_details[0]["quantization"])
89
90    output_details = interpreter.get_output_details()
91    self.assertEqual(1, len(output_details))
92    self.assertEqual("add", output_details[0]["name"])
93    self.assertEqual(np.float32, output_details[0]["dtype"])
94    self.assertTrue(([1, 16, 16, 3] == output_details[0]["shape"]).all())
95    self.assertEqual((0., 0.), output_details[0]["quantization"])
96
97  def testGraphDefQuantization(self):
98    with ops.Graph().as_default():
99      in_tensor_1 = array_ops.placeholder(
100          shape=[1, 16, 16, 3], dtype=dtypes.float32, name="inputA")
101      in_tensor_2 = array_ops.placeholder(
102          shape=[1, 16, 16, 3], dtype=dtypes.float32, name="inputB")
103      _ = array_ops.fake_quant_with_min_max_args(
104          in_tensor_1 + in_tensor_2, min=0., max=1., name="output")
105      sess = session.Session()
106
107    tflite_model = convert.convert_graphdef_with_arrays(
108        sess.graph_def,
109        input_arrays_with_shape=[("inputA", [1, 16, 16, 3]),
110                                 ("inputB", [1, 16, 16, 3])],
111        output_arrays=["output"],
112        control_output_arrays=None,
113        inference_type=dtypes.uint8,
114        quantized_input_stats=[(0., 1.), (0., 1.)],
115        enable_mlir_converter=False,
116    )
117    self.assertTrue(tflite_model)
118
119    # Check values from converted model.
120    interpreter = Interpreter(model_content=tflite_model)
121    interpreter.allocate_tensors()
122
123    input_details = interpreter.get_input_details()
124    self.assertEqual(2, len(input_details))
125    self.assertEqual("inputA", input_details[0]["name"])
126    self.assertEqual(np.uint8, input_details[0]["dtype"])
127    self.assertTrue(([1, 16, 16, 3] == input_details[0]["shape"]).all())
128    self.assertEqual((1., 0.),
129                     input_details[0]["quantization"])  # scale, zero_point
130
131    self.assertEqual("inputB", input_details[1]["name"])
132    self.assertEqual(np.uint8, input_details[1]["dtype"])
133    self.assertTrue(([1, 16, 16, 3] == input_details[1]["shape"]).all())
134    self.assertEqual((1., 0.),
135                     input_details[1]["quantization"])  # scale, zero_point
136
137    output_details = interpreter.get_output_details()
138    self.assertEqual(1, len(output_details))
139    self.assertEqual("output", output_details[0]["name"])
140    self.assertEqual(np.uint8, output_details[0]["dtype"])
141    self.assertTrue(([1, 16, 16, 3] == output_details[0]["shape"]).all())
142    self.assertGreater(output_details[0]["quantization"][0], 0)  # scale
143
144  def testGraphDefQuantizationInvalid(self):
145    with ops.Graph().as_default():
146      in_tensor_1 = array_ops.placeholder(
147          shape=[1, 16, 16, 3], dtype=dtypes.float32, name="inputA")
148      in_tensor_2 = array_ops.placeholder(
149          shape=[1, 16, 16, 3], dtype=dtypes.float32, name="inputB")
150      _ = array_ops.fake_quant_with_min_max_args(
151          in_tensor_1 + in_tensor_2, min=0., max=1., name="output")
152      sess = session.Session()
153
154    with self.assertRaises(ValueError) as error:
155      convert.convert_graphdef_with_arrays(
156          sess.graph_def,
157          input_arrays_with_shape=[("inputA", [1, 16, 16, 3]),
158                                   ("inputB", [1, 16, 16, 3])],
159          output_arrays=["output"],
160          control_output_arrays=None,
161          inference_type=dtypes.uint8,
162          enable_mlir_converter=False)
163    self.assertEqual(
164        "The `quantized_input_stats` flag must be defined when either "
165        "`inference_type` flag or `inference_input_type` flag is set to "
166        "tf.int8 or tf.uint8.", str(error.exception))
167
168
169class ConvertTestOpHint(test_util.TensorFlowTestCase):
170  """Test the hint to stub functionality."""
171
172  def _getGraphOpTypes(self, graphdef, output_nodes):
173    """Returns used op types in `graphdef` reachable from `output_nodes`.
174
175    This is used to check that after the stub transformation the expected
176    nodes are there.
177
178    NOTE: this is not a exact test that the graph is the correct output, but
179      it balances compact expressibility of test with sanity checking.
180
181    Args:
182      graphdef: TensorFlow proto graphdef.
183      output_nodes: A list of output node names that we need to reach.
184
185    Returns:
186      A set of node types reachable from `output_nodes`.
187    """
188    name_to_input_name, name_to_node, _ = (
189        _extract_graph_summary(graphdef))
190    # Find all nodes that are needed by the outputs
191    used_node_names = _bfs_for_reachable_nodes(output_nodes, name_to_input_name)
192    return set([name_to_node[node_name].op for node_name in used_node_names])
193
194  def _countIdentities(self, nodes):
195    """Count the number of "Identity" op types in the list of proto nodes.
196
197    Args:
198      nodes: NodeDefs of the graph.
199
200    Returns:
201      The number of nodes with op type "Identity" found.
202    """
203    return len([x for x in nodes if x.op == "Identity"])
204
205  def testSwishLiteHint(self):
206    """Makes a custom op swish and makes sure it gets converted as a unit."""
207    with ops.Graph().as_default():
208      image = array_ops.constant([1., 2., 3., 4.])
209      swish_scale = array_ops.constant(1.0)
210
211      def _swish(input_tensor, scale):
212        custom = op_hint.OpHint("cool_activation")
213        input_tensor, scale = custom.add_inputs(input_tensor, scale)
214        output = math_ops.sigmoid(input_tensor) * input_tensor * scale
215        output, = custom.add_outputs(output)
216        return output
217
218      output = array_ops.identity(
219          _swish(image, swish_scale), name="ModelOutput")
220
221      with self.cached_session() as sess:
222        # check if identities have been put into the graph (2 input, 1 output,
223        # and 1 final output).
224        self.assertEqual(self._countIdentities(sess.graph_def.node), 4)
225
226        stubbed_graphdef = op_hint.convert_op_hints_to_stubs(
227            graph_def=sess.graph_def)
228
229        self.assertEqual(
230            self._getGraphOpTypes(
231                stubbed_graphdef,
232                output_nodes=[op_hint._tensor_name_base(output.name)]),
233            set(["cool_activation", "Const", "Identity"]))
234
235  def testScaleAndBiasAndIdentity(self):
236    """This tests a scaled add which has 3 inputs and 2 outputs."""
237    with ops.Graph().as_default():
238      a = array_ops.constant(1.)
239      x = array_ops.constant([2., 3.])
240      b = array_ops.constant([4., 5.])
241
242      def _scaled_and_bias_and_identity(a, x, b):
243        custom = op_hint.OpHint("scale_and_bias_and_identity")
244        a, x, b = custom.add_inputs(a, x, b)
245        return custom.add_outputs(a * x + b, x)
246
247      output = array_ops.identity(
248          _scaled_and_bias_and_identity(a, x, b), name="ModelOutput")
249
250      with self.cached_session() as sess:
251        # make sure one identity for each input (3) and output (2) => 3 + 2 = 5
252        # +1 for the final output
253        self.assertEqual(self._countIdentities(sess.graph_def.node), 6)
254
255        stubbed_graphdef = op_hint.convert_op_hints_to_stubs(
256            graph_def=sess.graph_def)
257
258        self.assertEqual(
259            self._getGraphOpTypes(
260                stubbed_graphdef,
261                output_nodes=[op_hint._tensor_name_base(output.name)]),
262            set(["scale_and_bias_and_identity", "Const", "Identity", "Pack"]))
263
264  def testTwoFunctions(self):
265    """Tests if two functions are converted correctly."""
266    with ops.Graph().as_default():
267      a = array_ops.constant([1.])
268      b = array_ops.constant([1.])
269
270      def _double_values(x):
271        custom = op_hint.OpHint("add_test")
272        x, = custom.add_inputs(x)
273        output = math_ops.multiply(x, x)
274        output, = custom.add_outputs(output)
275        return output
276
277      output = array_ops.identity(
278          math_ops.add(_double_values(a), _double_values(b)),
279          name="ModelOutput")
280
281      with self.cached_session() as sess:
282        # make sure one identity for each input (2) and output (2) => 2 + 2
283        # +1 for the final output
284        self.assertEqual(self._countIdentities(sess.graph_def.node), 5)
285        stubbed_graphdef = op_hint.convert_op_hints_to_stubs(
286            graph_def=sess.graph_def)
287        self.assertEqual(
288            self._getGraphOpTypes(
289                stubbed_graphdef,
290                output_nodes=[op_hint._tensor_name_base(output.name)]),
291            set(["add_test", "Const", "Identity", "AddV2"]))
292
293  def _get_input_index(self, x):
294    return x.op.node_def.attr[op_hint.OpHint.FUNCTION_INPUT_INDEX_ATTR].i
295
296  def _get_output_index(self, x):
297    return x.op.node_def.attr[op_hint.OpHint.FUNCTION_OUTPUT_INDEX_ATTR].i
298
299  def _get_sort_index(self, x):
300    return x.op.node_def.attr[op_hint.OpHint.FUNCTION_SORT_INDEX_ATTR].i
301
302  def testTags(self):
303    """Test if multiple args with the same tag are grouped."""
304    with ops.Graph().as_default():
305      a = array_ops.constant([1.])
306      b = array_ops.constant([2.])
307      c = array_ops.constant([3.])
308      d = array_ops.constant([4.])
309      custom = op_hint.OpHint("test_tag")
310      a = custom.add_input(
311          a, tag="mytag", aggregate=op_hint.OpHint.AGGREGATE_STACK)
312      b, = custom.add_inputs(b)
313      c = custom.add_input(
314          c, tag="mytag", aggregate=op_hint.OpHint.AGGREGATE_STACK)
315      d = custom.add_input(
316          d, tag="mytag2", aggregate=op_hint.OpHint.AGGREGATE_STACK)
317      res = math_ops.add(math_ops.mul(a, b), math_ops.mul(c, b))
318      custom.add_outputs([res])
319      with self.cached_session():
320        self.assertEqual(self._get_input_index(a), 0)
321        self.assertEqual(self._get_sort_index(a), 0)
322        self.assertEqual(self._get_input_index(b), 1)
323        self.assertEqual(self._get_sort_index(b), 0)
324        self.assertEqual(self._get_input_index(c), 0)
325        self.assertEqual(self._get_sort_index(c), 1)
326
327  def testOverrideIndex(self):
328    with ops.Graph().as_default():
329      a = array_ops.constant([1.])
330      b = array_ops.constant([2.])
331      c = array_ops.constant([3.])
332      custom = op_hint.OpHint("test_override")
333      b = custom.add_input(b)  # should auto assign 0
334      a = custom.add_input(a, index_override=1)
335      c = custom.add_input(c)  # should auto assign 2
336      with self.cached_session():
337        self.assertEqual(self._get_input_index(a), 1)
338        self.assertEqual(self._get_input_index(b), 0)
339        self.assertEqual(self._get_input_index(c), 2)
340
341  def testAggregate(self):
342    with ops.Graph().as_default():
343      a = array_ops.constant([3., 4.])
344      b = array_ops.constant([5., 6.])
345      hint = op_hint.OpHint("agg")
346      a0, a1 = array_ops.unstack(a)
347      b0, b1 = array_ops.unstack(b)
348
349      a0 = hint.add_input(a0, tag="c", aggregate=op_hint.OpHint.AGGREGATE_STACK)
350      b0 = hint.add_input(b0, tag="n", aggregate=op_hint.OpHint.AGGREGATE_STACK)
351      a1 = hint.add_input(a1, tag="c", aggregate=op_hint.OpHint.AGGREGATE_STACK)
352      b1 = hint.add_input(b1, tag="n", aggregate=op_hint.OpHint.AGGREGATE_STACK)
353
354      c0 = math_ops.add(a0, b0, name="addleft")
355      c1 = math_ops.add(a1, b1, name="addright")
356      c0 = hint.add_output(
357          c0, tag="out", aggregate=op_hint.OpHint.AGGREGATE_STACK)
358      c1 = hint.add_output(
359          c1, tag="out", aggregate=op_hint.OpHint.AGGREGATE_STACK)
360
361      curr = array_ops.stack([c0, c1])
362      output = array_ops.identity(curr, name="FINAL_OUTPUT")
363      with self.cached_session() as sess:
364        stubbed_graphdef = op_hint.convert_op_hints_to_stubs(
365            graph_def=sess.graph_def)
366        self.assertEqual(
367            self._getGraphOpTypes(
368                stubbed_graphdef,
369                output_nodes=[op_hint._tensor_name_base(output.name)]),
370            set(["agg", "Const", "Identity"]))
371
372  def testFindHintedOutputNodes(self):
373    """Test if all hinted output nodes are correctly found."""
374    with ops.Graph().as_default():
375
376      def _build_ophinted_op(name, input1, input2):
377        custom_op = op_hint.OpHint(name)
378        input1 = custom_op.add_input(input1)
379        input2 = custom_op.add_input(input2)
380        output = math_ops.mul(input1, input2)
381        return custom_op.add_output(output)
382
383      output_1 = _build_ophinted_op("custom_op_1", array_ops.constant([1.]),
384                                    array_ops.constant([2.]))
385      output_2 = _build_ophinted_op("custom_op_2", array_ops.constant([3.]),
386                                    array_ops.constant([4.]))
387      with self.cached_session() as sess:
388        hinted_outputs_nodes = op_hint.find_all_hinted_output_nodes(sess)
389        expected_hinted_output_nodes = [
390            _node_name(output_1.name),
391            _node_name(output_2.name)
392        ]
393        self.assertEqual(
394            len(hinted_outputs_nodes), len(expected_hinted_output_nodes))
395
396
397if __name__ == "__main__":
398  test.main()
399