1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
17 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
18 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
19 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
20 #include "mlir/Dialect/Func/IR/FuncOps.h"
21 #include "mlir/IR/BuiltinOps.h"
22 #include "mlir/Pass/Pass.h"
23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24 
25 namespace mlir {
26 
27 namespace mhlo {
28 namespace {
29 
30 struct GatherIsTorchIndexSelect : public OpRewritePattern<GatherOp> {
31   using OpRewritePattern<GatherOp>::OpRewritePattern;
32 
matchAndRewritemlir::mhlo::__anonc7c8b3a80111::GatherIsTorchIndexSelect33   LogicalResult matchAndRewrite(GatherOp gather,
34                                 PatternRewriter &rewriter) const override {
35     auto startIndices = gather.start_indices();
36     auto startIndicesTy = startIndices.getType().cast<ShapedType>();
37     if (!startIndicesTy.hasRank()) {
38       return rewriter.notifyMatchFailure(gather, "unranked start_indices");
39     }
40 
41     auto operand = gather.operand();
42     auto operandTy = operand.getType().cast<ShapedType>();
43     if (!operandTy.hasRank()) {
44       return rewriter.notifyMatchFailure(gather, "unranked operand");
45     }
46 
47     int64_t indexVectorDim = std::max<int64_t>(0, startIndicesTy.getRank() - 1);
48 
49     // We can use torch_index_select if the last dimension represents the
50     // gather indices.
51     auto dimensionNumbers = gather.dimension_numbers();
52     if (dimensionNumbers.getIndexVectorDim() != indexVectorDim) {
53       return rewriter.notifyMatchFailure(
54           gather, "index_vector_dim not last dimension of start_indices");
55     }
56 
57     // Index select only works across a single dimension.
58     if (!startIndicesTy.getShape().empty() &&
59         startIndicesTy.getShape().back() != 1) {
60       return rewriter.notifyMatchFailure(
61           gather, "start_indices index vector dimension not 1");
62     }
63 
64     // Only support the default case for start_index_map.
65     if (dimensionNumbers.getStartIndexMap().size() != 1 ||
66         dimensionNumbers.getStartIndexMap()[0] != 0) {
67       return rewriter.notifyMatchFailure(gather, "start_index_map != [0]");
68     }
69 
70     auto resultTy = gather.getResult().getType().dyn_cast<RankedTensorType>();
71     if (!resultTy) {
72       return rewriter.notifyMatchFailure(gather, "unranked result");
73     }
74 
75     // Offset dimensions should be the defaults.
76     if (static_cast<int64_t>(dimensionNumbers.getOffsetDims().size()) !=
77         resultTy.getRank() - indexVectorDim) {
78       return rewriter.notifyMatchFailure(
79           gather, "offset_dims.size not operand rank minus index_vector_dim");
80     }
81 
82     for (const auto &it : llvm::enumerate(dimensionNumbers.getOffsetDims())) {
83       if (static_cast<int64_t>(it.index() + indexVectorDim) != it.value()) {
84         return rewriter.notifyMatchFailure(
85             gather, "offset_dims != [index_vector_dim, result.rank)");
86       }
87     }
88 
89     for (const auto &it :
90          llvm::enumerate(gather.slice_sizes().getValues<APInt>())) {
91       // First shape value must be 1.
92       if (it.index() == 0) {
93         if (it.value().getSExtValue() != 1) {
94           return rewriter.notifyMatchFailure(gather, "slice_size[0] != 1");
95         }
96         continue;
97       }
98 
99       // The gather needs to index the entire slice for each other dimension.
100       if (it.value().getSExtValue() != operandTy.getDimSize(it.index())) {
101         return rewriter.notifyMatchFailure(
102             gather, "slice_size doesn't match operand dimension");
103       }
104     }
105 
106     llvm::SmallVector<int64_t, 4> indexSelectShape =
107         llvm::to_vector<4>(startIndicesTy.getShape());
108 
109     for (auto dim : operandTy.getShape().drop_front()) {
110       indexSelectShape.push_back(dim);
111     }
112 
113     if (dimensionNumbers.getCollapsedSliceDims().size() != 1 ||
114         dimensionNumbers.getCollapsedSliceDims()[0] != 0) {
115       return rewriter.notifyMatchFailure(gather, "collapsed_slice_dims != [0]");
116     }
117 
118     auto torchIndexSelect = rewriter.create<TorchIndexSelectOp>(
119         gather.getLoc(),
120         RankedTensorType::get(indexSelectShape, operandTy.getElementType()),
121         operand, gather.start_indices(), rewriter.getI64IntegerAttr(0),
122         rewriter.getI64IntegerAttr(0));
123 
124     rewriter.replaceOpWithNewOp<ReshapeOp>(gather, gather.getType(),
125                                            torchIndexSelect);
126 
127     return success();
128   }
129 };
130 
131 struct LegalizeGatherToTorchIndexSelectPass
132     : public LegalizeGatherToTorchIndexSelectPassBase<
133           LegalizeGatherToTorchIndexSelectPass> {
134   /// Perform the lowering of standard dialect operations to approximations.
runOnOperationmlir::mhlo::__anonc7c8b3a80111::LegalizeGatherToTorchIndexSelectPass135   void runOnOperation() override {
136     RewritePatternSet patterns(&getContext());
137     populateGatherToTorchIndexSelectPatterns(&getContext(), &patterns);
138     if (failed(
139             applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
140       return signalPassFailure();
141   }
142 };
143 }  // namespace
144 
populateGatherToTorchIndexSelectPatterns(mlir::MLIRContext * context,RewritePatternSet * patterns)145 void populateGatherToTorchIndexSelectPatterns(mlir::MLIRContext *context,
146                                               RewritePatternSet *patterns) {
147   patterns->add<GatherIsTorchIndexSelect>(context);
148 }
149 
150 std::unique_ptr<OperationPass<func::FuncOp>>
createLegalizeGatherToTorchIndexSelectPass()151 createLegalizeGatherToTorchIndexSelectPass() {
152   return std::make_unique<LegalizeGatherToTorchIndexSelectPass>();
153 }
154 
155 }  // namespace mhlo
156 }  // namespace mlir
157