1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
17 #include "mlir/Transforms/Passes.h"  // from @llvm-project
18 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
19 #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h"
20 
21 namespace tensorflow {
22 namespace tfrt_compiler {
23 namespace {
24 
FunctionHasSideEffect(mlir::func::FuncOp func_op,llvm::DenseMap<mlir::func::FuncOp,bool> & function_side_effect)25 bool FunctionHasSideEffect(
26     mlir::func::FuncOp func_op,
27     llvm::DenseMap<mlir::func::FuncOp, bool>& function_side_effect) {
28   auto iter = function_side_effect.find(func_op);
29   if (iter != function_side_effect.end()) return iter->second;
30 
31   auto& block = func_op.front();
32 
33   auto op_has_side_effect = [&](mlir::Operation* op) {
34     if (auto while_op = llvm::dyn_cast<mlir::TF::WhileOp>(op)) {
35       if (while_op.is_stateless()) return false;
36 
37       return FunctionHasSideEffect(while_op.cond_function(),
38                                    function_side_effect) ||
39              FunctionHasSideEffect(while_op.body_function(),
40                                    function_side_effect);
41     }
42 
43     if (auto if_op = llvm::dyn_cast<mlir::TF::IfOp>(op)) {
44       if (if_op.is_stateless()) return false;
45 
46       return FunctionHasSideEffect(if_op.else_function(),
47                                    function_side_effect) ||
48              FunctionHasSideEffect(if_op.then_function(), function_side_effect);
49     }
50 
51     // Though tf.Assert and tf.Timestamp are side-effecting, they do not
52     // interfere with any other side-effecting ops. For now, if control flow
53     // ops' callee functions contain them, we treat them as non-side-effecting.
54     if (llvm::isa<mlir::TF::AssertOp, mlir::TF::TimestampOp>(op)) return false;
55 
56     return !mlir::MemoryEffectOpInterface::hasNoEffect(op);
57   };
58 
59   // Speculatively setting the function to have no side effect to avoid infinite
60   // recursion. The correct side effect will be updated later once more
61   // operations in the block are checked.
62   function_side_effect[func_op] = false;
63 
64   for (mlir::Operation& op : block) {
65     if (op_has_side_effect(&op)) {
66       function_side_effect[func_op] = true;
67       return true;
68     }
69   }
70 
71   function_side_effect[func_op] = false;
72   return false;
73 }
74 
75 // This pass sets `is_stateless` attribute of tf.If and tf.While ops to true if
76 // their callee functions contains only non-side-effecting ops.
77 class OptimizeTfControlFlowSideEffectPass
78     : public mlir::PassWrapper<OptimizeTfControlFlowSideEffectPass,
79                                mlir::OperationPass<mlir::ModuleOp>> {
80  public:
81   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
82       OptimizeTfControlFlowSideEffectPass)
83 
84  private:
getArgument() const85   llvm::StringRef getArgument() const final {
86     return "tfrt-optimize-tf-control-flow-side-effect";
87   }
getDescription() const88   llvm::StringRef getDescription() const final {
89     return "Set tf control flow ops to stateless if their callee functions "
90            "contains only non-side-effecting ops";
91   }
runOnOperation()92   void runOnOperation() override {
93     auto module = getOperation();
94     llvm::DenseMap<mlir::func::FuncOp, bool> function_side_effect;
95 
96     mlir::Builder builder(module.getContext());
97     module.walk([&](mlir::Operation* op) {
98       if (auto while_op = llvm::dyn_cast<mlir::TF::WhileOp>(op)) {
99         if (while_op.is_stateless()) return;
100 
101         if (!FunctionHasSideEffect(while_op.cond_function(),
102                                    function_side_effect) &&
103             !FunctionHasSideEffect(while_op.body_function(),
104                                    function_side_effect)) {
105           while_op->setAttr("is_stateless", builder.getBoolAttr(true));
106         }
107       }
108 
109       if (auto if_op = llvm::dyn_cast<mlir::TF::IfOp>(op)) {
110         if (if_op.is_stateless()) return;
111 
112         if (!FunctionHasSideEffect(if_op.else_function(),
113                                    function_side_effect) &&
114             !FunctionHasSideEffect(if_op.then_function(),
115                                    function_side_effect)) {
116           if_op->setAttr("is_stateless", builder.getBoolAttr(true));
117         }
118       }
119     });
120   }
121 };
122 
123 }  // namespace
124 
125 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateOptimizeTfControlFlowSideEffectPass()126 CreateOptimizeTfControlFlowSideEffectPass() {
127   return std::make_unique<OptimizeTfControlFlowSideEffectPass>();
128 }
129 
130 static mlir::PassRegistration<OptimizeTfControlFlowSideEffectPass>
131     register_pass(CreateOptimizeTfControlFlowSideEffectPass);
132 
133 }  // namespace tfrt_compiler
134 }  // namespace tensorflow
135