1# Copyright 2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Helper methods for TensorFlow variables.""" 15 16from typing import Optional, Union 17 18import tensorflow as tf 19import tensorflow_federated as tff 20 21from fcp.artifact_building import tensor_utils 22from fcp.artifact_building import type_checks 23 24# TFF types allowed for variables created at input/output serialization 25# boundaries. 26AllowedTffTypes = Union[tff.TensorType, tff.StructType, tff.FederatedType] 27 28 29# The prefix for the name of the sidechannel for a securely-summed variable. 30# 31# This transformed name is used as the name of the Op which *reads* from the 32# variable, rather than identifies the variable itself. Names with this prefix 33# are used as the keys in the `side_channel_tensors` map entries corresponding 34# with the variable of the unprefixed name. 35SIDECHANNEL_NAME_PREFIX = 'sidechannel_' 36 37# `variable_names_from_type` returns the `name` argument of `tf.Variable()`. 38# However when the variable is created, the name of its tensor is actually 39# `<name>:0`. This macro is created to match this behavior. 40_TF_TENSOR_NAME_SUFFIX = ':0' 41 42 43def _create_var_for_tff_tensor( 44 tff_type: tff.TensorType, name: str, **kwargs 45) -> tf.Variable: 46 """Creates a TensorFlow variable to hold a value of the `tff.TensorType`.""" 47 type_checks.check_type(tff_type, tff.TensorType) 48 type_checks.check_type(name, str) 49 # `tff_type` can have shapes that contain `None` or `0`: 50 # * `None` shape cannot be used in `tf.zeros` to create the initial value 51 # of a `tf.Variable`. Hence, we replace it with a `0` in `tf.zeros`. 52 # * The dimension that has `0` shape may change its shape at run time. To 53 # support this, we use `None` for that dimension when creating the 54 # `tf.Variable`. 55 initial_value_shape = [] 56 variable_shape = [] 57 for shape in tff_type.shape.as_list(): 58 if shape is None or shape == 0: 59 initial_value_shape.append(0) 60 variable_shape.append(None) 61 else: 62 initial_value_shape.append(shape) 63 variable_shape.append(shape) 64 return tf.Variable( 65 initial_value=tf.zeros(shape=initial_value_shape, dtype=tff_type.dtype), 66 name=name, 67 dtype=tff_type.dtype, 68 shape=variable_shape, 69 **kwargs, 70 ) 71 72 73# Build the TensorSpec for the values we will send to the client so that the 74# client graph will know how to read the incoming values. 75def tensorspec_from_var(var: tf.Variable) -> tf.TensorSpec: 76 """Builds `tf.TensorSpec` from `tf.Variables`. 77 78 Args: 79 var: An instance of `tf.Variable`. 80 81 Returns: 82 An instance of `tf.TensorSpec` corresponding to the input `tf.Variable`. 83 """ 84 return tf.TensorSpec( 85 shape=var.shape, dtype=var.dtype, name=tensor_utils.bare_name(var.name) 86 ) 87 88 89def create_vars_for_tff_type( 90 tff_type: AllowedTffTypes, name: Optional[str] = None, **kwargs 91) -> list[tf.Variable]: 92 """Creates TensorFlow variables to hold a value of the given `tff_type`. 93 94 The variables are created in the default graph and scope. The variables are 95 automatically given `tf.zeros` initializers. 96 97 Args: 98 tff_type: Either a `tff.StructType`, SERVER-placed `tff.FederatedType` or a 99 `tff.TensorType` object. 100 name: The preferred name to use at the top-most level (if not None, must be 101 a string). If `tff_type` is a `tff.StructType`, the names of the inner 102 fields will be scoped under `name`, e.g. `some_name/field_name`. 103 **kwargs: Optional arguments, if any, to pass to the `tf.Variable()` calls. 104 105 Returns: 106 A flat Python `list` of TensorFlow variable instances. 107 108 Raises: 109 TypeError: If the argument is of the wrong type or has the wrong placement. 110 """ 111 type_checks.check_type( 112 tff_type, 113 (tff.TensorType, tff.StructType, tff.FederatedType), 114 name='tff_type', 115 ) 116 if name is not None: 117 type_checks.check_type(name, str) 118 else: 119 name = 'v' 120 if isinstance(tff_type, tff.TensorType): 121 return [_create_var_for_tff_tensor(tff_type, name, **kwargs)] 122 elif isinstance(tff_type, tff.FederatedType): 123 if tff_type.placement != tff.SERVER: 124 raise TypeError( 125 'Can only create vars for unplaced types or types placed ' 126 'on the SERVER.' 127 ) 128 return create_vars_for_tff_type(tff_type.member, name, **kwargs) 129 else: # tff.StructType 130 result = [] 131 with tf.compat.v1.variable_scope(name): 132 fields = tff.structure.to_elements(tff_type) 133 for index, (field_name, field_type) in enumerate(fields): 134 # Default the name of the element to its index so that we don't wind up 135 # with multiple child fields listed under `/v/` 136 if field_name is None: 137 field_name = str(index) 138 result.extend( 139 create_vars_for_tff_type(field_type, name=field_name, **kwargs) 140 ) 141 return result 142 143 144def variable_names_from_type( 145 tff_type: AllowedTffTypes, name: str = 'v' 146) -> list[str]: 147 """Creates a flattened list of variables names for the given `tff_type`. 148 149 If `tff_type` is a `tff.TensorType`, the name is the `name` parameter if 150 specified, otherwise a default name: `v`. If `tff_type` is a 151 `tff.StructType` then '/' is used between inner and outer fields together 152 with the tuple name or index of the element in the tuple. 153 154 Some examples: 155 1. If the tff_type is `<'a'=tf.int32, 'b'=tf.int32>` and `name` is not 156 specified, the returned variable name list is ['v/a', 'v/b']. 157 2. If the tff_type is `<tf.int32, tf.int32>` and `name` is `update`, the 158 returned variable name list is ['update/0', 'update/1']. 159 3. If the tff_type is `<'a'=<'b'=tf.int32, 'c'=tf.int32>>` and `name` is 160 `update`, the returned variable name list is ['update/a/b', 'update/a/c']. 161 4. If the tff_type is `<'a'=<'b'=tf.int32, 'c'=tf.int32, tf.int32>>` and 162 `name` is `update`, the returned variable name list is ['update/a/b', 163 'update/a/c', 'update/a/2']. 164 165 Args: 166 tff_type: Either a `tff.StructType`, a `tff.FederatedType` or a 167 `tff.TensorType` object. 168 name: The preferred name to use at the top-most level (if not None, must be 169 a string). If `tff_type` is a `tff.StructType`, the names of the inner 170 fields will be scoped under `name`, e.g. `some_name/field_name`. 171 172 Returns: 173 A flat Python `list` of `str` names. 174 175 Raises: 176 TypeError: If the argument is of the wrong type. 177 """ 178 type_checks.check_type( 179 tff_type, 180 (tff.TensorType, tff.FederatedType, tff.StructType), 181 name='tff_type', 182 ) 183 type_checks.check_type(name, str, name='name') 184 if isinstance(tff_type, tff.TensorType): 185 return [name] 186 elif isinstance(tff_type, tff.FederatedType): 187 return variable_names_from_type(tff_type.member, name) 188 elif isinstance(tff_type, tff.StructType): 189 result = [] 190 fields = tff.structure.iter_elements(tff_type) 191 for index, (field_name, field_type) in enumerate(fields): 192 # Default the name of the element to its index so that we don't wind up 193 # with multiple child fields listed under `/v/` 194 field_name = field_name or str(index) 195 result.extend( 196 variable_names_from_type(field_type, name=name + '/' + field_name) 197 ) 198 return result 199 else: 200 raise TypeError( 201 'Cannot create variable names from [{t}] TFF type. ' 202 'Short-hand: {s}'.format(t=type(tff_type), s=tff_type) 203 ) 204 205 206def get_shared_secagg_tensor_names( 207 intrinsic_name: str, tff_type: AllowedTffTypes 208) -> list[str]: 209 """Creates the shared name of secagg tensors in client and server graph. 210 211 This is the canonical function for ensuring the secagg tensor names in the 212 client and server graph are the same. The server uses secagg tensor 213 names as the keys to retrieve values from secagg server which are originally 214 from client graph, so if the secagg tensor names in the client and server 215 graph are not the same, the server could not find secagg tensors. This 216 function is created to ensure this implicit dependency. 217 218 Args: 219 intrinsic_name: The name of the secure aggregation intrinsic being used. 220 tff_type: Either a `tff.StructType`, `tff.FederatedType` or a 221 `tff.TensorType` object. 222 223 Returns: 224 A list of variable names created from the input TFF type. 225 """ 226 tensor_names = variable_names_from_type( 227 tff_type, f'secagg_{intrinsic_name}_update' 228 ) 229 return [ 230 SIDECHANNEL_NAME_PREFIX + name + _TF_TENSOR_NAME_SUFFIX 231 for name in tensor_names 232 ] 233 234 235def get_flattened_tensor_specs( 236 tff_type: AllowedTffTypes, name: str 237) -> list[tf.TensorSpec]: 238 """Generates TensorSpecs for a flattened version of the given `tff_type`. 239 240 This function uses the same naming logic as `variable_names_from_type`. Please 241 see that function's docstring. 242 243 Args: 244 tff_type: Either a `tff.StructType`, a `tff.FederatedType` or a 245 `tff.TensorType` object. 246 name: The preferred name to use at the top-most level (if not None, must be 247 a string). If `tff_type` is a `tff.StructType`, the names of the inner 248 fields will be scoped under `name`, e.g. `some_name/field_name`. 249 250 Returns: 251 A flat Python `list` of `TensorSpec`s. 252 253 Raises: 254 TypeError: If the argument is of the wrong type. 255 """ 256 type_checks.check_type( 257 tff_type, 258 (tff.TensorType, tff.FederatedType, tff.StructType), 259 name='tff_type', 260 ) 261 type_checks.check_type(name, str, name='name') 262 if isinstance(tff_type, tff.TensorType): 263 return [tf.TensorSpec(tff_type.shape, tff_type.dtype, name=name)] 264 elif isinstance(tff_type, tff.FederatedType): 265 return get_flattened_tensor_specs(tff_type.member, name) 266 elif isinstance(tff_type, tff.StructType): 267 result = [] 268 fields = tff.structure.iter_elements(tff_type) 269 for index, (field_name, field_type) in enumerate(fields): 270 # Default the name of the element to its index so that we don't wind up 271 # with multiple child fields listed under `/v/` 272 field_name = field_name or str(index) 273 result.extend( 274 get_flattened_tensor_specs(field_type, name=name + '/' + field_name) 275 ) 276 return result 277 else: 278 raise TypeError( 279 'Cannot create TensorSpecs from [{t}] TFF type. Short-hand: {s}'.format( 280 t=type(tff_type), s=tff_type 281 ) 282 ) 283 284 285def get_grouped_input_tensor_specs_for_aggregations( 286 aggregation_comp: tff.framework.ComputationBuildingBlock, 287 names: dict[int, str], 288) -> list[list[list[tf.TensorSpec]]]: 289 """Gets the input TensorSpecs for an aggregation computation. 290 291 This function can be used to generate the TensorSpecs that are assigned to 292 ServerAggregationConfig.IntrinsicArg messages to represent the aggregation 293 intrinsic calls in DistributeAggregateForm.client_to_server_aggregation. 294 295 It derives the tensor name(s) for each intrinsic input argument by following 296 naming logic similar to `variable_names_from_type`. DistributeAggregateForm 297 does guarantee that each intrinsic input argument will be a 298 `building_block.Selection` or a (potentially nested) struct of 299 `building_block.Selection`s. The first element of the path is used to 300 determine the top-level name, which must match the top-level name that was 301 used to construct the tensor that will be getting consumed by this argument. 302 303 Args: 304 aggregation_comp: The aggregation computation. 305 names: A dictionary describing how to map the first element of the path to a 306 top-level name. 307 308 Returns: 309 A `list` where the ith entry represents the input tensor specs for the 310 ith intrinsic in the aggregation computation. The ith entry is itself a list 311 where the jth entry represents the input tensor specs for the jth argument 312 of the ith intrinsic in the aggregation computation. 313 314 Raises: 315 TypeError: If the argument is of the wrong type. 316 ValueError: If the argument contains an unexpected 317 `building_block.Selection` index. 318 """ 319 320 def _get_selection_path( 321 selection: tff.framework.ComputationBuildingBlock, 322 ) -> list[int]: 323 """Gets the list of selection indices for a building_blocks.Selection.""" 324 325 path = [] 326 while selection.is_selection(): 327 path.append(selection.index) # pytype: disable=attribute-error 328 selection = selection.source # pytype: disable=attribute-error 329 # In ASTs like x[0][1], we'll see the last (outermost) selection first. 330 path.reverse() 331 return path 332 333 def _get_input_tensor_specs_for_aggregation_arg( 334 value: tff.framework.ComputationBuildingBlock, names: dict[int, str] 335 ) -> list[tf.TensorSpec]: 336 """Gets the input TensorSpecs for a single intrinsic argument.""" 337 338 # An intrinsic arg may be a `building_block.Selection` or a (potentially 339 # nested) struct of `building_block.Selection`s. Start by creating a 340 # flattened list of the `building_block.Selection`s. 341 inner_values = [] 342 if value.is_struct(): 343 inner_values = tff.structure.flatten(value) 344 else: 345 inner_values = [value] 346 347 # For each `building_block.Selection`, reconstruct the tensor name that 348 # will be used to supply that value. The first index of the selection path 349 # indicates whether the tensor will be coming from the intermediate state 350 # checkpoint (0) or from the client checkpoint (1), since TFF condenses 351 # daf.client_to_server_aggregation(temp_server_state, client_update) 352 # into a 1-arg function. Since the tensors within the checkpoints 353 # corresponding to temp_server_state and work_at_clients will be named using 354 # variable_names_from_type, which uses a simple filepath-like naming pattern 355 # to refer to the tensors within a struct, we can reconstruct the relevant 356 # tensor name by concatenating together the remaining indices of each 357 # selection path. 358 tensor_specs = [] 359 for inner_value in inner_values: 360 inner_value.check_selection() 361 path = _get_selection_path(inner_value) 362 arg_index = path[0] 363 if arg_index in names: 364 prefix = names[arg_index] 365 else: 366 raise ValueError('Unexpected arg index for aggregation selection') 367 prefix += '/' + '/'.join([str(x) for x in path[1:]]) 368 tensor_specs.extend( 369 get_flattened_tensor_specs(inner_value.type_signature, name=prefix) 370 ) 371 372 return tensor_specs 373 374 grouped_input_tensor_specs = [] 375 376 for _, local_value in aggregation_comp.result.locals: # pytype: disable=attribute-error 377 local_value.check_call() 378 local_value.function.check_intrinsic() 379 assert local_value.function.intrinsic_def().aggregation_kind 380 381 # Collect the input TensorFlowSpecs for each argument for this intrinsic. 382 input_tensor_specs_for_intrinsic = [] 383 if ( 384 local_value.function.intrinsic_def().type_signature.parameter.is_struct() 385 ): 386 for element in local_value.argument.children(): 387 input_tensor_specs_for_intrinsic.append( 388 _get_input_tensor_specs_for_aggregation_arg(element, names) 389 ) 390 else: 391 input_tensor_specs_for_intrinsic.append( 392 _get_input_tensor_specs_for_aggregation_arg( 393 local_value.argument, names 394 ) 395 ) 396 397 grouped_input_tensor_specs.append(input_tensor_specs_for_intrinsic) 398 399 return grouped_input_tensor_specs 400 401 402def get_grouped_output_tensor_specs_for_aggregations( 403 aggregation_comp: tff.framework.ComputationBuildingBlock, 404) -> list[list[tf.TensorSpec]]: 405 """Gets the output TensorSpecs for an aggregation computation. 406 407 This function can be used to generate the TensorSpecs that are assigned 408 to the output_tensors value in ServerAggregationConfig messages to represent 409 the aggregation intrinsic calls in 410 DistributeAggregateForm.client_to_server_aggregation. 411 412 It derives the tensor name(s) for each intrinsic output argument by following 413 naming logic similar to `variable_names_from_type`. It must produce tensor 414 names that match the tensor names that are expected by the post-aggregation 415 computation. 416 417 Args: 418 aggregation_comp: The aggregation computation. 419 420 Returns: 421 A list where the ith entry represents the output tensor specs for the ith 422 intrinsic in the aggregation computation. 423 424 Raises: 425 TypeError: If the argument is of the wrong type. 426 """ 427 # TensorflowSpecs for all the intrinsic results. These TensorflowSpecs must 428 # have names that mirror the result of calling variable_names_from_type on 429 # the output type of DistributeAggregateForm.client_to_server_aggregation 430 # (which is the same as the type of the aggregation result input arg in 431 # DistributeAggregateForm.server_result). 432 output_tensor_specs = get_flattened_tensor_specs( 433 tff.StructType([aggregation_comp.type_signature.result]), 434 name='intermediate_update', 435 ) 436 output_tensor_spec_index = 0 437 438 grouped_output_tensor_specs = [] 439 440 for _, local_value in aggregation_comp.result.locals: # pytype: disable=attribute-error 441 local_value.check_call() 442 local_value.function.check_intrinsic() 443 local_value.type_signature.check_federated() 444 assert local_value.function.intrinsic_def().aggregation_kind 445 446 tensor_specs = [] 447 # If the output is a struct, select the appropriate number of 448 # TensorflowSpecs. 449 if local_value.type_signature.member.is_struct(): 450 num_specs = len(tff.structure.flatten(local_value.type_signature.member)) 451 tensor_specs = output_tensor_specs[ 452 output_tensor_spec_index : output_tensor_spec_index + num_specs 453 ] 454 output_tensor_spec_index += num_specs 455 else: 456 tensor_specs.append(output_tensor_specs[output_tensor_spec_index]) 457 output_tensor_spec_index += 1 458 grouped_output_tensor_specs.append(tensor_specs) 459 460 return grouped_output_tensor_specs 461