1*14675a02SAndroid Build Coastguard Worker# Copyright 2022 Google LLC 2*14675a02SAndroid Build Coastguard Worker# 3*14675a02SAndroid Build Coastguard Worker# Licensed under the Apache License, Version 2.0 (the "License"); 4*14675a02SAndroid Build Coastguard Worker# you may not use this file except in compliance with the License. 5*14675a02SAndroid Build Coastguard Worker# You may obtain a copy of the License at 6*14675a02SAndroid Build Coastguard Worker# 7*14675a02SAndroid Build Coastguard Worker# http://www.apache.org/licenses/LICENSE-2.0 8*14675a02SAndroid Build Coastguard Worker# 9*14675a02SAndroid Build Coastguard Worker# Unless required by applicable law or agreed to in writing, software 10*14675a02SAndroid Build Coastguard Worker# distributed under the License is distributed on an "AS IS" BASIS, 11*14675a02SAndroid Build Coastguard Worker# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12*14675a02SAndroid Build Coastguard Worker# See the License for the specific language governing permissions and 13*14675a02SAndroid Build Coastguard Worker# limitations under the License. 14*14675a02SAndroid Build Coastguard Worker"""Action handlers for the Aggregations service.""" 15*14675a02SAndroid Build Coastguard Worker 16*14675a02SAndroid Build Coastguard Workerimport asyncio 17*14675a02SAndroid Build Coastguard Workerfrom collections.abc import Callable, Sequence 18*14675a02SAndroid Build Coastguard Workerimport contextlib 19*14675a02SAndroid Build Coastguard Workerimport dataclasses 20*14675a02SAndroid Build Coastguard Workerimport enum 21*14675a02SAndroid Build Coastguard Workerimport functools 22*14675a02SAndroid Build Coastguard Workerimport http 23*14675a02SAndroid Build Coastguard Workerimport queue 24*14675a02SAndroid Build Coastguard Workerimport threading 25*14675a02SAndroid Build Coastguard Workerfrom typing import Optional 26*14675a02SAndroid Build Coastguard Workerimport uuid 27*14675a02SAndroid Build Coastguard Worker 28*14675a02SAndroid Build Coastguard Workerfrom absl import logging 29*14675a02SAndroid Build Coastguard Worker 30*14675a02SAndroid Build Coastguard Workerfrom google.longrunning import operations_pb2 31*14675a02SAndroid Build Coastguard Workerfrom google.rpc import code_pb2 32*14675a02SAndroid Build Coastguard Workerfrom fcp.aggregation.protocol import aggregation_protocol_messages_pb2 as apm_pb2 33*14675a02SAndroid Build Coastguard Workerfrom fcp.aggregation.protocol import configuration_pb2 34*14675a02SAndroid Build Coastguard Workerfrom fcp.aggregation.protocol.python import aggregation_protocol 35*14675a02SAndroid Build Coastguard Workerfrom fcp.aggregation.tensorflow.python import aggregation_protocols 36*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import http_actions 37*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import media 38*14675a02SAndroid Build Coastguard Workerfrom fcp.protos import plan_pb2 39*14675a02SAndroid Build Coastguard Workerfrom fcp.protos.federatedcompute import aggregations_pb2 40*14675a02SAndroid Build Coastguard Workerfrom fcp.protos.federatedcompute import common_pb2 41*14675a02SAndroid Build Coastguard Workerfrom pybind11_abseil import status as absl_status 42*14675a02SAndroid Build Coastguard Worker 43*14675a02SAndroid Build Coastguard Worker 44*14675a02SAndroid Build Coastguard Workerclass AggregationStatus(enum.Enum): 45*14675a02SAndroid Build Coastguard Worker COMPLETED = 1 46*14675a02SAndroid Build Coastguard Worker PENDING = 2 47*14675a02SAndroid Build Coastguard Worker FAILED = 3 48*14675a02SAndroid Build Coastguard Worker ABORTED = 4 49*14675a02SAndroid Build Coastguard Worker 50*14675a02SAndroid Build Coastguard Worker 51*14675a02SAndroid Build Coastguard Worker@dataclasses.dataclass 52*14675a02SAndroid Build Coastguard Workerclass SessionStatus: 53*14675a02SAndroid Build Coastguard Worker """The status of an aggregation session.""" 54*14675a02SAndroid Build Coastguard Worker # The current state of the aggregation session. 55*14675a02SAndroid Build Coastguard Worker status: AggregationStatus = AggregationStatus.PENDING 56*14675a02SAndroid Build Coastguard Worker # Number of clients that successfully started and completed the aggregation 57*14675a02SAndroid Build Coastguard Worker # upload protocol. 58*14675a02SAndroid Build Coastguard Worker num_clients_completed: int = 0 59*14675a02SAndroid Build Coastguard Worker # Number of clients that started the aggregation upload protocol but failed 60*14675a02SAndroid Build Coastguard Worker # to complete (e.g dropped out in the middle of the protocol). 61*14675a02SAndroid Build Coastguard Worker num_clients_failed: int = 0 62*14675a02SAndroid Build Coastguard Worker # Number of clients that started the aggregation upload protocol but have not 63*14675a02SAndroid Build Coastguard Worker # yet finished (either successfully or not). 64*14675a02SAndroid Build Coastguard Worker num_clients_pending: int = 0 65*14675a02SAndroid Build Coastguard Worker # Number of clients that started the aggregation protocol but were aborted by 66*14675a02SAndroid Build Coastguard Worker # the server before they could complete (e.g., if progress on the session was 67*14675a02SAndroid Build Coastguard Worker # no longer needed). 68*14675a02SAndroid Build Coastguard Worker num_clients_aborted: int = 0 69*14675a02SAndroid Build Coastguard Worker # Number of inputs that were successfully aggregated and included in the 70*14675a02SAndroid Build Coastguard Worker # final aggregate. Note that even if a client successfully completes the 71*14675a02SAndroid Build Coastguard Worker # protocol (i.e., it is included in num_clients_completed), it is not 72*14675a02SAndroid Build Coastguard Worker # guaranteed that the uploaded report is included in the final aggregate yet. 73*14675a02SAndroid Build Coastguard Worker num_inputs_aggregated_and_included: int = 0 74*14675a02SAndroid Build Coastguard Worker # Number of inputs that were received by the server and are pending (i.e., 75*14675a02SAndroid Build Coastguard Worker # the inputs have not been included in the final aggregate yet). 76*14675a02SAndroid Build Coastguard Worker num_inputs_aggregated_and_pending: int = 0 77*14675a02SAndroid Build Coastguard Worker # Number of inputs that were received by the server but discarded. 78*14675a02SAndroid Build Coastguard Worker num_inputs_discarded: int = 0 79*14675a02SAndroid Build Coastguard Worker 80*14675a02SAndroid Build Coastguard Worker 81*14675a02SAndroid Build Coastguard Worker@dataclasses.dataclass(frozen=True) 82*14675a02SAndroid Build Coastguard Workerclass AggregationRequirements: 83*14675a02SAndroid Build Coastguard Worker # The minimum number of clients required before a result can be released 84*14675a02SAndroid Build Coastguard Worker # outside this service. Note that aggregation does not automatically stop if 85*14675a02SAndroid Build Coastguard Worker # minimum_clients_in_server_published_aggregate is met. It is up to callers 86*14675a02SAndroid Build Coastguard Worker # to stop aggregation when they want. 87*14675a02SAndroid Build Coastguard Worker minimum_clients_in_server_published_aggregate: int 88*14675a02SAndroid Build Coastguard Worker # The Plan to execute. 89*14675a02SAndroid Build Coastguard Worker plan: plan_pb2.Plan 90*14675a02SAndroid Build Coastguard Worker 91*14675a02SAndroid Build Coastguard Worker 92*14675a02SAndroid Build Coastguard Worker@dataclasses.dataclass 93*14675a02SAndroid Build Coastguard Workerclass _ActiveClientData: 94*14675a02SAndroid Build Coastguard Worker """Information about an active client.""" 95*14675a02SAndroid Build Coastguard Worker # The client's identifier in the aggregation protocol. 96*14675a02SAndroid Build Coastguard Worker client_id: int 97*14675a02SAndroid Build Coastguard Worker # Queue receiving the final status of the client connection (if closed by the 98*14675a02SAndroid Build Coastguard Worker # aggregation protocol). At most one value will be written. 99*14675a02SAndroid Build Coastguard Worker close_status: queue.SimpleQueue[absl_status.Status] 100*14675a02SAndroid Build Coastguard Worker # The name of the resource to which the client should write its update. 101*14675a02SAndroid Build Coastguard Worker resource_name: str 102*14675a02SAndroid Build Coastguard Worker 103*14675a02SAndroid Build Coastguard Worker 104*14675a02SAndroid Build Coastguard Worker@dataclasses.dataclass(eq=False) 105*14675a02SAndroid Build Coastguard Workerclass _WaitData: 106*14675a02SAndroid Build Coastguard Worker """Information about a pending wait operation.""" 107*14675a02SAndroid Build Coastguard Worker # The condition under which the wait should complete. 108*14675a02SAndroid Build Coastguard Worker num_inputs_aggregated_and_included: Optional[int] 109*14675a02SAndroid Build Coastguard Worker # The loop the caller is waiting on. 110*14675a02SAndroid Build Coastguard Worker loop: asyncio.AbstractEventLoop = dataclasses.field( 111*14675a02SAndroid Build Coastguard Worker default_factory=asyncio.get_running_loop) 112*14675a02SAndroid Build Coastguard Worker # The future to which the SessionStatus will be written once the wait is over. 113*14675a02SAndroid Build Coastguard Worker status_future: asyncio.Future[SessionStatus] = dataclasses.field( 114*14675a02SAndroid Build Coastguard Worker default_factory=asyncio.Future) 115*14675a02SAndroid Build Coastguard Worker 116*14675a02SAndroid Build Coastguard Worker 117*14675a02SAndroid Build Coastguard Workerclass _AggregationProtocolCallback( 118*14675a02SAndroid Build Coastguard Worker aggregation_protocol.AggregationProtocol.Callback): 119*14675a02SAndroid Build Coastguard Worker """AggregationProtocol.Callback that writes events to queues.""" 120*14675a02SAndroid Build Coastguard Worker 121*14675a02SAndroid Build Coastguard Worker def __init__(self, on_abort: Callable[[], None]): 122*14675a02SAndroid Build Coastguard Worker """Constructs a new _AggregationProtocolCallback.. 123*14675a02SAndroid Build Coastguard Worker 124*14675a02SAndroid Build Coastguard Worker Args: 125*14675a02SAndroid Build Coastguard Worker on_abort: A callback invoked if/when Abort is called. 126*14675a02SAndroid Build Coastguard Worker """ 127*14675a02SAndroid Build Coastguard Worker super().__init__() 128*14675a02SAndroid Build Coastguard Worker # When a client is accepted after calling AggregationProtocol.AddClients, 129*14675a02SAndroid Build Coastguard Worker # this queue receives the new client's id as well as a queue that will 130*14675a02SAndroid Build Coastguard Worker # provide the diagnostic status when the client is closed. (The status 131*14675a02SAndroid Build Coastguard Worker # queue is being used as a future and will only receive one element.) 132*14675a02SAndroid Build Coastguard Worker self.accepted_clients: queue.SimpleQueue[tuple[ 133*14675a02SAndroid Build Coastguard Worker int, queue.SimpleQueue[absl_status.Status]]] = queue.SimpleQueue() 134*14675a02SAndroid Build Coastguard Worker # A queue receiving the final result of the aggregation session: either the 135*14675a02SAndroid Build Coastguard Worker # aggregated tensors or a failure status. This queue is being used as a 136*14675a02SAndroid Build Coastguard Worker # future and will only receive one element. 137*14675a02SAndroid Build Coastguard Worker self.result: queue.SimpleQueue[bytes | absl_status.Status] = ( 138*14675a02SAndroid Build Coastguard Worker queue.SimpleQueue()) 139*14675a02SAndroid Build Coastguard Worker 140*14675a02SAndroid Build Coastguard Worker self._on_abort = on_abort 141*14675a02SAndroid Build Coastguard Worker self._client_results_lock = threading.Lock() 142*14675a02SAndroid Build Coastguard Worker # A map from client id to the queue for each client's close status. 143*14675a02SAndroid Build Coastguard Worker self._client_results: dict[int, queue.SimpleQueue[absl_status.Status]] = {} 144*14675a02SAndroid Build Coastguard Worker 145*14675a02SAndroid Build Coastguard Worker def OnAcceptClients(self, start_client_id: int, num_clients: int, 146*14675a02SAndroid Build Coastguard Worker message: apm_pb2.AcceptanceMessage) -> None: 147*14675a02SAndroid Build Coastguard Worker with self._client_results_lock: 148*14675a02SAndroid Build Coastguard Worker for client_id in range(start_client_id, start_client_id + num_clients): 149*14675a02SAndroid Build Coastguard Worker q = queue.SimpleQueue() 150*14675a02SAndroid Build Coastguard Worker self._client_results[client_id] = q 151*14675a02SAndroid Build Coastguard Worker self.accepted_clients.put((client_id, q)) 152*14675a02SAndroid Build Coastguard Worker 153*14675a02SAndroid Build Coastguard Worker def OnSendServerMessage(self, client_id: int, 154*14675a02SAndroid Build Coastguard Worker message: apm_pb2.ServerMessage) -> None: 155*14675a02SAndroid Build Coastguard Worker raise NotImplementedError() 156*14675a02SAndroid Build Coastguard Worker 157*14675a02SAndroid Build Coastguard Worker def OnCloseClient(self, client_id: int, 158*14675a02SAndroid Build Coastguard Worker diagnostic_status: absl_status.Status) -> None: 159*14675a02SAndroid Build Coastguard Worker with self._client_results_lock: 160*14675a02SAndroid Build Coastguard Worker self._client_results.pop(client_id).put(diagnostic_status) 161*14675a02SAndroid Build Coastguard Worker 162*14675a02SAndroid Build Coastguard Worker def OnComplete(self, result: bytes) -> None: 163*14675a02SAndroid Build Coastguard Worker self.result.put(result) 164*14675a02SAndroid Build Coastguard Worker 165*14675a02SAndroid Build Coastguard Worker def OnAbort(self, diagnostic_status: absl_status.Status) -> None: 166*14675a02SAndroid Build Coastguard Worker self.result.put(diagnostic_status) 167*14675a02SAndroid Build Coastguard Worker self._on_abort() 168*14675a02SAndroid Build Coastguard Worker 169*14675a02SAndroid Build Coastguard Worker 170*14675a02SAndroid Build Coastguard Worker@dataclasses.dataclass(eq=False) 171*14675a02SAndroid Build Coastguard Workerclass _AggregationSessionState: 172*14675a02SAndroid Build Coastguard Worker """Internal state for an aggregation session.""" 173*14675a02SAndroid Build Coastguard Worker # The session's aggregation requirements. 174*14675a02SAndroid Build Coastguard Worker requirements: AggregationRequirements 175*14675a02SAndroid Build Coastguard Worker # The AggregationProtocol.Callback object receiving protocol events. 176*14675a02SAndroid Build Coastguard Worker callback: _AggregationProtocolCallback 177*14675a02SAndroid Build Coastguard Worker # The protocol performing the aggregation. Service._sessions_lock should not 178*14675a02SAndroid Build Coastguard Worker # be held while AggregationProtocol methods are invoked -- both because 179*14675a02SAndroid Build Coastguard Worker # methods may be slow and because callbacks may also need to acquire the lock. 180*14675a02SAndroid Build Coastguard Worker agg_protocol: aggregation_protocol.AggregationProtocol 181*14675a02SAndroid Build Coastguard Worker # The current status of the session. 182*14675a02SAndroid Build Coastguard Worker status: AggregationStatus = AggregationStatus.PENDING 183*14675a02SAndroid Build Coastguard Worker # Unredeemed client authorization tokens. 184*14675a02SAndroid Build Coastguard Worker authorization_tokens: set[str] = dataclasses.field(default_factory=set) 185*14675a02SAndroid Build Coastguard Worker # Information about active clients, keyed by authorization token 186*14675a02SAndroid Build Coastguard Worker active_clients: dict[str, _ActiveClientData] = dataclasses.field( 187*14675a02SAndroid Build Coastguard Worker default_factory=dict) 188*14675a02SAndroid Build Coastguard Worker # Information for in-progress wait calls on this session. 189*14675a02SAndroid Build Coastguard Worker pending_waits: set[_WaitData] = dataclasses.field(default_factory=set) 190*14675a02SAndroid Build Coastguard Worker 191*14675a02SAndroid Build Coastguard Worker 192*14675a02SAndroid Build Coastguard Workerclass Service: 193*14675a02SAndroid Build Coastguard Worker """Implements the Aggregations service.""" 194*14675a02SAndroid Build Coastguard Worker 195*14675a02SAndroid Build Coastguard Worker def __init__(self, forwarding_info: Callable[[], common_pb2.ForwardingInfo], 196*14675a02SAndroid Build Coastguard Worker media_service: media.Service): 197*14675a02SAndroid Build Coastguard Worker self._forwarding_info = forwarding_info 198*14675a02SAndroid Build Coastguard Worker self._media_service = media_service 199*14675a02SAndroid Build Coastguard Worker self._sessions: dict[str, _AggregationSessionState] = {} 200*14675a02SAndroid Build Coastguard Worker self._sessions_lock = threading.Lock() 201*14675a02SAndroid Build Coastguard Worker 202*14675a02SAndroid Build Coastguard Worker def create_session(self, 203*14675a02SAndroid Build Coastguard Worker aggregation_requirements: AggregationRequirements) -> str: 204*14675a02SAndroid Build Coastguard Worker """Creates a new aggregation session and returns its id.""" 205*14675a02SAndroid Build Coastguard Worker session_id = str(uuid.uuid4()) 206*14675a02SAndroid Build Coastguard Worker callback = _AggregationProtocolCallback( 207*14675a02SAndroid Build Coastguard Worker functools.partial(self._handle_protocol_abort, session_id)) 208*14675a02SAndroid Build Coastguard Worker if (len(aggregation_requirements.plan.phase) != 1 or 209*14675a02SAndroid Build Coastguard Worker not aggregation_requirements.plan.phase[0].HasField('server_phase_v2')): 210*14675a02SAndroid Build Coastguard Worker raise ValueError('Plan must contain exactly one server_phase_v2.') 211*14675a02SAndroid Build Coastguard Worker 212*14675a02SAndroid Build Coastguard Worker # NOTE: For simplicity, this implementation only creates a single, 213*14675a02SAndroid Build Coastguard Worker # in-process aggregation shard. In a production implementation, there should 214*14675a02SAndroid Build Coastguard Worker # be multiple shards running on separate servers to enable high rates of 215*14675a02SAndroid Build Coastguard Worker # client contributions. Utilities for combining results from separate shards 216*14675a02SAndroid Build Coastguard Worker # are still in development as of Jan 2023. 217*14675a02SAndroid Build Coastguard Worker agg_protocol = aggregation_protocols.create_simple_aggregation_protocol( 218*14675a02SAndroid Build Coastguard Worker configuration_pb2.Configuration(aggregation_configs=[ 219*14675a02SAndroid Build Coastguard Worker self._translate_server_aggregation_config(aggregation_config) 220*14675a02SAndroid Build Coastguard Worker for aggregation_config in 221*14675a02SAndroid Build Coastguard Worker aggregation_requirements.plan.phase[0].server_phase_v2.aggregations 222*14675a02SAndroid Build Coastguard Worker ]), callback) 223*14675a02SAndroid Build Coastguard Worker agg_protocol.Start(0) 224*14675a02SAndroid Build Coastguard Worker 225*14675a02SAndroid Build Coastguard Worker with self._sessions_lock: 226*14675a02SAndroid Build Coastguard Worker self._sessions[session_id] = _AggregationSessionState( 227*14675a02SAndroid Build Coastguard Worker requirements=aggregation_requirements, 228*14675a02SAndroid Build Coastguard Worker callback=callback, 229*14675a02SAndroid Build Coastguard Worker agg_protocol=agg_protocol) 230*14675a02SAndroid Build Coastguard Worker return session_id 231*14675a02SAndroid Build Coastguard Worker 232*14675a02SAndroid Build Coastguard Worker def complete_session( 233*14675a02SAndroid Build Coastguard Worker self, session_id: str) -> tuple[SessionStatus, Optional[bytes]]: 234*14675a02SAndroid Build Coastguard Worker """Completes the aggregation session and returns its results.""" 235*14675a02SAndroid Build Coastguard Worker with self._sessions_lock: 236*14675a02SAndroid Build Coastguard Worker state = self._sessions.pop(session_id) 237*14675a02SAndroid Build Coastguard Worker 238*14675a02SAndroid Build Coastguard Worker try: 239*14675a02SAndroid Build Coastguard Worker # Only complete the AggregationProtocol if it's still pending. The most 240*14675a02SAndroid Build Coastguard Worker # likely alternative is that it's ABORTED due to an error generated by the 241*14675a02SAndroid Build Coastguard Worker # protocol itself. 242*14675a02SAndroid Build Coastguard Worker status = self._get_session_status(state) 243*14675a02SAndroid Build Coastguard Worker if status.status != AggregationStatus.PENDING: 244*14675a02SAndroid Build Coastguard Worker return self._get_session_status(state), None 245*14675a02SAndroid Build Coastguard Worker 246*14675a02SAndroid Build Coastguard Worker # Ensure privacy requirements have been met. 247*14675a02SAndroid Build Coastguard Worker if (state.agg_protocol.GetStatus().num_inputs_aggregated_and_included < 248*14675a02SAndroid Build Coastguard Worker state.requirements.minimum_clients_in_server_published_aggregate): 249*14675a02SAndroid Build Coastguard Worker state.agg_protocol.Abort() 250*14675a02SAndroid Build Coastguard Worker raise ValueError( 251*14675a02SAndroid Build Coastguard Worker 'minimum_clients_in_server_published_aggregate has not been met.') 252*14675a02SAndroid Build Coastguard Worker 253*14675a02SAndroid Build Coastguard Worker state.agg_protocol.Complete() 254*14675a02SAndroid Build Coastguard Worker result = state.callback.result.get(timeout=1) 255*14675a02SAndroid Build Coastguard Worker if isinstance(result, absl_status.Status): 256*14675a02SAndroid Build Coastguard Worker raise absl_status.StatusNotOk(result) 257*14675a02SAndroid Build Coastguard Worker state.status = AggregationStatus.COMPLETED 258*14675a02SAndroid Build Coastguard Worker return self._get_session_status(state), result 259*14675a02SAndroid Build Coastguard Worker except (ValueError, absl_status.StatusNotOk, queue.Empty) as e: 260*14675a02SAndroid Build Coastguard Worker logging.warning('Failed to complete aggregation session: %s', e) 261*14675a02SAndroid Build Coastguard Worker state.status = AggregationStatus.FAILED 262*14675a02SAndroid Build Coastguard Worker return self._get_session_status(state), None 263*14675a02SAndroid Build Coastguard Worker finally: 264*14675a02SAndroid Build Coastguard Worker self._cleanup_session(state) 265*14675a02SAndroid Build Coastguard Worker 266*14675a02SAndroid Build Coastguard Worker def abort_session(self, session_id: str) -> SessionStatus: 267*14675a02SAndroid Build Coastguard Worker """Aborts/cancels an aggregation session.""" 268*14675a02SAndroid Build Coastguard Worker with self._sessions_lock: 269*14675a02SAndroid Build Coastguard Worker state = self._sessions.pop(session_id) 270*14675a02SAndroid Build Coastguard Worker 271*14675a02SAndroid Build Coastguard Worker # Only abort the AggregationProtocol if it's still pending. The most likely 272*14675a02SAndroid Build Coastguard Worker # alternative is that it's ABORTED due to an error generated by the protocol 273*14675a02SAndroid Build Coastguard Worker # itself. 274*14675a02SAndroid Build Coastguard Worker if state.status == AggregationStatus.PENDING: 275*14675a02SAndroid Build Coastguard Worker state.status = AggregationStatus.ABORTED 276*14675a02SAndroid Build Coastguard Worker state.agg_protocol.Abort() 277*14675a02SAndroid Build Coastguard Worker 278*14675a02SAndroid Build Coastguard Worker self._cleanup_session(state) 279*14675a02SAndroid Build Coastguard Worker return self._get_session_status(state) 280*14675a02SAndroid Build Coastguard Worker 281*14675a02SAndroid Build Coastguard Worker def get_session_status(self, session_id: str) -> SessionStatus: 282*14675a02SAndroid Build Coastguard Worker """Returns the status of an aggregation session.""" 283*14675a02SAndroid Build Coastguard Worker with self._sessions_lock: 284*14675a02SAndroid Build Coastguard Worker return self._get_session_status(self._sessions[session_id]) 285*14675a02SAndroid Build Coastguard Worker 286*14675a02SAndroid Build Coastguard Worker async def wait( 287*14675a02SAndroid Build Coastguard Worker self, 288*14675a02SAndroid Build Coastguard Worker session_id: str, 289*14675a02SAndroid Build Coastguard Worker num_inputs_aggregated_and_included: Optional[int] = None 290*14675a02SAndroid Build Coastguard Worker ) -> SessionStatus: 291*14675a02SAndroid Build Coastguard Worker """Blocks until all conditions are satisfied or the aggregation fails.""" 292*14675a02SAndroid Build Coastguard Worker with self._sessions_lock: 293*14675a02SAndroid Build Coastguard Worker state = self._sessions[session_id] 294*14675a02SAndroid Build Coastguard Worker # Check if any of the conditions are already satisfied. 295*14675a02SAndroid Build Coastguard Worker status = self._get_session_status(state) 296*14675a02SAndroid Build Coastguard Worker if (num_inputs_aggregated_and_included is None or 297*14675a02SAndroid Build Coastguard Worker num_inputs_aggregated_and_included <= 298*14675a02SAndroid Build Coastguard Worker status.num_inputs_aggregated_and_included): 299*14675a02SAndroid Build Coastguard Worker return status 300*14675a02SAndroid Build Coastguard Worker 301*14675a02SAndroid Build Coastguard Worker wait_data = _WaitData(num_inputs_aggregated_and_included) 302*14675a02SAndroid Build Coastguard Worker state.pending_waits.add(wait_data) 303*14675a02SAndroid Build Coastguard Worker return await wait_data.status_future 304*14675a02SAndroid Build Coastguard Worker 305*14675a02SAndroid Build Coastguard Worker def pre_authorize_clients(self, session_id: str, 306*14675a02SAndroid Build Coastguard Worker num_tokens: int) -> Sequence[str]: 307*14675a02SAndroid Build Coastguard Worker """Generates tokens authorizing clients to contribute to the session.""" 308*14675a02SAndroid Build Coastguard Worker tokens = set(str(uuid.uuid4()) for _ in range(num_tokens)) 309*14675a02SAndroid Build Coastguard Worker with self._sessions_lock: 310*14675a02SAndroid Build Coastguard Worker self._sessions[session_id].authorization_tokens |= tokens 311*14675a02SAndroid Build Coastguard Worker return list(tokens) 312*14675a02SAndroid Build Coastguard Worker 313*14675a02SAndroid Build Coastguard Worker def _translate_intrinsic_arg( 314*14675a02SAndroid Build Coastguard Worker self, intrinsic_arg: plan_pb2.ServerAggregationConfig.IntrinsicArg 315*14675a02SAndroid Build Coastguard Worker ) -> configuration_pb2.Configuration.ServerAggregationConfig.IntrinsicArg: 316*14675a02SAndroid Build Coastguard Worker """Transform an aggregation intrinsic arg for the aggregation service.""" 317*14675a02SAndroid Build Coastguard Worker if intrinsic_arg.HasField('input_tensor'): 318*14675a02SAndroid Build Coastguard Worker return configuration_pb2.Configuration.ServerAggregationConfig.IntrinsicArg( 319*14675a02SAndroid Build Coastguard Worker input_tensor=intrinsic_arg.input_tensor) 320*14675a02SAndroid Build Coastguard Worker elif intrinsic_arg.HasField('state_tensor'): 321*14675a02SAndroid Build Coastguard Worker raise ValueError( 322*14675a02SAndroid Build Coastguard Worker 'Non-client intrinsic args are not supported in this demo.' 323*14675a02SAndroid Build Coastguard Worker ) 324*14675a02SAndroid Build Coastguard Worker else: 325*14675a02SAndroid Build Coastguard Worker raise AssertionError( 326*14675a02SAndroid Build Coastguard Worker 'Cases should have exhausted all possible types of intrinsic args.') 327*14675a02SAndroid Build Coastguard Worker 328*14675a02SAndroid Build Coastguard Worker def _translate_server_aggregation_config( 329*14675a02SAndroid Build Coastguard Worker self, plan_aggregation_config: plan_pb2.ServerAggregationConfig 330*14675a02SAndroid Build Coastguard Worker ) -> configuration_pb2.Configuration.ServerAggregationConfig: 331*14675a02SAndroid Build Coastguard Worker """Transform the aggregation config for use by the aggregation service.""" 332*14675a02SAndroid Build Coastguard Worker if plan_aggregation_config.inner_aggregations: 333*14675a02SAndroid Build Coastguard Worker raise AssertionError('Nested intrinsic structrues are not supported yet.') 334*14675a02SAndroid Build Coastguard Worker return configuration_pb2.Configuration.ServerAggregationConfig( 335*14675a02SAndroid Build Coastguard Worker intrinsic_uri=plan_aggregation_config.intrinsic_uri, 336*14675a02SAndroid Build Coastguard Worker intrinsic_args=[ 337*14675a02SAndroid Build Coastguard Worker self._translate_intrinsic_arg(intrinsic_arg) 338*14675a02SAndroid Build Coastguard Worker for intrinsic_arg in plan_aggregation_config.intrinsic_args 339*14675a02SAndroid Build Coastguard Worker ], 340*14675a02SAndroid Build Coastguard Worker output_tensors=plan_aggregation_config.output_tensors) 341*14675a02SAndroid Build Coastguard Worker 342*14675a02SAndroid Build Coastguard Worker def _get_session_status(self, 343*14675a02SAndroid Build Coastguard Worker state: _AggregationSessionState) -> SessionStatus: 344*14675a02SAndroid Build Coastguard Worker """Returns the SessionStatus for an _AggregationSessionState object.""" 345*14675a02SAndroid Build Coastguard Worker status = state.agg_protocol.GetStatus() 346*14675a02SAndroid Build Coastguard Worker return SessionStatus( 347*14675a02SAndroid Build Coastguard Worker status=state.status, 348*14675a02SAndroid Build Coastguard Worker num_clients_completed=status.num_clients_completed, 349*14675a02SAndroid Build Coastguard Worker num_clients_failed=status.num_clients_failed, 350*14675a02SAndroid Build Coastguard Worker num_clients_pending=status.num_clients_pending, 351*14675a02SAndroid Build Coastguard Worker num_clients_aborted=status.num_clients_aborted, 352*14675a02SAndroid Build Coastguard Worker num_inputs_aggregated_and_included=( 353*14675a02SAndroid Build Coastguard Worker status.num_inputs_aggregated_and_included), 354*14675a02SAndroid Build Coastguard Worker num_inputs_aggregated_and_pending=( 355*14675a02SAndroid Build Coastguard Worker status.num_inputs_aggregated_and_pending), 356*14675a02SAndroid Build Coastguard Worker num_inputs_discarded=status.num_inputs_discarded) 357*14675a02SAndroid Build Coastguard Worker 358*14675a02SAndroid Build Coastguard Worker def _get_http_status(self, code: absl_status.StatusCode) -> http.HTTPStatus: 359*14675a02SAndroid Build Coastguard Worker """Returns the HTTPStatus code for an absl StatusCode.""" 360*14675a02SAndroid Build Coastguard Worker if (code == absl_status.StatusCode.INVALID_ARGUMENT or 361*14675a02SAndroid Build Coastguard Worker code == absl_status.StatusCode.FAILED_PRECONDITION): 362*14675a02SAndroid Build Coastguard Worker return http.HTTPStatus.BAD_REQUEST 363*14675a02SAndroid Build Coastguard Worker elif code == absl_status.StatusCode.NOT_FOUND: 364*14675a02SAndroid Build Coastguard Worker return http.HTTPStatus.NOT_FOUND 365*14675a02SAndroid Build Coastguard Worker else: 366*14675a02SAndroid Build Coastguard Worker return http.HTTPStatus.INTERNAL_SERVER_ERROR 367*14675a02SAndroid Build Coastguard Worker 368*14675a02SAndroid Build Coastguard Worker def _cleanup_session(self, state: _AggregationSessionState) -> None: 369*14675a02SAndroid Build Coastguard Worker """Cleans up the session and releases any resources. 370*14675a02SAndroid Build Coastguard Worker 371*14675a02SAndroid Build Coastguard Worker Args: 372*14675a02SAndroid Build Coastguard Worker state: The session state to clean up. 373*14675a02SAndroid Build Coastguard Worker """ 374*14675a02SAndroid Build Coastguard Worker state.authorization_tokens.clear() 375*14675a02SAndroid Build Coastguard Worker for client_data in state.active_clients.values(): 376*14675a02SAndroid Build Coastguard Worker self._media_service.finalize_upload(client_data.resource_name) 377*14675a02SAndroid Build Coastguard Worker state.active_clients.clear() 378*14675a02SAndroid Build Coastguard Worker # Anyone waiting on the session should be notified that it's finished. 379*14675a02SAndroid Build Coastguard Worker if state.pending_waits: 380*14675a02SAndroid Build Coastguard Worker status = self._get_session_status(state) 381*14675a02SAndroid Build Coastguard Worker for data in state.pending_waits: 382*14675a02SAndroid Build Coastguard Worker data.loop.call_soon_threadsafe( 383*14675a02SAndroid Build Coastguard Worker functools.partial(data.status_future.set_result, status)) 384*14675a02SAndroid Build Coastguard Worker state.pending_waits.clear() 385*14675a02SAndroid Build Coastguard Worker 386*14675a02SAndroid Build Coastguard Worker def _handle_protocol_abort(self, session_id: str) -> None: 387*14675a02SAndroid Build Coastguard Worker """Notifies waiting clients when the protocol is aborted.""" 388*14675a02SAndroid Build Coastguard Worker with self._sessions_lock: 389*14675a02SAndroid Build Coastguard Worker with contextlib.suppress(KeyError): 390*14675a02SAndroid Build Coastguard Worker state = self._sessions[session_id] 391*14675a02SAndroid Build Coastguard Worker state.status = AggregationStatus.FAILED 392*14675a02SAndroid Build Coastguard Worker # Anyone waiting on the session should be notified it's been aborted. 393*14675a02SAndroid Build Coastguard Worker if state.pending_waits: 394*14675a02SAndroid Build Coastguard Worker status = self._get_session_status(state) 395*14675a02SAndroid Build Coastguard Worker for data in state.pending_waits: 396*14675a02SAndroid Build Coastguard Worker data.loop.call_soon_threadsafe( 397*14675a02SAndroid Build Coastguard Worker functools.partial(data.status_future.set_result, status)) 398*14675a02SAndroid Build Coastguard Worker state.pending_waits.clear() 399*14675a02SAndroid Build Coastguard Worker 400*14675a02SAndroid Build Coastguard Worker @http_actions.proto_action( 401*14675a02SAndroid Build Coastguard Worker service='google.internal.federatedcompute.v1.Aggregations', 402*14675a02SAndroid Build Coastguard Worker method='StartAggregationDataUpload') 403*14675a02SAndroid Build Coastguard Worker def start_aggregation_data_upload( 404*14675a02SAndroid Build Coastguard Worker self, request: aggregations_pb2.StartAggregationDataUploadRequest 405*14675a02SAndroid Build Coastguard Worker ) -> operations_pb2.Operation: 406*14675a02SAndroid Build Coastguard Worker """Handles a StartAggregationDataUpload request.""" 407*14675a02SAndroid Build Coastguard Worker with self._sessions_lock: 408*14675a02SAndroid Build Coastguard Worker try: 409*14675a02SAndroid Build Coastguard Worker state = self._sessions[request.aggregation_id] 410*14675a02SAndroid Build Coastguard Worker except KeyError as e: 411*14675a02SAndroid Build Coastguard Worker raise http_actions.HttpError(http.HTTPStatus.NOT_FOUND) from e 412*14675a02SAndroid Build Coastguard Worker try: 413*14675a02SAndroid Build Coastguard Worker state.authorization_tokens.remove(request.authorization_token) 414*14675a02SAndroid Build Coastguard Worker except KeyError as e: 415*14675a02SAndroid Build Coastguard Worker raise http_actions.HttpError(http.HTTPStatus.UNAUTHORIZED) from e 416*14675a02SAndroid Build Coastguard Worker 417*14675a02SAndroid Build Coastguard Worker state.agg_protocol.AddClients(1) 418*14675a02SAndroid Build Coastguard Worker client_token = str(uuid.uuid4()) 419*14675a02SAndroid Build Coastguard Worker client_id, close_status = state.callback.accepted_clients.get(timeout=1) 420*14675a02SAndroid Build Coastguard Worker upload_name = self._media_service.register_upload() 421*14675a02SAndroid Build Coastguard Worker 422*14675a02SAndroid Build Coastguard Worker with self._sessions_lock: 423*14675a02SAndroid Build Coastguard Worker state.active_clients[client_token] = _ActiveClientData( 424*14675a02SAndroid Build Coastguard Worker client_id, close_status, upload_name) 425*14675a02SAndroid Build Coastguard Worker 426*14675a02SAndroid Build Coastguard Worker forwarding_info = self._forwarding_info() 427*14675a02SAndroid Build Coastguard Worker response = aggregations_pb2.StartAggregationDataUploadResponse( 428*14675a02SAndroid Build Coastguard Worker aggregation_protocol_forwarding_info=forwarding_info, 429*14675a02SAndroid Build Coastguard Worker resource=common_pb2.ByteStreamResource( 430*14675a02SAndroid Build Coastguard Worker data_upload_forwarding_info=forwarding_info, 431*14675a02SAndroid Build Coastguard Worker resource_name=upload_name), 432*14675a02SAndroid Build Coastguard Worker client_token=client_token) 433*14675a02SAndroid Build Coastguard Worker 434*14675a02SAndroid Build Coastguard Worker op = operations_pb2.Operation(name=f'operations/{uuid.uuid4()}', done=True) 435*14675a02SAndroid Build Coastguard Worker op.metadata.Pack(aggregations_pb2.StartAggregationDataUploadMetadata()) 436*14675a02SAndroid Build Coastguard Worker op.response.Pack(response) 437*14675a02SAndroid Build Coastguard Worker return op 438*14675a02SAndroid Build Coastguard Worker 439*14675a02SAndroid Build Coastguard Worker @http_actions.proto_action( 440*14675a02SAndroid Build Coastguard Worker service='google.internal.federatedcompute.v1.Aggregations', 441*14675a02SAndroid Build Coastguard Worker method='SubmitAggregationResult') 442*14675a02SAndroid Build Coastguard Worker def submit_aggregation_result( 443*14675a02SAndroid Build Coastguard Worker self, request: aggregations_pb2.SubmitAggregationResultRequest 444*14675a02SAndroid Build Coastguard Worker ) -> aggregations_pb2.SubmitAggregationResultResponse: 445*14675a02SAndroid Build Coastguard Worker """Handles a SubmitAggregationResult request.""" 446*14675a02SAndroid Build Coastguard Worker with self._sessions_lock: 447*14675a02SAndroid Build Coastguard Worker try: 448*14675a02SAndroid Build Coastguard Worker state = self._sessions[request.aggregation_id] 449*14675a02SAndroid Build Coastguard Worker except KeyError as e: 450*14675a02SAndroid Build Coastguard Worker raise http_actions.HttpError(http.HTTPStatus.NOT_FOUND) from e 451*14675a02SAndroid Build Coastguard Worker try: 452*14675a02SAndroid Build Coastguard Worker client_data = state.active_clients.pop(request.client_token) 453*14675a02SAndroid Build Coastguard Worker except KeyError as e: 454*14675a02SAndroid Build Coastguard Worker raise http_actions.HttpError(http.HTTPStatus.UNAUTHORIZED) from e 455*14675a02SAndroid Build Coastguard Worker 456*14675a02SAndroid Build Coastguard Worker # Ensure the client is using the resource name provided when they called 457*14675a02SAndroid Build Coastguard Worker # StartAggregationDataUpload. 458*14675a02SAndroid Build Coastguard Worker if request.resource_name != client_data.resource_name: 459*14675a02SAndroid Build Coastguard Worker raise http_actions.HttpError(http.HTTPStatus.BAD_REQUEST) 460*14675a02SAndroid Build Coastguard Worker 461*14675a02SAndroid Build Coastguard Worker # The aggregation protocol may have already closed the connect (e.g., if 462*14675a02SAndroid Build Coastguard Worker # an error occurred). If so, clean up the upload and return the error. 463*14675a02SAndroid Build Coastguard Worker if not client_data.close_status.empty(): 464*14675a02SAndroid Build Coastguard Worker with contextlib.suppress(KeyError): 465*14675a02SAndroid Build Coastguard Worker self._media_service.finalize_upload(request.resource_name) 466*14675a02SAndroid Build Coastguard Worker raise http_actions.HttpError( 467*14675a02SAndroid Build Coastguard Worker self._get_http_status(client_data.close_status.get().code())) 468*14675a02SAndroid Build Coastguard Worker 469*14675a02SAndroid Build Coastguard Worker # Finalize the upload. 470*14675a02SAndroid Build Coastguard Worker try: 471*14675a02SAndroid Build Coastguard Worker update = self._media_service.finalize_upload(request.resource_name) 472*14675a02SAndroid Build Coastguard Worker if update is None: 473*14675a02SAndroid Build Coastguard Worker raise absl_status.StatusNotOk( 474*14675a02SAndroid Build Coastguard Worker absl_status.invalid_argument_error( 475*14675a02SAndroid Build Coastguard Worker 'Aggregation result never uploaded')) 476*14675a02SAndroid Build Coastguard Worker except (KeyError, absl_status.StatusNotOk) as e: 477*14675a02SAndroid Build Coastguard Worker if isinstance(e, KeyError): 478*14675a02SAndroid Build Coastguard Worker e = absl_status.StatusNotOk( 479*14675a02SAndroid Build Coastguard Worker absl_status.internal_error('Failed to finalize upload')) 480*14675a02SAndroid Build Coastguard Worker state.agg_protocol.CloseClient(client_data.client_id, e.status) 481*14675a02SAndroid Build Coastguard Worker # Since we're initiating the close, it's also necessary to notify the 482*14675a02SAndroid Build Coastguard Worker # _AggregationProtocolCallback so it can clean up resources. 483*14675a02SAndroid Build Coastguard Worker state.callback.OnCloseClient(client_data.client_id, e.status) 484*14675a02SAndroid Build Coastguard Worker raise http_actions.HttpError(self._get_http_status( 485*14675a02SAndroid Build Coastguard Worker e.status.code())) from e 486*14675a02SAndroid Build Coastguard Worker 487*14675a02SAndroid Build Coastguard Worker client_message = apm_pb2.ClientMessage( 488*14675a02SAndroid Build Coastguard Worker simple_aggregation=apm_pb2.ClientMessage.SimpleAggregation( 489*14675a02SAndroid Build Coastguard Worker input=apm_pb2.ClientResource(inline_bytes=update))) 490*14675a02SAndroid Build Coastguard Worker try: 491*14675a02SAndroid Build Coastguard Worker state.agg_protocol.ReceiveClientMessage(client_data.client_id, 492*14675a02SAndroid Build Coastguard Worker client_message) 493*14675a02SAndroid Build Coastguard Worker except absl_status.StatusNotOk as e: 494*14675a02SAndroid Build Coastguard Worker # ReceiveClientInput should only fail if the AggregationProtocol is in a 495*14675a02SAndroid Build Coastguard Worker # bad state -- likely leading to it being aborted. 496*14675a02SAndroid Build Coastguard Worker logging.warning('Failed to receive client input: %s', e) 497*14675a02SAndroid Build Coastguard Worker raise http_actions.HttpError(http.HTTPStatus.INTERNAL_SERVER_ERROR) from e 498*14675a02SAndroid Build Coastguard Worker 499*14675a02SAndroid Build Coastguard Worker # Wait for the client input to be processed. 500*14675a02SAndroid Build Coastguard Worker close_status = client_data.close_status.get() 501*14675a02SAndroid Build Coastguard Worker if not close_status.ok(): 502*14675a02SAndroid Build Coastguard Worker raise http_actions.HttpError(self._get_http_status(close_status.code())) 503*14675a02SAndroid Build Coastguard Worker 504*14675a02SAndroid Build Coastguard Worker # Check for any newly-satisfied pending wait operations. 505*14675a02SAndroid Build Coastguard Worker with self._sessions_lock: 506*14675a02SAndroid Build Coastguard Worker if state.pending_waits: 507*14675a02SAndroid Build Coastguard Worker completed_waits = set() 508*14675a02SAndroid Build Coastguard Worker status = self._get_session_status(state) 509*14675a02SAndroid Build Coastguard Worker for data in state.pending_waits: 510*14675a02SAndroid Build Coastguard Worker if (data.num_inputs_aggregated_and_included is not None and 511*14675a02SAndroid Build Coastguard Worker status.num_inputs_aggregated_and_included >= 512*14675a02SAndroid Build Coastguard Worker data.num_inputs_aggregated_and_included): 513*14675a02SAndroid Build Coastguard Worker data.loop.call_soon_threadsafe( 514*14675a02SAndroid Build Coastguard Worker functools.partial(data.status_future.set_result, status)) 515*14675a02SAndroid Build Coastguard Worker completed_waits.add(data) 516*14675a02SAndroid Build Coastguard Worker state.pending_waits -= completed_waits 517*14675a02SAndroid Build Coastguard Worker return aggregations_pb2.SubmitAggregationResultResponse() 518*14675a02SAndroid Build Coastguard Worker 519*14675a02SAndroid Build Coastguard Worker @http_actions.proto_action( 520*14675a02SAndroid Build Coastguard Worker service='google.internal.federatedcompute.v1.Aggregations', 521*14675a02SAndroid Build Coastguard Worker method='AbortAggregation') 522*14675a02SAndroid Build Coastguard Worker def abort_aggregation( 523*14675a02SAndroid Build Coastguard Worker self, request: aggregations_pb2.AbortAggregationRequest 524*14675a02SAndroid Build Coastguard Worker ) -> aggregations_pb2.AbortAggregationResponse: 525*14675a02SAndroid Build Coastguard Worker """Handles an AbortAggregation request.""" 526*14675a02SAndroid Build Coastguard Worker with self._sessions_lock: 527*14675a02SAndroid Build Coastguard Worker try: 528*14675a02SAndroid Build Coastguard Worker state = self._sessions[request.aggregation_id] 529*14675a02SAndroid Build Coastguard Worker except KeyError as e: 530*14675a02SAndroid Build Coastguard Worker raise http_actions.HttpError(http.HTTPStatus.NOT_FOUND) from e 531*14675a02SAndroid Build Coastguard Worker try: 532*14675a02SAndroid Build Coastguard Worker client_data = state.active_clients.pop(request.client_token) 533*14675a02SAndroid Build Coastguard Worker except KeyError as e: 534*14675a02SAndroid Build Coastguard Worker raise http_actions.HttpError(http.HTTPStatus.UNAUTHORIZED) from e 535*14675a02SAndroid Build Coastguard Worker 536*14675a02SAndroid Build Coastguard Worker # Attempt to finalize the in-progress upload to free up resources. 537*14675a02SAndroid Build Coastguard Worker with contextlib.suppress(KeyError): 538*14675a02SAndroid Build Coastguard Worker self._media_service.finalize_upload(client_data.resource_name) 539*14675a02SAndroid Build Coastguard Worker 540*14675a02SAndroid Build Coastguard Worker # Notify the aggregation protocol that the client has left. 541*14675a02SAndroid Build Coastguard Worker if request.status.code == code_pb2.Code.OK: 542*14675a02SAndroid Build Coastguard Worker status = absl_status.Status.OkStatus() 543*14675a02SAndroid Build Coastguard Worker else: 544*14675a02SAndroid Build Coastguard Worker status = absl_status.BuildStatusNotOk( 545*14675a02SAndroid Build Coastguard Worker absl_status.StatusCodeFromInt(request.status.code), 546*14675a02SAndroid Build Coastguard Worker request.status.message) 547*14675a02SAndroid Build Coastguard Worker state.agg_protocol.CloseClient(client_data.client_id, status) 548*14675a02SAndroid Build Coastguard Worker # Since we're initiating the close, it's also necessary to notify the 549*14675a02SAndroid Build Coastguard Worker # _AggregationProtocolCallback so it can clean up resources. 550*14675a02SAndroid Build Coastguard Worker state.callback.OnCloseClient(client_data.client_id, status) 551*14675a02SAndroid Build Coastguard Worker 552*14675a02SAndroid Build Coastguard Worker logging.debug('[%s] AbortAggregation: %s', request.aggregation_id, 553*14675a02SAndroid Build Coastguard Worker request.status) 554*14675a02SAndroid Build Coastguard Worker return aggregations_pb2.AbortAggregationResponse() 555