xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/group_assignment.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_DTENSOR_MLIR_GROUP_ASSIGNMENT_H_
17 #define TENSORFLOW_DTENSOR_MLIR_GROUP_ASSIGNMENT_H_
18 
19 #include <ostream>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/container/flat_hash_map.h"
25 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
26 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
27 #include "tensorflow/core/platform/errors.h"
28 #include "tensorflow/core/platform/logging.h"
29 #include "tensorflow/core/platform/status.h"
30 #include "tensorflow/core/platform/types.h"
31 #include "tensorflow/dtensor/cc/dstatus.h"
32 
33 namespace tensorflow {
34 namespace dtensor {
35 
36 // Arranges all replica IDs in a DTensor mesh in groups, used as an attribute
37 // on collective operations.
38 //
39 // A group assignment has two views:
40 //
41 // - The global mesh view contains replica IDs from all participant TPU slices.
42 //   These replica IDs are identical to global device IDs in a DTensor mesh.
43 // - The local slice view contains per-slice device IDs understood and used by
44 //   the TPU runtime on each slice. These device IDs are used to set replica
45 //   IDs on each slice.
46 //
47 // Some notable common cases:
48 //
49 // - In a single-slice case, `slice_size` is set to the actual slice size
50 //   (e.g., 32 for 4x4 DF). The global and local views are identical.
51 // - In a special topology case, `slice_size` is set to 8.
52 // - In a multi-topology case, `slice_size` is set to the size of a single
53 // topology.
54 //   All topologies must have the same size.
55 class GroupAssignment {
56  public:
57   using ReplicaId = int;
58 
59   struct DeviceId {
60    public:
61     int slice_id;
62     int core_id;  // within `slice_id`
63   };
64 
65   // Maps global replica IDs to local device IDs consisting of a slice ID and a
66   // core-on-slice ID.
67   class ReplicaToDeviceMap {
68    public:
69     // Creates a default map that orders devices according to TF task IDs
70     // followed by device ordinals.
71     static ReplicaToDeviceMap DefaultReplicaToDeviceMap(int num_slices,
72                                                         int slice_size);
73 
74     // Constructs a map directly, checking it's valid.
75     explicit ReplicaToDeviceMap(absl::flat_hash_map<ReplicaId, DeviceId> map);
76 
num_slices()77     int num_slices() { return num_slices_; }
num_cores()78     int num_cores() { return map_.size(); }
device_id(ReplicaId replica_id)79     DeviceId device_id(ReplicaId replica_id) { return map_[replica_id]; }
80 
81    private:
82     absl::flat_hash_map<ReplicaId, DeviceId> map_;
83     int num_slices_;
84   };
85 
86   // Creates a group assignment by converting from an MLIR attribute.
87   static StatusOr<GroupAssignment> FromMLIR(
88       const mlir::DenseIntElementsAttr& group_assignment_attr,
89       ReplicaToDeviceMap replica_to_device_map);
90 
91   // Creates an MLIR attribute using the global view.
GlobalToMLIR(mlir::MLIRContext & context)92   mlir::DenseIntElementsAttr GlobalToMLIR(mlir::MLIRContext& context) const {
93     return global_.ToMLIR(context);
94   }
95 
96   // Creates an MLIR attribute for a particular slice.
97   // Callers should make sure `slice_id` is >= 0 and < num_slices().
SliceToMLIR(mlir::MLIRContext & context,int slice_id)98   StatusOr<mlir::DenseIntElementsAttr> SliceToMLIR(mlir::MLIRContext& context,
99                                                    int slice_id) const {
100     if (slice_id < 0 || slice_id >= num_slices())
101       return errors::InvalidArgument("slide_id was not within bounds.");
102     return slices_[slice_id].ToMLIR(context);
103   }
104 
105   // Returns a string representation for debugging.
106   std::string ToString() const;
107 
108   // Returns true if every group in the global view only has replica IDs from
109   // the same slice.
110   bool IsWithinSlices() const;
111 
112   // Returns the number of slices in the local view.
num_slices()113   int num_slices() const { return slices_.size(); }
114 
115   // These methods return attributes of the global view.
num_groups()116   int num_groups() const { return global_.num_groups(); }
group_size()117   int group_size() const { return global_.group_size(); }
num_replica_ids()118   int num_replica_ids() const { return global_.num_replica_ids(); }
replica_ids()119   const std::vector<std::vector<int>>& replica_ids() const {
120     return global_.replica_ids();
121   }
122 
123   // These methods return attributes of a particular slice.
124   // Callers should make sure `slice_id` is >= 0 and < num_slices().
num_groups(int slice_id)125   StatusOr<int> num_groups(int slice_id) const {
126     if (slice_id < 0 || slice_id >= num_slices())
127       return errors::InvalidArgument("slide_id was not within bounds.");
128     return slices_[slice_id].num_groups();
129   }
group_size(int slice_id)130   StatusOr<int> group_size(int slice_id) const {
131     if (slice_id < 0 || slice_id >= num_slices())
132       return errors::InvalidArgument("slide_id was not within bounds.");
133     return slices_[slice_id].group_size();
134   }
replica_ids(int slice_id)135   const std::vector<std::vector<int>>& replica_ids(int slice_id) const {
136     return slices_[slice_id].replica_ids();
137   }
138 
139   // Returns the replica groups for collectives running on a particular host.
140   // Callers should make sure `slice_id` is >= 0 and < num_slices().
host_replica_ids(int slice_id)141   const std::vector<std::vector<int>>& host_replica_ids(int slice_id) const {
142     return hosts_[slice_id].replica_ids();
143   }
144 
145  private:
146   // Groups of consecutive replica IDs starting at 0.
147   class ReplicaGroups {
148    public:
149     // Creates an object, enforcing the requirements on `replica_ids_`.
150     explicit ReplicaGroups(std::vector<std::vector<int>> replica_ids);
151 
152     mlir::DenseIntElementsAttr ToMLIR(mlir::MLIRContext& context) const;
153 
154     std::string ToString() const;
155 
num_groups()156     int num_groups() const { return replica_ids_.size(); }
group_size()157     int group_size() const { return replica_ids_.front().size(); }
num_replica_ids()158     int num_replica_ids() const { return num_groups() * group_size(); }
replica_ids()159     const std::vector<std::vector<int>>& replica_ids() const {
160       return replica_ids_;
161     }
162 
163    private:
164     // N groups of replica IDs, N > 0. All groups have the same size G, G > 0.
165     // All replica IDs are distinct values >= 0;
166     std::vector<std::vector<int>> replica_ids_;  // replica ID order matters
167   };
168 
169   // Creates an object but leaves `slices_` empty. `GlobalToSlices` should be
170   // called next to fill in `slices_`.
GroupAssignment(ReplicaGroups global,ReplicaToDeviceMap replica_to_device_map)171   explicit GroupAssignment(ReplicaGroups global,
172                            ReplicaToDeviceMap replica_to_device_map)
173       : global_(std::move(global)),
174         replica_to_device_map_(std::move(replica_to_device_map)) {}
175 
176   // Divides the global view along slice boundaries and fill in the slice view.
177   Status GlobalToSlices();
178 
179   ReplicaGroups global_;
180   std::vector<ReplicaGroups> hosts_;   // sorted by increasing slice ID
181   std::vector<ReplicaGroups> slices_;  // sorted by increasing slice ID
182   ReplicaToDeviceMap replica_to_device_map_;
183 };
184 
185 }  // namespace dtensor
186 }  // namespace tensorflow
187 
188 #endif  // TENSORFLOW_DTENSOR_MLIR_GROUP_ASSIGNMENT_H_
189