1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/ADT/DenseMap.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
19 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
20 #include "mlir/IR/Operation.h" // from @llvm-project
21 #include "mlir/IR/Value.h" // from @llvm-project
22 #include "mlir/Transforms/Passes.h" // from @llvm-project
23 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
24 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
25
26 namespace tensorflow {
27 namespace {
28
29 // This pass rewrites tf._TPUCompileMlirOp and tf.TPUExecuteOp into a single
30 // tf.TPUCompileMlirAndExecuteOp. Also it removes the unnecessary
31 // TPUCompileSucceededAssertOp.
32 class FuseTpuCompileAndExecutePass
33 : public mlir::PassWrapper<FuseTpuCompileAndExecutePass,
34 mlir::OperationPass<mlir::func::FuncOp>> {
35 public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FuseTpuCompileAndExecutePass)36 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FuseTpuCompileAndExecutePass)
37
38 llvm::StringRef getArgument() const final {
39 return "tfrt-fuse-tpu-compile-and-execute-ops";
40 }
getDescription() const41 llvm::StringRef getDescription() const final {
42 return "Fuse TPU Ops according to TFRT's requirements.";
43 }
44
runOnOperation()45 void runOnOperation() override {
46 auto func = getOperation();
47
48 // remove TPUCompileSucceededAssertOp
49 func.walk([&](mlir::Operation *op) {
50 if (llvm::isa<mlir::TF::TPUCompileSucceededAssertOp>(op)) {
51 op->erase();
52 }
53 });
54
55 // A map from an exec op to a struct containing the static shape tensor from
56 // a SetDynamicDimensionBoundsOp and the operand index.
57 llvm::SmallDenseMap<
58 mlir::TF::TPUExecuteOp,
59 llvm::SmallDenseMap<int, mlir::TF::SetStaticDimensionBoundsOp>>
60 exec_to_static_shaped_operands_map;
61
62 llvm::SmallVector<mlir::TF::TPUExecuteOp, 4> tpu_execute_ops;
63 func.walk([&](mlir::Operation *op) {
64 if (auto exec_op = llvm::dyn_cast<mlir::TF::TPUExecuteOp>(op)) {
65 tpu_execute_ops.push_back(exec_op);
66 // Collect any operands to this tf.Execute op that are defined by a
67 // SetStaticDimensionBoundsOp along with the operand index.
68 for (const auto &operand : llvm::enumerate(exec_op.getOperands())) {
69 if (auto defining_op =
70 operand.value()
71 .getDefiningOp<mlir::TF::SetStaticDimensionBoundsOp>()) {
72 exec_to_static_shaped_operands_map[exec_op][operand.index()] =
73 defining_op;
74 }
75 }
76 }
77 });
78
79 mlir::OpBuilder builder(&func.getBody());
80
81 for (auto exec_op : tpu_execute_ops) {
82 auto compile_cache_entry = exec_op.key();
83 auto compile_op = ::llvm::dyn_cast<mlir::TF::_TPUCompileMlirOp>(
84 compile_cache_entry.getDefiningOp());
85 if (!compile_op) {
86 exec_op.emitOpError("could not get the _TPUCompileMlirOp");
87 signalPassFailure();
88 return;
89 }
90
91 builder.setInsertionPointAfter(compile_op);
92 llvm::SmallVector<mlir::Type, 4> output_types;
93 output_types.push_back(mlir::RankedTensorType::get(
94 {3}, builder.getType<mlir::TF::StringType>()));
95 output_types.insert(output_types.end(), exec_op.getResultTypes().begin(),
96 exec_op.getResultTypes().end());
97 llvm::SmallVector<int> static_shaped_operand_indices_attr;
98 llvm::SmallVector<mlir::Value> static_shape_tensors;
99 llvm::SmallVector<mlir::Value> exec_op_args;
100 exec_op_args.resize(exec_op.args().size());
101
102 auto &static_shaped_operands =
103 exec_to_static_shaped_operands_map[exec_op];
104 for (int i = 0; i < exec_op.args().size(); ++i) {
105 auto iter = static_shaped_operands.find(i);
106 if (iter != static_shaped_operands.end()) {
107 static_shaped_operand_indices_attr.push_back(iter->first);
108 static_shape_tensors.push_back(iter->second.static_shape());
109 exec_op_args[i] = iter->second.input();
110 // The first operand is the input tensor, while the second operand is
111 // the static shape tensor, hence the drop_back here.
112 iter->second->replaceAllUsesWith(
113 mlir::ValueRange({iter->second.input()}));
114 iter->second->erase();
115 } else {
116 exec_op_args[i] = exec_op->getOperand(i);
117 }
118 }
119
120 auto producer_name =
121 exec_op->getAttrOfType<mlir::StringAttr>("_producer_name");
122 if (!producer_name)
123 producer_name = mlir::StringAttr::get(&getContext(), "default");
124 auto compile_and_execute_op =
125 builder.create<mlir::TF::TPUCompileMlirAndExecuteOp>(
126 exec_op.getLoc(), output_types, exec_op_args,
127 static_shape_tensors,
128 builder.getI32ArrayAttr(static_shaped_operand_indices_attr),
129 compile_op.mlir_module(), compile_op.metadata(), producer_name);
130
131 exec_op.replaceAllUsesWith(compile_and_execute_op.results());
132 for (auto program_result : compile_op.program()) {
133 program_result.replaceAllUsesWith(
134 compile_and_execute_op.rendezvous_key_base());
135 }
136
137 assert(exec_op.use_empty());
138 exec_op.erase();
139 assert(compile_op.use_empty());
140 compile_op.erase();
141 }
142 }
143 };
144
145 } // namespace
146
147 namespace tfrt_compiler {
148
149 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
CreateFuseTpuCompileAndExecutePass()150 CreateFuseTpuCompileAndExecutePass() {
151 return std::make_unique<FuseTpuCompileAndExecutePass>();
152 }
153
154 static mlir::PassRegistration<FuseTpuCompileAndExecutePass>
155 fuse_tpu_compile_and_execute_ops_pass;
156
157 } // namespace tfrt_compiler
158
159 } // namespace tensorflow
160