1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the 'License'); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an 'AS IS' BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ====================================== 15"""Experimental support for defining XLA shardings.""" 16 17import numpy as _np # Avoids becoming a part of public Tensorflow API. 18 19from tensorflow.compiler.tf2xla.python import xla as tf2xla 20from tensorflow.compiler.xla import xla_data_pb2 21from tensorflow.core.framework import attr_value_pb2 22 23 24class Sharding(object): 25 """A class to support adding sharding attributes to Ops. 26 27 Use the factory constructors and then call apply_to_tensor: 28 Sharding.replicate().apply_to_tensor(tensor) 29 """ 30 31 def __init__(self, proto=None): 32 """Do not use this constructor; use the factory functions below.""" 33 self._proto = proto 34 35 @classmethod 36 def replicate(cls): 37 """Returns a replicated sharding attribute. 38 39 This causes an op to be computed in its entirety independently on all 40 cores in the XLA device. 41 """ 42 return Sharding( 43 proto=xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.REPLICATED)) 44 45 @classmethod 46 def manual(cls): 47 """Returns a manuall sharding attribute. 48 49 This means the op is manually partitioned by the user and XLA will not 50 change the shapes. 51 """ 52 return Sharding( 53 proto=xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.MANUAL)) 54 55 @classmethod 56 def assign_device(cls, core): 57 """Returns an AssignDevice sharding attribute. 58 59 This causes an op to be computed in its entirety only on one core in 60 the XLA device. 61 Args: 62 core: The core to assign this Op to. 63 """ 64 return Sharding( 65 proto=xla_data_pb2.OpSharding( 66 type=xla_data_pb2.OpSharding.MAXIMAL, 67 tile_assignment_dimensions=[1], 68 tile_assignment_devices=[core])) 69 70 @classmethod 71 def tile(cls, tile_assignment): 72 """Returns a Tiled sharding attribute. 73 74 This causes an op to be partially computed on multiple cores in the 75 XLA device. 76 77 Args: 78 tile_assignment: An np.ndarray describing the topology of the tiling and 79 which device will compute which part of the topology. 80 81 Raises: 82 TypeError: tile_assignment was not of np.array type. 83 84 TODO(jmolloy): This concept is nefarious and is not 85 something we really want to expose to users (especially as the 86 contract for tile_assignment is very strict). 87 """ 88 if not isinstance(tile_assignment, _np.ndarray): 89 raise TypeError('Tile assignment must be of type np.ndarray') 90 dims = list(tile_assignment.shape) 91 flattened_devices = tile_assignment.reshape(-1, order='C') 92 return Sharding( 93 proto=xla_data_pb2.OpSharding( 94 type=xla_data_pb2.OpSharding.OTHER, 95 tile_assignment_dimensions=dims, 96 tile_assignment_devices=list(flattened_devices))) 97 98 @classmethod 99 def subgroup_tile(cls, tile_assignment, subgroup_modes): 100 """Returns a subgroup manual sharding attribute. 101 102 This is similar to tile(), but tile_assignment has one or more dimension 103 than the tensor, and subgroup_modes define the sharding types in the last 104 dimensions of tile_assignment. 105 106 Args: 107 tile_assignment: An np.ndarray describing the topology of the tiling and 108 which device will compute which part of the topology. 109 subgroup_modes: sharding types for the dimension more than the tensor 110 shape rank. 111 112 Raises: 113 TypeError: tile_assignment was not of np.array type or subgroup_modes 114 has unsupported sharding type. 115 """ 116 if not isinstance(tile_assignment, _np.ndarray): 117 raise TypeError('SubgroupTile assignment must be of type np.ndarray') 118 119 if not isinstance(subgroup_modes, list): 120 raise TypeError('subgroup_modes in subgroup manual must be of type list') 121 122 if len(tile_assignment.shape) < len(subgroup_modes): 123 raise TypeError('SubgroupTile assignment must have rank larger than' 124 ' length of subgroup_modes') 125 126 for sharding_type in subgroup_modes: 127 if sharding_type not in [ 128 xla_data_pb2.OpSharding.REPLICATED, xla_data_pb2.OpSharding.MANUAL 129 ]: 130 raise TypeError( 131 'Each sharding_type in subgroup_modes in subgroup manual must ' 132 'be of type xla_data_pb2.OpSharding.REPLICATED' 133 ' or xla_data_pb2.OpSharding.MANUAL') 134 dims = list(tile_assignment.shape) 135 flattened_devices = tile_assignment.reshape(-1, order='C') 136 return Sharding( 137 proto=xla_data_pb2.OpSharding( 138 type=xla_data_pb2.OpSharding.OTHER, 139 tile_assignment_dimensions=dims, 140 tile_assignment_devices=list(flattened_devices), 141 last_tile_dims=list(subgroup_modes))) 142 143 @classmethod 144 def partial_tile(cls, tile_assignment): 145 """Returns a partially tiled sharding attribute. 146 147 This is similar to tile(), but tile_assignment has one more dimension than 148 the tensor, and tiles in the last dimension of tile_assignment are 149 replicated. 150 151 Args: 152 tile_assignment: An np.ndarray describing the topology of the tiling and 153 which device will compute which part of the topology. 154 155 Raises: 156 TypeError: tile_assignment was not of np.array type. 157 """ 158 if not isinstance(tile_assignment, _np.ndarray): 159 raise TypeError('PartialTile assignment must be of type np.ndarray') 160 dims = list(tile_assignment.shape) 161 flattened_devices = tile_assignment.reshape(-1, order='C') 162 return Sharding( 163 proto=xla_data_pb2.OpSharding( 164 type=xla_data_pb2.OpSharding.OTHER, 165 tile_assignment_dimensions=dims, 166 tile_assignment_devices=list(flattened_devices), 167 replicate_on_last_tile_dim=True)) 168 169 @classmethod 170 def split(cls, tensor, split_dimension, num_devices, input_shape=None): 171 """Returns a Sharding that splits a tensor across a dimension. 172 173 This creates a Tiled attribute, similar to tile(), but easier to use for the 174 common case of tiling a tensor N ways in one dimension. 175 176 Args: 177 tensor: A tf.Tensor to split. 178 split_dimension: The dimension number to split. 179 num_devices: The number of cores to split `tensor` over. 180 input_shape: The shape of the original tensor. 181 182 Raises: 183 ValueError: The tensor to split was smaller in the split dimension than 184 the number of devices to split over. 185 """ 186 if input_shape: 187 shape = input_shape 188 else: 189 shape = tensor.shape.as_list() 190 if (shape[split_dimension] is not None and 191 shape[split_dimension] < num_devices): 192 raise ValueError('Split dimension was smaller than the required number ' 193 'of splits: shape=%r, dimension=%r, num_devices=%r' % 194 (shape, split_dimension, num_devices)) 195 196 tile_assignment_dims = [1] * len(shape) 197 tile_assignment_dims[split_dimension] = num_devices 198 199 return Sharding( 200 proto=xla_data_pb2.OpSharding( 201 type=xla_data_pb2.OpSharding.OTHER, 202 tile_assignment_dimensions=tile_assignment_dims, 203 tile_assignment_devices=range(num_devices))) 204 205 def apply_to_tensor(self, 206 tensor, 207 assign_tuple_sharding=False, 208 use_sharding_op=False, 209 unspecified_dims=None): 210 """Applies this Sharding attribute to `tensor`. 211 212 Args: 213 tensor: A tf.Tensor to split. 214 assign_tuple_sharding: If the sharding type should be a tuple. 215 use_sharding_op: Whether to create a sharding op on `tensor`. 216 unspecified_dims: An optional list of dimensions unspecified. 217 218 Returns: 219 The tensor with Sharding attribute. 220 """ 221 if unspecified_dims: 222 assert use_sharding_op and not assign_tuple_sharding 223 proto = self._proto 224 if use_sharding_op: 225 if assign_tuple_sharding: 226 proto = self._create_tuple_proto(num_outputs=1) 227 tensor = tf2xla.sharding(tensor, sharding=proto.SerializeToString()) 228 else: 229 tensor = tf2xla.sharding( 230 tensor, 231 sharding=proto.SerializeToString(), 232 unspecified_dims=unspecified_dims or []) 233 elif assign_tuple_sharding or len(tensor.op.outputs) > 1: 234 proto = self._get_or_create_tuple_proto(tensor.op) 235 # We can't mutate an element of old_proto.tuple_shardings, so create 236 # a new proto. 237 tuple_shardings = list(proto.tuple_shardings) 238 tuple_shardings[tensor.value_index] = self._proto 239 proto = xla_data_pb2.OpSharding( 240 type=xla_data_pb2.OpSharding.TUPLE, tuple_shardings=tuple_shardings) 241 242 # TODO(jmolloy): This need to be seriously revisited before declaring this 243 # API available for public use. 244 # pylint: disable=protected-access 245 tensor.op._set_attr('_XlaSharding', 246 attr_value_pb2.AttrValue(s=proto.SerializeToString())) 247 return tensor 248 249 def apply_to_operation(self, operation): 250 """Applies this Sharding attribute to `operation`. 251 252 Args: 253 operation: A tf.Operation to add sharding annotation. 254 """ 255 attr_value = attr_value_pb2.AttrValue(s=self._proto.SerializeToString()) 256 # pylint: disable=protected-access 257 operation._set_attr('_XlaSharding', attr_value) 258 259 @property 260 def proto(self): 261 """Return the sharding protobuf of type xla_data_pb2.OpSharding.""" 262 return self._proto 263 264 def _get_or_create_tuple_proto(self, op): 265 try: 266 attr = op.get_attr('_XlaSharding') 267 proto = xla_data_pb2.OpSharding() 268 proto.ParseFromString(attr) 269 return proto 270 except ValueError: 271 return self._create_tuple_proto(len(op.outputs)) 272 273 def _create_tuple_proto(self, num_outputs): 274 shardings = [ 275 xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.REPLICATED) 276 ] * num_outputs 277 return xla_data_pb2.OpSharding( 278 type=xla_data_pb2.OpSharding.TUPLE, tuple_shardings=shardings) 279 280 281def copy_sharding(from_tensor, to_tensor, use_sharding_op=False): 282 """Copies the a tensor's sharding to another. 283 284 Args: 285 from_tensor: Source tensor. Must be the sole output of an op. 286 to_tensor: the tensor the annotate with the copy. 287 use_sharding_op: whether to create a sharding op on `to_tensor`. 288 289 Returns: 290 A tensor with sharding annotation copied from `from_tensor`. 291 """ 292 sharding = get_tensor_sharding(from_tensor) 293 if sharding is None: 294 return to_tensor 295 296 if use_sharding_op: 297 to_tensor = tf2xla.sharding(to_tensor, sharding=sharding) 298 attr_value = attr_value_pb2.AttrValue(s=sharding) 299 # pylint: disable=protected-access 300 to_tensor.op._set_attr('_XlaSharding', attr_value) 301 return to_tensor 302 303# Helpers for the above factory functions that allow easy application of 304# shardings, for example: 305# tensor = xla_sharding.replicate(tensor) 306 307 308def replicate(tensor, assign_tuple_sharding=False, use_sharding_op=False): 309 return Sharding.replicate().apply_to_tensor( 310 tensor, 311 assign_tuple_sharding=assign_tuple_sharding, 312 use_sharding_op=use_sharding_op) 313 314 315def assign_device(tensor, 316 device, 317 assign_tuple_sharding=False, 318 use_sharding_op=False): 319 """Returns a tensor that has AssignDevice sharding attribute.""" 320 return Sharding.assign_device(device).apply_to_tensor( 321 tensor, 322 assign_tuple_sharding=assign_tuple_sharding, 323 use_sharding_op=use_sharding_op) 324 325 326def tile(tensor, 327 tile_assignment, 328 assign_tuple_sharding=False, 329 use_sharding_op=False, 330 unspecified_dims=None): 331 """Returns a tensor that has tiled sharding. 332 333 Args: 334 tensor: A tf.Tensor to shard. 335 tile_assignment: An np.ndarray describing the topology of the tiling and 336 which device will compute which part of the topology. 337 assign_tuple_sharding: If the sharding type should be a tuple. 338 use_sharding_op: If true, adds a sharding op to set the sharding. 339 unspecified_dims: An optional list of dimensions unspecified. 340 """ 341 return Sharding.tile(tile_assignment).apply_to_tensor( 342 tensor, 343 assign_tuple_sharding=assign_tuple_sharding, 344 use_sharding_op=use_sharding_op, 345 unspecified_dims=unspecified_dims or []) 346 347 348def split(tensor, 349 split_dimension, 350 num_devices, 351 assign_tuple_sharding=False, 352 use_sharding_op=False, 353 input_shape=None): 354 """Returns a tensor that is split along the given dimension. 355 356 Args: 357 tensor: A tf.Tensor to split. 358 split_dimension: The dimension to split. 359 num_devices: The number of devices to partition the dimension. 360 assign_tuple_sharding: If the sharding type should be a tuple. 361 use_sharding_op: If true, adds a sharding op to set the sharding. 362 input_shape: The full shape of the input tensor. 363 """ 364 return Sharding.split(tensor, split_dimension, num_devices, 365 input_shape).apply_to_tensor( 366 tensor, 367 assign_tuple_sharding=assign_tuple_sharding, 368 use_sharding_op=use_sharding_op) 369 370 371def partial_tile(tensor, 372 tile_assignment, 373 use_sharding_op=False, 374 unspecified_dims=None): 375 """Returns a tensor that has tiled sharding. 376 377 Args: 378 tensor: A tf.Tensor to shard. 379 tile_assignment: An np.ndarray describing the topology of the tiling and 380 which device will compute which part of the topology. It must have one 381 more dimension than tensor, and the last dimension represents partially 382 replicated tiles. 383 use_sharding_op: If true, adds a sharding op to set the sharding. 384 unspecified_dims: An optional list of dimensions unspecified. 385 """ 386 return Sharding.partial_tile(tile_assignment).apply_to_tensor( 387 tensor, 388 use_sharding_op=use_sharding_op, 389 unspecified_dims=unspecified_dims or []) 390 391 392def get_op_sharding(op): 393 """Returns sharding attribute of an op. 394 395 Args: 396 op: a TensorFlow op. 397 398 Returns: 399 The attribute representing XLA sharding on this op. 400 """ 401 try: 402 return op.get_attr('_XlaSharding') 403 except ValueError: 404 return None 405 except AttributeError: 406 # AttributeError: 'DistributedVarOp' object has no attribute 'get_attr'. 407 return None 408 409 410def get_tensor_sharding(tensor): 411 """Returns sharding attribute of a Tensor. 412 413 Args: 414 tensor: a Tensor. 415 416 Returns: 417 The attribute representing XLA sharding on tensor's op. 418 """ 419 try: 420 return get_op_sharding(tensor.op) 421 except AttributeError: 422 # AttributeError: Tensor.op is meaningless when eager execution is enabled. 423 return None 424 425 426def get_sharding_tile_shape(sharding): 427 """Returns the tile assignment shape for a sharded Tensor. 428 429 Args: 430 sharding: a serialized OpSharding message describing the layout of a 431 sharded Tensor. 432 433 Returns: 434 A list, for each dimension of the sharded Tensor, of the number of shards 435 into which it has been split. Returns None if the input indicates no tile 436 assignments. 437 """ 438 if sharding is None: 439 return None 440 sharding_message = xla_data_pb2.OpSharding() 441 sharding_message.ParseFromString(sharding) 442 if sharding_message.tile_assignment_dimensions: 443 return sharding_message.tile_assignment_dimensions 444 else: 445 return None 446 447 448def auto_to_manual_spmd_partition(tensor, 449 manual_sharding, 450 single_dim=-1, 451 unspecified_dims=None): 452 """Switches from automatic SPMD partitioning to manual partitioning. 453 454 Converts a full-shaped tensor (to be automatically partitioned by SPMD 455 partitioner) to a shard-shaped tensor to be consumed by manually partitioned 456 ops. 457 458 Args: 459 tensor: A tf.Tensor in full shape. 460 manual_sharding: A serialized string of OpSharding to be used in manual 461 partitioning. 462 single_dim: If >= 0, the conversion will happen only on this dim in 463 subgroups. 464 unspecified_dims: An optional list of dimensions unspecified. 465 466 Returns: 467 A shard-shaped tensor to be consumed by manually partitioned ops. 468 """ 469 return tf2xla.spmd_full_to_shard_shape( 470 tensor, 471 manual_sharding=manual_sharding, 472 dim=single_dim, 473 unspecified_dims=unspecified_dims or []) 474 475 476def manual_to_auto_spmd_partition(tensor, 477 manual_sharding, 478 full_shape, 479 single_dim=-1, 480 unspecified_dims=None): 481 """Switches from manual partitioning to automatic SPMD partitioning. 482 483 Converts a shard-shaped tensor (manually partitioned in SPMD-style) to a 484 full-shaped tensor to be partitioned automatically by the SPMD partitioner. 485 486 Args: 487 tensor: A tf.Tensor in shard shape. 488 manual_sharding: a serialized string of OpSharding to be used in manual 489 partitioning. 490 full_shape: the shape of tensor before partitioning. 491 single_dim: If >= 0, the conversion will happen only on this dim in 492 subgroups. 493 unspecified_dims: An optional list of dimensions unspecified. 494 495 Returns: 496 A full-shaped tensor to be partitioned automatically by the SPMD 497 partitioner. 498 """ 499 return tf2xla.spmd_shard_to_full_shape( 500 tensor, 501 manual_sharding=manual_sharding, 502 full_shape=full_shape, 503 dim=single_dim, 504 unspecified_dims=unspecified_dims or []) 505 506 507def mesh_split_sharding(device_mesh, 508 tensor_split_dims_mapping, 509 manual_mesh_dims=None): 510 """Returns a Sharding object representing sharding along multiple dimensions. 511 512 Args: 513 device_mesh: An np.ndarray describing the topology of the device mesh and 514 each element is the ID of the device in the topology. 515 tensor_split_dims_mapping: A list of integers that map each tensor axis to 516 the device mesh axis along which it is sharded. Its length is the tensor 517 rank, and tensor_split_dims_mapping[i] is device mesh axis for tensor 518 dimension i. Use -1 for tensor dimensions that are not sharded. 519 manual_mesh_dims: An optional list of mesh dims for manual subgroups. 520 521 Raises: 522 ValueError: The number of tensor split dimensions is larger than device mesh 523 rank. 524 """ 525 manual_mesh_dims = manual_mesh_dims or [] 526 permutation = [d for d in tensor_split_dims_mapping if d >= 0 527 ] + manual_mesh_dims 528 if len(permutation) > len(device_mesh.shape): 529 raise ValueError( 530 'Number of tensor split dimensions (%r) is larger than device mesh ' 531 'rank (%r). tensor_split_dims_mapping: %r, device_mesh.shape: %r' % 532 (len(permutation), len( 533 device_mesh.shape), tensor_split_dims_mapping, device_mesh.shape)) 534 # Append replicated dimensions to the end. 535 transpose_permutation = permutation + [ 536 d for d in range(len(device_mesh.shape)) if d not in permutation 537 ] 538 tile_assignment = _np.transpose(device_mesh, transpose_permutation) 539 tile_shape = [ 540 1 if d < 0 else device_mesh.shape[d] 541 for d in (tensor_split_dims_mapping + manual_mesh_dims) 542 ] 543 subgroup_modes = [xla_data_pb2.OpSharding.MANUAL] * len(manual_mesh_dims) 544 partial = len(permutation) < len(device_mesh.shape) 545 if partial: 546 tile_shape.append(_np.prod(device_mesh.shape) // _np.prod(tile_shape)) 547 subgroup_modes.append(xla_data_pb2.OpSharding.REPLICATED) 548 tile_assignment = _np.reshape(tile_assignment, tile_shape) 549 550 if manual_mesh_dims: 551 return Sharding.subgroup_tile(tile_assignment, subgroup_modes) 552 553 if partial: 554 return Sharding.partial_tile(tile_assignment) 555 return Sharding.tile(tile_assignment) 556 557 558def mesh_split(tensor, 559 device_mesh, 560 tensor_split_dims_mapping, 561 use_sharding_op=False, 562 manual_mesh_dims=None, 563 unspecified_dims=None): 564 """Returns a tensor that is split along multiple dimensions in a device mesh. 565 566 Args: 567 tensor: A tf.Tensor to split. 568 device_mesh: An np.ndarray describing the topology of the device mesh and 569 each element is the ID of the device in the topology. 570 tensor_split_dims_mapping: A list of integers that map each tensor axis to 571 the device mesh axis along which it is sharded. Its length is the tensor 572 rank, and tensor_split_dims_mapping[i] is device mesh axis for tensor 573 dimension i. Use -1 for tensor dimensions that are not sharded. 574 use_sharding_op: If true, adds a sharding op to set the sharding. 575 manual_mesh_dims: An optional list of mesh dims for manual subgroups. 576 unspecified_dims: An optional list of dimensions unspecified. 577 578 Raises: 579 ValueError: The number of tensor split dimensions is larger than device mesh 580 rank. 581 """ 582 sharding = mesh_split_sharding(device_mesh, tensor_split_dims_mapping, 583 manual_mesh_dims) 584 return sharding.apply_to_tensor( 585 tensor, 586 use_sharding_op=use_sharding_op, 587 unspecified_dims=unspecified_dims or []) 588