1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for lowering TensorFlow dialect's collective
17 // ops (TF/XLA) to the HLO dialect.
18
19 #include <numeric>
20 #include <string>
21 #include <utility>
22
23 #include "llvm/ADT/StringRef.h"
24 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
25 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
26 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
27 #include "mlir/IR/Dialect.h" // from @llvm-project
28 #include "mlir/IR/Matchers.h" // from @llvm-project
29 #include "mlir/IR/Operation.h" // from @llvm-project
30 #include "mlir/Pass/Pass.h" // from @llvm-project
31 #include "mlir/Support/DebugStringHelper.h" // from @llvm-project
32 #include "mlir/Support/LLVM.h" // from @llvm-project
33 #include "mlir/Support/LogicalResult.h" // from @llvm-project
34 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
36 #include "tensorflow/compiler/mlir/xla/transforms/utils.h"
37 #include "tensorflow/compiler/mlir/xla/transforms/xla_legalize_tf_passes_detail.h"
38 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
39 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
40 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/utils/convert_op_folder.h"
41 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/utils/hlo_utils.h"
42 #include "tensorflow/compiler/xla/xla_data.pb.h"
43
44 namespace mlir {
45 namespace mhlo {
46
47 namespace {
48
49 constexpr absl::string_view kGroupSizeAttrName =
50 "tf2xla.collective_info.group_size";
51 constexpr absl::string_view kGroupKeyAttrName =
52 "tf2xla.collective_info.group_key";
53
54 class LegalizeTFCollective
55 : public LegalizeTFCollectiveBase<LegalizeTFCollective> {
56 public:
57 void runOnOperation() override;
58 };
59
SetOnceModuleAttribute(StringRef attr_name,IntegerAttr attr_value,Operation * op,ModuleOp & module)60 LogicalResult SetOnceModuleAttribute(StringRef attr_name,
61 IntegerAttr attr_value, Operation* op,
62 ModuleOp& module) {
63 const auto ex_attr_value = module->getAttrOfType<IntegerAttr>(attr_name);
64 if (ex_attr_value == nullptr) {
65 module->setAttr(attr_name, attr_value);
66 return success();
67 }
68 if (ex_attr_value == attr_value) {
69 return success();
70 }
71 return op->emitOpError() << "module already contains an attribute "
72 << attr_name << "=" << ex_attr_value.getInt()
73 << ", overwritting to a new value "
74 << attr_value.getInt() << " is not allowed.";
75 }
76
SetCollectiveInfo(IntegerAttr group_size,IntegerAttr group_key,Operation * op)77 LogicalResult SetCollectiveInfo(IntegerAttr group_size, IntegerAttr group_key,
78 Operation* op) {
79 ModuleOp module = op->getParentOfType<ModuleOp>();
80 // The StringRef cast is necessary before cxx14.
81 if (failed(SetOnceModuleAttribute(
82 StringRef(kGroupSizeAttrName.data(), kGroupSizeAttrName.size()),
83 group_size, op, module))) {
84 return failure();
85 }
86 if (failed(SetOnceModuleAttribute(
87 StringRef(kGroupKeyAttrName.data(), kGroupKeyAttrName.size()),
88 group_key, op, module))) {
89 return failure();
90 }
91 return success();
92 }
93
SetCollectiveInfo(OpBuilder & builder,DenseIntElementsAttr replica_groups,Operation * op)94 LogicalResult SetCollectiveInfo(OpBuilder& builder,
95 DenseIntElementsAttr replica_groups,
96 Operation* op) {
97 // Use special group_key 0 to represent "all available devices". This
98 // shall resolve to a DeviceAssignment that includes all devices intended for
99 // replica_groups.
100 IntegerAttr group_size = builder.getI32IntegerAttr(replica_groups.size());
101 IntegerAttr group_key = builder.getI32IntegerAttr(0);
102 return SetCollectiveInfo(group_size, group_key, op);
103 }
104
ConvertReplicaGroups(OpBuilder & builder,Value group_assignment_value,DenseIntElementsAttr & replica_groups,Operation * op)105 LogicalResult ConvertReplicaGroups(OpBuilder& builder,
106 Value group_assignment_value,
107 DenseIntElementsAttr& replica_groups,
108 Operation* op) {
109 DenseIntElementsAttr group_assignment;
110 if (!matchPattern(group_assignment_value, m_Constant(&group_assignment))) {
111 return op->emitOpError() << "expects constant group_assignment";
112 }
113 replica_groups =
114 hlo::convertElementsAttr(group_assignment, builder.getIntegerType(64))
115 .cast<DenseIntElementsAttr>();
116 if (replica_groups.getType().getRank() != 2) {
117 return op->emitOpError() << "group_assignment should have rank 2, got "
118 << replica_groups.getType().getRank();
119 }
120 return success();
121 }
122
ConvertChannel(OpBuilder & builder,int64_t channel_id,StringRef mode)123 ChannelHandleAttr ConvertChannel(OpBuilder& builder, int64_t channel_id,
124 StringRef mode) {
125 if (mode == "CrossReplica") {
126 return ChannelHandleAttr();
127 }
128 return ChannelHandleAttr::get(builder.getContext(),
129 /*handle=*/channel_id,
130 /*type=*/xla::ChannelHandle::DEVICE_TO_DEVICE);
131 }
132
ConvertAllReduce(OpBuilder & builder,int64_t channel_id,TensorType result_type,DenseIntElementsAttr replica_groups,StringRef mode,Value input,StringRef merge_op,StringRef final_op,Operation * op)133 LogicalResult ConvertAllReduce(OpBuilder& builder, int64_t channel_id,
134 TensorType result_type,
135 DenseIntElementsAttr replica_groups,
136 StringRef mode, Value input, StringRef merge_op,
137 StringRef final_op, Operation* op) {
138 builder.setInsertionPoint(op);
139 ChannelHandleAttr channel_handle = ConvertChannel(builder, channel_id, mode);
140 Location loc = op->getLoc();
141 Type element_type = getElementTypeOrSelf(input.getType());
142 auto all_reduce = builder.create<AllReduceOp>(
143 loc, result_type, input, replica_groups, channel_handle, nullptr);
144 if (merge_op == "Add") {
145 BuildReduceBody<AddOp>(element_type, &all_reduce.computation(), &builder);
146 } else if (merge_op == "Mul") {
147 BuildReduceBody<MulOp>(element_type, &all_reduce.computation(), &builder);
148 } else if (merge_op == "Min") {
149 BuildReduceBody<MinOp>(element_type, &all_reduce.computation(), &builder);
150 } else if (merge_op == "Max") {
151 BuildReduceBody<MaxOp>(element_type, &all_reduce.computation(), &builder);
152 } else {
153 return op->emitOpError() << "invalid merge_op " << merge_op
154 << ", want one of [Add, Mul, Min, Max]";
155 }
156
157 Operation* result = all_reduce;
158 // For "Div" final op, divide the merge result by group size.
159 if (final_op == "Div") {
160 int64_t replica_group_size = replica_groups.getType().getDimSize(1);
161 if (replica_group_size == 0) {
162 op->emitOpError()
163 << "Div final_op requires a non-empty replica_groups argument.";
164 }
165 auto divisor =
166 GetScalarConstOfType(element_type, loc, replica_group_size, &builder);
167 auto broadcast_dims = GetI64ElementsAttr({}, &builder);
168 result = builder.create<chlo::BroadcastDivOp>(
169 loc, all_reduce.getResult(), divisor.getResult(), broadcast_dims);
170 } else if (final_op != "Id") {
171 return op->emitOpError()
172 << "invalid final_op " << final_op << ", want one of [Id, Div]";
173 }
174 op->replaceAllUsesWith(result);
175
176 op->erase();
177 return success();
178 }
179
180 template <typename T>
181 class CollectiveRewritePattern : public OpRewritePattern<T> {
182 public:
183 // Does not take any ownership. Caller must ensure channel_id is valid during
184 // life-cylce of this object.
CollectiveRewritePattern(MLIRContext * context,int64_t * channel_id)185 CollectiveRewritePattern(MLIRContext* context, int64_t* channel_id)
186 : OpRewritePattern<T>(context), channel_id_(*channel_id) {}
187
188 protected:
189 int64_t& channel_id_; // A unique channel_id shared by all rewrite patterns
190 // in this pass. Not thread-safe.
191 };
192
193 // Converts XlaAllReduce. Not thread-safe.
194 class ConvertXlaAllReduce
195 : public CollectiveRewritePattern<TF::XlaAllReduceOp> {
196 public:
197 using CollectiveRewritePattern::CollectiveRewritePattern;
198
matchAndRewrite(TF::XlaAllReduceOp all_reduce,PatternRewriter & rewriter) const199 LogicalResult matchAndRewrite(TF::XlaAllReduceOp all_reduce,
200 PatternRewriter& rewriter) const override {
201 DenseIntElementsAttr replica_groups;
202 if (failed(ConvertReplicaGroups(rewriter, all_reduce.group_assignment(),
203 replica_groups, all_reduce))) {
204 return failure();
205 }
206
207 // TODO(b/226201111): Stop emitting CollectiveInfo when it is no longer
208 // needed.
209 if (failed(SetCollectiveInfo(rewriter, replica_groups, all_reduce))) {
210 return failure();
211 }
212
213 StringRef reduce_op = all_reduce.reduce_op();
214
215 StringRef merge_op, final_op;
216 if (reduce_op == "Add") {
217 merge_op = "Add";
218 final_op = "Id";
219 } else if (reduce_op == "Mul") {
220 merge_op = "Mul";
221 final_op = "Id";
222 } else if (reduce_op == "Min") {
223 merge_op = "Min";
224 final_op = "Id";
225 } else if (reduce_op == "Max") {
226 merge_op = "Max";
227 final_op = "Id";
228 } else if (reduce_op == "Mean") {
229 merge_op = "Add";
230 final_op = "Div";
231 } else {
232 return all_reduce->emitOpError()
233 << "invalid reduce_op " << reduce_op
234 << ", want one of [Add, Mul, Min, Max, Mean]";
235 }
236
237 int64_t channel_id = channel_id_++;
238 return ConvertAllReduce(rewriter, channel_id, all_reduce.getType(),
239 replica_groups, all_reduce.mode(),
240 all_reduce.input(), merge_op, final_op, all_reduce);
241 }
242 };
243
244 // Converts CollectiveReduceV2, with or without a preceding
245 // CollectiveAssignGroupV2. Not thread-safe.
246 class ConvertCollectiveReduceV2
247 : public CollectiveRewritePattern<TF::CollectiveReduceV2Op> {
248 public:
249 using CollectiveRewritePattern::CollectiveRewritePattern;
250
matchAndRewrite(TF::CollectiveReduceV2Op all_reduce,PatternRewriter & rewriter) const251 LogicalResult matchAndRewrite(TF::CollectiveReduceV2Op all_reduce,
252 PatternRewriter& rewriter) const override {
253 TF::CollectiveAssignGroupV2Op assign_group =
254 all_reduce.group_size().getDefiningOp<TF::CollectiveAssignGroupV2Op>();
255
256 if (assign_group) {
257 // Found a group assignment. Use replica_groups to represent group
258 // assignment.
259
260 if (assign_group != all_reduce.group_key()
261 .getDefiningOp<TF::CollectiveAssignGroupV2Op>()) {
262 return all_reduce->emitOpError()
263 << "group_size and group_key are not from the "
264 "same CollectiveAssignGroupV2Op";
265 }
266
267 DenseIntElementsAttr replica_groups;
268 if (failed(ConvertReplicaGroups(rewriter, assign_group.group_assignment(),
269 replica_groups, all_reduce))) {
270 return failure();
271 }
272
273 // TODO(b/226201111): Stop emitting CollectiveInfo when it is no longer
274 // needed.
275 if (failed(SetCollectiveInfo(rewriter, replica_groups, all_reduce))) {
276 return failure();
277 }
278
279 int64_t channel_id = channel_id_++;
280 // FIXME(b/226139061): Mode should be set to CrossReplicaAndPartition
281 // in order to use XLA:GPU for more than one workers.
282 // The mode is set to use CrossReplica to keep the
283 // behavior on the primary user of this optimized path, because
284 // CrossReplicaAndPartition triggers a conflict with the channel_id
285 // allocation in the communication lowering, and the user uses both set of
286 // ops are used.
287 return ConvertAllReduce(rewriter, channel_id, all_reduce.getType(),
288 replica_groups, /* mode=*/"CrossReplica",
289 all_reduce.input(), all_reduce.merge_op(),
290 all_reduce.final_op(), all_reduce);
291 }
292
293 // No group assignment, use separate channels per group_key.
294 DenseIntElementsAttr group_size_attr;
295 if (!matchPattern(all_reduce.group_size(), m_Constant(&group_size_attr))) {
296 return all_reduce.emitOpError()
297 << "group_size must be a compile time constant";
298 }
299 if (!group_size_attr.isSplat() || group_size_attr.size() != 1) {
300 return all_reduce.emitOpError() << "group_size must be a scalar";
301 }
302 const auto group_size = group_size_attr.getSplatValue<IntegerAttr>();
303
304 // Create a full group assignment. Empty group assignment errors when
305 // final_op = "Div"
306 llvm::SmallVector<int64_t> indices(group_size.getInt());
307 std::iota(indices.begin(), indices.end(), 0);
308
309 auto replica_groups = mlir::DenseIntElementsAttr::get(
310 mlir::RankedTensorType::get({1, group_size.getInt()},
311 rewriter.getI64Type()),
312 indices);
313
314 {
315 // TODO(b/226201111): Stop emitting CollectiveInfo when it is no longer
316 // needed.
317 DenseIntElementsAttr group_key_attr;
318 if (!matchPattern(all_reduce.group_key(), m_Constant(&group_key_attr))) {
319 return all_reduce.emitOpError()
320 << "group_key must be a compile time constant";
321 }
322 if (failed(SetCollectiveInfo(
323 /* group_size=*/group_size,
324 /* group_key=*/group_key_attr.getSplatValue<IntegerAttr>(),
325 all_reduce))) {
326 return failure();
327 }
328 }
329
330 // CrossReplicaAndPartition:
331 // Even though TF2XLA will setup the device assignment to include
332 // devices in this group as replicas before launching this module,
333 // "CrossReplica" mode (no channel) produces a deadlock when
334 // not using XLA SPMD expansion.
335 int64_t channel_id = channel_id_++;
336 return ConvertAllReduce(
337 rewriter, channel_id, all_reduce.getType(), replica_groups,
338 /* mode= */ "CrossReplicaAndPartition", all_reduce.input(),
339 all_reduce.merge_op(), all_reduce.final_op(), all_reduce);
340 }
341 };
342
runOnOperation()343 void LegalizeTFCollective::runOnOperation() {
344 // FIXME(b/226139061): Figure out a way to share the channel_id with
345 // send/recv Ops.
346 int64_t channel_id = 1;
347 auto module = getOperation();
348 MLIRContext* context = module->getContext();
349
350 RewritePatternSet patterns(context);
351 patterns.insert<ConvertCollectiveReduceV2>(context, &channel_id);
352 patterns.insert<ConvertXlaAllReduce>(context, &channel_id);
353
354 if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) {
355 signalPassFailure();
356 }
357 }
358 } // namespace
359
CreateLegalizeTFCollectivePass()360 std::unique_ptr<OperationPass<ModuleOp>> CreateLegalizeTFCollectivePass() {
361 return std::make_unique<LegalizeTFCollective>();
362 }
363
364 } // namespace mhlo
365 } // namespace mlir
366