xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/dtensor_allreduce_sum_optimization.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <string>
17 
18 #include "llvm/ADT/DenseMap.h"
19 #include "llvm/ADT/SmallPtrSet.h"
20 #include "llvm/Support/FormatVariadic.h"
21 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
22 #include "mlir/IR/Builders.h"  // from @llvm-project
23 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
24 #include "mlir/IR/Operation.h"  // from @llvm-project
25 #include "mlir/IR/UseDefLists.h"  // from @llvm-project
26 #include "mlir/IR/Value.h"  // from @llvm-project
27 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
28 #include "mlir/Transforms/Passes.h"  // from @llvm-project
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
30 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
31 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes_classes.h"
32 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
33 #include "tensorflow/dtensor/mlir/layout_parsing.h"
34 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
35 
36 namespace tensorflow {
37 namespace dtensor {
38 namespace {
39 
40 constexpr int kMaxIteration = 10;
41 
GetIdentitySkippedInputs(mlir::Value val)42 mlir::Value GetIdentitySkippedInputs(mlir::Value val) {
43   mlir::Value input = val;
44   while (auto identity = llvm::dyn_cast_or_null<mlir::TF::IdentityOp>(
45              input.getDefiningOp())) {
46     input = identity.input();
47   }
48   return input;
49 }
50 
IsZeroConstant(mlir::Value val)51 bool IsZeroConstant(mlir::Value val) {
52   auto const_input = llvm::dyn_cast_or_null<mlir::TF::ConstOp>(
53       GetIdentitySkippedInputs(val).getDefiningOp());
54   if (!const_input) return false;
55   mlir::DenseFPElementsAttr attr =
56       const_input.value().dyn_cast<mlir::DenseFPElementsAttr>();
57   // This uses the fact that constant Attrs becomes splats, so we only need to
58   // check one value.
59   if (!attr || !attr.isSplat()) return false;
60   return attr.getSplatValue<mlir::FloatAttr>().getValue().isZero();
61 }
62 
63 // Extracts inputs/ops required for optimization and checks whether graph
64 // meets the criteria for reduction + sum optimization. The criterion are:
65 // a) All DTensorAllReduce operations must be sum operations.
66 // b) Group assignment of DTensorAllReduceOp must be the same
67 // c) All operands of Add op must be DTensorAllReduce operations.
CheckReduceAndSumOptimizationCriteria(mlir::Operation * add_op,llvm::SmallVectorImpl<mlir::Value> * reduction_inputs,llvm::SmallVectorImpl<mlir::TF::DTensorAllReduceOp> * reduction_ops,bool * can_be_reordered)68 mlir::LogicalResult CheckReduceAndSumOptimizationCriteria(
69     mlir::Operation* add_op,
70     llvm::SmallVectorImpl<mlir::Value>* reduction_inputs,
71     llvm::SmallVectorImpl<mlir::TF::DTensorAllReduceOp>* reduction_ops,
72     bool* can_be_reordered) {
73   for (mlir::Value operand : add_op->getOperands()) {
74     if (IsZeroConstant(operand)) {
75       reduction_inputs->emplace_back(operand);
76       continue;
77     }
78 
79     auto reduction_op = llvm::dyn_cast_or_null<mlir::TF::DTensorAllReduceOp>(
80         operand.getDefiningOp());
81     if (!reduction_op) {
82       *can_be_reordered = false;
83       return mlir::success();
84     }
85 
86     reduction_ops->emplace_back(reduction_op);
87   }
88 
89   llvm::SmallDenseSet<mlir::Attribute> reduction_group_assignments;
90   for (mlir::TF::DTensorAllReduceOp reduction : *reduction_ops) {
91     if (reduction.reduce_op().str() != kReduceOpAdd) {
92       *can_be_reordered = false;
93       return mlir::success();
94     }
95 
96     mlir::DenseIntElementsAttr group_assignment;
97     if (!matchPattern(reduction.group_assignment(),
98                       m_Constant(&group_assignment))) {
99       *can_be_reordered = false;
100       return mlir::success();
101     }
102 
103     reduction_group_assignments.insert(group_assignment);
104     reduction_inputs->emplace_back(reduction.input());
105   }
106 
107   *can_be_reordered = (reduction_group_assignments.size() == 1);
108   return mlir::success();
109 }
110 
111 // Applies optimization that reorders AllReduce + Add operations.
112 // For example:
113 //   %3 = DTensorAllReduce(%0)
114 //   %4 = DTensorAllReduce(%1)
115 //   %5 = Add(%3, %4)
116 //
117 // Is transformed to:
118 //   %2 = Add(%0, %1)
119 //   %3 = DTensorAllReduce(%2)
120 //
121 // Therefore reducing the number of Reduction/cross device communication.
OptimizeAllReduceAndSum(mlir::Operation * op,bool * changed)122 mlir::LogicalResult OptimizeAllReduceAndSum(mlir::Operation* op,
123                                             bool* changed) {
124   bool can_be_reordered;
125   llvm::SmallVector<mlir::TF::DTensorAllReduceOp, 4> reduction_ops;
126   llvm::SmallVector<mlir::Value, 4> reduction_op_inputs;
127   if (mlir::failed(CheckReduceAndSumOptimizationCriteria(
128           op, &reduction_op_inputs, &reduction_ops, &can_be_reordered)))
129     return mlir::failure();
130 
131   if (!can_be_reordered || reduction_ops.empty()) return mlir::success();
132 
133   // Forward the inputs from the DTensorAllReduce to the add op. Calling
134   // getOperand(i).getDefiningOp() since CheckReduceAndSumOptimizationCriteria
135   // checks that each input is fed by a DTensorAllReduce or a Zero constant.
136   for (int i = 0; i < op->getNumOperands(); ++i) {
137     if (mlir::isa<mlir::TF::DTensorAllReduceOp>(
138             op->getOperand(i).getDefiningOp()))
139       op->setOperand(i, op->getOperand(i).getDefiningOp()->getOperand(0));
140   }
141 
142   mlir::TF::DTensorAllReduceOp first_reduction_op = reduction_ops.front();
143 
144   // Invoke reduction operation on locally added tensor once.
145   // From above check `CheckOptimizationCriteria()`, we know that all reduction
146   // operations that are fused reused the same group assignment value.
147   // 1) Get mlir::Value that represents group assignment used for reduction.
148   mlir::Value group_assignment = first_reduction_op.group_assignment();
149 
150   // Create a singe reduction operation that reduces the result of the locally
151   // added tensor.
152   mlir::OpBuilder builder(op);
153   builder.setInsertionPointAfterValue(op->getResult(0));
154   mlir::TF::DTensorAllReduceOp all_reduce =
155       builder.create<mlir::TF::DTensorAllReduceOp>(
156           op->getLoc(), op->getResult(0).getType(), op->getResult(0),
157           group_assignment, builder.getStringAttr(std::string(kReduceOpAdd)),
158           builder.getStringAttr(first_reduction_op.device_type()));
159 
160   const auto layout_or_status = ExtractSingleLayoutFromOp(first_reduction_op);
161   if (!layout_or_status.ok())
162     return first_reduction_op->emitOpError(llvm::formatv(
163         "Malformed layout specification for DTensorAllReduce op found: {0}",
164         layout_or_status.status().error_message()));
165 
166   if (!layout_or_status->has_value())
167     return first_reduction_op->emitOpError(
168         "DTensorAllReduce op must have layout specification.");
169 
170   // Set target layout that is equivalent to original DTensorReduction op in
171   // the graph. This is used during later optimization passes.
172   SetSingleLayoutOnOp(all_reduce, layout_or_status->value());
173 
174   // Replace usages of original tf.Add op with newly created output of
175   // `all_reduce`.
176   op->getResult(0).replaceAllUsesExcept(
177       all_reduce.output(),
178       llvm::SmallPtrSet<mlir::Operation*, 1>{all_reduce.getOperation()});
179 
180   // TODO(hongjunchoi, bfontain): Consider adding optimization for the case when
181   // `tree` of Add operations with DTensorAllReduce op as inputs exists.
182   // Remove original tf.Add `op` and if reduction operation inputs to original
183   // `op` is only used by the `op`, then remove the DTensorAllReduce op as well.
184   for (mlir::Operation* original_reduction_op : reduction_ops) {
185     if (original_reduction_op->use_empty()) original_reduction_op->erase();
186   }
187 
188   *changed = true;
189   return mlir::success();
190 }
191 
SkipIdentityLikeOpsOutputs(mlir::Value val)192 mlir::Value SkipIdentityLikeOpsOutputs(mlir::Value val) {
193   while (val.hasOneUse() &&
194          llvm::isa<mlir::TF::CastOp, mlir::TF::ReshapeOp, mlir::TF::IdentityOp>(
195              *val.user_begin())) {
196     val = val.user_begin()->getResult(0);
197   }
198   return val;
199 }
200 
201 // TODO(hongjunchoi): Consider using tracing algorithm to virtually transform
202 // the IR and only apply optimizations when total number of DTensorAllReduce in
203 // the graph is reduced.
MayRemoveAllReduce(mlir::Operation * op)204 bool MayRemoveAllReduce(mlir::Operation* op) {
205   mlir::Value op_output = op->getResult(0);
206   mlir::Value value_after_identity_like_ops =
207       SkipIdentityLikeOpsOutputs(op_output);
208   if (value_after_identity_like_ops.hasOneUse() &&
209       llvm::isa<mlir::TF::AddNOp, mlir::TF::AddV2Op, mlir::TF::AddOp>(
210           *value_after_identity_like_ops.user_begin()))
211 
212     return true;
213 
214   return false;
215 }
216 
217 // Moves DTensorAllReduce ops after IdentityLike Operations if the operation is
218 // connected to Add operation which may led to optimization.
219 // For example:
220 //
221 //  %0 = "tf.Const"() {value = dense<0> : tensor<2x64xi32>}
222 //  %2 = "tf.Const"() {value = dense<0.0> : tensor<8192x916xbf16>}
223 //  %4= "tf.DTensorAllReduce"(%2, %0) {reduce_op = "Add"}
224 //  %5 = "tf.Cast"(%4){Truncate = false, device = ""}
225 //  %6 = "tf.Identity"(%5){Truncate = false, device = ""}
226 //  %7 = "tf.Const"() {value = dense<[916,8192]> : tensor<2xi32>}
227 //  %8 = "tf.Reshape"(%6, %7)
228 //
229 // Becomes :
230 //
231 //  %0 = "tf.Const"()
232 //  %2 = "tf.Const"()
233 //  %3 = "tf.Cast"(%2)
234 //  %4 = "tf.Identity"(%3)
235 //  %7 = "tf.Const"()
236 //  %8 = "tf.Reshape"(%4, %7)
237 //  %9 = "tf.DTensorAllReduce"(%8, %0) {reduce_op = "Add"}
OptimizeIdentityLikeOps(mlir::Operation * op,bool * changed)238 void OptimizeIdentityLikeOps(mlir::Operation* op, bool* changed) {
239   auto dtensor_all_reduce =
240       llvm::dyn_cast_or_null<mlir::TF::DTensorAllReduceOp>(
241           op->getOperand(0).getDefiningOp());
242   if (!dtensor_all_reduce) return;
243   // TODO(hongjunchoi, bfontain): Consider allowing pushing DTensorAllReduce op
244   // with multiple usages if it can lead to performance optimization.
245   if (!dtensor_all_reduce->hasOneUse()) return;
246   if (!MayRemoveAllReduce(op)) return;
247 
248   dtensor_all_reduce->moveAfter(op);
249   mlir::Value input = dtensor_all_reduce.input();
250   op->setOperand(0, input);
251 
252   mlir::Value op_output = op->getResult(0);
253   dtensor_all_reduce.setOperand(0, op_output);
254   dtensor_all_reduce.input().setType(op_output.getType());
255   dtensor_all_reduce.output().setType(op_output.getType());
256 
257   llvm::SmallPtrSet<mlir::Operation*, 4> exceptions{dtensor_all_reduce};
258   op_output.replaceAllUsesExcept(dtensor_all_reduce.output(), exceptions);
259   *changed = true;
260 }
261 
CheckWhileLoopOptimizationCriteria(const int index,mlir::TF::WhileRegionOp while_op,mlir::Value while_output,mlir::Operation ** add_op,mlir::TF::DTensorAllReduceOp * all_reduce_op,mlir::OpOperand ** add_input)262 bool CheckWhileLoopOptimizationCriteria(
263     const int index, mlir::TF::WhileRegionOp while_op, mlir::Value while_output,
264     mlir::Operation** add_op, mlir::TF::DTensorAllReduceOp* all_reduce_op,
265     mlir::OpOperand** add_input) {
266   // Loop variant input that is being optimized should not be used in loop
267   // condition.
268   mlir::Value loop_condition_input = while_op.cond().getArgument(index);
269   if (!loop_condition_input.use_empty()) return false;
270 
271   // While loop output should be connected to add op.
272   // If operand to while loop body terminator if from Identity op,
273   // skip through the input identity operations.
274   mlir::Value output_value = GetIdentitySkippedInputs(while_output);
275   mlir::Operation* output_defining_op = output_value.getDefiningOp();
276   if (!output_defining_op) return false;
277 
278   // TODO(hongjunchoi): Handle AddN op as well.
279   if (!output_defining_op ||
280       !llvm::isa<mlir::TF::AddV2Op, mlir::TF::AddOp>(output_defining_op)) {
281     return false;
282   }
283 
284   // Input operand of add operation should be
285   // 1) DTensorAllReduce
286   // 2) from block argument of while loop
287   mlir::OpOperand& first_operand = output_defining_op->getOpOperand(0);
288   mlir::OpOperand& second_operand = output_defining_op->getOpOperand(1);
289   mlir::BlockArgument block_arg;
290   mlir::TF::DTensorAllReduceOp all_reduce =
291       llvm::dyn_cast_or_null<mlir::TF::DTensorAllReduceOp>(
292           first_operand.get().getDefiningOp());
293   if (all_reduce) {
294     block_arg = second_operand.get().dyn_cast<mlir::BlockArgument>();
295     *add_input = &second_operand;
296   } else {
297     all_reduce = llvm::dyn_cast_or_null<mlir::TF::DTensorAllReduceOp>(
298         second_operand.get().getDefiningOp());
299     block_arg = first_operand.get().dyn_cast<mlir::BlockArgument>();
300     *add_input = &first_operand;
301   }
302   if (!block_arg || !all_reduce) return false;
303 
304   // DTensorAllReduce should calculate sum across devices and group assignment
305   // must be statically known.
306   mlir::Operation* group_assignment =
307       all_reduce.group_assignment().getDefiningOp();
308   if (!group_assignment || !llvm::isa<mlir::TF::ConstOp>(group_assignment))
309     return false;
310 
311   if (all_reduce.reduce_op().str() != kReduceOpAdd) return false;
312 
313   // While loop block argument input connected to Add op should be
314   // connected to constant operations with zero value.
315   const int block_arg_index = block_arg.getArgNumber();
316   mlir::OpOperand& while_input = while_op->getOpOperand(block_arg_index);
317   if (!IsZeroConstant(while_input.get())) return false;
318 
319   // TODO(hongjunchoi): Handle the case when input is from DTensorAllReduce op.
320   // If group assignment is the same, then the input DTensorAllReduce op can
321   // also be optimized away.
322 
323   *add_op = output_defining_op;
324   *all_reduce_op = all_reduce;
325   return true;
326 }
327 
328 // Extracts out DTensorAllReduce operation from while op if
329 // a) While op contains DTensorAllReduce op followed by an Add Operation
330 // b) Remaining operand of Add operation is a loop variant input of the while
331 //    operation with zero initial value.
332 //
333 // For example:
334 //
335 //  %0 = "tf.Const"() {value = dense<0> : tensor<2x64xi32>}
336 //  %2 = "tf.Const"() {value = dense<0.0> : tensor<8192x916xbf16>}
337 //  WhileRegionOp(%2) {
338 //    %0 = "tf.A"(%2)
339 //    "tf.Yield"(%0)
340 //  }, {
341 //  ^bb0(%barg0: tensor<8192x916xbf16>):
342 //    ...
343 //    %0 = "tf.Const"()
344 //    %1 = "tf.Const"()
345 //    %2 = "tf.DTensorAllReduce"(%1, %0) {reduce_op = "Add"}
346 //    %3 = "tf.Add"(%2, %barg0)
347 //    "tf.Yield"(%3)
348 //  })
349 //
350 // Becomes :
351 //
352 //  %0 = "tf.Const"() {value = dense<0> : tensor<2x64xi32>}
353 //  %2 = "tf.Const"() {value = dense<0.0> : tensor<8192x916xbf16>}
354 //  %4 = WhileRegionOp(%2) {
355 //    %0 = "tf.A"(%2)
356 //    "tf.Yield"(%0)
357 //  }, {
358 //  ^bb0(%barg0: tensor<8192x916xbf16>):
359 //    ...
360 //    %0 = "tf.Const"()
361 //    %1 = "tf.Const"()
362 //    %3 = "tf.Add"(%1, %barg0)
363 //    "tf.Yield"(%3)
364 //  })
365 //  "tf.DTensorAllReduce"(%4, %0) {reduce_op = "Add"}
ExtractAllReduceFromWhileOp(const int output_index,mlir::TF::DTensorAllReduceOp all_reduce,mlir::TF::WhileRegionOp while_op,mlir::OpOperand & add_input,mlir::Operation * add_op,bool * changed)366 mlir::LogicalResult ExtractAllReduceFromWhileOp(
367     const int output_index, mlir::TF::DTensorAllReduceOp all_reduce,
368     mlir::TF::WhileRegionOp while_op, mlir::OpOperand& add_input,
369     mlir::Operation* add_op, bool* changed) {
370   // Set add input to input of all reduce.
371   mlir::Value all_reduce_input = all_reduce.input();
372   const int replacement_add_input_index =
373       add_input.getOperandNumber() == 0 ? 1 : 0;
374   add_op->setOperand(replacement_add_input_index, all_reduce_input);
375 
376   mlir::OpBuilder builder(while_op);
377   builder.setInsertionPointAfter(while_op);
378 
379   mlir::Value while_output = while_op.getResult(output_index);
380   mlir::Operation* group_assignment_const =
381       all_reduce.group_assignment().getDefiningOp();
382   mlir::Operation* cloned_group_assignment =
383       builder.clone(*group_assignment_const);
384 
385   // Create a singe reduction operation that reduces the result of the locally
386   // added tensor.
387   auto new_all_reduce = builder.create<mlir::TF::DTensorAllReduceOp>(
388       all_reduce.getLoc(), while_output.getType(), while_output,
389       cloned_group_assignment->getResult(0),
390       builder.getStringAttr(std::string(kReduceOpAdd)),
391       builder.getStringAttr(all_reduce.device_type()));
392 
393   const auto layout_or_status = ExtractSingleLayoutFromOp(all_reduce);
394   if (!layout_or_status.ok())
395     return all_reduce->emitOpError(llvm::formatv(
396         "Malformed layout specification for DTensorAllReduce op found: {0}",
397         layout_or_status.status().error_message()));
398 
399   if (!layout_or_status->has_value())
400     return all_reduce->emitOpError(
401         "DTensorAllReduce op must have layout specification.");
402 
403   // Set target layout that is equivalent to original DTensorReduction op in
404   // the graph. This is used during later optimization passes.
405   SetSingleLayoutOnOp(new_all_reduce, layout_or_status->value());
406 
407   llvm::SmallPtrSet<mlir::Operation*, 4> exceptions;
408   exceptions.insert(new_all_reduce.getOperation());
409   while_output.replaceAllUsesExcept(new_all_reduce.output(), exceptions);
410 
411   if (all_reduce.use_empty()) all_reduce.erase();
412 
413   *changed = true;
414   return mlir::success();
415 }
416 
OptimizeWhileLoopLazyAllReduce(mlir::TF::WhileRegionOp while_op,bool * changed)417 mlir::LogicalResult OptimizeWhileLoopLazyAllReduce(
418     mlir::TF::WhileRegionOp while_op, bool* changed) {
419   mlir::Operation* while_body_terminator =
420       while_op.body().front().getTerminator();
421   for (const auto& data :
422        llvm::enumerate(while_body_terminator->getOpOperands())) {
423     const int index = data.index();
424     mlir::OpOperand& operand = data.value();
425 
426     mlir::Operation* add_op = nullptr;
427     mlir::TF::DTensorAllReduceOp all_reduce;
428     mlir::OpOperand* add_input = nullptr;
429     if (!CheckWhileLoopOptimizationCriteria(index, while_op, operand.get(),
430                                             &add_op, &all_reduce, &add_input))
431       continue;
432 
433     // Perform while loop lazy all reduce optimization.
434     if (mlir::failed(ExtractAllReduceFromWhileOp(index, all_reduce, while_op,
435                                                  *add_input, add_op, changed)))
436       return mlir::failure();
437   }
438 
439   return mlir::success();
440 }
441 
ApplyOptimization(mlir::func::FuncOp function,const llvm::SmallVectorImpl<mlir::Operation * > & identity_like_ops,const llvm::SmallVectorImpl<mlir::TF::WhileRegionOp> & while_ops,const llvm::SmallVectorImpl<mlir::Operation * > & add_ops,bool * changed)442 mlir::LogicalResult ApplyOptimization(
443     mlir::func::FuncOp function,
444     const llvm::SmallVectorImpl<mlir::Operation*>& identity_like_ops,
445     const llvm::SmallVectorImpl<mlir::TF::WhileRegionOp>& while_ops,
446     const llvm::SmallVectorImpl<mlir::Operation*>& add_ops, bool* changed) {
447   // Collect and fold the reduction operations within the function.
448   for (mlir::Operation* add_op : add_ops)
449     if (mlir::failed(OptimizeAllReduceAndSum(add_op, changed)))
450       return mlir::failure();
451 
452   for (mlir::Operation* op : identity_like_ops)
453     OptimizeIdentityLikeOps(op, changed);
454 
455   for (mlir::TF::WhileRegionOp op : while_ops)
456     if (mlir::failed(OptimizeWhileLoopLazyAllReduce(op, changed)))
457       return mlir::failure();
458 
459   return mlir::success();
460 }
461 
462 // Finds all potential ops that could lead to all reduce optimizations. Those
463 // are:
464 //   a) Identity like ops (e.g. Identity/Reshape/Cast) ops.
465 //   b) WhileRegion op
466 //   c) Add operations.
CollectOptimizationCandidates(mlir::func::FuncOp func,llvm::SmallVectorImpl<mlir::Operation * > * identity_like_ops,llvm::SmallVectorImpl<mlir::Operation * > * add_ops,llvm::SmallVectorImpl<mlir::TF::WhileRegionOp> * while_ops)467 void CollectOptimizationCandidates(
468     mlir::func::FuncOp func,
469     llvm::SmallVectorImpl<mlir::Operation*>* identity_like_ops,
470     llvm::SmallVectorImpl<mlir::Operation*>* add_ops,
471     llvm::SmallVectorImpl<mlir::TF::WhileRegionOp>* while_ops) {
472   func.walk([&](mlir::Operation* op) {
473     if (llvm::isa<mlir::TF::IdentityOp, mlir::TF::CastOp, mlir::TF::ReshapeOp>(
474             op))
475       identity_like_ops->emplace_back(op);
476 
477     if (auto while_op = llvm::dyn_cast<mlir::TF::WhileRegionOp>(op))
478       while_ops->emplace_back(while_op);
479 
480     if (llvm::isa<mlir::TF::AddOp, mlir::TF::AddV2Op, mlir::TF::AddNOp>(op))
481       add_ops->emplace_back(op);
482   });
483 }
484 
485 // MLIR pass that folds constants that can be removed or deduplicated away.
486 struct DTensorAllReduceSumOptimization
487     : public DTensorAllReduceSumOptimizationBase<
488           DTensorAllReduceSumOptimization> {
runOnOperationtensorflow::dtensor::__anon99edb1d90111::DTensorAllReduceSumOptimization489   void runOnOperation() override {
490     mlir::func::FuncOp function = getOperation();
491     bool changed = true;
492     int iteration = 0;
493 
494     llvm::SmallVector<mlir::Operation*, 4> identity_like_ops;
495     llvm::SmallVector<mlir::Operation*, 4> add_ops;
496     llvm::SmallVector<mlir::TF::WhileRegionOp, 4> while_ops;
497     CollectOptimizationCandidates(function, &identity_like_ops, &add_ops,
498                                   &while_ops);
499     bool is_optimized = false;
500     while (changed && iteration < kMaxIteration) {
501       changed = false;
502       if (mlir::failed(ApplyOptimization(function, identity_like_ops, while_ops,
503                                          add_ops, &changed)))
504         return signalPassFailure();
505       iteration++;
506       if (changed) is_optimized = true;
507     }
508   }
509 };
510 
511 }  // namespace
512 
513 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
CreateDTensorAllReduceSumOptimization()514 CreateDTensorAllReduceSumOptimization() {
515   return std::make_unique<DTensorAllReduceSumOptimization>();
516 }
517 
518 }  // namespace dtensor
519 }  // namespace tensorflow
520