1# Copyright 2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Test class for plan_utils.""" 15 16import os 17 18import tensorflow as tf 19 20from fcp.artifact_building import checkpoint_utils 21from fcp.artifact_building import plan_utils 22from fcp.artifact_building import test_utils 23from fcp.protos import plan_pb2 24 25 26class PlanUtilsTest(tf.test.TestCase): 27 28 def test_write_checkpoint(self): 29 checkpoint_op = plan_pb2.CheckpointOp() 30 graph = tf.Graph() 31 with graph.as_default(): 32 v = tf.compat.v1.get_variable('v', initializer=tf.constant(1)) 33 saver = checkpoint_utils.create_deterministic_saver([v]) 34 test_utils.set_checkpoint_op(checkpoint_op, saver) 35 init_op = v.assign(tf.constant(2)) 36 change_op = v.assign(tf.constant(3)) 37 38 with tf.compat.v1.Session(graph=graph) as sess: 39 sess.run(init_op) 40 temp_file = self.create_tempfile().full_path 41 plan_utils.write_checkpoint(sess, checkpoint_op, temp_file) 42 # Change the variable in this session. 43 sess.run(change_op) 44 45 with tf.compat.v1.Session(graph=graph) as sess: 46 saver.restore(sess, temp_file) 47 # Should not see update to 3. 48 self.assertEqual(2, sess.run(v)) 49 50 def test_write_checkpoint_not_checkpoint_op(self): 51 with self.assertRaises(ValueError): 52 plan_utils.write_checkpoint(None, 'not_checkpoint_op', None) 53 54 def test_write_checkpoint_skips_when_no_saver_def(self): 55 checkpoint_op = plan_pb2.CheckpointOp() 56 with tf.compat.v1.Session() as sess: 57 temp_file = self.create_tempfile().full_path 58 # Close deletes the file, we just want a good name. 59 os.remove(temp_file) 60 plan_utils.write_checkpoint(sess, checkpoint_op, temp_file) 61 self.assertFalse(os.path.isfile(temp_file)) 62 63 def test_read_checkpoint(self): 64 checkpoint_op = plan_pb2.CheckpointOp() 65 graph = tf.Graph() 66 with graph.as_default(): 67 v = tf.compat.v1.get_variable('v', initializer=tf.constant(1)) 68 saver = checkpoint_utils.create_deterministic_saver([v]) 69 test_utils.set_checkpoint_op(checkpoint_op, saver) 70 init_op = v.assign(tf.constant(2)) 71 change_op = v.assign(tf.constant(3)) 72 73 with tf.compat.v1.Session(graph=graph) as sess: 74 sess.run(init_op) 75 temp_file = self.create_tempfile().full_path 76 saver.save(sess, temp_file) 77 sess.run(change_op) 78 79 plan_utils.read_checkpoint(sess, checkpoint_op, temp_file) 80 # Should not see update to 3. 81 self.assertEqual(2, sess.run(v)) 82 83 def test_generate_and_add_tflite_model_to_plan(self): 84 # Create a graph for y = x ^ 2. 85 graph = tf.Graph() 86 with graph.as_default(): 87 x = tf.compat.v1.placeholder(tf.int32, shape=[], name='x') 88 _ = tf.math.pow(x, 2, name='y') 89 input_tensor_spec = tf.TensorSpec( 90 shape=tf.TensorShape([]), dtype=tf.int32, name='x:0' 91 ).experimental_as_proto() 92 output_tensor_spec = tf.TensorSpec( 93 shape=tf.TensorShape([]), dtype=tf.int32, name='y:0' 94 ).experimental_as_proto() 95 96 tensorflow_spec = plan_pb2.TensorflowSpec() 97 tensorflow_spec.input_tensor_specs.append(input_tensor_spec) 98 tensorflow_spec.output_tensor_specs.append(output_tensor_spec) 99 100 flatbuffer = plan_utils.convert_graphdef_to_flatbuffer( 101 graph.as_graph_def(), tensorflow_spec 102 ) 103 104 interpreter = tf.lite.Interpreter(model_content=flatbuffer) 105 interpreter.allocate_tensors() 106 input_data = tf.constant(3, shape=[]) 107 # Model has single output. 108 model_output = interpreter.get_output_details()[0] 109 # Model has single input. 110 model_input = interpreter.get_input_details()[0] 111 interpreter.set_tensor(model_input['index'], input_data) 112 interpreter.invoke() 113 self.assertEqual(interpreter.get_tensor(model_output['index']), 9) 114 115 116class TfLiteTest(tf.test.TestCase): 117 """Tests common methods related to TFLite support.""" 118 119 def test_caught_exception_in_tflite_conversion_failure_for_plan(self): 120 plan = plan_pb2.Plan() 121 plan.client_graph_bytes.Pack(tf.compat.v1.GraphDef()) 122 plan.phase.add() 123 with self.assertRaisesRegex( 124 RuntimeError, 'Failure during TFLite conversion' 125 ): 126 plan_utils.generate_and_add_flat_buffer_to_plan( 127 plan, forgive_tflite_conversion_failure=False 128 ) 129 130 def test_forgive_tflite_conversion_failure_for_plan(self): 131 plan = plan_pb2.Plan() 132 plan.client_graph_bytes.Pack(tf.compat.v1.GraphDef()) 133 plan.phase.add() 134 plan_after_conversion = plan_utils.generate_and_add_flat_buffer_to_plan( 135 plan, forgive_tflite_conversion_failure=True 136 ) 137 self.assertIsInstance(plan_after_conversion, plan_pb2.Plan) 138 self.assertEmpty(plan_after_conversion.client_tflite_graph_bytes) 139 140 def test_caught_exception_in_tflite_conversion_failure_for_client_only_plan( 141 self, 142 ): 143 client_only_plan = plan_pb2.ClientOnlyPlan() 144 client_only_plan.graph = tf.compat.v1.GraphDef().SerializeToString() 145 with self.assertRaisesRegex( 146 RuntimeError, 'Failure during TFLite conversion' 147 ): 148 plan_utils.generate_and_add_flat_buffer_to_plan( 149 client_only_plan, forgive_tflite_conversion_failure=False 150 ) 151 152 def test_forgive_tflite_conversion_failure_for_client_only_plan(self): 153 client_only_plan = plan_pb2.ClientOnlyPlan() 154 client_only_plan.graph = tf.compat.v1.GraphDef().SerializeToString() 155 plan_after_conversion = plan_utils.generate_and_add_flat_buffer_to_plan( 156 client_only_plan, forgive_tflite_conversion_failure=True 157 ) 158 self.assertIsInstance(plan_after_conversion, plan_pb2.ClientOnlyPlan) 159 self.assertEmpty(plan_after_conversion.tflite_graph) 160 161 def _create_test_graph_with_associated_tensor_specs(self): 162 # Create a graph for y = x ^ 2. 163 graph = tf.Graph() 164 with graph.as_default(): 165 x = tf.compat.v1.placeholder(tf.int32, shape=[], name='x') 166 _ = tf.math.pow(x, 2, name='y') 167 input_tensor_spec = tf.TensorSpec( 168 shape=tf.TensorShape([]), dtype=tf.int32, name='x:0' 169 ).experimental_as_proto() 170 output_tensor_spec = tf.TensorSpec( 171 shape=tf.TensorShape([]), dtype=tf.int32, name='y:0' 172 ).experimental_as_proto() 173 return graph, input_tensor_spec, output_tensor_spec 174 175 def _assert_tflite_flatbuffer_is_equivalent_to_test_graph(self, tflite_graph): 176 # Check that the generated TFLite model also is y = x ^ 2. 177 self.assertNotEmpty(tflite_graph) 178 interpreter = tf.lite.Interpreter(model_content=tflite_graph) 179 interpreter.allocate_tensors() 180 input_data = tf.constant(3, shape=[]) 181 # Model has single output. 182 model_output = interpreter.get_output_details()[0] 183 # Model has single input. 184 model_input = interpreter.get_input_details()[0] 185 interpreter.set_tensor(model_input['index'], input_data) 186 interpreter.invoke() 187 self.assertEqual(interpreter.get_tensor(model_output['index']), 9) 188 189 def test_add_equivalent_tflite_model_to_plan(self): 190 """Tests that the generated tflite model is identical to the tf.Graph.""" 191 192 graph, input_tensor_spec, output_tensor_spec = ( 193 self._create_test_graph_with_associated_tensor_specs() 194 ) 195 196 # Create a fairly empty Plan with just the graph and the 197 # TensorSpecProtos populated (since that is all that is needed for 198 # conversion.) 199 plan_proto = plan_pb2.Plan() 200 plan_proto.client_graph_bytes.Pack(graph.as_graph_def()) 201 plan_proto.phase.add() 202 plan_proto.phase[0].client_phase.tensorflow_spec.input_tensor_specs.append( 203 input_tensor_spec 204 ) 205 plan_proto.phase[0].client_phase.tensorflow_spec.output_tensor_specs.append( 206 output_tensor_spec 207 ) 208 209 # Generate the TFLite model. 210 plan_after_conversion = plan_utils.generate_and_add_flat_buffer_to_plan( 211 plan_proto 212 ) 213 214 self.assertIsInstance(plan_after_conversion, plan_pb2.Plan) 215 self.assertEqual(plan_after_conversion, plan_proto) 216 self._assert_tflite_flatbuffer_is_equivalent_to_test_graph( 217 plan_after_conversion.client_tflite_graph_bytes 218 ) 219 220 def test_add_equivalent_tflite_model_to_client_only_plan(self): 221 """Tests that the generated tflite model is identical to the tf.Graph.""" 222 223 graph, input_tensor_spec, output_tensor_spec = ( 224 self._create_test_graph_with_associated_tensor_specs() 225 ) 226 227 # Create a fairly empty ClientOnlyPlan with just the graph and the 228 # TensorSpecProtos populated (since that is all that is needed for 229 # conversion.) 230 client_only_plan_proto = plan_pb2.ClientOnlyPlan() 231 client_only_plan_proto.graph = graph.as_graph_def().SerializeToString() 232 client_only_plan_proto.phase.tensorflow_spec.input_tensor_specs.append( 233 input_tensor_spec 234 ) 235 client_only_plan_proto.phase.tensorflow_spec.output_tensor_specs.append( 236 output_tensor_spec 237 ) 238 239 # Generate the TFLite model. 240 plan_after_conversion = plan_utils.generate_and_add_flat_buffer_to_plan( 241 client_only_plan_proto 242 ) 243 244 self.assertIsInstance(plan_after_conversion, plan_pb2.ClientOnlyPlan) 245 self.assertEqual(plan_after_conversion, client_only_plan_proto) 246 self._assert_tflite_flatbuffer_is_equivalent_to_test_graph( 247 plan_after_conversion.tflite_graph 248 ) 249 250 251if __name__ == '__main__': 252 tf.test.main() 253