1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <string>
17 #include <utility>
18
19 #include "mlir/Dialect/SCF/IR/SCF.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallSet.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/Support/Casting.h"
24 #include "llvm/Support/Debug.h"
25 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
26 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
27 #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
28 #include "mlir/IR/Block.h" // from @llvm-project
29 #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
30 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
31 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
32 #include "mlir/IR/Location.h" // from @llvm-project
33 #include "mlir/IR/MLIRContext.h" // from @llvm-project
34 #include "mlir/IR/OperationSupport.h" // from @llvm-project
35 #include "mlir/IR/PatternMatch.h" // from @llvm-project
36 #include "mlir/IR/TypeRange.h" // from @llvm-project
37 #include "mlir/Support/LogicalResult.h" // from @llvm-project
38 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
40 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
41 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
42 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
43
44 constexpr llvm::StringRef
45 mlir::kernel_gen::tf_framework ::JITCompileFromStrOp::kJITEntryFunctionName;
46
47 namespace mlir {
48 namespace kernel_gen {
49 namespace transforms {
50 namespace {
51
52 constexpr int64_t i32BitLimit = 4294967296;
53 using shape::ShapeOfOp;
54
IsSingleResultTFOperation(Operation * op)55 bool IsSingleResultTFOperation(Operation *op) {
56 assert(op != nullptr && "expect op");
57 if (op->getDialect() !=
58 op->getContext()->getLoadedDialect<TF::TensorFlowDialect>())
59 return false;
60 if (op->getNumResults() != 1) return false;
61 return true;
62 }
63
IsUnaryTFOperation(Operation * op)64 bool IsUnaryTFOperation(Operation *op) {
65 return IsSingleResultTFOperation(op) && op->getNumOperands() == 1;
66 }
67
68 struct TFToJITInvocationsPattern : public RewritePattern {
TFToJITInvocationsPatternmlir::kernel_gen::transforms::__anon5af4a4610111::TFToJITInvocationsPattern69 explicit TFToJITInvocationsPattern(MLIRContext *ctx)
70 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
71
matchAndRewritemlir::kernel_gen::transforms::__anon5af4a4610111::TFToJITInvocationsPattern72 LogicalResult matchAndRewrite(Operation *op,
73 PatternRewriter &rewriter) const override {
74 // Apply to all single result TF ops except those that are already in a
75 // JIT-compiled region.
76 if (!IsSingleResultTFOperation(op) ||
77 op->getParentOfType<tf_framework::JITCompileOp>())
78 return failure();
79
80 Location loc = op->getLoc();
81 Value op_result = op->getResults().front();
82
83 // Create the JIT compile op.
84 auto jit_compile_op = rewriter.create<tf_framework::JITCompileOp>(
85 loc, rewriter.getType<tf_framework::JITCallableType>(),
86 /*ctx=*/llvm::None);
87
88 // Move the TF operation into the body.
89 {
90 OpBuilder::InsertionGuard guard(rewriter);
91 llvm::SmallVector<Location> locs(op->getNumOperands(), loc);
92 Block *block = rewriter.createBlock(&jit_compile_op.body(), {},
93 op->getOperandTypes(), locs);
94
95 // Map operands.
96 BlockAndValueMapping bvm;
97 for (auto it : llvm::zip(op->getOperands(), block->getArguments()))
98 bvm.map(std::get<0>(it), std::get<1>(it));
99
100 rewriter.setInsertionPointToStart(block);
101 rewriter.clone(*op, bvm);
102 rewriter.create<tf_framework::JITCompileYieldOp>(loc,
103 bvm.lookup(op_result));
104 }
105
106 // Create JIT execute op.
107 rewriter.replaceOpWithNewOp<tf_framework::JITExecuteOp>(
108 op, op_result.getType(), /*ctx=*/Value(), jit_compile_op.result(),
109 op->getOperands());
110 return success();
111 }
112 };
113
114 struct TFToI64JITInvocationForLargeTensorsPattern : public RewritePattern {
TFToI64JITInvocationForLargeTensorsPatternmlir::kernel_gen::transforms::__anon5af4a4610111::TFToI64JITInvocationForLargeTensorsPattern115 explicit TFToI64JITInvocationForLargeTensorsPattern(MLIRContext *ctx)
116 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
117
matchAndRewritemlir::kernel_gen::transforms::__anon5af4a4610111::TFToI64JITInvocationForLargeTensorsPattern118 LogicalResult matchAndRewrite(Operation *op,
119 PatternRewriter &rewriter) const override {
120 if (!IsUnaryTFOperation(op) ||
121 !llvm::isa<func::FuncOp>(op->getParentOp())) {
122 return failure();
123 }
124
125 auto results = llvm::to_vector<16>(op->getResults());
126 auto operand_types = llvm::to_vector<16>(llvm::map_range(
127 op->getOperands(), [](Value v) { return v.getType(); }));
128 auto result_types = llvm::to_vector<16>(
129 llvm::map_range(results, [](Value v) { return v.getType(); }));
130
131 // Create the JIT compile op.
132 auto loc = op->getLoc();
133 Value shape_size_limit =
134 rewriter.create<arith::ConstantIndexOp>(loc, i32BitLimit);
135 auto arg = op->getOperands().front();
136 auto shape = rewriter.create<shape::ShapeOfOp>(loc, arg);
137 auto num_elems = rewriter.create<shape::NumElementsOp>(loc, shape);
138 Value coniditon_check_main = rewriter.create<arith::CmpIOp>(
139 loc, arith::CmpIPredicate::sgt, num_elems, shape_size_limit);
140
141 Value conditional_path =
142 rewriter
143 .create<scf::IfOp>(
144 loc, op->getResultTypes(), coniditon_check_main,
145 [&](OpBuilder &b, Location l) {
146 auto jit_compile_op =
147 rewriter.create<tf_framework::JITCompileOp>(
148 loc,
149 rewriter.getType<tf_framework::JITCallableType>(),
150 llvm::None);
151 BlockAndValueMapping bvm;
152 {
153 OpBuilder::InsertionGuard guard(rewriter);
154 Block *block = rewriter.createBlock(
155 &jit_compile_op.body(), {}, operand_types,
156 SmallVector<Location>(operand_types.size(), loc));
157 for (auto it :
158 llvm::zip(op->getOperands(), block->getArguments()))
159 bvm.map(std::get<0>(it), std::get<1>(it));
160 rewriter.setInsertionPointToStart(block);
161 rewriter.clone(*op, bvm);
162 auto new_op = rewriter.clone(*op, bvm);
163 rewriter.create<tf_framework::JITCompileYieldOp>(
164 loc, TypeRange{}, new_op->getResults());
165 }
166 auto jit_execute_op =
167 rewriter.create<tf_framework::JITExecuteOp>(
168 loc, result_types, Value(), jit_compile_op.result(),
169 op->getOperands());
170 b.create<scf::YieldOp>(l, jit_execute_op.result());
171 },
172 [&](OpBuilder &b, Location l) {
173 auto new_op = rewriter.clone(*op);
174 b.create<scf::YieldOp>(l, new_op->getResult(0));
175 })
176 .getResult(0);
177
178 rewriter.replaceOp(op, conditional_path);
179 return success();
180 }
181 };
182
183 struct PackJITCompileOpPattern
184 : public OpRewritePattern<tf_framework::JITCompileOp> {
185 using OpRewritePattern<tf_framework::JITCompileOp>::OpRewritePattern;
186
PackJITCompileOpPatternmlir::kernel_gen::transforms::__anon5af4a4610111::PackJITCompileOpPattern187 explicit PackJITCompileOpPattern(MLIRContext *ctx,
188 llvm::ArrayRef<int64_t> tile_sizes,
189 llvm::ArrayRef<int64_t> unroll_factors,
190 int64_t max_supported_rank, bool enable_ftz,
191 bool index_64bit_if_jit_compiling,
192 bool cpu_codegen)
193 : OpRewritePattern<tf_framework::JITCompileOp>(ctx),
194 tile_sizes(tile_sizes),
195 unroll_factors(unroll_factors),
196 max_supported_rank(max_supported_rank),
197 enable_ftz(enable_ftz),
198 index_64bit_if_jit_compiling(index_64bit_if_jit_compiling),
199 cpu_codegen(cpu_codegen) {}
200
matchAndRewritemlir::kernel_gen::transforms::__anon5af4a4610111::PackJITCompileOpPattern201 LogicalResult matchAndRewrite(tf_framework::JITCompileOp op,
202 PatternRewriter &rewriter) const override {
203 Block *body = op.getBody();
204 auto yield_op =
205 llvm::cast<tf_framework::JITCompileYieldOp>(body->getTerminator());
206
207 // Temporarily, build the module that would be JIT-compiled. This is only to
208 // obtain the serialized code attribute.
209 auto loc = op->getLoc();
210 OpBuilder tmp_module_builder(getContext(), rewriter.getListener());
211 auto jit_module = tmp_module_builder.create<ModuleOp>(loc);
212 tmp_module_builder.setInsertionPointToStart(jit_module.getBody());
213 auto jit_function = tmp_module_builder.create<func::FuncOp>(
214 loc, tf_framework::JITCompileFromStrOp::kJITEntryFunctionName,
215 tmp_module_builder.getFunctionType(body->getArgumentTypes(),
216 yield_op->getOperandTypes()));
217 jit_function->setAttr(tf_framework::TFFrameworkDialect::kTFEntryAttrName,
218 tmp_module_builder.getUnitAttr());
219 jit_function.getBody().takeBody(op.getBodyRegion());
220 tmp_module_builder.setInsertionPointToEnd(&jit_function.getBody().front());
221 tmp_module_builder.create<func::ReturnOp>(loc, yield_op.result());
222 rewriter.eraseOp(yield_op);
223
224 // Serialize JIT module.
225 std::string code;
226 llvm::raw_string_ostream ss(code);
227 jit_module.print(ss);
228
229 // Finally, create the new JIT compile op.
230 rewriter.replaceOpWithNewOp<tf_framework::JITCompileFromStrOp>(
231 op, op->getResultTypes(), op.ctx(), rewriter.getStringAttr(code),
232 rewriter.getI64ArrayAttr(tile_sizes),
233 rewriter.getI64ArrayAttr(unroll_factors),
234 rewriter.getI64IntegerAttr(max_supported_rank),
235 rewriter.getBoolAttr(enable_ftz),
236 rewriter.getBoolAttr(index_64bit_if_jit_compiling),
237 rewriter.getBoolAttr(cpu_codegen));
238
239 return success();
240 }
241
242 private:
243 llvm::ArrayRef<int64_t> tile_sizes;
244 llvm::ArrayRef<int64_t> unroll_factors;
245 int64_t max_supported_rank;
246 bool enable_ftz;
247 bool index_64bit_if_jit_compiling;
248 bool cpu_codegen;
249 };
250
251 #define GEN_PASS_CLASSES
252 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
253
254 struct TFToJITInvocationPass
255 : public TFToJITInvocationPassBase<TFToJITInvocationPass> {
getDependentDialectsmlir::kernel_gen::transforms::__anon5af4a4610111::TFToJITInvocationPass256 void getDependentDialects(DialectRegistry ®istry) const override {
257 registry.insert<mlir::kernel_gen::tf_framework::TFFrameworkDialect,
258 scf::SCFDialect, shape::ShapeDialect>();
259 }
TFToJITInvocationPassmlir::kernel_gen::transforms::__anon5af4a4610111::TFToJITInvocationPass260 explicit TFToJITInvocationPass(llvm::ArrayRef<int64_t> tile_sizes,
261 llvm::ArrayRef<int64_t> unroll_factors,
262 int64_t max_supported_rank, bool enable_ftz,
263 bool index_64bit, bool cpu_codegen,
264 bool jit_i64_indexed_for_large_tensors) {
265 tile_sizes_ = tile_sizes;
266 unroll_factors_ = unroll_factors;
267 max_supported_rank_ = max_supported_rank;
268 enable_ftz_ = enable_ftz;
269 index_64bit_ = index_64bit;
270 cpu_codegen_ = cpu_codegen;
271 jit_i64_indexed_for_large_tensors_ = jit_i64_indexed_for_large_tensors;
272 }
273
runOnOperationmlir::kernel_gen::transforms::__anon5af4a4610111::TFToJITInvocationPass274 void runOnOperation() override {
275 MLIRContext *ctx = &getContext();
276 RewritePatternSet patterns(ctx);
277 PopulateTFToJITInvocationPatterns(ctx, &patterns, tile_sizes_,
278 unroll_factors_, max_supported_rank_,
279 enable_ftz_, index_64bit_, cpu_codegen_,
280 jit_i64_indexed_for_large_tensors_);
281 if (failed(applyPatternsAndFoldGreedily(getOperation(),
282 std::move(patterns)))) {
283 return signalPassFailure();
284 }
285 }
286 };
287
288 } // namespace
289
PopulateTFToJITInvocationPatterns(MLIRContext * ctx,RewritePatternSet * patterns,llvm::ArrayRef<int64_t> tile_sizes,llvm::ArrayRef<int64_t> unroll_factors,int64_t max_supported_rank,bool enable_ftz,bool index_64bit,bool cpu_codegen,bool jit_i64_indexed_for_large_tensors)290 void PopulateTFToJITInvocationPatterns(
291 MLIRContext *ctx, RewritePatternSet *patterns,
292 llvm::ArrayRef<int64_t> tile_sizes, llvm::ArrayRef<int64_t> unroll_factors,
293 int64_t max_supported_rank, bool enable_ftz, bool index_64bit,
294 bool cpu_codegen, bool jit_i64_indexed_for_large_tensors) {
295 if (jit_i64_indexed_for_large_tensors) {
296 patterns->add<TFToI64JITInvocationForLargeTensorsPattern>(ctx);
297 } else {
298 patterns->add<TFToJITInvocationsPattern>(ctx);
299 }
300
301 bool index_64bit_if_jit_compiling =
302 jit_i64_indexed_for_large_tensors ? true : index_64bit;
303 patterns->add<PackJITCompileOpPattern>(
304 ctx, tile_sizes, unroll_factors, max_supported_rank, enable_ftz,
305 index_64bit_if_jit_compiling, cpu_codegen);
306 }
307
CreateTFToJITInvocationPass(llvm::ArrayRef<int64_t> tile_sizes,llvm::ArrayRef<int64_t> unroll_factors,int64_t max_supported_rank,bool enable_ftz,bool index_64bit,bool cpu_codegen,bool jit_i64_indexed_for_large_tensors)308 std::unique_ptr<OperationPass<func::FuncOp>> CreateTFToJITInvocationPass(
309 llvm::ArrayRef<int64_t> tile_sizes, llvm::ArrayRef<int64_t> unroll_factors,
310 int64_t max_supported_rank, bool enable_ftz, bool index_64bit,
311 bool cpu_codegen, bool jit_i64_indexed_for_large_tensors) {
312 return std::make_unique<TFToJITInvocationPass>(
313 tile_sizes, unroll_factors, max_supported_rank, enable_ftz, index_64bit,
314 cpu_codegen, jit_i64_indexed_for_large_tensors);
315 }
316
317 } // namespace transforms
318 } // namespace kernel_gen
319 } // namespace mlir
320