xref: /aosp_15_r20/external/federated-compute/fcp/demo/federated_computation_test.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1*14675a02SAndroid Build Coastguard Worker# Copyright 2022 Google LLC
2*14675a02SAndroid Build Coastguard Worker#
3*14675a02SAndroid Build Coastguard Worker# Licensed under the Apache License, Version 2.0 (the "License");
4*14675a02SAndroid Build Coastguard Worker# you may not use this file except in compliance with the License.
5*14675a02SAndroid Build Coastguard Worker# You may obtain a copy of the License at
6*14675a02SAndroid Build Coastguard Worker#
7*14675a02SAndroid Build Coastguard Worker#      http://www.apache.org/licenses/LICENSE-2.0
8*14675a02SAndroid Build Coastguard Worker#
9*14675a02SAndroid Build Coastguard Worker# Unless required by applicable law or agreed to in writing, software
10*14675a02SAndroid Build Coastguard Worker# distributed under the License is distributed on an "AS IS" BASIS,
11*14675a02SAndroid Build Coastguard Worker# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either expresus or implied.
12*14675a02SAndroid Build Coastguard Worker# See the License for the specific language governing permissions and
13*14675a02SAndroid Build Coastguard Worker# limitations under the License.
14*14675a02SAndroid Build Coastguard Worker"""Tests for federated_computation."""
15*14675a02SAndroid Build Coastguard Worker
16*14675a02SAndroid Build Coastguard Workerfrom unittest import mock
17*14675a02SAndroid Build Coastguard Worker
18*14675a02SAndroid Build Coastguard Workerfrom absl.testing import absltest
19*14675a02SAndroid Build Coastguard Workerimport tensorflow as tf
20*14675a02SAndroid Build Coastguard Workerimport tensorflow_federated as tff
21*14675a02SAndroid Build Coastguard Worker
22*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import federated_computation as fc
23*14675a02SAndroid Build Coastguard Worker
24*14675a02SAndroid Build Coastguard Worker
25*14675a02SAndroid Build Coastguard Worker@tff.tf_computation(tf.int32, tf.int32)
26*14675a02SAndroid Build Coastguard Workerdef add_values(x, y):
27*14675a02SAndroid Build Coastguard Worker  return x + y
28*14675a02SAndroid Build Coastguard Worker
29*14675a02SAndroid Build Coastguard Worker
30*14675a02SAndroid Build Coastguard Worker@tff.federated_computation(
31*14675a02SAndroid Build Coastguard Worker    tff.type_at_server(tf.int32),
32*14675a02SAndroid Build Coastguard Worker    tff.type_at_clients(tff.SequenceType(tf.string)))
33*14675a02SAndroid Build Coastguard Workerdef count_clients(state, client_data):
34*14675a02SAndroid Build Coastguard Worker  """Example TFF computation that counts clients."""
35*14675a02SAndroid Build Coastguard Worker  del client_data
36*14675a02SAndroid Build Coastguard Worker  client_value = tff.federated_value(1, tff.CLIENTS)
37*14675a02SAndroid Build Coastguard Worker  aggregated_count = tff.federated_sum(client_value)
38*14675a02SAndroid Build Coastguard Worker  metrics = tff.federated_value(tff.structure.Struct(()), tff.SERVER)
39*14675a02SAndroid Build Coastguard Worker  return tff.federated_map(add_values, (state, aggregated_count)), metrics
40*14675a02SAndroid Build Coastguard Worker
41*14675a02SAndroid Build Coastguard Worker
42*14675a02SAndroid Build Coastguard Worker@tff.federated_computation(
43*14675a02SAndroid Build Coastguard Worker    tff.type_at_server(tf.int32),
44*14675a02SAndroid Build Coastguard Worker    tff.type_at_clients(tff.SequenceType(tf.string)))
45*14675a02SAndroid Build Coastguard Workerdef count_examples(state, client_data):
46*14675a02SAndroid Build Coastguard Worker  """Example TFF computation that counts client examples."""
47*14675a02SAndroid Build Coastguard Worker
48*14675a02SAndroid Build Coastguard Worker  @tff.tf_computation
49*14675a02SAndroid Build Coastguard Worker  def client_work(client_data):
50*14675a02SAndroid Build Coastguard Worker    return client_data.reduce(0, lambda x, _: x + 1)
51*14675a02SAndroid Build Coastguard Worker
52*14675a02SAndroid Build Coastguard Worker  client_counts = tff.federated_map(client_work, client_data)
53*14675a02SAndroid Build Coastguard Worker  aggregated_count = tff.federated_sum(client_counts)
54*14675a02SAndroid Build Coastguard Worker  metrics = tff.federated_value(tff.structure.Struct(()), tff.SERVER)
55*14675a02SAndroid Build Coastguard Worker  return tff.federated_map(add_values, (state, aggregated_count)), metrics
56*14675a02SAndroid Build Coastguard Worker
57*14675a02SAndroid Build Coastguard Worker
58*14675a02SAndroid Build Coastguard Workerclass FederatedComputationTest(absltest.TestCase):
59*14675a02SAndroid Build Coastguard Worker
60*14675a02SAndroid Build Coastguard Worker  def test_invalid_name(self):
61*14675a02SAndroid Build Coastguard Worker    with self.assertRaisesRegex(ValueError, r'name must match ".+"'):
62*14675a02SAndroid Build Coastguard Worker      fc.FederatedComputation(count_clients, name='^invalid^')
63*14675a02SAndroid Build Coastguard Worker
64*14675a02SAndroid Build Coastguard Worker  def test_incompatible_computation(self):
65*14675a02SAndroid Build Coastguard Worker    # This function doesn't have the return value structure required for MRF.
66*14675a02SAndroid Build Coastguard Worker    @tff.federated_computation(tff.type_at_server(tf.int32))
67*14675a02SAndroid Build Coastguard Worker    def add_one(value):
68*14675a02SAndroid Build Coastguard Worker      return value + tff.federated_value(1, tff.SERVER)
69*14675a02SAndroid Build Coastguard Worker
70*14675a02SAndroid Build Coastguard Worker    with self.assertRaises(TypeError):
71*14675a02SAndroid Build Coastguard Worker      fc.FederatedComputation(add_one, name='comp')
72*14675a02SAndroid Build Coastguard Worker
73*14675a02SAndroid Build Coastguard Worker  @tff.test.with_context(
74*14675a02SAndroid Build Coastguard Worker      tff.backends.test.create_sync_test_cpp_execution_context
75*14675a02SAndroid Build Coastguard Worker  )
76*14675a02SAndroid Build Coastguard Worker  def test_map_reduce_form(self):
77*14675a02SAndroid Build Coastguard Worker    comp1 = fc.FederatedComputation(count_clients, name='comp1')
78*14675a02SAndroid Build Coastguard Worker    comp2 = fc.FederatedComputation(count_examples, name='comp2')
79*14675a02SAndroid Build Coastguard Worker    self.assertNotEqual(comp1.map_reduce_form, comp2.map_reduce_form)
80*14675a02SAndroid Build Coastguard Worker
81*14675a02SAndroid Build Coastguard Worker    # While we treat the MRF contents as an implementation detail, we can verify
82*14675a02SAndroid Build Coastguard Worker    # the invocation results of the corresponding computation.
83*14675a02SAndroid Build Coastguard Worker    # comp1 should return the number of clients.
84*14675a02SAndroid Build Coastguard Worker    self.assertEqual(
85*14675a02SAndroid Build Coastguard Worker        tff.backends.mapreduce.get_computation_for_map_reduce_form(
86*14675a02SAndroid Build Coastguard Worker            comp1.map_reduce_form
87*14675a02SAndroid Build Coastguard Worker        )(0, [['', '']] * 3),
88*14675a02SAndroid Build Coastguard Worker        (3, ()),
89*14675a02SAndroid Build Coastguard Worker    )
90*14675a02SAndroid Build Coastguard Worker    # comp2 should return the number of examples across all clients.
91*14675a02SAndroid Build Coastguard Worker    self.assertEqual(
92*14675a02SAndroid Build Coastguard Worker        tff.backends.mapreduce.get_computation_for_map_reduce_form(
93*14675a02SAndroid Build Coastguard Worker            comp2.map_reduce_form)(0, [['', '']] * 3), (6, ()))
94*14675a02SAndroid Build Coastguard Worker
95*14675a02SAndroid Build Coastguard Worker  @tff.test.with_context(
96*14675a02SAndroid Build Coastguard Worker      tff.backends.native.create_sync_local_cpp_execution_context
97*14675a02SAndroid Build Coastguard Worker  )
98*14675a02SAndroid Build Coastguard Worker  def test_distribute_aggregate_form(self):
99*14675a02SAndroid Build Coastguard Worker    comp1 = fc.FederatedComputation(count_clients, name='comp1')
100*14675a02SAndroid Build Coastguard Worker    comp2 = fc.FederatedComputation(count_examples, name='comp2')
101*14675a02SAndroid Build Coastguard Worker    self.assertNotEqual(
102*14675a02SAndroid Build Coastguard Worker        comp1.distribute_aggregate_form, comp2.distribute_aggregate_form
103*14675a02SAndroid Build Coastguard Worker    )
104*14675a02SAndroid Build Coastguard Worker
105*14675a02SAndroid Build Coastguard Worker    # While we treat the DAF contents as an implementation detail, we can verify
106*14675a02SAndroid Build Coastguard Worker    # the invocation results of the corresponding computation.
107*14675a02SAndroid Build Coastguard Worker    # comp1 should return the number of clients.
108*14675a02SAndroid Build Coastguard Worker    self.assertEqual(
109*14675a02SAndroid Build Coastguard Worker        tff.backends.mapreduce.get_computation_for_distribute_aggregate_form(
110*14675a02SAndroid Build Coastguard Worker            comp1.distribute_aggregate_form
111*14675a02SAndroid Build Coastguard Worker        )(0, [['', '']] * 3),
112*14675a02SAndroid Build Coastguard Worker        (3, ()),
113*14675a02SAndroid Build Coastguard Worker    )
114*14675a02SAndroid Build Coastguard Worker    # comp2 should return the number of examples across all clients.
115*14675a02SAndroid Build Coastguard Worker    self.assertEqual(
116*14675a02SAndroid Build Coastguard Worker        tff.backends.mapreduce.get_computation_for_distribute_aggregate_form(
117*14675a02SAndroid Build Coastguard Worker            comp2.distribute_aggregate_form
118*14675a02SAndroid Build Coastguard Worker        )(0, [['', '']] * 3),
119*14675a02SAndroid Build Coastguard Worker        (6, ()),
120*14675a02SAndroid Build Coastguard Worker    )
121*14675a02SAndroid Build Coastguard Worker
122*14675a02SAndroid Build Coastguard Worker  def test_wrapped_computation(self):
123*14675a02SAndroid Build Coastguard Worker    comp = fc.FederatedComputation(count_clients, name='comp')
124*14675a02SAndroid Build Coastguard Worker    self.assertEqual(comp.wrapped_computation, count_clients)
125*14675a02SAndroid Build Coastguard Worker
126*14675a02SAndroid Build Coastguard Worker  def test_name(self):
127*14675a02SAndroid Build Coastguard Worker    comp = fc.FederatedComputation(count_clients, name='comp')
128*14675a02SAndroid Build Coastguard Worker    self.assertEqual(comp.name, 'comp')
129*14675a02SAndroid Build Coastguard Worker
130*14675a02SAndroid Build Coastguard Worker  def test_type_signature(self):
131*14675a02SAndroid Build Coastguard Worker    comp = fc.FederatedComputation(count_clients, name='comp')
132*14675a02SAndroid Build Coastguard Worker    self.assertEqual(comp.type_signature, count_clients.type_signature)
133*14675a02SAndroid Build Coastguard Worker
134*14675a02SAndroid Build Coastguard Worker  def test_call(self):
135*14675a02SAndroid Build Coastguard Worker    comp = fc.FederatedComputation(count_clients, name='comp')
136*14675a02SAndroid Build Coastguard Worker    ctx = mock.create_autospec(tff.program.FederatedContext, instance=True)
137*14675a02SAndroid Build Coastguard Worker    ctx.invoke.return_value = 1234
138*14675a02SAndroid Build Coastguard Worker    with tff.framework.get_context_stack().install(ctx):
139*14675a02SAndroid Build Coastguard Worker      self.assertEqual(comp(1, 2, 3, kw1='a', kw2='b'), 1234)
140*14675a02SAndroid Build Coastguard Worker    ctx.invoke.assert_called_once_with(
141*14675a02SAndroid Build Coastguard Worker        comp,
142*14675a02SAndroid Build Coastguard Worker        tff.structure.Struct([(None, 1), (None, 2), (None, 3), ('kw1', 'a'),
143*14675a02SAndroid Build Coastguard Worker                              ('kw2', 'b')]))
144*14675a02SAndroid Build Coastguard Worker
145*14675a02SAndroid Build Coastguard Worker  def test_hash(self):
146*14675a02SAndroid Build Coastguard Worker    comp = fc.FederatedComputation(count_clients, name='comp')
147*14675a02SAndroid Build Coastguard Worker    # Equivalent objects should have equal hashes.
148*14675a02SAndroid Build Coastguard Worker    self.assertEqual(
149*14675a02SAndroid Build Coastguard Worker        hash(comp), hash(fc.FederatedComputation(count_clients, name='comp')))
150*14675a02SAndroid Build Coastguard Worker    # Different computations or names should produce different hashes.
151*14675a02SAndroid Build Coastguard Worker    self.assertNotEqual(
152*14675a02SAndroid Build Coastguard Worker        hash(comp), hash(fc.FederatedComputation(count_clients, name='other')))
153*14675a02SAndroid Build Coastguard Worker    self.assertNotEqual(
154*14675a02SAndroid Build Coastguard Worker        hash(comp), hash(fc.FederatedComputation(count_examples, name='comp')))
155*14675a02SAndroid Build Coastguard Worker
156*14675a02SAndroid Build Coastguard Worker
157*14675a02SAndroid Build Coastguard Workerif __name__ == '__main__':
158*14675a02SAndroid Build Coastguard Worker  absltest.main()
159