xref: /aosp_15_r20/external/federated-compute/fcp/demo/aggregations.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Action handlers for the Aggregations service."""
15
16import asyncio
17from collections.abc import Callable, Sequence
18import contextlib
19import dataclasses
20import enum
21import functools
22import http
23import queue
24import threading
25from typing import Optional
26import uuid
27
28from absl import logging
29
30from google.longrunning import operations_pb2
31from google.rpc import code_pb2
32from fcp.aggregation.protocol import aggregation_protocol_messages_pb2 as apm_pb2
33from fcp.aggregation.protocol import configuration_pb2
34from fcp.aggregation.protocol.python import aggregation_protocol
35from fcp.aggregation.tensorflow.python import aggregation_protocols
36from fcp.demo import http_actions
37from fcp.demo import media
38from fcp.protos import plan_pb2
39from fcp.protos.federatedcompute import aggregations_pb2
40from fcp.protos.federatedcompute import common_pb2
41from pybind11_abseil import status as absl_status
42
43
44class AggregationStatus(enum.Enum):
45  COMPLETED = 1
46  PENDING = 2
47  FAILED = 3
48  ABORTED = 4
49
50
51@dataclasses.dataclass
52class SessionStatus:
53  """The status of an aggregation session."""
54  # The current state of the aggregation session.
55  status: AggregationStatus = AggregationStatus.PENDING
56  # Number of clients that successfully started and completed the aggregation
57  # upload protocol.
58  num_clients_completed: int = 0
59  # Number of clients that started the aggregation upload protocol but failed
60  # to complete (e.g dropped out in the middle of the protocol).
61  num_clients_failed: int = 0
62  # Number of clients that started the aggregation upload protocol but have not
63  # yet finished (either successfully or not).
64  num_clients_pending: int = 0
65  # Number of clients that started the aggregation protocol but were aborted by
66  # the server before they could complete (e.g., if progress on the session was
67  # no longer needed).
68  num_clients_aborted: int = 0
69  # Number of inputs that were successfully aggregated and included in the
70  # final aggregate. Note that even if a client successfully completes the
71  # protocol (i.e., it is included in num_clients_completed), it is not
72  # guaranteed that the uploaded report is included in the final aggregate yet.
73  num_inputs_aggregated_and_included: int = 0
74  # Number of inputs that were received by the server and are pending (i.e.,
75  # the inputs have not been included in the final aggregate yet).
76  num_inputs_aggregated_and_pending: int = 0
77  # Number of inputs that were received by the server but discarded.
78  num_inputs_discarded: int = 0
79
80
81@dataclasses.dataclass(frozen=True)
82class AggregationRequirements:
83  # The minimum number of clients required before a result can be released
84  # outside this service. Note that aggregation does not automatically stop if
85  # minimum_clients_in_server_published_aggregate is met. It is up to callers
86  # to stop aggregation when they want.
87  minimum_clients_in_server_published_aggregate: int
88  # The Plan to execute.
89  plan: plan_pb2.Plan
90
91
92@dataclasses.dataclass
93class _ActiveClientData:
94  """Information about an active client."""
95  # The client's identifier in the aggregation protocol.
96  client_id: int
97  # Queue receiving the final status of the client connection (if closed by the
98  # aggregation protocol). At most one value will be written.
99  close_status: queue.SimpleQueue[absl_status.Status]
100  # The name of the resource to which the client should write its update.
101  resource_name: str
102
103
104@dataclasses.dataclass(eq=False)
105class _WaitData:
106  """Information about a pending wait operation."""
107  # The condition under which the wait should complete.
108  num_inputs_aggregated_and_included: Optional[int]
109  # The loop the caller is waiting on.
110  loop: asyncio.AbstractEventLoop = dataclasses.field(
111      default_factory=asyncio.get_running_loop)
112  # The future to which the SessionStatus will be written once the wait is over.
113  status_future: asyncio.Future[SessionStatus] = dataclasses.field(
114      default_factory=asyncio.Future)
115
116
117class _AggregationProtocolCallback(
118    aggregation_protocol.AggregationProtocol.Callback):
119  """AggregationProtocol.Callback that writes events to queues."""
120
121  def __init__(self, on_abort: Callable[[], None]):
122    """Constructs a new _AggregationProtocolCallback..
123
124    Args:
125      on_abort: A callback invoked if/when Abort is called.
126    """
127    super().__init__()
128    # When a client is accepted after calling AggregationProtocol.AddClients,
129    # this queue receives the new client's id as well as a queue that will
130    # provide the diagnostic status when the client is closed. (The status
131    # queue is being used as a future and will only receive one element.)
132    self.accepted_clients: queue.SimpleQueue[tuple[
133        int, queue.SimpleQueue[absl_status.Status]]] = queue.SimpleQueue()
134    # A queue receiving the final result of the aggregation session: either the
135    # aggregated tensors or a failure status. This queue is being used as a
136    # future and will only receive one element.
137    self.result: queue.SimpleQueue[bytes | absl_status.Status] = (
138        queue.SimpleQueue())
139
140    self._on_abort = on_abort
141    self._client_results_lock = threading.Lock()
142    # A map from client id to the queue for each client's close status.
143    self._client_results: dict[int, queue.SimpleQueue[absl_status.Status]] = {}
144
145  def OnAcceptClients(self, start_client_id: int, num_clients: int,
146                      message: apm_pb2.AcceptanceMessage) -> None:
147    with self._client_results_lock:
148      for client_id in range(start_client_id, start_client_id + num_clients):
149        q = queue.SimpleQueue()
150        self._client_results[client_id] = q
151        self.accepted_clients.put((client_id, q))
152
153  def OnSendServerMessage(self, client_id: int,
154                          message: apm_pb2.ServerMessage) -> None:
155    raise NotImplementedError()
156
157  def OnCloseClient(self, client_id: int,
158                    diagnostic_status: absl_status.Status) -> None:
159    with self._client_results_lock:
160      self._client_results.pop(client_id).put(diagnostic_status)
161
162  def OnComplete(self, result: bytes) -> None:
163    self.result.put(result)
164
165  def OnAbort(self, diagnostic_status: absl_status.Status) -> None:
166    self.result.put(diagnostic_status)
167    self._on_abort()
168
169
170@dataclasses.dataclass(eq=False)
171class _AggregationSessionState:
172  """Internal state for an aggregation session."""
173  # The session's aggregation requirements.
174  requirements: AggregationRequirements
175  # The AggregationProtocol.Callback object receiving protocol events.
176  callback: _AggregationProtocolCallback
177  # The protocol performing the aggregation. Service._sessions_lock should not
178  # be held while AggregationProtocol methods are invoked -- both because
179  # methods may be slow and because callbacks may also need to acquire the lock.
180  agg_protocol: aggregation_protocol.AggregationProtocol
181  # The current status of the session.
182  status: AggregationStatus = AggregationStatus.PENDING
183  # Unredeemed client authorization tokens.
184  authorization_tokens: set[str] = dataclasses.field(default_factory=set)
185  # Information about active clients, keyed by authorization token
186  active_clients: dict[str, _ActiveClientData] = dataclasses.field(
187      default_factory=dict)
188  # Information for in-progress wait calls on this session.
189  pending_waits: set[_WaitData] = dataclasses.field(default_factory=set)
190
191
192class Service:
193  """Implements the Aggregations service."""
194
195  def __init__(self, forwarding_info: Callable[[], common_pb2.ForwardingInfo],
196               media_service: media.Service):
197    self._forwarding_info = forwarding_info
198    self._media_service = media_service
199    self._sessions: dict[str, _AggregationSessionState] = {}
200    self._sessions_lock = threading.Lock()
201
202  def create_session(self,
203                     aggregation_requirements: AggregationRequirements) -> str:
204    """Creates a new aggregation session and returns its id."""
205    session_id = str(uuid.uuid4())
206    callback = _AggregationProtocolCallback(
207        functools.partial(self._handle_protocol_abort, session_id))
208    if (len(aggregation_requirements.plan.phase) != 1 or
209        not aggregation_requirements.plan.phase[0].HasField('server_phase_v2')):
210      raise ValueError('Plan must contain exactly one server_phase_v2.')
211
212    # NOTE: For simplicity, this implementation only creates a single,
213    # in-process aggregation shard. In a production implementation, there should
214    # be multiple shards running on separate servers to enable high rates of
215    # client contributions. Utilities for combining results from separate shards
216    # are still in development as of Jan 2023.
217    agg_protocol = aggregation_protocols.create_simple_aggregation_protocol(
218        configuration_pb2.Configuration(aggregation_configs=[
219            self._translate_server_aggregation_config(aggregation_config)
220            for aggregation_config in
221            aggregation_requirements.plan.phase[0].server_phase_v2.aggregations
222        ]), callback)
223    agg_protocol.Start(0)
224
225    with self._sessions_lock:
226      self._sessions[session_id] = _AggregationSessionState(
227          requirements=aggregation_requirements,
228          callback=callback,
229          agg_protocol=agg_protocol)
230    return session_id
231
232  def complete_session(
233      self, session_id: str) -> tuple[SessionStatus, Optional[bytes]]:
234    """Completes the aggregation session and returns its results."""
235    with self._sessions_lock:
236      state = self._sessions.pop(session_id)
237
238    try:
239      # Only complete the AggregationProtocol if it's still pending. The most
240      # likely alternative is that it's ABORTED due to an error generated by the
241      # protocol itself.
242      status = self._get_session_status(state)
243      if status.status != AggregationStatus.PENDING:
244        return self._get_session_status(state), None
245
246      # Ensure privacy requirements have been met.
247      if (state.agg_protocol.GetStatus().num_inputs_aggregated_and_included <
248          state.requirements.minimum_clients_in_server_published_aggregate):
249        state.agg_protocol.Abort()
250        raise ValueError(
251            'minimum_clients_in_server_published_aggregate has not been met.')
252
253      state.agg_protocol.Complete()
254      result = state.callback.result.get(timeout=1)
255      if isinstance(result, absl_status.Status):
256        raise absl_status.StatusNotOk(result)
257      state.status = AggregationStatus.COMPLETED
258      return self._get_session_status(state), result
259    except (ValueError, absl_status.StatusNotOk, queue.Empty) as e:
260      logging.warning('Failed to complete aggregation session: %s', e)
261      state.status = AggregationStatus.FAILED
262      return self._get_session_status(state), None
263    finally:
264      self._cleanup_session(state)
265
266  def abort_session(self, session_id: str) -> SessionStatus:
267    """Aborts/cancels an aggregation session."""
268    with self._sessions_lock:
269      state = self._sessions.pop(session_id)
270
271    # Only abort the AggregationProtocol if it's still pending. The most likely
272    # alternative is that it's ABORTED due to an error generated by the protocol
273    # itself.
274    if state.status == AggregationStatus.PENDING:
275      state.status = AggregationStatus.ABORTED
276      state.agg_protocol.Abort()
277
278    self._cleanup_session(state)
279    return self._get_session_status(state)
280
281  def get_session_status(self, session_id: str) -> SessionStatus:
282    """Returns the status of an aggregation session."""
283    with self._sessions_lock:
284      return self._get_session_status(self._sessions[session_id])
285
286  async def wait(
287      self,
288      session_id: str,
289      num_inputs_aggregated_and_included: Optional[int] = None
290  ) -> SessionStatus:
291    """Blocks until all conditions are satisfied or the aggregation fails."""
292    with self._sessions_lock:
293      state = self._sessions[session_id]
294      # Check if any of the conditions are already satisfied.
295      status = self._get_session_status(state)
296      if (num_inputs_aggregated_and_included is None or
297          num_inputs_aggregated_and_included <=
298          status.num_inputs_aggregated_and_included):
299        return status
300
301      wait_data = _WaitData(num_inputs_aggregated_and_included)
302      state.pending_waits.add(wait_data)
303    return await wait_data.status_future
304
305  def pre_authorize_clients(self, session_id: str,
306                            num_tokens: int) -> Sequence[str]:
307    """Generates tokens authorizing clients to contribute to the session."""
308    tokens = set(str(uuid.uuid4()) for _ in range(num_tokens))
309    with self._sessions_lock:
310      self._sessions[session_id].authorization_tokens |= tokens
311    return list(tokens)
312
313  def _translate_intrinsic_arg(
314      self, intrinsic_arg: plan_pb2.ServerAggregationConfig.IntrinsicArg
315  ) -> configuration_pb2.Configuration.ServerAggregationConfig.IntrinsicArg:
316    """Transform an aggregation intrinsic arg for the aggregation service."""
317    if intrinsic_arg.HasField('input_tensor'):
318      return configuration_pb2.Configuration.ServerAggregationConfig.IntrinsicArg(
319          input_tensor=intrinsic_arg.input_tensor)
320    elif intrinsic_arg.HasField('state_tensor'):
321      raise ValueError(
322          'Non-client intrinsic args are not supported in this demo.'
323      )
324    else:
325      raise AssertionError(
326          'Cases should have exhausted all possible types of intrinsic args.')
327
328  def _translate_server_aggregation_config(
329      self, plan_aggregation_config: plan_pb2.ServerAggregationConfig
330  ) -> configuration_pb2.Configuration.ServerAggregationConfig:
331    """Transform the aggregation config for use by the aggregation service."""
332    if plan_aggregation_config.inner_aggregations:
333      raise AssertionError('Nested intrinsic structrues are not supported yet.')
334    return configuration_pb2.Configuration.ServerAggregationConfig(
335        intrinsic_uri=plan_aggregation_config.intrinsic_uri,
336        intrinsic_args=[
337            self._translate_intrinsic_arg(intrinsic_arg)
338            for intrinsic_arg in plan_aggregation_config.intrinsic_args
339        ],
340        output_tensors=plan_aggregation_config.output_tensors)
341
342  def _get_session_status(self,
343                          state: _AggregationSessionState) -> SessionStatus:
344    """Returns the SessionStatus for an _AggregationSessionState object."""
345    status = state.agg_protocol.GetStatus()
346    return SessionStatus(
347        status=state.status,
348        num_clients_completed=status.num_clients_completed,
349        num_clients_failed=status.num_clients_failed,
350        num_clients_pending=status.num_clients_pending,
351        num_clients_aborted=status.num_clients_aborted,
352        num_inputs_aggregated_and_included=(
353            status.num_inputs_aggregated_and_included),
354        num_inputs_aggregated_and_pending=(
355            status.num_inputs_aggregated_and_pending),
356        num_inputs_discarded=status.num_inputs_discarded)
357
358  def _get_http_status(self, code: absl_status.StatusCode) -> http.HTTPStatus:
359    """Returns the HTTPStatus code for an absl StatusCode."""
360    if (code == absl_status.StatusCode.INVALID_ARGUMENT or
361        code == absl_status.StatusCode.FAILED_PRECONDITION):
362      return http.HTTPStatus.BAD_REQUEST
363    elif code == absl_status.StatusCode.NOT_FOUND:
364      return http.HTTPStatus.NOT_FOUND
365    else:
366      return http.HTTPStatus.INTERNAL_SERVER_ERROR
367
368  def _cleanup_session(self, state: _AggregationSessionState) -> None:
369    """Cleans up the session and releases any resources.
370
371    Args:
372      state: The session state to clean up.
373    """
374    state.authorization_tokens.clear()
375    for client_data in state.active_clients.values():
376      self._media_service.finalize_upload(client_data.resource_name)
377    state.active_clients.clear()
378    # Anyone waiting on the session should be notified that it's finished.
379    if state.pending_waits:
380      status = self._get_session_status(state)
381      for data in state.pending_waits:
382        data.loop.call_soon_threadsafe(
383            functools.partial(data.status_future.set_result, status))
384      state.pending_waits.clear()
385
386  def _handle_protocol_abort(self, session_id: str) -> None:
387    """Notifies waiting clients when the protocol is aborted."""
388    with self._sessions_lock:
389      with contextlib.suppress(KeyError):
390        state = self._sessions[session_id]
391        state.status = AggregationStatus.FAILED
392        # Anyone waiting on the session should be notified it's been aborted.
393        if state.pending_waits:
394          status = self._get_session_status(state)
395          for data in state.pending_waits:
396            data.loop.call_soon_threadsafe(
397                functools.partial(data.status_future.set_result, status))
398          state.pending_waits.clear()
399
400  @http_actions.proto_action(
401      service='google.internal.federatedcompute.v1.Aggregations',
402      method='StartAggregationDataUpload')
403  def start_aggregation_data_upload(
404      self, request: aggregations_pb2.StartAggregationDataUploadRequest
405  ) -> operations_pb2.Operation:
406    """Handles a StartAggregationDataUpload request."""
407    with self._sessions_lock:
408      try:
409        state = self._sessions[request.aggregation_id]
410      except KeyError as e:
411        raise http_actions.HttpError(http.HTTPStatus.NOT_FOUND) from e
412      try:
413        state.authorization_tokens.remove(request.authorization_token)
414      except KeyError as e:
415        raise http_actions.HttpError(http.HTTPStatus.UNAUTHORIZED) from e
416
417    state.agg_protocol.AddClients(1)
418    client_token = str(uuid.uuid4())
419    client_id, close_status = state.callback.accepted_clients.get(timeout=1)
420    upload_name = self._media_service.register_upload()
421
422    with self._sessions_lock:
423      state.active_clients[client_token] = _ActiveClientData(
424          client_id, close_status, upload_name)
425
426    forwarding_info = self._forwarding_info()
427    response = aggregations_pb2.StartAggregationDataUploadResponse(
428        aggregation_protocol_forwarding_info=forwarding_info,
429        resource=common_pb2.ByteStreamResource(
430            data_upload_forwarding_info=forwarding_info,
431            resource_name=upload_name),
432        client_token=client_token)
433
434    op = operations_pb2.Operation(name=f'operations/{uuid.uuid4()}', done=True)
435    op.metadata.Pack(aggregations_pb2.StartAggregationDataUploadMetadata())
436    op.response.Pack(response)
437    return op
438
439  @http_actions.proto_action(
440      service='google.internal.federatedcompute.v1.Aggregations',
441      method='SubmitAggregationResult')
442  def submit_aggregation_result(
443      self, request: aggregations_pb2.SubmitAggregationResultRequest
444  ) -> aggregations_pb2.SubmitAggregationResultResponse:
445    """Handles a SubmitAggregationResult request."""
446    with self._sessions_lock:
447      try:
448        state = self._sessions[request.aggregation_id]
449      except KeyError as e:
450        raise http_actions.HttpError(http.HTTPStatus.NOT_FOUND) from e
451      try:
452        client_data = state.active_clients.pop(request.client_token)
453      except KeyError as e:
454        raise http_actions.HttpError(http.HTTPStatus.UNAUTHORIZED) from e
455
456    # Ensure the client is using the resource name provided when they called
457    # StartAggregationDataUpload.
458    if request.resource_name != client_data.resource_name:
459      raise http_actions.HttpError(http.HTTPStatus.BAD_REQUEST)
460
461    # The aggregation protocol may have already closed the connect (e.g., if
462    # an error occurred). If so, clean up the upload and return the error.
463    if not client_data.close_status.empty():
464      with contextlib.suppress(KeyError):
465        self._media_service.finalize_upload(request.resource_name)
466      raise http_actions.HttpError(
467          self._get_http_status(client_data.close_status.get().code()))
468
469    # Finalize the upload.
470    try:
471      update = self._media_service.finalize_upload(request.resource_name)
472      if update is None:
473        raise absl_status.StatusNotOk(
474            absl_status.invalid_argument_error(
475                'Aggregation result never uploaded'))
476    except (KeyError, absl_status.StatusNotOk) as e:
477      if isinstance(e, KeyError):
478        e = absl_status.StatusNotOk(
479            absl_status.internal_error('Failed to finalize upload'))
480      state.agg_protocol.CloseClient(client_data.client_id, e.status)
481      # Since we're initiating the close, it's also necessary to notify the
482      # _AggregationProtocolCallback so it can clean up resources.
483      state.callback.OnCloseClient(client_data.client_id, e.status)
484      raise http_actions.HttpError(self._get_http_status(
485          e.status.code())) from e
486
487    client_message = apm_pb2.ClientMessage(
488        simple_aggregation=apm_pb2.ClientMessage.SimpleAggregation(
489            input=apm_pb2.ClientResource(inline_bytes=update)))
490    try:
491      state.agg_protocol.ReceiveClientMessage(client_data.client_id,
492                                              client_message)
493    except absl_status.StatusNotOk as e:
494      # ReceiveClientInput should only fail if the AggregationProtocol is in a
495      # bad state -- likely leading to it being aborted.
496      logging.warning('Failed to receive client input: %s', e)
497      raise http_actions.HttpError(http.HTTPStatus.INTERNAL_SERVER_ERROR) from e
498
499    # Wait for the client input to be processed.
500    close_status = client_data.close_status.get()
501    if not close_status.ok():
502      raise http_actions.HttpError(self._get_http_status(close_status.code()))
503
504    # Check for any newly-satisfied pending wait operations.
505    with self._sessions_lock:
506      if state.pending_waits:
507        completed_waits = set()
508        status = self._get_session_status(state)
509        for data in state.pending_waits:
510          if (data.num_inputs_aggregated_and_included is not None and
511              status.num_inputs_aggregated_and_included >=
512              data.num_inputs_aggregated_and_included):
513            data.loop.call_soon_threadsafe(
514                functools.partial(data.status_future.set_result, status))
515            completed_waits.add(data)
516        state.pending_waits -= completed_waits
517    return aggregations_pb2.SubmitAggregationResultResponse()
518
519  @http_actions.proto_action(
520      service='google.internal.federatedcompute.v1.Aggregations',
521      method='AbortAggregation')
522  def abort_aggregation(
523      self, request: aggregations_pb2.AbortAggregationRequest
524  ) -> aggregations_pb2.AbortAggregationResponse:
525    """Handles an AbortAggregation request."""
526    with self._sessions_lock:
527      try:
528        state = self._sessions[request.aggregation_id]
529      except KeyError as e:
530        raise http_actions.HttpError(http.HTTPStatus.NOT_FOUND) from e
531      try:
532        client_data = state.active_clients.pop(request.client_token)
533      except KeyError as e:
534        raise http_actions.HttpError(http.HTTPStatus.UNAUTHORIZED) from e
535
536    # Attempt to finalize the in-progress upload to free up resources.
537    with contextlib.suppress(KeyError):
538      self._media_service.finalize_upload(client_data.resource_name)
539
540    # Notify the aggregation protocol that the client has left.
541    if request.status.code == code_pb2.Code.OK:
542      status = absl_status.Status.OkStatus()
543    else:
544      status = absl_status.BuildStatusNotOk(
545          absl_status.StatusCodeFromInt(request.status.code),
546          request.status.message)
547    state.agg_protocol.CloseClient(client_data.client_id, status)
548    # Since we're initiating the close, it's also necessary to notify the
549    # _AggregationProtocolCallback so it can clean up resources.
550    state.callback.OnCloseClient(client_data.client_id, status)
551
552    logging.debug('[%s] AbortAggregation: %s', request.aggregation_id,
553                  request.status)
554    return aggregations_pb2.AbortAggregationResponse()
555