1*14675a02SAndroid Build Coastguard Worker# Copyright 2022 Google LLC 2*14675a02SAndroid Build Coastguard Worker# 3*14675a02SAndroid Build Coastguard Worker# Licensed under the Apache License, Version 2.0 (the "License"); 4*14675a02SAndroid Build Coastguard Worker# you may not use this file except in compliance with the License. 5*14675a02SAndroid Build Coastguard Worker# You may obtain a copy of the License at 6*14675a02SAndroid Build Coastguard Worker# 7*14675a02SAndroid Build Coastguard Worker# http://www.apache.org/licenses/LICENSE-2.0 8*14675a02SAndroid Build Coastguard Worker# 9*14675a02SAndroid Build Coastguard Worker# Unless required by applicable law or agreed to in writing, software 10*14675a02SAndroid Build Coastguard Worker# distributed under the License is distributed on an "AS IS" BASIS, 11*14675a02SAndroid Build Coastguard Worker# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either expresus or implied. 12*14675a02SAndroid Build Coastguard Worker# See the License for the specific language governing permissions and 13*14675a02SAndroid Build Coastguard Worker# limitations under the License. 14*14675a02SAndroid Build Coastguard Worker"""Tests for federated_context.""" 15*14675a02SAndroid Build Coastguard Worker 16*14675a02SAndroid Build Coastguard Workerimport http 17*14675a02SAndroid Build Coastguard Workerimport http.client 18*14675a02SAndroid Build Coastguard Workerimport socket 19*14675a02SAndroid Build Coastguard Workerimport threading 20*14675a02SAndroid Build Coastguard Workerimport unittest 21*14675a02SAndroid Build Coastguard Workerfrom unittest import mock 22*14675a02SAndroid Build Coastguard Worker 23*14675a02SAndroid Build Coastguard Workerfrom absl.testing import absltest 24*14675a02SAndroid Build Coastguard Workerimport attr 25*14675a02SAndroid Build Coastguard Workerimport tensorflow as tf 26*14675a02SAndroid Build Coastguard Workerimport tensorflow_federated as tff 27*14675a02SAndroid Build Coastguard Worker 28*14675a02SAndroid Build Coastguard Workerfrom fcp.artifact_building import artifact_constants 29*14675a02SAndroid Build Coastguard Workerfrom fcp.artifact_building import federated_compute_plan_builder 30*14675a02SAndroid Build Coastguard Workerfrom fcp.artifact_building import plan_utils 31*14675a02SAndroid Build Coastguard Workerfrom fcp.artifact_building import variable_helpers 32*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import federated_computation 33*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import federated_context 34*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import federated_data_source 35*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import server 36*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import test_utils 37*14675a02SAndroid Build Coastguard Workerfrom fcp.protos import plan_pb2 38*14675a02SAndroid Build Coastguard Worker 39*14675a02SAndroid Build Coastguard WorkerADDRESS_FAMILY = socket.AddressFamily.AF_INET 40*14675a02SAndroid Build Coastguard WorkerPOPULATION_NAME = 'test/population' 41*14675a02SAndroid Build Coastguard WorkerDATA_SOURCE = federated_data_source.FederatedDataSource( 42*14675a02SAndroid Build Coastguard Worker POPULATION_NAME, plan_pb2.ExampleSelector(collection_uri='app:/test')) 43*14675a02SAndroid Build Coastguard Worker 44*14675a02SAndroid Build Coastguard Worker 45*14675a02SAndroid Build Coastguard Worker@tff.tf_computation(tf.int32) 46*14675a02SAndroid Build Coastguard Workerdef add_one(x): 47*14675a02SAndroid Build Coastguard Worker return x + 1 48*14675a02SAndroid Build Coastguard Worker 49*14675a02SAndroid Build Coastguard Worker 50*14675a02SAndroid Build Coastguard Worker@tff.federated_computation( 51*14675a02SAndroid Build Coastguard Worker tff.type_at_server(tf.int32), 52*14675a02SAndroid Build Coastguard Worker tff.type_at_clients(tff.SequenceType(tf.string))) 53*14675a02SAndroid Build Coastguard Workerdef count_clients(state, client_data): 54*14675a02SAndroid Build Coastguard Worker """Example TFF computation that counts clients.""" 55*14675a02SAndroid Build Coastguard Worker del client_data 56*14675a02SAndroid Build Coastguard Worker num_clients = tff.federated_sum(tff.federated_value(1, tff.CLIENTS)) 57*14675a02SAndroid Build Coastguard Worker non_state = tff.federated_value((), tff.SERVER) 58*14675a02SAndroid Build Coastguard Worker return state + num_clients, non_state 59*14675a02SAndroid Build Coastguard Worker 60*14675a02SAndroid Build Coastguard Worker 61*14675a02SAndroid Build Coastguard Worker@tff.federated_computation( 62*14675a02SAndroid Build Coastguard Worker tff.type_at_server(tff.StructType([('foo', tf.int32), ('bar', tf.int32)])), 63*14675a02SAndroid Build Coastguard Worker tff.type_at_clients(tff.SequenceType(tf.string)), 64*14675a02SAndroid Build Coastguard Worker) 65*14675a02SAndroid Build Coastguard Workerdef irregular_arrays(state, client_data): 66*14675a02SAndroid Build Coastguard Worker """Example TFF computation that returns irregular data.""" 67*14675a02SAndroid Build Coastguard Worker del client_data 68*14675a02SAndroid Build Coastguard Worker num_clients = tff.federated_sum(tff.federated_value(1, tff.CLIENTS)) 69*14675a02SAndroid Build Coastguard Worker non_state = tff.federated_value(1, tff.SERVER) 70*14675a02SAndroid Build Coastguard Worker return state, non_state + num_clients 71*14675a02SAndroid Build Coastguard Worker 72*14675a02SAndroid Build Coastguard Worker 73*14675a02SAndroid Build Coastguard Worker@attr.s(eq=False, frozen=True, slots=True) 74*14675a02SAndroid Build Coastguard Workerclass TestClass: 75*14675a02SAndroid Build Coastguard Worker """An attrs class.""" 76*14675a02SAndroid Build Coastguard Worker 77*14675a02SAndroid Build Coastguard Worker field_one = attr.ib() 78*14675a02SAndroid Build Coastguard Worker field_two = attr.ib() 79*14675a02SAndroid Build Coastguard Worker 80*14675a02SAndroid Build Coastguard Worker 81*14675a02SAndroid Build Coastguard Worker@tff.tf_computation 82*14675a02SAndroid Build Coastguard Workerdef init(): 83*14675a02SAndroid Build Coastguard Worker return TestClass(field_one=1, field_two=2) 84*14675a02SAndroid Build Coastguard Worker 85*14675a02SAndroid Build Coastguard Worker 86*14675a02SAndroid Build Coastguard Workerattrs_type = init.type_signature.result 87*14675a02SAndroid Build Coastguard Worker 88*14675a02SAndroid Build Coastguard Worker 89*14675a02SAndroid Build Coastguard Worker@tff.federated_computation( 90*14675a02SAndroid Build Coastguard Worker tff.type_at_server(attrs_type), 91*14675a02SAndroid Build Coastguard Worker tff.type_at_clients(tff.SequenceType(tf.string)), 92*14675a02SAndroid Build Coastguard Worker) 93*14675a02SAndroid Build Coastguard Workerdef attrs_computation(state, client_data): 94*14675a02SAndroid Build Coastguard Worker """Example TFF computation that returns an attrs class.""" 95*14675a02SAndroid Build Coastguard Worker del client_data 96*14675a02SAndroid Build Coastguard Worker num_clients = tff.federated_sum(tff.federated_value(1, tff.CLIENTS)) 97*14675a02SAndroid Build Coastguard Worker non_state = tff.federated_value(1, tff.SERVER) 98*14675a02SAndroid Build Coastguard Worker return state, non_state + num_clients 99*14675a02SAndroid Build Coastguard Worker 100*14675a02SAndroid Build Coastguard Worker 101*14675a02SAndroid Build Coastguard Workerdef build_result_checkpoint(state: int) -> bytes: 102*14675a02SAndroid Build Coastguard Worker """Helper function to build a result checkpoint for `count_clients`.""" 103*14675a02SAndroid Build Coastguard Worker var_names = variable_helpers.variable_names_from_type( 104*14675a02SAndroid Build Coastguard Worker count_clients.type_signature.result[0], 105*14675a02SAndroid Build Coastguard Worker name=artifact_constants.SERVER_STATE_VAR_PREFIX) 106*14675a02SAndroid Build Coastguard Worker return test_utils.create_checkpoint({var_names[0]: state}) 107*14675a02SAndroid Build Coastguard Worker 108*14675a02SAndroid Build Coastguard Worker 109*14675a02SAndroid Build Coastguard Workerclass FederatedContextTest(absltest.TestCase, unittest.IsolatedAsyncioTestCase): 110*14675a02SAndroid Build Coastguard Worker 111*14675a02SAndroid Build Coastguard Worker def test_invalid_population_name(self): 112*14675a02SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, 'population_name must match ".+"'): 113*14675a02SAndroid Build Coastguard Worker federated_context.FederatedContext( 114*14675a02SAndroid Build Coastguard Worker '^^invalid^^', address_family=ADDRESS_FAMILY) 115*14675a02SAndroid Build Coastguard Worker 116*14675a02SAndroid Build Coastguard Worker @mock.patch.object(server.InProcessServer, 'shutdown', autospec=True) 117*14675a02SAndroid Build Coastguard Worker @mock.patch.object(server.InProcessServer, 'serve_forever', autospec=True) 118*14675a02SAndroid Build Coastguard Worker def test_context_management(self, serve_forever, shutdown): 119*14675a02SAndroid Build Coastguard Worker started = threading.Event() 120*14675a02SAndroid Build Coastguard Worker serve_forever.side_effect = lambda *args, **kwargs: started.set() 121*14675a02SAndroid Build Coastguard Worker 122*14675a02SAndroid Build Coastguard Worker ctx = federated_context.FederatedContext( 123*14675a02SAndroid Build Coastguard Worker POPULATION_NAME, address_family=ADDRESS_FAMILY) 124*14675a02SAndroid Build Coastguard Worker self.assertFalse(started.is_set()) 125*14675a02SAndroid Build Coastguard Worker shutdown.assert_not_called() 126*14675a02SAndroid Build Coastguard Worker with ctx: 127*14675a02SAndroid Build Coastguard Worker self.assertTrue(started.wait(0.5)) 128*14675a02SAndroid Build Coastguard Worker shutdown.assert_not_called() 129*14675a02SAndroid Build Coastguard Worker shutdown.assert_called_once() 130*14675a02SAndroid Build Coastguard Worker 131*14675a02SAndroid Build Coastguard Worker def test_http(self): 132*14675a02SAndroid Build Coastguard Worker with federated_context.FederatedContext( 133*14675a02SAndroid Build Coastguard Worker POPULATION_NAME, address_family=ADDRESS_FAMILY) as ctx: 134*14675a02SAndroid Build Coastguard Worker conn = http.client.HTTPConnection('localhost', port=ctx.server_port) 135*14675a02SAndroid Build Coastguard Worker conn.request('GET', '/does-not-exist') 136*14675a02SAndroid Build Coastguard Worker self.assertEqual(conn.getresponse().status, http.HTTPStatus.NOT_FOUND) 137*14675a02SAndroid Build Coastguard Worker 138*14675a02SAndroid Build Coastguard Worker def test_invoke_non_federated_with_base_context(self): 139*14675a02SAndroid Build Coastguard Worker base_context = tff.backends.native.create_sync_local_cpp_execution_context() 140*14675a02SAndroid Build Coastguard Worker ctx = federated_context.FederatedContext( 141*14675a02SAndroid Build Coastguard Worker POPULATION_NAME, 142*14675a02SAndroid Build Coastguard Worker address_family=ADDRESS_FAMILY, 143*14675a02SAndroid Build Coastguard Worker base_context=base_context) 144*14675a02SAndroid Build Coastguard Worker with tff.framework.get_context_stack().install(ctx): 145*14675a02SAndroid Build Coastguard Worker self.assertEqual(add_one(3), 4) 146*14675a02SAndroid Build Coastguard Worker 147*14675a02SAndroid Build Coastguard Worker def test_invoke_non_federated_without_base_context(self): 148*14675a02SAndroid Build Coastguard Worker ctx = federated_context.FederatedContext( 149*14675a02SAndroid Build Coastguard Worker POPULATION_NAME, address_family=ADDRESS_FAMILY) 150*14675a02SAndroid Build Coastguard Worker with tff.framework.get_context_stack().install(ctx): 151*14675a02SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, 152*14675a02SAndroid Build Coastguard Worker 'computation must be a FederatedComputation'): 153*14675a02SAndroid Build Coastguard Worker add_one(3) 154*14675a02SAndroid Build Coastguard Worker 155*14675a02SAndroid Build Coastguard Worker def test_invoke_with_invalid_state_type(self): 156*14675a02SAndroid Build Coastguard Worker comp = federated_computation.FederatedComputation(count_clients, name='x') 157*14675a02SAndroid Build Coastguard Worker ctx = federated_context.FederatedContext( 158*14675a02SAndroid Build Coastguard Worker POPULATION_NAME, address_family=ADDRESS_FAMILY) 159*14675a02SAndroid Build Coastguard Worker with tff.framework.get_context_stack().install(ctx): 160*14675a02SAndroid Build Coastguard Worker with self.assertRaisesRegex( 161*14675a02SAndroid Build Coastguard Worker TypeError, r'arg\[0\] must be a value or structure of values' 162*14675a02SAndroid Build Coastguard Worker ): 163*14675a02SAndroid Build Coastguard Worker comp(plan_pb2.Plan(), DATA_SOURCE.iterator().select(1)) 164*14675a02SAndroid Build Coastguard Worker 165*14675a02SAndroid Build Coastguard Worker def test_invoke_with_invalid_data_source_type(self): 166*14675a02SAndroid Build Coastguard Worker comp = federated_computation.FederatedComputation(count_clients, name='x') 167*14675a02SAndroid Build Coastguard Worker ctx = federated_context.FederatedContext( 168*14675a02SAndroid Build Coastguard Worker POPULATION_NAME, address_family=ADDRESS_FAMILY) 169*14675a02SAndroid Build Coastguard Worker with tff.framework.get_context_stack().install(ctx): 170*14675a02SAndroid Build Coastguard Worker with self.assertRaisesRegex( 171*14675a02SAndroid Build Coastguard Worker TypeError, r'arg\[1\] must be the result of ' 172*14675a02SAndroid Build Coastguard Worker r'FederatedDataSource.iterator\(\).select\(\)'): 173*14675a02SAndroid Build Coastguard Worker comp(0, plan_pb2.Plan()) 174*14675a02SAndroid Build Coastguard Worker 175*14675a02SAndroid Build Coastguard Worker def test_invoke_succeeds_with_structure_state_type(self): 176*14675a02SAndroid Build Coastguard Worker comp = federated_computation.FederatedComputation( 177*14675a02SAndroid Build Coastguard Worker irregular_arrays, name='x' 178*14675a02SAndroid Build Coastguard Worker ) 179*14675a02SAndroid Build Coastguard Worker ctx = federated_context.FederatedContext( 180*14675a02SAndroid Build Coastguard Worker POPULATION_NAME, address_family=ADDRESS_FAMILY 181*14675a02SAndroid Build Coastguard Worker ) 182*14675a02SAndroid Build Coastguard Worker with tff.framework.get_context_stack().install(ctx): 183*14675a02SAndroid Build Coastguard Worker state = {'foo': (3, 1), 'bar': (4, 5, 6)} 184*14675a02SAndroid Build Coastguard Worker comp(state, DATA_SOURCE.iterator().select(1)) 185*14675a02SAndroid Build Coastguard Worker 186*14675a02SAndroid Build Coastguard Worker def test_invoke_succeeds_with_attrs_state_type(self): 187*14675a02SAndroid Build Coastguard Worker comp = federated_computation.FederatedComputation( 188*14675a02SAndroid Build Coastguard Worker attrs_computation, name='x' 189*14675a02SAndroid Build Coastguard Worker ) 190*14675a02SAndroid Build Coastguard Worker ctx = federated_context.FederatedContext( 191*14675a02SAndroid Build Coastguard Worker POPULATION_NAME, address_family=ADDRESS_FAMILY 192*14675a02SAndroid Build Coastguard Worker ) 193*14675a02SAndroid Build Coastguard Worker with tff.framework.get_context_stack().install(ctx): 194*14675a02SAndroid Build Coastguard Worker state = TestClass(field_one=1, field_two=2) 195*14675a02SAndroid Build Coastguard Worker comp(state, DATA_SOURCE.iterator().select(1)) 196*14675a02SAndroid Build Coastguard Worker 197*14675a02SAndroid Build Coastguard Worker def test_invoke_with_mismatched_population_names(self): 198*14675a02SAndroid Build Coastguard Worker comp = federated_computation.FederatedComputation(count_clients, name='x') 199*14675a02SAndroid Build Coastguard Worker ds = federated_data_source.FederatedDataSource('other/name', 200*14675a02SAndroid Build Coastguard Worker DATA_SOURCE.example_selector) 201*14675a02SAndroid Build Coastguard Worker ctx = federated_context.FederatedContext( 202*14675a02SAndroid Build Coastguard Worker POPULATION_NAME, address_family=ADDRESS_FAMILY) 203*14675a02SAndroid Build Coastguard Worker with tff.framework.get_context_stack().install(ctx): 204*14675a02SAndroid Build Coastguard Worker with self.assertRaisesRegex( 205*14675a02SAndroid Build Coastguard Worker ValueError, 'FederatedDataSource and FederatedContext ' 206*14675a02SAndroid Build Coastguard Worker 'population_names must match'): 207*14675a02SAndroid Build Coastguard Worker comp(0, ds.iterator().select(1)) 208*14675a02SAndroid Build Coastguard Worker 209*14675a02SAndroid Build Coastguard Worker @mock.patch.object(server.InProcessServer, 'run_computation', autospec=True) 210*14675a02SAndroid Build Coastguard Worker async def test_invoke_success(self, run_computation): 211*14675a02SAndroid Build Coastguard Worker run_computation.return_value = build_result_checkpoint(7) 212*14675a02SAndroid Build Coastguard Worker 213*14675a02SAndroid Build Coastguard Worker comp = federated_computation.FederatedComputation(count_clients, name='x') 214*14675a02SAndroid Build Coastguard Worker ctx = federated_context.FederatedContext( 215*14675a02SAndroid Build Coastguard Worker POPULATION_NAME, address_family=ADDRESS_FAMILY) 216*14675a02SAndroid Build Coastguard Worker release_manager = tff.program.MemoryReleaseManager() 217*14675a02SAndroid Build Coastguard Worker with tff.framework.get_context_stack().install(ctx): 218*14675a02SAndroid Build Coastguard Worker state, _ = comp(3, DATA_SOURCE.iterator().select(10)) 219*14675a02SAndroid Build Coastguard Worker await release_manager.release( 220*14675a02SAndroid Build Coastguard Worker state, tff.type_at_server(tf.int32), key='result') 221*14675a02SAndroid Build Coastguard Worker 222*14675a02SAndroid Build Coastguard Worker self.assertEqual(release_manager.values()['result'][0], 7) 223*14675a02SAndroid Build Coastguard Worker 224*14675a02SAndroid Build Coastguard Worker run_computation.assert_called_once_with( 225*14675a02SAndroid Build Coastguard Worker mock.ANY, 226*14675a02SAndroid Build Coastguard Worker comp.name, 227*14675a02SAndroid Build Coastguard Worker mock.ANY, 228*14675a02SAndroid Build Coastguard Worker mock.ANY, 229*14675a02SAndroid Build Coastguard Worker DATA_SOURCE.task_assignment_mode, 230*14675a02SAndroid Build Coastguard Worker 10, 231*14675a02SAndroid Build Coastguard Worker ) 232*14675a02SAndroid Build Coastguard Worker plan = run_computation.call_args.args[2] 233*14675a02SAndroid Build Coastguard Worker self.assertIsInstance(plan, plan_pb2.Plan) 234*14675a02SAndroid Build Coastguard Worker self.assertNotEmpty(plan.client_tflite_graph_bytes) 235*14675a02SAndroid Build Coastguard Worker input_var_names = variable_helpers.variable_names_from_type( 236*14675a02SAndroid Build Coastguard Worker count_clients.type_signature.parameter[0], 237*14675a02SAndroid Build Coastguard Worker name=artifact_constants.SERVER_STATE_VAR_PREFIX) 238*14675a02SAndroid Build Coastguard Worker self.assertLen(input_var_names, 1) 239*14675a02SAndroid Build Coastguard Worker self.assertEqual( 240*14675a02SAndroid Build Coastguard Worker test_utils.read_tensor_from_checkpoint( 241*14675a02SAndroid Build Coastguard Worker run_computation.call_args.args[3], input_var_names[0], tf.int32), 3) 242*14675a02SAndroid Build Coastguard Worker 243*14675a02SAndroid Build Coastguard Worker @mock.patch.object(server.InProcessServer, 'run_computation', autospec=True) 244*14675a02SAndroid Build Coastguard Worker async def test_invoke_with_value_reference(self, run_computation): 245*14675a02SAndroid Build Coastguard Worker run_computation.side_effect = [ 246*14675a02SAndroid Build Coastguard Worker build_result_checkpoint(1234), 247*14675a02SAndroid Build Coastguard Worker build_result_checkpoint(5678) 248*14675a02SAndroid Build Coastguard Worker ] 249*14675a02SAndroid Build Coastguard Worker 250*14675a02SAndroid Build Coastguard Worker comp = federated_computation.FederatedComputation(count_clients, name='x') 251*14675a02SAndroid Build Coastguard Worker ctx = federated_context.FederatedContext( 252*14675a02SAndroid Build Coastguard Worker POPULATION_NAME, address_family=ADDRESS_FAMILY) 253*14675a02SAndroid Build Coastguard Worker release_manager = tff.program.MemoryReleaseManager() 254*14675a02SAndroid Build Coastguard Worker with tff.framework.get_context_stack().install(ctx): 255*14675a02SAndroid Build Coastguard Worker state, _ = comp(3, DATA_SOURCE.iterator().select(10)) 256*14675a02SAndroid Build Coastguard Worker state, _ = comp(state, DATA_SOURCE.iterator().select(10)) 257*14675a02SAndroid Build Coastguard Worker await release_manager.release( 258*14675a02SAndroid Build Coastguard Worker state, tff.type_at_server(tf.int32), key='result') 259*14675a02SAndroid Build Coastguard Worker 260*14675a02SAndroid Build Coastguard Worker self.assertEqual(release_manager.values()['result'][0], 5678) 261*14675a02SAndroid Build Coastguard Worker 262*14675a02SAndroid Build Coastguard Worker input_var_names = variable_helpers.variable_names_from_type( 263*14675a02SAndroid Build Coastguard Worker count_clients.type_signature.parameter[0], 264*14675a02SAndroid Build Coastguard Worker name=artifact_constants.SERVER_STATE_VAR_PREFIX) 265*14675a02SAndroid Build Coastguard Worker self.assertLen(input_var_names, 1) 266*14675a02SAndroid Build Coastguard Worker # The second invocation should be passed the value returned by the first 267*14675a02SAndroid Build Coastguard Worker # invocation. 268*14675a02SAndroid Build Coastguard Worker self.assertEqual(run_computation.call_count, 2) 269*14675a02SAndroid Build Coastguard Worker self.assertEqual( 270*14675a02SAndroid Build Coastguard Worker test_utils.read_tensor_from_checkpoint( 271*14675a02SAndroid Build Coastguard Worker run_computation.call_args.args[3], input_var_names[0], tf.int32), 272*14675a02SAndroid Build Coastguard Worker 1234) 273*14675a02SAndroid Build Coastguard Worker 274*14675a02SAndroid Build Coastguard Worker async def test_invoke_without_input_state(self): 275*14675a02SAndroid Build Coastguard Worker comp = federated_computation.FederatedComputation(count_clients, name='x') 276*14675a02SAndroid Build Coastguard Worker ctx = federated_context.FederatedContext( 277*14675a02SAndroid Build Coastguard Worker POPULATION_NAME, address_family=ADDRESS_FAMILY) 278*14675a02SAndroid Build Coastguard Worker with tff.framework.get_context_stack().install(ctx): 279*14675a02SAndroid Build Coastguard Worker with self.assertRaisesRegex( 280*14675a02SAndroid Build Coastguard Worker TypeError, r'arg\[0\] must be a value or structure of values' 281*14675a02SAndroid Build Coastguard Worker ): 282*14675a02SAndroid Build Coastguard Worker comp(None, DATA_SOURCE.iterator().select(1)) 283*14675a02SAndroid Build Coastguard Worker 284*14675a02SAndroid Build Coastguard Worker @mock.patch.object(server.InProcessServer, 'run_computation', autospec=True) 285*14675a02SAndroid Build Coastguard Worker async def test_invoke_with_run_computation_error(self, run_computation): 286*14675a02SAndroid Build Coastguard Worker run_computation.side_effect = ValueError('message') 287*14675a02SAndroid Build Coastguard Worker 288*14675a02SAndroid Build Coastguard Worker comp = federated_computation.FederatedComputation(count_clients, name='x') 289*14675a02SAndroid Build Coastguard Worker ctx = federated_context.FederatedContext( 290*14675a02SAndroid Build Coastguard Worker POPULATION_NAME, address_family=ADDRESS_FAMILY) 291*14675a02SAndroid Build Coastguard Worker release_manager = tff.program.MemoryReleaseManager() 292*14675a02SAndroid Build Coastguard Worker with tff.framework.get_context_stack().install(ctx): 293*14675a02SAndroid Build Coastguard Worker state, _ = comp(0, DATA_SOURCE.iterator().select(10)) 294*14675a02SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, 'message'): 295*14675a02SAndroid Build Coastguard Worker await release_manager.release( 296*14675a02SAndroid Build Coastguard Worker state, tff.type_at_server(tf.int32), key='result') 297*14675a02SAndroid Build Coastguard Worker 298*14675a02SAndroid Build Coastguard Worker 299*14675a02SAndroid Build Coastguard Workerclass FederatedContextPlanCachingTest(absltest.TestCase, 300*14675a02SAndroid Build Coastguard Worker unittest.IsolatedAsyncioTestCase): 301*14675a02SAndroid Build Coastguard Worker 302*14675a02SAndroid Build Coastguard Worker async def asyncSetUp(self): 303*14675a02SAndroid Build Coastguard Worker await super().asyncSetUp() 304*14675a02SAndroid Build Coastguard Worker 305*14675a02SAndroid Build Coastguard Worker @tff.federated_computation( 306*14675a02SAndroid Build Coastguard Worker tff.type_at_server(tf.int32), 307*14675a02SAndroid Build Coastguard Worker tff.type_at_clients(tff.SequenceType(tf.string))) 308*14675a02SAndroid Build Coastguard Worker def identity(state, client_data): 309*14675a02SAndroid Build Coastguard Worker del client_data 310*14675a02SAndroid Build Coastguard Worker return state, tff.federated_value((), tff.SERVER) 311*14675a02SAndroid Build Coastguard Worker 312*14675a02SAndroid Build Coastguard Worker self.count_clients_comp1 = federated_computation.FederatedComputation( 313*14675a02SAndroid Build Coastguard Worker count_clients, name='count_clients1') 314*14675a02SAndroid Build Coastguard Worker self.count_clients_comp2 = federated_computation.FederatedComputation( 315*14675a02SAndroid Build Coastguard Worker count_clients, name='count_clients2') 316*14675a02SAndroid Build Coastguard Worker self.identity_comp = federated_computation.FederatedComputation( 317*14675a02SAndroid Build Coastguard Worker identity, name='identity') 318*14675a02SAndroid Build Coastguard Worker 319*14675a02SAndroid Build Coastguard Worker self.data_source1 = federated_data_source.FederatedDataSource( 320*14675a02SAndroid Build Coastguard Worker POPULATION_NAME, plan_pb2.ExampleSelector(collection_uri='app:/1')) 321*14675a02SAndroid Build Coastguard Worker self.data_source2 = federated_data_source.FederatedDataSource( 322*14675a02SAndroid Build Coastguard Worker POPULATION_NAME, plan_pb2.ExampleSelector(collection_uri='app:/2')) 323*14675a02SAndroid Build Coastguard Worker 324*14675a02SAndroid Build Coastguard Worker self.run_computation = self.enter_context( 325*14675a02SAndroid Build Coastguard Worker mock.patch.object( 326*14675a02SAndroid Build Coastguard Worker server.InProcessServer, 'run_computation', autospec=True)) 327*14675a02SAndroid Build Coastguard Worker self.run_computation.return_value = build_result_checkpoint(0) 328*14675a02SAndroid Build Coastguard Worker self.build_plan = self.enter_context( 329*14675a02SAndroid Build Coastguard Worker mock.patch.object( 330*14675a02SAndroid Build Coastguard Worker federated_compute_plan_builder, 'build_plan', autospec=True)) 331*14675a02SAndroid Build Coastguard Worker self.build_plan.return_value = plan_pb2.Plan() 332*14675a02SAndroid Build Coastguard Worker self.generate_and_add_flat_buffer_to_plan = self.enter_context( 333*14675a02SAndroid Build Coastguard Worker mock.patch.object( 334*14675a02SAndroid Build Coastguard Worker plan_utils, 'generate_and_add_flat_buffer_to_plan', autospec=True)) 335*14675a02SAndroid Build Coastguard Worker self.generate_and_add_flat_buffer_to_plan.side_effect = lambda plan: plan 336*14675a02SAndroid Build Coastguard Worker self.enter_context(tff.framework.get_context_stack().install( 337*14675a02SAndroid Build Coastguard Worker federated_context.FederatedContext( 338*14675a02SAndroid Build Coastguard Worker POPULATION_NAME, address_family=ADDRESS_FAMILY))) 339*14675a02SAndroid Build Coastguard Worker self.release_manager = tff.program.MemoryReleaseManager() 340*14675a02SAndroid Build Coastguard Worker 341*14675a02SAndroid Build Coastguard Worker # Run (and therefore cache) count_clients_comp1 with data_source1. 342*14675a02SAndroid Build Coastguard Worker await self.release_manager.release( 343*14675a02SAndroid Build Coastguard Worker self.count_clients_comp1(0, 344*14675a02SAndroid Build Coastguard Worker self.data_source1.iterator().select(1)), 345*14675a02SAndroid Build Coastguard Worker self.count_clients_comp1.type_signature.result, 346*14675a02SAndroid Build Coastguard Worker key='result') 347*14675a02SAndroid Build Coastguard Worker self.build_plan.assert_called_once() 348*14675a02SAndroid Build Coastguard Worker self.assertEqual(self.build_plan.call_args.args[0], 349*14675a02SAndroid Build Coastguard Worker self.count_clients_comp1.map_reduce_form) 350*14675a02SAndroid Build Coastguard Worker self.assertEqual( 351*14675a02SAndroid Build Coastguard Worker self.build_plan.call_args.args[1], 352*14675a02SAndroid Build Coastguard Worker self.count_clients_comp1.distribute_aggregate_form, 353*14675a02SAndroid Build Coastguard Worker ) 354*14675a02SAndroid Build Coastguard Worker self.assertEqual( 355*14675a02SAndroid Build Coastguard Worker self.build_plan.call_args.args[2].example_selector_proto, 356*14675a02SAndroid Build Coastguard Worker self.data_source1.example_selector, 357*14675a02SAndroid Build Coastguard Worker ) 358*14675a02SAndroid Build Coastguard Worker self.run_computation.assert_called_once() 359*14675a02SAndroid Build Coastguard Worker self.build_plan.reset_mock() 360*14675a02SAndroid Build Coastguard Worker self.run_computation.reset_mock() 361*14675a02SAndroid Build Coastguard Worker 362*14675a02SAndroid Build Coastguard Worker async def test_reuse_with_repeat_computation(self): 363*14675a02SAndroid Build Coastguard Worker await self.release_manager.release( 364*14675a02SAndroid Build Coastguard Worker self.count_clients_comp1(0, 365*14675a02SAndroid Build Coastguard Worker self.data_source1.iterator().select(1)), 366*14675a02SAndroid Build Coastguard Worker self.count_clients_comp1.type_signature.result, 367*14675a02SAndroid Build Coastguard Worker key='result') 368*14675a02SAndroid Build Coastguard Worker self.build_plan.assert_not_called() 369*14675a02SAndroid Build Coastguard Worker self.run_computation.assert_called_once() 370*14675a02SAndroid Build Coastguard Worker 371*14675a02SAndroid Build Coastguard Worker async def test_reuse_with_changed_num_clients(self): 372*14675a02SAndroid Build Coastguard Worker await self.release_manager.release( 373*14675a02SAndroid Build Coastguard Worker self.count_clients_comp1(0, 374*14675a02SAndroid Build Coastguard Worker self.data_source1.iterator().select(10)), 375*14675a02SAndroid Build Coastguard Worker self.count_clients_comp1.type_signature.result, 376*14675a02SAndroid Build Coastguard Worker key='result') 377*14675a02SAndroid Build Coastguard Worker self.build_plan.assert_not_called() 378*14675a02SAndroid Build Coastguard Worker self.run_computation.assert_called_once() 379*14675a02SAndroid Build Coastguard Worker 380*14675a02SAndroid Build Coastguard Worker async def test_reuse_with_changed_initial_state(self): 381*14675a02SAndroid Build Coastguard Worker await self.release_manager.release( 382*14675a02SAndroid Build Coastguard Worker self.count_clients_comp1(3, 383*14675a02SAndroid Build Coastguard Worker self.data_source1.iterator().select(1)), 384*14675a02SAndroid Build Coastguard Worker self.count_clients_comp1.type_signature.result, 385*14675a02SAndroid Build Coastguard Worker key='result') 386*14675a02SAndroid Build Coastguard Worker self.build_plan.assert_not_called() 387*14675a02SAndroid Build Coastguard Worker self.run_computation.assert_called_once() 388*14675a02SAndroid Build Coastguard Worker 389*14675a02SAndroid Build Coastguard Worker async def test_reuse_with_equivalent_map_reduce_form(self): 390*14675a02SAndroid Build Coastguard Worker await self.release_manager.release( 391*14675a02SAndroid Build Coastguard Worker self.count_clients_comp2(0, 392*14675a02SAndroid Build Coastguard Worker self.data_source1.iterator().select(1)), 393*14675a02SAndroid Build Coastguard Worker self.count_clients_comp2.type_signature.result, 394*14675a02SAndroid Build Coastguard Worker key='result') 395*14675a02SAndroid Build Coastguard Worker self.build_plan.assert_not_called() 396*14675a02SAndroid Build Coastguard Worker self.run_computation.assert_called_once() 397*14675a02SAndroid Build Coastguard Worker 398*14675a02SAndroid Build Coastguard Worker async def test_rebuild_with_different_computation(self): 399*14675a02SAndroid Build Coastguard Worker await self.release_manager.release( 400*14675a02SAndroid Build Coastguard Worker self.identity_comp(0, 401*14675a02SAndroid Build Coastguard Worker self.data_source1.iterator().select(1)), 402*14675a02SAndroid Build Coastguard Worker self.identity_comp.type_signature.result, 403*14675a02SAndroid Build Coastguard Worker key='result') 404*14675a02SAndroid Build Coastguard Worker self.build_plan.assert_called_once() 405*14675a02SAndroid Build Coastguard Worker self.assertEqual(self.build_plan.call_args.args[0], 406*14675a02SAndroid Build Coastguard Worker self.identity_comp.map_reduce_form) 407*14675a02SAndroid Build Coastguard Worker self.assertEqual( 408*14675a02SAndroid Build Coastguard Worker self.build_plan.call_args.args[1], 409*14675a02SAndroid Build Coastguard Worker self.identity_comp.distribute_aggregate_form, 410*14675a02SAndroid Build Coastguard Worker ) 411*14675a02SAndroid Build Coastguard Worker self.assertEqual( 412*14675a02SAndroid Build Coastguard Worker self.build_plan.call_args.args[2].example_selector_proto, 413*14675a02SAndroid Build Coastguard Worker self.data_source1.example_selector, 414*14675a02SAndroid Build Coastguard Worker ) 415*14675a02SAndroid Build Coastguard Worker self.run_computation.assert_called_once() 416*14675a02SAndroid Build Coastguard Worker 417*14675a02SAndroid Build Coastguard Worker async def test_rebuild_with_different_data_source(self): 418*14675a02SAndroid Build Coastguard Worker await self.release_manager.release( 419*14675a02SAndroid Build Coastguard Worker self.count_clients_comp1(0, 420*14675a02SAndroid Build Coastguard Worker self.data_source2.iterator().select(1)), 421*14675a02SAndroid Build Coastguard Worker self.count_clients_comp1.type_signature.result, 422*14675a02SAndroid Build Coastguard Worker key='result') 423*14675a02SAndroid Build Coastguard Worker self.build_plan.assert_called_once() 424*14675a02SAndroid Build Coastguard Worker self.assertEqual(self.build_plan.call_args.args[0], 425*14675a02SAndroid Build Coastguard Worker self.count_clients_comp1.map_reduce_form) 426*14675a02SAndroid Build Coastguard Worker self.assertEqual( 427*14675a02SAndroid Build Coastguard Worker self.build_plan.call_args.args[1], 428*14675a02SAndroid Build Coastguard Worker self.count_clients_comp1.distribute_aggregate_form, 429*14675a02SAndroid Build Coastguard Worker ) 430*14675a02SAndroid Build Coastguard Worker self.assertEqual( 431*14675a02SAndroid Build Coastguard Worker self.build_plan.call_args.args[2].example_selector_proto, 432*14675a02SAndroid Build Coastguard Worker self.data_source2.example_selector, 433*14675a02SAndroid Build Coastguard Worker ) 434*14675a02SAndroid Build Coastguard Worker self.run_computation.assert_called_once() 435*14675a02SAndroid Build Coastguard Worker 436*14675a02SAndroid Build Coastguard Worker 437*14675a02SAndroid Build Coastguard Workerif __name__ == '__main__': 438*14675a02SAndroid Build Coastguard Worker absltest.main() 439