1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
17 #include "mlir/Transforms/Passes.h" // from @llvm-project
18 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
19 #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h"
20
21 namespace tensorflow {
22 namespace tfrt_compiler {
23 namespace {
24
FunctionHasSideEffect(mlir::func::FuncOp func_op,llvm::DenseMap<mlir::func::FuncOp,bool> & function_side_effect)25 bool FunctionHasSideEffect(
26 mlir::func::FuncOp func_op,
27 llvm::DenseMap<mlir::func::FuncOp, bool>& function_side_effect) {
28 auto iter = function_side_effect.find(func_op);
29 if (iter != function_side_effect.end()) return iter->second;
30
31 auto& block = func_op.front();
32
33 auto op_has_side_effect = [&](mlir::Operation* op) {
34 if (auto while_op = llvm::dyn_cast<mlir::TF::WhileOp>(op)) {
35 if (while_op.is_stateless()) return false;
36
37 return FunctionHasSideEffect(while_op.cond_function(),
38 function_side_effect) ||
39 FunctionHasSideEffect(while_op.body_function(),
40 function_side_effect);
41 }
42
43 if (auto if_op = llvm::dyn_cast<mlir::TF::IfOp>(op)) {
44 if (if_op.is_stateless()) return false;
45
46 return FunctionHasSideEffect(if_op.else_function(),
47 function_side_effect) ||
48 FunctionHasSideEffect(if_op.then_function(), function_side_effect);
49 }
50
51 // Though tf.Assert and tf.Timestamp are side-effecting, they do not
52 // interfere with any other side-effecting ops. For now, if control flow
53 // ops' callee functions contain them, we treat them as non-side-effecting.
54 if (llvm::isa<mlir::TF::AssertOp, mlir::TF::TimestampOp>(op)) return false;
55
56 return !mlir::MemoryEffectOpInterface::hasNoEffect(op);
57 };
58
59 // Speculatively setting the function to have no side effect to avoid infinite
60 // recursion. The correct side effect will be updated later once more
61 // operations in the block are checked.
62 function_side_effect[func_op] = false;
63
64 for (mlir::Operation& op : block) {
65 if (op_has_side_effect(&op)) {
66 function_side_effect[func_op] = true;
67 return true;
68 }
69 }
70
71 function_side_effect[func_op] = false;
72 return false;
73 }
74
75 // This pass sets `is_stateless` attribute of tf.If and tf.While ops to true if
76 // their callee functions contains only non-side-effecting ops.
77 class OptimizeTfControlFlowSideEffectPass
78 : public mlir::PassWrapper<OptimizeTfControlFlowSideEffectPass,
79 mlir::OperationPass<mlir::ModuleOp>> {
80 public:
81 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
82 OptimizeTfControlFlowSideEffectPass)
83
84 private:
getArgument() const85 llvm::StringRef getArgument() const final {
86 return "tfrt-optimize-tf-control-flow-side-effect";
87 }
getDescription() const88 llvm::StringRef getDescription() const final {
89 return "Set tf control flow ops to stateless if their callee functions "
90 "contains only non-side-effecting ops";
91 }
runOnOperation()92 void runOnOperation() override {
93 auto module = getOperation();
94 llvm::DenseMap<mlir::func::FuncOp, bool> function_side_effect;
95
96 mlir::Builder builder(module.getContext());
97 module.walk([&](mlir::Operation* op) {
98 if (auto while_op = llvm::dyn_cast<mlir::TF::WhileOp>(op)) {
99 if (while_op.is_stateless()) return;
100
101 if (!FunctionHasSideEffect(while_op.cond_function(),
102 function_side_effect) &&
103 !FunctionHasSideEffect(while_op.body_function(),
104 function_side_effect)) {
105 while_op->setAttr("is_stateless", builder.getBoolAttr(true));
106 }
107 }
108
109 if (auto if_op = llvm::dyn_cast<mlir::TF::IfOp>(op)) {
110 if (if_op.is_stateless()) return;
111
112 if (!FunctionHasSideEffect(if_op.else_function(),
113 function_side_effect) &&
114 !FunctionHasSideEffect(if_op.then_function(),
115 function_side_effect)) {
116 if_op->setAttr("is_stateless", builder.getBoolAttr(true));
117 }
118 }
119 });
120 }
121 };
122
123 } // namespace
124
125 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateOptimizeTfControlFlowSideEffectPass()126 CreateOptimizeTfControlFlowSideEffectPass() {
127 return std::make_unique<OptimizeTfControlFlowSideEffectPass>();
128 }
129
130 static mlir::PassRegistration<OptimizeTfControlFlowSideEffectPass>
131 register_pass(CreateOptimizeTfControlFlowSideEffectPass);
132
133 } // namespace tfrt_compiler
134 } // namespace tensorflow
135