xref: /aosp_15_r20/external/federated-compute/fcp/demo/federated_context.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""TFF FederatedContext subclass for the demo Federated Computation platform."""
15
16from collections.abc import Awaitable
17import socket
18import ssl
19import threading
20from typing import Any, Optional, Union
21import uuid
22
23from absl import logging
24import attr
25import numpy as np
26import tensorflow as tf
27import tensorflow_federated as tff
28import tree
29
30from fcp.artifact_building import artifact_constants
31from fcp.artifact_building import checkpoint_utils
32from fcp.artifact_building import data_spec
33from fcp.artifact_building import federated_compute_plan_builder
34from fcp.artifact_building import plan_utils
35from fcp.artifact_building import variable_helpers
36from fcp.demo import checkpoint_tensor_reference
37from fcp.demo import federated_computation
38from fcp.demo import federated_data_source
39from fcp.demo import server
40from fcp.protos import plan_pb2
41
42
43class FederatedContext(tff.program.FederatedContext):
44  """A FederatedContext for use with the demo platform."""
45
46  def __init__(self,
47               population_name: str,
48               *,
49               base_context: Optional[tff.framework.SyncContext] = None,
50               host: str = 'localhost',
51               port: int = 0,
52               certfile: Optional[str] = None,
53               keyfile: Optional[str] = None,
54               address_family: Optional[socket.AddressFamily] = None):
55    """Initializes a `FederatedContext`.
56
57    Args:
58      population_name: The name of the population to execute computations on.
59      base_context: The context used to run non-federated TFF computations
60        (i.e., computations with a type other than FederatedComputation).
61      host: The hostname the server should bind to.
62      port: The port the server should listen on.
63      certfile: The path to the certificate to use for https.
64      keyfile: The path to the certificate's private key (if separate).
65      address_family: An override for the HTTP server's address family.
66    """
67    # NOTE: The demo server only supports a single population, which must be
68    # specified at startup. An implementation that supports multiple populations
69    # should only use the population name from the PopulationDataSource.
70    if not federated_data_source.POPULATION_NAME_REGEX.fullmatch(
71        population_name):
72      raise ValueError(
73          'population_name must match '
74          f'"{federated_data_source.POPULATION_NAME_REGEX.pattern}".')
75    self._population_name = population_name
76    self._base_context = base_context
77    self._server = server.InProcessServer(
78        population_name=population_name,
79        host=host,
80        port=port,
81        address_family=address_family)
82    if certfile is not None:
83      context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
84      context.load_cert_chain(certfile, keyfile)
85      self._server.socket = context.wrap_socket(
86          self._server.socket, server_side=True)
87    self._server_thread = threading.Thread(
88        target=self._server.serve_forever, daemon=True)
89    self._cached_comps: dict[tuple[tff.Computation, int], plan_pb2.Plan] = {}
90
91  @property
92  def server_port(self) -> int:
93    """The port on which the Federated Compute server is running."""
94    return self._server.server_port
95
96  def __enter__(self):
97    self._server_thread.start()
98    logging.log(logging.INFO, 'Federated Compute server running on %s:%s',
99                self._server.server_name, self._server.server_port)
100    return self
101
102  def __exit__(self, exc_type, exc_value, tb):
103    self._server.shutdown()
104    self._server_thread.join()
105    logging.log(logging.INFO, 'Federated Compute server stopped')
106
107  def invoke(self, comp: tff.Computation, arg: Any) -> Any:
108    """Invokes a computation.
109
110    Args:
111      comp: The computation being invoked.
112      arg: The arguments of the call encoded in a computation-specific way. For
113        FederatedComputations, this should be a `(state, config)` tuple, where
114        the state is a possibly nested structure and the configuration is
115        provided by a FederatedDataSource.
116
117    Returns:
118      A value reference structure representing the result of the computation.
119    """
120    # Pass other computation types to the next FederatedContext.
121    if not isinstance(comp, federated_computation.FederatedComputation):
122      if not self._base_context:
123        raise TypeError('computation must be a FederatedComputation if no '
124                        'base_context was provided.')
125      return self._base_context.invoke(comp, arg)
126
127    state, config = self._parse_arg(arg)
128    if config.population_name != self._population_name:
129      raise ValueError('FederatedDataSource and FederatedContext '
130                       'population_names must match.')
131
132    # Since building the plan can be slow, we cache the result to speed up
133    # subsequent invocations.
134    cache_key = (comp.wrapped_computation, id(config.example_selector))
135    try:
136      plan = self._cached_comps[cache_key]
137    except KeyError:
138      plan = federated_compute_plan_builder.build_plan(
139          comp.map_reduce_form,
140          comp.distribute_aggregate_form,
141          self._get_nested_data_spec(config.example_selector),
142          grappler_config=tf.compat.v1.ConfigProto(),
143          generate_server_phase_v2=True,
144      )
145      # Add the TF Lite flatbuffer to the plan. If the conversion fails, the
146      # flatbuffer will be silently omitted and the client will use the
147      # TensorFlow graph in `plan.client_graph_bytes` instead.
148      # NOTE: If conversion failures should not be silent, pass
149      # `forgive_tflite_conversion_failure=False`.
150      plan = plan_utils.generate_and_add_flat_buffer_to_plan(plan)
151      self._cached_comps[cache_key] = plan
152
153    checkpoint_future = self._run_computation(comp.name, config, plan,
154                                              comp.type_signature.parameter[0],
155                                              state)
156    result_value_ref = self._create_tensor_reference_struct(
157        comp.type_signature.result, checkpoint_future)
158    return tff.types.type_to_py_container(result_value_ref,
159                                          comp.type_signature.result)
160
161  def _is_state_structure_of_allowed_types(
162      self,
163      structure: Union[
164          tff.structure.Struct,
165          tf.Tensor,
166          tff.program.MaterializableValue,
167      ],
168  ) -> bool:
169    """Checks if each node in `structure` is an allowed type for `state`."""
170    if isinstance(structure, tff.structure.Struct):
171      structure = tff.structure.flatten(structure)
172    else:
173      structure = tree.flatten(structure)
174    for item in structure:
175      if not (
176          tf.is_tensor(item)
177          or isinstance(
178              item,
179              (
180                  np.ndarray,
181                  np.number,
182                  int,
183                  float,
184                  str,
185                  bytes,
186                  tff.program.MaterializableValueReference,
187              ),
188          )
189      ):
190        return False
191    return True
192
193  def _parse_arg(
194      self, arg: tff.structure.Struct
195  ) -> tuple[Union[tff.structure.Struct, tf.Tensor,
196                   tff.program.MaterializableValueReference],
197             federated_data_source.DataSelectionConfig]:
198    """Parses and validates the invoke arguments."""
199    if len(arg) != 2:
200      raise ValueError(f'The argument structure is unsupported: {arg}.')
201
202    state, config = arg
203    if attr.has(type(state)):
204      state = tff.structure.from_container(state, recursive=True)
205    if not self._is_state_structure_of_allowed_types(state):
206      raise TypeError(
207          'arg[0] must be a value or structure of values of '
208          '`MaterializableValueReference`s, `tf.Tensor`s, '
209          '`np.ndarray`s, `np.number`s, or Python scalars. Got: '
210          f'{tf.nest.map_structure(type, state)!r})'
211      )
212
213    # Code below assumes single values are always `tf.Tensor`s.
214    if isinstance(state, (int, float, str, bytes, np.ndarray, np.number)):
215      state = tf.convert_to_tensor(state)
216
217    if not isinstance(config, federated_data_source.DataSelectionConfig):
218      raise TypeError('arg[1] must be the result of '
219                      'FederatedDataSource.iterator().select().')
220    return state, config
221
222  def _get_nested_data_spec(self, example_selector) -> data_spec.NestedDataSpec:
223    """Converts a NestedExampleSelector to a NestedDataSpec."""
224    if isinstance(example_selector, dict):
225      return {
226          k: self._get_nested_data_spec(v) for k, v in example_selector.items()
227      }
228    return data_spec.DataSpec(example_selector)
229
230  async def _run_computation(
231      self, name: str, config: federated_data_source.DataSelectionConfig,
232      plan: plan_pb2.Plan, input_type: tff.Type,
233      input_state: Union[tff.structure.Struct, tf.Tensor,
234                         tff.program.MaterializableValueReference]
235  ) -> bytes:
236    """Prepares and runs a computation using the demo server."""
237    input_checkpoint = self._state_to_checkpoint(
238        input_type, await self._resolve_value_references(input_state))
239    try:
240      logging.log(logging.INFO, 'Started running %s', name)
241      return await self._server.run_computation(
242          name,
243          plan,
244          input_checkpoint,
245          config.task_assignment_mode,
246          config.num_clients,
247      )
248    finally:
249      logging.log(logging.INFO, 'Finished running %s', name)
250
251  async def _resolve_value_references(
252      self, structure: Union[tff.structure.Struct, tf.Tensor,
253                             tff.program.MaterializableValueReference]
254  ) -> Union[tff.structure.Struct, tf.Tensor]:
255    """Dereferences any MaterializableValueReferences in a struct."""
256    if isinstance(structure, tff.program.MaterializableValueReference):
257      return await structure.get_value()  # pytype: disable=bad-return-type  # numpy-scalars
258    elif tf.is_tensor(structure):
259      return structure
260    elif isinstance(structure, tff.structure.Struct):
261      s = [
262          self._resolve_value_references(x)
263          for x in tff.structure.flatten(structure)
264      ]
265      return tff.structure.pack_sequence_as(structure, s)
266    else:
267      raise ValueError(
268          'arg[1] must be a struct, Tensor, or MaterializableValueReference.')
269
270  def _state_to_checkpoint(
271      self, state_type: tff.Type, state: Union[tff.structure.Struct,
272                                               tf.Tensor]) -> bytes:
273    """Converts computation input state to a checkpoint file.
274
275    The checkpoint file format is used to pass the state to
276    InProcessServer.run_computation.
277
278    Args:
279      state_type: The TFF type of the state structure.
280      state: A Tensor or TFF structure with input state for a computation.
281
282    Returns:
283      The state encoded as a checkpoint file.
284    """
285    var_names = variable_helpers.variable_names_from_type(
286        state_type, name=artifact_constants.SERVER_STATE_VAR_PREFIX)
287
288    # Write to a file in TensorFlow's RamFileSystem to avoid disk I/O.
289    tmpfile = f'ram://{uuid.uuid4()}.ckpt'
290    checkpoint_utils.save_tff_structure_to_checkpoint(state, var_names, tmpfile)
291    try:
292      with tf.io.gfile.GFile(tmpfile, 'rb') as f:
293        return f.read()
294    finally:
295      tf.io.gfile.remove(tmpfile)
296
297  def _create_tensor_reference_struct(
298      self, result_type: tff.Type,
299      checkpoint_future: Awaitable[bytes]) -> tff.structure.Struct:
300    """Creates the CheckpointTensorReference struct for a result type."""
301    shared_checkpoint_future = tff.async_utils.SharedAwaitable(
302        checkpoint_future)
303    tensor_specs = checkpoint_utils.tff_type_to_tensor_spec_list(result_type)
304    var_names = (
305        variable_helpers.variable_names_from_type(
306            result_type[0], name=artifact_constants.SERVER_STATE_VAR_PREFIX) +
307        variable_helpers.variable_names_from_type(
308            result_type[1], name=artifact_constants.SERVER_METRICS_VAR_PREFIX))
309    tensor_refs = [
310        checkpoint_tensor_reference.CheckpointTensorReference(
311            var_name, spec.dtype, spec.shape, shared_checkpoint_future)
312        for var_name, spec in zip(var_names, tensor_specs)
313    ]
314    return checkpoint_utils.pack_tff_value(result_type, tensor_refs)
315