1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // HLO shardings describe how an HLO instruction is split across multiple 17 // computations. 18 19 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_ 20 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_ 21 22 #include <map> 23 #include <string> 24 #include <vector> 25 26 #include "absl/algorithm/container.h" 27 #include "absl/types/span.h" 28 #include "tensorflow/compiler/xla/array.h" 29 #include "tensorflow/compiler/xla/shape_tree.h" 30 #include "tensorflow/compiler/xla/xla_data.pb.h" 31 32 namespace xla { 33 34 // HLO shardings describe how an HLO instruction is split across multiple 35 // computations. 36 class HloSharding { 37 public: 38 // Creates a trivial sharding that replicates a maximal tile across all 39 // devices. 40 static HloSharding Replicate(absl::Span<const OpMetadata> metadata = {}) { 41 return HloSharding(/*manual=*/false, /*replicated=*/true, metadata); 42 } 43 44 // Creates a sharding that represents the op is manually partitioned. 45 static HloSharding Manual(absl::Span<const OpMetadata> metadata = {}) { 46 return HloSharding(/*manual=*/true, /*replicated=*/false, metadata); 47 } 48 49 // Creates a sharding that emulates device placement; a tile shape equal to 50 // the input shape (one tile) assigned to a single device. 51 static HloSharding AssignDevice(int64_t device_id, 52 absl::Span<const OpMetadata> metadata = {}); 53 54 // Creates a new sharding which splits a shape into tiles amongst the devices 55 // specified by `tile_assignment`. 56 static HloSharding Tile(const Array<int64_t>& tile_assignment, 57 absl::Span<const OpMetadata> metadata = {}) { 58 return HloSharding(tile_assignment, /*replicate_on_last_tile_dim=*/false, 59 metadata); 60 } 61 62 // Creates a new sharding where data is replicated within each replication 63 // group, and sharded across replication groups according to 64 // group_tile_assignment. Replication group members will be sorted. 65 static HloSharding PartialTile( 66 const Array<int64_t>& group_tile_assignment, 67 absl::Span<const absl::Span<const int64_t>> replication_groups, 68 absl::Span<const OpMetadata> metadata = {}); 69 70 // Creates a partially replicated tiled sharding with device-level tile 71 // assignment, where the last dimension is the additional replication 72 // dimension. Replication group members will be sorted. 73 static HloSharding PartialTile( 74 const Array<int64_t>& tile_assignment_last_dim_replicate, 75 absl::Span<const OpMetadata> metadata = {}); 76 77 // Creates a subgroup sharding with device-level tile assignment, the 78 // sharding type of each subgroup is defined by subgroup_types. When creating 79 // the HloSharding, subgroup dims of the same type will be merged. 80 static HloSharding Subgroup(const Array<int64_t>& tile_assignment, 81 absl::Span<const OpSharding::Type> subgroup_types, 82 absl::Span<const OpMetadata> metadata = {}); 83 84 // Creates a new sharding which splits a one-dimensional input shape into 85 // `num_tiles` tiles. 86 static HloSharding Tile1D(const Shape& input_shape, int64_t num_tiles, 87 absl::Span<const OpMetadata> metadata = {}); 88 89 // Creates a new sharding for a tuple type. The given ShapeTree must have 90 // elements for every leaf shape contained in the tuple. 91 static HloSharding Tuple(const ShapeTree<HloSharding>& sub_shardings); 92 93 // Creates a new sharding for a tuple type. The number of elements in 94 // shardings must match the number of leaf nodes in tuple_shape. For 95 // empty tuples, the shardings array must have one element. 96 static HloSharding Tuple(const Shape& tuple_shape, 97 absl::Span<const HloSharding> shardings); 98 99 // Creates a new sharding for a tuple type, with a single input sharding 100 // repeated on each leaf. 101 static HloSharding SingleTuple(const Shape& tuple_shape, 102 const HloSharding& sharding); 103 104 // If shape is an array, returns sharding, otherwise returns the tuple shaped 105 // sharding with all the leaf nodes having the same input sharding. 106 static HloSharding Single(const Shape& shape, const HloSharding& sharding); 107 108 // Create a new sharding from a protobuf OpSharding. 109 static StatusOr<HloSharding> FromProto(const OpSharding& proto); 110 111 // Checks whether device is a reserved device number. A reserved device number 112 // has usually a special meaning, with dedicated handling logic. IsReservedDevice(int64_t device)113 static bool IsReservedDevice(int64_t device) { return device < 0; } 114 115 OpSharding ToProto() const; 116 117 // Note that this string canonically has outer curly braces, e.g. 118 // "{replicated}". 119 std::string ToString(bool include_metadata = false) const; 120 121 // Validate that this sharding can be applied to a tensor with shape `shape`. 122 Status Validate(const Shape& shape, int64_t num_devices) const; 123 124 // Returns true if the sharding has tuple type. IsTuple()125 bool IsTuple() const { return tuple_; } 126 127 // Returns true if the sharding is trivial: replicate on all devices. IsReplicated()128 bool IsReplicated() const { 129 if (!IsTuple()) { 130 return replicated_; 131 } 132 return absl::c_all_of( 133 tuple_elements_, [](const HloSharding& s) { return s.IsReplicated(); }); 134 } 135 136 // Returns true if the tile size is the same as the input size. IsTileMaximal()137 bool IsTileMaximal() const { 138 if (!IsTuple()) { 139 return maximal_; 140 } 141 return absl::c_all_of(tuple_elements_, [](const HloSharding& s) { 142 return s.IsTileMaximal(); 143 }); 144 } 145 146 // Returns whether the sharding represents manual partitioning. IsManual()147 bool IsManual() const { 148 if (!IsTuple()) { 149 return manual_; 150 } 151 return absl::c_all_of(tuple_elements_, 152 [](const HloSharding& s) { return s.IsManual(); }); 153 } 154 155 // Returns whether the sharding represents manual subgroup sharding. IsManualSubgroup()156 bool IsManualSubgroup() const { 157 if (!IsTuple()) { 158 return absl::c_linear_search(subgroup_types_, OpSharding::MANUAL); 159 } 160 return absl::c_all_of(tuple_elements_, [](const HloSharding& s) { 161 return s.IsManualSubgroup(); 162 }); 163 } 164 165 // Returns weather the sharding represents a tiled sharding where the mapping 166 // between devices and tiles is represented through 'tile_assignment()'. IsTiled()167 bool IsTiled() const { return !IsTileMaximal() && !IsManual(); } 168 169 // Returns if the sharding has partial replication and partial sharding. If 170 // true, data is sharded according to other dimensions of tile_assignment(), 171 // but replicated across devices along the last dimension. ReplicateOnLastTileDim()172 bool ReplicateOnLastTileDim() const { return replicate_on_last_tile_dim_; } 173 174 // Returns whether there is any partial replication. This can be using 175 // ReplicateOnLastTileDim or subgroups with REPLICATED. HasPartialReplication()176 bool HasPartialReplication() const { 177 return replicate_on_last_tile_dim_ || 178 absl::c_linear_search(subgroup_types_, OpSharding::REPLICATED); 179 } 180 181 // Returns true if the sharding defines an operation on the given device. 182 bool UsesDevice(int64_t device) const; 183 184 // Retrieves a histogram of the devices used by the sharding. The returned 185 // map has the device number as key, and the occurrence count as value. 186 // If a sharding does not have a device, it will not be included in the 187 // histogram. The count argument, if not nullptr, will receive the total 188 // number of elements this sharding is made of (one for array, N leaves for 189 // tuples). 190 std::map<int64_t, int64_t> UsedDevices(int64_t* count) const; 191 192 // Returns the tile that should be executed on the given device. 193 // REQUIRES: !IsTuple() 194 std::vector<int64_t> TileIndexForDevice(int64_t device) const; 195 196 // Returns the device that should execute the given tile. 197 // It is an error to call this if is_replicated() is true. 198 // When ReplicateOnLastTileDim() == true, if index.size() == data rank, it 199 // returns the first device in that replicated subgroup; otherwise, 200 // index.size() should be the same as tile_assignment()'s rank and specifies 201 // the member of the replication subgroup. 202 // REQUIRES: !IsTuple() 203 int64_t DeviceForTileIndex(absl::Span<const int64_t> index) const; 204 205 // Given a device ID, returns the offset within the specified shape of the 206 // tile that should be executed on the given core. This returns the lower 207 // extent of the tile in the input space. 208 // REQUIRES: !IsTuple() 209 std::vector<int64_t> TileOffsetForDevice(const Shape& shape, 210 int64_t device) const; 211 212 // Given a device ID, returns the limit within the specified shape of the 213 // tile that should be executed on the given core. This returns the upper 214 // extent of the tile in the input space. 215 // REQUIRES: !IsTuple() 216 std::vector<int64_t> TileLimitForDevice(const Shape& shape, 217 int64_t device) const; 218 219 // Returns the single device this op operates on. If the sharding does not 220 // span a single device, the return value will be empty. 221 // In order for a sharding to span a single device, every leaf sharding must 222 // be maximal and not replicated, and the used device must match. 223 std::optional<int64_t> UniqueDevice() const; 224 225 // Retrieves the unique device or fails with a CHECK. 226 int64_t GetUniqueDevice() const; 227 228 // Returns true if this op only uses a single device. HasUniqueDevice()229 bool HasUniqueDevice() const { return UniqueDevice().has_value(); } 230 231 // Returns the ShapeTree containing the shardings for each element of this 232 // tuple, if IsTuple, or a ShapeTree with a single element containing this 233 // sharding. Only the leaf elements are populated. This creates a new 234 // ShapeTree object so is not cheap. 235 StatusOr<ShapeTree<HloSharding>> AsShapeTree(const Shape& shape) const; GetAsShapeTree(const Shape & shape)236 ShapeTree<HloSharding> GetAsShapeTree(const Shape& shape) const { 237 return AsShapeTree(shape).ValueOrDie(); 238 } 239 240 // Retrieves the sub sharding at a given index, out of a tuple sharding. 241 // REQUIRES: IsTuple() 242 HloSharding GetSubSharding(const Shape& shape, const ShapeIndex& index) const; 243 244 // If the current sharding is a tuple sharding, return itself as result. 245 // Otherwise returns a tuple sharding for the input shape, with all the leaves 246 // having this object sharding. 247 StatusOr<HloSharding> GetTupleSharding(const Shape& shape) const; 248 249 // Extracts the sharding that is common within the current sharding. 250 // If the current sharding is not a tuple sharding, the current sharding will 251 // be returned. If it is a tuple, and all the tuple elements are common, the 252 // common element will be returned. Otherwise the optional will contain no 253 // value. 254 std::optional<HloSharding> ExtractSingleSharding() const; 255 256 // Returns a copy of the sharding with no metadata. If sharding is of tuple 257 // type, sub shardings will have no metadata. 258 HloSharding WithoutMetadata() const; 259 260 // Returns a copy of the sharding with specified metadata. If metadata is 261 // already present, that metadata will not be replaced unless `overwrite` is 262 // set to true. If sharding is of tuple type, sub shardings metadata will be 263 // assigned instead. 264 HloSharding WithMetadata(absl::Span<const OpMetadata> metadata, 265 bool overwrite) const; 266 267 bool operator==(const HloSharding& other) const { 268 return replicated_ == other.replicated_ && maximal_ == other.maximal_ && 269 manual_ == other.manual_ && 270 tile_assignment_ == other.tile_assignment_ && 271 tuple_elements_ == other.tuple_elements_ && 272 replicate_on_last_tile_dim_ == other.replicate_on_last_tile_dim_ && 273 subgroup_types_ == other.subgroup_types_; 274 } 275 bool operator!=(const HloSharding& other) const { return !(*this == other); } 276 277 template <typename H> AbslHashValue(H h,const HloSharding & sharding)278 friend H AbslHashValue(H h, const HloSharding& sharding) { 279 if (sharding.tuple_) { 280 return H::combine(std::move(h), sharding.tuple_elements_); 281 } 282 return H::combine(std::move(h), sharding.replicated_, sharding.manual_, 283 sharding.tile_assignment_, 284 sharding.replicate_on_last_tile_dim_); 285 } 286 287 // Gets the tile assignment tensor. 288 // REQUIRES: !IsReplicated() && !IsTuple() tile_assignment()289 const Array<int64_t>& tile_assignment() const { return tile_assignment_; } 290 291 // Gets the subgroup types array. 292 // REQUIRES: !IsTuple() subgroup_types()293 const std::vector<OpSharding::Type>& subgroup_types() const { 294 return subgroup_types_; 295 } 296 297 // Returns the flattened list of all the leaf shardings in a tuple shape, by 298 // pre-order walk (ShapeTree iterator order). 299 // REQUIRES: IsTuple(). tuple_elements()300 std::vector<HloSharding>& tuple_elements() { return tuple_elements_; } tuple_elements()301 const std::vector<HloSharding>& tuple_elements() const { 302 return tuple_elements_; 303 } 304 305 // Gets the tile shape. 306 // REQUIRES: !IsTuple() 307 Shape TileShape(const Shape& shape) const; 308 309 // Gets the tile shape on the device. 310 // REQUIRES: !IsTuple() 311 Shape TileShape(const Shape& shape, int64_t device) const; 312 313 // Gets the number of tiles. If it has partial replication, this will not 314 // equal the device count. 315 int64_t NumTiles() const; 316 // Like NumTiles() but considers only some specific dimensions passed as 317 // argument 318 int64_t NumTiles(absl::Span<const int64_t> dims) const; 319 320 // Gets metadata from sharding. metadata()321 std::vector<OpMetadata>& metadata() { return metadata_; } metadata()322 const std::vector<OpMetadata>& metadata() const { return metadata_; } 323 324 // Returns the replication subgroiup dim, or -1 if it doesn't exist. SubgroupReplicationDim()325 int64_t SubgroupReplicationDim() const { 326 auto it = absl::c_find(subgroup_types_, OpSharding::REPLICATED); 327 if (it != subgroup_types_.end()) { 328 return (it - subgroup_types_.begin()) + TiledDataRank(); 329 } 330 if (replicate_on_last_tile_dim_) { 331 return tile_assignment_.num_dimensions() - 1; 332 } 333 return -1; 334 } 335 336 // Returns the manual subgroup dim, or -1 if it doesn't exist. SubgroupManualDim()337 int64_t SubgroupManualDim() const { 338 auto it = absl::c_find(subgroup_types_, OpSharding::MANUAL); 339 if (it != subgroup_types_.end()) { 340 return (it - subgroup_types_.begin()) + TiledDataRank(); 341 } 342 return -1; 343 } 344 345 // Returns the data rank for tiled sharding. It doesn't include subgroup dims. TiledDataRank()346 int64_t TiledDataRank() const { 347 CHECK(IsTiled()); 348 int64_t rank = tile_assignment_.num_dimensions(); 349 if (ReplicateOnLastTileDim()) { 350 rank--; 351 } 352 rank -= subgroup_types_.size(); 353 return rank; 354 } 355 356 private: HloSharding(bool manual,bool replicated,absl::Span<const OpMetadata> metadata)357 explicit HloSharding(bool manual, bool replicated, 358 absl::Span<const OpMetadata> metadata) 359 : replicated_(replicated), 360 maximal_(replicated), 361 tuple_(false), 362 manual_(manual), 363 tile_assignment_({0}), 364 replicate_on_last_tile_dim_(false), 365 metadata_(metadata.begin(), metadata.end()) {} 366 // device_id values: 367 // -2: magic number to mean unassigned device, used by spatial partitioning 368 // -1: the id of the host 369 // 0 or positive: the id of a device 370 // NOTE(dimvar): -1 is needed for outside compilation. It can be removed once 371 // we have fully switched to the side-effect tokens. HloSharding(int64_t device_id,absl::Span<const OpMetadata> metadata)372 explicit HloSharding(int64_t device_id, absl::Span<const OpMetadata> metadata) 373 : replicated_(false), 374 maximal_(true), 375 tuple_(false), 376 manual_(false), 377 tile_assignment_({1}, device_id), 378 replicate_on_last_tile_dim_(false), 379 metadata_(metadata.begin(), metadata.end()) {} 380 explicit HloSharding(const Array<int64_t>& tile_assignment, 381 bool replicate_on_last_tile_dim, 382 absl::Span<const OpMetadata> metadata = {}) replicated_(false)383 : replicated_(false), 384 maximal_(false), 385 tuple_(false), 386 manual_(false), 387 tile_assignment_(tile_assignment), 388 replicate_on_last_tile_dim_(replicate_on_last_tile_dim), 389 metadata_(metadata.begin(), metadata.end()) {} 390 explicit HloSharding(const Array<int64_t>& tile_assignment, 391 absl::Span<const OpSharding::Type> subgroup_types, 392 absl::Span<const OpMetadata> metadata = {}) replicated_(false)393 : replicated_(false), 394 maximal_(false), 395 tuple_(false), 396 manual_(false), 397 tile_assignment_(tile_assignment), 398 replicate_on_last_tile_dim_(false), 399 metadata_(metadata.begin(), metadata.end()), 400 subgroup_types_(subgroup_types.begin(), subgroup_types.end()) {} HloSharding(const std::vector<HloSharding> & tuple_shardings)401 explicit HloSharding(const std::vector<HloSharding>& tuple_shardings) 402 : replicated_(false), 403 maximal_(false), 404 tuple_(true), 405 manual_(false), 406 tile_assignment_({0}), 407 tuple_elements_(tuple_shardings), 408 replicate_on_last_tile_dim_(false) {} 409 410 // Checks that the number of elements in tuple_elements_ is consistent with 411 // the tuple shape passes as argument. 412 Status CheckLeafCount(const Shape& shape) const; 413 414 // Internal helper to validate a tuple sharding. 415 Status ValidateTuple(const Shape& shape, int64_t num_devices) const; 416 417 // Internal helper to validate a non-tuple (leaf) sharding. 418 Status ValidateNonTuple(const Shape& shape, int64_t num_devices) const; 419 420 // Returns the number of tuple_elements_ entries to fit the shape. 421 static int64_t RequiredLeaves(const Shape& shape); 422 423 bool replicated_; 424 bool maximal_; 425 bool tuple_; 426 bool manual_; 427 // This field is only used if replicated_ is false. If maximal_ is true, then 428 // the field contains a rank 1 array with a single element, which is the 429 // device the HLO is assigned to. If maximal_ is false, the field contains an 430 // array with the same rank as the corresponding HLO. The dimension sizes of 431 // the array describe the number of ways the HLO is partitioned along each 432 // dimension. The values of the array specify which device each tile of 433 // the HLO is assigned to. The index of each value determines which tile it 434 // takes. 435 // For example, {{{2, 3}}, {{5, 7}}} (whose ToString representation is 436 // "{devices=[2,1,2]2,3,5,7}"), means that dimension 1 is split two way and 437 // dimension 3 is split 2 way. Core 5, whose index is [2,1,1] will take the 438 // tile that contains the 2nd half of dimension 1 and the 1st half of 439 // dimension 3. 440 Array<int64_t> tile_assignment_; 441 // Only non-empty when tuple_ is true. If a tuple is empty then one entry is 442 // present for the root. This is a flattened list of all the leaf shardings in 443 // a tuple shape, by pre-order walk (ShapeTree iterator order). 444 std::vector<HloSharding> tuple_elements_; 445 // This flag is to support partial replication and partial sharding. If it is 446 // true, tile_assignment_ will have an extra dimension in addition to the data 447 // shape rank, and the added last dimension represents the subgroups of 448 // replications, i.e., elements in slice [..., :] will be replicated. 449 bool replicate_on_last_tile_dim_; 450 // This field is used to track the source of this sharding, usually derived 451 // from instructions. Multiple metadata may be populated if sharding is 452 // combined with other shardings. Metadata are to not be populated when 453 // tuple_ == true and instead metadata should be set on individual tuple 454 // elements. 455 std::vector<OpMetadata> metadata_; 456 // This field is used to represented the sharding type of each subgroup. 457 // For example, sharding={devices=[2,2,2,2]0,1,2,...,15 last_tile_dims={ 458 // replicate, manual, unreduced}} means that each of the last 3 dimensions 459 // in [2,2,2,2] represents a subgrouping in replicate, manual. 460 // When creating HloSharding, subgroup dims of the same type will be merged, 461 // so that there is at most one dim with a given type. 462 std::vector<OpSharding::Type> subgroup_types_; 463 }; 464 465 std::ostream& operator<<(std::ostream& out, const HloSharding& sharding); 466 467 } // namespace xla 468 469 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_ 470