1# Copyright 2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Action handlers for the Aggregations service.""" 15 16import asyncio 17from collections.abc import Callable, Sequence 18import contextlib 19import dataclasses 20import enum 21import functools 22import http 23import queue 24import threading 25from typing import Optional 26import uuid 27 28from absl import logging 29 30from google.longrunning import operations_pb2 31from google.rpc import code_pb2 32from fcp.aggregation.protocol import aggregation_protocol_messages_pb2 as apm_pb2 33from fcp.aggregation.protocol import configuration_pb2 34from fcp.aggregation.protocol.python import aggregation_protocol 35from fcp.aggregation.tensorflow.python import aggregation_protocols 36from fcp.demo import http_actions 37from fcp.demo import media 38from fcp.protos import plan_pb2 39from fcp.protos.federatedcompute import aggregations_pb2 40from fcp.protos.federatedcompute import common_pb2 41from pybind11_abseil import status as absl_status 42 43 44class AggregationStatus(enum.Enum): 45 COMPLETED = 1 46 PENDING = 2 47 FAILED = 3 48 ABORTED = 4 49 50 51@dataclasses.dataclass 52class SessionStatus: 53 """The status of an aggregation session.""" 54 # The current state of the aggregation session. 55 status: AggregationStatus = AggregationStatus.PENDING 56 # Number of clients that successfully started and completed the aggregation 57 # upload protocol. 58 num_clients_completed: int = 0 59 # Number of clients that started the aggregation upload protocol but failed 60 # to complete (e.g dropped out in the middle of the protocol). 61 num_clients_failed: int = 0 62 # Number of clients that started the aggregation upload protocol but have not 63 # yet finished (either successfully or not). 64 num_clients_pending: int = 0 65 # Number of clients that started the aggregation protocol but were aborted by 66 # the server before they could complete (e.g., if progress on the session was 67 # no longer needed). 68 num_clients_aborted: int = 0 69 # Number of inputs that were successfully aggregated and included in the 70 # final aggregate. Note that even if a client successfully completes the 71 # protocol (i.e., it is included in num_clients_completed), it is not 72 # guaranteed that the uploaded report is included in the final aggregate yet. 73 num_inputs_aggregated_and_included: int = 0 74 # Number of inputs that were received by the server and are pending (i.e., 75 # the inputs have not been included in the final aggregate yet). 76 num_inputs_aggregated_and_pending: int = 0 77 # Number of inputs that were received by the server but discarded. 78 num_inputs_discarded: int = 0 79 80 81@dataclasses.dataclass(frozen=True) 82class AggregationRequirements: 83 # The minimum number of clients required before a result can be released 84 # outside this service. Note that aggregation does not automatically stop if 85 # minimum_clients_in_server_published_aggregate is met. It is up to callers 86 # to stop aggregation when they want. 87 minimum_clients_in_server_published_aggregate: int 88 # The Plan to execute. 89 plan: plan_pb2.Plan 90 91 92@dataclasses.dataclass 93class _ActiveClientData: 94 """Information about an active client.""" 95 # The client's identifier in the aggregation protocol. 96 client_id: int 97 # Queue receiving the final status of the client connection (if closed by the 98 # aggregation protocol). At most one value will be written. 99 close_status: queue.SimpleQueue[absl_status.Status] 100 # The name of the resource to which the client should write its update. 101 resource_name: str 102 103 104@dataclasses.dataclass(eq=False) 105class _WaitData: 106 """Information about a pending wait operation.""" 107 # The condition under which the wait should complete. 108 num_inputs_aggregated_and_included: Optional[int] 109 # The loop the caller is waiting on. 110 loop: asyncio.AbstractEventLoop = dataclasses.field( 111 default_factory=asyncio.get_running_loop) 112 # The future to which the SessionStatus will be written once the wait is over. 113 status_future: asyncio.Future[SessionStatus] = dataclasses.field( 114 default_factory=asyncio.Future) 115 116 117class _AggregationProtocolCallback( 118 aggregation_protocol.AggregationProtocol.Callback): 119 """AggregationProtocol.Callback that writes events to queues.""" 120 121 def __init__(self, on_abort: Callable[[], None]): 122 """Constructs a new _AggregationProtocolCallback.. 123 124 Args: 125 on_abort: A callback invoked if/when Abort is called. 126 """ 127 super().__init__() 128 # When a client is accepted after calling AggregationProtocol.AddClients, 129 # this queue receives the new client's id as well as a queue that will 130 # provide the diagnostic status when the client is closed. (The status 131 # queue is being used as a future and will only receive one element.) 132 self.accepted_clients: queue.SimpleQueue[tuple[ 133 int, queue.SimpleQueue[absl_status.Status]]] = queue.SimpleQueue() 134 # A queue receiving the final result of the aggregation session: either the 135 # aggregated tensors or a failure status. This queue is being used as a 136 # future and will only receive one element. 137 self.result: queue.SimpleQueue[bytes | absl_status.Status] = ( 138 queue.SimpleQueue()) 139 140 self._on_abort = on_abort 141 self._client_results_lock = threading.Lock() 142 # A map from client id to the queue for each client's close status. 143 self._client_results: dict[int, queue.SimpleQueue[absl_status.Status]] = {} 144 145 def OnAcceptClients(self, start_client_id: int, num_clients: int, 146 message: apm_pb2.AcceptanceMessage) -> None: 147 with self._client_results_lock: 148 for client_id in range(start_client_id, start_client_id + num_clients): 149 q = queue.SimpleQueue() 150 self._client_results[client_id] = q 151 self.accepted_clients.put((client_id, q)) 152 153 def OnSendServerMessage(self, client_id: int, 154 message: apm_pb2.ServerMessage) -> None: 155 raise NotImplementedError() 156 157 def OnCloseClient(self, client_id: int, 158 diagnostic_status: absl_status.Status) -> None: 159 with self._client_results_lock: 160 self._client_results.pop(client_id).put(diagnostic_status) 161 162 def OnComplete(self, result: bytes) -> None: 163 self.result.put(result) 164 165 def OnAbort(self, diagnostic_status: absl_status.Status) -> None: 166 self.result.put(diagnostic_status) 167 self._on_abort() 168 169 170@dataclasses.dataclass(eq=False) 171class _AggregationSessionState: 172 """Internal state for an aggregation session.""" 173 # The session's aggregation requirements. 174 requirements: AggregationRequirements 175 # The AggregationProtocol.Callback object receiving protocol events. 176 callback: _AggregationProtocolCallback 177 # The protocol performing the aggregation. Service._sessions_lock should not 178 # be held while AggregationProtocol methods are invoked -- both because 179 # methods may be slow and because callbacks may also need to acquire the lock. 180 agg_protocol: aggregation_protocol.AggregationProtocol 181 # The current status of the session. 182 status: AggregationStatus = AggregationStatus.PENDING 183 # Unredeemed client authorization tokens. 184 authorization_tokens: set[str] = dataclasses.field(default_factory=set) 185 # Information about active clients, keyed by authorization token 186 active_clients: dict[str, _ActiveClientData] = dataclasses.field( 187 default_factory=dict) 188 # Information for in-progress wait calls on this session. 189 pending_waits: set[_WaitData] = dataclasses.field(default_factory=set) 190 191 192class Service: 193 """Implements the Aggregations service.""" 194 195 def __init__(self, forwarding_info: Callable[[], common_pb2.ForwardingInfo], 196 media_service: media.Service): 197 self._forwarding_info = forwarding_info 198 self._media_service = media_service 199 self._sessions: dict[str, _AggregationSessionState] = {} 200 self._sessions_lock = threading.Lock() 201 202 def create_session(self, 203 aggregation_requirements: AggregationRequirements) -> str: 204 """Creates a new aggregation session and returns its id.""" 205 session_id = str(uuid.uuid4()) 206 callback = _AggregationProtocolCallback( 207 functools.partial(self._handle_protocol_abort, session_id)) 208 if (len(aggregation_requirements.plan.phase) != 1 or 209 not aggregation_requirements.plan.phase[0].HasField('server_phase_v2')): 210 raise ValueError('Plan must contain exactly one server_phase_v2.') 211 212 # NOTE: For simplicity, this implementation only creates a single, 213 # in-process aggregation shard. In a production implementation, there should 214 # be multiple shards running on separate servers to enable high rates of 215 # client contributions. Utilities for combining results from separate shards 216 # are still in development as of Jan 2023. 217 agg_protocol = aggregation_protocols.create_simple_aggregation_protocol( 218 configuration_pb2.Configuration(aggregation_configs=[ 219 self._translate_server_aggregation_config(aggregation_config) 220 for aggregation_config in 221 aggregation_requirements.plan.phase[0].server_phase_v2.aggregations 222 ]), callback) 223 agg_protocol.Start(0) 224 225 with self._sessions_lock: 226 self._sessions[session_id] = _AggregationSessionState( 227 requirements=aggregation_requirements, 228 callback=callback, 229 agg_protocol=agg_protocol) 230 return session_id 231 232 def complete_session( 233 self, session_id: str) -> tuple[SessionStatus, Optional[bytes]]: 234 """Completes the aggregation session and returns its results.""" 235 with self._sessions_lock: 236 state = self._sessions.pop(session_id) 237 238 try: 239 # Only complete the AggregationProtocol if it's still pending. The most 240 # likely alternative is that it's ABORTED due to an error generated by the 241 # protocol itself. 242 status = self._get_session_status(state) 243 if status.status != AggregationStatus.PENDING: 244 return self._get_session_status(state), None 245 246 # Ensure privacy requirements have been met. 247 if (state.agg_protocol.GetStatus().num_inputs_aggregated_and_included < 248 state.requirements.minimum_clients_in_server_published_aggregate): 249 state.agg_protocol.Abort() 250 raise ValueError( 251 'minimum_clients_in_server_published_aggregate has not been met.') 252 253 state.agg_protocol.Complete() 254 result = state.callback.result.get(timeout=1) 255 if isinstance(result, absl_status.Status): 256 raise absl_status.StatusNotOk(result) 257 state.status = AggregationStatus.COMPLETED 258 return self._get_session_status(state), result 259 except (ValueError, absl_status.StatusNotOk, queue.Empty) as e: 260 logging.warning('Failed to complete aggregation session: %s', e) 261 state.status = AggregationStatus.FAILED 262 return self._get_session_status(state), None 263 finally: 264 self._cleanup_session(state) 265 266 def abort_session(self, session_id: str) -> SessionStatus: 267 """Aborts/cancels an aggregation session.""" 268 with self._sessions_lock: 269 state = self._sessions.pop(session_id) 270 271 # Only abort the AggregationProtocol if it's still pending. The most likely 272 # alternative is that it's ABORTED due to an error generated by the protocol 273 # itself. 274 if state.status == AggregationStatus.PENDING: 275 state.status = AggregationStatus.ABORTED 276 state.agg_protocol.Abort() 277 278 self._cleanup_session(state) 279 return self._get_session_status(state) 280 281 def get_session_status(self, session_id: str) -> SessionStatus: 282 """Returns the status of an aggregation session.""" 283 with self._sessions_lock: 284 return self._get_session_status(self._sessions[session_id]) 285 286 async def wait( 287 self, 288 session_id: str, 289 num_inputs_aggregated_and_included: Optional[int] = None 290 ) -> SessionStatus: 291 """Blocks until all conditions are satisfied or the aggregation fails.""" 292 with self._sessions_lock: 293 state = self._sessions[session_id] 294 # Check if any of the conditions are already satisfied. 295 status = self._get_session_status(state) 296 if (num_inputs_aggregated_and_included is None or 297 num_inputs_aggregated_and_included <= 298 status.num_inputs_aggregated_and_included): 299 return status 300 301 wait_data = _WaitData(num_inputs_aggregated_and_included) 302 state.pending_waits.add(wait_data) 303 return await wait_data.status_future 304 305 def pre_authorize_clients(self, session_id: str, 306 num_tokens: int) -> Sequence[str]: 307 """Generates tokens authorizing clients to contribute to the session.""" 308 tokens = set(str(uuid.uuid4()) for _ in range(num_tokens)) 309 with self._sessions_lock: 310 self._sessions[session_id].authorization_tokens |= tokens 311 return list(tokens) 312 313 def _translate_intrinsic_arg( 314 self, intrinsic_arg: plan_pb2.ServerAggregationConfig.IntrinsicArg 315 ) -> configuration_pb2.Configuration.ServerAggregationConfig.IntrinsicArg: 316 """Transform an aggregation intrinsic arg for the aggregation service.""" 317 if intrinsic_arg.HasField('input_tensor'): 318 return configuration_pb2.Configuration.ServerAggregationConfig.IntrinsicArg( 319 input_tensor=intrinsic_arg.input_tensor) 320 elif intrinsic_arg.HasField('state_tensor'): 321 raise ValueError( 322 'Non-client intrinsic args are not supported in this demo.' 323 ) 324 else: 325 raise AssertionError( 326 'Cases should have exhausted all possible types of intrinsic args.') 327 328 def _translate_server_aggregation_config( 329 self, plan_aggregation_config: plan_pb2.ServerAggregationConfig 330 ) -> configuration_pb2.Configuration.ServerAggregationConfig: 331 """Transform the aggregation config for use by the aggregation service.""" 332 if plan_aggregation_config.inner_aggregations: 333 raise AssertionError('Nested intrinsic structrues are not supported yet.') 334 return configuration_pb2.Configuration.ServerAggregationConfig( 335 intrinsic_uri=plan_aggregation_config.intrinsic_uri, 336 intrinsic_args=[ 337 self._translate_intrinsic_arg(intrinsic_arg) 338 for intrinsic_arg in plan_aggregation_config.intrinsic_args 339 ], 340 output_tensors=plan_aggregation_config.output_tensors) 341 342 def _get_session_status(self, 343 state: _AggregationSessionState) -> SessionStatus: 344 """Returns the SessionStatus for an _AggregationSessionState object.""" 345 status = state.agg_protocol.GetStatus() 346 return SessionStatus( 347 status=state.status, 348 num_clients_completed=status.num_clients_completed, 349 num_clients_failed=status.num_clients_failed, 350 num_clients_pending=status.num_clients_pending, 351 num_clients_aborted=status.num_clients_aborted, 352 num_inputs_aggregated_and_included=( 353 status.num_inputs_aggregated_and_included), 354 num_inputs_aggregated_and_pending=( 355 status.num_inputs_aggregated_and_pending), 356 num_inputs_discarded=status.num_inputs_discarded) 357 358 def _get_http_status(self, code: absl_status.StatusCode) -> http.HTTPStatus: 359 """Returns the HTTPStatus code for an absl StatusCode.""" 360 if (code == absl_status.StatusCode.INVALID_ARGUMENT or 361 code == absl_status.StatusCode.FAILED_PRECONDITION): 362 return http.HTTPStatus.BAD_REQUEST 363 elif code == absl_status.StatusCode.NOT_FOUND: 364 return http.HTTPStatus.NOT_FOUND 365 else: 366 return http.HTTPStatus.INTERNAL_SERVER_ERROR 367 368 def _cleanup_session(self, state: _AggregationSessionState) -> None: 369 """Cleans up the session and releases any resources. 370 371 Args: 372 state: The session state to clean up. 373 """ 374 state.authorization_tokens.clear() 375 for client_data in state.active_clients.values(): 376 self._media_service.finalize_upload(client_data.resource_name) 377 state.active_clients.clear() 378 # Anyone waiting on the session should be notified that it's finished. 379 if state.pending_waits: 380 status = self._get_session_status(state) 381 for data in state.pending_waits: 382 data.loop.call_soon_threadsafe( 383 functools.partial(data.status_future.set_result, status)) 384 state.pending_waits.clear() 385 386 def _handle_protocol_abort(self, session_id: str) -> None: 387 """Notifies waiting clients when the protocol is aborted.""" 388 with self._sessions_lock: 389 with contextlib.suppress(KeyError): 390 state = self._sessions[session_id] 391 state.status = AggregationStatus.FAILED 392 # Anyone waiting on the session should be notified it's been aborted. 393 if state.pending_waits: 394 status = self._get_session_status(state) 395 for data in state.pending_waits: 396 data.loop.call_soon_threadsafe( 397 functools.partial(data.status_future.set_result, status)) 398 state.pending_waits.clear() 399 400 @http_actions.proto_action( 401 service='google.internal.federatedcompute.v1.Aggregations', 402 method='StartAggregationDataUpload') 403 def start_aggregation_data_upload( 404 self, request: aggregations_pb2.StartAggregationDataUploadRequest 405 ) -> operations_pb2.Operation: 406 """Handles a StartAggregationDataUpload request.""" 407 with self._sessions_lock: 408 try: 409 state = self._sessions[request.aggregation_id] 410 except KeyError as e: 411 raise http_actions.HttpError(http.HTTPStatus.NOT_FOUND) from e 412 try: 413 state.authorization_tokens.remove(request.authorization_token) 414 except KeyError as e: 415 raise http_actions.HttpError(http.HTTPStatus.UNAUTHORIZED) from e 416 417 state.agg_protocol.AddClients(1) 418 client_token = str(uuid.uuid4()) 419 client_id, close_status = state.callback.accepted_clients.get(timeout=1) 420 upload_name = self._media_service.register_upload() 421 422 with self._sessions_lock: 423 state.active_clients[client_token] = _ActiveClientData( 424 client_id, close_status, upload_name) 425 426 forwarding_info = self._forwarding_info() 427 response = aggregations_pb2.StartAggregationDataUploadResponse( 428 aggregation_protocol_forwarding_info=forwarding_info, 429 resource=common_pb2.ByteStreamResource( 430 data_upload_forwarding_info=forwarding_info, 431 resource_name=upload_name), 432 client_token=client_token) 433 434 op = operations_pb2.Operation(name=f'operations/{uuid.uuid4()}', done=True) 435 op.metadata.Pack(aggregations_pb2.StartAggregationDataUploadMetadata()) 436 op.response.Pack(response) 437 return op 438 439 @http_actions.proto_action( 440 service='google.internal.federatedcompute.v1.Aggregations', 441 method='SubmitAggregationResult') 442 def submit_aggregation_result( 443 self, request: aggregations_pb2.SubmitAggregationResultRequest 444 ) -> aggregations_pb2.SubmitAggregationResultResponse: 445 """Handles a SubmitAggregationResult request.""" 446 with self._sessions_lock: 447 try: 448 state = self._sessions[request.aggregation_id] 449 except KeyError as e: 450 raise http_actions.HttpError(http.HTTPStatus.NOT_FOUND) from e 451 try: 452 client_data = state.active_clients.pop(request.client_token) 453 except KeyError as e: 454 raise http_actions.HttpError(http.HTTPStatus.UNAUTHORIZED) from e 455 456 # Ensure the client is using the resource name provided when they called 457 # StartAggregationDataUpload. 458 if request.resource_name != client_data.resource_name: 459 raise http_actions.HttpError(http.HTTPStatus.BAD_REQUEST) 460 461 # The aggregation protocol may have already closed the connect (e.g., if 462 # an error occurred). If so, clean up the upload and return the error. 463 if not client_data.close_status.empty(): 464 with contextlib.suppress(KeyError): 465 self._media_service.finalize_upload(request.resource_name) 466 raise http_actions.HttpError( 467 self._get_http_status(client_data.close_status.get().code())) 468 469 # Finalize the upload. 470 try: 471 update = self._media_service.finalize_upload(request.resource_name) 472 if update is None: 473 raise absl_status.StatusNotOk( 474 absl_status.invalid_argument_error( 475 'Aggregation result never uploaded')) 476 except (KeyError, absl_status.StatusNotOk) as e: 477 if isinstance(e, KeyError): 478 e = absl_status.StatusNotOk( 479 absl_status.internal_error('Failed to finalize upload')) 480 state.agg_protocol.CloseClient(client_data.client_id, e.status) 481 # Since we're initiating the close, it's also necessary to notify the 482 # _AggregationProtocolCallback so it can clean up resources. 483 state.callback.OnCloseClient(client_data.client_id, e.status) 484 raise http_actions.HttpError(self._get_http_status( 485 e.status.code())) from e 486 487 client_message = apm_pb2.ClientMessage( 488 simple_aggregation=apm_pb2.ClientMessage.SimpleAggregation( 489 input=apm_pb2.ClientResource(inline_bytes=update))) 490 try: 491 state.agg_protocol.ReceiveClientMessage(client_data.client_id, 492 client_message) 493 except absl_status.StatusNotOk as e: 494 # ReceiveClientInput should only fail if the AggregationProtocol is in a 495 # bad state -- likely leading to it being aborted. 496 logging.warning('Failed to receive client input: %s', e) 497 raise http_actions.HttpError(http.HTTPStatus.INTERNAL_SERVER_ERROR) from e 498 499 # Wait for the client input to be processed. 500 close_status = client_data.close_status.get() 501 if not close_status.ok(): 502 raise http_actions.HttpError(self._get_http_status(close_status.code())) 503 504 # Check for any newly-satisfied pending wait operations. 505 with self._sessions_lock: 506 if state.pending_waits: 507 completed_waits = set() 508 status = self._get_session_status(state) 509 for data in state.pending_waits: 510 if (data.num_inputs_aggregated_and_included is not None and 511 status.num_inputs_aggregated_and_included >= 512 data.num_inputs_aggregated_and_included): 513 data.loop.call_soon_threadsafe( 514 functools.partial(data.status_future.set_result, status)) 515 completed_waits.add(data) 516 state.pending_waits -= completed_waits 517 return aggregations_pb2.SubmitAggregationResultResponse() 518 519 @http_actions.proto_action( 520 service='google.internal.federatedcompute.v1.Aggregations', 521 method='AbortAggregation') 522 def abort_aggregation( 523 self, request: aggregations_pb2.AbortAggregationRequest 524 ) -> aggregations_pb2.AbortAggregationResponse: 525 """Handles an AbortAggregation request.""" 526 with self._sessions_lock: 527 try: 528 state = self._sessions[request.aggregation_id] 529 except KeyError as e: 530 raise http_actions.HttpError(http.HTTPStatus.NOT_FOUND) from e 531 try: 532 client_data = state.active_clients.pop(request.client_token) 533 except KeyError as e: 534 raise http_actions.HttpError(http.HTTPStatus.UNAUTHORIZED) from e 535 536 # Attempt to finalize the in-progress upload to free up resources. 537 with contextlib.suppress(KeyError): 538 self._media_service.finalize_upload(client_data.resource_name) 539 540 # Notify the aggregation protocol that the client has left. 541 if request.status.code == code_pb2.Code.OK: 542 status = absl_status.Status.OkStatus() 543 else: 544 status = absl_status.BuildStatusNotOk( 545 absl_status.StatusCodeFromInt(request.status.code), 546 request.status.message) 547 state.agg_protocol.CloseClient(client_data.client_id, status) 548 # Since we're initiating the close, it's also necessary to notify the 549 # _AggregationProtocolCallback so it can clean up resources. 550 state.callback.OnCloseClient(client_data.client_id, status) 551 552 logging.debug('[%s] AbortAggregation: %s', request.aggregation_id, 553 request.status) 554 return aggregations_pb2.AbortAggregationResponse() 555