xref: /aosp_15_r20/external/federated-compute/fcp/demo/server.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1*14675a02SAndroid Build Coastguard Worker# Copyright 2022 Google LLC
2*14675a02SAndroid Build Coastguard Worker#
3*14675a02SAndroid Build Coastguard Worker# Licensed under the Apache License, Version 2.0 (the "License");
4*14675a02SAndroid Build Coastguard Worker# you may not use this file except in compliance with the License.
5*14675a02SAndroid Build Coastguard Worker# You may obtain a copy of the License at
6*14675a02SAndroid Build Coastguard Worker#
7*14675a02SAndroid Build Coastguard Worker#      http://www.apache.org/licenses/LICENSE-2.0
8*14675a02SAndroid Build Coastguard Worker#
9*14675a02SAndroid Build Coastguard Worker# Unless required by applicable law or agreed to in writing, software
10*14675a02SAndroid Build Coastguard Worker# distributed under the License is distributed on an "AS IS" BASIS,
11*14675a02SAndroid Build Coastguard Worker# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*14675a02SAndroid Build Coastguard Worker# See the License for the specific language governing permissions and
13*14675a02SAndroid Build Coastguard Worker# limitations under the License.
14*14675a02SAndroid Build Coastguard Worker"""An in-process federated compute server."""
15*14675a02SAndroid Build Coastguard Worker
16*14675a02SAndroid Build Coastguard Workerimport contextlib
17*14675a02SAndroid Build Coastguard Workerimport gzip
18*14675a02SAndroid Build Coastguard Workerimport http.server
19*14675a02SAndroid Build Coastguard Workerimport socket
20*14675a02SAndroid Build Coastguard Workerimport socketserver
21*14675a02SAndroid Build Coastguard Workerimport ssl
22*14675a02SAndroid Build Coastguard Workerfrom typing import Optional
23*14675a02SAndroid Build Coastguard Worker
24*14675a02SAndroid Build Coastguard Workerfrom absl import logging
25*14675a02SAndroid Build Coastguard Worker
26*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import aggregations
27*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import eligibility_eval_tasks
28*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import http_actions
29*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import media
30*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import plan_utils
31*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import task_assignments
32*14675a02SAndroid Build Coastguard Workerfrom fcp.protos import plan_pb2
33*14675a02SAndroid Build Coastguard Workerfrom fcp.protos.federatedcompute import common_pb2
34*14675a02SAndroid Build Coastguard Workerfrom fcp.protos.federatedcompute import eligibility_eval_tasks_pb2
35*14675a02SAndroid Build Coastguard Worker
36*14675a02SAndroid Build Coastguard Worker_TaskAssignmentMode = (
37*14675a02SAndroid Build Coastguard Worker    eligibility_eval_tasks_pb2.PopulationEligibilitySpec.TaskInfo.TaskAssignmentMode
38*14675a02SAndroid Build Coastguard Worker)
39*14675a02SAndroid Build Coastguard Worker
40*14675a02SAndroid Build Coastguard Worker# Template for file name for federated select slices. See
41*14675a02SAndroid Build Coastguard Worker# `FederatedSelectUriInfo.uri_template` for the meaning of the "{served_at_id}"
42*14675a02SAndroid Build Coastguard Worker# and "{key_base10}" substrings.
43*14675a02SAndroid Build Coastguard Worker_FEDERATED_SELECT_NAME_TEMPLATE = '{served_at_id}_{key_base10}'
44*14675a02SAndroid Build Coastguard Worker
45*14675a02SAndroid Build Coastguard Worker# Content type used for serialized and compressed Plan messages.
46*14675a02SAndroid Build Coastguard Worker_PLAN_CONTENT_TYPE = 'application/x-protobuf+gzip'
47*14675a02SAndroid Build Coastguard Worker
48*14675a02SAndroid Build Coastguard Worker# Content type used for serialzied and compressed TensorFlow checkpoints.
49*14675a02SAndroid Build Coastguard Worker_CHECKPOINT_CONTENT_TYPE = 'application/octet-stream+gzip'
50*14675a02SAndroid Build Coastguard Worker
51*14675a02SAndroid Build Coastguard Worker
52*14675a02SAndroid Build Coastguard Workerclass InProcessServer(socketserver.ThreadingMixIn, http.server.HTTPServer):
53*14675a02SAndroid Build Coastguard Worker  """An in-process HTTP server implementing the Federated Compute protocol."""
54*14675a02SAndroid Build Coastguard Worker
55*14675a02SAndroid Build Coastguard Worker  def __init__(self,
56*14675a02SAndroid Build Coastguard Worker               *,
57*14675a02SAndroid Build Coastguard Worker               population_name: str,
58*14675a02SAndroid Build Coastguard Worker               host: str,
59*14675a02SAndroid Build Coastguard Worker               port: int,
60*14675a02SAndroid Build Coastguard Worker               address_family: Optional[socket.AddressFamily] = None):
61*14675a02SAndroid Build Coastguard Worker    self._media_service = media.Service(self._get_forwarding_info)
62*14675a02SAndroid Build Coastguard Worker    self._aggregations_service = aggregations.Service(self._get_forwarding_info,
63*14675a02SAndroid Build Coastguard Worker                                                      self._media_service)
64*14675a02SAndroid Build Coastguard Worker    self._task_assignments_service = task_assignments.Service(
65*14675a02SAndroid Build Coastguard Worker        population_name, self._get_forwarding_info, self._aggregations_service)
66*14675a02SAndroid Build Coastguard Worker    self._eligibility_eval_tasks_service = eligibility_eval_tasks.Service(
67*14675a02SAndroid Build Coastguard Worker        population_name, self._get_forwarding_info
68*14675a02SAndroid Build Coastguard Worker    )
69*14675a02SAndroid Build Coastguard Worker    handler = http_actions.create_handler(
70*14675a02SAndroid Build Coastguard Worker        self._media_service,
71*14675a02SAndroid Build Coastguard Worker        self._aggregations_service,
72*14675a02SAndroid Build Coastguard Worker        self._task_assignments_service,
73*14675a02SAndroid Build Coastguard Worker        self._eligibility_eval_tasks_service,
74*14675a02SAndroid Build Coastguard Worker    )
75*14675a02SAndroid Build Coastguard Worker    if address_family is not None:
76*14675a02SAndroid Build Coastguard Worker      self.address_family = address_family
77*14675a02SAndroid Build Coastguard Worker    http.server.HTTPServer.__init__(self, (host, port), handler)
78*14675a02SAndroid Build Coastguard Worker
79*14675a02SAndroid Build Coastguard Worker  async def run_computation(
80*14675a02SAndroid Build Coastguard Worker      self,
81*14675a02SAndroid Build Coastguard Worker      task_name: str,
82*14675a02SAndroid Build Coastguard Worker      plan: plan_pb2.Plan,
83*14675a02SAndroid Build Coastguard Worker      server_checkpoint: bytes,
84*14675a02SAndroid Build Coastguard Worker      task_assignment_mode: _TaskAssignmentMode,
85*14675a02SAndroid Build Coastguard Worker      number_of_clients: int,
86*14675a02SAndroid Build Coastguard Worker  ) -> bytes:
87*14675a02SAndroid Build Coastguard Worker    """Runs a computation, returning the resulting checkpoint.
88*14675a02SAndroid Build Coastguard Worker
89*14675a02SAndroid Build Coastguard Worker    If there's already a computation in progress, the new computation will
90*14675a02SAndroid Build Coastguard Worker    not start until the previous one has completed (either successfully or not).
91*14675a02SAndroid Build Coastguard Worker
92*14675a02SAndroid Build Coastguard Worker    Args:
93*14675a02SAndroid Build Coastguard Worker      task_name: The name of the task.
94*14675a02SAndroid Build Coastguard Worker      plan: The Plan proto containing the client and server computations.
95*14675a02SAndroid Build Coastguard Worker      server_checkpoint: The starting server checkpoint.
96*14675a02SAndroid Build Coastguard Worker      task_assignment_mode: The task assignment mode to use for the computation.
97*14675a02SAndroid Build Coastguard Worker      number_of_clients: The minimum number of clients to include.
98*14675a02SAndroid Build Coastguard Worker
99*14675a02SAndroid Build Coastguard Worker    Returns:
100*14675a02SAndroid Build Coastguard Worker      A TensorFlow checkpoint containing the aggregated results.
101*14675a02SAndroid Build Coastguard Worker    """
102*14675a02SAndroid Build Coastguard Worker    requirements = aggregations.AggregationRequirements(
103*14675a02SAndroid Build Coastguard Worker        minimum_clients_in_server_published_aggregate=number_of_clients,
104*14675a02SAndroid Build Coastguard Worker        plan=plan)
105*14675a02SAndroid Build Coastguard Worker    session_id = self._aggregations_service.create_session(requirements)
106*14675a02SAndroid Build Coastguard Worker    with contextlib.ExitStack() as stack:
107*14675a02SAndroid Build Coastguard Worker      stack.callback(
108*14675a02SAndroid Build Coastguard Worker          lambda: self._aggregations_service.abort_session(session_id))
109*14675a02SAndroid Build Coastguard Worker      with plan_utils.Session(plan, server_checkpoint) as session:
110*14675a02SAndroid Build Coastguard Worker        with self._media_service.create_download_group() as group:
111*14675a02SAndroid Build Coastguard Worker          plan_url = group.add(
112*14675a02SAndroid Build Coastguard Worker              'plan',
113*14675a02SAndroid Build Coastguard Worker              gzip.compress(session.client_plan),
114*14675a02SAndroid Build Coastguard Worker              content_type=_PLAN_CONTENT_TYPE,
115*14675a02SAndroid Build Coastguard Worker          )
116*14675a02SAndroid Build Coastguard Worker          checkpoint_url = group.add(
117*14675a02SAndroid Build Coastguard Worker              'checkpoint',
118*14675a02SAndroid Build Coastguard Worker              gzip.compress(session.client_checkpoint),
119*14675a02SAndroid Build Coastguard Worker              content_type=_CHECKPOINT_CONTENT_TYPE,
120*14675a02SAndroid Build Coastguard Worker          )
121*14675a02SAndroid Build Coastguard Worker          for served_at_id, slices in session.slices.items():
122*14675a02SAndroid Build Coastguard Worker            for i, slice_data in enumerate(slices):
123*14675a02SAndroid Build Coastguard Worker              group.add(
124*14675a02SAndroid Build Coastguard Worker                  _FEDERATED_SELECT_NAME_TEMPLATE.format(
125*14675a02SAndroid Build Coastguard Worker                      served_at_id=served_at_id, key_base10=str(i)
126*14675a02SAndroid Build Coastguard Worker                  ),
127*14675a02SAndroid Build Coastguard Worker                  gzip.compress(slice_data),
128*14675a02SAndroid Build Coastguard Worker                  content_type=_CHECKPOINT_CONTENT_TYPE,
129*14675a02SAndroid Build Coastguard Worker              )
130*14675a02SAndroid Build Coastguard Worker          self._eligibility_eval_tasks_service.add_task(
131*14675a02SAndroid Build Coastguard Worker              task_name, task_assignment_mode
132*14675a02SAndroid Build Coastguard Worker          )
133*14675a02SAndroid Build Coastguard Worker          self._task_assignments_service.add_task(
134*14675a02SAndroid Build Coastguard Worker              task_name,
135*14675a02SAndroid Build Coastguard Worker              task_assignment_mode,
136*14675a02SAndroid Build Coastguard Worker              session_id,
137*14675a02SAndroid Build Coastguard Worker              common_pb2.Resource(uri=plan_url),
138*14675a02SAndroid Build Coastguard Worker              common_pb2.Resource(uri=checkpoint_url),
139*14675a02SAndroid Build Coastguard Worker              group.prefix + _FEDERATED_SELECT_NAME_TEMPLATE,
140*14675a02SAndroid Build Coastguard Worker          )
141*14675a02SAndroid Build Coastguard Worker          try:
142*14675a02SAndroid Build Coastguard Worker            status = await self._aggregations_service.wait(
143*14675a02SAndroid Build Coastguard Worker                session_id,
144*14675a02SAndroid Build Coastguard Worker                num_inputs_aggregated_and_included=number_of_clients)
145*14675a02SAndroid Build Coastguard Worker            if status.status != aggregations.AggregationStatus.PENDING:
146*14675a02SAndroid Build Coastguard Worker              raise ValueError('Aggregation failed.')
147*14675a02SAndroid Build Coastguard Worker          finally:
148*14675a02SAndroid Build Coastguard Worker            self._task_assignments_service.remove_task(session_id)
149*14675a02SAndroid Build Coastguard Worker            self._eligibility_eval_tasks_service.remove_task(task_name)
150*14675a02SAndroid Build Coastguard Worker
151*14675a02SAndroid Build Coastguard Worker          stack.pop_all()
152*14675a02SAndroid Build Coastguard Worker          status, intermedia_update = (
153*14675a02SAndroid Build Coastguard Worker              self._aggregations_service.complete_session(session_id))
154*14675a02SAndroid Build Coastguard Worker          if (status.status != aggregations.AggregationStatus.COMPLETED or
155*14675a02SAndroid Build Coastguard Worker              intermedia_update is None):
156*14675a02SAndroid Build Coastguard Worker            raise ValueError('Aggregation failed.')
157*14675a02SAndroid Build Coastguard Worker        logging.debug('%s aggregation complete: %s', task_name, status)
158*14675a02SAndroid Build Coastguard Worker        return session.finalize(intermedia_update)
159*14675a02SAndroid Build Coastguard Worker
160*14675a02SAndroid Build Coastguard Worker  def _get_forwarding_info(self) -> common_pb2.ForwardingInfo:
161*14675a02SAndroid Build Coastguard Worker    protocol = 'https' if isinstance(self.socket, ssl.SSLSocket) else 'http'
162*14675a02SAndroid Build Coastguard Worker    return common_pb2.ForwardingInfo(
163*14675a02SAndroid Build Coastguard Worker        target_uri_prefix=(
164*14675a02SAndroid Build Coastguard Worker            f'{protocol}://{self.server_name}:{self.server_port}/'))
165