xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/all_reduce_blueconnect.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/all_reduce_blueconnect.h"
17 
18 #include <algorithm>
19 #include <iterator>
20 #include <optional>
21 #include <vector>
22 
23 #include "absl/algorithm/container.h"
24 #include "absl/container/btree_map.h"
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
27 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
29 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
30 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
31 #include "tensorflow/compiler/xla/service/hlo_query.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/status_macros.h"
34 
35 namespace xla {
36 namespace {
37 
GetOutputs(HloInstruction & instruction)38 std::vector<HloInstruction*> GetOutputs(HloInstruction& instruction) {
39   if (!instruction.shape().IsTuple()) {
40     return {&instruction};
41   }
42 
43   std::vector<HloInstruction*> outputs;
44   outputs.reserve(instruction.shape().tuple_shapes_size());
45 
46   HloComputation& computation = *instruction.parent();  // never null
47   for (int i = 0; i < instruction.shape().tuple_shapes_size(); ++i) {
48     outputs.push_back(computation.AddInstruction(
49         HloInstruction::CreateGetTupleElement(&instruction, i)));
50   }
51   return outputs;
52 }
53 
54 struct DecomposedReplicaGroups {
55   std::vector<ReplicaGroup> scatter_gather_groups;
56   std::vector<ReplicaGroup> new_all_reduce_groups;
57 };
58 
TryDecomposeReplicaGroup(const ReplicaGroup & replica_group,const DeviceAssignment & device_assignment,size_t num_devices_per_host)59 StatusOr<std::optional<DecomposedReplicaGroups>> TryDecomposeReplicaGroup(
60     const ReplicaGroup& replica_group,
61     const DeviceAssignment& device_assignment, size_t num_devices_per_host) {
62   int group_size = replica_group.replica_ids_size();
63   TF_RET_CHECK(group_size > 0);
64 
65   absl::btree_map<int, std::vector<int64_t>> replica_ids_by_host;
66   for (int64_t replica_id : replica_group.replica_ids()) {
67     int device_id = device_assignment(replica_id, /*computation_id=*/0);
68     TF_RET_CHECK(device_id >= 0);
69     // We assume that devices are ordered by host.
70     int host_id = device_id / num_devices_per_host;
71     replica_ids_by_host[host_id].push_back(replica_id);
72   }
73 
74   size_t num_local_devices = replica_ids_by_host.begin()->second.size();
75   bool same_num_devices_on_each_host =
76       absl::c_all_of(replica_ids_by_host, [&](const auto& entry) {
77         return entry.second.size() == num_local_devices;
78       });
79 
80   if (!same_num_devices_on_each_host) {
81     return {std::nullopt};
82   }
83 
84   std::vector<int64_t> sorted_replica_group;
85   sorted_replica_group.reserve(group_size);
86   for (const auto& entry : replica_ids_by_host) {
87     absl::c_copy(entry.second, std::back_inserter(sorted_replica_group));
88   }
89 
90   size_t scatter_group_size = std::max(num_local_devices, size_t(2));
91   size_t num_scatter_groups = group_size / scatter_group_size;
92 
93   if ((group_size % scatter_group_size != 0) || (num_scatter_groups < 2)) {
94     return {std::nullopt};
95   }
96 
97   std::vector<ReplicaGroup> scatter_gather_groups(num_scatter_groups);
98   std::vector<ReplicaGroup> new_all_reduce_groups(scatter_group_size);
99 
100   for (size_t i = 0; i < group_size; ++i) {
101     int64_t replica_id = sorted_replica_group[i];
102     scatter_gather_groups[i / scatter_group_size].add_replica_ids(replica_id);
103     new_all_reduce_groups[i % scatter_group_size].add_replica_ids(replica_id);
104   }
105 
106   return {DecomposedReplicaGroups{std::move(scatter_gather_groups),
107                                   std::move(new_all_reduce_groups)}};
108 }
109 
TryDecomposeReplicaGroups(const HloAllReduceInstruction & all_reduce,size_t num_devices_per_host)110 StatusOr<std::optional<DecomposedReplicaGroups>> TryDecomposeReplicaGroups(
111     const HloAllReduceInstruction& all_reduce, size_t num_devices_per_host) {
112   const DeviceAssignment& device_assignment =
113       all_reduce.parent()->parent()->config().static_device_assignment();
114 
115   absl::Span<const ReplicaGroup> replica_groups = all_reduce.replica_groups();
116 
117   ReplicaGroup all_replicas;  // only populated if replica groups not present.
118   if (replica_groups.empty()) {
119     for (int i = 0; i < device_assignment.replica_count(); ++i) {
120       all_replicas.add_replica_ids(i);
121     }
122     replica_groups = absl::MakeSpan(&all_replicas, 1);
123   }
124 
125   std::vector<ReplicaGroup> scatter_gather_groups;
126   std::vector<ReplicaGroup> new_all_reduce_groups;
127 
128   // Try to find a valid decomposition for each replica group.
129   for (const ReplicaGroup& replica_group : replica_groups) {
130     TF_ASSIGN_OR_RETURN(
131         std::optional<DecomposedReplicaGroups> decomposed_groups,
132         TryDecomposeReplicaGroup(replica_group, device_assignment,
133                                  num_devices_per_host));
134 
135     if (!decomposed_groups) return {std::nullopt};
136 
137     int scatter_group_size =
138         decomposed_groups->scatter_gather_groups[0].replica_ids_size();
139 
140     if (scatter_gather_groups.empty()) {
141       // Check that every operand is exactly divisible by scatter group sizes.
142       for (const HloInstruction* operand : all_reduce.operands()) {
143         TF_RET_CHECK(operand->shape().IsArray());
144         int64_t num_elements = ShapeUtil::ElementsIn(operand->shape());
145         if (num_elements % scatter_group_size != 0) {
146           return {std::nullopt};
147         }
148       }
149 
150       scatter_gather_groups.reserve(
151           replica_groups.size() *
152           decomposed_groups->scatter_gather_groups.size());
153       new_all_reduce_groups.reserve(
154           replica_groups.size() *
155           decomposed_groups->new_all_reduce_groups.size());
156     } else if (scatter_group_size !=
157                scatter_gather_groups[0].replica_ids_size()) {
158       // Reduce-scatter would have different output shapes on different devices.
159       return {std::nullopt};
160     }
161 
162     absl::c_move(decomposed_groups->scatter_gather_groups,
163                  std::back_inserter(scatter_gather_groups));
164     absl::c_move(decomposed_groups->new_all_reduce_groups,
165                  std::back_inserter(new_all_reduce_groups));
166   }
167 
168   return {DecomposedReplicaGroups{std::move(scatter_gather_groups),
169                                   std::move(new_all_reduce_groups)}};
170 }
171 
172 // Attempts to decompose all-reduces as described by the BlueConnect paper.
173 //
174 // If possible, the all-reduce will be transformed into:
175 // 1. reduce-scatter
176 // 2. all-reduce
177 // 3. all-gather
178 //
179 // If the all-reduce replica groups have more than one device within the same
180 // host, the reduce-scatter will be performed over all devices with each host.
181 // Otherwise, the reduce-scatter will be performed between pairs of devices on
182 // different hosts.
183 //
184 // When applied repeatedly, this transformation will reproduce the same pattern
185 // as described in the BlueConnect paper.
TryDecomposeAllReduce(HloAllReduceInstruction * all_reduce,size_t num_devices_per_host)186 StatusOr<bool> TryDecomposeAllReduce(HloAllReduceInstruction* all_reduce,
187                                      size_t num_devices_per_host) {
188   TF_RET_CHECK(all_reduce);
189   TF_RET_CHECK(!all_reduce->has_sharding());
190 
191   HloComputation& computation = *all_reduce->parent();  // never null
192   PrimitiveType element_type = all_reduce->operand(0)->shape().element_type();
193 
194   TF_ASSIGN_OR_RETURN(
195       std::optional<DecomposedReplicaGroups> decomposed_groups,
196       TryDecomposeReplicaGroups(*all_reduce, num_devices_per_host));
197 
198   if (!decomposed_groups) return false;
199 
200   // Bitcast operands to 1D to guarantee that first dimension is divisible by
201   // scatter group size (we checked num elements was divisible above).
202   std::vector<HloInstruction*> flat_operands;
203   flat_operands.reserve(all_reduce->operand_count());
204   std::vector<Shape> flat_shapes;
205   flat_shapes.reserve(all_reduce->operand_count());
206   std::vector<Shape> scattered_shapes;
207   scattered_shapes.reserve(all_reduce->operand_count());
208 
209   int scatter_group_size =
210       decomposed_groups->scatter_gather_groups[0].replica_ids_size();
211 
212   for (HloInstruction* operand : all_reduce->operands()) {
213     TF_RET_CHECK(operand->shape().IsArray());
214     int64_t num_elements = ShapeUtil::ElementsIn(operand->shape());
215     Shape flat_shape = ShapeUtil::MakeShape(element_type, {num_elements});
216     flat_operands.push_back(computation.AddInstruction(
217         HloInstruction::CreateBitcast(flat_shape, operand)));
218     flat_shapes.push_back(std::move(flat_shape));
219     scattered_shapes.push_back(ShapeUtil::MakeShape(
220         element_type, {num_elements / scatter_group_size}));
221   }
222 
223   Shape reduce_scatter_shape = ShapeUtil::MakeMaybeTupleShape(scattered_shapes);
224 
225   HloInstruction* reduce_scatter =
226       computation.AddInstruction(HloInstruction::CreateReduceScatter(
227           reduce_scatter_shape, flat_operands, all_reduce->to_apply(),
228           decomposed_groups->scatter_gather_groups, /*constrain_layout=*/false,
229           all_reduce->channel_id(), all_reduce->use_global_device_ids(),
230           /*scatter_dimension=*/0));
231 
232   HloInstruction* new_all_reduce =
233       computation.AddInstruction(HloInstruction::CreateAllReduce(
234           reduce_scatter_shape, GetOutputs(*reduce_scatter),
235           all_reduce->to_apply(), decomposed_groups->new_all_reduce_groups,
236           /*constrain_layout=*/false, all_reduce->channel_id(),
237           all_reduce->use_global_device_ids()));
238 
239   HloInstruction* all_gather =
240       computation.AddInstruction(HloInstruction::CreateAllGather(
241           ShapeUtil::MakeMaybeTupleShape(flat_shapes),
242           GetOutputs(*new_all_reduce),
243           /*all_gather_dimension=*/0, decomposed_groups->scatter_gather_groups,
244           /*constrain_layout=*/false, all_reduce->channel_id(),
245           all_reduce->use_global_device_ids()));
246 
247   // Bitcast back to the original shapes and replace all-reduce with decomposed
248   // implementation.
249   std::vector<HloInstruction*> outputs = GetOutputs(*all_gather);
250   for (int64_t i = 0; i < outputs.size(); ++i) {
251     outputs[i] = computation.AddInstruction(HloInstruction::CreateBitcast(
252         all_reduce->operand(i)->shape(), outputs[i]));
253   }
254 
255   TF_RETURN_IF_ERROR(
256       computation.ReplaceInstruction(all_reduce, MaybeMakeTuple(outputs)));
257 
258   // Try to apply decomposition recursively.
259   TF_RETURN_IF_ERROR(
260       TryDecomposeAllReduce(Cast<HloAllReduceInstruction>(new_all_reduce),
261                             num_devices_per_host)
262           .status());
263   return true;
264 }
265 
266 }  // namespace
267 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)268 StatusOr<bool> AllReduceBlueConnect::Run(
269     HloModule* module,
270     const absl::flat_hash_set<absl::string_view>& execution_threads) {
271   VLOG(1) << "Running AllReduceBlueConnect";
272 
273   if (hlo_query::ContainsLayoutConstrainedAllReduce(*module)) {
274     VLOG(1)
275         << "Skip AllReduceBlueConnect because the module contains all-reduce "
276            "with constrained layouts";
277     return false;
278   }
279   if (!module->config().has_static_device_assignment()) {
280     VLOG(1)
281         << "Skip AllReduceBlueConnect because the module doesn't have static "
282            "device assignment";
283     return false;
284   }
285 
286   std::vector<HloAllReduceInstruction*> all_reduces;
287   for (HloComputation* computation :
288        module->MakeNonfusionComputations(execution_threads)) {
289     for (HloInstruction* instruction : computation->instructions()) {
290       if (instruction->opcode() == HloOpcode::kAllReduce) {
291         all_reduces.push_back(Cast<HloAllReduceInstruction>(instruction));
292       }
293     }
294   }
295 
296   bool changed = false;
297   for (HloAllReduceInstruction* all_reduce : all_reduces) {
298     TF_ASSIGN_OR_RETURN(
299         bool all_reduce_changed,
300         TryDecomposeAllReduce(all_reduce, num_devices_per_host_));
301     changed |= all_reduce_changed;
302   }
303 
304   return changed;
305 }
306 
307 }  // namespace xla
308