xref: /aosp_15_r20/external/federated-compute/fcp/artifact_building/plan_utils_test.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Test class for plan_utils."""
15
16import os
17
18import tensorflow as tf
19
20from fcp.artifact_building import checkpoint_utils
21from fcp.artifact_building import plan_utils
22from fcp.artifact_building import test_utils
23from fcp.protos import plan_pb2
24
25
26class PlanUtilsTest(tf.test.TestCase):
27
28  def test_write_checkpoint(self):
29    checkpoint_op = plan_pb2.CheckpointOp()
30    graph = tf.Graph()
31    with graph.as_default():
32      v = tf.compat.v1.get_variable('v', initializer=tf.constant(1))
33      saver = checkpoint_utils.create_deterministic_saver([v])
34      test_utils.set_checkpoint_op(checkpoint_op, saver)
35      init_op = v.assign(tf.constant(2))
36      change_op = v.assign(tf.constant(3))
37
38    with tf.compat.v1.Session(graph=graph) as sess:
39      sess.run(init_op)
40      temp_file = self.create_tempfile().full_path
41      plan_utils.write_checkpoint(sess, checkpoint_op, temp_file)
42      # Change the variable in this session.
43      sess.run(change_op)
44
45    with tf.compat.v1.Session(graph=graph) as sess:
46      saver.restore(sess, temp_file)
47      # Should not see update to 3.
48      self.assertEqual(2, sess.run(v))
49
50  def test_write_checkpoint_not_checkpoint_op(self):
51    with self.assertRaises(ValueError):
52      plan_utils.write_checkpoint(None, 'not_checkpoint_op', None)
53
54  def test_write_checkpoint_skips_when_no_saver_def(self):
55    checkpoint_op = plan_pb2.CheckpointOp()
56    with tf.compat.v1.Session() as sess:
57      temp_file = self.create_tempfile().full_path
58      # Close deletes the file, we just want a good name.
59      os.remove(temp_file)
60      plan_utils.write_checkpoint(sess, checkpoint_op, temp_file)
61      self.assertFalse(os.path.isfile(temp_file))
62
63  def test_read_checkpoint(self):
64    checkpoint_op = plan_pb2.CheckpointOp()
65    graph = tf.Graph()
66    with graph.as_default():
67      v = tf.compat.v1.get_variable('v', initializer=tf.constant(1))
68      saver = checkpoint_utils.create_deterministic_saver([v])
69      test_utils.set_checkpoint_op(checkpoint_op, saver)
70      init_op = v.assign(tf.constant(2))
71      change_op = v.assign(tf.constant(3))
72
73    with tf.compat.v1.Session(graph=graph) as sess:
74      sess.run(init_op)
75      temp_file = self.create_tempfile().full_path
76      saver.save(sess, temp_file)
77      sess.run(change_op)
78
79      plan_utils.read_checkpoint(sess, checkpoint_op, temp_file)
80      # Should not see update to 3.
81      self.assertEqual(2, sess.run(v))
82
83  def test_generate_and_add_tflite_model_to_plan(self):
84    # Create a graph for y = x ^ 2.
85    graph = tf.Graph()
86    with graph.as_default():
87      x = tf.compat.v1.placeholder(tf.int32, shape=[], name='x')
88      _ = tf.math.pow(x, 2, name='y')
89    input_tensor_spec = tf.TensorSpec(
90        shape=tf.TensorShape([]), dtype=tf.int32, name='x:0'
91    ).experimental_as_proto()
92    output_tensor_spec = tf.TensorSpec(
93        shape=tf.TensorShape([]), dtype=tf.int32, name='y:0'
94    ).experimental_as_proto()
95
96    tensorflow_spec = plan_pb2.TensorflowSpec()
97    tensorflow_spec.input_tensor_specs.append(input_tensor_spec)
98    tensorflow_spec.output_tensor_specs.append(output_tensor_spec)
99
100    flatbuffer = plan_utils.convert_graphdef_to_flatbuffer(
101        graph.as_graph_def(), tensorflow_spec
102    )
103
104    interpreter = tf.lite.Interpreter(model_content=flatbuffer)
105    interpreter.allocate_tensors()
106    input_data = tf.constant(3, shape=[])
107    # Model has single output.
108    model_output = interpreter.get_output_details()[0]
109    # Model has single input.
110    model_input = interpreter.get_input_details()[0]
111    interpreter.set_tensor(model_input['index'], input_data)
112    interpreter.invoke()
113    self.assertEqual(interpreter.get_tensor(model_output['index']), 9)
114
115
116class TfLiteTest(tf.test.TestCase):
117  """Tests common methods related to TFLite support."""
118
119  def test_caught_exception_in_tflite_conversion_failure_for_plan(self):
120    plan = plan_pb2.Plan()
121    plan.client_graph_bytes.Pack(tf.compat.v1.GraphDef())
122    plan.phase.add()
123    with self.assertRaisesRegex(
124        RuntimeError, 'Failure during TFLite conversion'
125    ):
126      plan_utils.generate_and_add_flat_buffer_to_plan(
127          plan, forgive_tflite_conversion_failure=False
128      )
129
130  def test_forgive_tflite_conversion_failure_for_plan(self):
131    plan = plan_pb2.Plan()
132    plan.client_graph_bytes.Pack(tf.compat.v1.GraphDef())
133    plan.phase.add()
134    plan_after_conversion = plan_utils.generate_and_add_flat_buffer_to_plan(
135        plan, forgive_tflite_conversion_failure=True
136    )
137    self.assertIsInstance(plan_after_conversion, plan_pb2.Plan)
138    self.assertEmpty(plan_after_conversion.client_tflite_graph_bytes)
139
140  def test_caught_exception_in_tflite_conversion_failure_for_client_only_plan(
141      self,
142  ):
143    client_only_plan = plan_pb2.ClientOnlyPlan()
144    client_only_plan.graph = tf.compat.v1.GraphDef().SerializeToString()
145    with self.assertRaisesRegex(
146        RuntimeError, 'Failure during TFLite conversion'
147    ):
148      plan_utils.generate_and_add_flat_buffer_to_plan(
149          client_only_plan, forgive_tflite_conversion_failure=False
150      )
151
152  def test_forgive_tflite_conversion_failure_for_client_only_plan(self):
153    client_only_plan = plan_pb2.ClientOnlyPlan()
154    client_only_plan.graph = tf.compat.v1.GraphDef().SerializeToString()
155    plan_after_conversion = plan_utils.generate_and_add_flat_buffer_to_plan(
156        client_only_plan, forgive_tflite_conversion_failure=True
157    )
158    self.assertIsInstance(plan_after_conversion, plan_pb2.ClientOnlyPlan)
159    self.assertEmpty(plan_after_conversion.tflite_graph)
160
161  def _create_test_graph_with_associated_tensor_specs(self):
162    # Create a graph for y = x ^ 2.
163    graph = tf.Graph()
164    with graph.as_default():
165      x = tf.compat.v1.placeholder(tf.int32, shape=[], name='x')
166      _ = tf.math.pow(x, 2, name='y')
167    input_tensor_spec = tf.TensorSpec(
168        shape=tf.TensorShape([]), dtype=tf.int32, name='x:0'
169    ).experimental_as_proto()
170    output_tensor_spec = tf.TensorSpec(
171        shape=tf.TensorShape([]), dtype=tf.int32, name='y:0'
172    ).experimental_as_proto()
173    return graph, input_tensor_spec, output_tensor_spec
174
175  def _assert_tflite_flatbuffer_is_equivalent_to_test_graph(self, tflite_graph):
176    # Check that the generated TFLite model also is y = x ^ 2.
177    self.assertNotEmpty(tflite_graph)
178    interpreter = tf.lite.Interpreter(model_content=tflite_graph)
179    interpreter.allocate_tensors()
180    input_data = tf.constant(3, shape=[])
181    # Model has single output.
182    model_output = interpreter.get_output_details()[0]
183    # Model has single input.
184    model_input = interpreter.get_input_details()[0]
185    interpreter.set_tensor(model_input['index'], input_data)
186    interpreter.invoke()
187    self.assertEqual(interpreter.get_tensor(model_output['index']), 9)
188
189  def test_add_equivalent_tflite_model_to_plan(self):
190    """Tests that the generated tflite model is identical to the tf.Graph."""
191
192    graph, input_tensor_spec, output_tensor_spec = (
193        self._create_test_graph_with_associated_tensor_specs()
194    )
195
196    # Create a fairly empty Plan with just the graph and the
197    # TensorSpecProtos populated (since that is all that is needed for
198    # conversion.)
199    plan_proto = plan_pb2.Plan()
200    plan_proto.client_graph_bytes.Pack(graph.as_graph_def())
201    plan_proto.phase.add()
202    plan_proto.phase[0].client_phase.tensorflow_spec.input_tensor_specs.append(
203        input_tensor_spec
204    )
205    plan_proto.phase[0].client_phase.tensorflow_spec.output_tensor_specs.append(
206        output_tensor_spec
207    )
208
209    # Generate the TFLite model.
210    plan_after_conversion = plan_utils.generate_and_add_flat_buffer_to_plan(
211        plan_proto
212    )
213
214    self.assertIsInstance(plan_after_conversion, plan_pb2.Plan)
215    self.assertEqual(plan_after_conversion, plan_proto)
216    self._assert_tflite_flatbuffer_is_equivalent_to_test_graph(
217        plan_after_conversion.client_tflite_graph_bytes
218    )
219
220  def test_add_equivalent_tflite_model_to_client_only_plan(self):
221    """Tests that the generated tflite model is identical to the tf.Graph."""
222
223    graph, input_tensor_spec, output_tensor_spec = (
224        self._create_test_graph_with_associated_tensor_specs()
225    )
226
227    # Create a fairly empty ClientOnlyPlan with just the graph and the
228    # TensorSpecProtos populated (since that is all that is needed for
229    # conversion.)
230    client_only_plan_proto = plan_pb2.ClientOnlyPlan()
231    client_only_plan_proto.graph = graph.as_graph_def().SerializeToString()
232    client_only_plan_proto.phase.tensorflow_spec.input_tensor_specs.append(
233        input_tensor_spec
234    )
235    client_only_plan_proto.phase.tensorflow_spec.output_tensor_specs.append(
236        output_tensor_spec
237    )
238
239    # Generate the TFLite model.
240    plan_after_conversion = plan_utils.generate_and_add_flat_buffer_to_plan(
241        client_only_plan_proto
242    )
243
244    self.assertIsInstance(plan_after_conversion, plan_pb2.ClientOnlyPlan)
245    self.assertEqual(plan_after_conversion, client_only_plan_proto)
246    self._assert_tflite_flatbuffer_is_equivalent_to_test_graph(
247        plan_after_conversion.tflite_graph
248    )
249
250
251if __name__ == '__main__':
252  tf.test.main()
253