1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file provides optional optimization patterns for mhlo, canonocalizing
17 // operations to equivalent but potentially more efficient operations.
18
19 #include <cstddef>
20 #include <cstdint>
21 #include <iterator>
22 #include <numeric>
23
24 #include "llvm/ADT/STLExtras.h"
25 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
26 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
27 #include "mlir-hlo/utils/hlo_utils.h"
28 #include "mlir/IR/Attributes.h"
29 #include "mlir/IR/MLIRContext.h"
30 #include "mlir/IR/Operation.h"
31 #include "mlir/IR/PatternMatch.h"
32 #include "mlir/IR/TypeUtilities.h"
33 #include "mlir/IR/Types.h"
34 #include "mlir/Pass/Pass.h"
35 #include "mlir/Pass/PassRegistry.h"
36
37 namespace mlir {
38 namespace mhlo {
39 namespace {
40
41 // Returns 1D 64-bit dense elements attribute with the given values.
getI64ElementsAttr(ArrayRef<int64_t> values,Builder * builder)42 static DenseIntElementsAttr getI64ElementsAttr(ArrayRef<int64_t> values,
43 Builder* builder) {
44 RankedTensorType ty = RankedTensorType::get(
45 {static_cast<int64_t>(values.size())}, builder->getIntegerType(64));
46 return DenseIntElementsAttr::get(ty, values);
47 }
48
49 //===----------------------------------------------------------------------===//
50 // GatherOp
51 //===----------------------------------------------------------------------===//
52
53 class GatherIsSlice : public OpRewritePattern<GatherOp> {
54 using OpRewritePattern::OpRewritePattern;
matchAndRewrite(GatherOp gather,PatternRewriter & rewriter) const55 LogicalResult matchAndRewrite(GatherOp gather,
56 PatternRewriter& rewriter) const override {
57 auto dimensionNumbers = gather.dimension_numbers();
58
59 // Inputs need to be ranked to lower.
60 if (!gather.operand().getType().cast<ShapedType>().hasRank() ||
61 !gather.operand().getType().cast<ShapedType>().hasStaticShape() ||
62 !gather.start_indices().getType().cast<ShapedType>().hasRank() ||
63 !gather.start_indices().getType().cast<ShapedType>().hasStaticShape()) {
64 return rewriter.notifyMatchFailure(gather,
65 "non-static operand or start_indices");
66 }
67
68 if (dimensionNumbers.getIndexVectorDim() != 0) {
69 return rewriter.notifyMatchFailure(gather, "non-zero index_vector_dim");
70 }
71
72 // TODO(suderman): Handle start index map != {0}.
73 if (dimensionNumbers.getStartIndexMap().empty() ||
74 dimensionNumbers.getStartIndexMap().size() != 1 ||
75 dimensionNumbers.getStartIndexMap()[0] != 0) {
76 return rewriter.notifyMatchFailure(gather,
77 "start_index_map not empty or [0]");
78 }
79
80 auto resultTy = gather.getResult().getType().dyn_cast<RankedTensorType>();
81
82 if (!resultTy) {
83 return rewriter.notifyMatchFailure(gather, "unranked result");
84 }
85 if (static_cast<int64_t>(dimensionNumbers.getOffsetDims().size()) !=
86 resultTy.getRank()) {
87 return rewriter.notifyMatchFailure(gather,
88 "offset_dims.size != operand.rank");
89 }
90 for (const auto& it : llvm::enumerate(dimensionNumbers.getOffsetDims())) {
91 if (static_cast<int64_t>(it.index()) != it.value()) {
92 return rewriter.notifyMatchFailure(gather,
93 "offset_dims != [0, result.rank)");
94 }
95 }
96
97 if (gather.slice_sizes().size() <= resultTy.getRank()) {
98 return rewriter.notifyMatchFailure(gather,
99 "slices_size.size > result.rank");
100 }
101
102 for (const auto& it : llvm::enumerate(resultTy.getShape())) {
103 if (gather.slice_sizes().getValues<int64_t>()[it.index() + 1] !=
104 it.value()) {
105 return failure();
106 }
107 }
108
109 auto gatherStartIndices = gather.start_indices();
110 auto gatherStartIndicesTy = gatherStartIndices.getType().cast<ShapedType>();
111
112 llvm::SmallVector<Value, 4> sliceStartIndices;
113
114 if (gatherStartIndicesTy.getRank() == 0) {
115 sliceStartIndices.push_back(gatherStartIndices);
116 } else if (gatherStartIndicesTy.getRank() == 1) {
117 for (int i = 0; i < gatherStartIndicesTy.getDimSize(0); i++) {
118 auto start = getI64ElementsAttr({i}, &rewriter);
119 auto limit = getI64ElementsAttr({i + 1}, &rewriter);
120 auto stride = getI64ElementsAttr({1}, &rewriter);
121 auto indicesSlice = rewriter.create<SliceOp>(
122 gather.getLoc(), gatherStartIndices, start, limit, stride);
123 auto reshaped = rewriter.create<ReshapeOp>(
124 gather.getLoc(),
125 RankedTensorType::get(
126 {}, indicesSlice.getType().cast<ShapedType>().getElementType()),
127 indicesSlice);
128 sliceStartIndices.push_back(reshaped);
129 }
130 } else {
131 return rewriter.notifyMatchFailure(gather, "start_indices.rank > 1");
132 }
133
134 auto sliceSizesTy = gather.slice_sizes().getType();
135
136 // Start indices have implicit zeros when not specified. This is because
137 // Gather occurs similar to slicing where full slices are inferred. Add any
138 // missing zeros as necessary.
139 auto zero = rewriter.create<ConstantOp>(
140 gather.getLoc(), rewriter.getZeroAttr(RankedTensorType::get(
141 {}, gatherStartIndicesTy.getElementType())));
142 while (static_cast<int64_t>(sliceStartIndices.size()) <
143 sliceSizesTy.getDimSize(0)) {
144 sliceStartIndices.push_back(zero);
145 }
146
147 SmallVector<int64_t, 5> sliceShape;
148 for (auto shapeValue : gather.slice_sizes().getValues<APInt>()) {
149 sliceShape.push_back(shapeValue.getSExtValue());
150 }
151
152 auto sliceTy = RankedTensorType::get(sliceShape, resultTy.getElementType());
153 auto slice = rewriter.create<DynamicSliceOp>(
154 gather.getLoc(), sliceTy, gather.operand(), sliceStartIndices,
155 gather.slice_sizes());
156
157 rewriter.replaceOpWithNewOp<ReshapeOp>(gather, gather.getType(), slice);
158
159 return success();
160 }
161 };
162
163 } // end anonymous namespace
164
populateOptimizeMhloPatterns(MLIRContext * context,RewritePatternSet * patterns)165 void populateOptimizeMhloPatterns(MLIRContext* context,
166 RewritePatternSet* patterns) {
167 patterns->add<GatherIsSlice>(context);
168 }
169 } // end namespace mhlo
170 } // end namespace mlir
171