1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_DTENSOR_MLIR_GROUP_ASSIGNMENT_H_ 17 #define TENSORFLOW_DTENSOR_MLIR_GROUP_ASSIGNMENT_H_ 18 19 #include <ostream> 20 #include <string> 21 #include <utility> 22 #include <vector> 23 24 #include "absl/container/flat_hash_map.h" 25 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project 26 #include "mlir/IR/MLIRContext.h" // from @llvm-project 27 #include "tensorflow/core/platform/errors.h" 28 #include "tensorflow/core/platform/logging.h" 29 #include "tensorflow/core/platform/status.h" 30 #include "tensorflow/core/platform/types.h" 31 #include "tensorflow/dtensor/cc/dstatus.h" 32 33 namespace tensorflow { 34 namespace dtensor { 35 36 // Arranges all replica IDs in a DTensor mesh in groups, used as an attribute 37 // on collective operations. 38 // 39 // A group assignment has two views: 40 // 41 // - The global mesh view contains replica IDs from all participant TPU slices. 42 // These replica IDs are identical to global device IDs in a DTensor mesh. 43 // - The local slice view contains per-slice device IDs understood and used by 44 // the TPU runtime on each slice. These device IDs are used to set replica 45 // IDs on each slice. 46 // 47 // Some notable common cases: 48 // 49 // - In a single-slice case, `slice_size` is set to the actual slice size 50 // (e.g., 32 for 4x4 DF). The global and local views are identical. 51 // - In a special topology case, `slice_size` is set to 8. 52 // - In a multi-topology case, `slice_size` is set to the size of a single 53 // topology. 54 // All topologies must have the same size. 55 class GroupAssignment { 56 public: 57 using ReplicaId = int; 58 59 struct DeviceId { 60 public: 61 int slice_id; 62 int core_id; // within `slice_id` 63 }; 64 65 // Maps global replica IDs to local device IDs consisting of a slice ID and a 66 // core-on-slice ID. 67 class ReplicaToDeviceMap { 68 public: 69 // Creates a default map that orders devices according to TF task IDs 70 // followed by device ordinals. 71 static ReplicaToDeviceMap DefaultReplicaToDeviceMap(int num_slices, 72 int slice_size); 73 74 // Constructs a map directly, checking it's valid. 75 explicit ReplicaToDeviceMap(absl::flat_hash_map<ReplicaId, DeviceId> map); 76 num_slices()77 int num_slices() { return num_slices_; } num_cores()78 int num_cores() { return map_.size(); } device_id(ReplicaId replica_id)79 DeviceId device_id(ReplicaId replica_id) { return map_[replica_id]; } 80 81 private: 82 absl::flat_hash_map<ReplicaId, DeviceId> map_; 83 int num_slices_; 84 }; 85 86 // Creates a group assignment by converting from an MLIR attribute. 87 static StatusOr<GroupAssignment> FromMLIR( 88 const mlir::DenseIntElementsAttr& group_assignment_attr, 89 ReplicaToDeviceMap replica_to_device_map); 90 91 // Creates an MLIR attribute using the global view. GlobalToMLIR(mlir::MLIRContext & context)92 mlir::DenseIntElementsAttr GlobalToMLIR(mlir::MLIRContext& context) const { 93 return global_.ToMLIR(context); 94 } 95 96 // Creates an MLIR attribute for a particular slice. 97 // Callers should make sure `slice_id` is >= 0 and < num_slices(). SliceToMLIR(mlir::MLIRContext & context,int slice_id)98 StatusOr<mlir::DenseIntElementsAttr> SliceToMLIR(mlir::MLIRContext& context, 99 int slice_id) const { 100 if (slice_id < 0 || slice_id >= num_slices()) 101 return errors::InvalidArgument("slide_id was not within bounds."); 102 return slices_[slice_id].ToMLIR(context); 103 } 104 105 // Returns a string representation for debugging. 106 std::string ToString() const; 107 108 // Returns true if every group in the global view only has replica IDs from 109 // the same slice. 110 bool IsWithinSlices() const; 111 112 // Returns the number of slices in the local view. num_slices()113 int num_slices() const { return slices_.size(); } 114 115 // These methods return attributes of the global view. num_groups()116 int num_groups() const { return global_.num_groups(); } group_size()117 int group_size() const { return global_.group_size(); } num_replica_ids()118 int num_replica_ids() const { return global_.num_replica_ids(); } replica_ids()119 const std::vector<std::vector<int>>& replica_ids() const { 120 return global_.replica_ids(); 121 } 122 123 // These methods return attributes of a particular slice. 124 // Callers should make sure `slice_id` is >= 0 and < num_slices(). num_groups(int slice_id)125 StatusOr<int> num_groups(int slice_id) const { 126 if (slice_id < 0 || slice_id >= num_slices()) 127 return errors::InvalidArgument("slide_id was not within bounds."); 128 return slices_[slice_id].num_groups(); 129 } group_size(int slice_id)130 StatusOr<int> group_size(int slice_id) const { 131 if (slice_id < 0 || slice_id >= num_slices()) 132 return errors::InvalidArgument("slide_id was not within bounds."); 133 return slices_[slice_id].group_size(); 134 } replica_ids(int slice_id)135 const std::vector<std::vector<int>>& replica_ids(int slice_id) const { 136 return slices_[slice_id].replica_ids(); 137 } 138 139 // Returns the replica groups for collectives running on a particular host. 140 // Callers should make sure `slice_id` is >= 0 and < num_slices(). host_replica_ids(int slice_id)141 const std::vector<std::vector<int>>& host_replica_ids(int slice_id) const { 142 return hosts_[slice_id].replica_ids(); 143 } 144 145 private: 146 // Groups of consecutive replica IDs starting at 0. 147 class ReplicaGroups { 148 public: 149 // Creates an object, enforcing the requirements on `replica_ids_`. 150 explicit ReplicaGroups(std::vector<std::vector<int>> replica_ids); 151 152 mlir::DenseIntElementsAttr ToMLIR(mlir::MLIRContext& context) const; 153 154 std::string ToString() const; 155 num_groups()156 int num_groups() const { return replica_ids_.size(); } group_size()157 int group_size() const { return replica_ids_.front().size(); } num_replica_ids()158 int num_replica_ids() const { return num_groups() * group_size(); } replica_ids()159 const std::vector<std::vector<int>>& replica_ids() const { 160 return replica_ids_; 161 } 162 163 private: 164 // N groups of replica IDs, N > 0. All groups have the same size G, G > 0. 165 // All replica IDs are distinct values >= 0; 166 std::vector<std::vector<int>> replica_ids_; // replica ID order matters 167 }; 168 169 // Creates an object but leaves `slices_` empty. `GlobalToSlices` should be 170 // called next to fill in `slices_`. GroupAssignment(ReplicaGroups global,ReplicaToDeviceMap replica_to_device_map)171 explicit GroupAssignment(ReplicaGroups global, 172 ReplicaToDeviceMap replica_to_device_map) 173 : global_(std::move(global)), 174 replica_to_device_map_(std::move(replica_to_device_map)) {} 175 176 // Divides the global view along slice boundaries and fill in the slice view. 177 Status GlobalToSlices(); 178 179 ReplicaGroups global_; 180 std::vector<ReplicaGroups> hosts_; // sorted by increasing slice ID 181 std::vector<ReplicaGroups> slices_; // sorted by increasing slice ID 182 ReplicaToDeviceMap replica_to_device_map_; 183 }; 184 185 } // namespace dtensor 186 } // namespace tensorflow 187 188 #endif // TENSORFLOW_DTENSOR_MLIR_GROUP_ASSIGNMENT_H_ 189