xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/sharded_variable.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""ShardedVariable class."""
16import copy
17import math
18from typing import Sequence
19import weakref
20
21import numpy as np
22
23from tensorflow.python.framework import composite_tensor
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import indexed_slices as indexed_slices_lib
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import tensor_shape
29from tensorflow.python.framework import type_spec
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import data_flow_ops
32from tensorflow.python.ops import embedding_ops
33from tensorflow.python.ops import math_ops
34from tensorflow.python.ops import partitioned_variables
35from tensorflow.python.ops import resource_variable_ops
36from tensorflow.python.ops import variables as variables_lib
37from tensorflow.python.saved_model import save_context
38from tensorflow.python.trackable import base as trackable
39from tensorflow.python.training.saving import saveable_object_util
40from tensorflow.python.util import dispatch
41from tensorflow.python.util.tf_export import tf_export
42
43
44@tf_export('distribute.experimental.partitioners.Partitioner', v1=[])
45class Partitioner(object):
46  """Partitioner base class: all partitiners inherit from this class.
47
48  Partitioners should implement a `__call__` method with the following
49  signature:
50
51  ```python
52  def __call__(self, shape, dtype, axis=0):
53    # Partitions the given `shape` and returns the partition results.
54    # See docstring of `__call__` method for the format of partition results.
55  ```
56  """
57
58  def __call__(self, shape, dtype, axis=0):
59    """Partitions the given `shape` and returns the partition results.
60
61    Examples of a partitioner that allocates a fixed number of shards:
62
63    ```python
64    partitioner = FixedShardsPartitioner(num_shards=2)
65    partitions = partitioner(tf.TensorShape([10, 3], tf.float32), axis=0)
66    print(partitions) # [2, 0]
67    ```
68
69    Args:
70      shape: a `tf.TensorShape`, the shape to partition.
71      dtype: a `tf.dtypes.Dtype` indicating the type of the partition value.
72      axis: The axis to partition along.  Default: outermost axis.
73
74    Returns:
75      A list of integers representing the number of partitions on each axis,
76      where i-th value correponds to i-th axis.
77    """
78    raise NotImplementedError
79
80
81@tf_export('distribute.experimental.partitioners.FixedShardsPartitioner', v1=[])
82class FixedShardsPartitioner(Partitioner):
83  """Partitioner that allocates a fixed number of shards.
84
85  Examples:
86
87  >>> # standalone usage:
88  >>> partitioner = FixedShardsPartitioner(num_shards=2)
89  >>> partitions = partitioner(tf.TensorShape([10, 3]), tf.float32)
90  >>> [2, 1]
91  >>>
92  >>> # use in ParameterServerStrategy
93  >>> # strategy = tf.distribute.experimental.ParameterServerStrategy(
94  >>> #   cluster_resolver=cluster_resolver, variable_partitioner=partitioner)
95
96  """
97
98  def __init__(self, num_shards):
99    """Creates a new `FixedShardsPartitioner`.
100
101    Args:
102      num_shards: `int`, number of shards to partition.
103    """
104    self._num_shards = num_shards
105
106  def __call__(self, shape, dtype, axis=0):
107    del dtype
108    result = [1] * len(shape)
109    result[axis] = min(self._num_shards, shape.dims[axis].value)
110    return result
111
112
113@tf_export('distribute.experimental.partitioners.MinSizePartitioner', v1=[])
114class MinSizePartitioner(Partitioner):
115  """Partitioner that allocates a minimum size per shard.
116
117  This partitioner ensures each shard has at least `min_shard_bytes`, and tries
118  to allocate as many shards as possible, i.e., keeping shard size as small as
119  possible. The maximum number of such shards (upper bound) is given by
120  `max_shards`.
121
122  Examples:
123
124  >>> partitioner = MinSizePartitioner(min_shard_bytes=4, max_shards=2)
125  >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
126  >>> [2, 1]
127  >>> partitioner = MinSizePartitioner(min_shard_bytes=4, max_shards=10)
128  >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
129  >>> [6, 1]
130  >>>
131  >>> # use in ParameterServerStrategy
132  >>> # strategy = tf.distribute.experimental.ParameterServerStrategy(
133  >>> #   cluster_resolver=cluster_resolver, variable_partitioner=partitioner)
134  """
135
136  def __init__(self,
137               min_shard_bytes=256 << 10,
138               max_shards=1,
139               bytes_per_string=16):
140    """Creates a new `MinSizePartitioner`.
141
142    Args:
143      min_shard_bytes: Minimum bytes of each shard. Defaults to 256K.
144      max_shards: Upper bound on the number of shards. Defaults to 1.
145      bytes_per_string: If the partition value is of type string, this provides
146        an estimate of how large each string is.
147    """
148    if min_shard_bytes < 1:
149      raise ValueError('Argument `min_shard_bytes` must be positive. '
150                       f'Received: {min_shard_bytes}')
151    if max_shards < 1:
152      raise ValueError('Argument `max_shards` must be positive. '
153                       f'Received: {max_shards}')
154    if bytes_per_string < 1:
155      raise ValueError('Argument `bytes_per_string` must be positive. '
156                       f'Received: {bytes_per_string}')
157    self._min_shard_bytes = min_shard_bytes
158    self._max_shards = max_shards
159    self._bytes_per_string = bytes_per_string
160
161  def __call__(self, shape, dtype, axis=0):
162    return partitioned_variables.min_max_variable_partitioner(
163        max_partitions=self._max_shards,
164        axis=axis,
165        min_slice_size=self._min_shard_bytes,
166        bytes_per_string_element=self._bytes_per_string)(shape, dtype)
167
168
169@tf_export('distribute.experimental.partitioners.MaxSizePartitioner', v1=[])
170class MaxSizePartitioner(Partitioner):
171  """Partitioner that keeps shards below `max_shard_bytes`.
172
173  This partitioner ensures each shard has at most `max_shard_bytes`, and tries
174  to allocate as few shards as possible, i.e., keeping shard size as large
175  as possible.
176
177  If the partitioner hits the `max_shards` limit, then each shard may end up
178  larger than `max_shard_bytes`. By default `max_shards` equals `None` and no
179  limit on the number of shards is enforced.
180
181  Examples:
182
183  >>> partitioner = MaxSizePartitioner(max_shard_bytes=4)
184  >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
185  >>> [6, 1]
186  >>> partitioner = MaxSizePartitioner(max_shard_bytes=4, max_shards=2)
187  >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
188  >>> [2, 1]
189  >>> partitioner = MaxSizePartitioner(max_shard_bytes=1024)
190  >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
191  >>> [1, 1]
192  >>>
193  >>> # use in ParameterServerStrategy
194  >>> # strategy = tf.distribute.experimental.ParameterServerStrategy(
195  >>> #   cluster_resolver=cluster_resolver, variable_partitioner=partitioner)
196  """
197
198  def __init__(self, max_shard_bytes, max_shards=None, bytes_per_string=16):
199    """Creates a new `MaxSizePartitioner`.
200
201    Args:
202      max_shard_bytes: The maximum size any given shard is allowed to be.
203      max_shards: The maximum number of shards in `int` created taking
204        precedence over `max_shard_bytes`.
205      bytes_per_string: If the partition value is of type string, this provides
206        an estimate of how large each string is.
207    """
208    if max_shard_bytes < 1:
209      raise ValueError('Argument `max_shard_bytes` must be positive. '
210                       f'Received {max_shard_bytes}')
211    if max_shards and max_shards < 1:
212      raise ValueError('Argument `max_shards` must be positive. '
213                       f'Received {max_shards}')
214    if bytes_per_string < 1:
215      raise ValueError('Argument `bytes_per_string` must be positive. '
216                       f'Received: {bytes_per_string}')
217
218    self._max_shard_bytes = max_shard_bytes
219    self._max_shards = max_shards
220    self._bytes_per_string = bytes_per_string
221
222  def __call__(self, shape, dtype, axis=0):
223    return partitioned_variables.variable_axis_size_partitioner(
224        max_shard_bytes=self._max_shard_bytes,
225        max_shards=self._max_shards,
226        bytes_per_string_element=self._bytes_per_string,
227        axis=axis)(shape, dtype)
228
229
230class ShardedVariableSpec(type_spec.TypeSpec):
231  """Type specification for a `ShardedVariable`."""
232
233  __slots__ = ['_variable_specs']
234
235  value_type = property(lambda self: ShardedVariable)
236
237  def __init__(self, *variable_specs):
238    self._variable_specs = tuple(variable_specs)
239
240  def _serialize(self):
241    return self._variable_specs
242
243  @property
244  def _component_specs(self):
245    return self._variable_specs
246
247  def _to_components(self, value):
248    return value.variables
249
250  def _from_components(self, variables):
251    return ShardedVariable(variables)
252
253
254class ShardedVariableMixin(trackable.Trackable):
255  """Mixin for ShardedVariable."""
256
257  # TODO(b/170877138): Remove this mixin once fixed. This mixin is required
258  # since TPUEmbeddingVariable can't be a CompositeTensor.
259
260  def __init__(self, variables, name='ShardedVariable'):
261    """Treats `variables` as shards of a larger Variable.
262
263
264    Example:
265
266    ```
267    variables = [
268      tf.Variable(..., shape=(10, 100), dtype=tf.float32),
269      tf.Variable(..., shape=(15, 100), dtype=tf.float32),
270      tf.Variable(..., shape=(5, 100), dtype=tf.float32)
271    ]
272    sharded_variable = ShardedVariableMixin(variables)
273    assert sharded_variable.shape.as_list() == [30, 100]
274    ```
275
276    Args:
277      variables: A list of `ResourceVariable`s that comprise this sharded
278        variable. Variables should not be shared between different
279        `ShardedVariableMixin` objects.
280      name: String. Name of this container. Defaults to "ShardedVariable".
281    """
282    super(ShardedVariableMixin, self).__init__()
283    self._variables = variables
284    self._name = name
285
286    if not isinstance(variables, Sequence) or not variables or any(
287        not isinstance(v, variables_lib.Variable) for v in variables):
288      raise TypeError('Argument `variables` should be a non-empty list of '
289                      f'`variables.Variable`s. Received {variables}')
290
291    var_dtypes = {v.dtype for v in variables}
292    if len(var_dtypes) > 1:
293      raise ValueError(
294          'All elements in argument `variables` must have the same dtype. '
295          f'Received dtypes: {[v.dtype for v in variables]}')
296
297    first_var = variables[0]
298    self._dtype = first_var.dtype
299
300    # All variables must have the same shape for axes > 0.
301    higher_dim_shapes = {tuple(v.shape.as_list()[1:]) for v in variables}
302    if len(higher_dim_shapes) > 1:
303      raise ValueError(
304          'All elements in argument `variables` must have the same shapes '
305          'except for the first axis. '
306          f'Received shapes: {[v.shape for v in variables]}')
307    first_dim = sum(int(v.shape.as_list()[0]) for v in variables)
308    self._shape = tensor_shape.TensorShape([first_dim] +
309                                           first_var.shape.as_list()[1:])
310
311    for v in variables:
312      v._sharded_container = weakref.ref(self)
313
314    self._var_offsets = [
315        [0 for _ in range(len(first_var.shape))] for _ in range(len(variables))
316    ]
317    for i in range(1, len(variables)):
318      # Always partition on the first axis. Offsets on other axes are 0.
319      self._var_offsets[i][0] += (
320          self._var_offsets[i - 1][0] + variables[i - 1].shape.as_list()[0])
321
322    save_slice_info = [v._get_save_slice_info() for v in variables]  # pylint: disable=protected-access
323    if any(slice_info is not None for slice_info in save_slice_info):
324      raise ValueError(
325          '`SaveSliceInfo` should not be set for all elements in argument '
326          '`variables`. `ShardedVariable` will infer `SaveSliceInfo` according '
327          'to the order of the elements `variables`. '
328          f'Received save slice info {save_slice_info}')
329
330    # We create an uninitialized saving_variable with the full shape, which can
331    # be later captured in signatures so that the signatures can treat this
332    # ShardedVariable as one single variable.
333    self._saving_variable = resource_variable_ops.UninitializedVariable(
334        shape=self._shape, dtype=self._dtype, name=self._name,
335        trainable=self._variables[0].trainable,
336        synchronization=variables_lib.VariableSynchronization.NONE,
337        aggregation=variables_lib.VariableAggregation.NONE)
338
339  def __iter__(self):
340    """Return an iterable for accessing the underlying sharded variables."""
341    return iter(self._variables)
342
343  def __getitem__(self, slice_spec):
344    """Extracts the specified region as a Tensor from the sharded variable.
345
346    The API contract is identical to `Tensor.__getitem__`. Assignment to the
347    sliced range is not yet supported.
348
349    Args:
350      slice_spec: The arguments to __getitem__, specifying the global slicing of
351        the sharded variable.
352
353    Returns:
354      The appropriate slice of tensor based on `slice_spec`.
355
356    Raises:
357      IndexError: If a slice index is out of bound.
358      TypeError: If `spec_spec` contains Tensor.
359    """
360
361    # TODO(b/177482728): Support tensor input.
362    # TODO(b/177482728): Support slice assign, similar to variable slice assign.
363
364    if (isinstance(slice_spec, bool) or (isinstance(slice_spec, ops.Tensor) and
365                                         slice_spec.dtype == dtypes.bool) or
366        (isinstance(slice_spec, np.ndarray) and slice_spec.dtype == bool)):
367      tensor = _var_to_tensor(self)
368      return array_ops.boolean_mask(tensor=tensor, mask=slice_spec)
369
370    if not isinstance(slice_spec, (list, tuple)):
371      slice_spec = (slice_spec,)
372
373    s = slice_spec[0]
374    if isinstance(s, slice):
375      first_dim_slice_specs = self._decompose_slice_spec(s)
376      values = []
377      for i, var in enumerate(self._variables):
378        if first_dim_slice_specs[i] is not None:
379          all_dim_slice_spec = (first_dim_slice_specs[i],) + slice_spec[1:]
380          values.append(var[all_dim_slice_spec])
381      if s.step is not None and s.step < 0:
382        values.reverse()
383      if not values:
384        return constant_op.constant([],
385                                    dtype=self._dtype,
386                                    shape=((0,) + self._shape[1:]))
387      return array_ops.concat(values, axis=0)
388    elif s is Ellipsis:
389      return array_ops.concat([var[slice_spec] for var in self._variables],
390                              axis=0)
391    elif s is array_ops.newaxis:
392      return array_ops.concat([var[slice_spec[1:]] for var in self._variables],
393                              axis=0)[array_ops.newaxis]
394    else:
395      if isinstance(s, ops.Tensor):
396        raise TypeError(
397            'ShardedVariable: using Tensor for indexing is not allowed.')
398      if s < 0:
399        s += self._shape[0]
400      if s < 0 or s >= self._shape[0]:
401        raise IndexError(
402            f'ShardedVariable: slice index {s} of dimension 0 out of bounds.')
403      for i in range(len(self._variables)):
404        if i == len(self._variables) - 1 or (s > self._var_offsets[i][0] and
405                                             s < self._var_offsets[i + 1][0]):
406          return self._variables[i][(s - self._var_offsets[i][0],) +
407                                    slice_spec[1:]]
408
409  def _decompose_slice_spec(self, slice_spec):
410    """Decompose a global slice_spec into a list of per-variable slice_spec.
411
412    `ShardedVariable` only supports first dimension partitioning, thus
413    `slice_spec` must be for first dimension.
414
415    Args:
416      slice_spec: A python `slice` object that specifies the global slicing.
417
418    Returns:
419      A list of python `slice` objects or None specifying the local slicing for
420      each component variable. None means no slicing.
421
422    For example, given component variables:
423      v0 = [0, 1, 2]
424      v1 = [3, 4, 5]
425      v2 = [6, 7, 8, 9]
426
427    If `slice_spec` is slice(start=None, stop=None, step=None), we will have:
428      v0[returned[0]] = [0, 1, 2]
429      v1[returned[1]] = [3, 4, 5]
430      v2[returned[2]] = [6, 7, 8, 9]
431    If `slice_spec` is slice(start=2, stop=8, step=3), we will have:
432      v0[returned[0]] = [2]
433      v1[returned[1]] = [5]
434      returned[2] == None
435    If `slice_spec` is slice(start=9, stop=3, step=-2), we will have:
436      returned[0] == None
437      v1[returned[1]] = [5]
438      v2[returned[2]] = [9, 7]
439    """
440    if isinstance(slice_spec.start, ops.Tensor) or isinstance(
441        slice_spec.stop, ops.Tensor) or isinstance(slice_spec.step, ops.Tensor):
442      raise TypeError(
443          'ShardedVariable: using Tensor in slice_spec is not allowed. Please '
444          'file a feature request with the TensorFlow team.')
445
446    result = []
447    # Normalize start, end and stop.
448    slice_step = slice_spec.step if slice_spec.step is not None else 1
449    if slice_step == 0:
450      raise ValueError('slice step cannot be zero')
451    slice_start = slice_spec.start
452    if slice_start is None:
453      slice_start = 0 if slice_step > 0 else self._shape[0] - 1
454    elif slice_start < 0:
455      slice_start += self._shape[0]
456    slice_end = slice_spec.stop
457    if slice_end is None:
458      # After the normalization, we no longer interpret negative index, thus
459      # "-1" conceptually refers to the element before the first one, which
460      # doesn't exist. This is to ease the decomposition code.
461      slice_end = self._shape[0] if slice_step > 0 else -1
462    elif slice_end < 0:
463      slice_end += self._shape[0]
464
465    # To find the local slice_spec of each component variable, we start from
466    # the start of the global slice, and iterate through each variable.
467    # When iterating on a variable, we move the cursor (`cur`) to the first
468    # index that falls into the variable's range, which becomes the start of
469    # the variable's local slice_spec. The end of the local_spec is determined
470    # by using whatever is smaller between global slice end and variable range
471    # end.
472    cur = slice_start
473    if slice_step > 0:
474      for i in range(len(self._var_offsets)):
475        var_start = self._var_offsets[i][0]
476        var_end = (
477            self._var_offsets[i + 1][0]
478            if i < len(self._var_offsets) - 1 else self._shape[0])
479        if cur < var_start:
480          cur += slice_step * int(math.ceil((var_start - cur) / slice_step))
481        if cur >= var_end or cur >= slice_end:
482          result.append(None)
483        else:
484          start = cur - var_start
485          end = min(slice_end, var_end) - var_start
486          result.append(slice(start, end, slice_step))
487    else:  # slice_step < 0
488      for i in range(len(self._var_offsets) - 1, -1, -1):
489        var_start = self._var_offsets[i][0]
490        var_end = (
491            self._var_offsets[i + 1][0]
492            if i < len(self._var_offsets) - 1 else self._shape[0])
493        if cur >= var_end:
494          cur += slice_step * int(math.ceil((var_end - cur - 1) / slice_step))
495        if cur < var_start or cur <= slice_end:
496          result.append(None)
497        else:
498          start = cur - var_start
499          if slice_end >= var_start:
500            end = slice_end - var_start
501          else:
502            end = None  # no explicit end: slice until hitting the boundary.
503          result.append(slice(start, end, slice_step))
504
505      result.reverse()
506
507    return result
508
509  @property
510  def _type_spec(self):
511    return ShardedVariableSpec(
512        *(resource_variable_ops.VariableSpec(v.shape, v.dtype)
513          for v in self._variables))
514
515  @property
516  def variables(self):
517    """The list of `Variable`s that make up the shards of this object."""
518    if save_context.in_save_context():
519      return [self._saving_variable]
520    return self._variables
521
522  @property
523  def name(self):
524    """The name of this object. Used for checkpointing."""
525    return self._name
526
527  @property
528  def dtype(self):
529    """The dtype of all `Variable`s in this object."""
530    return self._dtype
531
532  @property
533  def shape(self):
534    """The overall shape, combining all shards along axis `0`."""
535    return self._shape
536
537  def assign(self, value, use_locking=None, name=None, read_value=True):
538    for i, v in enumerate(self._variables):
539      v.assign(array_ops.slice(value, self._var_offsets[i], v.shape.as_list()))
540    return self
541
542  def assign_add(self, delta, use_locking=False, name=None, read_value=True):
543    for i, v in enumerate(self._variables):
544      v.assign_add(
545          array_ops.slice(delta, self._var_offsets[i], v.shape.as_list()))
546    return self
547
548  def assign_sub(self, delta, use_locking=False, name=None, read_value=True):
549    for i, v in enumerate(self._variables):
550      v.assign_sub(
551          array_ops.slice(delta, self._var_offsets[i], v.shape.as_list()))
552    return self
553
554  def _decompose_indices(self, indices):
555    """Decompose a global 1D indices into a list of per-variable indices."""
556    if indices.shape.rank != 1:
557      raise ValueError(
558          'ShardedVariable: indices must be 1D Tensor for sparse operations. '
559          f'Received shape: {indices.shape}')
560
561    base = self._shape[0] // len(self._variables)
562    extra = self._shape[0] % len(self._variables)
563
564    # Assert that sharding conforms to "div" sharding
565    expect_first_dim = [base] * len(self._variables)
566    for i in range(extra):
567      expect_first_dim[i] = expect_first_dim[i] + 1
568    actual_first_dim = [v.shape.as_list()[0] for v in self._variables]
569    if expect_first_dim != actual_first_dim:
570      raise NotImplementedError(
571          'scater_xxx ops are not supported in ShardedVariale that does not '
572          'conform to "div" sharding')
573
574    # For index that falls into the partition that has extra 1, assignment is
575    # `index // (base + 1)` (no less than `(indices - extra) // base`)
576    # For index that falls into the partition that doesn't has extra 1,
577    # assignment is `(indices - extra) // base` (no less than
578    # `indices // (base + 1)`)
579    #
580    # Example:
581    #   base = 10, extra = 2, partitions: [0, 11), [11, 22), [22, 32)
582    #   index = 10 -> partition_assigment = 0
583    #   index = 22 -> partition_assiment = 2
584    partition_assignments = math_ops.maximum(indices // (base + 1),
585                                             (indices - extra) // base)
586    local_indices = array_ops.where(partition_assignments < extra,
587                                    indices % (base + 1),
588                                    (indices - extra) % base)
589    # For whatever reason `dynamic_partition` only supports int32
590    partition_assignments = math_ops.cast(partition_assignments, dtypes.int32)
591    per_var_indices = data_flow_ops.dynamic_partition(local_indices,
592                                                      partition_assignments,
593                                                      len(self._variables))
594
595    return per_var_indices, partition_assignments
596
597  def _decompose_indexed_slices(self, indexed_slices):
598    """Decompose a global `IndexedSlices` into a list of per-variable ones."""
599    per_var_indices, partition_assignments = self._decompose_indices(
600        indexed_slices.indices)
601    per_var_values = data_flow_ops.dynamic_partition(indexed_slices.values,
602                                                     partition_assignments,
603                                                     len(self._variables))
604
605    return [
606        indexed_slices_lib.IndexedSlices(
607            values=per_var_values[i], indices=per_var_indices[i])
608        for i in range(len(self._variables))
609    ]
610
611  # ==================== scatter ops implementations ======================== #
612
613  def scatter_add(self, sparse_delta, use_locking=False, name=None):
614    """Implements tf.Variable.scatter_add."""
615    per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta)
616    for i, v in enumerate(self._variables):
617      new_name = None
618      if name is not None:
619        new_name = '{}/part_{}'.format(name, i)
620      v.scatter_add(per_var_sparse_delta[i], name=new_name)
621    return self
622
623  def scatter_div(self, sparse_delta, use_locking=False, name=None):
624    """Implements tf.Variable.scatter_div."""
625    per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta)
626    for i, v in enumerate(self._variables):
627      new_name = None
628      if name is not None:
629        new_name = '{}/part_{}'.format(name, i)
630      v.scatter_div(per_var_sparse_delta[i], name=new_name)
631    return self
632
633  def scatter_max(self, sparse_delta, use_locking=False, name=None):
634    """Implements tf.Variable.scatter_max."""
635    per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta)
636    for i, v in enumerate(self._variables):
637      new_name = None
638      if name is not None:
639        new_name = '{}/part_{}'.format(name, i)
640      v.scatter_max(per_var_sparse_delta[i], name=new_name)
641    return self
642
643  def scatter_min(self, sparse_delta, use_locking=False, name=None):
644    """Implements tf.Variable.scatter_min."""
645    per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta)
646    for i, v in enumerate(self._variables):
647      new_name = None
648      if name is not None:
649        new_name = '{}/part_{}'.format(name, i)
650      v.scatter_min(per_var_sparse_delta[i], name=new_name)
651    return self
652
653  def scatter_mul(self, sparse_delta, use_locking=False, name=None):
654    """Implements tf.Variable.scatter_mul."""
655    per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta)
656    for i, v in enumerate(self._variables):
657      new_name = None
658      if name is not None:
659        new_name = '{}/part_{}'.format(name, i)
660      v.scatter_mul(per_var_sparse_delta[i], name=new_name)
661    return self
662
663  def scatter_sub(self, sparse_delta, use_locking=False, name=None):
664    """Implements tf.Variable.scatter_sub."""
665    per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta)
666    for i, v in enumerate(self._variables):
667      new_name = None
668      if name is not None:
669        new_name = '{}/part_{}'.format(name, i)
670      v.scatter_sub(per_var_sparse_delta[i], name=new_name)
671    return self
672
673  def scatter_update(self, sparse_delta, use_locking=False, name=None):
674    """Implements tf.Variable.scatter_update."""
675    per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta)
676    for i, v in enumerate(self._variables):
677      new_name = None
678      if name is not None:
679        new_name = '{}/part_{}'.format(name, i)
680      v.scatter_update(per_var_sparse_delta[i], name=new_name)
681    return self
682
683  def batch_scatter_update(self, sparse_delta, use_locking=False, name=None):
684    """Implements tf.Variable.batch_scatter_update."""
685    per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta)
686    for i, v in enumerate(self._variables):
687      new_name = None
688      if name is not None:
689        new_name = '{}/part_{}'.format(name, i)
690      v.batch_scatter_update(per_var_sparse_delta[i], name=new_name)
691    return self
692
693  # ================== scatter ops implementations END ====================== #
694
695  def sparse_read(self, indices, name=None):
696    """Implements tf.Variable.sparse_read."""
697    per_var_indices, _ = self._decompose_indices(indices)
698    result = []
699    for i, v in enumerate(self._variables):
700      new_name = None
701      if name is not None:
702        new_name = '{}/part_{}'.format(name, i)
703      result.append(v.sparse_read(per_var_indices[i], name=new_name))
704    return array_ops.concat(result, axis=0)
705
706  def _gather_saveables_for_checkpoint(self):
707    """Return a `Saveable` for each shard. See `Trackable`."""
708
709    def _saveable_factory(name=self.name):
710      """Creates `SaveableObject`s for this `ShardedVariable`."""
711      saveables = []
712      dims = len(self._variables[0].shape)
713      var_offset = [0 for _ in range(dims)]
714      for v in self._variables:
715        save_slice_info = variables_lib.Variable.SaveSliceInfo(
716            full_name=self.name,
717            full_shape=self.shape.as_list(),
718            var_offset=copy.copy(var_offset),
719            var_shape=v.shape.as_list())
720        saveables.append(
721            saveable_object_util.ResourceVariableSaveable(
722                v, save_slice_info.spec, name))
723        var_offset[0] += int(v.shape[0])
724      return saveables
725
726    return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
727
728  def _map_resources(self, save_options):
729    """For implementing `Trackable`."""
730    obj_map, resource_map = {}, {}
731    for v in self._variables + [self._saving_variable]:
732      v_obj_map, v_resource_map = v._map_resources(save_options)  # pylint:disable=protected-access
733      obj_map.update(v_obj_map)
734      resource_map.update(v_resource_map)
735    obj_map[self] = ShardedVariable([obj_map[self._saving_variable]],
736                                    name=self.name)
737    return obj_map, resource_map
738
739  @property
740  def _unique_id(self):
741    # String-replace to ensure uniqueness for checkpoint tracking
742    return self.variables[0]._unique_id.replace('part_0', 'sharded')  # pylint: disable=protected-access
743
744  @property
745  def _distribute_strategy(self):
746    return self.variables[0]._distribute_strategy  # pylint: disable=protected-access
747
748  @property
749  def _shared_name(self):
750    return self._name
751
752  @property
753  def is_sharded_variable(self):
754    return True
755
756  def numpy(self):
757    """Copies the values in this ShardedVariable to a NumPy array.
758
759    First converts to a single Tensor using the registered conversion function,
760    which concatenates the shards, then uses Tensor.numpy() to convert to
761    a NumPy array.
762
763    Returns:
764      A NumPy array of the same shape and dtype.
765    """
766    return _var_to_tensor(self).numpy()
767
768
769@tf_export('__internal__.distribute.ShardedVariable', v1=[])
770class ShardedVariable(ShardedVariableMixin, composite_tensor.CompositeTensor):
771  """A container for `Variables` that should be treated as shards.
772
773  Variables that are too large to fit on a single device (e.g., large
774  embeddings)
775  may need to be sharded over multiple devices. This class maintains a list of
776  smaller variables that can be independently stored on separate devices (eg,
777  multiple parameter servers), and saves and restores those variables as if they
778  were a single larger variable.
779
780  Objects of this class can be saved with a given number of shards and then
781  restored from a checkpoint into a different number of shards.
782
783  Objects of this class can be saved to SavedModel format using
784  `tf.saved_model.save`. The SavedModel can be used by programs like TF serving
785  APIs. It is not yet supported to load the SavedModel with
786  `tf.saved_model.load`.
787
788  Since `ShardedVariable` can be saved and then restored to different number of
789  shards depending on the restore environments, for example, TF serving APIs
790  would restore to one shard for serving efficiency, when using
791  `ShardedVariable` in a tf.function, one should generally not assume it has the
792  same number of shards across save and load.
793
794  Sharding is only supported along the first dimension.
795
796  >>> class Model(tf.Module):
797  ...   def __init__(self):
798  ...     self.sharded_variable = ShardedVariable([
799  ...       tf.Variable([3.0], dtype=tf.float32),
800  ...       tf.Variable([2.0], dtype=tf.float32)
801  ...     ])
802  ...
803  ...   @tf.function(input_signature=[tf.TensorSpec([], dtype=tf.int32)])
804  ...   def fn(self, x):
805  ...     return tf.nn.embedding_lookup(self.sharded_variable.variables, x)
806  ...
807  ...   @tf.function(input_signature=[tf.TensorSpec([], dtype=tf.int32)])
808  ...   def serve_fn(self, x):
809  ...     return tf.nn.embedding_lookup(self.sharded_variable.variables, x)
810  >>>
811  >>> model = Model()
812  >>> model.fn(1).numpy()
813  2.0
814  >>> tf.saved_model.save(model, export_dir='/tmp/saved_model',
815  ...   signatures=model.serve_fn)
816  """
817
818  @property
819  def _type_spec(self):
820    return ShardedVariableSpec(
821        *(resource_variable_ops.VariableSpec(v.shape, v.dtype)
822          for v in self._variables))
823
824  @classmethod
825  def _overload_all_operators(cls):
826    """Register overloads for all operators."""
827    for operator in ops.Tensor.OVERLOADABLE_OPERATORS:
828      if operator == '__getitem__':
829        continue
830
831      cls._overload_operator(operator)
832
833  @classmethod
834  def _overload_operator(cls, operator):
835    """Delegate an operator overload to `ops.Tensor`."""
836    tensor_operator = getattr(ops.Tensor, operator)
837
838    def _operator(v, *args, **kwargs):
839      return tensor_operator(_var_to_tensor(v), *args, **kwargs)
840
841    setattr(cls, operator, _operator)
842
843  def __tf_experimental_restore_capture__(self, concrete_function,
844                                          internal_capture):
845    # Avoid restoring captures for functions that use ShardedVariable - the
846    # layer will be recreated during Keras model loading
847    # TODO(jmullenbach): support loading models with ShardedVariables using
848    # tf.saved_model.load
849    return None
850
851  def _should_act_as_resource_variable(self):
852    """Pass resource_variable_ops.is_resource_variable check."""
853    return True
854
855  def _write_object_proto(self, proto, options):
856    resource_variable_ops.write_object_proto_for_resource_variable(
857        self._saving_variable, proto, options, enforce_naming=False)
858
859
860def _var_to_tensor(var, dtype=None, name=None, as_ref=False):
861  """Converts a `ShardedVariable` to a `Tensor`."""
862  del name
863  if dtype is not None and not dtype.is_compatible_with(var.dtype):
864    raise ValueError(
865        'Incompatible type conversion requested to type {!r} for variable '
866        'of type {!r}'.format(dtype.name, var.dtype.name))
867  if as_ref:
868    raise NotImplementedError(
869        "ShardedVariable doesn't support being used as a reference.")
870  # We use op dispatch mechanism to override embedding_lookup ops when called
871  # with ShardedVariable. This requires embedding_lookup ops to raise TypeError
872  # when called with ShardedVariable. However since ShardedVariable can be
873  # converted to a tensor via concat, embedding_lookup ops would silently
874  # do the convertion and never raise a TypeError. To be able to properly
875  # raise a TypeError, namescope is used to detect if this method is called
876  # within a embedding_lookup op.
877  # NOTE: This doesn't work in eager mode since op namescope is always cleared
878  # in eager. This also breaks if user sets the name of embedding_lookup op
879  # with something that doesn't contain str "embedding_lookup".
880  #
881  # TODO(chenkai): Find a more robust way to do this, which should not rely
882  # on namescope.
883  if 'embedding_lookup' in ops.get_name_scope():
884    raise TypeError('Converting ShardedVariable to tensor in embedding lookup'
885                    ' ops is disallowed.')
886  return array_ops.concat(var.variables, axis=0)
887
888
889# Register a conversion function which reads the value of the variable,
890# allowing instances of the class to be used as tensors.
891ops.register_tensor_conversion_function(ShardedVariable, _var_to_tensor)
892
893ShardedVariable._overload_all_operators()  # pylint: disable=protected-access
894
895
896# Override the behavior of embedding_lookup(sharded_variable, ...)
897@dispatch.dispatch_for_types(embedding_ops.embedding_lookup, ShardedVariable)
898def embedding_lookup(params,
899                     ids,
900                     partition_strategy='mod',
901                     name=None,
902                     validate_indices=True,
903                     max_norm=None):
904  if isinstance(params, list):
905    params = params[0]
906  return embedding_ops.embedding_lookup(params.variables, ids,
907                                        partition_strategy, name,
908                                        validate_indices, max_norm)
909