xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/xla/transforms/legalize_tf_collective.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for lowering TensorFlow dialect's collective
17 // ops (TF/XLA) to the HLO dialect.
18 
19 #include <numeric>
20 #include <string>
21 #include <utility>
22 
23 #include "llvm/ADT/StringRef.h"
24 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
25 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
26 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
27 #include "mlir/IR/Dialect.h"  // from @llvm-project
28 #include "mlir/IR/Matchers.h"  // from @llvm-project
29 #include "mlir/IR/Operation.h"  // from @llvm-project
30 #include "mlir/Pass/Pass.h"  // from @llvm-project
31 #include "mlir/Support/DebugStringHelper.h"  // from @llvm-project
32 #include "mlir/Support/LLVM.h"  // from @llvm-project
33 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
34 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
36 #include "tensorflow/compiler/mlir/xla/transforms/utils.h"
37 #include "tensorflow/compiler/mlir/xla/transforms/xla_legalize_tf_passes_detail.h"
38 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
39 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
40 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/utils/convert_op_folder.h"
41 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/utils/hlo_utils.h"
42 #include "tensorflow/compiler/xla/xla_data.pb.h"
43 
44 namespace mlir {
45 namespace mhlo {
46 
47 namespace {
48 
49 constexpr absl::string_view kGroupSizeAttrName =
50     "tf2xla.collective_info.group_size";
51 constexpr absl::string_view kGroupKeyAttrName =
52     "tf2xla.collective_info.group_key";
53 
54 class LegalizeTFCollective
55     : public LegalizeTFCollectiveBase<LegalizeTFCollective> {
56  public:
57   void runOnOperation() override;
58 };
59 
SetOnceModuleAttribute(StringRef attr_name,IntegerAttr attr_value,Operation * op,ModuleOp & module)60 LogicalResult SetOnceModuleAttribute(StringRef attr_name,
61                                      IntegerAttr attr_value, Operation* op,
62                                      ModuleOp& module) {
63   const auto ex_attr_value = module->getAttrOfType<IntegerAttr>(attr_name);
64   if (ex_attr_value == nullptr) {
65     module->setAttr(attr_name, attr_value);
66     return success();
67   }
68   if (ex_attr_value == attr_value) {
69     return success();
70   }
71   return op->emitOpError() << "module already contains an attribute "
72                            << attr_name << "=" << ex_attr_value.getInt()
73                            << ", overwritting to a new value "
74                            << attr_value.getInt() << " is not allowed.";
75 }
76 
SetCollectiveInfo(IntegerAttr group_size,IntegerAttr group_key,Operation * op)77 LogicalResult SetCollectiveInfo(IntegerAttr group_size, IntegerAttr group_key,
78                                 Operation* op) {
79   ModuleOp module = op->getParentOfType<ModuleOp>();
80   // The StringRef cast is necessary before cxx14.
81   if (failed(SetOnceModuleAttribute(
82           StringRef(kGroupSizeAttrName.data(), kGroupSizeAttrName.size()),
83           group_size, op, module))) {
84     return failure();
85   }
86   if (failed(SetOnceModuleAttribute(
87           StringRef(kGroupKeyAttrName.data(), kGroupKeyAttrName.size()),
88           group_key, op, module))) {
89     return failure();
90   }
91   return success();
92 }
93 
SetCollectiveInfo(OpBuilder & builder,DenseIntElementsAttr replica_groups,Operation * op)94 LogicalResult SetCollectiveInfo(OpBuilder& builder,
95                                 DenseIntElementsAttr replica_groups,
96                                 Operation* op) {
97   // Use special group_key 0 to represent "all available devices". This
98   // shall resolve to a DeviceAssignment that includes all devices intended for
99   // replica_groups.
100   IntegerAttr group_size = builder.getI32IntegerAttr(replica_groups.size());
101   IntegerAttr group_key = builder.getI32IntegerAttr(0);
102   return SetCollectiveInfo(group_size, group_key, op);
103 }
104 
ConvertReplicaGroups(OpBuilder & builder,Value group_assignment_value,DenseIntElementsAttr & replica_groups,Operation * op)105 LogicalResult ConvertReplicaGroups(OpBuilder& builder,
106                                    Value group_assignment_value,
107                                    DenseIntElementsAttr& replica_groups,
108                                    Operation* op) {
109   DenseIntElementsAttr group_assignment;
110   if (!matchPattern(group_assignment_value, m_Constant(&group_assignment))) {
111     return op->emitOpError() << "expects constant group_assignment";
112   }
113   replica_groups =
114       hlo::convertElementsAttr(group_assignment, builder.getIntegerType(64))
115           .cast<DenseIntElementsAttr>();
116   if (replica_groups.getType().getRank() != 2) {
117     return op->emitOpError() << "group_assignment should have rank 2, got "
118                              << replica_groups.getType().getRank();
119   }
120   return success();
121 }
122 
ConvertChannel(OpBuilder & builder,int64_t channel_id,StringRef mode)123 ChannelHandleAttr ConvertChannel(OpBuilder& builder, int64_t channel_id,
124                                  StringRef mode) {
125   if (mode == "CrossReplica") {
126     return ChannelHandleAttr();
127   }
128   return ChannelHandleAttr::get(builder.getContext(),
129                                 /*handle=*/channel_id,
130                                 /*type=*/xla::ChannelHandle::DEVICE_TO_DEVICE);
131 }
132 
ConvertAllReduce(OpBuilder & builder,int64_t channel_id,TensorType result_type,DenseIntElementsAttr replica_groups,StringRef mode,Value input,StringRef merge_op,StringRef final_op,Operation * op)133 LogicalResult ConvertAllReduce(OpBuilder& builder, int64_t channel_id,
134                                TensorType result_type,
135                                DenseIntElementsAttr replica_groups,
136                                StringRef mode, Value input, StringRef merge_op,
137                                StringRef final_op, Operation* op) {
138   builder.setInsertionPoint(op);
139   ChannelHandleAttr channel_handle = ConvertChannel(builder, channel_id, mode);
140   Location loc = op->getLoc();
141   Type element_type = getElementTypeOrSelf(input.getType());
142   auto all_reduce = builder.create<AllReduceOp>(
143       loc, result_type, input, replica_groups, channel_handle, nullptr);
144   if (merge_op == "Add") {
145     BuildReduceBody<AddOp>(element_type, &all_reduce.computation(), &builder);
146   } else if (merge_op == "Mul") {
147     BuildReduceBody<MulOp>(element_type, &all_reduce.computation(), &builder);
148   } else if (merge_op == "Min") {
149     BuildReduceBody<MinOp>(element_type, &all_reduce.computation(), &builder);
150   } else if (merge_op == "Max") {
151     BuildReduceBody<MaxOp>(element_type, &all_reduce.computation(), &builder);
152   } else {
153     return op->emitOpError() << "invalid merge_op " << merge_op
154                              << ", want one of [Add, Mul, Min, Max]";
155   }
156 
157   Operation* result = all_reduce;
158   // For "Div" final op, divide the merge result by group size.
159   if (final_op == "Div") {
160     int64_t replica_group_size = replica_groups.getType().getDimSize(1);
161     if (replica_group_size == 0) {
162       op->emitOpError()
163           << "Div final_op requires a non-empty replica_groups argument.";
164     }
165     auto divisor =
166         GetScalarConstOfType(element_type, loc, replica_group_size, &builder);
167     auto broadcast_dims = GetI64ElementsAttr({}, &builder);
168     result = builder.create<chlo::BroadcastDivOp>(
169         loc, all_reduce.getResult(), divisor.getResult(), broadcast_dims);
170   } else if (final_op != "Id") {
171     return op->emitOpError()
172            << "invalid final_op " << final_op << ", want one of [Id, Div]";
173   }
174   op->replaceAllUsesWith(result);
175 
176   op->erase();
177   return success();
178 }
179 
180 template <typename T>
181 class CollectiveRewritePattern : public OpRewritePattern<T> {
182  public:
183   // Does not take any ownership. Caller must ensure channel_id is valid during
184   // life-cylce of this object.
CollectiveRewritePattern(MLIRContext * context,int64_t * channel_id)185   CollectiveRewritePattern(MLIRContext* context, int64_t* channel_id)
186       : OpRewritePattern<T>(context), channel_id_(*channel_id) {}
187 
188  protected:
189   int64_t& channel_id_;  // A unique channel_id shared by all rewrite patterns
190                          // in this pass. Not thread-safe.
191 };
192 
193 // Converts XlaAllReduce. Not thread-safe.
194 class ConvertXlaAllReduce
195     : public CollectiveRewritePattern<TF::XlaAllReduceOp> {
196  public:
197   using CollectiveRewritePattern::CollectiveRewritePattern;
198 
matchAndRewrite(TF::XlaAllReduceOp all_reduce,PatternRewriter & rewriter) const199   LogicalResult matchAndRewrite(TF::XlaAllReduceOp all_reduce,
200                                 PatternRewriter& rewriter) const override {
201     DenseIntElementsAttr replica_groups;
202     if (failed(ConvertReplicaGroups(rewriter, all_reduce.group_assignment(),
203                                     replica_groups, all_reduce))) {
204       return failure();
205     }
206 
207     // TODO(b/226201111): Stop emitting CollectiveInfo when it is no longer
208     // needed.
209     if (failed(SetCollectiveInfo(rewriter, replica_groups, all_reduce))) {
210       return failure();
211     }
212 
213     StringRef reduce_op = all_reduce.reduce_op();
214 
215     StringRef merge_op, final_op;
216     if (reduce_op == "Add") {
217       merge_op = "Add";
218       final_op = "Id";
219     } else if (reduce_op == "Mul") {
220       merge_op = "Mul";
221       final_op = "Id";
222     } else if (reduce_op == "Min") {
223       merge_op = "Min";
224       final_op = "Id";
225     } else if (reduce_op == "Max") {
226       merge_op = "Max";
227       final_op = "Id";
228     } else if (reduce_op == "Mean") {
229       merge_op = "Add";
230       final_op = "Div";
231     } else {
232       return all_reduce->emitOpError()
233              << "invalid reduce_op " << reduce_op
234              << ", want one of [Add, Mul, Min, Max, Mean]";
235     }
236 
237     int64_t channel_id = channel_id_++;
238     return ConvertAllReduce(rewriter, channel_id, all_reduce.getType(),
239                             replica_groups, all_reduce.mode(),
240                             all_reduce.input(), merge_op, final_op, all_reduce);
241   }
242 };
243 
244 // Converts CollectiveReduceV2, with or without a preceding
245 // CollectiveAssignGroupV2. Not thread-safe.
246 class ConvertCollectiveReduceV2
247     : public CollectiveRewritePattern<TF::CollectiveReduceV2Op> {
248  public:
249   using CollectiveRewritePattern::CollectiveRewritePattern;
250 
matchAndRewrite(TF::CollectiveReduceV2Op all_reduce,PatternRewriter & rewriter) const251   LogicalResult matchAndRewrite(TF::CollectiveReduceV2Op all_reduce,
252                                 PatternRewriter& rewriter) const override {
253     TF::CollectiveAssignGroupV2Op assign_group =
254         all_reduce.group_size().getDefiningOp<TF::CollectiveAssignGroupV2Op>();
255 
256     if (assign_group) {
257       // Found a group assignment. Use replica_groups to represent group
258       // assignment.
259 
260       if (assign_group != all_reduce.group_key()
261                               .getDefiningOp<TF::CollectiveAssignGroupV2Op>()) {
262         return all_reduce->emitOpError()
263                << "group_size and group_key are not from the "
264                   "same CollectiveAssignGroupV2Op";
265       }
266 
267       DenseIntElementsAttr replica_groups;
268       if (failed(ConvertReplicaGroups(rewriter, assign_group.group_assignment(),
269                                       replica_groups, all_reduce))) {
270         return failure();
271       }
272 
273       // TODO(b/226201111): Stop emitting CollectiveInfo when it is no longer
274       // needed.
275       if (failed(SetCollectiveInfo(rewriter, replica_groups, all_reduce))) {
276         return failure();
277       }
278 
279       int64_t channel_id = channel_id_++;
280       // FIXME(b/226139061): Mode should be set to CrossReplicaAndPartition
281       // in order to use XLA:GPU for more than one workers.
282       // The mode is set to use CrossReplica to keep the
283       // behavior on the primary user of this optimized path, because
284       // CrossReplicaAndPartition triggers a conflict with the channel_id
285       // allocation in the communication lowering, and the user uses both set of
286       // ops are used.
287       return ConvertAllReduce(rewriter, channel_id, all_reduce.getType(),
288                               replica_groups, /* mode=*/"CrossReplica",
289                               all_reduce.input(), all_reduce.merge_op(),
290                               all_reduce.final_op(), all_reduce);
291     }
292 
293     // No group assignment, use separate channels per group_key.
294     DenseIntElementsAttr group_size_attr;
295     if (!matchPattern(all_reduce.group_size(), m_Constant(&group_size_attr))) {
296       return all_reduce.emitOpError()
297              << "group_size must be a compile time constant";
298     }
299     if (!group_size_attr.isSplat() || group_size_attr.size() != 1) {
300       return all_reduce.emitOpError() << "group_size must be a scalar";
301     }
302     const auto group_size = group_size_attr.getSplatValue<IntegerAttr>();
303 
304     // Create a full group assignment. Empty group assignment errors when
305     // final_op = "Div"
306     llvm::SmallVector<int64_t> indices(group_size.getInt());
307     std::iota(indices.begin(), indices.end(), 0);
308 
309     auto replica_groups = mlir::DenseIntElementsAttr::get(
310         mlir::RankedTensorType::get({1, group_size.getInt()},
311                                     rewriter.getI64Type()),
312         indices);
313 
314     {
315       // TODO(b/226201111): Stop emitting CollectiveInfo when it is no longer
316       // needed.
317       DenseIntElementsAttr group_key_attr;
318       if (!matchPattern(all_reduce.group_key(), m_Constant(&group_key_attr))) {
319         return all_reduce.emitOpError()
320                << "group_key must be a compile time constant";
321       }
322       if (failed(SetCollectiveInfo(
323               /* group_size=*/group_size,
324               /* group_key=*/group_key_attr.getSplatValue<IntegerAttr>(),
325               all_reduce))) {
326         return failure();
327       }
328     }
329 
330     // CrossReplicaAndPartition:
331     // Even though TF2XLA will setup the device assignment to include
332     // devices in this group as replicas before launching this module,
333     // "CrossReplica" mode (no channel) produces a deadlock when
334     // not using XLA SPMD expansion.
335     int64_t channel_id = channel_id_++;
336     return ConvertAllReduce(
337         rewriter, channel_id, all_reduce.getType(), replica_groups,
338         /* mode= */ "CrossReplicaAndPartition", all_reduce.input(),
339         all_reduce.merge_op(), all_reduce.final_op(), all_reduce);
340   }
341 };
342 
runOnOperation()343 void LegalizeTFCollective::runOnOperation() {
344   // FIXME(b/226139061): Figure out a way to share the channel_id with
345   // send/recv Ops.
346   int64_t channel_id = 1;
347   auto module = getOperation();
348   MLIRContext* context = module->getContext();
349 
350   RewritePatternSet patterns(context);
351   patterns.insert<ConvertCollectiveReduceV2>(context, &channel_id);
352   patterns.insert<ConvertXlaAllReduce>(context, &channel_id);
353 
354   if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) {
355     signalPassFailure();
356   }
357 }
358 }  // namespace
359 
CreateLegalizeTFCollectivePass()360 std::unique_ptr<OperationPass<ModuleOp>> CreateLegalizeTFCollectivePass() {
361   return std::make_unique<LegalizeTFCollective>();
362 }
363 
364 }  // namespace mhlo
365 }  // namespace mlir
366