xref: /aosp_15_r20/external/federated-compute/fcp/demo/server_test.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests for server."""
15
16import asyncio
17import gzip
18import http
19import http.client
20import os
21import threading
22import unittest
23from unittest import mock
24import urllib.parse
25import urllib.request
26
27from absl import flags
28from absl import logging
29from absl.testing import absltest
30import tensorflow as tf
31
32from google.longrunning import operations_pb2
33from fcp.demo import plan_utils
34from fcp.demo import server
35from fcp.demo import test_utils
36from fcp.protos import plan_pb2
37from fcp.protos.federatedcompute import eligibility_eval_tasks_pb2
38from fcp.protos.federatedcompute import task_assignments_pb2
39from fcp.tensorflow import external_dataset
40
41_TaskAssignmentMode = (
42    eligibility_eval_tasks_pb2.PopulationEligibilitySpec.TaskInfo.TaskAssignmentMode
43)
44
45POPULATION_NAME = 'test/population'
46CAP_TENSOR_NAME = 'cap'
47COUNT_TENSOR_NAME = 'count'
48TEST_SLICES = {
49    'id1': [b'1-1', b'1-2', b'1-3'],
50    'id2': [b'2-1', b'2-2'],
51}
52
53
54def create_plan() -> plan_pb2.Plan:
55  """Creates a test plan that counts examples, with a per-client cap."""
56
57  with tf.compat.v1.Graph().as_default() as client_graph:
58    dataset_token = tf.compat.v1.placeholder(tf.string, shape=())
59    input_filepath = tf.compat.v1.placeholder(tf.string, shape=())
60    output_filepath = tf.compat.v1.placeholder(tf.string, shape=())
61    ds = external_dataset.ExternalDataset(token=dataset_token, selector=b'')
62    cap = tf.raw_ops.Restore(
63        file_pattern=input_filepath, tensor_name=CAP_TENSOR_NAME, dt=tf.int32)
64    count = ds.take(tf.cast(cap, dtype=tf.int64)).reduce(0, lambda x, _: x + 1)
65    target_node = tf.raw_ops.Save(
66        filename=output_filepath,
67        tensor_names=[COUNT_TENSOR_NAME],
68        data=[count])
69
70  with tf.compat.v1.Graph().as_default() as server_graph:
71    filename = tf.compat.v1.placeholder(tf.string, shape=())
72    contribution_cap = tf.Variable(0, dtype=tf.int32)
73    count = tf.Variable(0, dtype=tf.int32)
74    load_initial_count = count.assign(
75        tf.raw_ops.Restore(
76            file_pattern=filename, tensor_name=COUNT_TENSOR_NAME, dt=tf.int32),
77        read_value=False)
78    load_contribution_cap = contribution_cap.assign(
79        tf.raw_ops.Restore(
80            file_pattern=filename, tensor_name=CAP_TENSOR_NAME, dt=tf.int32),
81        read_value=False)
82    with tf.control_dependencies([load_initial_count, load_contribution_cap]):
83      restore_server_savepoint = tf.no_op()
84    write_client_init = tf.raw_ops.Save(
85        filename=filename,
86        tensor_names=[CAP_TENSOR_NAME],
87        data=[contribution_cap])
88
89    read_intermediate_update = count.assign_add(
90        tf.raw_ops.Restore(
91            file_pattern=filename, tensor_name=COUNT_TENSOR_NAME, dt=tf.int32))
92    save_count = tf.raw_ops.Save(
93        filename=filename, tensor_names=[COUNT_TENSOR_NAME], data=[count])
94
95  plan = plan_pb2.Plan(
96      phase=[
97          plan_pb2.Plan.Phase(
98              client_phase=plan_pb2.ClientPhase(
99                  tensorflow_spec=plan_pb2.TensorflowSpec(
100                      dataset_token_tensor_name=dataset_token.op.name,
101                      input_tensor_specs=[
102                          tf.TensorSpec.from_tensor(
103                              input_filepath).experimental_as_proto(),
104                          tf.TensorSpec.from_tensor(
105                              output_filepath).experimental_as_proto(),
106                      ],
107                      target_node_names=[target_node.name]),
108                  federated_compute=plan_pb2.FederatedComputeIORouter(
109                      input_filepath_tensor_name=input_filepath.op.name,
110                      output_filepath_tensor_name=output_filepath.op.name)),
111              server_phase=plan_pb2.ServerPhase(
112                  write_client_init=plan_pb2.CheckpointOp(
113                      saver_def=tf.compat.v1.train.SaverDef(
114                          filename_tensor_name=filename.name,
115                          save_tensor_name=write_client_init.name)),
116                  read_intermediate_update=plan_pb2.CheckpointOp(
117                      saver_def=tf.compat.v1.train.SaverDef(
118                          filename_tensor_name=filename.name,
119                          restore_op_name=read_intermediate_update.name))),
120              server_phase_v2=plan_pb2.ServerPhaseV2(aggregations=[
121                  plan_pb2.ServerAggregationConfig(
122                      intrinsic_uri='federated_sum',
123                      intrinsic_args=[
124                          plan_pb2.ServerAggregationConfig.IntrinsicArg(
125                              input_tensor=tf.TensorSpec(
126                                  (), tf.int32,
127                                  COUNT_TENSOR_NAME).experimental_as_proto())
128                      ],
129                      output_tensors=[
130                          tf.TensorSpec((), tf.int32, COUNT_TENSOR_NAME)
131                          .experimental_as_proto()
132                      ])
133              ]))
134      ],
135      server_savepoint=plan_pb2.CheckpointOp(
136          saver_def=tf.compat.v1.train.SaverDef(
137              filename_tensor_name=filename.name,
138              save_tensor_name=save_count.name,
139              restore_op_name=restore_server_savepoint.name)),
140      version=1)
141  plan.client_graph_bytes.Pack(client_graph.as_graph_def())
142  plan.server_graph_bytes.Pack(server_graph.as_graph_def())
143  return plan
144
145
146class ServerTest(absltest.TestCase, unittest.IsolatedAsyncioTestCase):
147
148  def setUp(self):
149    super().setUp()
150    self.server = server.InProcessServer(  # pytype: disable=wrong-arg-types
151        population_name=POPULATION_NAME,
152        host='localhost',
153        port=0)
154    self._server_thread = threading.Thread(target=self.server.serve_forever)
155    self._server_thread.start()
156    self.conn = http.client.HTTPConnection(
157        self.server.server_name, port=self.server.server_port)
158
159  def tearDown(self):
160    self.server.shutdown()
161    self._server_thread.join()
162    self.server.server_close()
163    super().tearDown()
164
165  async def wait_for_task(self) -> task_assignments_pb2.TaskAssignment:
166    """Polls the server until a task is being served."""
167    pop = urllib.parse.quote(POPULATION_NAME, safe='')
168    url = f'/v1/populations/{pop}/taskassignments/test:start?%24alt=proto'
169    request = task_assignments_pb2.StartTaskAssignmentRequest()
170    while True:
171      self.conn.request('POST', url, request.SerializeToString())
172      http_response = self.conn.getresponse()
173      if http_response.status == http.HTTPStatus.OK:
174        op = operations_pb2.Operation.FromString(http_response.read())
175        response = task_assignments_pb2.StartTaskAssignmentResponse()
176        op.response.Unpack(response)
177        if response.HasField('task_assignment'):
178          logging.info('wait_for_task received assignment to %s',
179                       response.task_assignment.task_name)
180          return response.task_assignment
181      await asyncio.sleep(0.5)
182
183  async def test_run_computation(self):
184    initial_count = 100
185    cap = 10
186    examples_per_client = [1, 5, 15]
187    checkpoint = test_utils.create_checkpoint({
188        CAP_TENSOR_NAME: cap,
189        COUNT_TENSOR_NAME: initial_count,
190    })
191    run_computation_task = asyncio.create_task(
192        self.server.run_computation(
193            'task/name',
194            create_plan(),
195            checkpoint,
196            _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_SINGLE,
197            len(examples_per_client),
198        )
199    )
200
201    # Wait for task assignment to return a task.
202    wait_task = asyncio.create_task(self.wait_for_task())
203    await asyncio.wait([run_computation_task, wait_task],
204                       timeout=10,
205                       return_when=asyncio.FIRST_COMPLETED)
206    self.assertTrue(wait_task.done())
207    # `run_computation` should not be done since no clients have reported.
208    self.assertFalse(run_computation_task.done())
209
210    client_runner = os.path.join(
211        flags.FLAGS.test_srcdir,
212        'com_google_fcp',
213        'fcp',
214        'client',
215        'client_runner_main')
216    server_url = f'http://{self.server.server_name}:{self.server.server_port}/'
217    clients = []
218    for num_examples in examples_per_client:
219      subprocess = asyncio.create_subprocess_exec(
220          client_runner, f'--server={server_url}',
221          f'--population={POPULATION_NAME}',
222          f'--num_empty_examples={num_examples}', '--sleep_after_round_secs=0',
223          '--use_http_federated_compute_protocol')
224      clients.append(asyncio.create_task((await subprocess).wait()))
225
226    # Wait for the computation to complete.
227    await asyncio.wait([run_computation_task] + clients, timeout=10)
228    self.assertTrue(run_computation_task.done())
229    for client in clients:
230      self.assertTrue(client.done())
231      self.assertEqual(client.result(), 0)
232
233    # Verify the sum in the checkpoint.
234    result = test_utils.read_tensor_from_checkpoint(
235        run_computation_task.result(), COUNT_TENSOR_NAME, tf.int32)
236    self.assertEqual(
237        result, initial_count + sum([min(n, cap) for n in examples_per_client]))
238
239  @mock.patch.object(
240      plan_utils.Session,
241      'slices',
242      new=property(lambda unused_self: TEST_SLICES),
243  )
244  async def test_federated_select(self):
245    checkpoint = test_utils.create_checkpoint({
246        CAP_TENSOR_NAME: 100,
247        COUNT_TENSOR_NAME: 0,
248    })
249    run_computation_task = asyncio.create_task(
250        self.server.run_computation(
251            'task/name',
252            create_plan(),
253            checkpoint,
254            _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_SINGLE,
255            1,
256        )
257    )
258
259    # Wait for task assignment to return a task.
260    wait_task = asyncio.create_task(self.wait_for_task())
261    await asyncio.wait(
262        [run_computation_task, wait_task],
263        timeout=10,
264        return_when=asyncio.FIRST_COMPLETED,
265    )
266    self.assertTrue(wait_task.done())
267    uri_template = wait_task.result().federated_select_uri_info.uri_template
268    self.assertNotEmpty(uri_template)
269
270    # Check the contents of the slices.
271    for served_at_id, slices in TEST_SLICES.items():
272      for i, slice_data in enumerate(slices):
273        with urllib.request.urlopen(
274            uri_template.format(served_at_id=served_at_id, key_base10=str(i))
275        ) as response:
276          self.assertEqual(
277              response.getheader('Content-Type'),
278              'application/octet-stream+gzip',
279          )
280          self.assertEqual(gzip.decompress(response.read()), slice_data)
281
282
283if __name__ == '__main__':
284  absltest.main()
285