# Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test class for plan_utils.""" import os import tensorflow as tf from fcp.artifact_building import checkpoint_utils from fcp.artifact_building import plan_utils from fcp.artifact_building import test_utils from fcp.protos import plan_pb2 class PlanUtilsTest(tf.test.TestCase): def test_write_checkpoint(self): checkpoint_op = plan_pb2.CheckpointOp() graph = tf.Graph() with graph.as_default(): v = tf.compat.v1.get_variable('v', initializer=tf.constant(1)) saver = checkpoint_utils.create_deterministic_saver([v]) test_utils.set_checkpoint_op(checkpoint_op, saver) init_op = v.assign(tf.constant(2)) change_op = v.assign(tf.constant(3)) with tf.compat.v1.Session(graph=graph) as sess: sess.run(init_op) temp_file = self.create_tempfile().full_path plan_utils.write_checkpoint(sess, checkpoint_op, temp_file) # Change the variable in this session. sess.run(change_op) with tf.compat.v1.Session(graph=graph) as sess: saver.restore(sess, temp_file) # Should not see update to 3. self.assertEqual(2, sess.run(v)) def test_write_checkpoint_not_checkpoint_op(self): with self.assertRaises(ValueError): plan_utils.write_checkpoint(None, 'not_checkpoint_op', None) def test_write_checkpoint_skips_when_no_saver_def(self): checkpoint_op = plan_pb2.CheckpointOp() with tf.compat.v1.Session() as sess: temp_file = self.create_tempfile().full_path # Close deletes the file, we just want a good name. os.remove(temp_file) plan_utils.write_checkpoint(sess, checkpoint_op, temp_file) self.assertFalse(os.path.isfile(temp_file)) def test_read_checkpoint(self): checkpoint_op = plan_pb2.CheckpointOp() graph = tf.Graph() with graph.as_default(): v = tf.compat.v1.get_variable('v', initializer=tf.constant(1)) saver = checkpoint_utils.create_deterministic_saver([v]) test_utils.set_checkpoint_op(checkpoint_op, saver) init_op = v.assign(tf.constant(2)) change_op = v.assign(tf.constant(3)) with tf.compat.v1.Session(graph=graph) as sess: sess.run(init_op) temp_file = self.create_tempfile().full_path saver.save(sess, temp_file) sess.run(change_op) plan_utils.read_checkpoint(sess, checkpoint_op, temp_file) # Should not see update to 3. self.assertEqual(2, sess.run(v)) def test_generate_and_add_tflite_model_to_plan(self): # Create a graph for y = x ^ 2. graph = tf.Graph() with graph.as_default(): x = tf.compat.v1.placeholder(tf.int32, shape=[], name='x') _ = tf.math.pow(x, 2, name='y') input_tensor_spec = tf.TensorSpec( shape=tf.TensorShape([]), dtype=tf.int32, name='x:0' ).experimental_as_proto() output_tensor_spec = tf.TensorSpec( shape=tf.TensorShape([]), dtype=tf.int32, name='y:0' ).experimental_as_proto() tensorflow_spec = plan_pb2.TensorflowSpec() tensorflow_spec.input_tensor_specs.append(input_tensor_spec) tensorflow_spec.output_tensor_specs.append(output_tensor_spec) flatbuffer = plan_utils.convert_graphdef_to_flatbuffer( graph.as_graph_def(), tensorflow_spec ) interpreter = tf.lite.Interpreter(model_content=flatbuffer) interpreter.allocate_tensors() input_data = tf.constant(3, shape=[]) # Model has single output. model_output = interpreter.get_output_details()[0] # Model has single input. model_input = interpreter.get_input_details()[0] interpreter.set_tensor(model_input['index'], input_data) interpreter.invoke() self.assertEqual(interpreter.get_tensor(model_output['index']), 9) class TfLiteTest(tf.test.TestCase): """Tests common methods related to TFLite support.""" def test_caught_exception_in_tflite_conversion_failure_for_plan(self): plan = plan_pb2.Plan() plan.client_graph_bytes.Pack(tf.compat.v1.GraphDef()) plan.phase.add() with self.assertRaisesRegex( RuntimeError, 'Failure during TFLite conversion' ): plan_utils.generate_and_add_flat_buffer_to_plan( plan, forgive_tflite_conversion_failure=False ) def test_forgive_tflite_conversion_failure_for_plan(self): plan = plan_pb2.Plan() plan.client_graph_bytes.Pack(tf.compat.v1.GraphDef()) plan.phase.add() plan_after_conversion = plan_utils.generate_and_add_flat_buffer_to_plan( plan, forgive_tflite_conversion_failure=True ) self.assertIsInstance(plan_after_conversion, plan_pb2.Plan) self.assertEmpty(plan_after_conversion.client_tflite_graph_bytes) def test_caught_exception_in_tflite_conversion_failure_for_client_only_plan( self, ): client_only_plan = plan_pb2.ClientOnlyPlan() client_only_plan.graph = tf.compat.v1.GraphDef().SerializeToString() with self.assertRaisesRegex( RuntimeError, 'Failure during TFLite conversion' ): plan_utils.generate_and_add_flat_buffer_to_plan( client_only_plan, forgive_tflite_conversion_failure=False ) def test_forgive_tflite_conversion_failure_for_client_only_plan(self): client_only_plan = plan_pb2.ClientOnlyPlan() client_only_plan.graph = tf.compat.v1.GraphDef().SerializeToString() plan_after_conversion = plan_utils.generate_and_add_flat_buffer_to_plan( client_only_plan, forgive_tflite_conversion_failure=True ) self.assertIsInstance(plan_after_conversion, plan_pb2.ClientOnlyPlan) self.assertEmpty(plan_after_conversion.tflite_graph) def _create_test_graph_with_associated_tensor_specs(self): # Create a graph for y = x ^ 2. graph = tf.Graph() with graph.as_default(): x = tf.compat.v1.placeholder(tf.int32, shape=[], name='x') _ = tf.math.pow(x, 2, name='y') input_tensor_spec = tf.TensorSpec( shape=tf.TensorShape([]), dtype=tf.int32, name='x:0' ).experimental_as_proto() output_tensor_spec = tf.TensorSpec( shape=tf.TensorShape([]), dtype=tf.int32, name='y:0' ).experimental_as_proto() return graph, input_tensor_spec, output_tensor_spec def _assert_tflite_flatbuffer_is_equivalent_to_test_graph(self, tflite_graph): # Check that the generated TFLite model also is y = x ^ 2. self.assertNotEmpty(tflite_graph) interpreter = tf.lite.Interpreter(model_content=tflite_graph) interpreter.allocate_tensors() input_data = tf.constant(3, shape=[]) # Model has single output. model_output = interpreter.get_output_details()[0] # Model has single input. model_input = interpreter.get_input_details()[0] interpreter.set_tensor(model_input['index'], input_data) interpreter.invoke() self.assertEqual(interpreter.get_tensor(model_output['index']), 9) def test_add_equivalent_tflite_model_to_plan(self): """Tests that the generated tflite model is identical to the tf.Graph.""" graph, input_tensor_spec, output_tensor_spec = ( self._create_test_graph_with_associated_tensor_specs() ) # Create a fairly empty Plan with just the graph and the # TensorSpecProtos populated (since that is all that is needed for # conversion.) plan_proto = plan_pb2.Plan() plan_proto.client_graph_bytes.Pack(graph.as_graph_def()) plan_proto.phase.add() plan_proto.phase[0].client_phase.tensorflow_spec.input_tensor_specs.append( input_tensor_spec ) plan_proto.phase[0].client_phase.tensorflow_spec.output_tensor_specs.append( output_tensor_spec ) # Generate the TFLite model. plan_after_conversion = plan_utils.generate_and_add_flat_buffer_to_plan( plan_proto ) self.assertIsInstance(plan_after_conversion, plan_pb2.Plan) self.assertEqual(plan_after_conversion, plan_proto) self._assert_tflite_flatbuffer_is_equivalent_to_test_graph( plan_after_conversion.client_tflite_graph_bytes ) def test_add_equivalent_tflite_model_to_client_only_plan(self): """Tests that the generated tflite model is identical to the tf.Graph.""" graph, input_tensor_spec, output_tensor_spec = ( self._create_test_graph_with_associated_tensor_specs() ) # Create a fairly empty ClientOnlyPlan with just the graph and the # TensorSpecProtos populated (since that is all that is needed for # conversion.) client_only_plan_proto = plan_pb2.ClientOnlyPlan() client_only_plan_proto.graph = graph.as_graph_def().SerializeToString() client_only_plan_proto.phase.tensorflow_spec.input_tensor_specs.append( input_tensor_spec ) client_only_plan_proto.phase.tensorflow_spec.output_tensor_specs.append( output_tensor_spec ) # Generate the TFLite model. plan_after_conversion = plan_utils.generate_and_add_flat_buffer_to_plan( client_only_plan_proto ) self.assertIsInstance(plan_after_conversion, plan_pb2.ClientOnlyPlan) self.assertEqual(plan_after_conversion, client_only_plan_proto) self._assert_tflite_flatbuffer_is_equivalent_to_test_graph( plan_after_conversion.tflite_graph ) if __name__ == '__main__': tf.test.main()