1*14675a02SAndroid Build Coastguard Worker# Copyright 2022 Google LLC 2*14675a02SAndroid Build Coastguard Worker# 3*14675a02SAndroid Build Coastguard Worker# Licensed under the Apache License, Version 2.0 (the "License"); 4*14675a02SAndroid Build Coastguard Worker# you may not use this file except in compliance with the License. 5*14675a02SAndroid Build Coastguard Worker# You may obtain a copy of the License at 6*14675a02SAndroid Build Coastguard Worker# 7*14675a02SAndroid Build Coastguard Worker# http://www.apache.org/licenses/LICENSE-2.0 8*14675a02SAndroid Build Coastguard Worker# 9*14675a02SAndroid Build Coastguard Worker# Unless required by applicable law or agreed to in writing, software 10*14675a02SAndroid Build Coastguard Worker# distributed under the License is distributed on an "AS IS" BASIS, 11*14675a02SAndroid Build Coastguard Worker# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either expresus or implied. 12*14675a02SAndroid Build Coastguard Worker# See the License for the specific language governing permissions and 13*14675a02SAndroid Build Coastguard Worker# limitations under the License. 14*14675a02SAndroid Build Coastguard Worker"""Tests for federated_computation.""" 15*14675a02SAndroid Build Coastguard Worker 16*14675a02SAndroid Build Coastguard Workerfrom unittest import mock 17*14675a02SAndroid Build Coastguard Worker 18*14675a02SAndroid Build Coastguard Workerfrom absl.testing import absltest 19*14675a02SAndroid Build Coastguard Workerimport tensorflow as tf 20*14675a02SAndroid Build Coastguard Workerimport tensorflow_federated as tff 21*14675a02SAndroid Build Coastguard Worker 22*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import federated_computation as fc 23*14675a02SAndroid Build Coastguard Worker 24*14675a02SAndroid Build Coastguard Worker 25*14675a02SAndroid Build Coastguard Worker@tff.tf_computation(tf.int32, tf.int32) 26*14675a02SAndroid Build Coastguard Workerdef add_values(x, y): 27*14675a02SAndroid Build Coastguard Worker return x + y 28*14675a02SAndroid Build Coastguard Worker 29*14675a02SAndroid Build Coastguard Worker 30*14675a02SAndroid Build Coastguard Worker@tff.federated_computation( 31*14675a02SAndroid Build Coastguard Worker tff.type_at_server(tf.int32), 32*14675a02SAndroid Build Coastguard Worker tff.type_at_clients(tff.SequenceType(tf.string))) 33*14675a02SAndroid Build Coastguard Workerdef count_clients(state, client_data): 34*14675a02SAndroid Build Coastguard Worker """Example TFF computation that counts clients.""" 35*14675a02SAndroid Build Coastguard Worker del client_data 36*14675a02SAndroid Build Coastguard Worker client_value = tff.federated_value(1, tff.CLIENTS) 37*14675a02SAndroid Build Coastguard Worker aggregated_count = tff.federated_sum(client_value) 38*14675a02SAndroid Build Coastguard Worker metrics = tff.federated_value(tff.structure.Struct(()), tff.SERVER) 39*14675a02SAndroid Build Coastguard Worker return tff.federated_map(add_values, (state, aggregated_count)), metrics 40*14675a02SAndroid Build Coastguard Worker 41*14675a02SAndroid Build Coastguard Worker 42*14675a02SAndroid Build Coastguard Worker@tff.federated_computation( 43*14675a02SAndroid Build Coastguard Worker tff.type_at_server(tf.int32), 44*14675a02SAndroid Build Coastguard Worker tff.type_at_clients(tff.SequenceType(tf.string))) 45*14675a02SAndroid Build Coastguard Workerdef count_examples(state, client_data): 46*14675a02SAndroid Build Coastguard Worker """Example TFF computation that counts client examples.""" 47*14675a02SAndroid Build Coastguard Worker 48*14675a02SAndroid Build Coastguard Worker @tff.tf_computation 49*14675a02SAndroid Build Coastguard Worker def client_work(client_data): 50*14675a02SAndroid Build Coastguard Worker return client_data.reduce(0, lambda x, _: x + 1) 51*14675a02SAndroid Build Coastguard Worker 52*14675a02SAndroid Build Coastguard Worker client_counts = tff.federated_map(client_work, client_data) 53*14675a02SAndroid Build Coastguard Worker aggregated_count = tff.federated_sum(client_counts) 54*14675a02SAndroid Build Coastguard Worker metrics = tff.federated_value(tff.structure.Struct(()), tff.SERVER) 55*14675a02SAndroid Build Coastguard Worker return tff.federated_map(add_values, (state, aggregated_count)), metrics 56*14675a02SAndroid Build Coastguard Worker 57*14675a02SAndroid Build Coastguard Worker 58*14675a02SAndroid Build Coastguard Workerclass FederatedComputationTest(absltest.TestCase): 59*14675a02SAndroid Build Coastguard Worker 60*14675a02SAndroid Build Coastguard Worker def test_invalid_name(self): 61*14675a02SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, r'name must match ".+"'): 62*14675a02SAndroid Build Coastguard Worker fc.FederatedComputation(count_clients, name='^invalid^') 63*14675a02SAndroid Build Coastguard Worker 64*14675a02SAndroid Build Coastguard Worker def test_incompatible_computation(self): 65*14675a02SAndroid Build Coastguard Worker # This function doesn't have the return value structure required for MRF. 66*14675a02SAndroid Build Coastguard Worker @tff.federated_computation(tff.type_at_server(tf.int32)) 67*14675a02SAndroid Build Coastguard Worker def add_one(value): 68*14675a02SAndroid Build Coastguard Worker return value + tff.federated_value(1, tff.SERVER) 69*14675a02SAndroid Build Coastguard Worker 70*14675a02SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 71*14675a02SAndroid Build Coastguard Worker fc.FederatedComputation(add_one, name='comp') 72*14675a02SAndroid Build Coastguard Worker 73*14675a02SAndroid Build Coastguard Worker @tff.test.with_context( 74*14675a02SAndroid Build Coastguard Worker tff.backends.test.create_sync_test_cpp_execution_context 75*14675a02SAndroid Build Coastguard Worker ) 76*14675a02SAndroid Build Coastguard Worker def test_map_reduce_form(self): 77*14675a02SAndroid Build Coastguard Worker comp1 = fc.FederatedComputation(count_clients, name='comp1') 78*14675a02SAndroid Build Coastguard Worker comp2 = fc.FederatedComputation(count_examples, name='comp2') 79*14675a02SAndroid Build Coastguard Worker self.assertNotEqual(comp1.map_reduce_form, comp2.map_reduce_form) 80*14675a02SAndroid Build Coastguard Worker 81*14675a02SAndroid Build Coastguard Worker # While we treat the MRF contents as an implementation detail, we can verify 82*14675a02SAndroid Build Coastguard Worker # the invocation results of the corresponding computation. 83*14675a02SAndroid Build Coastguard Worker # comp1 should return the number of clients. 84*14675a02SAndroid Build Coastguard Worker self.assertEqual( 85*14675a02SAndroid Build Coastguard Worker tff.backends.mapreduce.get_computation_for_map_reduce_form( 86*14675a02SAndroid Build Coastguard Worker comp1.map_reduce_form 87*14675a02SAndroid Build Coastguard Worker )(0, [['', '']] * 3), 88*14675a02SAndroid Build Coastguard Worker (3, ()), 89*14675a02SAndroid Build Coastguard Worker ) 90*14675a02SAndroid Build Coastguard Worker # comp2 should return the number of examples across all clients. 91*14675a02SAndroid Build Coastguard Worker self.assertEqual( 92*14675a02SAndroid Build Coastguard Worker tff.backends.mapreduce.get_computation_for_map_reduce_form( 93*14675a02SAndroid Build Coastguard Worker comp2.map_reduce_form)(0, [['', '']] * 3), (6, ())) 94*14675a02SAndroid Build Coastguard Worker 95*14675a02SAndroid Build Coastguard Worker @tff.test.with_context( 96*14675a02SAndroid Build Coastguard Worker tff.backends.native.create_sync_local_cpp_execution_context 97*14675a02SAndroid Build Coastguard Worker ) 98*14675a02SAndroid Build Coastguard Worker def test_distribute_aggregate_form(self): 99*14675a02SAndroid Build Coastguard Worker comp1 = fc.FederatedComputation(count_clients, name='comp1') 100*14675a02SAndroid Build Coastguard Worker comp2 = fc.FederatedComputation(count_examples, name='comp2') 101*14675a02SAndroid Build Coastguard Worker self.assertNotEqual( 102*14675a02SAndroid Build Coastguard Worker comp1.distribute_aggregate_form, comp2.distribute_aggregate_form 103*14675a02SAndroid Build Coastguard Worker ) 104*14675a02SAndroid Build Coastguard Worker 105*14675a02SAndroid Build Coastguard Worker # While we treat the DAF contents as an implementation detail, we can verify 106*14675a02SAndroid Build Coastguard Worker # the invocation results of the corresponding computation. 107*14675a02SAndroid Build Coastguard Worker # comp1 should return the number of clients. 108*14675a02SAndroid Build Coastguard Worker self.assertEqual( 109*14675a02SAndroid Build Coastguard Worker tff.backends.mapreduce.get_computation_for_distribute_aggregate_form( 110*14675a02SAndroid Build Coastguard Worker comp1.distribute_aggregate_form 111*14675a02SAndroid Build Coastguard Worker )(0, [['', '']] * 3), 112*14675a02SAndroid Build Coastguard Worker (3, ()), 113*14675a02SAndroid Build Coastguard Worker ) 114*14675a02SAndroid Build Coastguard Worker # comp2 should return the number of examples across all clients. 115*14675a02SAndroid Build Coastguard Worker self.assertEqual( 116*14675a02SAndroid Build Coastguard Worker tff.backends.mapreduce.get_computation_for_distribute_aggregate_form( 117*14675a02SAndroid Build Coastguard Worker comp2.distribute_aggregate_form 118*14675a02SAndroid Build Coastguard Worker )(0, [['', '']] * 3), 119*14675a02SAndroid Build Coastguard Worker (6, ()), 120*14675a02SAndroid Build Coastguard Worker ) 121*14675a02SAndroid Build Coastguard Worker 122*14675a02SAndroid Build Coastguard Worker def test_wrapped_computation(self): 123*14675a02SAndroid Build Coastguard Worker comp = fc.FederatedComputation(count_clients, name='comp') 124*14675a02SAndroid Build Coastguard Worker self.assertEqual(comp.wrapped_computation, count_clients) 125*14675a02SAndroid Build Coastguard Worker 126*14675a02SAndroid Build Coastguard Worker def test_name(self): 127*14675a02SAndroid Build Coastguard Worker comp = fc.FederatedComputation(count_clients, name='comp') 128*14675a02SAndroid Build Coastguard Worker self.assertEqual(comp.name, 'comp') 129*14675a02SAndroid Build Coastguard Worker 130*14675a02SAndroid Build Coastguard Worker def test_type_signature(self): 131*14675a02SAndroid Build Coastguard Worker comp = fc.FederatedComputation(count_clients, name='comp') 132*14675a02SAndroid Build Coastguard Worker self.assertEqual(comp.type_signature, count_clients.type_signature) 133*14675a02SAndroid Build Coastguard Worker 134*14675a02SAndroid Build Coastguard Worker def test_call(self): 135*14675a02SAndroid Build Coastguard Worker comp = fc.FederatedComputation(count_clients, name='comp') 136*14675a02SAndroid Build Coastguard Worker ctx = mock.create_autospec(tff.program.FederatedContext, instance=True) 137*14675a02SAndroid Build Coastguard Worker ctx.invoke.return_value = 1234 138*14675a02SAndroid Build Coastguard Worker with tff.framework.get_context_stack().install(ctx): 139*14675a02SAndroid Build Coastguard Worker self.assertEqual(comp(1, 2, 3, kw1='a', kw2='b'), 1234) 140*14675a02SAndroid Build Coastguard Worker ctx.invoke.assert_called_once_with( 141*14675a02SAndroid Build Coastguard Worker comp, 142*14675a02SAndroid Build Coastguard Worker tff.structure.Struct([(None, 1), (None, 2), (None, 3), ('kw1', 'a'), 143*14675a02SAndroid Build Coastguard Worker ('kw2', 'b')])) 144*14675a02SAndroid Build Coastguard Worker 145*14675a02SAndroid Build Coastguard Worker def test_hash(self): 146*14675a02SAndroid Build Coastguard Worker comp = fc.FederatedComputation(count_clients, name='comp') 147*14675a02SAndroid Build Coastguard Worker # Equivalent objects should have equal hashes. 148*14675a02SAndroid Build Coastguard Worker self.assertEqual( 149*14675a02SAndroid Build Coastguard Worker hash(comp), hash(fc.FederatedComputation(count_clients, name='comp'))) 150*14675a02SAndroid Build Coastguard Worker # Different computations or names should produce different hashes. 151*14675a02SAndroid Build Coastguard Worker self.assertNotEqual( 152*14675a02SAndroid Build Coastguard Worker hash(comp), hash(fc.FederatedComputation(count_clients, name='other'))) 153*14675a02SAndroid Build Coastguard Worker self.assertNotEqual( 154*14675a02SAndroid Build Coastguard Worker hash(comp), hash(fc.FederatedComputation(count_examples, name='comp'))) 155*14675a02SAndroid Build Coastguard Worker 156*14675a02SAndroid Build Coastguard Worker 157*14675a02SAndroid Build Coastguard Workerif __name__ == '__main__': 158*14675a02SAndroid Build Coastguard Worker absltest.main() 159