1# Copyright 2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either expresus or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests for federated_context.""" 15 16import http 17import http.client 18import socket 19import threading 20import unittest 21from unittest import mock 22 23from absl.testing import absltest 24import attr 25import tensorflow as tf 26import tensorflow_federated as tff 27 28from fcp.artifact_building import artifact_constants 29from fcp.artifact_building import federated_compute_plan_builder 30from fcp.artifact_building import plan_utils 31from fcp.artifact_building import variable_helpers 32from fcp.demo import federated_computation 33from fcp.demo import federated_context 34from fcp.demo import federated_data_source 35from fcp.demo import server 36from fcp.demo import test_utils 37from fcp.protos import plan_pb2 38 39ADDRESS_FAMILY = socket.AddressFamily.AF_INET 40POPULATION_NAME = 'test/population' 41DATA_SOURCE = federated_data_source.FederatedDataSource( 42 POPULATION_NAME, plan_pb2.ExampleSelector(collection_uri='app:/test')) 43 44 45@tff.tf_computation(tf.int32) 46def add_one(x): 47 return x + 1 48 49 50@tff.federated_computation( 51 tff.type_at_server(tf.int32), 52 tff.type_at_clients(tff.SequenceType(tf.string))) 53def count_clients(state, client_data): 54 """Example TFF computation that counts clients.""" 55 del client_data 56 num_clients = tff.federated_sum(tff.federated_value(1, tff.CLIENTS)) 57 non_state = tff.federated_value((), tff.SERVER) 58 return state + num_clients, non_state 59 60 61@tff.federated_computation( 62 tff.type_at_server(tff.StructType([('foo', tf.int32), ('bar', tf.int32)])), 63 tff.type_at_clients(tff.SequenceType(tf.string)), 64) 65def irregular_arrays(state, client_data): 66 """Example TFF computation that returns irregular data.""" 67 del client_data 68 num_clients = tff.federated_sum(tff.federated_value(1, tff.CLIENTS)) 69 non_state = tff.federated_value(1, tff.SERVER) 70 return state, non_state + num_clients 71 72 73@attr.s(eq=False, frozen=True, slots=True) 74class TestClass: 75 """An attrs class.""" 76 77 field_one = attr.ib() 78 field_two = attr.ib() 79 80 81@tff.tf_computation 82def init(): 83 return TestClass(field_one=1, field_two=2) 84 85 86attrs_type = init.type_signature.result 87 88 89@tff.federated_computation( 90 tff.type_at_server(attrs_type), 91 tff.type_at_clients(tff.SequenceType(tf.string)), 92) 93def attrs_computation(state, client_data): 94 """Example TFF computation that returns an attrs class.""" 95 del client_data 96 num_clients = tff.federated_sum(tff.federated_value(1, tff.CLIENTS)) 97 non_state = tff.federated_value(1, tff.SERVER) 98 return state, non_state + num_clients 99 100 101def build_result_checkpoint(state: int) -> bytes: 102 """Helper function to build a result checkpoint for `count_clients`.""" 103 var_names = variable_helpers.variable_names_from_type( 104 count_clients.type_signature.result[0], 105 name=artifact_constants.SERVER_STATE_VAR_PREFIX) 106 return test_utils.create_checkpoint({var_names[0]: state}) 107 108 109class FederatedContextTest(absltest.TestCase, unittest.IsolatedAsyncioTestCase): 110 111 def test_invalid_population_name(self): 112 with self.assertRaisesRegex(ValueError, 'population_name must match ".+"'): 113 federated_context.FederatedContext( 114 '^^invalid^^', address_family=ADDRESS_FAMILY) 115 116 @mock.patch.object(server.InProcessServer, 'shutdown', autospec=True) 117 @mock.patch.object(server.InProcessServer, 'serve_forever', autospec=True) 118 def test_context_management(self, serve_forever, shutdown): 119 started = threading.Event() 120 serve_forever.side_effect = lambda *args, **kwargs: started.set() 121 122 ctx = federated_context.FederatedContext( 123 POPULATION_NAME, address_family=ADDRESS_FAMILY) 124 self.assertFalse(started.is_set()) 125 shutdown.assert_not_called() 126 with ctx: 127 self.assertTrue(started.wait(0.5)) 128 shutdown.assert_not_called() 129 shutdown.assert_called_once() 130 131 def test_http(self): 132 with federated_context.FederatedContext( 133 POPULATION_NAME, address_family=ADDRESS_FAMILY) as ctx: 134 conn = http.client.HTTPConnection('localhost', port=ctx.server_port) 135 conn.request('GET', '/does-not-exist') 136 self.assertEqual(conn.getresponse().status, http.HTTPStatus.NOT_FOUND) 137 138 def test_invoke_non_federated_with_base_context(self): 139 base_context = tff.backends.native.create_sync_local_cpp_execution_context() 140 ctx = federated_context.FederatedContext( 141 POPULATION_NAME, 142 address_family=ADDRESS_FAMILY, 143 base_context=base_context) 144 with tff.framework.get_context_stack().install(ctx): 145 self.assertEqual(add_one(3), 4) 146 147 def test_invoke_non_federated_without_base_context(self): 148 ctx = federated_context.FederatedContext( 149 POPULATION_NAME, address_family=ADDRESS_FAMILY) 150 with tff.framework.get_context_stack().install(ctx): 151 with self.assertRaisesRegex(TypeError, 152 'computation must be a FederatedComputation'): 153 add_one(3) 154 155 def test_invoke_with_invalid_state_type(self): 156 comp = federated_computation.FederatedComputation(count_clients, name='x') 157 ctx = federated_context.FederatedContext( 158 POPULATION_NAME, address_family=ADDRESS_FAMILY) 159 with tff.framework.get_context_stack().install(ctx): 160 with self.assertRaisesRegex( 161 TypeError, r'arg\[0\] must be a value or structure of values' 162 ): 163 comp(plan_pb2.Plan(), DATA_SOURCE.iterator().select(1)) 164 165 def test_invoke_with_invalid_data_source_type(self): 166 comp = federated_computation.FederatedComputation(count_clients, name='x') 167 ctx = federated_context.FederatedContext( 168 POPULATION_NAME, address_family=ADDRESS_FAMILY) 169 with tff.framework.get_context_stack().install(ctx): 170 with self.assertRaisesRegex( 171 TypeError, r'arg\[1\] must be the result of ' 172 r'FederatedDataSource.iterator\(\).select\(\)'): 173 comp(0, plan_pb2.Plan()) 174 175 def test_invoke_succeeds_with_structure_state_type(self): 176 comp = federated_computation.FederatedComputation( 177 irregular_arrays, name='x' 178 ) 179 ctx = federated_context.FederatedContext( 180 POPULATION_NAME, address_family=ADDRESS_FAMILY 181 ) 182 with tff.framework.get_context_stack().install(ctx): 183 state = {'foo': (3, 1), 'bar': (4, 5, 6)} 184 comp(state, DATA_SOURCE.iterator().select(1)) 185 186 def test_invoke_succeeds_with_attrs_state_type(self): 187 comp = federated_computation.FederatedComputation( 188 attrs_computation, name='x' 189 ) 190 ctx = federated_context.FederatedContext( 191 POPULATION_NAME, address_family=ADDRESS_FAMILY 192 ) 193 with tff.framework.get_context_stack().install(ctx): 194 state = TestClass(field_one=1, field_two=2) 195 comp(state, DATA_SOURCE.iterator().select(1)) 196 197 def test_invoke_with_mismatched_population_names(self): 198 comp = federated_computation.FederatedComputation(count_clients, name='x') 199 ds = federated_data_source.FederatedDataSource('other/name', 200 DATA_SOURCE.example_selector) 201 ctx = federated_context.FederatedContext( 202 POPULATION_NAME, address_family=ADDRESS_FAMILY) 203 with tff.framework.get_context_stack().install(ctx): 204 with self.assertRaisesRegex( 205 ValueError, 'FederatedDataSource and FederatedContext ' 206 'population_names must match'): 207 comp(0, ds.iterator().select(1)) 208 209 @mock.patch.object(server.InProcessServer, 'run_computation', autospec=True) 210 async def test_invoke_success(self, run_computation): 211 run_computation.return_value = build_result_checkpoint(7) 212 213 comp = federated_computation.FederatedComputation(count_clients, name='x') 214 ctx = federated_context.FederatedContext( 215 POPULATION_NAME, address_family=ADDRESS_FAMILY) 216 release_manager = tff.program.MemoryReleaseManager() 217 with tff.framework.get_context_stack().install(ctx): 218 state, _ = comp(3, DATA_SOURCE.iterator().select(10)) 219 await release_manager.release( 220 state, tff.type_at_server(tf.int32), key='result') 221 222 self.assertEqual(release_manager.values()['result'][0], 7) 223 224 run_computation.assert_called_once_with( 225 mock.ANY, 226 comp.name, 227 mock.ANY, 228 mock.ANY, 229 DATA_SOURCE.task_assignment_mode, 230 10, 231 ) 232 plan = run_computation.call_args.args[2] 233 self.assertIsInstance(plan, plan_pb2.Plan) 234 self.assertNotEmpty(plan.client_tflite_graph_bytes) 235 input_var_names = variable_helpers.variable_names_from_type( 236 count_clients.type_signature.parameter[0], 237 name=artifact_constants.SERVER_STATE_VAR_PREFIX) 238 self.assertLen(input_var_names, 1) 239 self.assertEqual( 240 test_utils.read_tensor_from_checkpoint( 241 run_computation.call_args.args[3], input_var_names[0], tf.int32), 3) 242 243 @mock.patch.object(server.InProcessServer, 'run_computation', autospec=True) 244 async def test_invoke_with_value_reference(self, run_computation): 245 run_computation.side_effect = [ 246 build_result_checkpoint(1234), 247 build_result_checkpoint(5678) 248 ] 249 250 comp = federated_computation.FederatedComputation(count_clients, name='x') 251 ctx = federated_context.FederatedContext( 252 POPULATION_NAME, address_family=ADDRESS_FAMILY) 253 release_manager = tff.program.MemoryReleaseManager() 254 with tff.framework.get_context_stack().install(ctx): 255 state, _ = comp(3, DATA_SOURCE.iterator().select(10)) 256 state, _ = comp(state, DATA_SOURCE.iterator().select(10)) 257 await release_manager.release( 258 state, tff.type_at_server(tf.int32), key='result') 259 260 self.assertEqual(release_manager.values()['result'][0], 5678) 261 262 input_var_names = variable_helpers.variable_names_from_type( 263 count_clients.type_signature.parameter[0], 264 name=artifact_constants.SERVER_STATE_VAR_PREFIX) 265 self.assertLen(input_var_names, 1) 266 # The second invocation should be passed the value returned by the first 267 # invocation. 268 self.assertEqual(run_computation.call_count, 2) 269 self.assertEqual( 270 test_utils.read_tensor_from_checkpoint( 271 run_computation.call_args.args[3], input_var_names[0], tf.int32), 272 1234) 273 274 async def test_invoke_without_input_state(self): 275 comp = federated_computation.FederatedComputation(count_clients, name='x') 276 ctx = federated_context.FederatedContext( 277 POPULATION_NAME, address_family=ADDRESS_FAMILY) 278 with tff.framework.get_context_stack().install(ctx): 279 with self.assertRaisesRegex( 280 TypeError, r'arg\[0\] must be a value or structure of values' 281 ): 282 comp(None, DATA_SOURCE.iterator().select(1)) 283 284 @mock.patch.object(server.InProcessServer, 'run_computation', autospec=True) 285 async def test_invoke_with_run_computation_error(self, run_computation): 286 run_computation.side_effect = ValueError('message') 287 288 comp = federated_computation.FederatedComputation(count_clients, name='x') 289 ctx = federated_context.FederatedContext( 290 POPULATION_NAME, address_family=ADDRESS_FAMILY) 291 release_manager = tff.program.MemoryReleaseManager() 292 with tff.framework.get_context_stack().install(ctx): 293 state, _ = comp(0, DATA_SOURCE.iterator().select(10)) 294 with self.assertRaisesRegex(ValueError, 'message'): 295 await release_manager.release( 296 state, tff.type_at_server(tf.int32), key='result') 297 298 299class FederatedContextPlanCachingTest(absltest.TestCase, 300 unittest.IsolatedAsyncioTestCase): 301 302 async def asyncSetUp(self): 303 await super().asyncSetUp() 304 305 @tff.federated_computation( 306 tff.type_at_server(tf.int32), 307 tff.type_at_clients(tff.SequenceType(tf.string))) 308 def identity(state, client_data): 309 del client_data 310 return state, tff.federated_value((), tff.SERVER) 311 312 self.count_clients_comp1 = federated_computation.FederatedComputation( 313 count_clients, name='count_clients1') 314 self.count_clients_comp2 = federated_computation.FederatedComputation( 315 count_clients, name='count_clients2') 316 self.identity_comp = federated_computation.FederatedComputation( 317 identity, name='identity') 318 319 self.data_source1 = federated_data_source.FederatedDataSource( 320 POPULATION_NAME, plan_pb2.ExampleSelector(collection_uri='app:/1')) 321 self.data_source2 = federated_data_source.FederatedDataSource( 322 POPULATION_NAME, plan_pb2.ExampleSelector(collection_uri='app:/2')) 323 324 self.run_computation = self.enter_context( 325 mock.patch.object( 326 server.InProcessServer, 'run_computation', autospec=True)) 327 self.run_computation.return_value = build_result_checkpoint(0) 328 self.build_plan = self.enter_context( 329 mock.patch.object( 330 federated_compute_plan_builder, 'build_plan', autospec=True)) 331 self.build_plan.return_value = plan_pb2.Plan() 332 self.generate_and_add_flat_buffer_to_plan = self.enter_context( 333 mock.patch.object( 334 plan_utils, 'generate_and_add_flat_buffer_to_plan', autospec=True)) 335 self.generate_and_add_flat_buffer_to_plan.side_effect = lambda plan: plan 336 self.enter_context(tff.framework.get_context_stack().install( 337 federated_context.FederatedContext( 338 POPULATION_NAME, address_family=ADDRESS_FAMILY))) 339 self.release_manager = tff.program.MemoryReleaseManager() 340 341 # Run (and therefore cache) count_clients_comp1 with data_source1. 342 await self.release_manager.release( 343 self.count_clients_comp1(0, 344 self.data_source1.iterator().select(1)), 345 self.count_clients_comp1.type_signature.result, 346 key='result') 347 self.build_plan.assert_called_once() 348 self.assertEqual(self.build_plan.call_args.args[0], 349 self.count_clients_comp1.map_reduce_form) 350 self.assertEqual( 351 self.build_plan.call_args.args[1], 352 self.count_clients_comp1.distribute_aggregate_form, 353 ) 354 self.assertEqual( 355 self.build_plan.call_args.args[2].example_selector_proto, 356 self.data_source1.example_selector, 357 ) 358 self.run_computation.assert_called_once() 359 self.build_plan.reset_mock() 360 self.run_computation.reset_mock() 361 362 async def test_reuse_with_repeat_computation(self): 363 await self.release_manager.release( 364 self.count_clients_comp1(0, 365 self.data_source1.iterator().select(1)), 366 self.count_clients_comp1.type_signature.result, 367 key='result') 368 self.build_plan.assert_not_called() 369 self.run_computation.assert_called_once() 370 371 async def test_reuse_with_changed_num_clients(self): 372 await self.release_manager.release( 373 self.count_clients_comp1(0, 374 self.data_source1.iterator().select(10)), 375 self.count_clients_comp1.type_signature.result, 376 key='result') 377 self.build_plan.assert_not_called() 378 self.run_computation.assert_called_once() 379 380 async def test_reuse_with_changed_initial_state(self): 381 await self.release_manager.release( 382 self.count_clients_comp1(3, 383 self.data_source1.iterator().select(1)), 384 self.count_clients_comp1.type_signature.result, 385 key='result') 386 self.build_plan.assert_not_called() 387 self.run_computation.assert_called_once() 388 389 async def test_reuse_with_equivalent_map_reduce_form(self): 390 await self.release_manager.release( 391 self.count_clients_comp2(0, 392 self.data_source1.iterator().select(1)), 393 self.count_clients_comp2.type_signature.result, 394 key='result') 395 self.build_plan.assert_not_called() 396 self.run_computation.assert_called_once() 397 398 async def test_rebuild_with_different_computation(self): 399 await self.release_manager.release( 400 self.identity_comp(0, 401 self.data_source1.iterator().select(1)), 402 self.identity_comp.type_signature.result, 403 key='result') 404 self.build_plan.assert_called_once() 405 self.assertEqual(self.build_plan.call_args.args[0], 406 self.identity_comp.map_reduce_form) 407 self.assertEqual( 408 self.build_plan.call_args.args[1], 409 self.identity_comp.distribute_aggregate_form, 410 ) 411 self.assertEqual( 412 self.build_plan.call_args.args[2].example_selector_proto, 413 self.data_source1.example_selector, 414 ) 415 self.run_computation.assert_called_once() 416 417 async def test_rebuild_with_different_data_source(self): 418 await self.release_manager.release( 419 self.count_clients_comp1(0, 420 self.data_source2.iterator().select(1)), 421 self.count_clients_comp1.type_signature.result, 422 key='result') 423 self.build_plan.assert_called_once() 424 self.assertEqual(self.build_plan.call_args.args[0], 425 self.count_clients_comp1.map_reduce_form) 426 self.assertEqual( 427 self.build_plan.call_args.args[1], 428 self.count_clients_comp1.distribute_aggregate_form, 429 ) 430 self.assertEqual( 431 self.build_plan.call_args.args[2].example_selector_proto, 432 self.data_source2.example_selector, 433 ) 434 self.run_computation.assert_called_once() 435 436 437if __name__ == '__main__': 438 absltest.main() 439