1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""TensorFlow Lite Python Interface: Sanity check.""" 16import numpy as np 17 18from tensorflow.lite.python import convert 19from tensorflow.lite.python import op_hint 20from tensorflow.lite.python.interpreter import Interpreter 21from tensorflow.python.client import session 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import ops 24from tensorflow.python.framework import test_util 25from tensorflow.python.framework.graph_util_impl import _bfs_for_reachable_nodes 26from tensorflow.python.framework.graph_util_impl import _extract_graph_summary 27from tensorflow.python.framework.graph_util_impl import _node_name 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import math_ops 30from tensorflow.python.platform import test 31 32 33class ConvertTest(test_util.TensorFlowTestCase): 34 35 def testBasic(self): 36 with ops.Graph().as_default(): 37 in_tensor = array_ops.placeholder( 38 shape=[1, 16, 16, 3], dtype=dtypes.float32) 39 out_tensor = in_tensor + in_tensor 40 sess = session.Session() 41 42 # Try running on valid graph 43 tflite_model = convert.convert_graphdef( 44 sess.graph_def, input_tensors=[in_tensor], output_tensors=[out_tensor]) 45 self.assertTrue(tflite_model) 46 47 def testQuantization(self): 48 with ops.Graph().as_default(): 49 in_tensor = array_ops.placeholder( 50 shape=[1, 16, 16, 3], dtype=dtypes.float32) 51 out_tensor = array_ops.fake_quant_with_min_max_args( 52 in_tensor + in_tensor, min=0., max=1.) 53 sess = session.Session() 54 55 tflite_model = convert.convert_graphdef( 56 sess.graph_def, 57 input_tensors=[in_tensor], 58 output_tensors=[out_tensor], 59 inference_type=dtypes.uint8, 60 quantized_input_stats=[(0., 1.)]) 61 self.assertTrue(tflite_model) 62 63 def testGraphDefBasic(self): 64 with ops.Graph().as_default(): 65 in_tensor = array_ops.placeholder( 66 shape=[1, 16, 16, 3], dtype=dtypes.float32, name="input") 67 _ = in_tensor + in_tensor 68 sess = session.Session() 69 70 tflite_model = convert.convert_graphdef_with_arrays( 71 sess.graph_def, 72 input_arrays_with_shape=[("input", [1, 16, 16, 3])], 73 output_arrays=["add"], 74 control_output_arrays=None, 75 inference_type=dtypes.float32, 76 enable_mlir_converter=False) 77 self.assertTrue(tflite_model) 78 79 # Check values from converted model. 80 interpreter = Interpreter(model_content=tflite_model) 81 interpreter.allocate_tensors() 82 83 input_details = interpreter.get_input_details() 84 self.assertEqual(1, len(input_details)) 85 self.assertEqual("input", input_details[0]["name"]) 86 self.assertEqual(np.float32, input_details[0]["dtype"]) 87 self.assertTrue(([1, 16, 16, 3] == input_details[0]["shape"]).all()) 88 self.assertEqual((0., 0.), input_details[0]["quantization"]) 89 90 output_details = interpreter.get_output_details() 91 self.assertEqual(1, len(output_details)) 92 self.assertEqual("add", output_details[0]["name"]) 93 self.assertEqual(np.float32, output_details[0]["dtype"]) 94 self.assertTrue(([1, 16, 16, 3] == output_details[0]["shape"]).all()) 95 self.assertEqual((0., 0.), output_details[0]["quantization"]) 96 97 def testGraphDefQuantization(self): 98 with ops.Graph().as_default(): 99 in_tensor_1 = array_ops.placeholder( 100 shape=[1, 16, 16, 3], dtype=dtypes.float32, name="inputA") 101 in_tensor_2 = array_ops.placeholder( 102 shape=[1, 16, 16, 3], dtype=dtypes.float32, name="inputB") 103 _ = array_ops.fake_quant_with_min_max_args( 104 in_tensor_1 + in_tensor_2, min=0., max=1., name="output") 105 sess = session.Session() 106 107 tflite_model = convert.convert_graphdef_with_arrays( 108 sess.graph_def, 109 input_arrays_with_shape=[("inputA", [1, 16, 16, 3]), 110 ("inputB", [1, 16, 16, 3])], 111 output_arrays=["output"], 112 control_output_arrays=None, 113 inference_type=dtypes.uint8, 114 quantized_input_stats=[(0., 1.), (0., 1.)], 115 enable_mlir_converter=False, 116 ) 117 self.assertTrue(tflite_model) 118 119 # Check values from converted model. 120 interpreter = Interpreter(model_content=tflite_model) 121 interpreter.allocate_tensors() 122 123 input_details = interpreter.get_input_details() 124 self.assertEqual(2, len(input_details)) 125 self.assertEqual("inputA", input_details[0]["name"]) 126 self.assertEqual(np.uint8, input_details[0]["dtype"]) 127 self.assertTrue(([1, 16, 16, 3] == input_details[0]["shape"]).all()) 128 self.assertEqual((1., 0.), 129 input_details[0]["quantization"]) # scale, zero_point 130 131 self.assertEqual("inputB", input_details[1]["name"]) 132 self.assertEqual(np.uint8, input_details[1]["dtype"]) 133 self.assertTrue(([1, 16, 16, 3] == input_details[1]["shape"]).all()) 134 self.assertEqual((1., 0.), 135 input_details[1]["quantization"]) # scale, zero_point 136 137 output_details = interpreter.get_output_details() 138 self.assertEqual(1, len(output_details)) 139 self.assertEqual("output", output_details[0]["name"]) 140 self.assertEqual(np.uint8, output_details[0]["dtype"]) 141 self.assertTrue(([1, 16, 16, 3] == output_details[0]["shape"]).all()) 142 self.assertGreater(output_details[0]["quantization"][0], 0) # scale 143 144 def testGraphDefQuantizationInvalid(self): 145 with ops.Graph().as_default(): 146 in_tensor_1 = array_ops.placeholder( 147 shape=[1, 16, 16, 3], dtype=dtypes.float32, name="inputA") 148 in_tensor_2 = array_ops.placeholder( 149 shape=[1, 16, 16, 3], dtype=dtypes.float32, name="inputB") 150 _ = array_ops.fake_quant_with_min_max_args( 151 in_tensor_1 + in_tensor_2, min=0., max=1., name="output") 152 sess = session.Session() 153 154 with self.assertRaises(ValueError) as error: 155 convert.convert_graphdef_with_arrays( 156 sess.graph_def, 157 input_arrays_with_shape=[("inputA", [1, 16, 16, 3]), 158 ("inputB", [1, 16, 16, 3])], 159 output_arrays=["output"], 160 control_output_arrays=None, 161 inference_type=dtypes.uint8, 162 enable_mlir_converter=False) 163 self.assertEqual( 164 "The `quantized_input_stats` flag must be defined when either " 165 "`inference_type` flag or `inference_input_type` flag is set to " 166 "tf.int8 or tf.uint8.", str(error.exception)) 167 168 169class ConvertTestOpHint(test_util.TensorFlowTestCase): 170 """Test the hint to stub functionality.""" 171 172 def _getGraphOpTypes(self, graphdef, output_nodes): 173 """Returns used op types in `graphdef` reachable from `output_nodes`. 174 175 This is used to check that after the stub transformation the expected 176 nodes are there. 177 178 NOTE: this is not a exact test that the graph is the correct output, but 179 it balances compact expressibility of test with sanity checking. 180 181 Args: 182 graphdef: TensorFlow proto graphdef. 183 output_nodes: A list of output node names that we need to reach. 184 185 Returns: 186 A set of node types reachable from `output_nodes`. 187 """ 188 name_to_input_name, name_to_node, _ = ( 189 _extract_graph_summary(graphdef)) 190 # Find all nodes that are needed by the outputs 191 used_node_names = _bfs_for_reachable_nodes(output_nodes, name_to_input_name) 192 return set([name_to_node[node_name].op for node_name in used_node_names]) 193 194 def _countIdentities(self, nodes): 195 """Count the number of "Identity" op types in the list of proto nodes. 196 197 Args: 198 nodes: NodeDefs of the graph. 199 200 Returns: 201 The number of nodes with op type "Identity" found. 202 """ 203 return len([x for x in nodes if x.op == "Identity"]) 204 205 def testSwishLiteHint(self): 206 """Makes a custom op swish and makes sure it gets converted as a unit.""" 207 with ops.Graph().as_default(): 208 image = array_ops.constant([1., 2., 3., 4.]) 209 swish_scale = array_ops.constant(1.0) 210 211 def _swish(input_tensor, scale): 212 custom = op_hint.OpHint("cool_activation") 213 input_tensor, scale = custom.add_inputs(input_tensor, scale) 214 output = math_ops.sigmoid(input_tensor) * input_tensor * scale 215 output, = custom.add_outputs(output) 216 return output 217 218 output = array_ops.identity( 219 _swish(image, swish_scale), name="ModelOutput") 220 221 with self.cached_session() as sess: 222 # check if identities have been put into the graph (2 input, 1 output, 223 # and 1 final output). 224 self.assertEqual(self._countIdentities(sess.graph_def.node), 4) 225 226 stubbed_graphdef = op_hint.convert_op_hints_to_stubs( 227 graph_def=sess.graph_def) 228 229 self.assertEqual( 230 self._getGraphOpTypes( 231 stubbed_graphdef, 232 output_nodes=[op_hint._tensor_name_base(output.name)]), 233 set(["cool_activation", "Const", "Identity"])) 234 235 def testScaleAndBiasAndIdentity(self): 236 """This tests a scaled add which has 3 inputs and 2 outputs.""" 237 with ops.Graph().as_default(): 238 a = array_ops.constant(1.) 239 x = array_ops.constant([2., 3.]) 240 b = array_ops.constant([4., 5.]) 241 242 def _scaled_and_bias_and_identity(a, x, b): 243 custom = op_hint.OpHint("scale_and_bias_and_identity") 244 a, x, b = custom.add_inputs(a, x, b) 245 return custom.add_outputs(a * x + b, x) 246 247 output = array_ops.identity( 248 _scaled_and_bias_and_identity(a, x, b), name="ModelOutput") 249 250 with self.cached_session() as sess: 251 # make sure one identity for each input (3) and output (2) => 3 + 2 = 5 252 # +1 for the final output 253 self.assertEqual(self._countIdentities(sess.graph_def.node), 6) 254 255 stubbed_graphdef = op_hint.convert_op_hints_to_stubs( 256 graph_def=sess.graph_def) 257 258 self.assertEqual( 259 self._getGraphOpTypes( 260 stubbed_graphdef, 261 output_nodes=[op_hint._tensor_name_base(output.name)]), 262 set(["scale_and_bias_and_identity", "Const", "Identity", "Pack"])) 263 264 def testTwoFunctions(self): 265 """Tests if two functions are converted correctly.""" 266 with ops.Graph().as_default(): 267 a = array_ops.constant([1.]) 268 b = array_ops.constant([1.]) 269 270 def _double_values(x): 271 custom = op_hint.OpHint("add_test") 272 x, = custom.add_inputs(x) 273 output = math_ops.multiply(x, x) 274 output, = custom.add_outputs(output) 275 return output 276 277 output = array_ops.identity( 278 math_ops.add(_double_values(a), _double_values(b)), 279 name="ModelOutput") 280 281 with self.cached_session() as sess: 282 # make sure one identity for each input (2) and output (2) => 2 + 2 283 # +1 for the final output 284 self.assertEqual(self._countIdentities(sess.graph_def.node), 5) 285 stubbed_graphdef = op_hint.convert_op_hints_to_stubs( 286 graph_def=sess.graph_def) 287 self.assertEqual( 288 self._getGraphOpTypes( 289 stubbed_graphdef, 290 output_nodes=[op_hint._tensor_name_base(output.name)]), 291 set(["add_test", "Const", "Identity", "AddV2"])) 292 293 def _get_input_index(self, x): 294 return x.op.node_def.attr[op_hint.OpHint.FUNCTION_INPUT_INDEX_ATTR].i 295 296 def _get_output_index(self, x): 297 return x.op.node_def.attr[op_hint.OpHint.FUNCTION_OUTPUT_INDEX_ATTR].i 298 299 def _get_sort_index(self, x): 300 return x.op.node_def.attr[op_hint.OpHint.FUNCTION_SORT_INDEX_ATTR].i 301 302 def testTags(self): 303 """Test if multiple args with the same tag are grouped.""" 304 with ops.Graph().as_default(): 305 a = array_ops.constant([1.]) 306 b = array_ops.constant([2.]) 307 c = array_ops.constant([3.]) 308 d = array_ops.constant([4.]) 309 custom = op_hint.OpHint("test_tag") 310 a = custom.add_input( 311 a, tag="mytag", aggregate=op_hint.OpHint.AGGREGATE_STACK) 312 b, = custom.add_inputs(b) 313 c = custom.add_input( 314 c, tag="mytag", aggregate=op_hint.OpHint.AGGREGATE_STACK) 315 d = custom.add_input( 316 d, tag="mytag2", aggregate=op_hint.OpHint.AGGREGATE_STACK) 317 res = math_ops.add(math_ops.mul(a, b), math_ops.mul(c, b)) 318 custom.add_outputs([res]) 319 with self.cached_session(): 320 self.assertEqual(self._get_input_index(a), 0) 321 self.assertEqual(self._get_sort_index(a), 0) 322 self.assertEqual(self._get_input_index(b), 1) 323 self.assertEqual(self._get_sort_index(b), 0) 324 self.assertEqual(self._get_input_index(c), 0) 325 self.assertEqual(self._get_sort_index(c), 1) 326 327 def testOverrideIndex(self): 328 with ops.Graph().as_default(): 329 a = array_ops.constant([1.]) 330 b = array_ops.constant([2.]) 331 c = array_ops.constant([3.]) 332 custom = op_hint.OpHint("test_override") 333 b = custom.add_input(b) # should auto assign 0 334 a = custom.add_input(a, index_override=1) 335 c = custom.add_input(c) # should auto assign 2 336 with self.cached_session(): 337 self.assertEqual(self._get_input_index(a), 1) 338 self.assertEqual(self._get_input_index(b), 0) 339 self.assertEqual(self._get_input_index(c), 2) 340 341 def testAggregate(self): 342 with ops.Graph().as_default(): 343 a = array_ops.constant([3., 4.]) 344 b = array_ops.constant([5., 6.]) 345 hint = op_hint.OpHint("agg") 346 a0, a1 = array_ops.unstack(a) 347 b0, b1 = array_ops.unstack(b) 348 349 a0 = hint.add_input(a0, tag="c", aggregate=op_hint.OpHint.AGGREGATE_STACK) 350 b0 = hint.add_input(b0, tag="n", aggregate=op_hint.OpHint.AGGREGATE_STACK) 351 a1 = hint.add_input(a1, tag="c", aggregate=op_hint.OpHint.AGGREGATE_STACK) 352 b1 = hint.add_input(b1, tag="n", aggregate=op_hint.OpHint.AGGREGATE_STACK) 353 354 c0 = math_ops.add(a0, b0, name="addleft") 355 c1 = math_ops.add(a1, b1, name="addright") 356 c0 = hint.add_output( 357 c0, tag="out", aggregate=op_hint.OpHint.AGGREGATE_STACK) 358 c1 = hint.add_output( 359 c1, tag="out", aggregate=op_hint.OpHint.AGGREGATE_STACK) 360 361 curr = array_ops.stack([c0, c1]) 362 output = array_ops.identity(curr, name="FINAL_OUTPUT") 363 with self.cached_session() as sess: 364 stubbed_graphdef = op_hint.convert_op_hints_to_stubs( 365 graph_def=sess.graph_def) 366 self.assertEqual( 367 self._getGraphOpTypes( 368 stubbed_graphdef, 369 output_nodes=[op_hint._tensor_name_base(output.name)]), 370 set(["agg", "Const", "Identity"])) 371 372 def testFindHintedOutputNodes(self): 373 """Test if all hinted output nodes are correctly found.""" 374 with ops.Graph().as_default(): 375 376 def _build_ophinted_op(name, input1, input2): 377 custom_op = op_hint.OpHint(name) 378 input1 = custom_op.add_input(input1) 379 input2 = custom_op.add_input(input2) 380 output = math_ops.mul(input1, input2) 381 return custom_op.add_output(output) 382 383 output_1 = _build_ophinted_op("custom_op_1", array_ops.constant([1.]), 384 array_ops.constant([2.])) 385 output_2 = _build_ophinted_op("custom_op_2", array_ops.constant([3.]), 386 array_ops.constant([4.])) 387 with self.cached_session() as sess: 388 hinted_outputs_nodes = op_hint.find_all_hinted_output_nodes(sess) 389 expected_hinted_output_nodes = [ 390 _node_name(output_1.name), 391 _node_name(output_2.name) 392 ] 393 self.assertEqual( 394 len(hinted_outputs_nodes), len(expected_hinted_output_nodes)) 395 396 397if __name__ == "__main__": 398 test.main() 399