1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/DenseMap.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
19 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
20 #include "mlir/IR/Operation.h"  // from @llvm-project
21 #include "mlir/IR/Value.h"  // from @llvm-project
22 #include "mlir/Transforms/Passes.h"  // from @llvm-project
23 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
24 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
25 
26 namespace tensorflow {
27 namespace {
28 
29 // This pass rewrites tf._TPUCompileMlirOp and tf.TPUExecuteOp into a single
30 // tf.TPUCompileMlirAndExecuteOp. Also it removes the unnecessary
31 // TPUCompileSucceededAssertOp.
32 class FuseTpuCompileAndExecutePass
33     : public mlir::PassWrapper<FuseTpuCompileAndExecutePass,
34                                mlir::OperationPass<mlir::func::FuncOp>> {
35  public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FuseTpuCompileAndExecutePass)36   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FuseTpuCompileAndExecutePass)
37 
38   llvm::StringRef getArgument() const final {
39     return "tfrt-fuse-tpu-compile-and-execute-ops";
40   }
getDescription() const41   llvm::StringRef getDescription() const final {
42     return "Fuse TPU Ops according to TFRT's requirements.";
43   }
44 
runOnOperation()45   void runOnOperation() override {
46     auto func = getOperation();
47 
48     // remove TPUCompileSucceededAssertOp
49     func.walk([&](mlir::Operation *op) {
50       if (llvm::isa<mlir::TF::TPUCompileSucceededAssertOp>(op)) {
51         op->erase();
52       }
53     });
54 
55     // A map from an exec op to a struct containing the static shape tensor from
56     // a SetDynamicDimensionBoundsOp and the operand index.
57     llvm::SmallDenseMap<
58         mlir::TF::TPUExecuteOp,
59         llvm::SmallDenseMap<int, mlir::TF::SetStaticDimensionBoundsOp>>
60         exec_to_static_shaped_operands_map;
61 
62     llvm::SmallVector<mlir::TF::TPUExecuteOp, 4> tpu_execute_ops;
63     func.walk([&](mlir::Operation *op) {
64       if (auto exec_op = llvm::dyn_cast<mlir::TF::TPUExecuteOp>(op)) {
65         tpu_execute_ops.push_back(exec_op);
66         // Collect any operands to this tf.Execute op that are defined by a
67         // SetStaticDimensionBoundsOp along with the operand index.
68         for (const auto &operand : llvm::enumerate(exec_op.getOperands())) {
69           if (auto defining_op =
70                   operand.value()
71                       .getDefiningOp<mlir::TF::SetStaticDimensionBoundsOp>()) {
72             exec_to_static_shaped_operands_map[exec_op][operand.index()] =
73                 defining_op;
74           }
75         }
76       }
77     });
78 
79     mlir::OpBuilder builder(&func.getBody());
80 
81     for (auto exec_op : tpu_execute_ops) {
82       auto compile_cache_entry = exec_op.key();
83       auto compile_op = ::llvm::dyn_cast<mlir::TF::_TPUCompileMlirOp>(
84           compile_cache_entry.getDefiningOp());
85       if (!compile_op) {
86         exec_op.emitOpError("could not get the _TPUCompileMlirOp");
87         signalPassFailure();
88         return;
89       }
90 
91       builder.setInsertionPointAfter(compile_op);
92       llvm::SmallVector<mlir::Type, 4> output_types;
93       output_types.push_back(mlir::RankedTensorType::get(
94           {3}, builder.getType<mlir::TF::StringType>()));
95       output_types.insert(output_types.end(), exec_op.getResultTypes().begin(),
96                           exec_op.getResultTypes().end());
97       llvm::SmallVector<int> static_shaped_operand_indices_attr;
98       llvm::SmallVector<mlir::Value> static_shape_tensors;
99       llvm::SmallVector<mlir::Value> exec_op_args;
100       exec_op_args.resize(exec_op.args().size());
101 
102       auto &static_shaped_operands =
103           exec_to_static_shaped_operands_map[exec_op];
104       for (int i = 0; i < exec_op.args().size(); ++i) {
105         auto iter = static_shaped_operands.find(i);
106         if (iter != static_shaped_operands.end()) {
107           static_shaped_operand_indices_attr.push_back(iter->first);
108           static_shape_tensors.push_back(iter->second.static_shape());
109           exec_op_args[i] = iter->second.input();
110           // The first operand is the input tensor, while the second operand is
111           // the static shape tensor, hence the drop_back here.
112           iter->second->replaceAllUsesWith(
113               mlir::ValueRange({iter->second.input()}));
114           iter->second->erase();
115         } else {
116           exec_op_args[i] = exec_op->getOperand(i);
117         }
118       }
119 
120       auto producer_name =
121           exec_op->getAttrOfType<mlir::StringAttr>("_producer_name");
122       if (!producer_name)
123         producer_name = mlir::StringAttr::get(&getContext(), "default");
124       auto compile_and_execute_op =
125           builder.create<mlir::TF::TPUCompileMlirAndExecuteOp>(
126               exec_op.getLoc(), output_types, exec_op_args,
127               static_shape_tensors,
128               builder.getI32ArrayAttr(static_shaped_operand_indices_attr),
129               compile_op.mlir_module(), compile_op.metadata(), producer_name);
130 
131       exec_op.replaceAllUsesWith(compile_and_execute_op.results());
132       for (auto program_result : compile_op.program()) {
133         program_result.replaceAllUsesWith(
134             compile_and_execute_op.rendezvous_key_base());
135       }
136 
137       assert(exec_op.use_empty());
138       exec_op.erase();
139       assert(compile_op.use_empty());
140       compile_op.erase();
141     }
142   }
143 };
144 
145 }  // namespace
146 
147 namespace tfrt_compiler {
148 
149 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
CreateFuseTpuCompileAndExecutePass()150 CreateFuseTpuCompileAndExecutePass() {
151   return std::make_unique<FuseTpuCompileAndExecutePass>();
152 }
153 
154 static mlir::PassRegistration<FuseTpuCompileAndExecutePass>
155     fuse_tpu_compile_and_execute_ops_pass;
156 
157 }  // namespace tfrt_compiler
158 
159 }  // namespace tensorflow
160