xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the 'License');
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an 'AS IS' BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ======================================
15"""Experimental support for defining XLA shardings."""
16
17import numpy as _np  # Avoids becoming a part of public Tensorflow API.
18
19from tensorflow.compiler.tf2xla.python import xla as tf2xla
20from tensorflow.compiler.xla import xla_data_pb2
21from tensorflow.core.framework import attr_value_pb2
22
23
24class Sharding(object):
25  """A class to support adding sharding attributes to Ops.
26
27  Use the factory constructors and then call apply_to_tensor:
28    Sharding.replicate().apply_to_tensor(tensor)
29  """
30
31  def __init__(self, proto=None):
32    """Do not use this constructor; use the factory functions below."""
33    self._proto = proto
34
35  @classmethod
36  def replicate(cls):
37    """Returns a replicated sharding attribute.
38
39    This causes an op to be computed in its entirety independently on all
40    cores in the XLA device.
41    """
42    return Sharding(
43        proto=xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.REPLICATED))
44
45  @classmethod
46  def manual(cls):
47    """Returns a manuall sharding attribute.
48
49    This means the op is manually partitioned by the user and XLA will not
50    change the shapes.
51    """
52    return Sharding(
53        proto=xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.MANUAL))
54
55  @classmethod
56  def assign_device(cls, core):
57    """Returns an AssignDevice sharding attribute.
58
59    This causes an op to be computed in its entirety only on one core in
60    the XLA device.
61    Args:
62      core: The core to assign this Op to.
63    """
64    return Sharding(
65        proto=xla_data_pb2.OpSharding(
66            type=xla_data_pb2.OpSharding.MAXIMAL,
67            tile_assignment_dimensions=[1],
68            tile_assignment_devices=[core]))
69
70  @classmethod
71  def tile(cls, tile_assignment):
72    """Returns a Tiled sharding attribute.
73
74    This causes an op to be partially computed on multiple cores in the
75    XLA device.
76
77    Args:
78      tile_assignment: An np.ndarray describing the topology of the tiling and
79        which device will compute which part of the topology.
80
81    Raises:
82      TypeError: tile_assignment was not of np.array type.
83
84    TODO(jmolloy): This concept is nefarious and is not
85    something we really want to expose to users (especially as the
86    contract for tile_assignment is very strict).
87    """
88    if not isinstance(tile_assignment, _np.ndarray):
89      raise TypeError('Tile assignment must be of type np.ndarray')
90    dims = list(tile_assignment.shape)
91    flattened_devices = tile_assignment.reshape(-1, order='C')
92    return Sharding(
93        proto=xla_data_pb2.OpSharding(
94            type=xla_data_pb2.OpSharding.OTHER,
95            tile_assignment_dimensions=dims,
96            tile_assignment_devices=list(flattened_devices)))
97
98  @classmethod
99  def subgroup_tile(cls, tile_assignment, subgroup_modes):
100    """Returns a subgroup manual sharding attribute.
101
102    This is similar to tile(), but tile_assignment has one or more dimension
103    than the tensor, and subgroup_modes define the sharding types in the last
104    dimensions of tile_assignment.
105
106    Args:
107      tile_assignment: An np.ndarray describing the topology of the tiling and
108        which device will compute which part of the topology.
109      subgroup_modes: sharding types for the dimension more than the tensor
110        shape rank.
111
112    Raises:
113      TypeError: tile_assignment was not of np.array type or subgroup_modes
114        has unsupported sharding type.
115    """
116    if not isinstance(tile_assignment, _np.ndarray):
117      raise TypeError('SubgroupTile assignment must be of type np.ndarray')
118
119    if not isinstance(subgroup_modes, list):
120      raise TypeError('subgroup_modes in subgroup manual must be of type list')
121
122    if len(tile_assignment.shape) < len(subgroup_modes):
123      raise TypeError('SubgroupTile assignment must have rank larger than'
124                      ' length of subgroup_modes')
125
126    for sharding_type in subgroup_modes:
127      if sharding_type not in [
128          xla_data_pb2.OpSharding.REPLICATED, xla_data_pb2.OpSharding.MANUAL
129      ]:
130        raise TypeError(
131            'Each sharding_type in subgroup_modes in subgroup manual must '
132            'be of type xla_data_pb2.OpSharding.REPLICATED'
133            ' or xla_data_pb2.OpSharding.MANUAL')
134    dims = list(tile_assignment.shape)
135    flattened_devices = tile_assignment.reshape(-1, order='C')
136    return Sharding(
137        proto=xla_data_pb2.OpSharding(
138            type=xla_data_pb2.OpSharding.OTHER,
139            tile_assignment_dimensions=dims,
140            tile_assignment_devices=list(flattened_devices),
141            last_tile_dims=list(subgroup_modes)))
142
143  @classmethod
144  def partial_tile(cls, tile_assignment):
145    """Returns a partially tiled sharding attribute.
146
147    This is similar to tile(), but tile_assignment has one more dimension than
148    the tensor, and tiles in the last dimension of tile_assignment are
149    replicated.
150
151    Args:
152      tile_assignment: An np.ndarray describing the topology of the tiling and
153        which device will compute which part of the topology.
154
155    Raises:
156      TypeError: tile_assignment was not of np.array type.
157    """
158    if not isinstance(tile_assignment, _np.ndarray):
159      raise TypeError('PartialTile assignment must be of type np.ndarray')
160    dims = list(tile_assignment.shape)
161    flattened_devices = tile_assignment.reshape(-1, order='C')
162    return Sharding(
163        proto=xla_data_pb2.OpSharding(
164            type=xla_data_pb2.OpSharding.OTHER,
165            tile_assignment_dimensions=dims,
166            tile_assignment_devices=list(flattened_devices),
167            replicate_on_last_tile_dim=True))
168
169  @classmethod
170  def split(cls, tensor, split_dimension, num_devices, input_shape=None):
171    """Returns a Sharding that splits a tensor across a dimension.
172
173    This creates a Tiled attribute, similar to tile(), but easier to use for the
174    common case of tiling a tensor N ways in one dimension.
175
176    Args:
177      tensor: A tf.Tensor to split.
178      split_dimension: The dimension number to split.
179      num_devices: The number of cores to split `tensor` over.
180      input_shape: The shape of the original tensor.
181
182    Raises:
183      ValueError: The tensor to split was smaller in the split dimension than
184        the number of devices to split over.
185    """
186    if input_shape:
187      shape = input_shape
188    else:
189      shape = tensor.shape.as_list()
190    if (shape[split_dimension] is not None and
191        shape[split_dimension] < num_devices):
192      raise ValueError('Split dimension was smaller than the required number '
193                       'of splits: shape=%r, dimension=%r, num_devices=%r' %
194                       (shape, split_dimension, num_devices))
195
196    tile_assignment_dims = [1] * len(shape)
197    tile_assignment_dims[split_dimension] = num_devices
198
199    return Sharding(
200        proto=xla_data_pb2.OpSharding(
201            type=xla_data_pb2.OpSharding.OTHER,
202            tile_assignment_dimensions=tile_assignment_dims,
203            tile_assignment_devices=range(num_devices)))
204
205  def apply_to_tensor(self,
206                      tensor,
207                      assign_tuple_sharding=False,
208                      use_sharding_op=False,
209                      unspecified_dims=None):
210    """Applies this Sharding attribute to `tensor`.
211
212    Args:
213      tensor: A tf.Tensor to split.
214      assign_tuple_sharding: If the sharding type should be a tuple.
215      use_sharding_op: Whether to create a sharding op on `tensor`.
216      unspecified_dims: An optional list of dimensions unspecified.
217
218    Returns:
219      The tensor with Sharding attribute.
220    """
221    if unspecified_dims:
222      assert use_sharding_op and not assign_tuple_sharding
223    proto = self._proto
224    if use_sharding_op:
225      if assign_tuple_sharding:
226        proto = self._create_tuple_proto(num_outputs=1)
227        tensor = tf2xla.sharding(tensor, sharding=proto.SerializeToString())
228      else:
229        tensor = tf2xla.sharding(
230            tensor,
231            sharding=proto.SerializeToString(),
232            unspecified_dims=unspecified_dims or [])
233    elif assign_tuple_sharding or len(tensor.op.outputs) > 1:
234      proto = self._get_or_create_tuple_proto(tensor.op)
235      # We can't mutate an element of old_proto.tuple_shardings, so create
236      # a new proto.
237      tuple_shardings = list(proto.tuple_shardings)
238      tuple_shardings[tensor.value_index] = self._proto
239      proto = xla_data_pb2.OpSharding(
240          type=xla_data_pb2.OpSharding.TUPLE, tuple_shardings=tuple_shardings)
241
242    # TODO(jmolloy): This need to be seriously revisited before declaring this
243    # API available for public use.
244    # pylint: disable=protected-access
245    tensor.op._set_attr('_XlaSharding',
246                        attr_value_pb2.AttrValue(s=proto.SerializeToString()))
247    return tensor
248
249  def apply_to_operation(self, operation):
250    """Applies this Sharding attribute to `operation`.
251
252    Args:
253      operation: A tf.Operation to add sharding annotation.
254    """
255    attr_value = attr_value_pb2.AttrValue(s=self._proto.SerializeToString())
256    # pylint: disable=protected-access
257    operation._set_attr('_XlaSharding', attr_value)
258
259  @property
260  def proto(self):
261    """Return the sharding protobuf of type xla_data_pb2.OpSharding."""
262    return self._proto
263
264  def _get_or_create_tuple_proto(self, op):
265    try:
266      attr = op.get_attr('_XlaSharding')
267      proto = xla_data_pb2.OpSharding()
268      proto.ParseFromString(attr)
269      return proto
270    except ValueError:
271      return self._create_tuple_proto(len(op.outputs))
272
273  def _create_tuple_proto(self, num_outputs):
274    shardings = [
275        xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.REPLICATED)
276    ] * num_outputs
277    return xla_data_pb2.OpSharding(
278        type=xla_data_pb2.OpSharding.TUPLE, tuple_shardings=shardings)
279
280
281def copy_sharding(from_tensor, to_tensor, use_sharding_op=False):
282  """Copies the a tensor's sharding to another.
283
284  Args:
285    from_tensor: Source tensor. Must be the sole output of an op.
286    to_tensor: the tensor the annotate with the copy.
287    use_sharding_op: whether to create a sharding op on `to_tensor`.
288
289  Returns:
290    A tensor with sharding annotation copied from `from_tensor`.
291  """
292  sharding = get_tensor_sharding(from_tensor)
293  if sharding is None:
294    return to_tensor
295
296  if use_sharding_op:
297    to_tensor = tf2xla.sharding(to_tensor, sharding=sharding)
298  attr_value = attr_value_pb2.AttrValue(s=sharding)
299  # pylint: disable=protected-access
300  to_tensor.op._set_attr('_XlaSharding', attr_value)
301  return to_tensor
302
303# Helpers for the above factory functions that allow easy application of
304# shardings, for example:
305#   tensor = xla_sharding.replicate(tensor)
306
307
308def replicate(tensor, assign_tuple_sharding=False, use_sharding_op=False):
309  return Sharding.replicate().apply_to_tensor(
310      tensor,
311      assign_tuple_sharding=assign_tuple_sharding,
312      use_sharding_op=use_sharding_op)
313
314
315def assign_device(tensor,
316                  device,
317                  assign_tuple_sharding=False,
318                  use_sharding_op=False):
319  """Returns a tensor that has AssignDevice sharding attribute."""
320  return Sharding.assign_device(device).apply_to_tensor(
321      tensor,
322      assign_tuple_sharding=assign_tuple_sharding,
323      use_sharding_op=use_sharding_op)
324
325
326def tile(tensor,
327         tile_assignment,
328         assign_tuple_sharding=False,
329         use_sharding_op=False,
330         unspecified_dims=None):
331  """Returns a tensor that has tiled sharding.
332
333  Args:
334    tensor: A tf.Tensor to shard.
335    tile_assignment: An np.ndarray describing the topology of the tiling and
336      which device will compute which part of the topology.
337    assign_tuple_sharding: If the sharding type should be a tuple.
338    use_sharding_op: If true, adds a sharding op to set the sharding.
339    unspecified_dims: An optional list of dimensions unspecified.
340  """
341  return Sharding.tile(tile_assignment).apply_to_tensor(
342      tensor,
343      assign_tuple_sharding=assign_tuple_sharding,
344      use_sharding_op=use_sharding_op,
345      unspecified_dims=unspecified_dims or [])
346
347
348def split(tensor,
349          split_dimension,
350          num_devices,
351          assign_tuple_sharding=False,
352          use_sharding_op=False,
353          input_shape=None):
354  """Returns a tensor that is split along the given dimension.
355
356  Args:
357    tensor: A tf.Tensor to split.
358    split_dimension: The dimension to split.
359    num_devices: The number of devices to partition the dimension.
360    assign_tuple_sharding: If the sharding type should be a tuple.
361    use_sharding_op: If true, adds a sharding op to set the sharding.
362    input_shape: The full shape of the input tensor.
363  """
364  return Sharding.split(tensor, split_dimension, num_devices,
365                        input_shape).apply_to_tensor(
366                            tensor,
367                            assign_tuple_sharding=assign_tuple_sharding,
368                            use_sharding_op=use_sharding_op)
369
370
371def partial_tile(tensor,
372                 tile_assignment,
373                 use_sharding_op=False,
374                 unspecified_dims=None):
375  """Returns a tensor that has tiled sharding.
376
377  Args:
378    tensor: A tf.Tensor to shard.
379    tile_assignment: An np.ndarray describing the topology of the tiling and
380      which device will compute which part of the topology. It must have one
381      more dimension than tensor, and the last dimension represents partially
382      replicated tiles.
383    use_sharding_op: If true, adds a sharding op to set the sharding.
384    unspecified_dims: An optional list of dimensions unspecified.
385  """
386  return Sharding.partial_tile(tile_assignment).apply_to_tensor(
387      tensor,
388      use_sharding_op=use_sharding_op,
389      unspecified_dims=unspecified_dims or [])
390
391
392def get_op_sharding(op):
393  """Returns sharding attribute of an op.
394
395  Args:
396    op: a TensorFlow op.
397
398  Returns:
399    The attribute representing XLA sharding on this op.
400  """
401  try:
402    return op.get_attr('_XlaSharding')
403  except ValueError:
404    return None
405  except AttributeError:
406    # AttributeError: 'DistributedVarOp' object has no attribute 'get_attr'.
407    return None
408
409
410def get_tensor_sharding(tensor):
411  """Returns sharding attribute of a Tensor.
412
413  Args:
414    tensor: a Tensor.
415
416  Returns:
417    The attribute representing XLA sharding on tensor's op.
418  """
419  try:
420    return get_op_sharding(tensor.op)
421  except AttributeError:
422    # AttributeError: Tensor.op is meaningless when eager execution is enabled.
423    return None
424
425
426def get_sharding_tile_shape(sharding):
427  """Returns the tile assignment shape for a sharded Tensor.
428
429  Args:
430    sharding: a serialized OpSharding message describing the layout of a
431      sharded Tensor.
432
433  Returns:
434    A list, for each dimension of the sharded Tensor, of the number of shards
435      into which it has been split. Returns None if the input indicates no tile
436      assignments.
437  """
438  if sharding is None:
439    return None
440  sharding_message = xla_data_pb2.OpSharding()
441  sharding_message.ParseFromString(sharding)
442  if sharding_message.tile_assignment_dimensions:
443    return sharding_message.tile_assignment_dimensions
444  else:
445    return None
446
447
448def auto_to_manual_spmd_partition(tensor,
449                                  manual_sharding,
450                                  single_dim=-1,
451                                  unspecified_dims=None):
452  """Switches from automatic SPMD partitioning to manual partitioning.
453
454  Converts a full-shaped tensor (to be automatically partitioned by SPMD
455  partitioner) to a shard-shaped tensor to be consumed by manually partitioned
456  ops.
457
458  Args:
459    tensor: A tf.Tensor in full shape.
460    manual_sharding: A serialized string of OpSharding to be used in manual
461      partitioning.
462    single_dim: If >= 0, the conversion will happen only on this dim in
463      subgroups.
464    unspecified_dims: An optional list of dimensions unspecified.
465
466  Returns:
467    A shard-shaped tensor to be consumed by manually partitioned ops.
468  """
469  return tf2xla.spmd_full_to_shard_shape(
470      tensor,
471      manual_sharding=manual_sharding,
472      dim=single_dim,
473      unspecified_dims=unspecified_dims or [])
474
475
476def manual_to_auto_spmd_partition(tensor,
477                                  manual_sharding,
478                                  full_shape,
479                                  single_dim=-1,
480                                  unspecified_dims=None):
481  """Switches from manual partitioning to automatic SPMD partitioning.
482
483  Converts a shard-shaped tensor (manually partitioned in SPMD-style) to a
484  full-shaped tensor to be partitioned automatically by the SPMD partitioner.
485
486  Args:
487    tensor: A tf.Tensor in shard shape.
488    manual_sharding: a serialized string of OpSharding to be used in manual
489      partitioning.
490    full_shape: the shape of tensor before partitioning.
491    single_dim: If >= 0, the conversion will happen only on this dim in
492      subgroups.
493    unspecified_dims: An optional list of dimensions unspecified.
494
495  Returns:
496    A full-shaped tensor to be partitioned automatically by the SPMD
497    partitioner.
498  """
499  return tf2xla.spmd_shard_to_full_shape(
500      tensor,
501      manual_sharding=manual_sharding,
502      full_shape=full_shape,
503      dim=single_dim,
504      unspecified_dims=unspecified_dims or [])
505
506
507def mesh_split_sharding(device_mesh,
508                        tensor_split_dims_mapping,
509                        manual_mesh_dims=None):
510  """Returns a Sharding object representing sharding along multiple dimensions.
511
512  Args:
513    device_mesh: An np.ndarray describing the topology of the device mesh and
514      each element is the ID of the device in the topology.
515    tensor_split_dims_mapping: A list of integers that map each tensor axis to
516      the device mesh axis along which it is sharded. Its length is the tensor
517      rank, and tensor_split_dims_mapping[i] is device mesh axis for tensor
518      dimension i. Use -1 for tensor dimensions that are not sharded.
519    manual_mesh_dims: An optional list of mesh dims for manual subgroups.
520
521  Raises:
522    ValueError: The number of tensor split dimensions is larger than device mesh
523      rank.
524  """
525  manual_mesh_dims = manual_mesh_dims or []
526  permutation = [d for d in tensor_split_dims_mapping if d >= 0
527                ] + manual_mesh_dims
528  if len(permutation) > len(device_mesh.shape):
529    raise ValueError(
530        'Number of tensor split dimensions (%r) is larger than device mesh '
531        'rank (%r). tensor_split_dims_mapping: %r, device_mesh.shape: %r' %
532        (len(permutation), len(
533            device_mesh.shape), tensor_split_dims_mapping, device_mesh.shape))
534  # Append replicated dimensions to the end.
535  transpose_permutation = permutation + [
536      d for d in range(len(device_mesh.shape)) if d not in permutation
537  ]
538  tile_assignment = _np.transpose(device_mesh, transpose_permutation)
539  tile_shape = [
540      1 if d < 0 else device_mesh.shape[d]
541      for d in (tensor_split_dims_mapping + manual_mesh_dims)
542  ]
543  subgroup_modes = [xla_data_pb2.OpSharding.MANUAL] * len(manual_mesh_dims)
544  partial = len(permutation) < len(device_mesh.shape)
545  if partial:
546    tile_shape.append(_np.prod(device_mesh.shape) // _np.prod(tile_shape))
547    subgroup_modes.append(xla_data_pb2.OpSharding.REPLICATED)
548  tile_assignment = _np.reshape(tile_assignment, tile_shape)
549
550  if manual_mesh_dims:
551    return Sharding.subgroup_tile(tile_assignment, subgroup_modes)
552
553  if partial:
554    return Sharding.partial_tile(tile_assignment)
555  return Sharding.tile(tile_assignment)
556
557
558def mesh_split(tensor,
559               device_mesh,
560               tensor_split_dims_mapping,
561               use_sharding_op=False,
562               manual_mesh_dims=None,
563               unspecified_dims=None):
564  """Returns a tensor that is split along multiple dimensions in a device mesh.
565
566  Args:
567    tensor: A tf.Tensor to split.
568    device_mesh: An np.ndarray describing the topology of the device mesh and
569      each element is the ID of the device in the topology.
570    tensor_split_dims_mapping: A list of integers that map each tensor axis to
571      the device mesh axis along which it is sharded. Its length is the tensor
572      rank, and tensor_split_dims_mapping[i] is device mesh axis for tensor
573      dimension i. Use -1 for tensor dimensions that are not sharded.
574    use_sharding_op: If true, adds a sharding op to set the sharding.
575    manual_mesh_dims: An optional list of mesh dims for manual subgroups.
576    unspecified_dims: An optional list of dimensions unspecified.
577
578  Raises:
579    ValueError: The number of tensor split dimensions is larger than device mesh
580      rank.
581  """
582  sharding = mesh_split_sharding(device_mesh, tensor_split_dims_mapping,
583                                 manual_mesh_dims)
584  return sharding.apply_to_tensor(
585      tensor,
586      use_sharding_op=use_sharding_op,
587      unspecified_dims=unspecified_dims or [])
588