xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/optimize_mhlo.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file provides optional optimization patterns for mhlo, canonocalizing
17 // operations to equivalent but potentially more efficient operations.
18 
19 #include <cstddef>
20 #include <cstdint>
21 #include <iterator>
22 #include <numeric>
23 
24 #include "llvm/ADT/STLExtras.h"
25 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
26 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
27 #include "mlir-hlo/utils/hlo_utils.h"
28 #include "mlir/IR/Attributes.h"
29 #include "mlir/IR/MLIRContext.h"
30 #include "mlir/IR/Operation.h"
31 #include "mlir/IR/PatternMatch.h"
32 #include "mlir/IR/TypeUtilities.h"
33 #include "mlir/IR/Types.h"
34 #include "mlir/Pass/Pass.h"
35 #include "mlir/Pass/PassRegistry.h"
36 
37 namespace mlir {
38 namespace mhlo {
39 namespace {
40 
41 // Returns 1D 64-bit dense elements attribute with the given values.
getI64ElementsAttr(ArrayRef<int64_t> values,Builder * builder)42 static DenseIntElementsAttr getI64ElementsAttr(ArrayRef<int64_t> values,
43                                                Builder* builder) {
44   RankedTensorType ty = RankedTensorType::get(
45       {static_cast<int64_t>(values.size())}, builder->getIntegerType(64));
46   return DenseIntElementsAttr::get(ty, values);
47 }
48 
49 //===----------------------------------------------------------------------===//
50 // GatherOp
51 //===----------------------------------------------------------------------===//
52 
53 class GatherIsSlice : public OpRewritePattern<GatherOp> {
54   using OpRewritePattern::OpRewritePattern;
matchAndRewrite(GatherOp gather,PatternRewriter & rewriter) const55   LogicalResult matchAndRewrite(GatherOp gather,
56                                 PatternRewriter& rewriter) const override {
57     auto dimensionNumbers = gather.dimension_numbers();
58 
59     // Inputs need to be ranked to lower.
60     if (!gather.operand().getType().cast<ShapedType>().hasRank() ||
61         !gather.operand().getType().cast<ShapedType>().hasStaticShape() ||
62         !gather.start_indices().getType().cast<ShapedType>().hasRank() ||
63         !gather.start_indices().getType().cast<ShapedType>().hasStaticShape()) {
64       return rewriter.notifyMatchFailure(gather,
65                                          "non-static operand or start_indices");
66     }
67 
68     if (dimensionNumbers.getIndexVectorDim() != 0) {
69       return rewriter.notifyMatchFailure(gather, "non-zero index_vector_dim");
70     }
71 
72     // TODO(suderman): Handle start index map != {0}.
73     if (dimensionNumbers.getStartIndexMap().empty() ||
74         dimensionNumbers.getStartIndexMap().size() != 1 ||
75         dimensionNumbers.getStartIndexMap()[0] != 0) {
76       return rewriter.notifyMatchFailure(gather,
77                                          "start_index_map not empty or [0]");
78     }
79 
80     auto resultTy = gather.getResult().getType().dyn_cast<RankedTensorType>();
81 
82     if (!resultTy) {
83       return rewriter.notifyMatchFailure(gather, "unranked result");
84     }
85     if (static_cast<int64_t>(dimensionNumbers.getOffsetDims().size()) !=
86         resultTy.getRank()) {
87       return rewriter.notifyMatchFailure(gather,
88                                          "offset_dims.size != operand.rank");
89     }
90     for (const auto& it : llvm::enumerate(dimensionNumbers.getOffsetDims())) {
91       if (static_cast<int64_t>(it.index()) != it.value()) {
92         return rewriter.notifyMatchFailure(gather,
93                                            "offset_dims != [0, result.rank)");
94       }
95     }
96 
97     if (gather.slice_sizes().size() <= resultTy.getRank()) {
98       return rewriter.notifyMatchFailure(gather,
99                                          "slices_size.size > result.rank");
100     }
101 
102     for (const auto& it : llvm::enumerate(resultTy.getShape())) {
103       if (gather.slice_sizes().getValues<int64_t>()[it.index() + 1] !=
104           it.value()) {
105         return failure();
106       }
107     }
108 
109     auto gatherStartIndices = gather.start_indices();
110     auto gatherStartIndicesTy = gatherStartIndices.getType().cast<ShapedType>();
111 
112     llvm::SmallVector<Value, 4> sliceStartIndices;
113 
114     if (gatherStartIndicesTy.getRank() == 0) {
115       sliceStartIndices.push_back(gatherStartIndices);
116     } else if (gatherStartIndicesTy.getRank() == 1) {
117       for (int i = 0; i < gatherStartIndicesTy.getDimSize(0); i++) {
118         auto start = getI64ElementsAttr({i}, &rewriter);
119         auto limit = getI64ElementsAttr({i + 1}, &rewriter);
120         auto stride = getI64ElementsAttr({1}, &rewriter);
121         auto indicesSlice = rewriter.create<SliceOp>(
122             gather.getLoc(), gatherStartIndices, start, limit, stride);
123         auto reshaped = rewriter.create<ReshapeOp>(
124             gather.getLoc(),
125             RankedTensorType::get(
126                 {}, indicesSlice.getType().cast<ShapedType>().getElementType()),
127             indicesSlice);
128         sliceStartIndices.push_back(reshaped);
129       }
130     } else {
131       return rewriter.notifyMatchFailure(gather, "start_indices.rank > 1");
132     }
133 
134     auto sliceSizesTy = gather.slice_sizes().getType();
135 
136     // Start indices have implicit zeros when not specified. This is because
137     // Gather occurs similar to slicing where full slices are inferred. Add any
138     // missing zeros as necessary.
139     auto zero = rewriter.create<ConstantOp>(
140         gather.getLoc(), rewriter.getZeroAttr(RankedTensorType::get(
141                              {}, gatherStartIndicesTy.getElementType())));
142     while (static_cast<int64_t>(sliceStartIndices.size()) <
143            sliceSizesTy.getDimSize(0)) {
144       sliceStartIndices.push_back(zero);
145     }
146 
147     SmallVector<int64_t, 5> sliceShape;
148     for (auto shapeValue : gather.slice_sizes().getValues<APInt>()) {
149       sliceShape.push_back(shapeValue.getSExtValue());
150     }
151 
152     auto sliceTy = RankedTensorType::get(sliceShape, resultTy.getElementType());
153     auto slice = rewriter.create<DynamicSliceOp>(
154         gather.getLoc(), sliceTy, gather.operand(), sliceStartIndices,
155         gather.slice_sizes());
156 
157     rewriter.replaceOpWithNewOp<ReshapeOp>(gather, gather.getType(), slice);
158 
159     return success();
160   }
161 };
162 
163 }  // end anonymous namespace
164 
populateOptimizeMhloPatterns(MLIRContext * context,RewritePatternSet * patterns)165 void populateOptimizeMhloPatterns(MLIRContext* context,
166                                   RewritePatternSet* patterns) {
167   patterns->add<GatherIsSlice>(context);
168 }
169 }  // end namespace mhlo
170 }  // end namespace mlir
171