1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/all_reduce_blueconnect.h"
17
18 #include <algorithm>
19 #include <iterator>
20 #include <optional>
21 #include <vector>
22
23 #include "absl/algorithm/container.h"
24 #include "absl/container/btree_map.h"
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
27 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
29 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
30 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
31 #include "tensorflow/compiler/xla/service/hlo_query.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/status_macros.h"
34
35 namespace xla {
36 namespace {
37
GetOutputs(HloInstruction & instruction)38 std::vector<HloInstruction*> GetOutputs(HloInstruction& instruction) {
39 if (!instruction.shape().IsTuple()) {
40 return {&instruction};
41 }
42
43 std::vector<HloInstruction*> outputs;
44 outputs.reserve(instruction.shape().tuple_shapes_size());
45
46 HloComputation& computation = *instruction.parent(); // never null
47 for (int i = 0; i < instruction.shape().tuple_shapes_size(); ++i) {
48 outputs.push_back(computation.AddInstruction(
49 HloInstruction::CreateGetTupleElement(&instruction, i)));
50 }
51 return outputs;
52 }
53
54 struct DecomposedReplicaGroups {
55 std::vector<ReplicaGroup> scatter_gather_groups;
56 std::vector<ReplicaGroup> new_all_reduce_groups;
57 };
58
TryDecomposeReplicaGroup(const ReplicaGroup & replica_group,const DeviceAssignment & device_assignment,size_t num_devices_per_host)59 StatusOr<std::optional<DecomposedReplicaGroups>> TryDecomposeReplicaGroup(
60 const ReplicaGroup& replica_group,
61 const DeviceAssignment& device_assignment, size_t num_devices_per_host) {
62 int group_size = replica_group.replica_ids_size();
63 TF_RET_CHECK(group_size > 0);
64
65 absl::btree_map<int, std::vector<int64_t>> replica_ids_by_host;
66 for (int64_t replica_id : replica_group.replica_ids()) {
67 int device_id = device_assignment(replica_id, /*computation_id=*/0);
68 TF_RET_CHECK(device_id >= 0);
69 // We assume that devices are ordered by host.
70 int host_id = device_id / num_devices_per_host;
71 replica_ids_by_host[host_id].push_back(replica_id);
72 }
73
74 size_t num_local_devices = replica_ids_by_host.begin()->second.size();
75 bool same_num_devices_on_each_host =
76 absl::c_all_of(replica_ids_by_host, [&](const auto& entry) {
77 return entry.second.size() == num_local_devices;
78 });
79
80 if (!same_num_devices_on_each_host) {
81 return {std::nullopt};
82 }
83
84 std::vector<int64_t> sorted_replica_group;
85 sorted_replica_group.reserve(group_size);
86 for (const auto& entry : replica_ids_by_host) {
87 absl::c_copy(entry.second, std::back_inserter(sorted_replica_group));
88 }
89
90 size_t scatter_group_size = std::max(num_local_devices, size_t(2));
91 size_t num_scatter_groups = group_size / scatter_group_size;
92
93 if ((group_size % scatter_group_size != 0) || (num_scatter_groups < 2)) {
94 return {std::nullopt};
95 }
96
97 std::vector<ReplicaGroup> scatter_gather_groups(num_scatter_groups);
98 std::vector<ReplicaGroup> new_all_reduce_groups(scatter_group_size);
99
100 for (size_t i = 0; i < group_size; ++i) {
101 int64_t replica_id = sorted_replica_group[i];
102 scatter_gather_groups[i / scatter_group_size].add_replica_ids(replica_id);
103 new_all_reduce_groups[i % scatter_group_size].add_replica_ids(replica_id);
104 }
105
106 return {DecomposedReplicaGroups{std::move(scatter_gather_groups),
107 std::move(new_all_reduce_groups)}};
108 }
109
TryDecomposeReplicaGroups(const HloAllReduceInstruction & all_reduce,size_t num_devices_per_host)110 StatusOr<std::optional<DecomposedReplicaGroups>> TryDecomposeReplicaGroups(
111 const HloAllReduceInstruction& all_reduce, size_t num_devices_per_host) {
112 const DeviceAssignment& device_assignment =
113 all_reduce.parent()->parent()->config().static_device_assignment();
114
115 absl::Span<const ReplicaGroup> replica_groups = all_reduce.replica_groups();
116
117 ReplicaGroup all_replicas; // only populated if replica groups not present.
118 if (replica_groups.empty()) {
119 for (int i = 0; i < device_assignment.replica_count(); ++i) {
120 all_replicas.add_replica_ids(i);
121 }
122 replica_groups = absl::MakeSpan(&all_replicas, 1);
123 }
124
125 std::vector<ReplicaGroup> scatter_gather_groups;
126 std::vector<ReplicaGroup> new_all_reduce_groups;
127
128 // Try to find a valid decomposition for each replica group.
129 for (const ReplicaGroup& replica_group : replica_groups) {
130 TF_ASSIGN_OR_RETURN(
131 std::optional<DecomposedReplicaGroups> decomposed_groups,
132 TryDecomposeReplicaGroup(replica_group, device_assignment,
133 num_devices_per_host));
134
135 if (!decomposed_groups) return {std::nullopt};
136
137 int scatter_group_size =
138 decomposed_groups->scatter_gather_groups[0].replica_ids_size();
139
140 if (scatter_gather_groups.empty()) {
141 // Check that every operand is exactly divisible by scatter group sizes.
142 for (const HloInstruction* operand : all_reduce.operands()) {
143 TF_RET_CHECK(operand->shape().IsArray());
144 int64_t num_elements = ShapeUtil::ElementsIn(operand->shape());
145 if (num_elements % scatter_group_size != 0) {
146 return {std::nullopt};
147 }
148 }
149
150 scatter_gather_groups.reserve(
151 replica_groups.size() *
152 decomposed_groups->scatter_gather_groups.size());
153 new_all_reduce_groups.reserve(
154 replica_groups.size() *
155 decomposed_groups->new_all_reduce_groups.size());
156 } else if (scatter_group_size !=
157 scatter_gather_groups[0].replica_ids_size()) {
158 // Reduce-scatter would have different output shapes on different devices.
159 return {std::nullopt};
160 }
161
162 absl::c_move(decomposed_groups->scatter_gather_groups,
163 std::back_inserter(scatter_gather_groups));
164 absl::c_move(decomposed_groups->new_all_reduce_groups,
165 std::back_inserter(new_all_reduce_groups));
166 }
167
168 return {DecomposedReplicaGroups{std::move(scatter_gather_groups),
169 std::move(new_all_reduce_groups)}};
170 }
171
172 // Attempts to decompose all-reduces as described by the BlueConnect paper.
173 //
174 // If possible, the all-reduce will be transformed into:
175 // 1. reduce-scatter
176 // 2. all-reduce
177 // 3. all-gather
178 //
179 // If the all-reduce replica groups have more than one device within the same
180 // host, the reduce-scatter will be performed over all devices with each host.
181 // Otherwise, the reduce-scatter will be performed between pairs of devices on
182 // different hosts.
183 //
184 // When applied repeatedly, this transformation will reproduce the same pattern
185 // as described in the BlueConnect paper.
TryDecomposeAllReduce(HloAllReduceInstruction * all_reduce,size_t num_devices_per_host)186 StatusOr<bool> TryDecomposeAllReduce(HloAllReduceInstruction* all_reduce,
187 size_t num_devices_per_host) {
188 TF_RET_CHECK(all_reduce);
189 TF_RET_CHECK(!all_reduce->has_sharding());
190
191 HloComputation& computation = *all_reduce->parent(); // never null
192 PrimitiveType element_type = all_reduce->operand(0)->shape().element_type();
193
194 TF_ASSIGN_OR_RETURN(
195 std::optional<DecomposedReplicaGroups> decomposed_groups,
196 TryDecomposeReplicaGroups(*all_reduce, num_devices_per_host));
197
198 if (!decomposed_groups) return false;
199
200 // Bitcast operands to 1D to guarantee that first dimension is divisible by
201 // scatter group size (we checked num elements was divisible above).
202 std::vector<HloInstruction*> flat_operands;
203 flat_operands.reserve(all_reduce->operand_count());
204 std::vector<Shape> flat_shapes;
205 flat_shapes.reserve(all_reduce->operand_count());
206 std::vector<Shape> scattered_shapes;
207 scattered_shapes.reserve(all_reduce->operand_count());
208
209 int scatter_group_size =
210 decomposed_groups->scatter_gather_groups[0].replica_ids_size();
211
212 for (HloInstruction* operand : all_reduce->operands()) {
213 TF_RET_CHECK(operand->shape().IsArray());
214 int64_t num_elements = ShapeUtil::ElementsIn(operand->shape());
215 Shape flat_shape = ShapeUtil::MakeShape(element_type, {num_elements});
216 flat_operands.push_back(computation.AddInstruction(
217 HloInstruction::CreateBitcast(flat_shape, operand)));
218 flat_shapes.push_back(std::move(flat_shape));
219 scattered_shapes.push_back(ShapeUtil::MakeShape(
220 element_type, {num_elements / scatter_group_size}));
221 }
222
223 Shape reduce_scatter_shape = ShapeUtil::MakeMaybeTupleShape(scattered_shapes);
224
225 HloInstruction* reduce_scatter =
226 computation.AddInstruction(HloInstruction::CreateReduceScatter(
227 reduce_scatter_shape, flat_operands, all_reduce->to_apply(),
228 decomposed_groups->scatter_gather_groups, /*constrain_layout=*/false,
229 all_reduce->channel_id(), all_reduce->use_global_device_ids(),
230 /*scatter_dimension=*/0));
231
232 HloInstruction* new_all_reduce =
233 computation.AddInstruction(HloInstruction::CreateAllReduce(
234 reduce_scatter_shape, GetOutputs(*reduce_scatter),
235 all_reduce->to_apply(), decomposed_groups->new_all_reduce_groups,
236 /*constrain_layout=*/false, all_reduce->channel_id(),
237 all_reduce->use_global_device_ids()));
238
239 HloInstruction* all_gather =
240 computation.AddInstruction(HloInstruction::CreateAllGather(
241 ShapeUtil::MakeMaybeTupleShape(flat_shapes),
242 GetOutputs(*new_all_reduce),
243 /*all_gather_dimension=*/0, decomposed_groups->scatter_gather_groups,
244 /*constrain_layout=*/false, all_reduce->channel_id(),
245 all_reduce->use_global_device_ids()));
246
247 // Bitcast back to the original shapes and replace all-reduce with decomposed
248 // implementation.
249 std::vector<HloInstruction*> outputs = GetOutputs(*all_gather);
250 for (int64_t i = 0; i < outputs.size(); ++i) {
251 outputs[i] = computation.AddInstruction(HloInstruction::CreateBitcast(
252 all_reduce->operand(i)->shape(), outputs[i]));
253 }
254
255 TF_RETURN_IF_ERROR(
256 computation.ReplaceInstruction(all_reduce, MaybeMakeTuple(outputs)));
257
258 // Try to apply decomposition recursively.
259 TF_RETURN_IF_ERROR(
260 TryDecomposeAllReduce(Cast<HloAllReduceInstruction>(new_all_reduce),
261 num_devices_per_host)
262 .status());
263 return true;
264 }
265
266 } // namespace
267
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)268 StatusOr<bool> AllReduceBlueConnect::Run(
269 HloModule* module,
270 const absl::flat_hash_set<absl::string_view>& execution_threads) {
271 VLOG(1) << "Running AllReduceBlueConnect";
272
273 if (hlo_query::ContainsLayoutConstrainedAllReduce(*module)) {
274 VLOG(1)
275 << "Skip AllReduceBlueConnect because the module contains all-reduce "
276 "with constrained layouts";
277 return false;
278 }
279 if (!module->config().has_static_device_assignment()) {
280 VLOG(1)
281 << "Skip AllReduceBlueConnect because the module doesn't have static "
282 "device assignment";
283 return false;
284 }
285
286 std::vector<HloAllReduceInstruction*> all_reduces;
287 for (HloComputation* computation :
288 module->MakeNonfusionComputations(execution_threads)) {
289 for (HloInstruction* instruction : computation->instructions()) {
290 if (instruction->opcode() == HloOpcode::kAllReduce) {
291 all_reduces.push_back(Cast<HloAllReduceInstruction>(instruction));
292 }
293 }
294 }
295
296 bool changed = false;
297 for (HloAllReduceInstruction* all_reduce : all_reduces) {
298 TF_ASSIGN_OR_RETURN(
299 bool all_reduce_changed,
300 TryDecomposeAllReduce(all_reduce, num_devices_per_host_));
301 changed |= all_reduce_changed;
302 }
303
304 return changed;
305 }
306
307 } // namespace xla
308