xref: /aosp_15_r20/external/federated-compute/fcp/demo/federated_context_test.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1*14675a02SAndroid Build Coastguard Worker# Copyright 2022 Google LLC
2*14675a02SAndroid Build Coastguard Worker#
3*14675a02SAndroid Build Coastguard Worker# Licensed under the Apache License, Version 2.0 (the "License");
4*14675a02SAndroid Build Coastguard Worker# you may not use this file except in compliance with the License.
5*14675a02SAndroid Build Coastguard Worker# You may obtain a copy of the License at
6*14675a02SAndroid Build Coastguard Worker#
7*14675a02SAndroid Build Coastguard Worker#      http://www.apache.org/licenses/LICENSE-2.0
8*14675a02SAndroid Build Coastguard Worker#
9*14675a02SAndroid Build Coastguard Worker# Unless required by applicable law or agreed to in writing, software
10*14675a02SAndroid Build Coastguard Worker# distributed under the License is distributed on an "AS IS" BASIS,
11*14675a02SAndroid Build Coastguard Worker# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either expresus or implied.
12*14675a02SAndroid Build Coastguard Worker# See the License for the specific language governing permissions and
13*14675a02SAndroid Build Coastguard Worker# limitations under the License.
14*14675a02SAndroid Build Coastguard Worker"""Tests for federated_context."""
15*14675a02SAndroid Build Coastguard Worker
16*14675a02SAndroid Build Coastguard Workerimport http
17*14675a02SAndroid Build Coastguard Workerimport http.client
18*14675a02SAndroid Build Coastguard Workerimport socket
19*14675a02SAndroid Build Coastguard Workerimport threading
20*14675a02SAndroid Build Coastguard Workerimport unittest
21*14675a02SAndroid Build Coastguard Workerfrom unittest import mock
22*14675a02SAndroid Build Coastguard Worker
23*14675a02SAndroid Build Coastguard Workerfrom absl.testing import absltest
24*14675a02SAndroid Build Coastguard Workerimport attr
25*14675a02SAndroid Build Coastguard Workerimport tensorflow as tf
26*14675a02SAndroid Build Coastguard Workerimport tensorflow_federated as tff
27*14675a02SAndroid Build Coastguard Worker
28*14675a02SAndroid Build Coastguard Workerfrom fcp.artifact_building import artifact_constants
29*14675a02SAndroid Build Coastguard Workerfrom fcp.artifact_building import federated_compute_plan_builder
30*14675a02SAndroid Build Coastguard Workerfrom fcp.artifact_building import plan_utils
31*14675a02SAndroid Build Coastguard Workerfrom fcp.artifact_building import variable_helpers
32*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import federated_computation
33*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import federated_context
34*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import federated_data_source
35*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import server
36*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import test_utils
37*14675a02SAndroid Build Coastguard Workerfrom fcp.protos import plan_pb2
38*14675a02SAndroid Build Coastguard Worker
39*14675a02SAndroid Build Coastguard WorkerADDRESS_FAMILY = socket.AddressFamily.AF_INET
40*14675a02SAndroid Build Coastguard WorkerPOPULATION_NAME = 'test/population'
41*14675a02SAndroid Build Coastguard WorkerDATA_SOURCE = federated_data_source.FederatedDataSource(
42*14675a02SAndroid Build Coastguard Worker    POPULATION_NAME, plan_pb2.ExampleSelector(collection_uri='app:/test'))
43*14675a02SAndroid Build Coastguard Worker
44*14675a02SAndroid Build Coastguard Worker
45*14675a02SAndroid Build Coastguard Worker@tff.tf_computation(tf.int32)
46*14675a02SAndroid Build Coastguard Workerdef add_one(x):
47*14675a02SAndroid Build Coastguard Worker  return x + 1
48*14675a02SAndroid Build Coastguard Worker
49*14675a02SAndroid Build Coastguard Worker
50*14675a02SAndroid Build Coastguard Worker@tff.federated_computation(
51*14675a02SAndroid Build Coastguard Worker    tff.type_at_server(tf.int32),
52*14675a02SAndroid Build Coastguard Worker    tff.type_at_clients(tff.SequenceType(tf.string)))
53*14675a02SAndroid Build Coastguard Workerdef count_clients(state, client_data):
54*14675a02SAndroid Build Coastguard Worker  """Example TFF computation that counts clients."""
55*14675a02SAndroid Build Coastguard Worker  del client_data
56*14675a02SAndroid Build Coastguard Worker  num_clients = tff.federated_sum(tff.federated_value(1, tff.CLIENTS))
57*14675a02SAndroid Build Coastguard Worker  non_state = tff.federated_value((), tff.SERVER)
58*14675a02SAndroid Build Coastguard Worker  return state + num_clients, non_state
59*14675a02SAndroid Build Coastguard Worker
60*14675a02SAndroid Build Coastguard Worker
61*14675a02SAndroid Build Coastguard Worker@tff.federated_computation(
62*14675a02SAndroid Build Coastguard Worker    tff.type_at_server(tff.StructType([('foo', tf.int32), ('bar', tf.int32)])),
63*14675a02SAndroid Build Coastguard Worker    tff.type_at_clients(tff.SequenceType(tf.string)),
64*14675a02SAndroid Build Coastguard Worker)
65*14675a02SAndroid Build Coastguard Workerdef irregular_arrays(state, client_data):
66*14675a02SAndroid Build Coastguard Worker  """Example TFF computation that returns irregular data."""
67*14675a02SAndroid Build Coastguard Worker  del client_data
68*14675a02SAndroid Build Coastguard Worker  num_clients = tff.federated_sum(tff.federated_value(1, tff.CLIENTS))
69*14675a02SAndroid Build Coastguard Worker  non_state = tff.federated_value(1, tff.SERVER)
70*14675a02SAndroid Build Coastguard Worker  return state, non_state + num_clients
71*14675a02SAndroid Build Coastguard Worker
72*14675a02SAndroid Build Coastguard Worker
73*14675a02SAndroid Build Coastguard Worker@attr.s(eq=False, frozen=True, slots=True)
74*14675a02SAndroid Build Coastguard Workerclass TestClass:
75*14675a02SAndroid Build Coastguard Worker  """An attrs class."""
76*14675a02SAndroid Build Coastguard Worker
77*14675a02SAndroid Build Coastguard Worker  field_one = attr.ib()
78*14675a02SAndroid Build Coastguard Worker  field_two = attr.ib()
79*14675a02SAndroid Build Coastguard Worker
80*14675a02SAndroid Build Coastguard Worker
81*14675a02SAndroid Build Coastguard Worker@tff.tf_computation
82*14675a02SAndroid Build Coastguard Workerdef init():
83*14675a02SAndroid Build Coastguard Worker  return TestClass(field_one=1, field_two=2)
84*14675a02SAndroid Build Coastguard Worker
85*14675a02SAndroid Build Coastguard Worker
86*14675a02SAndroid Build Coastguard Workerattrs_type = init.type_signature.result
87*14675a02SAndroid Build Coastguard Worker
88*14675a02SAndroid Build Coastguard Worker
89*14675a02SAndroid Build Coastguard Worker@tff.federated_computation(
90*14675a02SAndroid Build Coastguard Worker    tff.type_at_server(attrs_type),
91*14675a02SAndroid Build Coastguard Worker    tff.type_at_clients(tff.SequenceType(tf.string)),
92*14675a02SAndroid Build Coastguard Worker)
93*14675a02SAndroid Build Coastguard Workerdef attrs_computation(state, client_data):
94*14675a02SAndroid Build Coastguard Worker  """Example TFF computation that returns an attrs class."""
95*14675a02SAndroid Build Coastguard Worker  del client_data
96*14675a02SAndroid Build Coastguard Worker  num_clients = tff.federated_sum(tff.federated_value(1, tff.CLIENTS))
97*14675a02SAndroid Build Coastguard Worker  non_state = tff.federated_value(1, tff.SERVER)
98*14675a02SAndroid Build Coastguard Worker  return state, non_state + num_clients
99*14675a02SAndroid Build Coastguard Worker
100*14675a02SAndroid Build Coastguard Worker
101*14675a02SAndroid Build Coastguard Workerdef build_result_checkpoint(state: int) -> bytes:
102*14675a02SAndroid Build Coastguard Worker  """Helper function to build a result checkpoint for `count_clients`."""
103*14675a02SAndroid Build Coastguard Worker  var_names = variable_helpers.variable_names_from_type(
104*14675a02SAndroid Build Coastguard Worker      count_clients.type_signature.result[0],
105*14675a02SAndroid Build Coastguard Worker      name=artifact_constants.SERVER_STATE_VAR_PREFIX)
106*14675a02SAndroid Build Coastguard Worker  return test_utils.create_checkpoint({var_names[0]: state})
107*14675a02SAndroid Build Coastguard Worker
108*14675a02SAndroid Build Coastguard Worker
109*14675a02SAndroid Build Coastguard Workerclass FederatedContextTest(absltest.TestCase, unittest.IsolatedAsyncioTestCase):
110*14675a02SAndroid Build Coastguard Worker
111*14675a02SAndroid Build Coastguard Worker  def test_invalid_population_name(self):
112*14675a02SAndroid Build Coastguard Worker    with self.assertRaisesRegex(ValueError, 'population_name must match ".+"'):
113*14675a02SAndroid Build Coastguard Worker      federated_context.FederatedContext(
114*14675a02SAndroid Build Coastguard Worker          '^^invalid^^', address_family=ADDRESS_FAMILY)
115*14675a02SAndroid Build Coastguard Worker
116*14675a02SAndroid Build Coastguard Worker  @mock.patch.object(server.InProcessServer, 'shutdown', autospec=True)
117*14675a02SAndroid Build Coastguard Worker  @mock.patch.object(server.InProcessServer, 'serve_forever', autospec=True)
118*14675a02SAndroid Build Coastguard Worker  def test_context_management(self, serve_forever, shutdown):
119*14675a02SAndroid Build Coastguard Worker    started = threading.Event()
120*14675a02SAndroid Build Coastguard Worker    serve_forever.side_effect = lambda *args, **kwargs: started.set()
121*14675a02SAndroid Build Coastguard Worker
122*14675a02SAndroid Build Coastguard Worker    ctx = federated_context.FederatedContext(
123*14675a02SAndroid Build Coastguard Worker        POPULATION_NAME, address_family=ADDRESS_FAMILY)
124*14675a02SAndroid Build Coastguard Worker    self.assertFalse(started.is_set())
125*14675a02SAndroid Build Coastguard Worker    shutdown.assert_not_called()
126*14675a02SAndroid Build Coastguard Worker    with ctx:
127*14675a02SAndroid Build Coastguard Worker      self.assertTrue(started.wait(0.5))
128*14675a02SAndroid Build Coastguard Worker      shutdown.assert_not_called()
129*14675a02SAndroid Build Coastguard Worker    shutdown.assert_called_once()
130*14675a02SAndroid Build Coastguard Worker
131*14675a02SAndroid Build Coastguard Worker  def test_http(self):
132*14675a02SAndroid Build Coastguard Worker    with federated_context.FederatedContext(
133*14675a02SAndroid Build Coastguard Worker        POPULATION_NAME, address_family=ADDRESS_FAMILY) as ctx:
134*14675a02SAndroid Build Coastguard Worker      conn = http.client.HTTPConnection('localhost', port=ctx.server_port)
135*14675a02SAndroid Build Coastguard Worker      conn.request('GET', '/does-not-exist')
136*14675a02SAndroid Build Coastguard Worker      self.assertEqual(conn.getresponse().status, http.HTTPStatus.NOT_FOUND)
137*14675a02SAndroid Build Coastguard Worker
138*14675a02SAndroid Build Coastguard Worker  def test_invoke_non_federated_with_base_context(self):
139*14675a02SAndroid Build Coastguard Worker    base_context = tff.backends.native.create_sync_local_cpp_execution_context()
140*14675a02SAndroid Build Coastguard Worker    ctx = federated_context.FederatedContext(
141*14675a02SAndroid Build Coastguard Worker        POPULATION_NAME,
142*14675a02SAndroid Build Coastguard Worker        address_family=ADDRESS_FAMILY,
143*14675a02SAndroid Build Coastguard Worker        base_context=base_context)
144*14675a02SAndroid Build Coastguard Worker    with tff.framework.get_context_stack().install(ctx):
145*14675a02SAndroid Build Coastguard Worker      self.assertEqual(add_one(3), 4)
146*14675a02SAndroid Build Coastguard Worker
147*14675a02SAndroid Build Coastguard Worker  def test_invoke_non_federated_without_base_context(self):
148*14675a02SAndroid Build Coastguard Worker    ctx = federated_context.FederatedContext(
149*14675a02SAndroid Build Coastguard Worker        POPULATION_NAME, address_family=ADDRESS_FAMILY)
150*14675a02SAndroid Build Coastguard Worker    with tff.framework.get_context_stack().install(ctx):
151*14675a02SAndroid Build Coastguard Worker      with self.assertRaisesRegex(TypeError,
152*14675a02SAndroid Build Coastguard Worker                                  'computation must be a FederatedComputation'):
153*14675a02SAndroid Build Coastguard Worker        add_one(3)
154*14675a02SAndroid Build Coastguard Worker
155*14675a02SAndroid Build Coastguard Worker  def test_invoke_with_invalid_state_type(self):
156*14675a02SAndroid Build Coastguard Worker    comp = federated_computation.FederatedComputation(count_clients, name='x')
157*14675a02SAndroid Build Coastguard Worker    ctx = federated_context.FederatedContext(
158*14675a02SAndroid Build Coastguard Worker        POPULATION_NAME, address_family=ADDRESS_FAMILY)
159*14675a02SAndroid Build Coastguard Worker    with tff.framework.get_context_stack().install(ctx):
160*14675a02SAndroid Build Coastguard Worker      with self.assertRaisesRegex(
161*14675a02SAndroid Build Coastguard Worker          TypeError, r'arg\[0\] must be a value or structure of values'
162*14675a02SAndroid Build Coastguard Worker      ):
163*14675a02SAndroid Build Coastguard Worker        comp(plan_pb2.Plan(), DATA_SOURCE.iterator().select(1))
164*14675a02SAndroid Build Coastguard Worker
165*14675a02SAndroid Build Coastguard Worker  def test_invoke_with_invalid_data_source_type(self):
166*14675a02SAndroid Build Coastguard Worker    comp = federated_computation.FederatedComputation(count_clients, name='x')
167*14675a02SAndroid Build Coastguard Worker    ctx = federated_context.FederatedContext(
168*14675a02SAndroid Build Coastguard Worker        POPULATION_NAME, address_family=ADDRESS_FAMILY)
169*14675a02SAndroid Build Coastguard Worker    with tff.framework.get_context_stack().install(ctx):
170*14675a02SAndroid Build Coastguard Worker      with self.assertRaisesRegex(
171*14675a02SAndroid Build Coastguard Worker          TypeError, r'arg\[1\] must be the result of '
172*14675a02SAndroid Build Coastguard Worker          r'FederatedDataSource.iterator\(\).select\(\)'):
173*14675a02SAndroid Build Coastguard Worker        comp(0, plan_pb2.Plan())
174*14675a02SAndroid Build Coastguard Worker
175*14675a02SAndroid Build Coastguard Worker  def test_invoke_succeeds_with_structure_state_type(self):
176*14675a02SAndroid Build Coastguard Worker    comp = federated_computation.FederatedComputation(
177*14675a02SAndroid Build Coastguard Worker        irregular_arrays, name='x'
178*14675a02SAndroid Build Coastguard Worker    )
179*14675a02SAndroid Build Coastguard Worker    ctx = federated_context.FederatedContext(
180*14675a02SAndroid Build Coastguard Worker        POPULATION_NAME, address_family=ADDRESS_FAMILY
181*14675a02SAndroid Build Coastguard Worker    )
182*14675a02SAndroid Build Coastguard Worker    with tff.framework.get_context_stack().install(ctx):
183*14675a02SAndroid Build Coastguard Worker      state = {'foo': (3, 1), 'bar': (4, 5, 6)}
184*14675a02SAndroid Build Coastguard Worker      comp(state, DATA_SOURCE.iterator().select(1))
185*14675a02SAndroid Build Coastguard Worker
186*14675a02SAndroid Build Coastguard Worker  def test_invoke_succeeds_with_attrs_state_type(self):
187*14675a02SAndroid Build Coastguard Worker    comp = federated_computation.FederatedComputation(
188*14675a02SAndroid Build Coastguard Worker        attrs_computation, name='x'
189*14675a02SAndroid Build Coastguard Worker    )
190*14675a02SAndroid Build Coastguard Worker    ctx = federated_context.FederatedContext(
191*14675a02SAndroid Build Coastguard Worker        POPULATION_NAME, address_family=ADDRESS_FAMILY
192*14675a02SAndroid Build Coastguard Worker    )
193*14675a02SAndroid Build Coastguard Worker    with tff.framework.get_context_stack().install(ctx):
194*14675a02SAndroid Build Coastguard Worker      state = TestClass(field_one=1, field_two=2)
195*14675a02SAndroid Build Coastguard Worker      comp(state, DATA_SOURCE.iterator().select(1))
196*14675a02SAndroid Build Coastguard Worker
197*14675a02SAndroid Build Coastguard Worker  def test_invoke_with_mismatched_population_names(self):
198*14675a02SAndroid Build Coastguard Worker    comp = federated_computation.FederatedComputation(count_clients, name='x')
199*14675a02SAndroid Build Coastguard Worker    ds = federated_data_source.FederatedDataSource('other/name',
200*14675a02SAndroid Build Coastguard Worker                                                   DATA_SOURCE.example_selector)
201*14675a02SAndroid Build Coastguard Worker    ctx = federated_context.FederatedContext(
202*14675a02SAndroid Build Coastguard Worker        POPULATION_NAME, address_family=ADDRESS_FAMILY)
203*14675a02SAndroid Build Coastguard Worker    with tff.framework.get_context_stack().install(ctx):
204*14675a02SAndroid Build Coastguard Worker      with self.assertRaisesRegex(
205*14675a02SAndroid Build Coastguard Worker          ValueError, 'FederatedDataSource and FederatedContext '
206*14675a02SAndroid Build Coastguard Worker          'population_names must match'):
207*14675a02SAndroid Build Coastguard Worker        comp(0, ds.iterator().select(1))
208*14675a02SAndroid Build Coastguard Worker
209*14675a02SAndroid Build Coastguard Worker  @mock.patch.object(server.InProcessServer, 'run_computation', autospec=True)
210*14675a02SAndroid Build Coastguard Worker  async def test_invoke_success(self, run_computation):
211*14675a02SAndroid Build Coastguard Worker    run_computation.return_value = build_result_checkpoint(7)
212*14675a02SAndroid Build Coastguard Worker
213*14675a02SAndroid Build Coastguard Worker    comp = federated_computation.FederatedComputation(count_clients, name='x')
214*14675a02SAndroid Build Coastguard Worker    ctx = federated_context.FederatedContext(
215*14675a02SAndroid Build Coastguard Worker        POPULATION_NAME, address_family=ADDRESS_FAMILY)
216*14675a02SAndroid Build Coastguard Worker    release_manager = tff.program.MemoryReleaseManager()
217*14675a02SAndroid Build Coastguard Worker    with tff.framework.get_context_stack().install(ctx):
218*14675a02SAndroid Build Coastguard Worker      state, _ = comp(3, DATA_SOURCE.iterator().select(10))
219*14675a02SAndroid Build Coastguard Worker      await release_manager.release(
220*14675a02SAndroid Build Coastguard Worker          state, tff.type_at_server(tf.int32), key='result')
221*14675a02SAndroid Build Coastguard Worker
222*14675a02SAndroid Build Coastguard Worker    self.assertEqual(release_manager.values()['result'][0], 7)
223*14675a02SAndroid Build Coastguard Worker
224*14675a02SAndroid Build Coastguard Worker    run_computation.assert_called_once_with(
225*14675a02SAndroid Build Coastguard Worker        mock.ANY,
226*14675a02SAndroid Build Coastguard Worker        comp.name,
227*14675a02SAndroid Build Coastguard Worker        mock.ANY,
228*14675a02SAndroid Build Coastguard Worker        mock.ANY,
229*14675a02SAndroid Build Coastguard Worker        DATA_SOURCE.task_assignment_mode,
230*14675a02SAndroid Build Coastguard Worker        10,
231*14675a02SAndroid Build Coastguard Worker    )
232*14675a02SAndroid Build Coastguard Worker    plan = run_computation.call_args.args[2]
233*14675a02SAndroid Build Coastguard Worker    self.assertIsInstance(plan, plan_pb2.Plan)
234*14675a02SAndroid Build Coastguard Worker    self.assertNotEmpty(plan.client_tflite_graph_bytes)
235*14675a02SAndroid Build Coastguard Worker    input_var_names = variable_helpers.variable_names_from_type(
236*14675a02SAndroid Build Coastguard Worker        count_clients.type_signature.parameter[0],
237*14675a02SAndroid Build Coastguard Worker        name=artifact_constants.SERVER_STATE_VAR_PREFIX)
238*14675a02SAndroid Build Coastguard Worker    self.assertLen(input_var_names, 1)
239*14675a02SAndroid Build Coastguard Worker    self.assertEqual(
240*14675a02SAndroid Build Coastguard Worker        test_utils.read_tensor_from_checkpoint(
241*14675a02SAndroid Build Coastguard Worker            run_computation.call_args.args[3], input_var_names[0], tf.int32), 3)
242*14675a02SAndroid Build Coastguard Worker
243*14675a02SAndroid Build Coastguard Worker  @mock.patch.object(server.InProcessServer, 'run_computation', autospec=True)
244*14675a02SAndroid Build Coastguard Worker  async def test_invoke_with_value_reference(self, run_computation):
245*14675a02SAndroid Build Coastguard Worker    run_computation.side_effect = [
246*14675a02SAndroid Build Coastguard Worker        build_result_checkpoint(1234),
247*14675a02SAndroid Build Coastguard Worker        build_result_checkpoint(5678)
248*14675a02SAndroid Build Coastguard Worker    ]
249*14675a02SAndroid Build Coastguard Worker
250*14675a02SAndroid Build Coastguard Worker    comp = federated_computation.FederatedComputation(count_clients, name='x')
251*14675a02SAndroid Build Coastguard Worker    ctx = federated_context.FederatedContext(
252*14675a02SAndroid Build Coastguard Worker        POPULATION_NAME, address_family=ADDRESS_FAMILY)
253*14675a02SAndroid Build Coastguard Worker    release_manager = tff.program.MemoryReleaseManager()
254*14675a02SAndroid Build Coastguard Worker    with tff.framework.get_context_stack().install(ctx):
255*14675a02SAndroid Build Coastguard Worker      state, _ = comp(3, DATA_SOURCE.iterator().select(10))
256*14675a02SAndroid Build Coastguard Worker      state, _ = comp(state, DATA_SOURCE.iterator().select(10))
257*14675a02SAndroid Build Coastguard Worker      await release_manager.release(
258*14675a02SAndroid Build Coastguard Worker          state, tff.type_at_server(tf.int32), key='result')
259*14675a02SAndroid Build Coastguard Worker
260*14675a02SAndroid Build Coastguard Worker    self.assertEqual(release_manager.values()['result'][0], 5678)
261*14675a02SAndroid Build Coastguard Worker
262*14675a02SAndroid Build Coastguard Worker    input_var_names = variable_helpers.variable_names_from_type(
263*14675a02SAndroid Build Coastguard Worker        count_clients.type_signature.parameter[0],
264*14675a02SAndroid Build Coastguard Worker        name=artifact_constants.SERVER_STATE_VAR_PREFIX)
265*14675a02SAndroid Build Coastguard Worker    self.assertLen(input_var_names, 1)
266*14675a02SAndroid Build Coastguard Worker    # The second invocation should be passed the value returned by the first
267*14675a02SAndroid Build Coastguard Worker    # invocation.
268*14675a02SAndroid Build Coastguard Worker    self.assertEqual(run_computation.call_count, 2)
269*14675a02SAndroid Build Coastguard Worker    self.assertEqual(
270*14675a02SAndroid Build Coastguard Worker        test_utils.read_tensor_from_checkpoint(
271*14675a02SAndroid Build Coastguard Worker            run_computation.call_args.args[3], input_var_names[0], tf.int32),
272*14675a02SAndroid Build Coastguard Worker        1234)
273*14675a02SAndroid Build Coastguard Worker
274*14675a02SAndroid Build Coastguard Worker  async def test_invoke_without_input_state(self):
275*14675a02SAndroid Build Coastguard Worker    comp = federated_computation.FederatedComputation(count_clients, name='x')
276*14675a02SAndroid Build Coastguard Worker    ctx = federated_context.FederatedContext(
277*14675a02SAndroid Build Coastguard Worker        POPULATION_NAME, address_family=ADDRESS_FAMILY)
278*14675a02SAndroid Build Coastguard Worker    with tff.framework.get_context_stack().install(ctx):
279*14675a02SAndroid Build Coastguard Worker      with self.assertRaisesRegex(
280*14675a02SAndroid Build Coastguard Worker          TypeError, r'arg\[0\] must be a value or structure of values'
281*14675a02SAndroid Build Coastguard Worker      ):
282*14675a02SAndroid Build Coastguard Worker        comp(None, DATA_SOURCE.iterator().select(1))
283*14675a02SAndroid Build Coastguard Worker
284*14675a02SAndroid Build Coastguard Worker  @mock.patch.object(server.InProcessServer, 'run_computation', autospec=True)
285*14675a02SAndroid Build Coastguard Worker  async def test_invoke_with_run_computation_error(self, run_computation):
286*14675a02SAndroid Build Coastguard Worker    run_computation.side_effect = ValueError('message')
287*14675a02SAndroid Build Coastguard Worker
288*14675a02SAndroid Build Coastguard Worker    comp = federated_computation.FederatedComputation(count_clients, name='x')
289*14675a02SAndroid Build Coastguard Worker    ctx = federated_context.FederatedContext(
290*14675a02SAndroid Build Coastguard Worker        POPULATION_NAME, address_family=ADDRESS_FAMILY)
291*14675a02SAndroid Build Coastguard Worker    release_manager = tff.program.MemoryReleaseManager()
292*14675a02SAndroid Build Coastguard Worker    with tff.framework.get_context_stack().install(ctx):
293*14675a02SAndroid Build Coastguard Worker      state, _ = comp(0, DATA_SOURCE.iterator().select(10))
294*14675a02SAndroid Build Coastguard Worker      with self.assertRaisesRegex(ValueError, 'message'):
295*14675a02SAndroid Build Coastguard Worker        await release_manager.release(
296*14675a02SAndroid Build Coastguard Worker            state, tff.type_at_server(tf.int32), key='result')
297*14675a02SAndroid Build Coastguard Worker
298*14675a02SAndroid Build Coastguard Worker
299*14675a02SAndroid Build Coastguard Workerclass FederatedContextPlanCachingTest(absltest.TestCase,
300*14675a02SAndroid Build Coastguard Worker                                      unittest.IsolatedAsyncioTestCase):
301*14675a02SAndroid Build Coastguard Worker
302*14675a02SAndroid Build Coastguard Worker  async def asyncSetUp(self):
303*14675a02SAndroid Build Coastguard Worker    await super().asyncSetUp()
304*14675a02SAndroid Build Coastguard Worker
305*14675a02SAndroid Build Coastguard Worker    @tff.federated_computation(
306*14675a02SAndroid Build Coastguard Worker        tff.type_at_server(tf.int32),
307*14675a02SAndroid Build Coastguard Worker        tff.type_at_clients(tff.SequenceType(tf.string)))
308*14675a02SAndroid Build Coastguard Worker    def identity(state, client_data):
309*14675a02SAndroid Build Coastguard Worker      del client_data
310*14675a02SAndroid Build Coastguard Worker      return state, tff.federated_value((), tff.SERVER)
311*14675a02SAndroid Build Coastguard Worker
312*14675a02SAndroid Build Coastguard Worker    self.count_clients_comp1 = federated_computation.FederatedComputation(
313*14675a02SAndroid Build Coastguard Worker        count_clients, name='count_clients1')
314*14675a02SAndroid Build Coastguard Worker    self.count_clients_comp2 = federated_computation.FederatedComputation(
315*14675a02SAndroid Build Coastguard Worker        count_clients, name='count_clients2')
316*14675a02SAndroid Build Coastguard Worker    self.identity_comp = federated_computation.FederatedComputation(
317*14675a02SAndroid Build Coastguard Worker        identity, name='identity')
318*14675a02SAndroid Build Coastguard Worker
319*14675a02SAndroid Build Coastguard Worker    self.data_source1 = federated_data_source.FederatedDataSource(
320*14675a02SAndroid Build Coastguard Worker        POPULATION_NAME, plan_pb2.ExampleSelector(collection_uri='app:/1'))
321*14675a02SAndroid Build Coastguard Worker    self.data_source2 = federated_data_source.FederatedDataSource(
322*14675a02SAndroid Build Coastguard Worker        POPULATION_NAME, plan_pb2.ExampleSelector(collection_uri='app:/2'))
323*14675a02SAndroid Build Coastguard Worker
324*14675a02SAndroid Build Coastguard Worker    self.run_computation = self.enter_context(
325*14675a02SAndroid Build Coastguard Worker        mock.patch.object(
326*14675a02SAndroid Build Coastguard Worker            server.InProcessServer, 'run_computation', autospec=True))
327*14675a02SAndroid Build Coastguard Worker    self.run_computation.return_value = build_result_checkpoint(0)
328*14675a02SAndroid Build Coastguard Worker    self.build_plan = self.enter_context(
329*14675a02SAndroid Build Coastguard Worker        mock.patch.object(
330*14675a02SAndroid Build Coastguard Worker            federated_compute_plan_builder, 'build_plan', autospec=True))
331*14675a02SAndroid Build Coastguard Worker    self.build_plan.return_value = plan_pb2.Plan()
332*14675a02SAndroid Build Coastguard Worker    self.generate_and_add_flat_buffer_to_plan = self.enter_context(
333*14675a02SAndroid Build Coastguard Worker        mock.patch.object(
334*14675a02SAndroid Build Coastguard Worker            plan_utils, 'generate_and_add_flat_buffer_to_plan', autospec=True))
335*14675a02SAndroid Build Coastguard Worker    self.generate_and_add_flat_buffer_to_plan.side_effect = lambda plan: plan
336*14675a02SAndroid Build Coastguard Worker    self.enter_context(tff.framework.get_context_stack().install(
337*14675a02SAndroid Build Coastguard Worker        federated_context.FederatedContext(
338*14675a02SAndroid Build Coastguard Worker            POPULATION_NAME, address_family=ADDRESS_FAMILY)))
339*14675a02SAndroid Build Coastguard Worker    self.release_manager = tff.program.MemoryReleaseManager()
340*14675a02SAndroid Build Coastguard Worker
341*14675a02SAndroid Build Coastguard Worker    # Run (and therefore cache) count_clients_comp1 with data_source1.
342*14675a02SAndroid Build Coastguard Worker    await self.release_manager.release(
343*14675a02SAndroid Build Coastguard Worker        self.count_clients_comp1(0,
344*14675a02SAndroid Build Coastguard Worker                                 self.data_source1.iterator().select(1)),
345*14675a02SAndroid Build Coastguard Worker        self.count_clients_comp1.type_signature.result,
346*14675a02SAndroid Build Coastguard Worker        key='result')
347*14675a02SAndroid Build Coastguard Worker    self.build_plan.assert_called_once()
348*14675a02SAndroid Build Coastguard Worker    self.assertEqual(self.build_plan.call_args.args[0],
349*14675a02SAndroid Build Coastguard Worker                     self.count_clients_comp1.map_reduce_form)
350*14675a02SAndroid Build Coastguard Worker    self.assertEqual(
351*14675a02SAndroid Build Coastguard Worker        self.build_plan.call_args.args[1],
352*14675a02SAndroid Build Coastguard Worker        self.count_clients_comp1.distribute_aggregate_form,
353*14675a02SAndroid Build Coastguard Worker    )
354*14675a02SAndroid Build Coastguard Worker    self.assertEqual(
355*14675a02SAndroid Build Coastguard Worker        self.build_plan.call_args.args[2].example_selector_proto,
356*14675a02SAndroid Build Coastguard Worker        self.data_source1.example_selector,
357*14675a02SAndroid Build Coastguard Worker    )
358*14675a02SAndroid Build Coastguard Worker    self.run_computation.assert_called_once()
359*14675a02SAndroid Build Coastguard Worker    self.build_plan.reset_mock()
360*14675a02SAndroid Build Coastguard Worker    self.run_computation.reset_mock()
361*14675a02SAndroid Build Coastguard Worker
362*14675a02SAndroid Build Coastguard Worker  async def test_reuse_with_repeat_computation(self):
363*14675a02SAndroid Build Coastguard Worker    await self.release_manager.release(
364*14675a02SAndroid Build Coastguard Worker        self.count_clients_comp1(0,
365*14675a02SAndroid Build Coastguard Worker                                 self.data_source1.iterator().select(1)),
366*14675a02SAndroid Build Coastguard Worker        self.count_clients_comp1.type_signature.result,
367*14675a02SAndroid Build Coastguard Worker        key='result')
368*14675a02SAndroid Build Coastguard Worker    self.build_plan.assert_not_called()
369*14675a02SAndroid Build Coastguard Worker    self.run_computation.assert_called_once()
370*14675a02SAndroid Build Coastguard Worker
371*14675a02SAndroid Build Coastguard Worker  async def test_reuse_with_changed_num_clients(self):
372*14675a02SAndroid Build Coastguard Worker    await self.release_manager.release(
373*14675a02SAndroid Build Coastguard Worker        self.count_clients_comp1(0,
374*14675a02SAndroid Build Coastguard Worker                                 self.data_source1.iterator().select(10)),
375*14675a02SAndroid Build Coastguard Worker        self.count_clients_comp1.type_signature.result,
376*14675a02SAndroid Build Coastguard Worker        key='result')
377*14675a02SAndroid Build Coastguard Worker    self.build_plan.assert_not_called()
378*14675a02SAndroid Build Coastguard Worker    self.run_computation.assert_called_once()
379*14675a02SAndroid Build Coastguard Worker
380*14675a02SAndroid Build Coastguard Worker  async def test_reuse_with_changed_initial_state(self):
381*14675a02SAndroid Build Coastguard Worker    await self.release_manager.release(
382*14675a02SAndroid Build Coastguard Worker        self.count_clients_comp1(3,
383*14675a02SAndroid Build Coastguard Worker                                 self.data_source1.iterator().select(1)),
384*14675a02SAndroid Build Coastguard Worker        self.count_clients_comp1.type_signature.result,
385*14675a02SAndroid Build Coastguard Worker        key='result')
386*14675a02SAndroid Build Coastguard Worker    self.build_plan.assert_not_called()
387*14675a02SAndroid Build Coastguard Worker    self.run_computation.assert_called_once()
388*14675a02SAndroid Build Coastguard Worker
389*14675a02SAndroid Build Coastguard Worker  async def test_reuse_with_equivalent_map_reduce_form(self):
390*14675a02SAndroid Build Coastguard Worker    await self.release_manager.release(
391*14675a02SAndroid Build Coastguard Worker        self.count_clients_comp2(0,
392*14675a02SAndroid Build Coastguard Worker                                 self.data_source1.iterator().select(1)),
393*14675a02SAndroid Build Coastguard Worker        self.count_clients_comp2.type_signature.result,
394*14675a02SAndroid Build Coastguard Worker        key='result')
395*14675a02SAndroid Build Coastguard Worker    self.build_plan.assert_not_called()
396*14675a02SAndroid Build Coastguard Worker    self.run_computation.assert_called_once()
397*14675a02SAndroid Build Coastguard Worker
398*14675a02SAndroid Build Coastguard Worker  async def test_rebuild_with_different_computation(self):
399*14675a02SAndroid Build Coastguard Worker    await self.release_manager.release(
400*14675a02SAndroid Build Coastguard Worker        self.identity_comp(0,
401*14675a02SAndroid Build Coastguard Worker                           self.data_source1.iterator().select(1)),
402*14675a02SAndroid Build Coastguard Worker        self.identity_comp.type_signature.result,
403*14675a02SAndroid Build Coastguard Worker        key='result')
404*14675a02SAndroid Build Coastguard Worker    self.build_plan.assert_called_once()
405*14675a02SAndroid Build Coastguard Worker    self.assertEqual(self.build_plan.call_args.args[0],
406*14675a02SAndroid Build Coastguard Worker                     self.identity_comp.map_reduce_form)
407*14675a02SAndroid Build Coastguard Worker    self.assertEqual(
408*14675a02SAndroid Build Coastguard Worker        self.build_plan.call_args.args[1],
409*14675a02SAndroid Build Coastguard Worker        self.identity_comp.distribute_aggregate_form,
410*14675a02SAndroid Build Coastguard Worker    )
411*14675a02SAndroid Build Coastguard Worker    self.assertEqual(
412*14675a02SAndroid Build Coastguard Worker        self.build_plan.call_args.args[2].example_selector_proto,
413*14675a02SAndroid Build Coastguard Worker        self.data_source1.example_selector,
414*14675a02SAndroid Build Coastguard Worker    )
415*14675a02SAndroid Build Coastguard Worker    self.run_computation.assert_called_once()
416*14675a02SAndroid Build Coastguard Worker
417*14675a02SAndroid Build Coastguard Worker  async def test_rebuild_with_different_data_source(self):
418*14675a02SAndroid Build Coastguard Worker    await self.release_manager.release(
419*14675a02SAndroid Build Coastguard Worker        self.count_clients_comp1(0,
420*14675a02SAndroid Build Coastguard Worker                                 self.data_source2.iterator().select(1)),
421*14675a02SAndroid Build Coastguard Worker        self.count_clients_comp1.type_signature.result,
422*14675a02SAndroid Build Coastguard Worker        key='result')
423*14675a02SAndroid Build Coastguard Worker    self.build_plan.assert_called_once()
424*14675a02SAndroid Build Coastguard Worker    self.assertEqual(self.build_plan.call_args.args[0],
425*14675a02SAndroid Build Coastguard Worker                     self.count_clients_comp1.map_reduce_form)
426*14675a02SAndroid Build Coastguard Worker    self.assertEqual(
427*14675a02SAndroid Build Coastguard Worker        self.build_plan.call_args.args[1],
428*14675a02SAndroid Build Coastguard Worker        self.count_clients_comp1.distribute_aggregate_form,
429*14675a02SAndroid Build Coastguard Worker    )
430*14675a02SAndroid Build Coastguard Worker    self.assertEqual(
431*14675a02SAndroid Build Coastguard Worker        self.build_plan.call_args.args[2].example_selector_proto,
432*14675a02SAndroid Build Coastguard Worker        self.data_source2.example_selector,
433*14675a02SAndroid Build Coastguard Worker    )
434*14675a02SAndroid Build Coastguard Worker    self.run_computation.assert_called_once()
435*14675a02SAndroid Build Coastguard Worker
436*14675a02SAndroid Build Coastguard Worker
437*14675a02SAndroid Build Coastguard Workerif __name__ == '__main__':
438*14675a02SAndroid Build Coastguard Worker  absltest.main()
439