xref: /aosp_15_r20/external/federated-compute/fcp/demo/aggregations.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1*14675a02SAndroid Build Coastguard Worker# Copyright 2022 Google LLC
2*14675a02SAndroid Build Coastguard Worker#
3*14675a02SAndroid Build Coastguard Worker# Licensed under the Apache License, Version 2.0 (the "License");
4*14675a02SAndroid Build Coastguard Worker# you may not use this file except in compliance with the License.
5*14675a02SAndroid Build Coastguard Worker# You may obtain a copy of the License at
6*14675a02SAndroid Build Coastguard Worker#
7*14675a02SAndroid Build Coastguard Worker#      http://www.apache.org/licenses/LICENSE-2.0
8*14675a02SAndroid Build Coastguard Worker#
9*14675a02SAndroid Build Coastguard Worker# Unless required by applicable law or agreed to in writing, software
10*14675a02SAndroid Build Coastguard Worker# distributed under the License is distributed on an "AS IS" BASIS,
11*14675a02SAndroid Build Coastguard Worker# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*14675a02SAndroid Build Coastguard Worker# See the License for the specific language governing permissions and
13*14675a02SAndroid Build Coastguard Worker# limitations under the License.
14*14675a02SAndroid Build Coastguard Worker"""Action handlers for the Aggregations service."""
15*14675a02SAndroid Build Coastguard Worker
16*14675a02SAndroid Build Coastguard Workerimport asyncio
17*14675a02SAndroid Build Coastguard Workerfrom collections.abc import Callable, Sequence
18*14675a02SAndroid Build Coastguard Workerimport contextlib
19*14675a02SAndroid Build Coastguard Workerimport dataclasses
20*14675a02SAndroid Build Coastguard Workerimport enum
21*14675a02SAndroid Build Coastguard Workerimport functools
22*14675a02SAndroid Build Coastguard Workerimport http
23*14675a02SAndroid Build Coastguard Workerimport queue
24*14675a02SAndroid Build Coastguard Workerimport threading
25*14675a02SAndroid Build Coastguard Workerfrom typing import Optional
26*14675a02SAndroid Build Coastguard Workerimport uuid
27*14675a02SAndroid Build Coastguard Worker
28*14675a02SAndroid Build Coastguard Workerfrom absl import logging
29*14675a02SAndroid Build Coastguard Worker
30*14675a02SAndroid Build Coastguard Workerfrom google.longrunning import operations_pb2
31*14675a02SAndroid Build Coastguard Workerfrom google.rpc import code_pb2
32*14675a02SAndroid Build Coastguard Workerfrom fcp.aggregation.protocol import aggregation_protocol_messages_pb2 as apm_pb2
33*14675a02SAndroid Build Coastguard Workerfrom fcp.aggregation.protocol import configuration_pb2
34*14675a02SAndroid Build Coastguard Workerfrom fcp.aggregation.protocol.python import aggregation_protocol
35*14675a02SAndroid Build Coastguard Workerfrom fcp.aggregation.tensorflow.python import aggregation_protocols
36*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import http_actions
37*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import media
38*14675a02SAndroid Build Coastguard Workerfrom fcp.protos import plan_pb2
39*14675a02SAndroid Build Coastguard Workerfrom fcp.protos.federatedcompute import aggregations_pb2
40*14675a02SAndroid Build Coastguard Workerfrom fcp.protos.federatedcompute import common_pb2
41*14675a02SAndroid Build Coastguard Workerfrom pybind11_abseil import status as absl_status
42*14675a02SAndroid Build Coastguard Worker
43*14675a02SAndroid Build Coastguard Worker
44*14675a02SAndroid Build Coastguard Workerclass AggregationStatus(enum.Enum):
45*14675a02SAndroid Build Coastguard Worker  COMPLETED = 1
46*14675a02SAndroid Build Coastguard Worker  PENDING = 2
47*14675a02SAndroid Build Coastguard Worker  FAILED = 3
48*14675a02SAndroid Build Coastguard Worker  ABORTED = 4
49*14675a02SAndroid Build Coastguard Worker
50*14675a02SAndroid Build Coastguard Worker
51*14675a02SAndroid Build Coastguard Worker@dataclasses.dataclass
52*14675a02SAndroid Build Coastguard Workerclass SessionStatus:
53*14675a02SAndroid Build Coastguard Worker  """The status of an aggregation session."""
54*14675a02SAndroid Build Coastguard Worker  # The current state of the aggregation session.
55*14675a02SAndroid Build Coastguard Worker  status: AggregationStatus = AggregationStatus.PENDING
56*14675a02SAndroid Build Coastguard Worker  # Number of clients that successfully started and completed the aggregation
57*14675a02SAndroid Build Coastguard Worker  # upload protocol.
58*14675a02SAndroid Build Coastguard Worker  num_clients_completed: int = 0
59*14675a02SAndroid Build Coastguard Worker  # Number of clients that started the aggregation upload protocol but failed
60*14675a02SAndroid Build Coastguard Worker  # to complete (e.g dropped out in the middle of the protocol).
61*14675a02SAndroid Build Coastguard Worker  num_clients_failed: int = 0
62*14675a02SAndroid Build Coastguard Worker  # Number of clients that started the aggregation upload protocol but have not
63*14675a02SAndroid Build Coastguard Worker  # yet finished (either successfully or not).
64*14675a02SAndroid Build Coastguard Worker  num_clients_pending: int = 0
65*14675a02SAndroid Build Coastguard Worker  # Number of clients that started the aggregation protocol but were aborted by
66*14675a02SAndroid Build Coastguard Worker  # the server before they could complete (e.g., if progress on the session was
67*14675a02SAndroid Build Coastguard Worker  # no longer needed).
68*14675a02SAndroid Build Coastguard Worker  num_clients_aborted: int = 0
69*14675a02SAndroid Build Coastguard Worker  # Number of inputs that were successfully aggregated and included in the
70*14675a02SAndroid Build Coastguard Worker  # final aggregate. Note that even if a client successfully completes the
71*14675a02SAndroid Build Coastguard Worker  # protocol (i.e., it is included in num_clients_completed), it is not
72*14675a02SAndroid Build Coastguard Worker  # guaranteed that the uploaded report is included in the final aggregate yet.
73*14675a02SAndroid Build Coastguard Worker  num_inputs_aggregated_and_included: int = 0
74*14675a02SAndroid Build Coastguard Worker  # Number of inputs that were received by the server and are pending (i.e.,
75*14675a02SAndroid Build Coastguard Worker  # the inputs have not been included in the final aggregate yet).
76*14675a02SAndroid Build Coastguard Worker  num_inputs_aggregated_and_pending: int = 0
77*14675a02SAndroid Build Coastguard Worker  # Number of inputs that were received by the server but discarded.
78*14675a02SAndroid Build Coastguard Worker  num_inputs_discarded: int = 0
79*14675a02SAndroid Build Coastguard Worker
80*14675a02SAndroid Build Coastguard Worker
81*14675a02SAndroid Build Coastguard Worker@dataclasses.dataclass(frozen=True)
82*14675a02SAndroid Build Coastguard Workerclass AggregationRequirements:
83*14675a02SAndroid Build Coastguard Worker  # The minimum number of clients required before a result can be released
84*14675a02SAndroid Build Coastguard Worker  # outside this service. Note that aggregation does not automatically stop if
85*14675a02SAndroid Build Coastguard Worker  # minimum_clients_in_server_published_aggregate is met. It is up to callers
86*14675a02SAndroid Build Coastguard Worker  # to stop aggregation when they want.
87*14675a02SAndroid Build Coastguard Worker  minimum_clients_in_server_published_aggregate: int
88*14675a02SAndroid Build Coastguard Worker  # The Plan to execute.
89*14675a02SAndroid Build Coastguard Worker  plan: plan_pb2.Plan
90*14675a02SAndroid Build Coastguard Worker
91*14675a02SAndroid Build Coastguard Worker
92*14675a02SAndroid Build Coastguard Worker@dataclasses.dataclass
93*14675a02SAndroid Build Coastguard Workerclass _ActiveClientData:
94*14675a02SAndroid Build Coastguard Worker  """Information about an active client."""
95*14675a02SAndroid Build Coastguard Worker  # The client's identifier in the aggregation protocol.
96*14675a02SAndroid Build Coastguard Worker  client_id: int
97*14675a02SAndroid Build Coastguard Worker  # Queue receiving the final status of the client connection (if closed by the
98*14675a02SAndroid Build Coastguard Worker  # aggregation protocol). At most one value will be written.
99*14675a02SAndroid Build Coastguard Worker  close_status: queue.SimpleQueue[absl_status.Status]
100*14675a02SAndroid Build Coastguard Worker  # The name of the resource to which the client should write its update.
101*14675a02SAndroid Build Coastguard Worker  resource_name: str
102*14675a02SAndroid Build Coastguard Worker
103*14675a02SAndroid Build Coastguard Worker
104*14675a02SAndroid Build Coastguard Worker@dataclasses.dataclass(eq=False)
105*14675a02SAndroid Build Coastguard Workerclass _WaitData:
106*14675a02SAndroid Build Coastguard Worker  """Information about a pending wait operation."""
107*14675a02SAndroid Build Coastguard Worker  # The condition under which the wait should complete.
108*14675a02SAndroid Build Coastguard Worker  num_inputs_aggregated_and_included: Optional[int]
109*14675a02SAndroid Build Coastguard Worker  # The loop the caller is waiting on.
110*14675a02SAndroid Build Coastguard Worker  loop: asyncio.AbstractEventLoop = dataclasses.field(
111*14675a02SAndroid Build Coastguard Worker      default_factory=asyncio.get_running_loop)
112*14675a02SAndroid Build Coastguard Worker  # The future to which the SessionStatus will be written once the wait is over.
113*14675a02SAndroid Build Coastguard Worker  status_future: asyncio.Future[SessionStatus] = dataclasses.field(
114*14675a02SAndroid Build Coastguard Worker      default_factory=asyncio.Future)
115*14675a02SAndroid Build Coastguard Worker
116*14675a02SAndroid Build Coastguard Worker
117*14675a02SAndroid Build Coastguard Workerclass _AggregationProtocolCallback(
118*14675a02SAndroid Build Coastguard Worker    aggregation_protocol.AggregationProtocol.Callback):
119*14675a02SAndroid Build Coastguard Worker  """AggregationProtocol.Callback that writes events to queues."""
120*14675a02SAndroid Build Coastguard Worker
121*14675a02SAndroid Build Coastguard Worker  def __init__(self, on_abort: Callable[[], None]):
122*14675a02SAndroid Build Coastguard Worker    """Constructs a new _AggregationProtocolCallback..
123*14675a02SAndroid Build Coastguard Worker
124*14675a02SAndroid Build Coastguard Worker    Args:
125*14675a02SAndroid Build Coastguard Worker      on_abort: A callback invoked if/when Abort is called.
126*14675a02SAndroid Build Coastguard Worker    """
127*14675a02SAndroid Build Coastguard Worker    super().__init__()
128*14675a02SAndroid Build Coastguard Worker    # When a client is accepted after calling AggregationProtocol.AddClients,
129*14675a02SAndroid Build Coastguard Worker    # this queue receives the new client's id as well as a queue that will
130*14675a02SAndroid Build Coastguard Worker    # provide the diagnostic status when the client is closed. (The status
131*14675a02SAndroid Build Coastguard Worker    # queue is being used as a future and will only receive one element.)
132*14675a02SAndroid Build Coastguard Worker    self.accepted_clients: queue.SimpleQueue[tuple[
133*14675a02SAndroid Build Coastguard Worker        int, queue.SimpleQueue[absl_status.Status]]] = queue.SimpleQueue()
134*14675a02SAndroid Build Coastguard Worker    # A queue receiving the final result of the aggregation session: either the
135*14675a02SAndroid Build Coastguard Worker    # aggregated tensors or a failure status. This queue is being used as a
136*14675a02SAndroid Build Coastguard Worker    # future and will only receive one element.
137*14675a02SAndroid Build Coastguard Worker    self.result: queue.SimpleQueue[bytes | absl_status.Status] = (
138*14675a02SAndroid Build Coastguard Worker        queue.SimpleQueue())
139*14675a02SAndroid Build Coastguard Worker
140*14675a02SAndroid Build Coastguard Worker    self._on_abort = on_abort
141*14675a02SAndroid Build Coastguard Worker    self._client_results_lock = threading.Lock()
142*14675a02SAndroid Build Coastguard Worker    # A map from client id to the queue for each client's close status.
143*14675a02SAndroid Build Coastguard Worker    self._client_results: dict[int, queue.SimpleQueue[absl_status.Status]] = {}
144*14675a02SAndroid Build Coastguard Worker
145*14675a02SAndroid Build Coastguard Worker  def OnAcceptClients(self, start_client_id: int, num_clients: int,
146*14675a02SAndroid Build Coastguard Worker                      message: apm_pb2.AcceptanceMessage) -> None:
147*14675a02SAndroid Build Coastguard Worker    with self._client_results_lock:
148*14675a02SAndroid Build Coastguard Worker      for client_id in range(start_client_id, start_client_id + num_clients):
149*14675a02SAndroid Build Coastguard Worker        q = queue.SimpleQueue()
150*14675a02SAndroid Build Coastguard Worker        self._client_results[client_id] = q
151*14675a02SAndroid Build Coastguard Worker        self.accepted_clients.put((client_id, q))
152*14675a02SAndroid Build Coastguard Worker
153*14675a02SAndroid Build Coastguard Worker  def OnSendServerMessage(self, client_id: int,
154*14675a02SAndroid Build Coastguard Worker                          message: apm_pb2.ServerMessage) -> None:
155*14675a02SAndroid Build Coastguard Worker    raise NotImplementedError()
156*14675a02SAndroid Build Coastguard Worker
157*14675a02SAndroid Build Coastguard Worker  def OnCloseClient(self, client_id: int,
158*14675a02SAndroid Build Coastguard Worker                    diagnostic_status: absl_status.Status) -> None:
159*14675a02SAndroid Build Coastguard Worker    with self._client_results_lock:
160*14675a02SAndroid Build Coastguard Worker      self._client_results.pop(client_id).put(diagnostic_status)
161*14675a02SAndroid Build Coastguard Worker
162*14675a02SAndroid Build Coastguard Worker  def OnComplete(self, result: bytes) -> None:
163*14675a02SAndroid Build Coastguard Worker    self.result.put(result)
164*14675a02SAndroid Build Coastguard Worker
165*14675a02SAndroid Build Coastguard Worker  def OnAbort(self, diagnostic_status: absl_status.Status) -> None:
166*14675a02SAndroid Build Coastguard Worker    self.result.put(diagnostic_status)
167*14675a02SAndroid Build Coastguard Worker    self._on_abort()
168*14675a02SAndroid Build Coastguard Worker
169*14675a02SAndroid Build Coastguard Worker
170*14675a02SAndroid Build Coastguard Worker@dataclasses.dataclass(eq=False)
171*14675a02SAndroid Build Coastguard Workerclass _AggregationSessionState:
172*14675a02SAndroid Build Coastguard Worker  """Internal state for an aggregation session."""
173*14675a02SAndroid Build Coastguard Worker  # The session's aggregation requirements.
174*14675a02SAndroid Build Coastguard Worker  requirements: AggregationRequirements
175*14675a02SAndroid Build Coastguard Worker  # The AggregationProtocol.Callback object receiving protocol events.
176*14675a02SAndroid Build Coastguard Worker  callback: _AggregationProtocolCallback
177*14675a02SAndroid Build Coastguard Worker  # The protocol performing the aggregation. Service._sessions_lock should not
178*14675a02SAndroid Build Coastguard Worker  # be held while AggregationProtocol methods are invoked -- both because
179*14675a02SAndroid Build Coastguard Worker  # methods may be slow and because callbacks may also need to acquire the lock.
180*14675a02SAndroid Build Coastguard Worker  agg_protocol: aggregation_protocol.AggregationProtocol
181*14675a02SAndroid Build Coastguard Worker  # The current status of the session.
182*14675a02SAndroid Build Coastguard Worker  status: AggregationStatus = AggregationStatus.PENDING
183*14675a02SAndroid Build Coastguard Worker  # Unredeemed client authorization tokens.
184*14675a02SAndroid Build Coastguard Worker  authorization_tokens: set[str] = dataclasses.field(default_factory=set)
185*14675a02SAndroid Build Coastguard Worker  # Information about active clients, keyed by authorization token
186*14675a02SAndroid Build Coastguard Worker  active_clients: dict[str, _ActiveClientData] = dataclasses.field(
187*14675a02SAndroid Build Coastguard Worker      default_factory=dict)
188*14675a02SAndroid Build Coastguard Worker  # Information for in-progress wait calls on this session.
189*14675a02SAndroid Build Coastguard Worker  pending_waits: set[_WaitData] = dataclasses.field(default_factory=set)
190*14675a02SAndroid Build Coastguard Worker
191*14675a02SAndroid Build Coastguard Worker
192*14675a02SAndroid Build Coastguard Workerclass Service:
193*14675a02SAndroid Build Coastguard Worker  """Implements the Aggregations service."""
194*14675a02SAndroid Build Coastguard Worker
195*14675a02SAndroid Build Coastguard Worker  def __init__(self, forwarding_info: Callable[[], common_pb2.ForwardingInfo],
196*14675a02SAndroid Build Coastguard Worker               media_service: media.Service):
197*14675a02SAndroid Build Coastguard Worker    self._forwarding_info = forwarding_info
198*14675a02SAndroid Build Coastguard Worker    self._media_service = media_service
199*14675a02SAndroid Build Coastguard Worker    self._sessions: dict[str, _AggregationSessionState] = {}
200*14675a02SAndroid Build Coastguard Worker    self._sessions_lock = threading.Lock()
201*14675a02SAndroid Build Coastguard Worker
202*14675a02SAndroid Build Coastguard Worker  def create_session(self,
203*14675a02SAndroid Build Coastguard Worker                     aggregation_requirements: AggregationRequirements) -> str:
204*14675a02SAndroid Build Coastguard Worker    """Creates a new aggregation session and returns its id."""
205*14675a02SAndroid Build Coastguard Worker    session_id = str(uuid.uuid4())
206*14675a02SAndroid Build Coastguard Worker    callback = _AggregationProtocolCallback(
207*14675a02SAndroid Build Coastguard Worker        functools.partial(self._handle_protocol_abort, session_id))
208*14675a02SAndroid Build Coastguard Worker    if (len(aggregation_requirements.plan.phase) != 1 or
209*14675a02SAndroid Build Coastguard Worker        not aggregation_requirements.plan.phase[0].HasField('server_phase_v2')):
210*14675a02SAndroid Build Coastguard Worker      raise ValueError('Plan must contain exactly one server_phase_v2.')
211*14675a02SAndroid Build Coastguard Worker
212*14675a02SAndroid Build Coastguard Worker    # NOTE: For simplicity, this implementation only creates a single,
213*14675a02SAndroid Build Coastguard Worker    # in-process aggregation shard. In a production implementation, there should
214*14675a02SAndroid Build Coastguard Worker    # be multiple shards running on separate servers to enable high rates of
215*14675a02SAndroid Build Coastguard Worker    # client contributions. Utilities for combining results from separate shards
216*14675a02SAndroid Build Coastguard Worker    # are still in development as of Jan 2023.
217*14675a02SAndroid Build Coastguard Worker    agg_protocol = aggregation_protocols.create_simple_aggregation_protocol(
218*14675a02SAndroid Build Coastguard Worker        configuration_pb2.Configuration(aggregation_configs=[
219*14675a02SAndroid Build Coastguard Worker            self._translate_server_aggregation_config(aggregation_config)
220*14675a02SAndroid Build Coastguard Worker            for aggregation_config in
221*14675a02SAndroid Build Coastguard Worker            aggregation_requirements.plan.phase[0].server_phase_v2.aggregations
222*14675a02SAndroid Build Coastguard Worker        ]), callback)
223*14675a02SAndroid Build Coastguard Worker    agg_protocol.Start(0)
224*14675a02SAndroid Build Coastguard Worker
225*14675a02SAndroid Build Coastguard Worker    with self._sessions_lock:
226*14675a02SAndroid Build Coastguard Worker      self._sessions[session_id] = _AggregationSessionState(
227*14675a02SAndroid Build Coastguard Worker          requirements=aggregation_requirements,
228*14675a02SAndroid Build Coastguard Worker          callback=callback,
229*14675a02SAndroid Build Coastguard Worker          agg_protocol=agg_protocol)
230*14675a02SAndroid Build Coastguard Worker    return session_id
231*14675a02SAndroid Build Coastguard Worker
232*14675a02SAndroid Build Coastguard Worker  def complete_session(
233*14675a02SAndroid Build Coastguard Worker      self, session_id: str) -> tuple[SessionStatus, Optional[bytes]]:
234*14675a02SAndroid Build Coastguard Worker    """Completes the aggregation session and returns its results."""
235*14675a02SAndroid Build Coastguard Worker    with self._sessions_lock:
236*14675a02SAndroid Build Coastguard Worker      state = self._sessions.pop(session_id)
237*14675a02SAndroid Build Coastguard Worker
238*14675a02SAndroid Build Coastguard Worker    try:
239*14675a02SAndroid Build Coastguard Worker      # Only complete the AggregationProtocol if it's still pending. The most
240*14675a02SAndroid Build Coastguard Worker      # likely alternative is that it's ABORTED due to an error generated by the
241*14675a02SAndroid Build Coastguard Worker      # protocol itself.
242*14675a02SAndroid Build Coastguard Worker      status = self._get_session_status(state)
243*14675a02SAndroid Build Coastguard Worker      if status.status != AggregationStatus.PENDING:
244*14675a02SAndroid Build Coastguard Worker        return self._get_session_status(state), None
245*14675a02SAndroid Build Coastguard Worker
246*14675a02SAndroid Build Coastguard Worker      # Ensure privacy requirements have been met.
247*14675a02SAndroid Build Coastguard Worker      if (state.agg_protocol.GetStatus().num_inputs_aggregated_and_included <
248*14675a02SAndroid Build Coastguard Worker          state.requirements.minimum_clients_in_server_published_aggregate):
249*14675a02SAndroid Build Coastguard Worker        state.agg_protocol.Abort()
250*14675a02SAndroid Build Coastguard Worker        raise ValueError(
251*14675a02SAndroid Build Coastguard Worker            'minimum_clients_in_server_published_aggregate has not been met.')
252*14675a02SAndroid Build Coastguard Worker
253*14675a02SAndroid Build Coastguard Worker      state.agg_protocol.Complete()
254*14675a02SAndroid Build Coastguard Worker      result = state.callback.result.get(timeout=1)
255*14675a02SAndroid Build Coastguard Worker      if isinstance(result, absl_status.Status):
256*14675a02SAndroid Build Coastguard Worker        raise absl_status.StatusNotOk(result)
257*14675a02SAndroid Build Coastguard Worker      state.status = AggregationStatus.COMPLETED
258*14675a02SAndroid Build Coastguard Worker      return self._get_session_status(state), result
259*14675a02SAndroid Build Coastguard Worker    except (ValueError, absl_status.StatusNotOk, queue.Empty) as e:
260*14675a02SAndroid Build Coastguard Worker      logging.warning('Failed to complete aggregation session: %s', e)
261*14675a02SAndroid Build Coastguard Worker      state.status = AggregationStatus.FAILED
262*14675a02SAndroid Build Coastguard Worker      return self._get_session_status(state), None
263*14675a02SAndroid Build Coastguard Worker    finally:
264*14675a02SAndroid Build Coastguard Worker      self._cleanup_session(state)
265*14675a02SAndroid Build Coastguard Worker
266*14675a02SAndroid Build Coastguard Worker  def abort_session(self, session_id: str) -> SessionStatus:
267*14675a02SAndroid Build Coastguard Worker    """Aborts/cancels an aggregation session."""
268*14675a02SAndroid Build Coastguard Worker    with self._sessions_lock:
269*14675a02SAndroid Build Coastguard Worker      state = self._sessions.pop(session_id)
270*14675a02SAndroid Build Coastguard Worker
271*14675a02SAndroid Build Coastguard Worker    # Only abort the AggregationProtocol if it's still pending. The most likely
272*14675a02SAndroid Build Coastguard Worker    # alternative is that it's ABORTED due to an error generated by the protocol
273*14675a02SAndroid Build Coastguard Worker    # itself.
274*14675a02SAndroid Build Coastguard Worker    if state.status == AggregationStatus.PENDING:
275*14675a02SAndroid Build Coastguard Worker      state.status = AggregationStatus.ABORTED
276*14675a02SAndroid Build Coastguard Worker      state.agg_protocol.Abort()
277*14675a02SAndroid Build Coastguard Worker
278*14675a02SAndroid Build Coastguard Worker    self._cleanup_session(state)
279*14675a02SAndroid Build Coastguard Worker    return self._get_session_status(state)
280*14675a02SAndroid Build Coastguard Worker
281*14675a02SAndroid Build Coastguard Worker  def get_session_status(self, session_id: str) -> SessionStatus:
282*14675a02SAndroid Build Coastguard Worker    """Returns the status of an aggregation session."""
283*14675a02SAndroid Build Coastguard Worker    with self._sessions_lock:
284*14675a02SAndroid Build Coastguard Worker      return self._get_session_status(self._sessions[session_id])
285*14675a02SAndroid Build Coastguard Worker
286*14675a02SAndroid Build Coastguard Worker  async def wait(
287*14675a02SAndroid Build Coastguard Worker      self,
288*14675a02SAndroid Build Coastguard Worker      session_id: str,
289*14675a02SAndroid Build Coastguard Worker      num_inputs_aggregated_and_included: Optional[int] = None
290*14675a02SAndroid Build Coastguard Worker  ) -> SessionStatus:
291*14675a02SAndroid Build Coastguard Worker    """Blocks until all conditions are satisfied or the aggregation fails."""
292*14675a02SAndroid Build Coastguard Worker    with self._sessions_lock:
293*14675a02SAndroid Build Coastguard Worker      state = self._sessions[session_id]
294*14675a02SAndroid Build Coastguard Worker      # Check if any of the conditions are already satisfied.
295*14675a02SAndroid Build Coastguard Worker      status = self._get_session_status(state)
296*14675a02SAndroid Build Coastguard Worker      if (num_inputs_aggregated_and_included is None or
297*14675a02SAndroid Build Coastguard Worker          num_inputs_aggregated_and_included <=
298*14675a02SAndroid Build Coastguard Worker          status.num_inputs_aggregated_and_included):
299*14675a02SAndroid Build Coastguard Worker        return status
300*14675a02SAndroid Build Coastguard Worker
301*14675a02SAndroid Build Coastguard Worker      wait_data = _WaitData(num_inputs_aggregated_and_included)
302*14675a02SAndroid Build Coastguard Worker      state.pending_waits.add(wait_data)
303*14675a02SAndroid Build Coastguard Worker    return await wait_data.status_future
304*14675a02SAndroid Build Coastguard Worker
305*14675a02SAndroid Build Coastguard Worker  def pre_authorize_clients(self, session_id: str,
306*14675a02SAndroid Build Coastguard Worker                            num_tokens: int) -> Sequence[str]:
307*14675a02SAndroid Build Coastguard Worker    """Generates tokens authorizing clients to contribute to the session."""
308*14675a02SAndroid Build Coastguard Worker    tokens = set(str(uuid.uuid4()) for _ in range(num_tokens))
309*14675a02SAndroid Build Coastguard Worker    with self._sessions_lock:
310*14675a02SAndroid Build Coastguard Worker      self._sessions[session_id].authorization_tokens |= tokens
311*14675a02SAndroid Build Coastguard Worker    return list(tokens)
312*14675a02SAndroid Build Coastguard Worker
313*14675a02SAndroid Build Coastguard Worker  def _translate_intrinsic_arg(
314*14675a02SAndroid Build Coastguard Worker      self, intrinsic_arg: plan_pb2.ServerAggregationConfig.IntrinsicArg
315*14675a02SAndroid Build Coastguard Worker  ) -> configuration_pb2.Configuration.ServerAggregationConfig.IntrinsicArg:
316*14675a02SAndroid Build Coastguard Worker    """Transform an aggregation intrinsic arg for the aggregation service."""
317*14675a02SAndroid Build Coastguard Worker    if intrinsic_arg.HasField('input_tensor'):
318*14675a02SAndroid Build Coastguard Worker      return configuration_pb2.Configuration.ServerAggregationConfig.IntrinsicArg(
319*14675a02SAndroid Build Coastguard Worker          input_tensor=intrinsic_arg.input_tensor)
320*14675a02SAndroid Build Coastguard Worker    elif intrinsic_arg.HasField('state_tensor'):
321*14675a02SAndroid Build Coastguard Worker      raise ValueError(
322*14675a02SAndroid Build Coastguard Worker          'Non-client intrinsic args are not supported in this demo.'
323*14675a02SAndroid Build Coastguard Worker      )
324*14675a02SAndroid Build Coastguard Worker    else:
325*14675a02SAndroid Build Coastguard Worker      raise AssertionError(
326*14675a02SAndroid Build Coastguard Worker          'Cases should have exhausted all possible types of intrinsic args.')
327*14675a02SAndroid Build Coastguard Worker
328*14675a02SAndroid Build Coastguard Worker  def _translate_server_aggregation_config(
329*14675a02SAndroid Build Coastguard Worker      self, plan_aggregation_config: plan_pb2.ServerAggregationConfig
330*14675a02SAndroid Build Coastguard Worker  ) -> configuration_pb2.Configuration.ServerAggregationConfig:
331*14675a02SAndroid Build Coastguard Worker    """Transform the aggregation config for use by the aggregation service."""
332*14675a02SAndroid Build Coastguard Worker    if plan_aggregation_config.inner_aggregations:
333*14675a02SAndroid Build Coastguard Worker      raise AssertionError('Nested intrinsic structrues are not supported yet.')
334*14675a02SAndroid Build Coastguard Worker    return configuration_pb2.Configuration.ServerAggregationConfig(
335*14675a02SAndroid Build Coastguard Worker        intrinsic_uri=plan_aggregation_config.intrinsic_uri,
336*14675a02SAndroid Build Coastguard Worker        intrinsic_args=[
337*14675a02SAndroid Build Coastguard Worker            self._translate_intrinsic_arg(intrinsic_arg)
338*14675a02SAndroid Build Coastguard Worker            for intrinsic_arg in plan_aggregation_config.intrinsic_args
339*14675a02SAndroid Build Coastguard Worker        ],
340*14675a02SAndroid Build Coastguard Worker        output_tensors=plan_aggregation_config.output_tensors)
341*14675a02SAndroid Build Coastguard Worker
342*14675a02SAndroid Build Coastguard Worker  def _get_session_status(self,
343*14675a02SAndroid Build Coastguard Worker                          state: _AggregationSessionState) -> SessionStatus:
344*14675a02SAndroid Build Coastguard Worker    """Returns the SessionStatus for an _AggregationSessionState object."""
345*14675a02SAndroid Build Coastguard Worker    status = state.agg_protocol.GetStatus()
346*14675a02SAndroid Build Coastguard Worker    return SessionStatus(
347*14675a02SAndroid Build Coastguard Worker        status=state.status,
348*14675a02SAndroid Build Coastguard Worker        num_clients_completed=status.num_clients_completed,
349*14675a02SAndroid Build Coastguard Worker        num_clients_failed=status.num_clients_failed,
350*14675a02SAndroid Build Coastguard Worker        num_clients_pending=status.num_clients_pending,
351*14675a02SAndroid Build Coastguard Worker        num_clients_aborted=status.num_clients_aborted,
352*14675a02SAndroid Build Coastguard Worker        num_inputs_aggregated_and_included=(
353*14675a02SAndroid Build Coastguard Worker            status.num_inputs_aggregated_and_included),
354*14675a02SAndroid Build Coastguard Worker        num_inputs_aggregated_and_pending=(
355*14675a02SAndroid Build Coastguard Worker            status.num_inputs_aggregated_and_pending),
356*14675a02SAndroid Build Coastguard Worker        num_inputs_discarded=status.num_inputs_discarded)
357*14675a02SAndroid Build Coastguard Worker
358*14675a02SAndroid Build Coastguard Worker  def _get_http_status(self, code: absl_status.StatusCode) -> http.HTTPStatus:
359*14675a02SAndroid Build Coastguard Worker    """Returns the HTTPStatus code for an absl StatusCode."""
360*14675a02SAndroid Build Coastguard Worker    if (code == absl_status.StatusCode.INVALID_ARGUMENT or
361*14675a02SAndroid Build Coastguard Worker        code == absl_status.StatusCode.FAILED_PRECONDITION):
362*14675a02SAndroid Build Coastguard Worker      return http.HTTPStatus.BAD_REQUEST
363*14675a02SAndroid Build Coastguard Worker    elif code == absl_status.StatusCode.NOT_FOUND:
364*14675a02SAndroid Build Coastguard Worker      return http.HTTPStatus.NOT_FOUND
365*14675a02SAndroid Build Coastguard Worker    else:
366*14675a02SAndroid Build Coastguard Worker      return http.HTTPStatus.INTERNAL_SERVER_ERROR
367*14675a02SAndroid Build Coastguard Worker
368*14675a02SAndroid Build Coastguard Worker  def _cleanup_session(self, state: _AggregationSessionState) -> None:
369*14675a02SAndroid Build Coastguard Worker    """Cleans up the session and releases any resources.
370*14675a02SAndroid Build Coastguard Worker
371*14675a02SAndroid Build Coastguard Worker    Args:
372*14675a02SAndroid Build Coastguard Worker      state: The session state to clean up.
373*14675a02SAndroid Build Coastguard Worker    """
374*14675a02SAndroid Build Coastguard Worker    state.authorization_tokens.clear()
375*14675a02SAndroid Build Coastguard Worker    for client_data in state.active_clients.values():
376*14675a02SAndroid Build Coastguard Worker      self._media_service.finalize_upload(client_data.resource_name)
377*14675a02SAndroid Build Coastguard Worker    state.active_clients.clear()
378*14675a02SAndroid Build Coastguard Worker    # Anyone waiting on the session should be notified that it's finished.
379*14675a02SAndroid Build Coastguard Worker    if state.pending_waits:
380*14675a02SAndroid Build Coastguard Worker      status = self._get_session_status(state)
381*14675a02SAndroid Build Coastguard Worker      for data in state.pending_waits:
382*14675a02SAndroid Build Coastguard Worker        data.loop.call_soon_threadsafe(
383*14675a02SAndroid Build Coastguard Worker            functools.partial(data.status_future.set_result, status))
384*14675a02SAndroid Build Coastguard Worker      state.pending_waits.clear()
385*14675a02SAndroid Build Coastguard Worker
386*14675a02SAndroid Build Coastguard Worker  def _handle_protocol_abort(self, session_id: str) -> None:
387*14675a02SAndroid Build Coastguard Worker    """Notifies waiting clients when the protocol is aborted."""
388*14675a02SAndroid Build Coastguard Worker    with self._sessions_lock:
389*14675a02SAndroid Build Coastguard Worker      with contextlib.suppress(KeyError):
390*14675a02SAndroid Build Coastguard Worker        state = self._sessions[session_id]
391*14675a02SAndroid Build Coastguard Worker        state.status = AggregationStatus.FAILED
392*14675a02SAndroid Build Coastguard Worker        # Anyone waiting on the session should be notified it's been aborted.
393*14675a02SAndroid Build Coastguard Worker        if state.pending_waits:
394*14675a02SAndroid Build Coastguard Worker          status = self._get_session_status(state)
395*14675a02SAndroid Build Coastguard Worker          for data in state.pending_waits:
396*14675a02SAndroid Build Coastguard Worker            data.loop.call_soon_threadsafe(
397*14675a02SAndroid Build Coastguard Worker                functools.partial(data.status_future.set_result, status))
398*14675a02SAndroid Build Coastguard Worker          state.pending_waits.clear()
399*14675a02SAndroid Build Coastguard Worker
400*14675a02SAndroid Build Coastguard Worker  @http_actions.proto_action(
401*14675a02SAndroid Build Coastguard Worker      service='google.internal.federatedcompute.v1.Aggregations',
402*14675a02SAndroid Build Coastguard Worker      method='StartAggregationDataUpload')
403*14675a02SAndroid Build Coastguard Worker  def start_aggregation_data_upload(
404*14675a02SAndroid Build Coastguard Worker      self, request: aggregations_pb2.StartAggregationDataUploadRequest
405*14675a02SAndroid Build Coastguard Worker  ) -> operations_pb2.Operation:
406*14675a02SAndroid Build Coastguard Worker    """Handles a StartAggregationDataUpload request."""
407*14675a02SAndroid Build Coastguard Worker    with self._sessions_lock:
408*14675a02SAndroid Build Coastguard Worker      try:
409*14675a02SAndroid Build Coastguard Worker        state = self._sessions[request.aggregation_id]
410*14675a02SAndroid Build Coastguard Worker      except KeyError as e:
411*14675a02SAndroid Build Coastguard Worker        raise http_actions.HttpError(http.HTTPStatus.NOT_FOUND) from e
412*14675a02SAndroid Build Coastguard Worker      try:
413*14675a02SAndroid Build Coastguard Worker        state.authorization_tokens.remove(request.authorization_token)
414*14675a02SAndroid Build Coastguard Worker      except KeyError as e:
415*14675a02SAndroid Build Coastguard Worker        raise http_actions.HttpError(http.HTTPStatus.UNAUTHORIZED) from e
416*14675a02SAndroid Build Coastguard Worker
417*14675a02SAndroid Build Coastguard Worker    state.agg_protocol.AddClients(1)
418*14675a02SAndroid Build Coastguard Worker    client_token = str(uuid.uuid4())
419*14675a02SAndroid Build Coastguard Worker    client_id, close_status = state.callback.accepted_clients.get(timeout=1)
420*14675a02SAndroid Build Coastguard Worker    upload_name = self._media_service.register_upload()
421*14675a02SAndroid Build Coastguard Worker
422*14675a02SAndroid Build Coastguard Worker    with self._sessions_lock:
423*14675a02SAndroid Build Coastguard Worker      state.active_clients[client_token] = _ActiveClientData(
424*14675a02SAndroid Build Coastguard Worker          client_id, close_status, upload_name)
425*14675a02SAndroid Build Coastguard Worker
426*14675a02SAndroid Build Coastguard Worker    forwarding_info = self._forwarding_info()
427*14675a02SAndroid Build Coastguard Worker    response = aggregations_pb2.StartAggregationDataUploadResponse(
428*14675a02SAndroid Build Coastguard Worker        aggregation_protocol_forwarding_info=forwarding_info,
429*14675a02SAndroid Build Coastguard Worker        resource=common_pb2.ByteStreamResource(
430*14675a02SAndroid Build Coastguard Worker            data_upload_forwarding_info=forwarding_info,
431*14675a02SAndroid Build Coastguard Worker            resource_name=upload_name),
432*14675a02SAndroid Build Coastguard Worker        client_token=client_token)
433*14675a02SAndroid Build Coastguard Worker
434*14675a02SAndroid Build Coastguard Worker    op = operations_pb2.Operation(name=f'operations/{uuid.uuid4()}', done=True)
435*14675a02SAndroid Build Coastguard Worker    op.metadata.Pack(aggregations_pb2.StartAggregationDataUploadMetadata())
436*14675a02SAndroid Build Coastguard Worker    op.response.Pack(response)
437*14675a02SAndroid Build Coastguard Worker    return op
438*14675a02SAndroid Build Coastguard Worker
439*14675a02SAndroid Build Coastguard Worker  @http_actions.proto_action(
440*14675a02SAndroid Build Coastguard Worker      service='google.internal.federatedcompute.v1.Aggregations',
441*14675a02SAndroid Build Coastguard Worker      method='SubmitAggregationResult')
442*14675a02SAndroid Build Coastguard Worker  def submit_aggregation_result(
443*14675a02SAndroid Build Coastguard Worker      self, request: aggregations_pb2.SubmitAggregationResultRequest
444*14675a02SAndroid Build Coastguard Worker  ) -> aggregations_pb2.SubmitAggregationResultResponse:
445*14675a02SAndroid Build Coastguard Worker    """Handles a SubmitAggregationResult request."""
446*14675a02SAndroid Build Coastguard Worker    with self._sessions_lock:
447*14675a02SAndroid Build Coastguard Worker      try:
448*14675a02SAndroid Build Coastguard Worker        state = self._sessions[request.aggregation_id]
449*14675a02SAndroid Build Coastguard Worker      except KeyError as e:
450*14675a02SAndroid Build Coastguard Worker        raise http_actions.HttpError(http.HTTPStatus.NOT_FOUND) from e
451*14675a02SAndroid Build Coastguard Worker      try:
452*14675a02SAndroid Build Coastguard Worker        client_data = state.active_clients.pop(request.client_token)
453*14675a02SAndroid Build Coastguard Worker      except KeyError as e:
454*14675a02SAndroid Build Coastguard Worker        raise http_actions.HttpError(http.HTTPStatus.UNAUTHORIZED) from e
455*14675a02SAndroid Build Coastguard Worker
456*14675a02SAndroid Build Coastguard Worker    # Ensure the client is using the resource name provided when they called
457*14675a02SAndroid Build Coastguard Worker    # StartAggregationDataUpload.
458*14675a02SAndroid Build Coastguard Worker    if request.resource_name != client_data.resource_name:
459*14675a02SAndroid Build Coastguard Worker      raise http_actions.HttpError(http.HTTPStatus.BAD_REQUEST)
460*14675a02SAndroid Build Coastguard Worker
461*14675a02SAndroid Build Coastguard Worker    # The aggregation protocol may have already closed the connect (e.g., if
462*14675a02SAndroid Build Coastguard Worker    # an error occurred). If so, clean up the upload and return the error.
463*14675a02SAndroid Build Coastguard Worker    if not client_data.close_status.empty():
464*14675a02SAndroid Build Coastguard Worker      with contextlib.suppress(KeyError):
465*14675a02SAndroid Build Coastguard Worker        self._media_service.finalize_upload(request.resource_name)
466*14675a02SAndroid Build Coastguard Worker      raise http_actions.HttpError(
467*14675a02SAndroid Build Coastguard Worker          self._get_http_status(client_data.close_status.get().code()))
468*14675a02SAndroid Build Coastguard Worker
469*14675a02SAndroid Build Coastguard Worker    # Finalize the upload.
470*14675a02SAndroid Build Coastguard Worker    try:
471*14675a02SAndroid Build Coastguard Worker      update = self._media_service.finalize_upload(request.resource_name)
472*14675a02SAndroid Build Coastguard Worker      if update is None:
473*14675a02SAndroid Build Coastguard Worker        raise absl_status.StatusNotOk(
474*14675a02SAndroid Build Coastguard Worker            absl_status.invalid_argument_error(
475*14675a02SAndroid Build Coastguard Worker                'Aggregation result never uploaded'))
476*14675a02SAndroid Build Coastguard Worker    except (KeyError, absl_status.StatusNotOk) as e:
477*14675a02SAndroid Build Coastguard Worker      if isinstance(e, KeyError):
478*14675a02SAndroid Build Coastguard Worker        e = absl_status.StatusNotOk(
479*14675a02SAndroid Build Coastguard Worker            absl_status.internal_error('Failed to finalize upload'))
480*14675a02SAndroid Build Coastguard Worker      state.agg_protocol.CloseClient(client_data.client_id, e.status)
481*14675a02SAndroid Build Coastguard Worker      # Since we're initiating the close, it's also necessary to notify the
482*14675a02SAndroid Build Coastguard Worker      # _AggregationProtocolCallback so it can clean up resources.
483*14675a02SAndroid Build Coastguard Worker      state.callback.OnCloseClient(client_data.client_id, e.status)
484*14675a02SAndroid Build Coastguard Worker      raise http_actions.HttpError(self._get_http_status(
485*14675a02SAndroid Build Coastguard Worker          e.status.code())) from e
486*14675a02SAndroid Build Coastguard Worker
487*14675a02SAndroid Build Coastguard Worker    client_message = apm_pb2.ClientMessage(
488*14675a02SAndroid Build Coastguard Worker        simple_aggregation=apm_pb2.ClientMessage.SimpleAggregation(
489*14675a02SAndroid Build Coastguard Worker            input=apm_pb2.ClientResource(inline_bytes=update)))
490*14675a02SAndroid Build Coastguard Worker    try:
491*14675a02SAndroid Build Coastguard Worker      state.agg_protocol.ReceiveClientMessage(client_data.client_id,
492*14675a02SAndroid Build Coastguard Worker                                              client_message)
493*14675a02SAndroid Build Coastguard Worker    except absl_status.StatusNotOk as e:
494*14675a02SAndroid Build Coastguard Worker      # ReceiveClientInput should only fail if the AggregationProtocol is in a
495*14675a02SAndroid Build Coastguard Worker      # bad state -- likely leading to it being aborted.
496*14675a02SAndroid Build Coastguard Worker      logging.warning('Failed to receive client input: %s', e)
497*14675a02SAndroid Build Coastguard Worker      raise http_actions.HttpError(http.HTTPStatus.INTERNAL_SERVER_ERROR) from e
498*14675a02SAndroid Build Coastguard Worker
499*14675a02SAndroid Build Coastguard Worker    # Wait for the client input to be processed.
500*14675a02SAndroid Build Coastguard Worker    close_status = client_data.close_status.get()
501*14675a02SAndroid Build Coastguard Worker    if not close_status.ok():
502*14675a02SAndroid Build Coastguard Worker      raise http_actions.HttpError(self._get_http_status(close_status.code()))
503*14675a02SAndroid Build Coastguard Worker
504*14675a02SAndroid Build Coastguard Worker    # Check for any newly-satisfied pending wait operations.
505*14675a02SAndroid Build Coastguard Worker    with self._sessions_lock:
506*14675a02SAndroid Build Coastguard Worker      if state.pending_waits:
507*14675a02SAndroid Build Coastguard Worker        completed_waits = set()
508*14675a02SAndroid Build Coastguard Worker        status = self._get_session_status(state)
509*14675a02SAndroid Build Coastguard Worker        for data in state.pending_waits:
510*14675a02SAndroid Build Coastguard Worker          if (data.num_inputs_aggregated_and_included is not None and
511*14675a02SAndroid Build Coastguard Worker              status.num_inputs_aggregated_and_included >=
512*14675a02SAndroid Build Coastguard Worker              data.num_inputs_aggregated_and_included):
513*14675a02SAndroid Build Coastguard Worker            data.loop.call_soon_threadsafe(
514*14675a02SAndroid Build Coastguard Worker                functools.partial(data.status_future.set_result, status))
515*14675a02SAndroid Build Coastguard Worker            completed_waits.add(data)
516*14675a02SAndroid Build Coastguard Worker        state.pending_waits -= completed_waits
517*14675a02SAndroid Build Coastguard Worker    return aggregations_pb2.SubmitAggregationResultResponse()
518*14675a02SAndroid Build Coastguard Worker
519*14675a02SAndroid Build Coastguard Worker  @http_actions.proto_action(
520*14675a02SAndroid Build Coastguard Worker      service='google.internal.federatedcompute.v1.Aggregations',
521*14675a02SAndroid Build Coastguard Worker      method='AbortAggregation')
522*14675a02SAndroid Build Coastguard Worker  def abort_aggregation(
523*14675a02SAndroid Build Coastguard Worker      self, request: aggregations_pb2.AbortAggregationRequest
524*14675a02SAndroid Build Coastguard Worker  ) -> aggregations_pb2.AbortAggregationResponse:
525*14675a02SAndroid Build Coastguard Worker    """Handles an AbortAggregation request."""
526*14675a02SAndroid Build Coastguard Worker    with self._sessions_lock:
527*14675a02SAndroid Build Coastguard Worker      try:
528*14675a02SAndroid Build Coastguard Worker        state = self._sessions[request.aggregation_id]
529*14675a02SAndroid Build Coastguard Worker      except KeyError as e:
530*14675a02SAndroid Build Coastguard Worker        raise http_actions.HttpError(http.HTTPStatus.NOT_FOUND) from e
531*14675a02SAndroid Build Coastguard Worker      try:
532*14675a02SAndroid Build Coastguard Worker        client_data = state.active_clients.pop(request.client_token)
533*14675a02SAndroid Build Coastguard Worker      except KeyError as e:
534*14675a02SAndroid Build Coastguard Worker        raise http_actions.HttpError(http.HTTPStatus.UNAUTHORIZED) from e
535*14675a02SAndroid Build Coastguard Worker
536*14675a02SAndroid Build Coastguard Worker    # Attempt to finalize the in-progress upload to free up resources.
537*14675a02SAndroid Build Coastguard Worker    with contextlib.suppress(KeyError):
538*14675a02SAndroid Build Coastguard Worker      self._media_service.finalize_upload(client_data.resource_name)
539*14675a02SAndroid Build Coastguard Worker
540*14675a02SAndroid Build Coastguard Worker    # Notify the aggregation protocol that the client has left.
541*14675a02SAndroid Build Coastguard Worker    if request.status.code == code_pb2.Code.OK:
542*14675a02SAndroid Build Coastguard Worker      status = absl_status.Status.OkStatus()
543*14675a02SAndroid Build Coastguard Worker    else:
544*14675a02SAndroid Build Coastguard Worker      status = absl_status.BuildStatusNotOk(
545*14675a02SAndroid Build Coastguard Worker          absl_status.StatusCodeFromInt(request.status.code),
546*14675a02SAndroid Build Coastguard Worker          request.status.message)
547*14675a02SAndroid Build Coastguard Worker    state.agg_protocol.CloseClient(client_data.client_id, status)
548*14675a02SAndroid Build Coastguard Worker    # Since we're initiating the close, it's also necessary to notify the
549*14675a02SAndroid Build Coastguard Worker    # _AggregationProtocolCallback so it can clean up resources.
550*14675a02SAndroid Build Coastguard Worker    state.callback.OnCloseClient(client_data.client_id, status)
551*14675a02SAndroid Build Coastguard Worker
552*14675a02SAndroid Build Coastguard Worker    logging.debug('[%s] AbortAggregation: %s', request.aggregation_id,
553*14675a02SAndroid Build Coastguard Worker                  request.status)
554*14675a02SAndroid Build Coastguard Worker    return aggregations_pb2.AbortAggregationResponse()
555