xref: /aosp_15_r20/external/federated-compute/fcp/artifact_building/variable_helpers.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Helper methods for TensorFlow variables."""
15
16from typing import Optional, Union
17
18import tensorflow as tf
19import tensorflow_federated as tff
20
21from fcp.artifact_building import tensor_utils
22from fcp.artifact_building import type_checks
23
24# TFF types allowed for variables created at input/output serialization
25# boundaries.
26AllowedTffTypes = Union[tff.TensorType, tff.StructType, tff.FederatedType]
27
28
29# The prefix for the name of the sidechannel for a securely-summed variable.
30#
31# This transformed name is used as the name of the Op which *reads* from the
32# variable, rather than identifies the variable itself. Names with this prefix
33# are used as the keys in the `side_channel_tensors` map entries corresponding
34# with the variable of the unprefixed name.
35SIDECHANNEL_NAME_PREFIX = 'sidechannel_'
36
37# `variable_names_from_type` returns the `name` argument of `tf.Variable()`.
38# However when the variable is created, the name of its tensor is actually
39# `<name>:0`. This macro is created to match this behavior.
40_TF_TENSOR_NAME_SUFFIX = ':0'
41
42
43def _create_var_for_tff_tensor(
44    tff_type: tff.TensorType, name: str, **kwargs
45) -> tf.Variable:
46  """Creates a TensorFlow variable to hold a value of the `tff.TensorType`."""
47  type_checks.check_type(tff_type, tff.TensorType)
48  type_checks.check_type(name, str)
49  # `tff_type` can have shapes that contain `None` or `0`:
50  # * `None` shape cannot be used in `tf.zeros` to create the initial value
51  #   of a `tf.Variable`. Hence, we replace it with a `0` in `tf.zeros`.
52  # * The dimension that has `0` shape may change its shape at run time. To
53  #   support this, we use `None` for that dimension when creating the
54  #   `tf.Variable`.
55  initial_value_shape = []
56  variable_shape = []
57  for shape in tff_type.shape.as_list():
58    if shape is None or shape == 0:
59      initial_value_shape.append(0)
60      variable_shape.append(None)
61    else:
62      initial_value_shape.append(shape)
63      variable_shape.append(shape)
64  return tf.Variable(
65      initial_value=tf.zeros(shape=initial_value_shape, dtype=tff_type.dtype),
66      name=name,
67      dtype=tff_type.dtype,
68      shape=variable_shape,
69      **kwargs,
70  )
71
72
73# Build the TensorSpec for the values we will send to the client so that the
74# client graph will know how to read the incoming values.
75def tensorspec_from_var(var: tf.Variable) -> tf.TensorSpec:
76  """Builds `tf.TensorSpec` from `tf.Variables`.
77
78  Args:
79    var: An instance of `tf.Variable`.
80
81  Returns:
82    An instance of `tf.TensorSpec` corresponding to the input `tf.Variable`.
83  """
84  return tf.TensorSpec(
85      shape=var.shape, dtype=var.dtype, name=tensor_utils.bare_name(var.name)
86  )
87
88
89def create_vars_for_tff_type(
90    tff_type: AllowedTffTypes, name: Optional[str] = None, **kwargs
91) -> list[tf.Variable]:
92  """Creates TensorFlow variables to hold a value of the given `tff_type`.
93
94  The variables are created in the default graph and scope. The variables are
95  automatically given `tf.zeros` initializers.
96
97  Args:
98    tff_type: Either a `tff.StructType`, SERVER-placed `tff.FederatedType` or a
99      `tff.TensorType` object.
100    name: The preferred name to use at the top-most level (if not None, must be
101      a string). If `tff_type` is a `tff.StructType`, the names of the inner
102      fields will be scoped under `name`, e.g. `some_name/field_name`.
103    **kwargs: Optional arguments, if any, to pass to the `tf.Variable()` calls.
104
105  Returns:
106    A flat Python `list` of TensorFlow variable instances.
107
108  Raises:
109    TypeError: If the argument is of the wrong type or has the wrong placement.
110  """
111  type_checks.check_type(
112      tff_type,
113      (tff.TensorType, tff.StructType, tff.FederatedType),
114      name='tff_type',
115  )
116  if name is not None:
117    type_checks.check_type(name, str)
118  else:
119    name = 'v'
120  if isinstance(tff_type, tff.TensorType):
121    return [_create_var_for_tff_tensor(tff_type, name, **kwargs)]
122  elif isinstance(tff_type, tff.FederatedType):
123    if tff_type.placement != tff.SERVER:
124      raise TypeError(
125          'Can only create vars for unplaced types or types placed '
126          'on the SERVER.'
127      )
128    return create_vars_for_tff_type(tff_type.member, name, **kwargs)
129  else:  # tff.StructType
130    result = []
131    with tf.compat.v1.variable_scope(name):
132      fields = tff.structure.to_elements(tff_type)
133      for index, (field_name, field_type) in enumerate(fields):
134        # Default the name of the element to its index so that we don't wind up
135        # with multiple child fields listed under `/v/`
136        if field_name is None:
137          field_name = str(index)
138        result.extend(
139            create_vars_for_tff_type(field_type, name=field_name, **kwargs)
140        )
141    return result
142
143
144def variable_names_from_type(
145    tff_type: AllowedTffTypes, name: str = 'v'
146) -> list[str]:
147  """Creates a flattened list of variables names for the given `tff_type`.
148
149  If `tff_type` is a `tff.TensorType`, the name is the `name` parameter if
150  specified, otherwise a default name: `v`. If `tff_type` is a
151  `tff.StructType` then '/' is used between inner and outer fields together
152  with the tuple name or index of the element in the tuple.
153
154  Some examples:
155  1. If the tff_type is `<'a'=tf.int32, 'b'=tf.int32>` and `name` is not
156    specified, the returned variable name list is ['v/a', 'v/b'].
157  2. If the tff_type is `<tf.int32, tf.int32>` and `name` is `update`, the
158    returned variable name list is ['update/0', 'update/1'].
159  3. If the tff_type is `<'a'=<'b'=tf.int32, 'c'=tf.int32>>` and `name` is
160    `update`, the returned variable name list is ['update/a/b', 'update/a/c'].
161  4. If the tff_type is `<'a'=<'b'=tf.int32, 'c'=tf.int32, tf.int32>>` and
162    `name` is `update`, the returned variable name list is ['update/a/b',
163    'update/a/c', 'update/a/2'].
164
165  Args:
166    tff_type: Either a `tff.StructType`, a `tff.FederatedType` or a
167      `tff.TensorType` object.
168    name: The preferred name to use at the top-most level (if not None, must be
169      a string). If `tff_type` is a `tff.StructType`, the names of the inner
170      fields will be scoped under `name`, e.g. `some_name/field_name`.
171
172  Returns:
173    A flat Python `list` of `str` names.
174
175  Raises:
176    TypeError: If the argument is of the wrong type.
177  """
178  type_checks.check_type(
179      tff_type,
180      (tff.TensorType, tff.FederatedType, tff.StructType),
181      name='tff_type',
182  )
183  type_checks.check_type(name, str, name='name')
184  if isinstance(tff_type, tff.TensorType):
185    return [name]
186  elif isinstance(tff_type, tff.FederatedType):
187    return variable_names_from_type(tff_type.member, name)
188  elif isinstance(tff_type, tff.StructType):
189    result = []
190    fields = tff.structure.iter_elements(tff_type)
191    for index, (field_name, field_type) in enumerate(fields):
192      # Default the name of the element to its index so that we don't wind up
193      # with multiple child fields listed under `/v/`
194      field_name = field_name or str(index)
195      result.extend(
196          variable_names_from_type(field_type, name=name + '/' + field_name)
197      )
198    return result
199  else:
200    raise TypeError(
201        'Cannot create variable names from [{t}] TFF type. '
202        'Short-hand: {s}'.format(t=type(tff_type), s=tff_type)
203    )
204
205
206def get_shared_secagg_tensor_names(
207    intrinsic_name: str, tff_type: AllowedTffTypes
208) -> list[str]:
209  """Creates the shared name of secagg tensors in client and server graph.
210
211  This is the canonical function for ensuring the secagg tensor names in the
212  client and server graph are the same. The server uses secagg tensor
213  names as the keys to retrieve values from secagg server which are originally
214  from client graph, so if the secagg tensor names in the client and server
215  graph are not the same, the server could not find secagg tensors. This
216  function is created to ensure this implicit dependency.
217
218  Args:
219    intrinsic_name: The name of the secure aggregation intrinsic being used.
220    tff_type: Either a `tff.StructType`, `tff.FederatedType` or a
221      `tff.TensorType` object.
222
223  Returns:
224    A list of variable names created from the input TFF type.
225  """
226  tensor_names = variable_names_from_type(
227      tff_type, f'secagg_{intrinsic_name}_update'
228  )
229  return [
230      SIDECHANNEL_NAME_PREFIX + name + _TF_TENSOR_NAME_SUFFIX
231      for name in tensor_names
232  ]
233
234
235def get_flattened_tensor_specs(
236    tff_type: AllowedTffTypes, name: str
237) -> list[tf.TensorSpec]:
238  """Generates TensorSpecs for a flattened version of the given `tff_type`.
239
240  This function uses the same naming logic as `variable_names_from_type`. Please
241  see that function's docstring.
242
243  Args:
244    tff_type: Either a `tff.StructType`, a `tff.FederatedType` or a
245      `tff.TensorType` object.
246    name: The preferred name to use at the top-most level (if not None, must be
247      a string). If `tff_type` is a `tff.StructType`, the names of the inner
248      fields will be scoped under `name`, e.g. `some_name/field_name`.
249
250  Returns:
251    A flat Python `list` of `TensorSpec`s.
252
253  Raises:
254    TypeError: If the argument is of the wrong type.
255  """
256  type_checks.check_type(
257      tff_type,
258      (tff.TensorType, tff.FederatedType, tff.StructType),
259      name='tff_type',
260  )
261  type_checks.check_type(name, str, name='name')
262  if isinstance(tff_type, tff.TensorType):
263    return [tf.TensorSpec(tff_type.shape, tff_type.dtype, name=name)]
264  elif isinstance(tff_type, tff.FederatedType):
265    return get_flattened_tensor_specs(tff_type.member, name)
266  elif isinstance(tff_type, tff.StructType):
267    result = []
268    fields = tff.structure.iter_elements(tff_type)
269    for index, (field_name, field_type) in enumerate(fields):
270      # Default the name of the element to its index so that we don't wind up
271      # with multiple child fields listed under `/v/`
272      field_name = field_name or str(index)
273      result.extend(
274          get_flattened_tensor_specs(field_type, name=name + '/' + field_name)
275      )
276    return result
277  else:
278    raise TypeError(
279        'Cannot create TensorSpecs from [{t}] TFF type. Short-hand: {s}'.format(
280            t=type(tff_type), s=tff_type
281        )
282    )
283
284
285def get_grouped_input_tensor_specs_for_aggregations(
286    aggregation_comp: tff.framework.ComputationBuildingBlock,
287    names: dict[int, str],
288) -> list[list[list[tf.TensorSpec]]]:
289  """Gets the input TensorSpecs for an aggregation computation.
290
291  This function can be used to generate the TensorSpecs that are assigned to
292  ServerAggregationConfig.IntrinsicArg messages to represent the aggregation
293  intrinsic calls in DistributeAggregateForm.client_to_server_aggregation.
294
295  It derives the tensor name(s) for each intrinsic input argument by following
296  naming logic similar to `variable_names_from_type`. DistributeAggregateForm
297  does guarantee that each intrinsic input argument will be a
298  `building_block.Selection` or a (potentially nested) struct of
299  `building_block.Selection`s. The first element of the path is used to
300  determine the top-level name, which must match the top-level name that was
301  used to construct the tensor that will be getting consumed by this argument.
302
303  Args:
304    aggregation_comp: The aggregation computation.
305    names: A dictionary describing how to map the first element of the path to a
306      top-level name.
307
308  Returns:
309    A `list` where the ith entry represents the input tensor specs for the
310    ith intrinsic in the aggregation computation. The ith entry is itself a list
311    where the jth entry represents the input tensor specs for the jth argument
312    of the ith intrinsic in the aggregation computation.
313
314  Raises:
315    TypeError: If the argument is of the wrong type.
316    ValueError: If the argument contains an unexpected
317      `building_block.Selection` index.
318  """
319
320  def _get_selection_path(
321      selection: tff.framework.ComputationBuildingBlock,
322  ) -> list[int]:
323    """Gets the list of selection indices for a building_blocks.Selection."""
324
325    path = []
326    while selection.is_selection():
327      path.append(selection.index)  # pytype: disable=attribute-error
328      selection = selection.source  # pytype: disable=attribute-error
329    # In ASTs like x[0][1], we'll see the last (outermost) selection first.
330    path.reverse()
331    return path
332
333  def _get_input_tensor_specs_for_aggregation_arg(
334      value: tff.framework.ComputationBuildingBlock, names: dict[int, str]
335  ) -> list[tf.TensorSpec]:
336    """Gets the input TensorSpecs for a single intrinsic argument."""
337
338    # An intrinsic arg may be a `building_block.Selection` or a (potentially
339    # nested) struct of `building_block.Selection`s. Start by creating a
340    # flattened list of the `building_block.Selection`s.
341    inner_values = []
342    if value.is_struct():
343      inner_values = tff.structure.flatten(value)
344    else:
345      inner_values = [value]
346
347    # For each `building_block.Selection`, reconstruct the tensor name that
348    # will be used to supply that value. The first index of the selection path
349    # indicates whether the tensor will be coming from the intermediate state
350    # checkpoint (0) or from the client checkpoint (1), since TFF condenses
351    # daf.client_to_server_aggregation(temp_server_state, client_update)
352    # into a 1-arg function. Since the tensors within the checkpoints
353    # corresponding to temp_server_state and work_at_clients will be named using
354    # variable_names_from_type, which uses a simple filepath-like naming pattern
355    # to refer to the tensors within a struct, we can reconstruct the relevant
356    # tensor name by concatenating together the remaining indices of each
357    # selection path.
358    tensor_specs = []
359    for inner_value in inner_values:
360      inner_value.check_selection()
361      path = _get_selection_path(inner_value)
362      arg_index = path[0]
363      if arg_index in names:
364        prefix = names[arg_index]
365      else:
366        raise ValueError('Unexpected arg index for aggregation selection')
367      prefix += '/' + '/'.join([str(x) for x in path[1:]])
368      tensor_specs.extend(
369          get_flattened_tensor_specs(inner_value.type_signature, name=prefix)
370      )
371
372    return tensor_specs
373
374  grouped_input_tensor_specs = []
375
376  for _, local_value in aggregation_comp.result.locals:  # pytype: disable=attribute-error
377    local_value.check_call()
378    local_value.function.check_intrinsic()
379    assert local_value.function.intrinsic_def().aggregation_kind
380
381    # Collect the input TensorFlowSpecs for each argument for this intrinsic.
382    input_tensor_specs_for_intrinsic = []
383    if (
384        local_value.function.intrinsic_def().type_signature.parameter.is_struct()
385    ):
386      for element in local_value.argument.children():
387        input_tensor_specs_for_intrinsic.append(
388            _get_input_tensor_specs_for_aggregation_arg(element, names)
389        )
390    else:
391      input_tensor_specs_for_intrinsic.append(
392          _get_input_tensor_specs_for_aggregation_arg(
393              local_value.argument, names
394          )
395      )
396
397    grouped_input_tensor_specs.append(input_tensor_specs_for_intrinsic)
398
399  return grouped_input_tensor_specs
400
401
402def get_grouped_output_tensor_specs_for_aggregations(
403    aggregation_comp: tff.framework.ComputationBuildingBlock,
404) -> list[list[tf.TensorSpec]]:
405  """Gets the output TensorSpecs for an aggregation computation.
406
407  This function can be used to generate the TensorSpecs that are assigned
408  to the output_tensors value in ServerAggregationConfig messages to represent
409  the aggregation intrinsic calls in
410  DistributeAggregateForm.client_to_server_aggregation.
411
412  It derives the tensor name(s) for each intrinsic output argument by following
413  naming logic similar to `variable_names_from_type`. It must produce tensor
414  names that match the tensor names that are expected by the post-aggregation
415  computation.
416
417  Args:
418    aggregation_comp: The aggregation computation.
419
420  Returns:
421    A list where the ith entry represents the output tensor specs for the ith
422    intrinsic in the aggregation computation.
423
424  Raises:
425    TypeError: If the argument is of the wrong type.
426  """
427  # TensorflowSpecs for all the intrinsic results. These TensorflowSpecs must
428  # have names that mirror the result of calling variable_names_from_type on
429  # the output type of DistributeAggregateForm.client_to_server_aggregation
430  # (which is the same as the type of the aggregation result input arg in
431  # DistributeAggregateForm.server_result).
432  output_tensor_specs = get_flattened_tensor_specs(
433      tff.StructType([aggregation_comp.type_signature.result]),
434      name='intermediate_update',
435  )
436  output_tensor_spec_index = 0
437
438  grouped_output_tensor_specs = []
439
440  for _, local_value in aggregation_comp.result.locals:  # pytype: disable=attribute-error
441    local_value.check_call()
442    local_value.function.check_intrinsic()
443    local_value.type_signature.check_federated()
444    assert local_value.function.intrinsic_def().aggregation_kind
445
446    tensor_specs = []
447    # If the output is a struct, select the appropriate number of
448    # TensorflowSpecs.
449    if local_value.type_signature.member.is_struct():
450      num_specs = len(tff.structure.flatten(local_value.type_signature.member))
451      tensor_specs = output_tensor_specs[
452          output_tensor_spec_index : output_tensor_spec_index + num_specs
453      ]
454      output_tensor_spec_index += num_specs
455    else:
456      tensor_specs.append(output_tensor_specs[output_tensor_spec_index])
457      output_tensor_spec_index += 1
458    grouped_output_tensor_specs.append(tensor_specs)
459
460  return grouped_output_tensor_specs
461