# Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for server.""" import asyncio import gzip import http import http.client import os import threading import unittest from unittest import mock import urllib.parse import urllib.request from absl import flags from absl import logging from absl.testing import absltest import tensorflow as tf from google.longrunning import operations_pb2 from fcp.demo import plan_utils from fcp.demo import server from fcp.demo import test_utils from fcp.protos import plan_pb2 from fcp.protos.federatedcompute import eligibility_eval_tasks_pb2 from fcp.protos.federatedcompute import task_assignments_pb2 from fcp.tensorflow import external_dataset _TaskAssignmentMode = ( eligibility_eval_tasks_pb2.PopulationEligibilitySpec.TaskInfo.TaskAssignmentMode ) POPULATION_NAME = 'test/population' CAP_TENSOR_NAME = 'cap' COUNT_TENSOR_NAME = 'count' TEST_SLICES = { 'id1': [b'1-1', b'1-2', b'1-3'], 'id2': [b'2-1', b'2-2'], } def create_plan() -> plan_pb2.Plan: """Creates a test plan that counts examples, with a per-client cap.""" with tf.compat.v1.Graph().as_default() as client_graph: dataset_token = tf.compat.v1.placeholder(tf.string, shape=()) input_filepath = tf.compat.v1.placeholder(tf.string, shape=()) output_filepath = tf.compat.v1.placeholder(tf.string, shape=()) ds = external_dataset.ExternalDataset(token=dataset_token, selector=b'') cap = tf.raw_ops.Restore( file_pattern=input_filepath, tensor_name=CAP_TENSOR_NAME, dt=tf.int32) count = ds.take(tf.cast(cap, dtype=tf.int64)).reduce(0, lambda x, _: x + 1) target_node = tf.raw_ops.Save( filename=output_filepath, tensor_names=[COUNT_TENSOR_NAME], data=[count]) with tf.compat.v1.Graph().as_default() as server_graph: filename = tf.compat.v1.placeholder(tf.string, shape=()) contribution_cap = tf.Variable(0, dtype=tf.int32) count = tf.Variable(0, dtype=tf.int32) load_initial_count = count.assign( tf.raw_ops.Restore( file_pattern=filename, tensor_name=COUNT_TENSOR_NAME, dt=tf.int32), read_value=False) load_contribution_cap = contribution_cap.assign( tf.raw_ops.Restore( file_pattern=filename, tensor_name=CAP_TENSOR_NAME, dt=tf.int32), read_value=False) with tf.control_dependencies([load_initial_count, load_contribution_cap]): restore_server_savepoint = tf.no_op() write_client_init = tf.raw_ops.Save( filename=filename, tensor_names=[CAP_TENSOR_NAME], data=[contribution_cap]) read_intermediate_update = count.assign_add( tf.raw_ops.Restore( file_pattern=filename, tensor_name=COUNT_TENSOR_NAME, dt=tf.int32)) save_count = tf.raw_ops.Save( filename=filename, tensor_names=[COUNT_TENSOR_NAME], data=[count]) plan = plan_pb2.Plan( phase=[ plan_pb2.Plan.Phase( client_phase=plan_pb2.ClientPhase( tensorflow_spec=plan_pb2.TensorflowSpec( dataset_token_tensor_name=dataset_token.op.name, input_tensor_specs=[ tf.TensorSpec.from_tensor( input_filepath).experimental_as_proto(), tf.TensorSpec.from_tensor( output_filepath).experimental_as_proto(), ], target_node_names=[target_node.name]), federated_compute=plan_pb2.FederatedComputeIORouter( input_filepath_tensor_name=input_filepath.op.name, output_filepath_tensor_name=output_filepath.op.name)), server_phase=plan_pb2.ServerPhase( write_client_init=plan_pb2.CheckpointOp( saver_def=tf.compat.v1.train.SaverDef( filename_tensor_name=filename.name, save_tensor_name=write_client_init.name)), read_intermediate_update=plan_pb2.CheckpointOp( saver_def=tf.compat.v1.train.SaverDef( filename_tensor_name=filename.name, restore_op_name=read_intermediate_update.name))), server_phase_v2=plan_pb2.ServerPhaseV2(aggregations=[ plan_pb2.ServerAggregationConfig( intrinsic_uri='federated_sum', intrinsic_args=[ plan_pb2.ServerAggregationConfig.IntrinsicArg( input_tensor=tf.TensorSpec( (), tf.int32, COUNT_TENSOR_NAME).experimental_as_proto()) ], output_tensors=[ tf.TensorSpec((), tf.int32, COUNT_TENSOR_NAME) .experimental_as_proto() ]) ])) ], server_savepoint=plan_pb2.CheckpointOp( saver_def=tf.compat.v1.train.SaverDef( filename_tensor_name=filename.name, save_tensor_name=save_count.name, restore_op_name=restore_server_savepoint.name)), version=1) plan.client_graph_bytes.Pack(client_graph.as_graph_def()) plan.server_graph_bytes.Pack(server_graph.as_graph_def()) return plan class ServerTest(absltest.TestCase, unittest.IsolatedAsyncioTestCase): def setUp(self): super().setUp() self.server = server.InProcessServer( # pytype: disable=wrong-arg-types population_name=POPULATION_NAME, host='localhost', port=0) self._server_thread = threading.Thread(target=self.server.serve_forever) self._server_thread.start() self.conn = http.client.HTTPConnection( self.server.server_name, port=self.server.server_port) def tearDown(self): self.server.shutdown() self._server_thread.join() self.server.server_close() super().tearDown() async def wait_for_task(self) -> task_assignments_pb2.TaskAssignment: """Polls the server until a task is being served.""" pop = urllib.parse.quote(POPULATION_NAME, safe='') url = f'/v1/populations/{pop}/taskassignments/test:start?%24alt=proto' request = task_assignments_pb2.StartTaskAssignmentRequest() while True: self.conn.request('POST', url, request.SerializeToString()) http_response = self.conn.getresponse() if http_response.status == http.HTTPStatus.OK: op = operations_pb2.Operation.FromString(http_response.read()) response = task_assignments_pb2.StartTaskAssignmentResponse() op.response.Unpack(response) if response.HasField('task_assignment'): logging.info('wait_for_task received assignment to %s', response.task_assignment.task_name) return response.task_assignment await asyncio.sleep(0.5) async def test_run_computation(self): initial_count = 100 cap = 10 examples_per_client = [1, 5, 15] checkpoint = test_utils.create_checkpoint({ CAP_TENSOR_NAME: cap, COUNT_TENSOR_NAME: initial_count, }) run_computation_task = asyncio.create_task( self.server.run_computation( 'task/name', create_plan(), checkpoint, _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_SINGLE, len(examples_per_client), ) ) # Wait for task assignment to return a task. wait_task = asyncio.create_task(self.wait_for_task()) await asyncio.wait([run_computation_task, wait_task], timeout=10, return_when=asyncio.FIRST_COMPLETED) self.assertTrue(wait_task.done()) # `run_computation` should not be done since no clients have reported. self.assertFalse(run_computation_task.done()) client_runner = os.path.join( flags.FLAGS.test_srcdir, 'com_google_fcp', 'fcp', 'client', 'client_runner_main') server_url = f'http://{self.server.server_name}:{self.server.server_port}/' clients = [] for num_examples in examples_per_client: subprocess = asyncio.create_subprocess_exec( client_runner, f'--server={server_url}', f'--population={POPULATION_NAME}', f'--num_empty_examples={num_examples}', '--sleep_after_round_secs=0', '--use_http_federated_compute_protocol') clients.append(asyncio.create_task((await subprocess).wait())) # Wait for the computation to complete. await asyncio.wait([run_computation_task] + clients, timeout=10) self.assertTrue(run_computation_task.done()) for client in clients: self.assertTrue(client.done()) self.assertEqual(client.result(), 0) # Verify the sum in the checkpoint. result = test_utils.read_tensor_from_checkpoint( run_computation_task.result(), COUNT_TENSOR_NAME, tf.int32) self.assertEqual( result, initial_count + sum([min(n, cap) for n in examples_per_client])) @mock.patch.object( plan_utils.Session, 'slices', new=property(lambda unused_self: TEST_SLICES), ) async def test_federated_select(self): checkpoint = test_utils.create_checkpoint({ CAP_TENSOR_NAME: 100, COUNT_TENSOR_NAME: 0, }) run_computation_task = asyncio.create_task( self.server.run_computation( 'task/name', create_plan(), checkpoint, _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_SINGLE, 1, ) ) # Wait for task assignment to return a task. wait_task = asyncio.create_task(self.wait_for_task()) await asyncio.wait( [run_computation_task, wait_task], timeout=10, return_when=asyncio.FIRST_COMPLETED, ) self.assertTrue(wait_task.done()) uri_template = wait_task.result().federated_select_uri_info.uri_template self.assertNotEmpty(uri_template) # Check the contents of the slices. for served_at_id, slices in TEST_SLICES.items(): for i, slice_data in enumerate(slices): with urllib.request.urlopen( uri_template.format(served_at_id=served_at_id, key_base10=str(i)) ) as response: self.assertEqual( response.getheader('Content-Type'), 'application/octet-stream+gzip', ) self.assertEqual(gzip.decompress(response.read()), slice_data) if __name__ == '__main__': absltest.main()