1*14675a02SAndroid Build Coastguard Worker# Copyright 2022 Google LLC 2*14675a02SAndroid Build Coastguard Worker# 3*14675a02SAndroid Build Coastguard Worker# Licensed under the Apache License, Version 2.0 (the "License"); 4*14675a02SAndroid Build Coastguard Worker# you may not use this file except in compliance with the License. 5*14675a02SAndroid Build Coastguard Worker# You may obtain a copy of the License at 6*14675a02SAndroid Build Coastguard Worker# 7*14675a02SAndroid Build Coastguard Worker# http://www.apache.org/licenses/LICENSE-2.0 8*14675a02SAndroid Build Coastguard Worker# 9*14675a02SAndroid Build Coastguard Worker# Unless required by applicable law or agreed to in writing, software 10*14675a02SAndroid Build Coastguard Worker# distributed under the License is distributed on an "AS IS" BASIS, 11*14675a02SAndroid Build Coastguard Worker# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12*14675a02SAndroid Build Coastguard Worker# See the License for the specific language governing permissions and 13*14675a02SAndroid Build Coastguard Worker# limitations under the License. 14*14675a02SAndroid Build Coastguard Worker"""An in-process federated compute server.""" 15*14675a02SAndroid Build Coastguard Worker 16*14675a02SAndroid Build Coastguard Workerimport contextlib 17*14675a02SAndroid Build Coastguard Workerimport gzip 18*14675a02SAndroid Build Coastguard Workerimport http.server 19*14675a02SAndroid Build Coastguard Workerimport socket 20*14675a02SAndroid Build Coastguard Workerimport socketserver 21*14675a02SAndroid Build Coastguard Workerimport ssl 22*14675a02SAndroid Build Coastguard Workerfrom typing import Optional 23*14675a02SAndroid Build Coastguard Worker 24*14675a02SAndroid Build Coastguard Workerfrom absl import logging 25*14675a02SAndroid Build Coastguard Worker 26*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import aggregations 27*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import eligibility_eval_tasks 28*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import http_actions 29*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import media 30*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import plan_utils 31*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import task_assignments 32*14675a02SAndroid Build Coastguard Workerfrom fcp.protos import plan_pb2 33*14675a02SAndroid Build Coastguard Workerfrom fcp.protos.federatedcompute import common_pb2 34*14675a02SAndroid Build Coastguard Workerfrom fcp.protos.federatedcompute import eligibility_eval_tasks_pb2 35*14675a02SAndroid Build Coastguard Worker 36*14675a02SAndroid Build Coastguard Worker_TaskAssignmentMode = ( 37*14675a02SAndroid Build Coastguard Worker eligibility_eval_tasks_pb2.PopulationEligibilitySpec.TaskInfo.TaskAssignmentMode 38*14675a02SAndroid Build Coastguard Worker) 39*14675a02SAndroid Build Coastguard Worker 40*14675a02SAndroid Build Coastguard Worker# Template for file name for federated select slices. See 41*14675a02SAndroid Build Coastguard Worker# `FederatedSelectUriInfo.uri_template` for the meaning of the "{served_at_id}" 42*14675a02SAndroid Build Coastguard Worker# and "{key_base10}" substrings. 43*14675a02SAndroid Build Coastguard Worker_FEDERATED_SELECT_NAME_TEMPLATE = '{served_at_id}_{key_base10}' 44*14675a02SAndroid Build Coastguard Worker 45*14675a02SAndroid Build Coastguard Worker# Content type used for serialized and compressed Plan messages. 46*14675a02SAndroid Build Coastguard Worker_PLAN_CONTENT_TYPE = 'application/x-protobuf+gzip' 47*14675a02SAndroid Build Coastguard Worker 48*14675a02SAndroid Build Coastguard Worker# Content type used for serialzied and compressed TensorFlow checkpoints. 49*14675a02SAndroid Build Coastguard Worker_CHECKPOINT_CONTENT_TYPE = 'application/octet-stream+gzip' 50*14675a02SAndroid Build Coastguard Worker 51*14675a02SAndroid Build Coastguard Worker 52*14675a02SAndroid Build Coastguard Workerclass InProcessServer(socketserver.ThreadingMixIn, http.server.HTTPServer): 53*14675a02SAndroid Build Coastguard Worker """An in-process HTTP server implementing the Federated Compute protocol.""" 54*14675a02SAndroid Build Coastguard Worker 55*14675a02SAndroid Build Coastguard Worker def __init__(self, 56*14675a02SAndroid Build Coastguard Worker *, 57*14675a02SAndroid Build Coastguard Worker population_name: str, 58*14675a02SAndroid Build Coastguard Worker host: str, 59*14675a02SAndroid Build Coastguard Worker port: int, 60*14675a02SAndroid Build Coastguard Worker address_family: Optional[socket.AddressFamily] = None): 61*14675a02SAndroid Build Coastguard Worker self._media_service = media.Service(self._get_forwarding_info) 62*14675a02SAndroid Build Coastguard Worker self._aggregations_service = aggregations.Service(self._get_forwarding_info, 63*14675a02SAndroid Build Coastguard Worker self._media_service) 64*14675a02SAndroid Build Coastguard Worker self._task_assignments_service = task_assignments.Service( 65*14675a02SAndroid Build Coastguard Worker population_name, self._get_forwarding_info, self._aggregations_service) 66*14675a02SAndroid Build Coastguard Worker self._eligibility_eval_tasks_service = eligibility_eval_tasks.Service( 67*14675a02SAndroid Build Coastguard Worker population_name, self._get_forwarding_info 68*14675a02SAndroid Build Coastguard Worker ) 69*14675a02SAndroid Build Coastguard Worker handler = http_actions.create_handler( 70*14675a02SAndroid Build Coastguard Worker self._media_service, 71*14675a02SAndroid Build Coastguard Worker self._aggregations_service, 72*14675a02SAndroid Build Coastguard Worker self._task_assignments_service, 73*14675a02SAndroid Build Coastguard Worker self._eligibility_eval_tasks_service, 74*14675a02SAndroid Build Coastguard Worker ) 75*14675a02SAndroid Build Coastguard Worker if address_family is not None: 76*14675a02SAndroid Build Coastguard Worker self.address_family = address_family 77*14675a02SAndroid Build Coastguard Worker http.server.HTTPServer.__init__(self, (host, port), handler) 78*14675a02SAndroid Build Coastguard Worker 79*14675a02SAndroid Build Coastguard Worker async def run_computation( 80*14675a02SAndroid Build Coastguard Worker self, 81*14675a02SAndroid Build Coastguard Worker task_name: str, 82*14675a02SAndroid Build Coastguard Worker plan: plan_pb2.Plan, 83*14675a02SAndroid Build Coastguard Worker server_checkpoint: bytes, 84*14675a02SAndroid Build Coastguard Worker task_assignment_mode: _TaskAssignmentMode, 85*14675a02SAndroid Build Coastguard Worker number_of_clients: int, 86*14675a02SAndroid Build Coastguard Worker ) -> bytes: 87*14675a02SAndroid Build Coastguard Worker """Runs a computation, returning the resulting checkpoint. 88*14675a02SAndroid Build Coastguard Worker 89*14675a02SAndroid Build Coastguard Worker If there's already a computation in progress, the new computation will 90*14675a02SAndroid Build Coastguard Worker not start until the previous one has completed (either successfully or not). 91*14675a02SAndroid Build Coastguard Worker 92*14675a02SAndroid Build Coastguard Worker Args: 93*14675a02SAndroid Build Coastguard Worker task_name: The name of the task. 94*14675a02SAndroid Build Coastguard Worker plan: The Plan proto containing the client and server computations. 95*14675a02SAndroid Build Coastguard Worker server_checkpoint: The starting server checkpoint. 96*14675a02SAndroid Build Coastguard Worker task_assignment_mode: The task assignment mode to use for the computation. 97*14675a02SAndroid Build Coastguard Worker number_of_clients: The minimum number of clients to include. 98*14675a02SAndroid Build Coastguard Worker 99*14675a02SAndroid Build Coastguard Worker Returns: 100*14675a02SAndroid Build Coastguard Worker A TensorFlow checkpoint containing the aggregated results. 101*14675a02SAndroid Build Coastguard Worker """ 102*14675a02SAndroid Build Coastguard Worker requirements = aggregations.AggregationRequirements( 103*14675a02SAndroid Build Coastguard Worker minimum_clients_in_server_published_aggregate=number_of_clients, 104*14675a02SAndroid Build Coastguard Worker plan=plan) 105*14675a02SAndroid Build Coastguard Worker session_id = self._aggregations_service.create_session(requirements) 106*14675a02SAndroid Build Coastguard Worker with contextlib.ExitStack() as stack: 107*14675a02SAndroid Build Coastguard Worker stack.callback( 108*14675a02SAndroid Build Coastguard Worker lambda: self._aggregations_service.abort_session(session_id)) 109*14675a02SAndroid Build Coastguard Worker with plan_utils.Session(plan, server_checkpoint) as session: 110*14675a02SAndroid Build Coastguard Worker with self._media_service.create_download_group() as group: 111*14675a02SAndroid Build Coastguard Worker plan_url = group.add( 112*14675a02SAndroid Build Coastguard Worker 'plan', 113*14675a02SAndroid Build Coastguard Worker gzip.compress(session.client_plan), 114*14675a02SAndroid Build Coastguard Worker content_type=_PLAN_CONTENT_TYPE, 115*14675a02SAndroid Build Coastguard Worker ) 116*14675a02SAndroid Build Coastguard Worker checkpoint_url = group.add( 117*14675a02SAndroid Build Coastguard Worker 'checkpoint', 118*14675a02SAndroid Build Coastguard Worker gzip.compress(session.client_checkpoint), 119*14675a02SAndroid Build Coastguard Worker content_type=_CHECKPOINT_CONTENT_TYPE, 120*14675a02SAndroid Build Coastguard Worker ) 121*14675a02SAndroid Build Coastguard Worker for served_at_id, slices in session.slices.items(): 122*14675a02SAndroid Build Coastguard Worker for i, slice_data in enumerate(slices): 123*14675a02SAndroid Build Coastguard Worker group.add( 124*14675a02SAndroid Build Coastguard Worker _FEDERATED_SELECT_NAME_TEMPLATE.format( 125*14675a02SAndroid Build Coastguard Worker served_at_id=served_at_id, key_base10=str(i) 126*14675a02SAndroid Build Coastguard Worker ), 127*14675a02SAndroid Build Coastguard Worker gzip.compress(slice_data), 128*14675a02SAndroid Build Coastguard Worker content_type=_CHECKPOINT_CONTENT_TYPE, 129*14675a02SAndroid Build Coastguard Worker ) 130*14675a02SAndroid Build Coastguard Worker self._eligibility_eval_tasks_service.add_task( 131*14675a02SAndroid Build Coastguard Worker task_name, task_assignment_mode 132*14675a02SAndroid Build Coastguard Worker ) 133*14675a02SAndroid Build Coastguard Worker self._task_assignments_service.add_task( 134*14675a02SAndroid Build Coastguard Worker task_name, 135*14675a02SAndroid Build Coastguard Worker task_assignment_mode, 136*14675a02SAndroid Build Coastguard Worker session_id, 137*14675a02SAndroid Build Coastguard Worker common_pb2.Resource(uri=plan_url), 138*14675a02SAndroid Build Coastguard Worker common_pb2.Resource(uri=checkpoint_url), 139*14675a02SAndroid Build Coastguard Worker group.prefix + _FEDERATED_SELECT_NAME_TEMPLATE, 140*14675a02SAndroid Build Coastguard Worker ) 141*14675a02SAndroid Build Coastguard Worker try: 142*14675a02SAndroid Build Coastguard Worker status = await self._aggregations_service.wait( 143*14675a02SAndroid Build Coastguard Worker session_id, 144*14675a02SAndroid Build Coastguard Worker num_inputs_aggregated_and_included=number_of_clients) 145*14675a02SAndroid Build Coastguard Worker if status.status != aggregations.AggregationStatus.PENDING: 146*14675a02SAndroid Build Coastguard Worker raise ValueError('Aggregation failed.') 147*14675a02SAndroid Build Coastguard Worker finally: 148*14675a02SAndroid Build Coastguard Worker self._task_assignments_service.remove_task(session_id) 149*14675a02SAndroid Build Coastguard Worker self._eligibility_eval_tasks_service.remove_task(task_name) 150*14675a02SAndroid Build Coastguard Worker 151*14675a02SAndroid Build Coastguard Worker stack.pop_all() 152*14675a02SAndroid Build Coastguard Worker status, intermedia_update = ( 153*14675a02SAndroid Build Coastguard Worker self._aggregations_service.complete_session(session_id)) 154*14675a02SAndroid Build Coastguard Worker if (status.status != aggregations.AggregationStatus.COMPLETED or 155*14675a02SAndroid Build Coastguard Worker intermedia_update is None): 156*14675a02SAndroid Build Coastguard Worker raise ValueError('Aggregation failed.') 157*14675a02SAndroid Build Coastguard Worker logging.debug('%s aggregation complete: %s', task_name, status) 158*14675a02SAndroid Build Coastguard Worker return session.finalize(intermedia_update) 159*14675a02SAndroid Build Coastguard Worker 160*14675a02SAndroid Build Coastguard Worker def _get_forwarding_info(self) -> common_pb2.ForwardingInfo: 161*14675a02SAndroid Build Coastguard Worker protocol = 'https' if isinstance(self.socket, ssl.SSLSocket) else 'http' 162*14675a02SAndroid Build Coastguard Worker return common_pb2.ForwardingInfo( 163*14675a02SAndroid Build Coastguard Worker target_uri_prefix=( 164*14675a02SAndroid Build Coastguard Worker f'{protocol}://{self.server_name}:{self.server_port}/')) 165