xref: /aosp_15_r20/external/federated-compute/fcp/artifact_building/graph_helpers_test.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests for graph_helpers.py."""
15
16import collections
17
18from absl.testing import absltest
19
20import numpy as np
21import tensorflow as tf
22import tensorflow_federated as tff
23
24from fcp.artifact_building import data_spec
25from fcp.artifact_building import graph_helpers
26from fcp.artifact_building import variable_helpers
27from fcp.protos import plan_pb2
28from tensorflow_federated.proto.v0 import computation_pb2
29
30TRAIN_URI = 'boo'
31TEST_URI = 'foo'
32NUM_PIXELS = 784
33FAKE_INPUT_DIRECTORY_TENSOR = tf.constant('/path/to/input_dir')
34
35
36class EmbedDataLogicTest(absltest.TestCase):
37
38  def assertTensorSpec(self, tensor, name, shape, dtype):
39    self.assertIsInstance(tensor, tf.Tensor)
40    self.assertEqual(tensor.name, name)
41    self.assertEqual(tensor.shape, shape)
42    self.assertEqual(tensor.dtype, dtype)
43
44  def test_one_dataset_of_integers_w_dataspec(self):
45    with tf.Graph().as_default():
46      token_placeholder, data_values, placeholders = (
47          graph_helpers.embed_data_logic(
48              tff.SequenceType((tf.string)),
49              data_spec.DataSpec(
50                  plan_pb2.ExampleSelector(collection_uri='app://fake_uri')
51              ),
52          )
53      )
54
55    self.assertTensorSpec(token_placeholder, 'data_token:0', [], tf.string)
56    self.assertLen(data_values, 1)
57    self.assertTensorSpec(data_values[0], 'data_0/Identity:0', [], tf.variant)
58    self.assertEmpty(placeholders)
59
60  def test_two_datasets_of_integers_w_dataspec(self):
61    with tf.Graph().as_default():
62      token_placeholder, data_values, placeholders = (
63          graph_helpers.embed_data_logic(
64              collections.OrderedDict(
65                  A=tff.SequenceType((tf.string)),
66                  B=tff.SequenceType((tf.string)),
67              ),
68              collections.OrderedDict(
69                  A=data_spec.DataSpec(
70                      plan_pb2.ExampleSelector(collection_uri='app://foo')
71                  ),
72                  B=data_spec.DataSpec(
73                      plan_pb2.ExampleSelector(collection_uri='app://bar')
74                  ),
75              ),
76          )
77      )
78
79    self.assertTensorSpec(token_placeholder, 'data_token:0', [], tf.string)
80
81    self.assertLen(data_values, 2)
82    self.assertTensorSpec(data_values[0], 'data_0/Identity:0', [], tf.variant)
83    self.assertTensorSpec(data_values[1], 'data_1/Identity:0', [], tf.variant)
84    self.assertEmpty(placeholders)
85
86  def test_nested_dataspec(self):
87    with tf.Graph().as_default():
88      token_placeholder, data_values, placeholders = (
89          graph_helpers.embed_data_logic(
90              collections.OrderedDict(
91                  A=collections.OrderedDict(B=tff.SequenceType((tf.string)))
92              ),
93              collections.OrderedDict(
94                  A=collections.OrderedDict(
95                      B=data_spec.DataSpec(
96                          plan_pb2.ExampleSelector(collection_uri='app://foo')
97                      )
98                  )
99              ),
100          )
101      )
102
103    self.assertTensorSpec(token_placeholder, 'data_token:0', [], tf.string)
104    self.assertLen(data_values, 1)
105    self.assertTensorSpec(data_values[0], 'data_0/Identity:0', [], tf.variant)
106    self.assertEmpty(placeholders)
107
108  def test_one_dataset_of_integers_without_dataspec(self):
109    with tf.Graph().as_default():
110      token_placeholder, data_values, placeholders = (
111          graph_helpers.embed_data_logic(tff.SequenceType((tf.string)))
112      )
113
114    self.assertTensorSpec(token_placeholder, 'data_token:0', [], tf.string)
115    self.assertLen(data_values, 1)
116    self.assertTensorSpec(data_values[0], 'data_0/Identity:0', [], tf.variant)
117    self.assertLen(placeholders, 1)
118    self.assertEqual(placeholders[0].name, 'example_selector:0')
119
120  def test_two_datasets_of_integers_without_dataspec(self):
121    with tf.Graph().as_default():
122      token_placeholder, data_values, placeholders = (
123          graph_helpers.embed_data_logic(
124              collections.OrderedDict(
125                  A=tff.SequenceType((tf.string)),
126                  B=tff.SequenceType((tf.string)),
127              )
128          )
129      )
130
131    self.assertTensorSpec(token_placeholder, 'data_token:0', [], tf.string)
132
133    self.assertLen(data_values, 2)
134    self.assertTensorSpec(data_values[0], 'data_0/Identity:0', [], tf.variant)
135    self.assertTensorSpec(data_values[1], 'data_1/Identity:0', [], tf.variant)
136    self.assertLen(placeholders, 2)
137    self.assertEqual(placeholders[0].name, 'example_selector_0:0')
138    self.assertEqual(placeholders[1].name, 'example_selector_1:0')
139
140  def test_nested_input_without_dataspec(self):
141    with tf.Graph().as_default():
142      token_placeholder, data_values, placeholders = (
143          graph_helpers.embed_data_logic(
144              collections.OrderedDict(
145                  A=collections.OrderedDict(B=tff.SequenceType((tf.string)))
146              )
147          )
148      )
149
150    self.assertTensorSpec(token_placeholder, 'data_token:0', [], tf.string)
151    self.assertLen(data_values, 1)
152    self.assertTensorSpec(data_values[0], 'data_0/Identity:0', [], tf.variant)
153    self.assertLen(placeholders, 1)
154    self.assertEqual(placeholders[0].name, 'example_selector_0_0:0')
155
156
157class GraphHelperTest(absltest.TestCase):
158
159  def test_import_tensorflow(self):
160    # NOTE: Minimal test for now, since this is exercised by other components,
161    # just a single example with a combo of all flavors of params and results.
162    @tff.tf_computation(tff.SequenceType(tf.int64), tf.int64)
163    def work(ds, x):
164      return x + 1, ds.map(lambda a: a + x)
165
166    with tf.Graph().as_default():
167      ds = tf.data.experimental.to_variant(tf.data.Dataset.range(3))
168      v = tf.constant(10, dtype=tf.int64)
169      y, ds2_variant = graph_helpers.import_tensorflow(
170          'work', work, ([ds], [v]), split_outputs=True
171      )
172      ds2 = tf.data.experimental.from_variant(
173          ds2_variant[0], tf.TensorSpec([], tf.int64)
174      )
175      z = ds2.reduce(np.int64(0), lambda x, y: x + y)
176      with tf.compat.v1.Session() as sess:
177        self.assertEqual(sess.run(y[0]), 11)
178        self.assertEqual(sess.run(z), 33)
179
180  def test_import_tensorflow_with_session_token(self):
181    @tff.tf_computation
182    def return_value():
183      return tff.framework.get_session_token()
184
185    with tf.Graph().as_default():
186      x = tf.compat.v1.placeholder(dtype=tf.string)
187      output = graph_helpers.import_tensorflow(
188          'return_value', comp=return_value, session_token_tensor=x
189      )
190      with tf.compat.v1.Session() as sess:
191        self.assertEqual(sess.run(output[0], feed_dict={x: 'value'}), b'value')
192
193  def test_import_tensorflow_with_control_dep_remap(self):
194    # Assert that importing graphdef remaps both regular and control dep inputs.
195    @tff.tf_computation(tf.int64, tf.int64)
196    def work(x, y):
197      # Insert a control dependency to ensure it is remapped during import.
198      with tf.compat.v1.control_dependencies([y]):
199        return tf.identity(x)
200
201    with tf.Graph().as_default():
202      x = tf.compat.v1.placeholder(dtype=tf.int64)
203      y = tf.compat.v1.placeholder(dtype=tf.int64)
204      output = graph_helpers.import_tensorflow(
205          'control_dep_graph', comp=work, args=[x, y]
206      )
207      with tf.compat.v1.Session() as sess:
208        self.assertEqual(sess.run(output, feed_dict={x: 10, y: 20})[0], 10)
209
210  def test_add_control_deps_for_init_op(self):
211    # Creates a graph (double edges are regular dependencies, single edges are
212    # control dependencies) like this:
213    #
214    #  ghi
215    #   |
216    #  def
217    #   ||
218    #  def:0         foo
219    #   ||        //     ||
220    #  abc      bar      ||
221    #     \   //   \\    ||
222    #      bak        baz
223    #
224    graph_def = tf.compat.v1.GraphDef(
225        node=[
226            tf.compat.v1.NodeDef(name='foo', input=[]),
227            tf.compat.v1.NodeDef(name='bar', input=['foo']),
228            tf.compat.v1.NodeDef(name='baz', input=['foo', 'bar']),
229            tf.compat.v1.NodeDef(name='bak', input=['bar', '^abc']),
230            tf.compat.v1.NodeDef(name='abc', input=['def:0']),
231            tf.compat.v1.NodeDef(name='def', input=['^ghi']),
232            tf.compat.v1.NodeDef(name='ghi', input=[]),
233        ]
234    )
235    new_graph_def = graph_helpers.add_control_deps_for_init_op(graph_def, 'abc')
236    self.assertEqual(
237        ','.join(
238            '{}({})'.format(node.name, ','.join(node.input))
239            for node in new_graph_def.node
240        ),
241        (
242            'foo(^abc),bar(foo,^abc),baz(foo,bar,^abc),'
243            'bak(bar,^abc),abc(def:0),def(^ghi),ghi()'
244        ),
245    )
246
247  def test_create_tensor_map_with_sequence_binding_and_variant(self):
248    with tf.Graph().as_default():
249      variant_tensor = tf.data.experimental.to_variant(tf.data.Dataset.range(3))
250      input_map = graph_helpers.create_tensor_map(
251          computation_pb2.TensorFlow.Binding(
252              sequence=computation_pb2.TensorFlow.SequenceBinding(
253                  variant_tensor_name='foo'
254              )
255          ),
256          [variant_tensor],
257      )
258      self.assertLen(input_map, 1)
259      self.assertCountEqual(list(input_map.keys()), ['foo'])
260      self.assertIs(input_map['foo'], variant_tensor)
261
262  def test_create_tensor_map_with_sequence_binding_and_multiple_variants(self):
263    with tf.Graph().as_default():
264      variant_tensor = tf.data.experimental.to_variant(tf.data.Dataset.range(3))
265      with self.assertRaises(ValueError):
266        graph_helpers.create_tensor_map(
267            computation_pb2.TensorFlow.Binding(
268                sequence=computation_pb2.TensorFlow.SequenceBinding(
269                    variant_tensor_name='foo'
270                )
271            ),
272            [variant_tensor, variant_tensor],
273        )
274
275  def test_create_tensor_map_with_sequence_binding_and_non_variant(self):
276    with tf.Graph().as_default():
277      non_variant_tensor = tf.constant(1)
278      with self.assertRaises(TypeError):
279        graph_helpers.create_tensor_map(
280            computation_pb2.TensorFlow.Binding(
281                sequence=computation_pb2.TensorFlow.SequenceBinding(
282                    variant_tensor_name='foo'
283                )
284            ),
285            [non_variant_tensor],
286        )
287
288  def test_create_tensor_map_with_non_sequence_binding_and_vars(self):
289    with tf.Graph().as_default():
290      vars_list = variable_helpers.create_vars_for_tff_type(
291          tff.to_type([('a', tf.int32), ('b', tf.int32)])
292      )
293      init_op = tf.compat.v1.global_variables_initializer()
294      assign_op = tf.group(
295          *(v.assign(tf.constant(k + 1)) for k, v in enumerate(vars_list))
296      )
297      input_map = graph_helpers.create_tensor_map(
298          computation_pb2.TensorFlow.Binding(
299              struct=computation_pb2.TensorFlow.StructBinding(
300                  element=[
301                      computation_pb2.TensorFlow.Binding(
302                          tensor=computation_pb2.TensorFlow.TensorBinding(
303                              tensor_name='foo'
304                          )
305                      ),
306                      computation_pb2.TensorFlow.Binding(
307                          tensor=computation_pb2.TensorFlow.TensorBinding(
308                              tensor_name='bar'
309                          )
310                      ),
311                  ]
312              )
313          ),
314          vars_list,
315      )
316      with tf.compat.v1.Session() as sess:
317        sess.run(init_op)
318        sess.run(assign_op)
319        self.assertDictEqual(sess.run(input_map), {'foo': 1, 'bar': 2})
320
321  def test_get_deps_for_graph_node(self):
322    # Creates a graph (double edges are regular dependencies, single edges are
323    # control dependencies) like this:
324    #                      foo
325    #                   //      \\
326    #               foo:0        foo:1
327    #                  ||       //
328    #       abc       bar      //
329    #     //    \   //   \\   //
330    #  abc:0     bak       baz
331    #    ||
332    #   def
333    #    |
334    #   ghi
335    #
336    graph_def = tf.compat.v1.GraphDef(
337        node=[
338            tf.compat.v1.NodeDef(name='foo', input=[]),
339            tf.compat.v1.NodeDef(name='bar', input=['foo:0']),
340            tf.compat.v1.NodeDef(name='baz', input=['foo:1', 'bar']),
341            tf.compat.v1.NodeDef(name='bak', input=['bar', '^abc']),
342            tf.compat.v1.NodeDef(name='abc', input=[]),
343            tf.compat.v1.NodeDef(name='def', input=['abc:0']),
344            tf.compat.v1.NodeDef(name='ghi', input=['^def']),
345        ]
346    )
347
348    def _get_deps(x):
349      return ','.join(
350          sorted(list(graph_helpers._get_deps_for_graph_node(graph_def, x)))
351      )
352
353    self.assertEqual(_get_deps('foo'), '')
354    self.assertEqual(_get_deps('bar'), 'foo')
355    self.assertEqual(_get_deps('baz'), 'bar,foo')
356    self.assertEqual(_get_deps('bak'), 'abc,bar,foo')
357    self.assertEqual(_get_deps('abc'), '')
358    self.assertEqual(_get_deps('def'), 'abc')
359    self.assertEqual(_get_deps('ghi'), 'abc,def')
360
361  def test_list_tensor_names_in_binding(self):
362    binding = computation_pb2.TensorFlow.Binding(
363        struct=computation_pb2.TensorFlow.StructBinding(
364            element=[
365                computation_pb2.TensorFlow.Binding(
366                    tensor=computation_pb2.TensorFlow.TensorBinding(
367                        tensor_name='a'
368                    )
369                ),
370                computation_pb2.TensorFlow.Binding(
371                    struct=computation_pb2.TensorFlow.StructBinding(
372                        element=[
373                            computation_pb2.TensorFlow.Binding(
374                                tensor=computation_pb2.TensorFlow.TensorBinding(
375                                    tensor_name='b'
376                                )
377                            ),
378                            computation_pb2.TensorFlow.Binding(
379                                tensor=computation_pb2.TensorFlow.TensorBinding(
380                                    tensor_name='c'
381                                )
382                            ),
383                        ]
384                    )
385                ),
386                computation_pb2.TensorFlow.Binding(
387                    tensor=computation_pb2.TensorFlow.TensorBinding(
388                        tensor_name='d'
389                    )
390                ),
391                computation_pb2.TensorFlow.Binding(
392                    sequence=computation_pb2.TensorFlow.SequenceBinding(
393                        variant_tensor_name='e'
394                    )
395                ),
396            ]
397        )
398    )
399    self.assertEqual(
400        graph_helpers._list_tensor_names_in_binding(binding),
401        ['a', 'b', 'c', 'd', 'e'],
402    )
403
404
405if __name__ == '__main__':
406  with tff.framework.get_context_stack().install(
407      tff.test.create_runtime_error_context()
408  ):
409    absltest.main()
410