xref: /aosp_15_r20/external/federated-compute/fcp/demo/federated_context_test.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either expresus or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests for federated_context."""
15
16import http
17import http.client
18import socket
19import threading
20import unittest
21from unittest import mock
22
23from absl.testing import absltest
24import attr
25import tensorflow as tf
26import tensorflow_federated as tff
27
28from fcp.artifact_building import artifact_constants
29from fcp.artifact_building import federated_compute_plan_builder
30from fcp.artifact_building import plan_utils
31from fcp.artifact_building import variable_helpers
32from fcp.demo import federated_computation
33from fcp.demo import federated_context
34from fcp.demo import federated_data_source
35from fcp.demo import server
36from fcp.demo import test_utils
37from fcp.protos import plan_pb2
38
39ADDRESS_FAMILY = socket.AddressFamily.AF_INET
40POPULATION_NAME = 'test/population'
41DATA_SOURCE = federated_data_source.FederatedDataSource(
42    POPULATION_NAME, plan_pb2.ExampleSelector(collection_uri='app:/test'))
43
44
45@tff.tf_computation(tf.int32)
46def add_one(x):
47  return x + 1
48
49
50@tff.federated_computation(
51    tff.type_at_server(tf.int32),
52    tff.type_at_clients(tff.SequenceType(tf.string)))
53def count_clients(state, client_data):
54  """Example TFF computation that counts clients."""
55  del client_data
56  num_clients = tff.federated_sum(tff.federated_value(1, tff.CLIENTS))
57  non_state = tff.federated_value((), tff.SERVER)
58  return state + num_clients, non_state
59
60
61@tff.federated_computation(
62    tff.type_at_server(tff.StructType([('foo', tf.int32), ('bar', tf.int32)])),
63    tff.type_at_clients(tff.SequenceType(tf.string)),
64)
65def irregular_arrays(state, client_data):
66  """Example TFF computation that returns irregular data."""
67  del client_data
68  num_clients = tff.federated_sum(tff.federated_value(1, tff.CLIENTS))
69  non_state = tff.federated_value(1, tff.SERVER)
70  return state, non_state + num_clients
71
72
73@attr.s(eq=False, frozen=True, slots=True)
74class TestClass:
75  """An attrs class."""
76
77  field_one = attr.ib()
78  field_two = attr.ib()
79
80
81@tff.tf_computation
82def init():
83  return TestClass(field_one=1, field_two=2)
84
85
86attrs_type = init.type_signature.result
87
88
89@tff.federated_computation(
90    tff.type_at_server(attrs_type),
91    tff.type_at_clients(tff.SequenceType(tf.string)),
92)
93def attrs_computation(state, client_data):
94  """Example TFF computation that returns an attrs class."""
95  del client_data
96  num_clients = tff.federated_sum(tff.federated_value(1, tff.CLIENTS))
97  non_state = tff.federated_value(1, tff.SERVER)
98  return state, non_state + num_clients
99
100
101def build_result_checkpoint(state: int) -> bytes:
102  """Helper function to build a result checkpoint for `count_clients`."""
103  var_names = variable_helpers.variable_names_from_type(
104      count_clients.type_signature.result[0],
105      name=artifact_constants.SERVER_STATE_VAR_PREFIX)
106  return test_utils.create_checkpoint({var_names[0]: state})
107
108
109class FederatedContextTest(absltest.TestCase, unittest.IsolatedAsyncioTestCase):
110
111  def test_invalid_population_name(self):
112    with self.assertRaisesRegex(ValueError, 'population_name must match ".+"'):
113      federated_context.FederatedContext(
114          '^^invalid^^', address_family=ADDRESS_FAMILY)
115
116  @mock.patch.object(server.InProcessServer, 'shutdown', autospec=True)
117  @mock.patch.object(server.InProcessServer, 'serve_forever', autospec=True)
118  def test_context_management(self, serve_forever, shutdown):
119    started = threading.Event()
120    serve_forever.side_effect = lambda *args, **kwargs: started.set()
121
122    ctx = federated_context.FederatedContext(
123        POPULATION_NAME, address_family=ADDRESS_FAMILY)
124    self.assertFalse(started.is_set())
125    shutdown.assert_not_called()
126    with ctx:
127      self.assertTrue(started.wait(0.5))
128      shutdown.assert_not_called()
129    shutdown.assert_called_once()
130
131  def test_http(self):
132    with federated_context.FederatedContext(
133        POPULATION_NAME, address_family=ADDRESS_FAMILY) as ctx:
134      conn = http.client.HTTPConnection('localhost', port=ctx.server_port)
135      conn.request('GET', '/does-not-exist')
136      self.assertEqual(conn.getresponse().status, http.HTTPStatus.NOT_FOUND)
137
138  def test_invoke_non_federated_with_base_context(self):
139    base_context = tff.backends.native.create_sync_local_cpp_execution_context()
140    ctx = federated_context.FederatedContext(
141        POPULATION_NAME,
142        address_family=ADDRESS_FAMILY,
143        base_context=base_context)
144    with tff.framework.get_context_stack().install(ctx):
145      self.assertEqual(add_one(3), 4)
146
147  def test_invoke_non_federated_without_base_context(self):
148    ctx = federated_context.FederatedContext(
149        POPULATION_NAME, address_family=ADDRESS_FAMILY)
150    with tff.framework.get_context_stack().install(ctx):
151      with self.assertRaisesRegex(TypeError,
152                                  'computation must be a FederatedComputation'):
153        add_one(3)
154
155  def test_invoke_with_invalid_state_type(self):
156    comp = federated_computation.FederatedComputation(count_clients, name='x')
157    ctx = federated_context.FederatedContext(
158        POPULATION_NAME, address_family=ADDRESS_FAMILY)
159    with tff.framework.get_context_stack().install(ctx):
160      with self.assertRaisesRegex(
161          TypeError, r'arg\[0\] must be a value or structure of values'
162      ):
163        comp(plan_pb2.Plan(), DATA_SOURCE.iterator().select(1))
164
165  def test_invoke_with_invalid_data_source_type(self):
166    comp = federated_computation.FederatedComputation(count_clients, name='x')
167    ctx = federated_context.FederatedContext(
168        POPULATION_NAME, address_family=ADDRESS_FAMILY)
169    with tff.framework.get_context_stack().install(ctx):
170      with self.assertRaisesRegex(
171          TypeError, r'arg\[1\] must be the result of '
172          r'FederatedDataSource.iterator\(\).select\(\)'):
173        comp(0, plan_pb2.Plan())
174
175  def test_invoke_succeeds_with_structure_state_type(self):
176    comp = federated_computation.FederatedComputation(
177        irregular_arrays, name='x'
178    )
179    ctx = federated_context.FederatedContext(
180        POPULATION_NAME, address_family=ADDRESS_FAMILY
181    )
182    with tff.framework.get_context_stack().install(ctx):
183      state = {'foo': (3, 1), 'bar': (4, 5, 6)}
184      comp(state, DATA_SOURCE.iterator().select(1))
185
186  def test_invoke_succeeds_with_attrs_state_type(self):
187    comp = federated_computation.FederatedComputation(
188        attrs_computation, name='x'
189    )
190    ctx = federated_context.FederatedContext(
191        POPULATION_NAME, address_family=ADDRESS_FAMILY
192    )
193    with tff.framework.get_context_stack().install(ctx):
194      state = TestClass(field_one=1, field_two=2)
195      comp(state, DATA_SOURCE.iterator().select(1))
196
197  def test_invoke_with_mismatched_population_names(self):
198    comp = federated_computation.FederatedComputation(count_clients, name='x')
199    ds = federated_data_source.FederatedDataSource('other/name',
200                                                   DATA_SOURCE.example_selector)
201    ctx = federated_context.FederatedContext(
202        POPULATION_NAME, address_family=ADDRESS_FAMILY)
203    with tff.framework.get_context_stack().install(ctx):
204      with self.assertRaisesRegex(
205          ValueError, 'FederatedDataSource and FederatedContext '
206          'population_names must match'):
207        comp(0, ds.iterator().select(1))
208
209  @mock.patch.object(server.InProcessServer, 'run_computation', autospec=True)
210  async def test_invoke_success(self, run_computation):
211    run_computation.return_value = build_result_checkpoint(7)
212
213    comp = federated_computation.FederatedComputation(count_clients, name='x')
214    ctx = federated_context.FederatedContext(
215        POPULATION_NAME, address_family=ADDRESS_FAMILY)
216    release_manager = tff.program.MemoryReleaseManager()
217    with tff.framework.get_context_stack().install(ctx):
218      state, _ = comp(3, DATA_SOURCE.iterator().select(10))
219      await release_manager.release(
220          state, tff.type_at_server(tf.int32), key='result')
221
222    self.assertEqual(release_manager.values()['result'][0], 7)
223
224    run_computation.assert_called_once_with(
225        mock.ANY,
226        comp.name,
227        mock.ANY,
228        mock.ANY,
229        DATA_SOURCE.task_assignment_mode,
230        10,
231    )
232    plan = run_computation.call_args.args[2]
233    self.assertIsInstance(plan, plan_pb2.Plan)
234    self.assertNotEmpty(plan.client_tflite_graph_bytes)
235    input_var_names = variable_helpers.variable_names_from_type(
236        count_clients.type_signature.parameter[0],
237        name=artifact_constants.SERVER_STATE_VAR_PREFIX)
238    self.assertLen(input_var_names, 1)
239    self.assertEqual(
240        test_utils.read_tensor_from_checkpoint(
241            run_computation.call_args.args[3], input_var_names[0], tf.int32), 3)
242
243  @mock.patch.object(server.InProcessServer, 'run_computation', autospec=True)
244  async def test_invoke_with_value_reference(self, run_computation):
245    run_computation.side_effect = [
246        build_result_checkpoint(1234),
247        build_result_checkpoint(5678)
248    ]
249
250    comp = federated_computation.FederatedComputation(count_clients, name='x')
251    ctx = federated_context.FederatedContext(
252        POPULATION_NAME, address_family=ADDRESS_FAMILY)
253    release_manager = tff.program.MemoryReleaseManager()
254    with tff.framework.get_context_stack().install(ctx):
255      state, _ = comp(3, DATA_SOURCE.iterator().select(10))
256      state, _ = comp(state, DATA_SOURCE.iterator().select(10))
257      await release_manager.release(
258          state, tff.type_at_server(tf.int32), key='result')
259
260    self.assertEqual(release_manager.values()['result'][0], 5678)
261
262    input_var_names = variable_helpers.variable_names_from_type(
263        count_clients.type_signature.parameter[0],
264        name=artifact_constants.SERVER_STATE_VAR_PREFIX)
265    self.assertLen(input_var_names, 1)
266    # The second invocation should be passed the value returned by the first
267    # invocation.
268    self.assertEqual(run_computation.call_count, 2)
269    self.assertEqual(
270        test_utils.read_tensor_from_checkpoint(
271            run_computation.call_args.args[3], input_var_names[0], tf.int32),
272        1234)
273
274  async def test_invoke_without_input_state(self):
275    comp = federated_computation.FederatedComputation(count_clients, name='x')
276    ctx = federated_context.FederatedContext(
277        POPULATION_NAME, address_family=ADDRESS_FAMILY)
278    with tff.framework.get_context_stack().install(ctx):
279      with self.assertRaisesRegex(
280          TypeError, r'arg\[0\] must be a value or structure of values'
281      ):
282        comp(None, DATA_SOURCE.iterator().select(1))
283
284  @mock.patch.object(server.InProcessServer, 'run_computation', autospec=True)
285  async def test_invoke_with_run_computation_error(self, run_computation):
286    run_computation.side_effect = ValueError('message')
287
288    comp = federated_computation.FederatedComputation(count_clients, name='x')
289    ctx = federated_context.FederatedContext(
290        POPULATION_NAME, address_family=ADDRESS_FAMILY)
291    release_manager = tff.program.MemoryReleaseManager()
292    with tff.framework.get_context_stack().install(ctx):
293      state, _ = comp(0, DATA_SOURCE.iterator().select(10))
294      with self.assertRaisesRegex(ValueError, 'message'):
295        await release_manager.release(
296            state, tff.type_at_server(tf.int32), key='result')
297
298
299class FederatedContextPlanCachingTest(absltest.TestCase,
300                                      unittest.IsolatedAsyncioTestCase):
301
302  async def asyncSetUp(self):
303    await super().asyncSetUp()
304
305    @tff.federated_computation(
306        tff.type_at_server(tf.int32),
307        tff.type_at_clients(tff.SequenceType(tf.string)))
308    def identity(state, client_data):
309      del client_data
310      return state, tff.federated_value((), tff.SERVER)
311
312    self.count_clients_comp1 = federated_computation.FederatedComputation(
313        count_clients, name='count_clients1')
314    self.count_clients_comp2 = federated_computation.FederatedComputation(
315        count_clients, name='count_clients2')
316    self.identity_comp = federated_computation.FederatedComputation(
317        identity, name='identity')
318
319    self.data_source1 = federated_data_source.FederatedDataSource(
320        POPULATION_NAME, plan_pb2.ExampleSelector(collection_uri='app:/1'))
321    self.data_source2 = federated_data_source.FederatedDataSource(
322        POPULATION_NAME, plan_pb2.ExampleSelector(collection_uri='app:/2'))
323
324    self.run_computation = self.enter_context(
325        mock.patch.object(
326            server.InProcessServer, 'run_computation', autospec=True))
327    self.run_computation.return_value = build_result_checkpoint(0)
328    self.build_plan = self.enter_context(
329        mock.patch.object(
330            federated_compute_plan_builder, 'build_plan', autospec=True))
331    self.build_plan.return_value = plan_pb2.Plan()
332    self.generate_and_add_flat_buffer_to_plan = self.enter_context(
333        mock.patch.object(
334            plan_utils, 'generate_and_add_flat_buffer_to_plan', autospec=True))
335    self.generate_and_add_flat_buffer_to_plan.side_effect = lambda plan: plan
336    self.enter_context(tff.framework.get_context_stack().install(
337        federated_context.FederatedContext(
338            POPULATION_NAME, address_family=ADDRESS_FAMILY)))
339    self.release_manager = tff.program.MemoryReleaseManager()
340
341    # Run (and therefore cache) count_clients_comp1 with data_source1.
342    await self.release_manager.release(
343        self.count_clients_comp1(0,
344                                 self.data_source1.iterator().select(1)),
345        self.count_clients_comp1.type_signature.result,
346        key='result')
347    self.build_plan.assert_called_once()
348    self.assertEqual(self.build_plan.call_args.args[0],
349                     self.count_clients_comp1.map_reduce_form)
350    self.assertEqual(
351        self.build_plan.call_args.args[1],
352        self.count_clients_comp1.distribute_aggregate_form,
353    )
354    self.assertEqual(
355        self.build_plan.call_args.args[2].example_selector_proto,
356        self.data_source1.example_selector,
357    )
358    self.run_computation.assert_called_once()
359    self.build_plan.reset_mock()
360    self.run_computation.reset_mock()
361
362  async def test_reuse_with_repeat_computation(self):
363    await self.release_manager.release(
364        self.count_clients_comp1(0,
365                                 self.data_source1.iterator().select(1)),
366        self.count_clients_comp1.type_signature.result,
367        key='result')
368    self.build_plan.assert_not_called()
369    self.run_computation.assert_called_once()
370
371  async def test_reuse_with_changed_num_clients(self):
372    await self.release_manager.release(
373        self.count_clients_comp1(0,
374                                 self.data_source1.iterator().select(10)),
375        self.count_clients_comp1.type_signature.result,
376        key='result')
377    self.build_plan.assert_not_called()
378    self.run_computation.assert_called_once()
379
380  async def test_reuse_with_changed_initial_state(self):
381    await self.release_manager.release(
382        self.count_clients_comp1(3,
383                                 self.data_source1.iterator().select(1)),
384        self.count_clients_comp1.type_signature.result,
385        key='result')
386    self.build_plan.assert_not_called()
387    self.run_computation.assert_called_once()
388
389  async def test_reuse_with_equivalent_map_reduce_form(self):
390    await self.release_manager.release(
391        self.count_clients_comp2(0,
392                                 self.data_source1.iterator().select(1)),
393        self.count_clients_comp2.type_signature.result,
394        key='result')
395    self.build_plan.assert_not_called()
396    self.run_computation.assert_called_once()
397
398  async def test_rebuild_with_different_computation(self):
399    await self.release_manager.release(
400        self.identity_comp(0,
401                           self.data_source1.iterator().select(1)),
402        self.identity_comp.type_signature.result,
403        key='result')
404    self.build_plan.assert_called_once()
405    self.assertEqual(self.build_plan.call_args.args[0],
406                     self.identity_comp.map_reduce_form)
407    self.assertEqual(
408        self.build_plan.call_args.args[1],
409        self.identity_comp.distribute_aggregate_form,
410    )
411    self.assertEqual(
412        self.build_plan.call_args.args[2].example_selector_proto,
413        self.data_source1.example_selector,
414    )
415    self.run_computation.assert_called_once()
416
417  async def test_rebuild_with_different_data_source(self):
418    await self.release_manager.release(
419        self.count_clients_comp1(0,
420                                 self.data_source2.iterator().select(1)),
421        self.count_clients_comp1.type_signature.result,
422        key='result')
423    self.build_plan.assert_called_once()
424    self.assertEqual(self.build_plan.call_args.args[0],
425                     self.count_clients_comp1.map_reduce_form)
426    self.assertEqual(
427        self.build_plan.call_args.args[1],
428        self.count_clients_comp1.distribute_aggregate_form,
429    )
430    self.assertEqual(
431        self.build_plan.call_args.args[2].example_selector_proto,
432        self.data_source2.example_selector,
433    )
434    self.run_computation.assert_called_once()
435
436
437if __name__ == '__main__':
438  absltest.main()
439