1# Copyright 2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests for server.""" 15 16import asyncio 17import gzip 18import http 19import http.client 20import os 21import threading 22import unittest 23from unittest import mock 24import urllib.parse 25import urllib.request 26 27from absl import flags 28from absl import logging 29from absl.testing import absltest 30import tensorflow as tf 31 32from google.longrunning import operations_pb2 33from fcp.demo import plan_utils 34from fcp.demo import server 35from fcp.demo import test_utils 36from fcp.protos import plan_pb2 37from fcp.protos.federatedcompute import eligibility_eval_tasks_pb2 38from fcp.protos.federatedcompute import task_assignments_pb2 39from fcp.tensorflow import external_dataset 40 41_TaskAssignmentMode = ( 42 eligibility_eval_tasks_pb2.PopulationEligibilitySpec.TaskInfo.TaskAssignmentMode 43) 44 45POPULATION_NAME = 'test/population' 46CAP_TENSOR_NAME = 'cap' 47COUNT_TENSOR_NAME = 'count' 48TEST_SLICES = { 49 'id1': [b'1-1', b'1-2', b'1-3'], 50 'id2': [b'2-1', b'2-2'], 51} 52 53 54def create_plan() -> plan_pb2.Plan: 55 """Creates a test plan that counts examples, with a per-client cap.""" 56 57 with tf.compat.v1.Graph().as_default() as client_graph: 58 dataset_token = tf.compat.v1.placeholder(tf.string, shape=()) 59 input_filepath = tf.compat.v1.placeholder(tf.string, shape=()) 60 output_filepath = tf.compat.v1.placeholder(tf.string, shape=()) 61 ds = external_dataset.ExternalDataset(token=dataset_token, selector=b'') 62 cap = tf.raw_ops.Restore( 63 file_pattern=input_filepath, tensor_name=CAP_TENSOR_NAME, dt=tf.int32) 64 count = ds.take(tf.cast(cap, dtype=tf.int64)).reduce(0, lambda x, _: x + 1) 65 target_node = tf.raw_ops.Save( 66 filename=output_filepath, 67 tensor_names=[COUNT_TENSOR_NAME], 68 data=[count]) 69 70 with tf.compat.v1.Graph().as_default() as server_graph: 71 filename = tf.compat.v1.placeholder(tf.string, shape=()) 72 contribution_cap = tf.Variable(0, dtype=tf.int32) 73 count = tf.Variable(0, dtype=tf.int32) 74 load_initial_count = count.assign( 75 tf.raw_ops.Restore( 76 file_pattern=filename, tensor_name=COUNT_TENSOR_NAME, dt=tf.int32), 77 read_value=False) 78 load_contribution_cap = contribution_cap.assign( 79 tf.raw_ops.Restore( 80 file_pattern=filename, tensor_name=CAP_TENSOR_NAME, dt=tf.int32), 81 read_value=False) 82 with tf.control_dependencies([load_initial_count, load_contribution_cap]): 83 restore_server_savepoint = tf.no_op() 84 write_client_init = tf.raw_ops.Save( 85 filename=filename, 86 tensor_names=[CAP_TENSOR_NAME], 87 data=[contribution_cap]) 88 89 read_intermediate_update = count.assign_add( 90 tf.raw_ops.Restore( 91 file_pattern=filename, tensor_name=COUNT_TENSOR_NAME, dt=tf.int32)) 92 save_count = tf.raw_ops.Save( 93 filename=filename, tensor_names=[COUNT_TENSOR_NAME], data=[count]) 94 95 plan = plan_pb2.Plan( 96 phase=[ 97 plan_pb2.Plan.Phase( 98 client_phase=plan_pb2.ClientPhase( 99 tensorflow_spec=plan_pb2.TensorflowSpec( 100 dataset_token_tensor_name=dataset_token.op.name, 101 input_tensor_specs=[ 102 tf.TensorSpec.from_tensor( 103 input_filepath).experimental_as_proto(), 104 tf.TensorSpec.from_tensor( 105 output_filepath).experimental_as_proto(), 106 ], 107 target_node_names=[target_node.name]), 108 federated_compute=plan_pb2.FederatedComputeIORouter( 109 input_filepath_tensor_name=input_filepath.op.name, 110 output_filepath_tensor_name=output_filepath.op.name)), 111 server_phase=plan_pb2.ServerPhase( 112 write_client_init=plan_pb2.CheckpointOp( 113 saver_def=tf.compat.v1.train.SaverDef( 114 filename_tensor_name=filename.name, 115 save_tensor_name=write_client_init.name)), 116 read_intermediate_update=plan_pb2.CheckpointOp( 117 saver_def=tf.compat.v1.train.SaverDef( 118 filename_tensor_name=filename.name, 119 restore_op_name=read_intermediate_update.name))), 120 server_phase_v2=plan_pb2.ServerPhaseV2(aggregations=[ 121 plan_pb2.ServerAggregationConfig( 122 intrinsic_uri='federated_sum', 123 intrinsic_args=[ 124 plan_pb2.ServerAggregationConfig.IntrinsicArg( 125 input_tensor=tf.TensorSpec( 126 (), tf.int32, 127 COUNT_TENSOR_NAME).experimental_as_proto()) 128 ], 129 output_tensors=[ 130 tf.TensorSpec((), tf.int32, COUNT_TENSOR_NAME) 131 .experimental_as_proto() 132 ]) 133 ])) 134 ], 135 server_savepoint=plan_pb2.CheckpointOp( 136 saver_def=tf.compat.v1.train.SaverDef( 137 filename_tensor_name=filename.name, 138 save_tensor_name=save_count.name, 139 restore_op_name=restore_server_savepoint.name)), 140 version=1) 141 plan.client_graph_bytes.Pack(client_graph.as_graph_def()) 142 plan.server_graph_bytes.Pack(server_graph.as_graph_def()) 143 return plan 144 145 146class ServerTest(absltest.TestCase, unittest.IsolatedAsyncioTestCase): 147 148 def setUp(self): 149 super().setUp() 150 self.server = server.InProcessServer( # pytype: disable=wrong-arg-types 151 population_name=POPULATION_NAME, 152 host='localhost', 153 port=0) 154 self._server_thread = threading.Thread(target=self.server.serve_forever) 155 self._server_thread.start() 156 self.conn = http.client.HTTPConnection( 157 self.server.server_name, port=self.server.server_port) 158 159 def tearDown(self): 160 self.server.shutdown() 161 self._server_thread.join() 162 self.server.server_close() 163 super().tearDown() 164 165 async def wait_for_task(self) -> task_assignments_pb2.TaskAssignment: 166 """Polls the server until a task is being served.""" 167 pop = urllib.parse.quote(POPULATION_NAME, safe='') 168 url = f'/v1/populations/{pop}/taskassignments/test:start?%24alt=proto' 169 request = task_assignments_pb2.StartTaskAssignmentRequest() 170 while True: 171 self.conn.request('POST', url, request.SerializeToString()) 172 http_response = self.conn.getresponse() 173 if http_response.status == http.HTTPStatus.OK: 174 op = operations_pb2.Operation.FromString(http_response.read()) 175 response = task_assignments_pb2.StartTaskAssignmentResponse() 176 op.response.Unpack(response) 177 if response.HasField('task_assignment'): 178 logging.info('wait_for_task received assignment to %s', 179 response.task_assignment.task_name) 180 return response.task_assignment 181 await asyncio.sleep(0.5) 182 183 async def test_run_computation(self): 184 initial_count = 100 185 cap = 10 186 examples_per_client = [1, 5, 15] 187 checkpoint = test_utils.create_checkpoint({ 188 CAP_TENSOR_NAME: cap, 189 COUNT_TENSOR_NAME: initial_count, 190 }) 191 run_computation_task = asyncio.create_task( 192 self.server.run_computation( 193 'task/name', 194 create_plan(), 195 checkpoint, 196 _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_SINGLE, 197 len(examples_per_client), 198 ) 199 ) 200 201 # Wait for task assignment to return a task. 202 wait_task = asyncio.create_task(self.wait_for_task()) 203 await asyncio.wait([run_computation_task, wait_task], 204 timeout=10, 205 return_when=asyncio.FIRST_COMPLETED) 206 self.assertTrue(wait_task.done()) 207 # `run_computation` should not be done since no clients have reported. 208 self.assertFalse(run_computation_task.done()) 209 210 client_runner = os.path.join( 211 flags.FLAGS.test_srcdir, 212 'com_google_fcp', 213 'fcp', 214 'client', 215 'client_runner_main') 216 server_url = f'http://{self.server.server_name}:{self.server.server_port}/' 217 clients = [] 218 for num_examples in examples_per_client: 219 subprocess = asyncio.create_subprocess_exec( 220 client_runner, f'--server={server_url}', 221 f'--population={POPULATION_NAME}', 222 f'--num_empty_examples={num_examples}', '--sleep_after_round_secs=0', 223 '--use_http_federated_compute_protocol') 224 clients.append(asyncio.create_task((await subprocess).wait())) 225 226 # Wait for the computation to complete. 227 await asyncio.wait([run_computation_task] + clients, timeout=10) 228 self.assertTrue(run_computation_task.done()) 229 for client in clients: 230 self.assertTrue(client.done()) 231 self.assertEqual(client.result(), 0) 232 233 # Verify the sum in the checkpoint. 234 result = test_utils.read_tensor_from_checkpoint( 235 run_computation_task.result(), COUNT_TENSOR_NAME, tf.int32) 236 self.assertEqual( 237 result, initial_count + sum([min(n, cap) for n in examples_per_client])) 238 239 @mock.patch.object( 240 plan_utils.Session, 241 'slices', 242 new=property(lambda unused_self: TEST_SLICES), 243 ) 244 async def test_federated_select(self): 245 checkpoint = test_utils.create_checkpoint({ 246 CAP_TENSOR_NAME: 100, 247 COUNT_TENSOR_NAME: 0, 248 }) 249 run_computation_task = asyncio.create_task( 250 self.server.run_computation( 251 'task/name', 252 create_plan(), 253 checkpoint, 254 _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_SINGLE, 255 1, 256 ) 257 ) 258 259 # Wait for task assignment to return a task. 260 wait_task = asyncio.create_task(self.wait_for_task()) 261 await asyncio.wait( 262 [run_computation_task, wait_task], 263 timeout=10, 264 return_when=asyncio.FIRST_COMPLETED, 265 ) 266 self.assertTrue(wait_task.done()) 267 uri_template = wait_task.result().federated_select_uri_info.uri_template 268 self.assertNotEmpty(uri_template) 269 270 # Check the contents of the slices. 271 for served_at_id, slices in TEST_SLICES.items(): 272 for i, slice_data in enumerate(slices): 273 with urllib.request.urlopen( 274 uri_template.format(served_at_id=served_at_id, key_base10=str(i)) 275 ) as response: 276 self.assertEqual( 277 response.getheader('Content-Type'), 278 'application/octet-stream+gzip', 279 ) 280 self.assertEqual(gzip.decompress(response.read()), slice_data) 281 282 283if __name__ == '__main__': 284 absltest.main() 285