1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <utility>
18 
19 #include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h"
20 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
21 #include "mlir/Dialect/Func/IR/FuncOps.h"
22 #include "mlir/Dialect/Linalg/IR/Linalg.h"
23 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25 #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h"
26 
27 namespace tensorflow {
28 namespace {
29 
30 #define GEN_PASS_CLASSES
31 #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h.inc"
32 
33 using mlir::failure;
34 using mlir::LogicalResult;
35 using mlir::MLIRContext;
36 using mlir::PatternRewriter;
37 using mlir::RewritePatternSet;
38 using mlir::success;
39 using mlir::arith::ConstantIndexOp;
40 using mlir::gml_st::LoopOp;
41 using mlir::linalg::FillOp;
42 using mlir::linalg::GenericOp;
43 using mlir::tensor::ExpandShapeOp;
44 using mlir::vector::TransferReadOp;
45 using mlir::vector::TransferWriteOp;
46 
47 // The upper limit for vectorization of untiled `linalg.fill`. If a tensor has a
48 // static shape with more elements, then `linalg.fill` won't be vectorized. It
49 // is expected that such operations are tiled to get to small static shapes.
50 constexpr int64_t kNumElementsThreshold = 1024;
51 
52 // Rewrite `vector.transfer_read(linalg.expand_shape)` as
53 // `vector.shape_cast(vector.transfer_read)`.
54 struct TransferReadOfOneDimExpandShape
55     : public mlir::OpRewritePattern<TransferReadOp> {
56   using OpRewritePattern<TransferReadOp>::OpRewritePattern;
57 
matchAndRewritetensorflow::__anon093571310111::TransferReadOfOneDimExpandShape58   mlir::LogicalResult matchAndRewrite(
59       TransferReadOp vector_read,
60       mlir::PatternRewriter &rewriter) const override {
61     auto expand = vector_read.getSource().getDefiningOp<ExpandShapeOp>();
62     if (!expand) return failure();
63 
64     auto expand_src = expand.getSrc();
65     auto expand_src_type = expand.getSrcType();
66     auto expand_dst_type = expand.getResultType();
67     if (expand_src_type.getRank() != 1 || expand_dst_type.getRank() != 2)
68       return failure();
69 
70     auto result_type = vector_read.getType().dyn_cast<mlir::ShapedType>();
71     if (!result_type || result_type.getShape() != expand_dst_type.getShape())
72       return failure();
73 
74     auto zero = rewriter.create<ConstantIndexOp>(vector_read.getLoc(), 0);
75     auto map = mlir::AffineMap::get(1, 0, {rewriter.getAffineDimExpr(0)},
76                                     vector_read.getContext());
77     // TODO(pifon): Also support canonicalization in case the map is not an
78     // identity.
79     if (!map.isIdentity()) return failure();
80 
81     auto new_read = rewriter.create<TransferReadOp>(
82         vector_read.getLoc(),
83         mlir::VectorType::get(expand_src_type.getShape(),
84                               expand_src_type.getElementType()),
85         expand_src, mlir::ValueRange{zero}, mlir::AffineMapAttr::get(map),
86         vector_read.getPadding(),
87         /*mask=*/mlir::Value(), rewriter.getBoolArrayAttr({true}));
88     rewriter.replaceOpWithNewOp<mlir::vector::ShapeCastOp>(
89         vector_read, vector_read.getType(), new_read);
90     return success();
91   }
92 };
93 
94 template <typename OpTy>
95 struct VectorizationPattern : public mlir::OpRewritePattern<OpTy> {
VectorizationPatterntensorflow::__anon093571310111::VectorizationPattern96   VectorizationPattern(MLIRContext *context,
97                        llvm::function_ref<bool(OpTy)> match_fn,
98                        mlir::PatternBenefit benefit = 1)
99       : mlir::OpRewritePattern<OpTy>(context, benefit), match_fn(match_fn) {}
100 
matchAndRewritetensorflow::__anon093571310111::VectorizationPattern101   LogicalResult matchAndRewrite(OpTy op,
102                                 PatternRewriter &rewriter) const override {
103     if (!match_fn(op)) return failure();
104     return mlir::linalg::vectorize(rewriter, op);
105   }
106 
107  private:
108   llvm::function_ref<bool(OpTy)> match_fn;
109 };
110 
getDefaultVectorizationPatterns(MLIRContext * ctx)111 RewritePatternSet getDefaultVectorizationPatterns(MLIRContext *ctx) {
112   RewritePatternSet patterns(ctx);
113   mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
114   mlir::vector::populateVectorReductionToContractPatterns(patterns);
115   patterns.add<mlir::linalg::LinalgCopyVTRForwardingPattern,
116                mlir::linalg::LinalgCopyVTWForwardingPattern>(ctx,
117                                                              /*benefit=*/2);
118   TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
119   TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
120   return patterns;
121 }
122 
isFillTiledOrSmall(FillOp fill)123 bool isFillTiledOrSmall(FillOp fill) {
124   if (fill->getParentOfType<LoopOp>()) return true;
125 
126   // Allow vectorization for static shapes with low number of elements.
127   auto output_type = fill.output().getType().cast<mlir::RankedTensorType>();
128   return output_type.hasStaticShape() &&
129          output_type.getNumElements() < kNumElementsThreshold;
130 }
131 
isGenericOpTiledOrOneDimReduction(GenericOp generic)132 bool isGenericOpTiledOrOneDimReduction(GenericOp generic) {
133   if (generic->getParentOfType<LoopOp>()) return true;
134 
135   // Allow vectorization of 1D reductions.
136   return generic.getNumLoops() == 1 && generic.getNumReductionLoops() == 1;
137 }
138 
139 struct VectorizeTiledOpsPass
140     : public VectorizeTiledOpsBase<VectorizeTiledOpsPass> {
getDependentDialectstensorflow::__anon093571310111::VectorizeTiledOpsPass141   void getDependentDialects(mlir::DialectRegistry &registry) const override {
142     registry.insert<mlir::vector::VectorDialect>();
143   }
144 
runOnOperationtensorflow::__anon093571310111::VectorizeTiledOpsPass145   void runOnOperation() override {
146     auto func = getOperation();
147     auto ctx = func.getContext();
148 
149     RewritePatternSet patterns = getDefaultVectorizationPatterns(ctx);
150     patterns.add<TransferReadOfOneDimExpandShape>(func.getContext());
151     patterns.add<VectorizationPattern<FillOp>>(ctx, isFillTiledOrSmall);
152     patterns.add<VectorizationPattern<GenericOp>>(
153         ctx, isGenericOpTiledOrOneDimReduction);
154     (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
155   }
156 };
157 
158 }  // namespace
159 
160 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
CreateVectorizeTiledOpsPass()161 CreateVectorizeTiledOpsPass() {
162   return std::make_unique<VectorizeTiledOpsPass>();
163 }
164 
165 }  // namespace tensorflow
166