# Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Helper methods for working with demo server checkpoints.""" import collections from collections.abc import Callable, Iterable, Mapping from typing import Any, Optional, Union import numpy as np import tensorflow as tf import tensorflow_federated as tff from fcp.artifact_building import artifact_constants from fcp.artifact_building import tensor_utils from fcp.artifact_building import type_checks from fcp.artifact_building import variable_helpers from fcp.protos import plan_pb2 SAVE_SERVER_SAVEPOINT_NAME = 'save_server_savepoint' def create_server_checkpoint_vars_and_savepoint( *, server_state_type: tff.StructType, server_metrics_type: Optional[tff.StructType] = None, write_metrics_to_checkpoint: bool = True, additional_checkpoint_metadata_var_fn: Optional[ Callable[[tff.StructType, tff.StructType, bool], list[tf.Variable]] ] = None, ) -> tuple[ list[tf.Variable], list[tf.Variable], list[tf.Variable], plan_pb2.CheckpointOp, ]: """Creates tf.Variables for a server checkpoint and the associated savepoint. The variables and the associated saver are constructed in the default graph. For now, only `server_state_type` is required. If metrics are to be saved in the server checkpoint, `server_metrics_type` and `server_result_type` must be provided. `server_state_type` refers to the server state portion of the checkpoint and is used in the `Restore` op of the savepoint. The `server_metrics_type` refers to the metrics saved in the checkpoint, and is not used in the `Restore` op of the savepoint. `server_result_type` refers to the complete round result structure stored in the checkpoint for a round. Args: server_state_type: A `tff.Type` with the type signature of the state. This is used to construct the server state variable names stored in the checkpoint and is used to create the metadata variables for the checkpoint if `server_result_type` is not provided. server_metrics_type: Optional. A `tff.Type` with the type signature of the metrics. If provided, this is used to construct the metric variable names that are stored in the checkpoint. write_metrics_to_checkpoint: If False, revert to legacy behavior where metrics and other non-state values were handled by post-processing separate from the outputted checkpoint. additional_checkpoint_metadata_var_fn: An optional method that takes in the server_state_type, server_metrics_type, and write_metrics_to_checkpoint to produce additional metadata variables. Returns: A tuple `(state_vars, metric_vars, metadata_vars, savepoint)`: - `state_vars` is a Python `list` of variables that hold the state. - `metric_vars` is a Python `list` of variables that hold the metrics. - `metadata_vars` is a Python `list` of variables that hold optional metadata. - `savepoint` is the associated savepoint, i.e., an instance of `plan_pb2.CheckpointOp` with a saver configured for saving the `state_vars`, `metadata_vars`, and, if write_metrics_to_checkpoint is True, `metric_vars`, and restoring the `state_vars` and `metadata_vars`. """ has_metrics = False metric_vars = [] save_tensor_name = None type_checks.check_type(server_state_type, tff.Type, name='server_state_type') state_vars = variable_helpers.create_vars_for_tff_type( server_state_type, artifact_constants.SERVER_STATE_VAR_PREFIX ) var_names = list(map(tensor_utils.bare_name, state_vars)) metadata_vars = [] if server_metrics_type is not None: type_checks.check_type( server_metrics_type, tff.Type, name='server_metrics_type' ) metric_vars = variable_helpers.create_vars_for_tff_type( server_metrics_type, artifact_constants.SERVER_METRICS_VAR_PREFIX ) if additional_checkpoint_metadata_var_fn: metadata_vars = additional_checkpoint_metadata_var_fn( state_vars, metric_vars, write_metrics_to_checkpoint ) has_metrics = bool(tff.structure.flatten(server_metrics_type)) if has_metrics and write_metrics_to_checkpoint: var_names.extend(list(map(tensor_utils.bare_name, metric_vars))) temp_saver_for_all_vars = create_deterministic_saver( var_list=state_vars + metadata_vars + metric_vars, name=SAVE_SERVER_SAVEPOINT_NAME, ) temp_saver_def = temp_saver_for_all_vars.as_saver_def() save_tensor_name = temp_saver_def.save_tensor_name else: if additional_checkpoint_metadata_var_fn: metadata_vars = additional_checkpoint_metadata_var_fn( state_vars, None, write_metrics_to_checkpoint ) saver = create_deterministic_saver( var_list=state_vars + metadata_vars, name='{}_savepoint'.format(artifact_constants.SERVER_STATE_VAR_PREFIX), ) savepoint = plan_pb2.CheckpointOp() savepoint.saver_def.CopyFrom(saver.as_saver_def()) if save_tensor_name is not None: # Replace the save_tensor_name to the one in # temp_saver_for_all_vars so that we are additionally saving metrics vars # in the checkpoint that don't need to be restored as part of the input # computation state. # Once we create the server GraphDef, we will edit the GraphDef directly # to ensure the input filename links to the filename tensor from the # `savepoint`. savepoint.saver_def.save_tensor_name = save_tensor_name return state_vars, metric_vars, metadata_vars, savepoint def create_state_vars_and_savepoint( type_spec: variable_helpers.AllowedTffTypes, name: str ) -> tuple[list[tf.Variable], plan_pb2.CheckpointOp]: """Creates state variables and their savepoint as a `plan_pb2.CheckpointOp`. The variables and the associated saver are constructed in the default graph. Args: type_spec: An instance of `tff.Type` with the type signature of the state. name: The string to use as a basis for naming the vars and the saver. The vars will be under `${name}_state`, and saver under `${name}_savepoint`. Returns: A tuple `(vars, savepoint)`, where `vars` is a Python `list` of variables that hold the state, and `savepoint` is the associated savepoint, i.e., an instance of `plan_pb2.CheckpointOp` with a saver configured for saving and restoring the `vars`. Raises: ValueError: If the name is empty. """ state_vars, saver = create_state_vars_and_saver(type_spec, name) savepoint = plan_pb2.CheckpointOp() savepoint.saver_def.CopyFrom(saver.as_saver_def()) return state_vars, savepoint def create_state_vars_and_saver( type_spec: variable_helpers.AllowedTffTypes, name: str ) -> tuple[list[tf.Variable], tf.compat.v1.train.Saver]: """Creates state variables and the associated saver. The variables and the associated saver are constructed in the default graph. Args: type_spec: An instance of `tff.Type` with the type signature of the state. name: The string to use as a basis for naming the vars and the saver. The vars will be under `${name}_state`, and saver under `${name}_savepoint`. Returns: A tuple `(vars, savepoint)`, where `vars` is a Python `list` of variables that hold the state, and `savepoint` is the associated `tf.compat.v1.train.Saver`. Raises: ValueError: If the name is empty. """ type_checks.check_type(type_spec, tff.Type, name='type_spec') type_checks.check_type(name, str, name='name') if not name: raise ValueError('Name cannot be empty.') state_vars = variable_helpers.create_vars_for_tff_type(type_spec, name) saver = create_deterministic_saver( state_vars, name='{}_savepoint'.format(name) ) return state_vars, saver def restore_tensors_from_savepoint( tensor_specs: Iterable[tf.TensorSpec], filepath_tensor: tf.Tensor ) -> list[tf.Tensor]: """Restores tensors from a checkpoint designated by a tensor filepath. Args: tensor_specs: A `list` of `tf.TensorSpec`s with the names and dtypes of the tensors to restore. filepath_tensor: A placeholder tensor that contains file names with a given pattern. Returns: A list of restored tensors. """ return [ tensor_utils.restore( filepath_tensor, tensor_utils.bare_name(spec.name), spec.dtype ) for spec in tensor_specs ] def create_deterministic_saver( var_list: Union[Iterable[tf.Variable], Mapping[str, tf.Variable]], *args, **kwargs, ) -> tf.compat.v1.train.Saver: """Creates a `tf.compat.v1.Saver` that is deterministic. This method sorts the `var_list` to ensure a deterministic ordering which in turn ensures a deterministic checkpoint. Uses `tf.compat.v1.train.SaverDef.V1` version for writing checkpoints. Args: var_list: An `Iterable` or `str` keyed `Mapping` of `tf.Variables`. In the case of a `dict`, the keys become the names of the checkpoint variables (rather than reading the names off the `tf.Variable` values). *args: Positional arguments forwarded to the `tf.compat.v1.train.Saver` constructor. **kwargs: Keyword arguments forwarded to the `tf.compat.v1.train.Saver` constructor. Returns: A `tf.compat.v1.train.Saver` instance. """ if isinstance(var_list, collections.abc.Mapping): determinisic_names = collections.OrderedDict(sorted(var_list.items())) elif isinstance(var_list, collections.abc.Iterable): determinisic_names = sorted(var_list, key=lambda v: v.name) else: raise ValueError( 'Do not know how to make a deterministic saver for ' '`var_list` of type [{t}]. Must be a Mapping or Sequence'.format( t=type(var_list) ) ) return tf.compat.v1.train.Saver( determinisic_names, write_version=tf.compat.v1.train.SaverDef.V1, *args, **kwargs, ) def tff_type_to_dtype_list( tff_type: variable_helpers.AllowedTffTypes, ) -> list[tf.DType]: """Creates a flat list of `tf.DType`s for tensors in a `tff.Type`. Args: tff_type: Either a `tff.StructType`, `tff.FederatedType`, or a `tff.TensorType` object. Returns: A flat list of `tf.DType`s. """ type_checks.check_type( tff_type, (tff.TensorType, tff.FederatedType, tff.StructType) ) if isinstance(tff_type, tff.TensorType): return [tff_type.dtype] elif isinstance(tff_type, tff.FederatedType): return tff_type_to_dtype_list(tff_type.member) else: # tff.StructType elem_list = [] for elem_type in tff_type: elem_list.extend(tff_type_to_dtype_list(elem_type)) return elem_list def tff_type_to_tensor_spec_list( tff_type: variable_helpers.AllowedTffTypes, ) -> list[tf.TensorSpec]: """Creates a flat list of tensor specs for tensors in a `tff.Type`. Args: tff_type: Either a `tff.StructType`, `tff.FederatedType` or a `tff.TensorType` object. Returns: A flat list of `tf.TensorSpec`s. """ type_checks.check_type( tff_type, (tff.TensorType, tff.FederatedType, tff.StructType) ) if isinstance(tff_type, tff.TensorType): return [tf.TensorSpec(tff_type.shape, dtype=tff_type.dtype)] elif isinstance(tff_type, tff.FederatedType): return tff_type_to_tensor_spec_list(tff_type.member) else: # tff.StructType elem_list = [] for elem_type in tff_type: elem_list.extend(tff_type_to_tensor_spec_list(elem_type)) return elem_list def pack_tff_value( tff_type: variable_helpers.AllowedTffTypes, value_list: Any ) -> Any: """Packs a list of values into a shape specified by a `tff.Type`. Args: tff_type: Either a `tff.StructType`, `tff.FederatedType`, or a `tff.TensorType` object. value_list: A flat list of `tf.Tensor` or `CheckpointTensorReference`. Returns: A Python container with a structure consistent with a `tff.Type`. Raises: ValueError: If the number of leaves in `tff_type` does not match the length of `value_list`, or `tff_type` is of a disallowed type. """ type_checks.check_type( tff_type, (tff.TensorType, tff.FederatedType, tff.StructType) ) # We must "unwrap" any FederatedTypes because the # `tff.structure.pack_sequence_as` call below will fail to recurse into them. # Instead, we remove all the FederatedTypes, because we're only trying to # build up a Python tree structure that matches the struct/tensor types from a # list of values. def remove_federated_types( type_spec: tff.Type, ) -> Union[tff.StructType, tff.TensorType]: """Removes `FederatedType` from a type tree, returning a new tree.""" if type_spec.is_tensor(): return type_spec elif type_spec.is_federated(): return type_spec.member elif type_spec.is_struct(): return tff.StructType( (elem_name, remove_federated_types(elem_type)) for elem_name, elem_type in tff.structure.iter_elements(type_spec) ) else: raise ValueError( 'Must be either tff.TensorType, tff.FederatedType, or tff.StructType.' f' Got a {type(type_spec)}' ) try: tff_type = remove_federated_types(tff_type) except ValueError as e: raise ValueError( '`tff_type` is not packable, see earlier error. ' f'Attempted to pack type: {tff_type}' ) from e ordered_dtypes = tff_type_to_dtype_list(tff_type) if len(ordered_dtypes) != len(value_list): raise ValueError( 'The number of leaves in `tff_type` must equals the length' ' of `value_list`. Found `tff_type` with' f' {len(ordered_dtypes)} leaves and `value_list` of length' f' {len(value_list)}.' ) if tff_type.is_tensor(): return value_list[0] elif tff_type.is_struct(): return tff.structure.pack_sequence_as(tff_type, value_list) else: raise ValueError( '`tff_type` must be either tff.TensorType or ' 'tff.StructType, reaching here is an internal coding ' 'error, please file a bug.' ) def variable_names_from_structure( tff_structure: Union[tff.structure.Struct, tf.Tensor], name: str = 'v' ) -> list[str]: """Creates a flattened list of variable names for the given structure. If the `tff_structure` is a `tf.Tensor`, the name is the `name` parameter if specified, otheriwse a default name: `v`. If `tff_structure` is a `tff.structure.Struct` then '/' is used between inner and outer fields together with the tuple name or index of the element in the tuple. Some examples: 1. If the `tff_structure` is `<'a'=tf.constant(1.0), 'b'=tf.constant(0.0)>` and name is not specified, the returned variable name list is ['v/a', 'v/b']. 2. If the `tff_structure` is `` and `name` is `update`, the returned variable name list is ['update/0', 'update/1']. 3. If the `tff_structure` is `<'a'=<'b'=tf.constant(1.0), 'c'=tf.constant(0.0)>>` and `name` is `update`, the returned variable name list is ['update/a/b', 'update/a/c']. 4. If the `tff_structure` is `<'a'=<'b'=tf.constant(1.0), 'c'=tf.constant(1.0), tf.constant(0.0)>>` and `name` is `update`, the returned variable name list is ['update/a/b', 'update/a/c', 'update/a/2']. Args: tff_structure: Either a `tff.structure.Struct` or a `tf.Tensor` object. name: The preferred name to use at the top-most level (if not None, must be a string). If `tff_structure` is a `tff.structure.Struct`, the names of the inner fields will be scoped under `name`, e.g. `some_name/field_name`. Returns: A flat Python `list` of `str` names. Raises: TypeError: If either argument is of the wrong type. """ type_checks.check_type( tff_structure, (tff.structure.Struct, tf.Tensor), name='structure_type' ) type_checks.check_type(name, str, name='name') if isinstance(tff_structure, tf.Tensor): return [name] elif isinstance(tff_structure, tff.structure.Struct): result = [] fields = tff.structure.iter_elements(tff_structure) for index, (field_name, field_type) in enumerate(fields): # Default the name of the element to its index so that we don't wind up # with multiple child fields listed under `/v/` field_name = field_name or str(index) result.extend( variable_names_from_structure( field_type, name=name + '/' + field_name ) ) return result else: raise TypeError( 'Cannot create variable names from [{t}] type. Short-hand: {s}'.format( t=type(tff_structure), s=tff_structure ) ) def is_structure_of_allowed_types( structure: Union[ tff.structure.Struct, tf.Tensor, np.ndarray, np.number, int, float, str, bytes, ] ) -> bool: """Checks if each node in `structure` is an allowed type for serialization.""" flattened_structure = tff.structure.flatten(structure) for item in flattened_structure: if not ( tf.is_tensor(item) or isinstance(item, (np.ndarray, np.number, int, float, str, bytes)) ): return False return True def save_tff_structure_to_checkpoint( tff_structure: Union[tff.structure.Struct, tf.Tensor], ordered_var_names: list[str], output_checkpoint_path: str, ) -> None: """Saves a TFF structure to a checkpoint file. The input `tff_structure` is a either `tff.structure.Struct` or a single `tf.Tensor`. This function saves `tff_structure` to a checkpoint file using variable names supplied via the `ordered_var_names` argument. Args: tff_structure: A `tff.structure.Struct` of values or a single value. Each leaf in the structure must be a value serializable to a TensorFlow checkpoint. ordered_var_names: The list of variable names for the values that appear in `tff_structure` after calling `tff.structure.flatten()`. output_checkpoint_path: A string specifying the path to the output checkpoint file. Raises: TypeError: If not all leaves in `tff_structure` are of allowed types. ValueError: If the number of `tf.Tensor`s in `tff_structure` does not match the size of `ordered_var_names`. """ if not is_structure_of_allowed_types(tff_structure): raise TypeError( 'Not all leaves in `tff_structure` are `tf.Tensor`s, ' '`np.ndarray`s, `np.number`s, or Python scalars. Got: ' f'{tff.structure.map_structure(type, tff_structure)!r})' ) tensors = tff.structure.flatten(tff_structure) if len(tensors) != len(ordered_var_names): raise ValueError( 'The length of `ordered_var_names` does not match the ' 'number of tensors in `tff_structure`:' f'{len(ordered_var_names)} != {len(tensors)}' ) tensor_utils.save( output_checkpoint_path, tensor_names=ordered_var_names, tensors=tensors )