1# Copyright 2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests for graph_helpers.py.""" 15 16import collections 17 18from absl.testing import absltest 19 20import numpy as np 21import tensorflow as tf 22import tensorflow_federated as tff 23 24from fcp.artifact_building import data_spec 25from fcp.artifact_building import graph_helpers 26from fcp.artifact_building import variable_helpers 27from fcp.protos import plan_pb2 28from tensorflow_federated.proto.v0 import computation_pb2 29 30TRAIN_URI = 'boo' 31TEST_URI = 'foo' 32NUM_PIXELS = 784 33FAKE_INPUT_DIRECTORY_TENSOR = tf.constant('/path/to/input_dir') 34 35 36class EmbedDataLogicTest(absltest.TestCase): 37 38 def assertTensorSpec(self, tensor, name, shape, dtype): 39 self.assertIsInstance(tensor, tf.Tensor) 40 self.assertEqual(tensor.name, name) 41 self.assertEqual(tensor.shape, shape) 42 self.assertEqual(tensor.dtype, dtype) 43 44 def test_one_dataset_of_integers_w_dataspec(self): 45 with tf.Graph().as_default(): 46 token_placeholder, data_values, placeholders = ( 47 graph_helpers.embed_data_logic( 48 tff.SequenceType((tf.string)), 49 data_spec.DataSpec( 50 plan_pb2.ExampleSelector(collection_uri='app://fake_uri') 51 ), 52 ) 53 ) 54 55 self.assertTensorSpec(token_placeholder, 'data_token:0', [], tf.string) 56 self.assertLen(data_values, 1) 57 self.assertTensorSpec(data_values[0], 'data_0/Identity:0', [], tf.variant) 58 self.assertEmpty(placeholders) 59 60 def test_two_datasets_of_integers_w_dataspec(self): 61 with tf.Graph().as_default(): 62 token_placeholder, data_values, placeholders = ( 63 graph_helpers.embed_data_logic( 64 collections.OrderedDict( 65 A=tff.SequenceType((tf.string)), 66 B=tff.SequenceType((tf.string)), 67 ), 68 collections.OrderedDict( 69 A=data_spec.DataSpec( 70 plan_pb2.ExampleSelector(collection_uri='app://foo') 71 ), 72 B=data_spec.DataSpec( 73 plan_pb2.ExampleSelector(collection_uri='app://bar') 74 ), 75 ), 76 ) 77 ) 78 79 self.assertTensorSpec(token_placeholder, 'data_token:0', [], tf.string) 80 81 self.assertLen(data_values, 2) 82 self.assertTensorSpec(data_values[0], 'data_0/Identity:0', [], tf.variant) 83 self.assertTensorSpec(data_values[1], 'data_1/Identity:0', [], tf.variant) 84 self.assertEmpty(placeholders) 85 86 def test_nested_dataspec(self): 87 with tf.Graph().as_default(): 88 token_placeholder, data_values, placeholders = ( 89 graph_helpers.embed_data_logic( 90 collections.OrderedDict( 91 A=collections.OrderedDict(B=tff.SequenceType((tf.string))) 92 ), 93 collections.OrderedDict( 94 A=collections.OrderedDict( 95 B=data_spec.DataSpec( 96 plan_pb2.ExampleSelector(collection_uri='app://foo') 97 ) 98 ) 99 ), 100 ) 101 ) 102 103 self.assertTensorSpec(token_placeholder, 'data_token:0', [], tf.string) 104 self.assertLen(data_values, 1) 105 self.assertTensorSpec(data_values[0], 'data_0/Identity:0', [], tf.variant) 106 self.assertEmpty(placeholders) 107 108 def test_one_dataset_of_integers_without_dataspec(self): 109 with tf.Graph().as_default(): 110 token_placeholder, data_values, placeholders = ( 111 graph_helpers.embed_data_logic(tff.SequenceType((tf.string))) 112 ) 113 114 self.assertTensorSpec(token_placeholder, 'data_token:0', [], tf.string) 115 self.assertLen(data_values, 1) 116 self.assertTensorSpec(data_values[0], 'data_0/Identity:0', [], tf.variant) 117 self.assertLen(placeholders, 1) 118 self.assertEqual(placeholders[0].name, 'example_selector:0') 119 120 def test_two_datasets_of_integers_without_dataspec(self): 121 with tf.Graph().as_default(): 122 token_placeholder, data_values, placeholders = ( 123 graph_helpers.embed_data_logic( 124 collections.OrderedDict( 125 A=tff.SequenceType((tf.string)), 126 B=tff.SequenceType((tf.string)), 127 ) 128 ) 129 ) 130 131 self.assertTensorSpec(token_placeholder, 'data_token:0', [], tf.string) 132 133 self.assertLen(data_values, 2) 134 self.assertTensorSpec(data_values[0], 'data_0/Identity:0', [], tf.variant) 135 self.assertTensorSpec(data_values[1], 'data_1/Identity:0', [], tf.variant) 136 self.assertLen(placeholders, 2) 137 self.assertEqual(placeholders[0].name, 'example_selector_0:0') 138 self.assertEqual(placeholders[1].name, 'example_selector_1:0') 139 140 def test_nested_input_without_dataspec(self): 141 with tf.Graph().as_default(): 142 token_placeholder, data_values, placeholders = ( 143 graph_helpers.embed_data_logic( 144 collections.OrderedDict( 145 A=collections.OrderedDict(B=tff.SequenceType((tf.string))) 146 ) 147 ) 148 ) 149 150 self.assertTensorSpec(token_placeholder, 'data_token:0', [], tf.string) 151 self.assertLen(data_values, 1) 152 self.assertTensorSpec(data_values[0], 'data_0/Identity:0', [], tf.variant) 153 self.assertLen(placeholders, 1) 154 self.assertEqual(placeholders[0].name, 'example_selector_0_0:0') 155 156 157class GraphHelperTest(absltest.TestCase): 158 159 def test_import_tensorflow(self): 160 # NOTE: Minimal test for now, since this is exercised by other components, 161 # just a single example with a combo of all flavors of params and results. 162 @tff.tf_computation(tff.SequenceType(tf.int64), tf.int64) 163 def work(ds, x): 164 return x + 1, ds.map(lambda a: a + x) 165 166 with tf.Graph().as_default(): 167 ds = tf.data.experimental.to_variant(tf.data.Dataset.range(3)) 168 v = tf.constant(10, dtype=tf.int64) 169 y, ds2_variant = graph_helpers.import_tensorflow( 170 'work', work, ([ds], [v]), split_outputs=True 171 ) 172 ds2 = tf.data.experimental.from_variant( 173 ds2_variant[0], tf.TensorSpec([], tf.int64) 174 ) 175 z = ds2.reduce(np.int64(0), lambda x, y: x + y) 176 with tf.compat.v1.Session() as sess: 177 self.assertEqual(sess.run(y[0]), 11) 178 self.assertEqual(sess.run(z), 33) 179 180 def test_import_tensorflow_with_session_token(self): 181 @tff.tf_computation 182 def return_value(): 183 return tff.framework.get_session_token() 184 185 with tf.Graph().as_default(): 186 x = tf.compat.v1.placeholder(dtype=tf.string) 187 output = graph_helpers.import_tensorflow( 188 'return_value', comp=return_value, session_token_tensor=x 189 ) 190 with tf.compat.v1.Session() as sess: 191 self.assertEqual(sess.run(output[0], feed_dict={x: 'value'}), b'value') 192 193 def test_import_tensorflow_with_control_dep_remap(self): 194 # Assert that importing graphdef remaps both regular and control dep inputs. 195 @tff.tf_computation(tf.int64, tf.int64) 196 def work(x, y): 197 # Insert a control dependency to ensure it is remapped during import. 198 with tf.compat.v1.control_dependencies([y]): 199 return tf.identity(x) 200 201 with tf.Graph().as_default(): 202 x = tf.compat.v1.placeholder(dtype=tf.int64) 203 y = tf.compat.v1.placeholder(dtype=tf.int64) 204 output = graph_helpers.import_tensorflow( 205 'control_dep_graph', comp=work, args=[x, y] 206 ) 207 with tf.compat.v1.Session() as sess: 208 self.assertEqual(sess.run(output, feed_dict={x: 10, y: 20})[0], 10) 209 210 def test_add_control_deps_for_init_op(self): 211 # Creates a graph (double edges are regular dependencies, single edges are 212 # control dependencies) like this: 213 # 214 # ghi 215 # | 216 # def 217 # || 218 # def:0 foo 219 # || // || 220 # abc bar || 221 # \ // \\ || 222 # bak baz 223 # 224 graph_def = tf.compat.v1.GraphDef( 225 node=[ 226 tf.compat.v1.NodeDef(name='foo', input=[]), 227 tf.compat.v1.NodeDef(name='bar', input=['foo']), 228 tf.compat.v1.NodeDef(name='baz', input=['foo', 'bar']), 229 tf.compat.v1.NodeDef(name='bak', input=['bar', '^abc']), 230 tf.compat.v1.NodeDef(name='abc', input=['def:0']), 231 tf.compat.v1.NodeDef(name='def', input=['^ghi']), 232 tf.compat.v1.NodeDef(name='ghi', input=[]), 233 ] 234 ) 235 new_graph_def = graph_helpers.add_control_deps_for_init_op(graph_def, 'abc') 236 self.assertEqual( 237 ','.join( 238 '{}({})'.format(node.name, ','.join(node.input)) 239 for node in new_graph_def.node 240 ), 241 ( 242 'foo(^abc),bar(foo,^abc),baz(foo,bar,^abc),' 243 'bak(bar,^abc),abc(def:0),def(^ghi),ghi()' 244 ), 245 ) 246 247 def test_create_tensor_map_with_sequence_binding_and_variant(self): 248 with tf.Graph().as_default(): 249 variant_tensor = tf.data.experimental.to_variant(tf.data.Dataset.range(3)) 250 input_map = graph_helpers.create_tensor_map( 251 computation_pb2.TensorFlow.Binding( 252 sequence=computation_pb2.TensorFlow.SequenceBinding( 253 variant_tensor_name='foo' 254 ) 255 ), 256 [variant_tensor], 257 ) 258 self.assertLen(input_map, 1) 259 self.assertCountEqual(list(input_map.keys()), ['foo']) 260 self.assertIs(input_map['foo'], variant_tensor) 261 262 def test_create_tensor_map_with_sequence_binding_and_multiple_variants(self): 263 with tf.Graph().as_default(): 264 variant_tensor = tf.data.experimental.to_variant(tf.data.Dataset.range(3)) 265 with self.assertRaises(ValueError): 266 graph_helpers.create_tensor_map( 267 computation_pb2.TensorFlow.Binding( 268 sequence=computation_pb2.TensorFlow.SequenceBinding( 269 variant_tensor_name='foo' 270 ) 271 ), 272 [variant_tensor, variant_tensor], 273 ) 274 275 def test_create_tensor_map_with_sequence_binding_and_non_variant(self): 276 with tf.Graph().as_default(): 277 non_variant_tensor = tf.constant(1) 278 with self.assertRaises(TypeError): 279 graph_helpers.create_tensor_map( 280 computation_pb2.TensorFlow.Binding( 281 sequence=computation_pb2.TensorFlow.SequenceBinding( 282 variant_tensor_name='foo' 283 ) 284 ), 285 [non_variant_tensor], 286 ) 287 288 def test_create_tensor_map_with_non_sequence_binding_and_vars(self): 289 with tf.Graph().as_default(): 290 vars_list = variable_helpers.create_vars_for_tff_type( 291 tff.to_type([('a', tf.int32), ('b', tf.int32)]) 292 ) 293 init_op = tf.compat.v1.global_variables_initializer() 294 assign_op = tf.group( 295 *(v.assign(tf.constant(k + 1)) for k, v in enumerate(vars_list)) 296 ) 297 input_map = graph_helpers.create_tensor_map( 298 computation_pb2.TensorFlow.Binding( 299 struct=computation_pb2.TensorFlow.StructBinding( 300 element=[ 301 computation_pb2.TensorFlow.Binding( 302 tensor=computation_pb2.TensorFlow.TensorBinding( 303 tensor_name='foo' 304 ) 305 ), 306 computation_pb2.TensorFlow.Binding( 307 tensor=computation_pb2.TensorFlow.TensorBinding( 308 tensor_name='bar' 309 ) 310 ), 311 ] 312 ) 313 ), 314 vars_list, 315 ) 316 with tf.compat.v1.Session() as sess: 317 sess.run(init_op) 318 sess.run(assign_op) 319 self.assertDictEqual(sess.run(input_map), {'foo': 1, 'bar': 2}) 320 321 def test_get_deps_for_graph_node(self): 322 # Creates a graph (double edges are regular dependencies, single edges are 323 # control dependencies) like this: 324 # foo 325 # // \\ 326 # foo:0 foo:1 327 # || // 328 # abc bar // 329 # // \ // \\ // 330 # abc:0 bak baz 331 # || 332 # def 333 # | 334 # ghi 335 # 336 graph_def = tf.compat.v1.GraphDef( 337 node=[ 338 tf.compat.v1.NodeDef(name='foo', input=[]), 339 tf.compat.v1.NodeDef(name='bar', input=['foo:0']), 340 tf.compat.v1.NodeDef(name='baz', input=['foo:1', 'bar']), 341 tf.compat.v1.NodeDef(name='bak', input=['bar', '^abc']), 342 tf.compat.v1.NodeDef(name='abc', input=[]), 343 tf.compat.v1.NodeDef(name='def', input=['abc:0']), 344 tf.compat.v1.NodeDef(name='ghi', input=['^def']), 345 ] 346 ) 347 348 def _get_deps(x): 349 return ','.join( 350 sorted(list(graph_helpers._get_deps_for_graph_node(graph_def, x))) 351 ) 352 353 self.assertEqual(_get_deps('foo'), '') 354 self.assertEqual(_get_deps('bar'), 'foo') 355 self.assertEqual(_get_deps('baz'), 'bar,foo') 356 self.assertEqual(_get_deps('bak'), 'abc,bar,foo') 357 self.assertEqual(_get_deps('abc'), '') 358 self.assertEqual(_get_deps('def'), 'abc') 359 self.assertEqual(_get_deps('ghi'), 'abc,def') 360 361 def test_list_tensor_names_in_binding(self): 362 binding = computation_pb2.TensorFlow.Binding( 363 struct=computation_pb2.TensorFlow.StructBinding( 364 element=[ 365 computation_pb2.TensorFlow.Binding( 366 tensor=computation_pb2.TensorFlow.TensorBinding( 367 tensor_name='a' 368 ) 369 ), 370 computation_pb2.TensorFlow.Binding( 371 struct=computation_pb2.TensorFlow.StructBinding( 372 element=[ 373 computation_pb2.TensorFlow.Binding( 374 tensor=computation_pb2.TensorFlow.TensorBinding( 375 tensor_name='b' 376 ) 377 ), 378 computation_pb2.TensorFlow.Binding( 379 tensor=computation_pb2.TensorFlow.TensorBinding( 380 tensor_name='c' 381 ) 382 ), 383 ] 384 ) 385 ), 386 computation_pb2.TensorFlow.Binding( 387 tensor=computation_pb2.TensorFlow.TensorBinding( 388 tensor_name='d' 389 ) 390 ), 391 computation_pb2.TensorFlow.Binding( 392 sequence=computation_pb2.TensorFlow.SequenceBinding( 393 variant_tensor_name='e' 394 ) 395 ), 396 ] 397 ) 398 ) 399 self.assertEqual( 400 graph_helpers._list_tensor_names_in_binding(binding), 401 ['a', 'b', 'c', 'd', 'e'], 402 ) 403 404 405if __name__ == '__main__': 406 with tff.framework.get_context_stack().install( 407 tff.test.create_runtime_error_context() 408 ): 409 absltest.main() 410