1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <string>
17 #include <utility>
18 
19 #include "mlir/Dialect/SCF/IR/SCF.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallSet.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/Support/Casting.h"
24 #include "llvm/Support/Debug.h"
25 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"  // from @llvm-project
26 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
27 #include "mlir/Dialect/Shape/IR/Shape.h"  // from @llvm-project
28 #include "mlir/IR/Block.h"  // from @llvm-project
29 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
30 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
31 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
32 #include "mlir/IR/Location.h"  // from @llvm-project
33 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
34 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
35 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
36 #include "mlir/IR/TypeRange.h"  // from @llvm-project
37 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
38 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
40 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
41 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
42 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
43 
44 constexpr llvm::StringRef
45     mlir::kernel_gen::tf_framework ::JITCompileFromStrOp::kJITEntryFunctionName;
46 
47 namespace mlir {
48 namespace kernel_gen {
49 namespace transforms {
50 namespace {
51 
52 constexpr int64_t i32BitLimit = 4294967296;
53 using shape::ShapeOfOp;
54 
IsSingleResultTFOperation(Operation * op)55 bool IsSingleResultTFOperation(Operation *op) {
56   assert(op != nullptr && "expect op");
57   if (op->getDialect() !=
58       op->getContext()->getLoadedDialect<TF::TensorFlowDialect>())
59     return false;
60   if (op->getNumResults() != 1) return false;
61   return true;
62 }
63 
IsUnaryTFOperation(Operation * op)64 bool IsUnaryTFOperation(Operation *op) {
65   return IsSingleResultTFOperation(op) && op->getNumOperands() == 1;
66 }
67 
68 struct TFToJITInvocationsPattern : public RewritePattern {
TFToJITInvocationsPatternmlir::kernel_gen::transforms::__anon5af4a4610111::TFToJITInvocationsPattern69   explicit TFToJITInvocationsPattern(MLIRContext *ctx)
70       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
71 
matchAndRewritemlir::kernel_gen::transforms::__anon5af4a4610111::TFToJITInvocationsPattern72   LogicalResult matchAndRewrite(Operation *op,
73                                 PatternRewriter &rewriter) const override {
74     // Apply to all single result TF ops except those that are already in a
75     // JIT-compiled region.
76     if (!IsSingleResultTFOperation(op) ||
77         op->getParentOfType<tf_framework::JITCompileOp>())
78       return failure();
79 
80     Location loc = op->getLoc();
81     Value op_result = op->getResults().front();
82 
83     // Create the JIT compile op.
84     auto jit_compile_op = rewriter.create<tf_framework::JITCompileOp>(
85         loc, rewriter.getType<tf_framework::JITCallableType>(),
86         /*ctx=*/llvm::None);
87 
88     // Move the TF operation into the body.
89     {
90       OpBuilder::InsertionGuard guard(rewriter);
91       llvm::SmallVector<Location> locs(op->getNumOperands(), loc);
92       Block *block = rewriter.createBlock(&jit_compile_op.body(), {},
93                                           op->getOperandTypes(), locs);
94 
95       // Map operands.
96       BlockAndValueMapping bvm;
97       for (auto it : llvm::zip(op->getOperands(), block->getArguments()))
98         bvm.map(std::get<0>(it), std::get<1>(it));
99 
100       rewriter.setInsertionPointToStart(block);
101       rewriter.clone(*op, bvm);
102       rewriter.create<tf_framework::JITCompileYieldOp>(loc,
103                                                        bvm.lookup(op_result));
104     }
105 
106     // Create JIT execute op.
107     rewriter.replaceOpWithNewOp<tf_framework::JITExecuteOp>(
108         op, op_result.getType(), /*ctx=*/Value(), jit_compile_op.result(),
109         op->getOperands());
110     return success();
111   }
112 };
113 
114 struct TFToI64JITInvocationForLargeTensorsPattern : public RewritePattern {
TFToI64JITInvocationForLargeTensorsPatternmlir::kernel_gen::transforms::__anon5af4a4610111::TFToI64JITInvocationForLargeTensorsPattern115   explicit TFToI64JITInvocationForLargeTensorsPattern(MLIRContext *ctx)
116       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
117 
matchAndRewritemlir::kernel_gen::transforms::__anon5af4a4610111::TFToI64JITInvocationForLargeTensorsPattern118   LogicalResult matchAndRewrite(Operation *op,
119                                 PatternRewriter &rewriter) const override {
120     if (!IsUnaryTFOperation(op) ||
121         !llvm::isa<func::FuncOp>(op->getParentOp())) {
122       return failure();
123     }
124 
125     auto results = llvm::to_vector<16>(op->getResults());
126     auto operand_types = llvm::to_vector<16>(llvm::map_range(
127         op->getOperands(), [](Value v) { return v.getType(); }));
128     auto result_types = llvm::to_vector<16>(
129         llvm::map_range(results, [](Value v) { return v.getType(); }));
130 
131     // Create the JIT compile op.
132     auto loc = op->getLoc();
133     Value shape_size_limit =
134         rewriter.create<arith::ConstantIndexOp>(loc, i32BitLimit);
135     auto arg = op->getOperands().front();
136     auto shape = rewriter.create<shape::ShapeOfOp>(loc, arg);
137     auto num_elems = rewriter.create<shape::NumElementsOp>(loc, shape);
138     Value coniditon_check_main = rewriter.create<arith::CmpIOp>(
139         loc, arith::CmpIPredicate::sgt, num_elems, shape_size_limit);
140 
141     Value conditional_path =
142         rewriter
143             .create<scf::IfOp>(
144                 loc, op->getResultTypes(), coniditon_check_main,
145                 [&](OpBuilder &b, Location l) {
146                   auto jit_compile_op =
147                       rewriter.create<tf_framework::JITCompileOp>(
148                           loc,
149                           rewriter.getType<tf_framework::JITCallableType>(),
150                           llvm::None);
151                   BlockAndValueMapping bvm;
152                   {
153                     OpBuilder::InsertionGuard guard(rewriter);
154                     Block *block = rewriter.createBlock(
155                         &jit_compile_op.body(), {}, operand_types,
156                         SmallVector<Location>(operand_types.size(), loc));
157                     for (auto it :
158                          llvm::zip(op->getOperands(), block->getArguments()))
159                       bvm.map(std::get<0>(it), std::get<1>(it));
160                     rewriter.setInsertionPointToStart(block);
161                     rewriter.clone(*op, bvm);
162                     auto new_op = rewriter.clone(*op, bvm);
163                     rewriter.create<tf_framework::JITCompileYieldOp>(
164                         loc, TypeRange{}, new_op->getResults());
165                   }
166                   auto jit_execute_op =
167                       rewriter.create<tf_framework::JITExecuteOp>(
168                           loc, result_types, Value(), jit_compile_op.result(),
169                           op->getOperands());
170                   b.create<scf::YieldOp>(l, jit_execute_op.result());
171                 },
172                 [&](OpBuilder &b, Location l) {
173                   auto new_op = rewriter.clone(*op);
174                   b.create<scf::YieldOp>(l, new_op->getResult(0));
175                 })
176             .getResult(0);
177 
178     rewriter.replaceOp(op, conditional_path);
179     return success();
180   }
181 };
182 
183 struct PackJITCompileOpPattern
184     : public OpRewritePattern<tf_framework::JITCompileOp> {
185   using OpRewritePattern<tf_framework::JITCompileOp>::OpRewritePattern;
186 
PackJITCompileOpPatternmlir::kernel_gen::transforms::__anon5af4a4610111::PackJITCompileOpPattern187   explicit PackJITCompileOpPattern(MLIRContext *ctx,
188                                    llvm::ArrayRef<int64_t> tile_sizes,
189                                    llvm::ArrayRef<int64_t> unroll_factors,
190                                    int64_t max_supported_rank, bool enable_ftz,
191                                    bool index_64bit_if_jit_compiling,
192                                    bool cpu_codegen)
193       : OpRewritePattern<tf_framework::JITCompileOp>(ctx),
194         tile_sizes(tile_sizes),
195         unroll_factors(unroll_factors),
196         max_supported_rank(max_supported_rank),
197         enable_ftz(enable_ftz),
198         index_64bit_if_jit_compiling(index_64bit_if_jit_compiling),
199         cpu_codegen(cpu_codegen) {}
200 
matchAndRewritemlir::kernel_gen::transforms::__anon5af4a4610111::PackJITCompileOpPattern201   LogicalResult matchAndRewrite(tf_framework::JITCompileOp op,
202                                 PatternRewriter &rewriter) const override {
203     Block *body = op.getBody();
204     auto yield_op =
205         llvm::cast<tf_framework::JITCompileYieldOp>(body->getTerminator());
206 
207     // Temporarily, build the module that would be JIT-compiled. This is only to
208     // obtain the serialized code attribute.
209     auto loc = op->getLoc();
210     OpBuilder tmp_module_builder(getContext(), rewriter.getListener());
211     auto jit_module = tmp_module_builder.create<ModuleOp>(loc);
212     tmp_module_builder.setInsertionPointToStart(jit_module.getBody());
213     auto jit_function = tmp_module_builder.create<func::FuncOp>(
214         loc, tf_framework::JITCompileFromStrOp::kJITEntryFunctionName,
215         tmp_module_builder.getFunctionType(body->getArgumentTypes(),
216                                            yield_op->getOperandTypes()));
217     jit_function->setAttr(tf_framework::TFFrameworkDialect::kTFEntryAttrName,
218                           tmp_module_builder.getUnitAttr());
219     jit_function.getBody().takeBody(op.getBodyRegion());
220     tmp_module_builder.setInsertionPointToEnd(&jit_function.getBody().front());
221     tmp_module_builder.create<func::ReturnOp>(loc, yield_op.result());
222     rewriter.eraseOp(yield_op);
223 
224     // Serialize JIT module.
225     std::string code;
226     llvm::raw_string_ostream ss(code);
227     jit_module.print(ss);
228 
229     // Finally, create the new JIT compile op.
230     rewriter.replaceOpWithNewOp<tf_framework::JITCompileFromStrOp>(
231         op, op->getResultTypes(), op.ctx(), rewriter.getStringAttr(code),
232         rewriter.getI64ArrayAttr(tile_sizes),
233         rewriter.getI64ArrayAttr(unroll_factors),
234         rewriter.getI64IntegerAttr(max_supported_rank),
235         rewriter.getBoolAttr(enable_ftz),
236         rewriter.getBoolAttr(index_64bit_if_jit_compiling),
237         rewriter.getBoolAttr(cpu_codegen));
238 
239     return success();
240   }
241 
242  private:
243   llvm::ArrayRef<int64_t> tile_sizes;
244   llvm::ArrayRef<int64_t> unroll_factors;
245   int64_t max_supported_rank;
246   bool enable_ftz;
247   bool index_64bit_if_jit_compiling;
248   bool cpu_codegen;
249 };
250 
251 #define GEN_PASS_CLASSES
252 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
253 
254 struct TFToJITInvocationPass
255     : public TFToJITInvocationPassBase<TFToJITInvocationPass> {
getDependentDialectsmlir::kernel_gen::transforms::__anon5af4a4610111::TFToJITInvocationPass256   void getDependentDialects(DialectRegistry &registry) const override {
257     registry.insert<mlir::kernel_gen::tf_framework::TFFrameworkDialect,
258                     scf::SCFDialect, shape::ShapeDialect>();
259   }
TFToJITInvocationPassmlir::kernel_gen::transforms::__anon5af4a4610111::TFToJITInvocationPass260   explicit TFToJITInvocationPass(llvm::ArrayRef<int64_t> tile_sizes,
261                                  llvm::ArrayRef<int64_t> unroll_factors,
262                                  int64_t max_supported_rank, bool enable_ftz,
263                                  bool index_64bit, bool cpu_codegen,
264                                  bool jit_i64_indexed_for_large_tensors) {
265     tile_sizes_ = tile_sizes;
266     unroll_factors_ = unroll_factors;
267     max_supported_rank_ = max_supported_rank;
268     enable_ftz_ = enable_ftz;
269     index_64bit_ = index_64bit;
270     cpu_codegen_ = cpu_codegen;
271     jit_i64_indexed_for_large_tensors_ = jit_i64_indexed_for_large_tensors;
272   }
273 
runOnOperationmlir::kernel_gen::transforms::__anon5af4a4610111::TFToJITInvocationPass274   void runOnOperation() override {
275     MLIRContext *ctx = &getContext();
276     RewritePatternSet patterns(ctx);
277     PopulateTFToJITInvocationPatterns(ctx, &patterns, tile_sizes_,
278                                       unroll_factors_, max_supported_rank_,
279                                       enable_ftz_, index_64bit_, cpu_codegen_,
280                                       jit_i64_indexed_for_large_tensors_);
281     if (failed(applyPatternsAndFoldGreedily(getOperation(),
282                                             std::move(patterns)))) {
283       return signalPassFailure();
284     }
285   }
286 };
287 
288 }  // namespace
289 
PopulateTFToJITInvocationPatterns(MLIRContext * ctx,RewritePatternSet * patterns,llvm::ArrayRef<int64_t> tile_sizes,llvm::ArrayRef<int64_t> unroll_factors,int64_t max_supported_rank,bool enable_ftz,bool index_64bit,bool cpu_codegen,bool jit_i64_indexed_for_large_tensors)290 void PopulateTFToJITInvocationPatterns(
291     MLIRContext *ctx, RewritePatternSet *patterns,
292     llvm::ArrayRef<int64_t> tile_sizes, llvm::ArrayRef<int64_t> unroll_factors,
293     int64_t max_supported_rank, bool enable_ftz, bool index_64bit,
294     bool cpu_codegen, bool jit_i64_indexed_for_large_tensors) {
295   if (jit_i64_indexed_for_large_tensors) {
296     patterns->add<TFToI64JITInvocationForLargeTensorsPattern>(ctx);
297   } else {
298     patterns->add<TFToJITInvocationsPattern>(ctx);
299   }
300 
301   bool index_64bit_if_jit_compiling =
302       jit_i64_indexed_for_large_tensors ? true : index_64bit;
303   patterns->add<PackJITCompileOpPattern>(
304       ctx, tile_sizes, unroll_factors, max_supported_rank, enable_ftz,
305       index_64bit_if_jit_compiling, cpu_codegen);
306 }
307 
CreateTFToJITInvocationPass(llvm::ArrayRef<int64_t> tile_sizes,llvm::ArrayRef<int64_t> unroll_factors,int64_t max_supported_rank,bool enable_ftz,bool index_64bit,bool cpu_codegen,bool jit_i64_indexed_for_large_tensors)308 std::unique_ptr<OperationPass<func::FuncOp>> CreateTFToJITInvocationPass(
309     llvm::ArrayRef<int64_t> tile_sizes, llvm::ArrayRef<int64_t> unroll_factors,
310     int64_t max_supported_rank, bool enable_ftz, bool index_64bit,
311     bool cpu_codegen, bool jit_i64_indexed_for_large_tensors) {
312   return std::make_unique<TFToJITInvocationPass>(
313       tile_sizes, unroll_factors, max_supported_rank, enable_ftz, index_64bit,
314       cpu_codegen, jit_i64_indexed_for_large_tensors);
315 }
316 
317 }  // namespace transforms
318 }  // namespace kernel_gen
319 }  // namespace mlir
320