1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
17 #include "mlir/IR/Builders.h" // from @llvm-project
18 #include "mlir/IR/Operation.h" // from @llvm-project
19 #include "mlir/Transforms/FoldUtils.h" // from @llvm-project
20 #include "mlir/Transforms/Passes.h" // from @llvm-project
21 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
22 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
23 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes_classes.h"
24
25 namespace tensorflow {
26 namespace dtensor {
27 namespace {
28
29 constexpr int kMaxIteration = 10;
30
FoldConstantOp(mlir::OperationFolder & folder,mlir::TF::ConstOp op)31 mlir::LogicalResult FoldConstantOp(mlir::OperationFolder& folder,
32 mlir::TF::ConstOp op) {
33 bool changed = false;
34 int i = 0;
35 // Iterate until convergence or until maxIterations. Deletion of the op as
36 // a result of being dead or folded is convergence.
37 do {
38 changed = false;
39
40 // If the operation is trivially dead - remove it.
41 if (isOpTriviallyDead(op)) {
42 op->erase();
43 return mlir::success();
44 }
45
46 // Try to fold this op.
47 bool inPlaceUpdate;
48 if (succeeded(folder.tryToFold(op,
49 /*processGeneratedConstants=*/nullptr,
50 /*preReplaceAction=*/nullptr,
51 &inPlaceUpdate))) {
52 changed = true;
53 if (!inPlaceUpdate) {
54 return mlir::success();
55 }
56 }
57 } while (changed && ++i < kMaxIteration);
58 return mlir::success();
59 }
60
61 // MLIR pass that folds constants that can be removed or deduplicated away.
62 struct DTensorConstantFolding
63 : public DTensorConstantFoldingBase<DTensorConstantFolding> {
runOnOperationtensorflow::dtensor::__anon131e6f1c0111::DTensorConstantFolding64 void runOnOperation() override {
65 mlir::MLIRContext& context = getContext();
66 mlir::OperationFolder helper(&context);
67
68 // Collect and fold the operations within the function.
69 llvm::SmallVector<mlir::TF::ConstOp, 8> const_ops;
70 getOperation().walk([&](mlir::TF::ConstOp op) { const_ops.push_back(op); });
71
72 // Attempt to fold the specified operation, including handling unused or
73 // duplicated constants.
74 for (mlir::TF::ConstOp op : llvm::reverse(const_ops))
75 if (mlir::failed(FoldConstantOp(helper, op))) return signalPassFailure();
76 }
77 };
78
79 } // namespace
80
81 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
CreateDTensorConstantFolding()82 CreateDTensorConstantFolding() {
83 return std::make_unique<DTensorConstantFolding>();
84 }
85
86 } // namespace dtensor
87 } // namespace tensorflow
88