1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_MLIR_TFRT_JIT_TF_JITRT_PASSES_H_ 17 #define TENSORFLOW_COMPILER_MLIR_TFRT_JIT_TF_JITRT_PASSES_H_ 18 19 #include <memory> 20 #include <string> 21 22 #include "mlir/Dialect/Func/IR/FuncOps.h" 23 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 24 #include "mlir/Dialect/Linalg/IR/Linalg.h" 25 #include "mlir/Dialect/MemRef/IR/MemRef.h" 26 #include "mlir/Dialect/SCF/IR/SCF.h" 27 #include "mlir/Dialect/Vector/IR/VectorOps.h" 28 #include "mlir/Pass/Pass.h" 29 #include "mlir/Pass/PassManager.h" 30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" 31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" 32 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h" 33 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" 34 35 namespace tensorflow { 36 37 // Pass for trivial buffer forwarding for the linalg.generic operations. 38 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> 39 CreateLinalgTrivialBufferForwardingPass(); 40 41 // Pass for trivial copy removal of memref.copy operations. 42 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> 43 CreateLinalgTrivialCopyRemovalPass(); 44 45 // Pass to optimize padding in tiled loops by peeling the final loop iteration. 46 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> 47 CreatePeelTiledLoopsPass(); 48 49 // Pass to tile and fuse linalg.generic on tensors that models reduction. 50 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> 51 CreateTileReductionPass(); 52 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> 53 CreateTileReductionPass(int64_t reduction_vector_size, 54 int64_t reduction_1d_tile_size, 55 llvm::ArrayRef<int64_t> reduction_2d_tile_sizes); 56 57 // Pass to fuse `linalg.fill` into a tiled reduction. 58 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> 59 CreateFuseFillIntoTiledReductionPass(); 60 61 // Pass to replace 'i1' tensor types with 'i8' tensor types. This pass is a 62 // temporary workaround to avoid the problem of vectorizing 'i1' tensors (see 63 // b/205714705). 64 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 65 CreateJitRtLegalizeI1TypesPass(); 66 67 // Pass to vectorize linalg ops. 68 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> 69 CreateVectorizeTiledOpsPass(); 70 71 // Rewrite `vector.multi_reduction` into a sequence of `vector.reduction` ops. 72 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> 73 createRewriteVectorMultiReductionPass(); 74 75 // Code generation passes targeting transpose operations. 76 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> 77 CreateTileTransposePass(); 78 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> 79 CreateLowerVectorTransposePass(); 80 81 // Pass to tile elementwise linalg.generic on tensors. 82 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> CreateTileCWisePass(); 83 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> CreateTileCWisePass( 84 int64_t cwise_tile_size); 85 86 // Pass to tile linalg.fill on tensors. 87 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> CreateTileFillPass(); 88 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> CreateTileFillPass( 89 int64_t cwise_tile_size); 90 91 // Pass to split _Fused Tensorflow kernels into primitives. 92 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> CreateFissionPass(); 93 94 // Pass to fuse Linalg generic operations on Tensors. 95 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> CreateFusionPass(); 96 97 // Pass to optimize broadcasts based on the symbolic shape constraints. 98 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> 99 CreateSymbolicShapeOptimizationPass(bool constraints_only = false); 100 101 // Pass to replace 0-d tensor inputs to LinalgOp with extracted elements. 102 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> 103 CreateDetensorizeLinalgPass(); 104 105 // Creates `tf_device.cluster` operations according to the TF JitRt clustering 106 // policy. 107 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> 108 CreateTfJitRtClusteringPass(); 109 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> 110 CreateTfJitRtClusteringPass(llvm::ArrayRef<std::string> oplist, 111 int min_cluster_size); 112 113 // Pass to replace math ops with approximations. 114 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> 115 CreateMathApproximationPass(llvm::ArrayRef<std::string> oplist = {}); 116 117 // Returns true if the `value` type is a memref that is contiguous in memory. 118 bool IsContiguousMemref(mlir::Value value); 119 120 // Detects the combiner in the body of LinalgOp if any. Currently, only 121 // ops with a single combiner are supported. 122 mlir::FailureOr<mlir::Operation *> DetectCombiner( 123 mlir::linalg::LinalgOp linalg_op); 124 125 // Sets the attribute to the `op` that indicates that the op was transformed. 126 void setTransformationAttr(mlir::OpBuilder &b, mlir::Operation *op); 127 128 // Removes the attribute that indicates that it was transformed. 129 void removeTransformationAttr(mlir::Operation *op); 130 131 // Checks if `op` has the attribute that indicates that it was transformed. 132 bool hasTransformationAttr(mlir::Operation *op); 133 134 } // namespace tensorflow 135 136 #define GEN_PASS_REGISTRATION 137 #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h.inc" 138 139 #endif // TENSORFLOW_COMPILER_MLIR_TFRT_JIT_TF_JITRT_PASSES_H_ 140