# Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for graph_helpers.py.""" import collections from absl.testing import absltest import numpy as np import tensorflow as tf import tensorflow_federated as tff from fcp.artifact_building import data_spec from fcp.artifact_building import graph_helpers from fcp.artifact_building import variable_helpers from fcp.protos import plan_pb2 from tensorflow_federated.proto.v0 import computation_pb2 TRAIN_URI = 'boo' TEST_URI = 'foo' NUM_PIXELS = 784 FAKE_INPUT_DIRECTORY_TENSOR = tf.constant('/path/to/input_dir') class EmbedDataLogicTest(absltest.TestCase): def assertTensorSpec(self, tensor, name, shape, dtype): self.assertIsInstance(tensor, tf.Tensor) self.assertEqual(tensor.name, name) self.assertEqual(tensor.shape, shape) self.assertEqual(tensor.dtype, dtype) def test_one_dataset_of_integers_w_dataspec(self): with tf.Graph().as_default(): token_placeholder, data_values, placeholders = ( graph_helpers.embed_data_logic( tff.SequenceType((tf.string)), data_spec.DataSpec( plan_pb2.ExampleSelector(collection_uri='app://fake_uri') ), ) ) self.assertTensorSpec(token_placeholder, 'data_token:0', [], tf.string) self.assertLen(data_values, 1) self.assertTensorSpec(data_values[0], 'data_0/Identity:0', [], tf.variant) self.assertEmpty(placeholders) def test_two_datasets_of_integers_w_dataspec(self): with tf.Graph().as_default(): token_placeholder, data_values, placeholders = ( graph_helpers.embed_data_logic( collections.OrderedDict( A=tff.SequenceType((tf.string)), B=tff.SequenceType((tf.string)), ), collections.OrderedDict( A=data_spec.DataSpec( plan_pb2.ExampleSelector(collection_uri='app://foo') ), B=data_spec.DataSpec( plan_pb2.ExampleSelector(collection_uri='app://bar') ), ), ) ) self.assertTensorSpec(token_placeholder, 'data_token:0', [], tf.string) self.assertLen(data_values, 2) self.assertTensorSpec(data_values[0], 'data_0/Identity:0', [], tf.variant) self.assertTensorSpec(data_values[1], 'data_1/Identity:0', [], tf.variant) self.assertEmpty(placeholders) def test_nested_dataspec(self): with tf.Graph().as_default(): token_placeholder, data_values, placeholders = ( graph_helpers.embed_data_logic( collections.OrderedDict( A=collections.OrderedDict(B=tff.SequenceType((tf.string))) ), collections.OrderedDict( A=collections.OrderedDict( B=data_spec.DataSpec( plan_pb2.ExampleSelector(collection_uri='app://foo') ) ) ), ) ) self.assertTensorSpec(token_placeholder, 'data_token:0', [], tf.string) self.assertLen(data_values, 1) self.assertTensorSpec(data_values[0], 'data_0/Identity:0', [], tf.variant) self.assertEmpty(placeholders) def test_one_dataset_of_integers_without_dataspec(self): with tf.Graph().as_default(): token_placeholder, data_values, placeholders = ( graph_helpers.embed_data_logic(tff.SequenceType((tf.string))) ) self.assertTensorSpec(token_placeholder, 'data_token:0', [], tf.string) self.assertLen(data_values, 1) self.assertTensorSpec(data_values[0], 'data_0/Identity:0', [], tf.variant) self.assertLen(placeholders, 1) self.assertEqual(placeholders[0].name, 'example_selector:0') def test_two_datasets_of_integers_without_dataspec(self): with tf.Graph().as_default(): token_placeholder, data_values, placeholders = ( graph_helpers.embed_data_logic( collections.OrderedDict( A=tff.SequenceType((tf.string)), B=tff.SequenceType((tf.string)), ) ) ) self.assertTensorSpec(token_placeholder, 'data_token:0', [], tf.string) self.assertLen(data_values, 2) self.assertTensorSpec(data_values[0], 'data_0/Identity:0', [], tf.variant) self.assertTensorSpec(data_values[1], 'data_1/Identity:0', [], tf.variant) self.assertLen(placeholders, 2) self.assertEqual(placeholders[0].name, 'example_selector_0:0') self.assertEqual(placeholders[1].name, 'example_selector_1:0') def test_nested_input_without_dataspec(self): with tf.Graph().as_default(): token_placeholder, data_values, placeholders = ( graph_helpers.embed_data_logic( collections.OrderedDict( A=collections.OrderedDict(B=tff.SequenceType((tf.string))) ) ) ) self.assertTensorSpec(token_placeholder, 'data_token:0', [], tf.string) self.assertLen(data_values, 1) self.assertTensorSpec(data_values[0], 'data_0/Identity:0', [], tf.variant) self.assertLen(placeholders, 1) self.assertEqual(placeholders[0].name, 'example_selector_0_0:0') class GraphHelperTest(absltest.TestCase): def test_import_tensorflow(self): # NOTE: Minimal test for now, since this is exercised by other components, # just a single example with a combo of all flavors of params and results. @tff.tf_computation(tff.SequenceType(tf.int64), tf.int64) def work(ds, x): return x + 1, ds.map(lambda a: a + x) with tf.Graph().as_default(): ds = tf.data.experimental.to_variant(tf.data.Dataset.range(3)) v = tf.constant(10, dtype=tf.int64) y, ds2_variant = graph_helpers.import_tensorflow( 'work', work, ([ds], [v]), split_outputs=True ) ds2 = tf.data.experimental.from_variant( ds2_variant[0], tf.TensorSpec([], tf.int64) ) z = ds2.reduce(np.int64(0), lambda x, y: x + y) with tf.compat.v1.Session() as sess: self.assertEqual(sess.run(y[0]), 11) self.assertEqual(sess.run(z), 33) def test_import_tensorflow_with_session_token(self): @tff.tf_computation def return_value(): return tff.framework.get_session_token() with tf.Graph().as_default(): x = tf.compat.v1.placeholder(dtype=tf.string) output = graph_helpers.import_tensorflow( 'return_value', comp=return_value, session_token_tensor=x ) with tf.compat.v1.Session() as sess: self.assertEqual(sess.run(output[0], feed_dict={x: 'value'}), b'value') def test_import_tensorflow_with_control_dep_remap(self): # Assert that importing graphdef remaps both regular and control dep inputs. @tff.tf_computation(tf.int64, tf.int64) def work(x, y): # Insert a control dependency to ensure it is remapped during import. with tf.compat.v1.control_dependencies([y]): return tf.identity(x) with tf.Graph().as_default(): x = tf.compat.v1.placeholder(dtype=tf.int64) y = tf.compat.v1.placeholder(dtype=tf.int64) output = graph_helpers.import_tensorflow( 'control_dep_graph', comp=work, args=[x, y] ) with tf.compat.v1.Session() as sess: self.assertEqual(sess.run(output, feed_dict={x: 10, y: 20})[0], 10) def test_add_control_deps_for_init_op(self): # Creates a graph (double edges are regular dependencies, single edges are # control dependencies) like this: # # ghi # | # def # || # def:0 foo # || // || # abc bar || # \ // \\ || # bak baz # graph_def = tf.compat.v1.GraphDef( node=[ tf.compat.v1.NodeDef(name='foo', input=[]), tf.compat.v1.NodeDef(name='bar', input=['foo']), tf.compat.v1.NodeDef(name='baz', input=['foo', 'bar']), tf.compat.v1.NodeDef(name='bak', input=['bar', '^abc']), tf.compat.v1.NodeDef(name='abc', input=['def:0']), tf.compat.v1.NodeDef(name='def', input=['^ghi']), tf.compat.v1.NodeDef(name='ghi', input=[]), ] ) new_graph_def = graph_helpers.add_control_deps_for_init_op(graph_def, 'abc') self.assertEqual( ','.join( '{}({})'.format(node.name, ','.join(node.input)) for node in new_graph_def.node ), ( 'foo(^abc),bar(foo,^abc),baz(foo,bar,^abc),' 'bak(bar,^abc),abc(def:0),def(^ghi),ghi()' ), ) def test_create_tensor_map_with_sequence_binding_and_variant(self): with tf.Graph().as_default(): variant_tensor = tf.data.experimental.to_variant(tf.data.Dataset.range(3)) input_map = graph_helpers.create_tensor_map( computation_pb2.TensorFlow.Binding( sequence=computation_pb2.TensorFlow.SequenceBinding( variant_tensor_name='foo' ) ), [variant_tensor], ) self.assertLen(input_map, 1) self.assertCountEqual(list(input_map.keys()), ['foo']) self.assertIs(input_map['foo'], variant_tensor) def test_create_tensor_map_with_sequence_binding_and_multiple_variants(self): with tf.Graph().as_default(): variant_tensor = tf.data.experimental.to_variant(tf.data.Dataset.range(3)) with self.assertRaises(ValueError): graph_helpers.create_tensor_map( computation_pb2.TensorFlow.Binding( sequence=computation_pb2.TensorFlow.SequenceBinding( variant_tensor_name='foo' ) ), [variant_tensor, variant_tensor], ) def test_create_tensor_map_with_sequence_binding_and_non_variant(self): with tf.Graph().as_default(): non_variant_tensor = tf.constant(1) with self.assertRaises(TypeError): graph_helpers.create_tensor_map( computation_pb2.TensorFlow.Binding( sequence=computation_pb2.TensorFlow.SequenceBinding( variant_tensor_name='foo' ) ), [non_variant_tensor], ) def test_create_tensor_map_with_non_sequence_binding_and_vars(self): with tf.Graph().as_default(): vars_list = variable_helpers.create_vars_for_tff_type( tff.to_type([('a', tf.int32), ('b', tf.int32)]) ) init_op = tf.compat.v1.global_variables_initializer() assign_op = tf.group( *(v.assign(tf.constant(k + 1)) for k, v in enumerate(vars_list)) ) input_map = graph_helpers.create_tensor_map( computation_pb2.TensorFlow.Binding( struct=computation_pb2.TensorFlow.StructBinding( element=[ computation_pb2.TensorFlow.Binding( tensor=computation_pb2.TensorFlow.TensorBinding( tensor_name='foo' ) ), computation_pb2.TensorFlow.Binding( tensor=computation_pb2.TensorFlow.TensorBinding( tensor_name='bar' ) ), ] ) ), vars_list, ) with tf.compat.v1.Session() as sess: sess.run(init_op) sess.run(assign_op) self.assertDictEqual(sess.run(input_map), {'foo': 1, 'bar': 2}) def test_get_deps_for_graph_node(self): # Creates a graph (double edges are regular dependencies, single edges are # control dependencies) like this: # foo # // \\ # foo:0 foo:1 # || // # abc bar // # // \ // \\ // # abc:0 bak baz # || # def # | # ghi # graph_def = tf.compat.v1.GraphDef( node=[ tf.compat.v1.NodeDef(name='foo', input=[]), tf.compat.v1.NodeDef(name='bar', input=['foo:0']), tf.compat.v1.NodeDef(name='baz', input=['foo:1', 'bar']), tf.compat.v1.NodeDef(name='bak', input=['bar', '^abc']), tf.compat.v1.NodeDef(name='abc', input=[]), tf.compat.v1.NodeDef(name='def', input=['abc:0']), tf.compat.v1.NodeDef(name='ghi', input=['^def']), ] ) def _get_deps(x): return ','.join( sorted(list(graph_helpers._get_deps_for_graph_node(graph_def, x))) ) self.assertEqual(_get_deps('foo'), '') self.assertEqual(_get_deps('bar'), 'foo') self.assertEqual(_get_deps('baz'), 'bar,foo') self.assertEqual(_get_deps('bak'), 'abc,bar,foo') self.assertEqual(_get_deps('abc'), '') self.assertEqual(_get_deps('def'), 'abc') self.assertEqual(_get_deps('ghi'), 'abc,def') def test_list_tensor_names_in_binding(self): binding = computation_pb2.TensorFlow.Binding( struct=computation_pb2.TensorFlow.StructBinding( element=[ computation_pb2.TensorFlow.Binding( tensor=computation_pb2.TensorFlow.TensorBinding( tensor_name='a' ) ), computation_pb2.TensorFlow.Binding( struct=computation_pb2.TensorFlow.StructBinding( element=[ computation_pb2.TensorFlow.Binding( tensor=computation_pb2.TensorFlow.TensorBinding( tensor_name='b' ) ), computation_pb2.TensorFlow.Binding( tensor=computation_pb2.TensorFlow.TensorBinding( tensor_name='c' ) ), ] ) ), computation_pb2.TensorFlow.Binding( tensor=computation_pb2.TensorFlow.TensorBinding( tensor_name='d' ) ), computation_pb2.TensorFlow.Binding( sequence=computation_pb2.TensorFlow.SequenceBinding( variant_tensor_name='e' ) ), ] ) ) self.assertEqual( graph_helpers._list_tensor_names_in_binding(binding), ['a', 'b', 'c', 'd', 'e'], ) if __name__ == '__main__': with tff.framework.get_context_stack().install( tff.test.create_runtime_error_context() ): absltest.main()