xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/hlo_sharding.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // HLO shardings describe how an HLO instruction is split across multiple
17 // computations.
18 
19 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_
20 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_
21 
22 #include <map>
23 #include <string>
24 #include <vector>
25 
26 #include "absl/algorithm/container.h"
27 #include "absl/types/span.h"
28 #include "tensorflow/compiler/xla/array.h"
29 #include "tensorflow/compiler/xla/shape_tree.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 
32 namespace xla {
33 
34 // HLO shardings describe how an HLO instruction is split across multiple
35 // computations.
36 class HloSharding {
37  public:
38   // Creates a trivial sharding that replicates a maximal tile across all
39   // devices.
40   static HloSharding Replicate(absl::Span<const OpMetadata> metadata = {}) {
41     return HloSharding(/*manual=*/false, /*replicated=*/true, metadata);
42   }
43 
44   // Creates a sharding that represents the op is manually partitioned.
45   static HloSharding Manual(absl::Span<const OpMetadata> metadata = {}) {
46     return HloSharding(/*manual=*/true, /*replicated=*/false, metadata);
47   }
48 
49   // Creates a sharding that emulates device placement; a tile shape equal to
50   // the input shape (one tile) assigned to a single device.
51   static HloSharding AssignDevice(int64_t device_id,
52                                   absl::Span<const OpMetadata> metadata = {});
53 
54   // Creates a new sharding which splits a shape into tiles amongst the devices
55   // specified by `tile_assignment`.
56   static HloSharding Tile(const Array<int64_t>& tile_assignment,
57                           absl::Span<const OpMetadata> metadata = {}) {
58     return HloSharding(tile_assignment, /*replicate_on_last_tile_dim=*/false,
59                        metadata);
60   }
61 
62   // Creates a new sharding where data is replicated within each replication
63   // group, and sharded across replication groups according to
64   // group_tile_assignment. Replication group members will be sorted.
65   static HloSharding PartialTile(
66       const Array<int64_t>& group_tile_assignment,
67       absl::Span<const absl::Span<const int64_t>> replication_groups,
68       absl::Span<const OpMetadata> metadata = {});
69 
70   // Creates a partially replicated tiled sharding with device-level tile
71   // assignment, where the last dimension is the additional replication
72   // dimension. Replication group members will be sorted.
73   static HloSharding PartialTile(
74       const Array<int64_t>& tile_assignment_last_dim_replicate,
75       absl::Span<const OpMetadata> metadata = {});
76 
77   // Creates a subgroup sharding with device-level tile assignment, the
78   // sharding type of each subgroup is defined by subgroup_types. When creating
79   // the HloSharding, subgroup dims of the same type will be merged.
80   static HloSharding Subgroup(const Array<int64_t>& tile_assignment,
81                               absl::Span<const OpSharding::Type> subgroup_types,
82                               absl::Span<const OpMetadata> metadata = {});
83 
84   // Creates a new sharding which splits a one-dimensional input shape into
85   // `num_tiles` tiles.
86   static HloSharding Tile1D(const Shape& input_shape, int64_t num_tiles,
87                             absl::Span<const OpMetadata> metadata = {});
88 
89   // Creates a new sharding for a tuple type. The given ShapeTree must have
90   // elements for every leaf shape contained in the tuple.
91   static HloSharding Tuple(const ShapeTree<HloSharding>& sub_shardings);
92 
93   // Creates a new sharding for a tuple type. The number of elements in
94   // shardings must match the number of leaf nodes in tuple_shape. For
95   // empty tuples, the shardings array must have one element.
96   static HloSharding Tuple(const Shape& tuple_shape,
97                            absl::Span<const HloSharding> shardings);
98 
99   // Creates a new sharding for a tuple type, with a single input sharding
100   // repeated on each leaf.
101   static HloSharding SingleTuple(const Shape& tuple_shape,
102                                  const HloSharding& sharding);
103 
104   // If shape is an array, returns sharding, otherwise returns the tuple shaped
105   // sharding with all the leaf nodes having the same input sharding.
106   static HloSharding Single(const Shape& shape, const HloSharding& sharding);
107 
108   // Create a new sharding from a protobuf OpSharding.
109   static StatusOr<HloSharding> FromProto(const OpSharding& proto);
110 
111   // Checks whether device is a reserved device number. A reserved device number
112   // has usually a special meaning, with dedicated handling logic.
IsReservedDevice(int64_t device)113   static bool IsReservedDevice(int64_t device) { return device < 0; }
114 
115   OpSharding ToProto() const;
116 
117   // Note that this string canonically has outer curly braces, e.g.
118   // "{replicated}".
119   std::string ToString(bool include_metadata = false) const;
120 
121   // Validate that this sharding can be applied to a tensor with shape `shape`.
122   Status Validate(const Shape& shape, int64_t num_devices) const;
123 
124   // Returns true if the sharding has tuple type.
IsTuple()125   bool IsTuple() const { return tuple_; }
126 
127   // Returns true if the sharding is trivial: replicate on all devices.
IsReplicated()128   bool IsReplicated() const {
129     if (!IsTuple()) {
130       return replicated_;
131     }
132     return absl::c_all_of(
133         tuple_elements_, [](const HloSharding& s) { return s.IsReplicated(); });
134   }
135 
136   // Returns true if the tile size is the same as the input size.
IsTileMaximal()137   bool IsTileMaximal() const {
138     if (!IsTuple()) {
139       return maximal_;
140     }
141     return absl::c_all_of(tuple_elements_, [](const HloSharding& s) {
142       return s.IsTileMaximal();
143     });
144   }
145 
146   // Returns whether the sharding represents manual partitioning.
IsManual()147   bool IsManual() const {
148     if (!IsTuple()) {
149       return manual_;
150     }
151     return absl::c_all_of(tuple_elements_,
152                           [](const HloSharding& s) { return s.IsManual(); });
153   }
154 
155   // Returns whether the sharding represents manual subgroup sharding.
IsManualSubgroup()156   bool IsManualSubgroup() const {
157     if (!IsTuple()) {
158       return absl::c_linear_search(subgroup_types_, OpSharding::MANUAL);
159     }
160     return absl::c_all_of(tuple_elements_, [](const HloSharding& s) {
161       return s.IsManualSubgroup();
162     });
163   }
164 
165   // Returns weather the sharding represents a tiled sharding where the mapping
166   // between devices and tiles is represented through 'tile_assignment()'.
IsTiled()167   bool IsTiled() const { return !IsTileMaximal() && !IsManual(); }
168 
169   // Returns if the sharding has partial replication and partial sharding. If
170   // true, data is sharded according to other dimensions of tile_assignment(),
171   // but replicated across devices along the last dimension.
ReplicateOnLastTileDim()172   bool ReplicateOnLastTileDim() const { return replicate_on_last_tile_dim_; }
173 
174   // Returns whether there is any partial replication. This can be using
175   // ReplicateOnLastTileDim or subgroups with REPLICATED.
HasPartialReplication()176   bool HasPartialReplication() const {
177     return replicate_on_last_tile_dim_ ||
178            absl::c_linear_search(subgroup_types_, OpSharding::REPLICATED);
179   }
180 
181   // Returns true if the sharding defines an operation on the given device.
182   bool UsesDevice(int64_t device) const;
183 
184   // Retrieves a histogram of the devices used by the sharding. The returned
185   // map has the device number as key, and the occurrence count as value.
186   // If a sharding does not have a device, it will not be included in the
187   // histogram. The count argument, if not nullptr, will receive the total
188   // number of elements this sharding is made of (one for array, N leaves for
189   // tuples).
190   std::map<int64_t, int64_t> UsedDevices(int64_t* count) const;
191 
192   // Returns the tile that should be executed on the given device.
193   // REQUIRES: !IsTuple()
194   std::vector<int64_t> TileIndexForDevice(int64_t device) const;
195 
196   // Returns the device that should execute the given tile.
197   // It is an error to call this if is_replicated() is true.
198   // When ReplicateOnLastTileDim() == true, if index.size() == data rank, it
199   // returns the first device in that replicated subgroup; otherwise,
200   // index.size() should be the same as tile_assignment()'s rank and specifies
201   // the member of the replication subgroup.
202   // REQUIRES: !IsTuple()
203   int64_t DeviceForTileIndex(absl::Span<const int64_t> index) const;
204 
205   // Given a device ID, returns the offset within the specified shape of the
206   // tile that should be executed on the given core. This returns the lower
207   // extent of the tile in the input space.
208   // REQUIRES: !IsTuple()
209   std::vector<int64_t> TileOffsetForDevice(const Shape& shape,
210                                            int64_t device) const;
211 
212   // Given a device ID, returns the limit within the specified shape of the
213   // tile that should be executed on the given core. This returns the upper
214   // extent of the tile in the input space.
215   // REQUIRES: !IsTuple()
216   std::vector<int64_t> TileLimitForDevice(const Shape& shape,
217                                           int64_t device) const;
218 
219   // Returns the single device this op operates on. If the sharding does not
220   // span a single device, the return value will be empty.
221   // In order for a sharding to span a single device, every leaf sharding must
222   // be maximal and not replicated, and the used device must match.
223   std::optional<int64_t> UniqueDevice() const;
224 
225   // Retrieves the unique device or fails with a CHECK.
226   int64_t GetUniqueDevice() const;
227 
228   // Returns true if this op only uses a single device.
HasUniqueDevice()229   bool HasUniqueDevice() const { return UniqueDevice().has_value(); }
230 
231   // Returns the ShapeTree containing the shardings for each element of this
232   // tuple, if IsTuple, or a ShapeTree with a single element containing this
233   // sharding. Only the leaf elements are populated. This creates a new
234   // ShapeTree object so is not cheap.
235   StatusOr<ShapeTree<HloSharding>> AsShapeTree(const Shape& shape) const;
GetAsShapeTree(const Shape & shape)236   ShapeTree<HloSharding> GetAsShapeTree(const Shape& shape) const {
237     return AsShapeTree(shape).ValueOrDie();
238   }
239 
240   // Retrieves the sub sharding at a given index, out of a tuple sharding.
241   // REQUIRES: IsTuple()
242   HloSharding GetSubSharding(const Shape& shape, const ShapeIndex& index) const;
243 
244   // If the current sharding is a tuple sharding, return itself as result.
245   // Otherwise returns a tuple sharding for the input shape, with all the leaves
246   // having this object sharding.
247   StatusOr<HloSharding> GetTupleSharding(const Shape& shape) const;
248 
249   // Extracts the sharding that is common within the current sharding.
250   // If the current sharding is not a tuple sharding, the current sharding will
251   // be returned. If it is a tuple, and all the tuple elements are common, the
252   // common element will be returned. Otherwise the optional will contain no
253   // value.
254   std::optional<HloSharding> ExtractSingleSharding() const;
255 
256   // Returns a copy of the sharding with no metadata. If sharding is of tuple
257   // type, sub shardings will have no metadata.
258   HloSharding WithoutMetadata() const;
259 
260   // Returns a copy of the sharding with specified metadata. If metadata is
261   // already present, that metadata will not be replaced unless `overwrite` is
262   // set to true. If sharding is of tuple type, sub shardings metadata will be
263   // assigned instead.
264   HloSharding WithMetadata(absl::Span<const OpMetadata> metadata,
265                            bool overwrite) const;
266 
267   bool operator==(const HloSharding& other) const {
268     return replicated_ == other.replicated_ && maximal_ == other.maximal_ &&
269            manual_ == other.manual_ &&
270            tile_assignment_ == other.tile_assignment_ &&
271            tuple_elements_ == other.tuple_elements_ &&
272            replicate_on_last_tile_dim_ == other.replicate_on_last_tile_dim_ &&
273            subgroup_types_ == other.subgroup_types_;
274   }
275   bool operator!=(const HloSharding& other) const { return !(*this == other); }
276 
277   template <typename H>
AbslHashValue(H h,const HloSharding & sharding)278   friend H AbslHashValue(H h, const HloSharding& sharding) {
279     if (sharding.tuple_) {
280       return H::combine(std::move(h), sharding.tuple_elements_);
281     }
282     return H::combine(std::move(h), sharding.replicated_, sharding.manual_,
283                       sharding.tile_assignment_,
284                       sharding.replicate_on_last_tile_dim_);
285   }
286 
287   // Gets the tile assignment tensor.
288   // REQUIRES: !IsReplicated() && !IsTuple()
tile_assignment()289   const Array<int64_t>& tile_assignment() const { return tile_assignment_; }
290 
291   // Gets the subgroup types array.
292   // REQUIRES: !IsTuple()
subgroup_types()293   const std::vector<OpSharding::Type>& subgroup_types() const {
294     return subgroup_types_;
295   }
296 
297   // Returns the flattened list of all the leaf shardings in a tuple shape, by
298   // pre-order walk (ShapeTree iterator order).
299   // REQUIRES: IsTuple().
tuple_elements()300   std::vector<HloSharding>& tuple_elements() { return tuple_elements_; }
tuple_elements()301   const std::vector<HloSharding>& tuple_elements() const {
302     return tuple_elements_;
303   }
304 
305   // Gets the tile shape.
306   // REQUIRES: !IsTuple()
307   Shape TileShape(const Shape& shape) const;
308 
309   // Gets the tile shape on the device.
310   // REQUIRES: !IsTuple()
311   Shape TileShape(const Shape& shape, int64_t device) const;
312 
313   // Gets the number of tiles. If it has partial replication, this will not
314   // equal the device count.
315   int64_t NumTiles() const;
316   // Like NumTiles() but considers only some specific dimensions passed as
317   // argument
318   int64_t NumTiles(absl::Span<const int64_t> dims) const;
319 
320   // Gets metadata from sharding.
metadata()321   std::vector<OpMetadata>& metadata() { return metadata_; }
metadata()322   const std::vector<OpMetadata>& metadata() const { return metadata_; }
323 
324   // Returns the replication subgroiup dim, or -1 if it doesn't exist.
SubgroupReplicationDim()325   int64_t SubgroupReplicationDim() const {
326     auto it = absl::c_find(subgroup_types_, OpSharding::REPLICATED);
327     if (it != subgroup_types_.end()) {
328       return (it - subgroup_types_.begin()) + TiledDataRank();
329     }
330     if (replicate_on_last_tile_dim_) {
331       return tile_assignment_.num_dimensions() - 1;
332     }
333     return -1;
334   }
335 
336   // Returns the manual subgroup dim, or -1 if it doesn't exist.
SubgroupManualDim()337   int64_t SubgroupManualDim() const {
338     auto it = absl::c_find(subgroup_types_, OpSharding::MANUAL);
339     if (it != subgroup_types_.end()) {
340       return (it - subgroup_types_.begin()) + TiledDataRank();
341     }
342     return -1;
343   }
344 
345   // Returns the data rank for tiled sharding. It doesn't include subgroup dims.
TiledDataRank()346   int64_t TiledDataRank() const {
347     CHECK(IsTiled());
348     int64_t rank = tile_assignment_.num_dimensions();
349     if (ReplicateOnLastTileDim()) {
350       rank--;
351     }
352     rank -= subgroup_types_.size();
353     return rank;
354   }
355 
356  private:
HloSharding(bool manual,bool replicated,absl::Span<const OpMetadata> metadata)357   explicit HloSharding(bool manual, bool replicated,
358                        absl::Span<const OpMetadata> metadata)
359       : replicated_(replicated),
360         maximal_(replicated),
361         tuple_(false),
362         manual_(manual),
363         tile_assignment_({0}),
364         replicate_on_last_tile_dim_(false),
365         metadata_(metadata.begin(), metadata.end()) {}
366   // device_id values:
367   // -2: magic number to mean unassigned device, used by spatial partitioning
368   // -1: the id of the host
369   //  0 or positive: the id of a device
370   // NOTE(dimvar): -1 is needed for outside compilation. It can be removed once
371   // we have fully switched to the side-effect tokens.
HloSharding(int64_t device_id,absl::Span<const OpMetadata> metadata)372   explicit HloSharding(int64_t device_id, absl::Span<const OpMetadata> metadata)
373       : replicated_(false),
374         maximal_(true),
375         tuple_(false),
376         manual_(false),
377         tile_assignment_({1}, device_id),
378         replicate_on_last_tile_dim_(false),
379         metadata_(metadata.begin(), metadata.end()) {}
380   explicit HloSharding(const Array<int64_t>& tile_assignment,
381                        bool replicate_on_last_tile_dim,
382                        absl::Span<const OpMetadata> metadata = {})
replicated_(false)383       : replicated_(false),
384         maximal_(false),
385         tuple_(false),
386         manual_(false),
387         tile_assignment_(tile_assignment),
388         replicate_on_last_tile_dim_(replicate_on_last_tile_dim),
389         metadata_(metadata.begin(), metadata.end()) {}
390   explicit HloSharding(const Array<int64_t>& tile_assignment,
391                        absl::Span<const OpSharding::Type> subgroup_types,
392                        absl::Span<const OpMetadata> metadata = {})
replicated_(false)393       : replicated_(false),
394         maximal_(false),
395         tuple_(false),
396         manual_(false),
397         tile_assignment_(tile_assignment),
398         replicate_on_last_tile_dim_(false),
399         metadata_(metadata.begin(), metadata.end()),
400         subgroup_types_(subgroup_types.begin(), subgroup_types.end()) {}
HloSharding(const std::vector<HloSharding> & tuple_shardings)401   explicit HloSharding(const std::vector<HloSharding>& tuple_shardings)
402       : replicated_(false),
403         maximal_(false),
404         tuple_(true),
405         manual_(false),
406         tile_assignment_({0}),
407         tuple_elements_(tuple_shardings),
408         replicate_on_last_tile_dim_(false) {}
409 
410   // Checks that the number of elements in tuple_elements_ is consistent with
411   // the tuple shape passes as argument.
412   Status CheckLeafCount(const Shape& shape) const;
413 
414   // Internal helper to validate a tuple sharding.
415   Status ValidateTuple(const Shape& shape, int64_t num_devices) const;
416 
417   // Internal helper to validate a non-tuple (leaf) sharding.
418   Status ValidateNonTuple(const Shape& shape, int64_t num_devices) const;
419 
420   // Returns the number of tuple_elements_ entries to fit the shape.
421   static int64_t RequiredLeaves(const Shape& shape);
422 
423   bool replicated_;
424   bool maximal_;
425   bool tuple_;
426   bool manual_;
427   // This field is only used if replicated_ is false. If maximal_ is true, then
428   // the field contains a rank 1 array with a single element, which is the
429   // device the HLO is assigned to. If maximal_ is false, the field contains an
430   // array with the same rank as the corresponding HLO. The dimension sizes of
431   // the array describe the number of ways the HLO is partitioned along each
432   // dimension. The values of the array specify which device each tile of
433   // the HLO is assigned to. The index of each value determines which tile it
434   // takes.
435   // For example, {{{2, 3}}, {{5, 7}}} (whose ToString representation is
436   // "{devices=[2,1,2]2,3,5,7}"), means that dimension 1 is split two way and
437   // dimension 3 is split 2 way. Core 5, whose index is [2,1,1] will take the
438   // tile that contains the 2nd half of dimension 1 and the 1st half of
439   // dimension 3.
440   Array<int64_t> tile_assignment_;
441   // Only non-empty when tuple_ is true. If a tuple is empty then one entry is
442   // present for the root. This is a flattened list of all the leaf shardings in
443   // a tuple shape, by pre-order walk (ShapeTree iterator order).
444   std::vector<HloSharding> tuple_elements_;
445   // This flag is to support partial replication and partial sharding. If it is
446   // true, tile_assignment_ will have an extra dimension in addition to the data
447   // shape rank, and the added last dimension represents the subgroups of
448   // replications, i.e., elements in slice [..., :] will be replicated.
449   bool replicate_on_last_tile_dim_;
450   // This field is used to track the source of this sharding, usually derived
451   // from instructions. Multiple metadata may be populated if sharding is
452   // combined with other shardings. Metadata are to not be populated when
453   // tuple_ == true and instead metadata should be set on individual tuple
454   // elements.
455   std::vector<OpMetadata> metadata_;
456   // This field is used to represented the sharding type of each subgroup.
457   // For example, sharding={devices=[2,2,2,2]0,1,2,...,15 last_tile_dims={
458   // replicate, manual, unreduced}} means that each of the last 3 dimensions
459   // in [2,2,2,2] represents a subgrouping in replicate, manual.
460   // When creating HloSharding, subgroup dims of the same type will be merged,
461   // so that there is at most one dim with a given type.
462   std::vector<OpSharding::Type> subgroup_types_;
463 };
464 
465 std::ostream& operator<<(std::ostream& out, const HloSharding& sharding);
466 
467 }  // namespace xla
468 
469 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_
470