1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <string>
17
18 #include "llvm/ADT/DenseMap.h"
19 #include "llvm/ADT/SmallPtrSet.h"
20 #include "llvm/Support/FormatVariadic.h"
21 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
22 #include "mlir/IR/Builders.h" // from @llvm-project
23 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
24 #include "mlir/IR/Operation.h" // from @llvm-project
25 #include "mlir/IR/UseDefLists.h" // from @llvm-project
26 #include "mlir/IR/Value.h" // from @llvm-project
27 #include "mlir/Support/LogicalResult.h" // from @llvm-project
28 #include "mlir/Transforms/Passes.h" // from @llvm-project
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
30 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
31 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes_classes.h"
32 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
33 #include "tensorflow/dtensor/mlir/layout_parsing.h"
34 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
35
36 namespace tensorflow {
37 namespace dtensor {
38 namespace {
39
40 constexpr int kMaxIteration = 10;
41
GetIdentitySkippedInputs(mlir::Value val)42 mlir::Value GetIdentitySkippedInputs(mlir::Value val) {
43 mlir::Value input = val;
44 while (auto identity = llvm::dyn_cast_or_null<mlir::TF::IdentityOp>(
45 input.getDefiningOp())) {
46 input = identity.input();
47 }
48 return input;
49 }
50
IsZeroConstant(mlir::Value val)51 bool IsZeroConstant(mlir::Value val) {
52 auto const_input = llvm::dyn_cast_or_null<mlir::TF::ConstOp>(
53 GetIdentitySkippedInputs(val).getDefiningOp());
54 if (!const_input) return false;
55 mlir::DenseFPElementsAttr attr =
56 const_input.value().dyn_cast<mlir::DenseFPElementsAttr>();
57 // This uses the fact that constant Attrs becomes splats, so we only need to
58 // check one value.
59 if (!attr || !attr.isSplat()) return false;
60 return attr.getSplatValue<mlir::FloatAttr>().getValue().isZero();
61 }
62
63 // Extracts inputs/ops required for optimization and checks whether graph
64 // meets the criteria for reduction + sum optimization. The criterion are:
65 // a) All DTensorAllReduce operations must be sum operations.
66 // b) Group assignment of DTensorAllReduceOp must be the same
67 // c) All operands of Add op must be DTensorAllReduce operations.
CheckReduceAndSumOptimizationCriteria(mlir::Operation * add_op,llvm::SmallVectorImpl<mlir::Value> * reduction_inputs,llvm::SmallVectorImpl<mlir::TF::DTensorAllReduceOp> * reduction_ops,bool * can_be_reordered)68 mlir::LogicalResult CheckReduceAndSumOptimizationCriteria(
69 mlir::Operation* add_op,
70 llvm::SmallVectorImpl<mlir::Value>* reduction_inputs,
71 llvm::SmallVectorImpl<mlir::TF::DTensorAllReduceOp>* reduction_ops,
72 bool* can_be_reordered) {
73 for (mlir::Value operand : add_op->getOperands()) {
74 if (IsZeroConstant(operand)) {
75 reduction_inputs->emplace_back(operand);
76 continue;
77 }
78
79 auto reduction_op = llvm::dyn_cast_or_null<mlir::TF::DTensorAllReduceOp>(
80 operand.getDefiningOp());
81 if (!reduction_op) {
82 *can_be_reordered = false;
83 return mlir::success();
84 }
85
86 reduction_ops->emplace_back(reduction_op);
87 }
88
89 llvm::SmallDenseSet<mlir::Attribute> reduction_group_assignments;
90 for (mlir::TF::DTensorAllReduceOp reduction : *reduction_ops) {
91 if (reduction.reduce_op().str() != kReduceOpAdd) {
92 *can_be_reordered = false;
93 return mlir::success();
94 }
95
96 mlir::DenseIntElementsAttr group_assignment;
97 if (!matchPattern(reduction.group_assignment(),
98 m_Constant(&group_assignment))) {
99 *can_be_reordered = false;
100 return mlir::success();
101 }
102
103 reduction_group_assignments.insert(group_assignment);
104 reduction_inputs->emplace_back(reduction.input());
105 }
106
107 *can_be_reordered = (reduction_group_assignments.size() == 1);
108 return mlir::success();
109 }
110
111 // Applies optimization that reorders AllReduce + Add operations.
112 // For example:
113 // %3 = DTensorAllReduce(%0)
114 // %4 = DTensorAllReduce(%1)
115 // %5 = Add(%3, %4)
116 //
117 // Is transformed to:
118 // %2 = Add(%0, %1)
119 // %3 = DTensorAllReduce(%2)
120 //
121 // Therefore reducing the number of Reduction/cross device communication.
OptimizeAllReduceAndSum(mlir::Operation * op,bool * changed)122 mlir::LogicalResult OptimizeAllReduceAndSum(mlir::Operation* op,
123 bool* changed) {
124 bool can_be_reordered;
125 llvm::SmallVector<mlir::TF::DTensorAllReduceOp, 4> reduction_ops;
126 llvm::SmallVector<mlir::Value, 4> reduction_op_inputs;
127 if (mlir::failed(CheckReduceAndSumOptimizationCriteria(
128 op, &reduction_op_inputs, &reduction_ops, &can_be_reordered)))
129 return mlir::failure();
130
131 if (!can_be_reordered || reduction_ops.empty()) return mlir::success();
132
133 // Forward the inputs from the DTensorAllReduce to the add op. Calling
134 // getOperand(i).getDefiningOp() since CheckReduceAndSumOptimizationCriteria
135 // checks that each input is fed by a DTensorAllReduce or a Zero constant.
136 for (int i = 0; i < op->getNumOperands(); ++i) {
137 if (mlir::isa<mlir::TF::DTensorAllReduceOp>(
138 op->getOperand(i).getDefiningOp()))
139 op->setOperand(i, op->getOperand(i).getDefiningOp()->getOperand(0));
140 }
141
142 mlir::TF::DTensorAllReduceOp first_reduction_op = reduction_ops.front();
143
144 // Invoke reduction operation on locally added tensor once.
145 // From above check `CheckOptimizationCriteria()`, we know that all reduction
146 // operations that are fused reused the same group assignment value.
147 // 1) Get mlir::Value that represents group assignment used for reduction.
148 mlir::Value group_assignment = first_reduction_op.group_assignment();
149
150 // Create a singe reduction operation that reduces the result of the locally
151 // added tensor.
152 mlir::OpBuilder builder(op);
153 builder.setInsertionPointAfterValue(op->getResult(0));
154 mlir::TF::DTensorAllReduceOp all_reduce =
155 builder.create<mlir::TF::DTensorAllReduceOp>(
156 op->getLoc(), op->getResult(0).getType(), op->getResult(0),
157 group_assignment, builder.getStringAttr(std::string(kReduceOpAdd)),
158 builder.getStringAttr(first_reduction_op.device_type()));
159
160 const auto layout_or_status = ExtractSingleLayoutFromOp(first_reduction_op);
161 if (!layout_or_status.ok())
162 return first_reduction_op->emitOpError(llvm::formatv(
163 "Malformed layout specification for DTensorAllReduce op found: {0}",
164 layout_or_status.status().error_message()));
165
166 if (!layout_or_status->has_value())
167 return first_reduction_op->emitOpError(
168 "DTensorAllReduce op must have layout specification.");
169
170 // Set target layout that is equivalent to original DTensorReduction op in
171 // the graph. This is used during later optimization passes.
172 SetSingleLayoutOnOp(all_reduce, layout_or_status->value());
173
174 // Replace usages of original tf.Add op with newly created output of
175 // `all_reduce`.
176 op->getResult(0).replaceAllUsesExcept(
177 all_reduce.output(),
178 llvm::SmallPtrSet<mlir::Operation*, 1>{all_reduce.getOperation()});
179
180 // TODO(hongjunchoi, bfontain): Consider adding optimization for the case when
181 // `tree` of Add operations with DTensorAllReduce op as inputs exists.
182 // Remove original tf.Add `op` and if reduction operation inputs to original
183 // `op` is only used by the `op`, then remove the DTensorAllReduce op as well.
184 for (mlir::Operation* original_reduction_op : reduction_ops) {
185 if (original_reduction_op->use_empty()) original_reduction_op->erase();
186 }
187
188 *changed = true;
189 return mlir::success();
190 }
191
SkipIdentityLikeOpsOutputs(mlir::Value val)192 mlir::Value SkipIdentityLikeOpsOutputs(mlir::Value val) {
193 while (val.hasOneUse() &&
194 llvm::isa<mlir::TF::CastOp, mlir::TF::ReshapeOp, mlir::TF::IdentityOp>(
195 *val.user_begin())) {
196 val = val.user_begin()->getResult(0);
197 }
198 return val;
199 }
200
201 // TODO(hongjunchoi): Consider using tracing algorithm to virtually transform
202 // the IR and only apply optimizations when total number of DTensorAllReduce in
203 // the graph is reduced.
MayRemoveAllReduce(mlir::Operation * op)204 bool MayRemoveAllReduce(mlir::Operation* op) {
205 mlir::Value op_output = op->getResult(0);
206 mlir::Value value_after_identity_like_ops =
207 SkipIdentityLikeOpsOutputs(op_output);
208 if (value_after_identity_like_ops.hasOneUse() &&
209 llvm::isa<mlir::TF::AddNOp, mlir::TF::AddV2Op, mlir::TF::AddOp>(
210 *value_after_identity_like_ops.user_begin()))
211
212 return true;
213
214 return false;
215 }
216
217 // Moves DTensorAllReduce ops after IdentityLike Operations if the operation is
218 // connected to Add operation which may led to optimization.
219 // For example:
220 //
221 // %0 = "tf.Const"() {value = dense<0> : tensor<2x64xi32>}
222 // %2 = "tf.Const"() {value = dense<0.0> : tensor<8192x916xbf16>}
223 // %4= "tf.DTensorAllReduce"(%2, %0) {reduce_op = "Add"}
224 // %5 = "tf.Cast"(%4){Truncate = false, device = ""}
225 // %6 = "tf.Identity"(%5){Truncate = false, device = ""}
226 // %7 = "tf.Const"() {value = dense<[916,8192]> : tensor<2xi32>}
227 // %8 = "tf.Reshape"(%6, %7)
228 //
229 // Becomes :
230 //
231 // %0 = "tf.Const"()
232 // %2 = "tf.Const"()
233 // %3 = "tf.Cast"(%2)
234 // %4 = "tf.Identity"(%3)
235 // %7 = "tf.Const"()
236 // %8 = "tf.Reshape"(%4, %7)
237 // %9 = "tf.DTensorAllReduce"(%8, %0) {reduce_op = "Add"}
OptimizeIdentityLikeOps(mlir::Operation * op,bool * changed)238 void OptimizeIdentityLikeOps(mlir::Operation* op, bool* changed) {
239 auto dtensor_all_reduce =
240 llvm::dyn_cast_or_null<mlir::TF::DTensorAllReduceOp>(
241 op->getOperand(0).getDefiningOp());
242 if (!dtensor_all_reduce) return;
243 // TODO(hongjunchoi, bfontain): Consider allowing pushing DTensorAllReduce op
244 // with multiple usages if it can lead to performance optimization.
245 if (!dtensor_all_reduce->hasOneUse()) return;
246 if (!MayRemoveAllReduce(op)) return;
247
248 dtensor_all_reduce->moveAfter(op);
249 mlir::Value input = dtensor_all_reduce.input();
250 op->setOperand(0, input);
251
252 mlir::Value op_output = op->getResult(0);
253 dtensor_all_reduce.setOperand(0, op_output);
254 dtensor_all_reduce.input().setType(op_output.getType());
255 dtensor_all_reduce.output().setType(op_output.getType());
256
257 llvm::SmallPtrSet<mlir::Operation*, 4> exceptions{dtensor_all_reduce};
258 op_output.replaceAllUsesExcept(dtensor_all_reduce.output(), exceptions);
259 *changed = true;
260 }
261
CheckWhileLoopOptimizationCriteria(const int index,mlir::TF::WhileRegionOp while_op,mlir::Value while_output,mlir::Operation ** add_op,mlir::TF::DTensorAllReduceOp * all_reduce_op,mlir::OpOperand ** add_input)262 bool CheckWhileLoopOptimizationCriteria(
263 const int index, mlir::TF::WhileRegionOp while_op, mlir::Value while_output,
264 mlir::Operation** add_op, mlir::TF::DTensorAllReduceOp* all_reduce_op,
265 mlir::OpOperand** add_input) {
266 // Loop variant input that is being optimized should not be used in loop
267 // condition.
268 mlir::Value loop_condition_input = while_op.cond().getArgument(index);
269 if (!loop_condition_input.use_empty()) return false;
270
271 // While loop output should be connected to add op.
272 // If operand to while loop body terminator if from Identity op,
273 // skip through the input identity operations.
274 mlir::Value output_value = GetIdentitySkippedInputs(while_output);
275 mlir::Operation* output_defining_op = output_value.getDefiningOp();
276 if (!output_defining_op) return false;
277
278 // TODO(hongjunchoi): Handle AddN op as well.
279 if (!output_defining_op ||
280 !llvm::isa<mlir::TF::AddV2Op, mlir::TF::AddOp>(output_defining_op)) {
281 return false;
282 }
283
284 // Input operand of add operation should be
285 // 1) DTensorAllReduce
286 // 2) from block argument of while loop
287 mlir::OpOperand& first_operand = output_defining_op->getOpOperand(0);
288 mlir::OpOperand& second_operand = output_defining_op->getOpOperand(1);
289 mlir::BlockArgument block_arg;
290 mlir::TF::DTensorAllReduceOp all_reduce =
291 llvm::dyn_cast_or_null<mlir::TF::DTensorAllReduceOp>(
292 first_operand.get().getDefiningOp());
293 if (all_reduce) {
294 block_arg = second_operand.get().dyn_cast<mlir::BlockArgument>();
295 *add_input = &second_operand;
296 } else {
297 all_reduce = llvm::dyn_cast_or_null<mlir::TF::DTensorAllReduceOp>(
298 second_operand.get().getDefiningOp());
299 block_arg = first_operand.get().dyn_cast<mlir::BlockArgument>();
300 *add_input = &first_operand;
301 }
302 if (!block_arg || !all_reduce) return false;
303
304 // DTensorAllReduce should calculate sum across devices and group assignment
305 // must be statically known.
306 mlir::Operation* group_assignment =
307 all_reduce.group_assignment().getDefiningOp();
308 if (!group_assignment || !llvm::isa<mlir::TF::ConstOp>(group_assignment))
309 return false;
310
311 if (all_reduce.reduce_op().str() != kReduceOpAdd) return false;
312
313 // While loop block argument input connected to Add op should be
314 // connected to constant operations with zero value.
315 const int block_arg_index = block_arg.getArgNumber();
316 mlir::OpOperand& while_input = while_op->getOpOperand(block_arg_index);
317 if (!IsZeroConstant(while_input.get())) return false;
318
319 // TODO(hongjunchoi): Handle the case when input is from DTensorAllReduce op.
320 // If group assignment is the same, then the input DTensorAllReduce op can
321 // also be optimized away.
322
323 *add_op = output_defining_op;
324 *all_reduce_op = all_reduce;
325 return true;
326 }
327
328 // Extracts out DTensorAllReduce operation from while op if
329 // a) While op contains DTensorAllReduce op followed by an Add Operation
330 // b) Remaining operand of Add operation is a loop variant input of the while
331 // operation with zero initial value.
332 //
333 // For example:
334 //
335 // %0 = "tf.Const"() {value = dense<0> : tensor<2x64xi32>}
336 // %2 = "tf.Const"() {value = dense<0.0> : tensor<8192x916xbf16>}
337 // WhileRegionOp(%2) {
338 // %0 = "tf.A"(%2)
339 // "tf.Yield"(%0)
340 // }, {
341 // ^bb0(%barg0: tensor<8192x916xbf16>):
342 // ...
343 // %0 = "tf.Const"()
344 // %1 = "tf.Const"()
345 // %2 = "tf.DTensorAllReduce"(%1, %0) {reduce_op = "Add"}
346 // %3 = "tf.Add"(%2, %barg0)
347 // "tf.Yield"(%3)
348 // })
349 //
350 // Becomes :
351 //
352 // %0 = "tf.Const"() {value = dense<0> : tensor<2x64xi32>}
353 // %2 = "tf.Const"() {value = dense<0.0> : tensor<8192x916xbf16>}
354 // %4 = WhileRegionOp(%2) {
355 // %0 = "tf.A"(%2)
356 // "tf.Yield"(%0)
357 // }, {
358 // ^bb0(%barg0: tensor<8192x916xbf16>):
359 // ...
360 // %0 = "tf.Const"()
361 // %1 = "tf.Const"()
362 // %3 = "tf.Add"(%1, %barg0)
363 // "tf.Yield"(%3)
364 // })
365 // "tf.DTensorAllReduce"(%4, %0) {reduce_op = "Add"}
ExtractAllReduceFromWhileOp(const int output_index,mlir::TF::DTensorAllReduceOp all_reduce,mlir::TF::WhileRegionOp while_op,mlir::OpOperand & add_input,mlir::Operation * add_op,bool * changed)366 mlir::LogicalResult ExtractAllReduceFromWhileOp(
367 const int output_index, mlir::TF::DTensorAllReduceOp all_reduce,
368 mlir::TF::WhileRegionOp while_op, mlir::OpOperand& add_input,
369 mlir::Operation* add_op, bool* changed) {
370 // Set add input to input of all reduce.
371 mlir::Value all_reduce_input = all_reduce.input();
372 const int replacement_add_input_index =
373 add_input.getOperandNumber() == 0 ? 1 : 0;
374 add_op->setOperand(replacement_add_input_index, all_reduce_input);
375
376 mlir::OpBuilder builder(while_op);
377 builder.setInsertionPointAfter(while_op);
378
379 mlir::Value while_output = while_op.getResult(output_index);
380 mlir::Operation* group_assignment_const =
381 all_reduce.group_assignment().getDefiningOp();
382 mlir::Operation* cloned_group_assignment =
383 builder.clone(*group_assignment_const);
384
385 // Create a singe reduction operation that reduces the result of the locally
386 // added tensor.
387 auto new_all_reduce = builder.create<mlir::TF::DTensorAllReduceOp>(
388 all_reduce.getLoc(), while_output.getType(), while_output,
389 cloned_group_assignment->getResult(0),
390 builder.getStringAttr(std::string(kReduceOpAdd)),
391 builder.getStringAttr(all_reduce.device_type()));
392
393 const auto layout_or_status = ExtractSingleLayoutFromOp(all_reduce);
394 if (!layout_or_status.ok())
395 return all_reduce->emitOpError(llvm::formatv(
396 "Malformed layout specification for DTensorAllReduce op found: {0}",
397 layout_or_status.status().error_message()));
398
399 if (!layout_or_status->has_value())
400 return all_reduce->emitOpError(
401 "DTensorAllReduce op must have layout specification.");
402
403 // Set target layout that is equivalent to original DTensorReduction op in
404 // the graph. This is used during later optimization passes.
405 SetSingleLayoutOnOp(new_all_reduce, layout_or_status->value());
406
407 llvm::SmallPtrSet<mlir::Operation*, 4> exceptions;
408 exceptions.insert(new_all_reduce.getOperation());
409 while_output.replaceAllUsesExcept(new_all_reduce.output(), exceptions);
410
411 if (all_reduce.use_empty()) all_reduce.erase();
412
413 *changed = true;
414 return mlir::success();
415 }
416
OptimizeWhileLoopLazyAllReduce(mlir::TF::WhileRegionOp while_op,bool * changed)417 mlir::LogicalResult OptimizeWhileLoopLazyAllReduce(
418 mlir::TF::WhileRegionOp while_op, bool* changed) {
419 mlir::Operation* while_body_terminator =
420 while_op.body().front().getTerminator();
421 for (const auto& data :
422 llvm::enumerate(while_body_terminator->getOpOperands())) {
423 const int index = data.index();
424 mlir::OpOperand& operand = data.value();
425
426 mlir::Operation* add_op = nullptr;
427 mlir::TF::DTensorAllReduceOp all_reduce;
428 mlir::OpOperand* add_input = nullptr;
429 if (!CheckWhileLoopOptimizationCriteria(index, while_op, operand.get(),
430 &add_op, &all_reduce, &add_input))
431 continue;
432
433 // Perform while loop lazy all reduce optimization.
434 if (mlir::failed(ExtractAllReduceFromWhileOp(index, all_reduce, while_op,
435 *add_input, add_op, changed)))
436 return mlir::failure();
437 }
438
439 return mlir::success();
440 }
441
ApplyOptimization(mlir::func::FuncOp function,const llvm::SmallVectorImpl<mlir::Operation * > & identity_like_ops,const llvm::SmallVectorImpl<mlir::TF::WhileRegionOp> & while_ops,const llvm::SmallVectorImpl<mlir::Operation * > & add_ops,bool * changed)442 mlir::LogicalResult ApplyOptimization(
443 mlir::func::FuncOp function,
444 const llvm::SmallVectorImpl<mlir::Operation*>& identity_like_ops,
445 const llvm::SmallVectorImpl<mlir::TF::WhileRegionOp>& while_ops,
446 const llvm::SmallVectorImpl<mlir::Operation*>& add_ops, bool* changed) {
447 // Collect and fold the reduction operations within the function.
448 for (mlir::Operation* add_op : add_ops)
449 if (mlir::failed(OptimizeAllReduceAndSum(add_op, changed)))
450 return mlir::failure();
451
452 for (mlir::Operation* op : identity_like_ops)
453 OptimizeIdentityLikeOps(op, changed);
454
455 for (mlir::TF::WhileRegionOp op : while_ops)
456 if (mlir::failed(OptimizeWhileLoopLazyAllReduce(op, changed)))
457 return mlir::failure();
458
459 return mlir::success();
460 }
461
462 // Finds all potential ops that could lead to all reduce optimizations. Those
463 // are:
464 // a) Identity like ops (e.g. Identity/Reshape/Cast) ops.
465 // b) WhileRegion op
466 // c) Add operations.
CollectOptimizationCandidates(mlir::func::FuncOp func,llvm::SmallVectorImpl<mlir::Operation * > * identity_like_ops,llvm::SmallVectorImpl<mlir::Operation * > * add_ops,llvm::SmallVectorImpl<mlir::TF::WhileRegionOp> * while_ops)467 void CollectOptimizationCandidates(
468 mlir::func::FuncOp func,
469 llvm::SmallVectorImpl<mlir::Operation*>* identity_like_ops,
470 llvm::SmallVectorImpl<mlir::Operation*>* add_ops,
471 llvm::SmallVectorImpl<mlir::TF::WhileRegionOp>* while_ops) {
472 func.walk([&](mlir::Operation* op) {
473 if (llvm::isa<mlir::TF::IdentityOp, mlir::TF::CastOp, mlir::TF::ReshapeOp>(
474 op))
475 identity_like_ops->emplace_back(op);
476
477 if (auto while_op = llvm::dyn_cast<mlir::TF::WhileRegionOp>(op))
478 while_ops->emplace_back(while_op);
479
480 if (llvm::isa<mlir::TF::AddOp, mlir::TF::AddV2Op, mlir::TF::AddNOp>(op))
481 add_ops->emplace_back(op);
482 });
483 }
484
485 // MLIR pass that folds constants that can be removed or deduplicated away.
486 struct DTensorAllReduceSumOptimization
487 : public DTensorAllReduceSumOptimizationBase<
488 DTensorAllReduceSumOptimization> {
runOnOperationtensorflow::dtensor::__anon99edb1d90111::DTensorAllReduceSumOptimization489 void runOnOperation() override {
490 mlir::func::FuncOp function = getOperation();
491 bool changed = true;
492 int iteration = 0;
493
494 llvm::SmallVector<mlir::Operation*, 4> identity_like_ops;
495 llvm::SmallVector<mlir::Operation*, 4> add_ops;
496 llvm::SmallVector<mlir::TF::WhileRegionOp, 4> while_ops;
497 CollectOptimizationCandidates(function, &identity_like_ops, &add_ops,
498 &while_ops);
499 bool is_optimized = false;
500 while (changed && iteration < kMaxIteration) {
501 changed = false;
502 if (mlir::failed(ApplyOptimization(function, identity_like_ops, while_ops,
503 add_ops, &changed)))
504 return signalPassFailure();
505 iteration++;
506 if (changed) is_optimized = true;
507 }
508 }
509 };
510
511 } // namespace
512
513 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
CreateDTensorAllReduceSumOptimization()514 CreateDTensorAllReduceSumOptimization() {
515 return std::make_unique<DTensorAllReduceSumOptimization>();
516 }
517
518 } // namespace dtensor
519 } // namespace tensorflow
520