1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
17 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
18 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
19 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
20 #include "mlir/Dialect/Func/IR/FuncOps.h"
21 #include "mlir/IR/BuiltinOps.h"
22 #include "mlir/Pass/Pass.h"
23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24
25 namespace mlir {
26
27 namespace mhlo {
28 namespace {
29
30 struct GatherIsTorchIndexSelect : public OpRewritePattern<GatherOp> {
31 using OpRewritePattern<GatherOp>::OpRewritePattern;
32
matchAndRewritemlir::mhlo::__anonc7c8b3a80111::GatherIsTorchIndexSelect33 LogicalResult matchAndRewrite(GatherOp gather,
34 PatternRewriter &rewriter) const override {
35 auto startIndices = gather.start_indices();
36 auto startIndicesTy = startIndices.getType().cast<ShapedType>();
37 if (!startIndicesTy.hasRank()) {
38 return rewriter.notifyMatchFailure(gather, "unranked start_indices");
39 }
40
41 auto operand = gather.operand();
42 auto operandTy = operand.getType().cast<ShapedType>();
43 if (!operandTy.hasRank()) {
44 return rewriter.notifyMatchFailure(gather, "unranked operand");
45 }
46
47 int64_t indexVectorDim = std::max<int64_t>(0, startIndicesTy.getRank() - 1);
48
49 // We can use torch_index_select if the last dimension represents the
50 // gather indices.
51 auto dimensionNumbers = gather.dimension_numbers();
52 if (dimensionNumbers.getIndexVectorDim() != indexVectorDim) {
53 return rewriter.notifyMatchFailure(
54 gather, "index_vector_dim not last dimension of start_indices");
55 }
56
57 // Index select only works across a single dimension.
58 if (!startIndicesTy.getShape().empty() &&
59 startIndicesTy.getShape().back() != 1) {
60 return rewriter.notifyMatchFailure(
61 gather, "start_indices index vector dimension not 1");
62 }
63
64 // Only support the default case for start_index_map.
65 if (dimensionNumbers.getStartIndexMap().size() != 1 ||
66 dimensionNumbers.getStartIndexMap()[0] != 0) {
67 return rewriter.notifyMatchFailure(gather, "start_index_map != [0]");
68 }
69
70 auto resultTy = gather.getResult().getType().dyn_cast<RankedTensorType>();
71 if (!resultTy) {
72 return rewriter.notifyMatchFailure(gather, "unranked result");
73 }
74
75 // Offset dimensions should be the defaults.
76 if (static_cast<int64_t>(dimensionNumbers.getOffsetDims().size()) !=
77 resultTy.getRank() - indexVectorDim) {
78 return rewriter.notifyMatchFailure(
79 gather, "offset_dims.size not operand rank minus index_vector_dim");
80 }
81
82 for (const auto &it : llvm::enumerate(dimensionNumbers.getOffsetDims())) {
83 if (static_cast<int64_t>(it.index() + indexVectorDim) != it.value()) {
84 return rewriter.notifyMatchFailure(
85 gather, "offset_dims != [index_vector_dim, result.rank)");
86 }
87 }
88
89 for (const auto &it :
90 llvm::enumerate(gather.slice_sizes().getValues<APInt>())) {
91 // First shape value must be 1.
92 if (it.index() == 0) {
93 if (it.value().getSExtValue() != 1) {
94 return rewriter.notifyMatchFailure(gather, "slice_size[0] != 1");
95 }
96 continue;
97 }
98
99 // The gather needs to index the entire slice for each other dimension.
100 if (it.value().getSExtValue() != operandTy.getDimSize(it.index())) {
101 return rewriter.notifyMatchFailure(
102 gather, "slice_size doesn't match operand dimension");
103 }
104 }
105
106 llvm::SmallVector<int64_t, 4> indexSelectShape =
107 llvm::to_vector<4>(startIndicesTy.getShape());
108
109 for (auto dim : operandTy.getShape().drop_front()) {
110 indexSelectShape.push_back(dim);
111 }
112
113 if (dimensionNumbers.getCollapsedSliceDims().size() != 1 ||
114 dimensionNumbers.getCollapsedSliceDims()[0] != 0) {
115 return rewriter.notifyMatchFailure(gather, "collapsed_slice_dims != [0]");
116 }
117
118 auto torchIndexSelect = rewriter.create<TorchIndexSelectOp>(
119 gather.getLoc(),
120 RankedTensorType::get(indexSelectShape, operandTy.getElementType()),
121 operand, gather.start_indices(), rewriter.getI64IntegerAttr(0),
122 rewriter.getI64IntegerAttr(0));
123
124 rewriter.replaceOpWithNewOp<ReshapeOp>(gather, gather.getType(),
125 torchIndexSelect);
126
127 return success();
128 }
129 };
130
131 struct LegalizeGatherToTorchIndexSelectPass
132 : public LegalizeGatherToTorchIndexSelectPassBase<
133 LegalizeGatherToTorchIndexSelectPass> {
134 /// Perform the lowering of standard dialect operations to approximations.
runOnOperationmlir::mhlo::__anonc7c8b3a80111::LegalizeGatherToTorchIndexSelectPass135 void runOnOperation() override {
136 RewritePatternSet patterns(&getContext());
137 populateGatherToTorchIndexSelectPatterns(&getContext(), &patterns);
138 if (failed(
139 applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
140 return signalPassFailure();
141 }
142 };
143 } // namespace
144
populateGatherToTorchIndexSelectPatterns(mlir::MLIRContext * context,RewritePatternSet * patterns)145 void populateGatherToTorchIndexSelectPatterns(mlir::MLIRContext *context,
146 RewritePatternSet *patterns) {
147 patterns->add<GatherIsTorchIndexSelect>(context);
148 }
149
150 std::unique_ptr<OperationPass<func::FuncOp>>
createLegalizeGatherToTorchIndexSelectPass()151 createLegalizeGatherToTorchIndexSelectPass() {
152 return std::make_unique<LegalizeGatherToTorchIndexSelectPass>();
153 }
154
155 } // namespace mhlo
156 } // namespace mlir
157