1# Copyright 2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""TFF FederatedContext subclass for the demo Federated Computation platform.""" 15 16from collections.abc import Awaitable 17import socket 18import ssl 19import threading 20from typing import Any, Optional, Union 21import uuid 22 23from absl import logging 24import attr 25import numpy as np 26import tensorflow as tf 27import tensorflow_federated as tff 28import tree 29 30from fcp.artifact_building import artifact_constants 31from fcp.artifact_building import checkpoint_utils 32from fcp.artifact_building import data_spec 33from fcp.artifact_building import federated_compute_plan_builder 34from fcp.artifact_building import plan_utils 35from fcp.artifact_building import variable_helpers 36from fcp.demo import checkpoint_tensor_reference 37from fcp.demo import federated_computation 38from fcp.demo import federated_data_source 39from fcp.demo import server 40from fcp.protos import plan_pb2 41 42 43class FederatedContext(tff.program.FederatedContext): 44 """A FederatedContext for use with the demo platform.""" 45 46 def __init__(self, 47 population_name: str, 48 *, 49 base_context: Optional[tff.framework.SyncContext] = None, 50 host: str = 'localhost', 51 port: int = 0, 52 certfile: Optional[str] = None, 53 keyfile: Optional[str] = None, 54 address_family: Optional[socket.AddressFamily] = None): 55 """Initializes a `FederatedContext`. 56 57 Args: 58 population_name: The name of the population to execute computations on. 59 base_context: The context used to run non-federated TFF computations 60 (i.e., computations with a type other than FederatedComputation). 61 host: The hostname the server should bind to. 62 port: The port the server should listen on. 63 certfile: The path to the certificate to use for https. 64 keyfile: The path to the certificate's private key (if separate). 65 address_family: An override for the HTTP server's address family. 66 """ 67 # NOTE: The demo server only supports a single population, which must be 68 # specified at startup. An implementation that supports multiple populations 69 # should only use the population name from the PopulationDataSource. 70 if not federated_data_source.POPULATION_NAME_REGEX.fullmatch( 71 population_name): 72 raise ValueError( 73 'population_name must match ' 74 f'"{federated_data_source.POPULATION_NAME_REGEX.pattern}".') 75 self._population_name = population_name 76 self._base_context = base_context 77 self._server = server.InProcessServer( 78 population_name=population_name, 79 host=host, 80 port=port, 81 address_family=address_family) 82 if certfile is not None: 83 context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) 84 context.load_cert_chain(certfile, keyfile) 85 self._server.socket = context.wrap_socket( 86 self._server.socket, server_side=True) 87 self._server_thread = threading.Thread( 88 target=self._server.serve_forever, daemon=True) 89 self._cached_comps: dict[tuple[tff.Computation, int], plan_pb2.Plan] = {} 90 91 @property 92 def server_port(self) -> int: 93 """The port on which the Federated Compute server is running.""" 94 return self._server.server_port 95 96 def __enter__(self): 97 self._server_thread.start() 98 logging.log(logging.INFO, 'Federated Compute server running on %s:%s', 99 self._server.server_name, self._server.server_port) 100 return self 101 102 def __exit__(self, exc_type, exc_value, tb): 103 self._server.shutdown() 104 self._server_thread.join() 105 logging.log(logging.INFO, 'Federated Compute server stopped') 106 107 def invoke(self, comp: tff.Computation, arg: Any) -> Any: 108 """Invokes a computation. 109 110 Args: 111 comp: The computation being invoked. 112 arg: The arguments of the call encoded in a computation-specific way. For 113 FederatedComputations, this should be a `(state, config)` tuple, where 114 the state is a possibly nested structure and the configuration is 115 provided by a FederatedDataSource. 116 117 Returns: 118 A value reference structure representing the result of the computation. 119 """ 120 # Pass other computation types to the next FederatedContext. 121 if not isinstance(comp, federated_computation.FederatedComputation): 122 if not self._base_context: 123 raise TypeError('computation must be a FederatedComputation if no ' 124 'base_context was provided.') 125 return self._base_context.invoke(comp, arg) 126 127 state, config = self._parse_arg(arg) 128 if config.population_name != self._population_name: 129 raise ValueError('FederatedDataSource and FederatedContext ' 130 'population_names must match.') 131 132 # Since building the plan can be slow, we cache the result to speed up 133 # subsequent invocations. 134 cache_key = (comp.wrapped_computation, id(config.example_selector)) 135 try: 136 plan = self._cached_comps[cache_key] 137 except KeyError: 138 plan = federated_compute_plan_builder.build_plan( 139 comp.map_reduce_form, 140 comp.distribute_aggregate_form, 141 self._get_nested_data_spec(config.example_selector), 142 grappler_config=tf.compat.v1.ConfigProto(), 143 generate_server_phase_v2=True, 144 ) 145 # Add the TF Lite flatbuffer to the plan. If the conversion fails, the 146 # flatbuffer will be silently omitted and the client will use the 147 # TensorFlow graph in `plan.client_graph_bytes` instead. 148 # NOTE: If conversion failures should not be silent, pass 149 # `forgive_tflite_conversion_failure=False`. 150 plan = plan_utils.generate_and_add_flat_buffer_to_plan(plan) 151 self._cached_comps[cache_key] = plan 152 153 checkpoint_future = self._run_computation(comp.name, config, plan, 154 comp.type_signature.parameter[0], 155 state) 156 result_value_ref = self._create_tensor_reference_struct( 157 comp.type_signature.result, checkpoint_future) 158 return tff.types.type_to_py_container(result_value_ref, 159 comp.type_signature.result) 160 161 def _is_state_structure_of_allowed_types( 162 self, 163 structure: Union[ 164 tff.structure.Struct, 165 tf.Tensor, 166 tff.program.MaterializableValue, 167 ], 168 ) -> bool: 169 """Checks if each node in `structure` is an allowed type for `state`.""" 170 if isinstance(structure, tff.structure.Struct): 171 structure = tff.structure.flatten(structure) 172 else: 173 structure = tree.flatten(structure) 174 for item in structure: 175 if not ( 176 tf.is_tensor(item) 177 or isinstance( 178 item, 179 ( 180 np.ndarray, 181 np.number, 182 int, 183 float, 184 str, 185 bytes, 186 tff.program.MaterializableValueReference, 187 ), 188 ) 189 ): 190 return False 191 return True 192 193 def _parse_arg( 194 self, arg: tff.structure.Struct 195 ) -> tuple[Union[tff.structure.Struct, tf.Tensor, 196 tff.program.MaterializableValueReference], 197 federated_data_source.DataSelectionConfig]: 198 """Parses and validates the invoke arguments.""" 199 if len(arg) != 2: 200 raise ValueError(f'The argument structure is unsupported: {arg}.') 201 202 state, config = arg 203 if attr.has(type(state)): 204 state = tff.structure.from_container(state, recursive=True) 205 if not self._is_state_structure_of_allowed_types(state): 206 raise TypeError( 207 'arg[0] must be a value or structure of values of ' 208 '`MaterializableValueReference`s, `tf.Tensor`s, ' 209 '`np.ndarray`s, `np.number`s, or Python scalars. Got: ' 210 f'{tf.nest.map_structure(type, state)!r})' 211 ) 212 213 # Code below assumes single values are always `tf.Tensor`s. 214 if isinstance(state, (int, float, str, bytes, np.ndarray, np.number)): 215 state = tf.convert_to_tensor(state) 216 217 if not isinstance(config, federated_data_source.DataSelectionConfig): 218 raise TypeError('arg[1] must be the result of ' 219 'FederatedDataSource.iterator().select().') 220 return state, config 221 222 def _get_nested_data_spec(self, example_selector) -> data_spec.NestedDataSpec: 223 """Converts a NestedExampleSelector to a NestedDataSpec.""" 224 if isinstance(example_selector, dict): 225 return { 226 k: self._get_nested_data_spec(v) for k, v in example_selector.items() 227 } 228 return data_spec.DataSpec(example_selector) 229 230 async def _run_computation( 231 self, name: str, config: federated_data_source.DataSelectionConfig, 232 plan: plan_pb2.Plan, input_type: tff.Type, 233 input_state: Union[tff.structure.Struct, tf.Tensor, 234 tff.program.MaterializableValueReference] 235 ) -> bytes: 236 """Prepares and runs a computation using the demo server.""" 237 input_checkpoint = self._state_to_checkpoint( 238 input_type, await self._resolve_value_references(input_state)) 239 try: 240 logging.log(logging.INFO, 'Started running %s', name) 241 return await self._server.run_computation( 242 name, 243 plan, 244 input_checkpoint, 245 config.task_assignment_mode, 246 config.num_clients, 247 ) 248 finally: 249 logging.log(logging.INFO, 'Finished running %s', name) 250 251 async def _resolve_value_references( 252 self, structure: Union[tff.structure.Struct, tf.Tensor, 253 tff.program.MaterializableValueReference] 254 ) -> Union[tff.structure.Struct, tf.Tensor]: 255 """Dereferences any MaterializableValueReferences in a struct.""" 256 if isinstance(structure, tff.program.MaterializableValueReference): 257 return await structure.get_value() # pytype: disable=bad-return-type # numpy-scalars 258 elif tf.is_tensor(structure): 259 return structure 260 elif isinstance(structure, tff.structure.Struct): 261 s = [ 262 self._resolve_value_references(x) 263 for x in tff.structure.flatten(structure) 264 ] 265 return tff.structure.pack_sequence_as(structure, s) 266 else: 267 raise ValueError( 268 'arg[1] must be a struct, Tensor, or MaterializableValueReference.') 269 270 def _state_to_checkpoint( 271 self, state_type: tff.Type, state: Union[tff.structure.Struct, 272 tf.Tensor]) -> bytes: 273 """Converts computation input state to a checkpoint file. 274 275 The checkpoint file format is used to pass the state to 276 InProcessServer.run_computation. 277 278 Args: 279 state_type: The TFF type of the state structure. 280 state: A Tensor or TFF structure with input state for a computation. 281 282 Returns: 283 The state encoded as a checkpoint file. 284 """ 285 var_names = variable_helpers.variable_names_from_type( 286 state_type, name=artifact_constants.SERVER_STATE_VAR_PREFIX) 287 288 # Write to a file in TensorFlow's RamFileSystem to avoid disk I/O. 289 tmpfile = f'ram://{uuid.uuid4()}.ckpt' 290 checkpoint_utils.save_tff_structure_to_checkpoint(state, var_names, tmpfile) 291 try: 292 with tf.io.gfile.GFile(tmpfile, 'rb') as f: 293 return f.read() 294 finally: 295 tf.io.gfile.remove(tmpfile) 296 297 def _create_tensor_reference_struct( 298 self, result_type: tff.Type, 299 checkpoint_future: Awaitable[bytes]) -> tff.structure.Struct: 300 """Creates the CheckpointTensorReference struct for a result type.""" 301 shared_checkpoint_future = tff.async_utils.SharedAwaitable( 302 checkpoint_future) 303 tensor_specs = checkpoint_utils.tff_type_to_tensor_spec_list(result_type) 304 var_names = ( 305 variable_helpers.variable_names_from_type( 306 result_type[0], name=artifact_constants.SERVER_STATE_VAR_PREFIX) + 307 variable_helpers.variable_names_from_type( 308 result_type[1], name=artifact_constants.SERVER_METRICS_VAR_PREFIX)) 309 tensor_refs = [ 310 checkpoint_tensor_reference.CheckpointTensorReference( 311 var_name, spec.dtype, spec.shape, shared_checkpoint_future) 312 for var_name, spec in zip(var_names, tensor_specs) 313 ] 314 return checkpoint_utils.pack_tff_value(result_type, tensor_refs) 315