1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for lowering HLO/LHLO dialect to scalar shape
17 // operations.
18
19 #include <algorithm>
20 #include <numeric>
21 #include <string>
22 #include <utility>
23
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SetVector.h"
26 #include "llvm/ADT/StringSet.h"
27 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
28 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
29 #include "mlir-hlo/Dialect/mhlo/transforms/map_mhlo_to_scalar_op.h"
30 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
31 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
32 #include "mlir/Dialect/Func/IR/FuncOps.h"
33 #include "mlir/Dialect/Math/IR/Math.h"
34 #include "mlir/Dialect/Tensor/IR/Tensor.h"
35 #include "mlir/IR/Attributes.h"
36 #include "mlir/IR/Builders.h"
37 #include "mlir/IR/BuiltinAttributes.h"
38 #include "mlir/IR/BuiltinOps.h"
39 #include "mlir/IR/BuiltinTypes.h"
40 #include "mlir/IR/Location.h"
41 #include "mlir/IR/MLIRContext.h"
42 #include "mlir/IR/Matchers.h"
43 #include "mlir/IR/Operation.h"
44 #include "mlir/IR/OperationSupport.h"
45 #include "mlir/IR/PatternMatch.h"
46 #include "mlir/IR/TypeUtilities.h"
47 #include "mlir/Pass/Pass.h"
48 #include "mlir/Pass/PassManager.h"
49 #include "mlir/Support/LogicalResult.h"
50 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
51
52 namespace mlir {
53 namespace {
54
55 // We assume that if one of the operands is a FromElements operation that means
56 // it is a shape computation.
opIsShapeComputation(Operation * op)57 bool opIsShapeComputation(Operation *op) {
58 bool foundFromElements = false;
59 for (auto operand : op->getOperands()) {
60 auto shapedTy = operand.getType().template cast<ShapedType>();
61 if (!shapedTy.hasRank() || shapedTy.getRank() > 1) return false;
62 if (auto fromElements =
63 operand.template getDefiningOp<tensor::FromElementsOp>()) {
64 foundFromElements = true;
65 continue;
66 }
67 }
68 return foundFromElements;
69 }
70
71 template <typename OpTy>
72 class MhloElementwiseConverter : public OpRewritePattern<OpTy> {
73 public:
74 using OpRewritePattern<OpTy>::OpRewritePattern;
75
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const76 LogicalResult matchAndRewrite(OpTy op,
77 PatternRewriter &rewriter) const final {
78 if (!opIsShapeComputation(op)) return failure();
79
80 auto resultTy = op.getType().template cast<ShapedType>();
81
82 Location loc = op.getLoc();
83 SmallVector<Value> operands;
84 for (int i = 0, s = resultTy.getNumElements(); i < s; i++) {
85 SmallVector<Value> extracts;
86 for (auto operand : op->getOperands()) {
87 ShapedType operandTy = operand.getType().template cast<ShapedType>();
88 if (operandTy.getRank() == 0) {
89 Value extract =
90 rewriter.create<tensor::ExtractOp>(loc, operand, ValueRange({}));
91 extracts.push_back(extract);
92 } else {
93 Value idx = rewriter.create<arith::ConstantIndexOp>(loc, i);
94 Value extract = rewriter.create<tensor::ExtractOp>(loc, operand, idx);
95 extracts.push_back(extract);
96 }
97 }
98
99 Value scalarOp = mhlo::MhloOpToStdScalarOp::mapOp(
100 op, resultTy.getElementType(), extracts, &rewriter);
101 operands.push_back(scalarOp);
102 }
103
104 rewriter.replaceOpWithNewOp<tensor::FromElementsOp>(op, resultTy, operands);
105
106 return success();
107 }
108 };
109
110 class ConcatenateConverter : public OpRewritePattern<mhlo::ConcatenateOp> {
111 public:
112 using OpRewritePattern<mhlo::ConcatenateOp>::OpRewritePattern;
113
matchAndRewrite(mhlo::ConcatenateOp op,PatternRewriter & rewriter) const114 LogicalResult matchAndRewrite(mhlo::ConcatenateOp op,
115 PatternRewriter &rewriter) const final {
116 if (!opIsShapeComputation(op)) return failure();
117
118 Location loc = op.getLoc();
119 auto resultTy = op.getType().cast<ShapedType>();
120 llvm::SmallVector<Value> elements;
121 elements.reserve(resultTy.getNumElements());
122
123 for (auto operand : op->getOperands()) {
124 ShapedType operandTy = operand.getType().template cast<ShapedType>();
125 if (operandTy.getRank() == 0) {
126 Value extract =
127 rewriter.create<tensor::ExtractOp>(loc, operand, ValueRange({}));
128 elements.push_back(extract);
129 } else {
130 for (int i = 0, s = operandTy.getNumElements(); i < s; i++) {
131 Value idx = rewriter.create<arith::ConstantIndexOp>(loc, i);
132 Value extract = rewriter.create<tensor::ExtractOp>(loc, operand, idx);
133 elements.push_back(extract);
134 }
135 }
136 }
137
138 rewriter.replaceOpWithNewOp<tensor::FromElementsOp>(op, resultTy, elements);
139 return success();
140 }
141 };
142
143 class GetDimSizeConverter : public OpRewritePattern<mhlo::GetDimensionSizeOp> {
144 public:
145 using OpRewritePattern<mhlo::GetDimensionSizeOp>::OpRewritePattern;
146
matchAndRewrite(mhlo::GetDimensionSizeOp op,PatternRewriter & rewriter) const147 LogicalResult matchAndRewrite(mhlo::GetDimensionSizeOp op,
148 PatternRewriter &rewriter) const final {
149 Location loc = op.getLoc();
150 auto resultTy = op.getType();
151 auto elementTy = getElementTypeOrSelf(resultTy);
152 auto dimAttr = rewriter.getIndexAttr(op.dimension());
153 auto dimConst = rewriter.create<arith::ConstantOp>(loc, dimAttr);
154
155 Value dimOp = rewriter.create<tensor::DimOp>(loc, rewriter.getIndexType(),
156 op.operand(), dimConst);
157
158 // Cast to the correct element type and convert to a tensor.
159 Value cast = rewriter.create<arith::IndexCastOp>(loc, elementTy, dimOp);
160 rewriter.replaceOpWithNewOp<tensor::FromElementsOp>(op, resultTy, cast);
161 return success();
162 }
163 };
164
165 class ReshapeConverter : public OpRewritePattern<mhlo::ReshapeOp> {
166 public:
167 using OpRewritePattern<mhlo::ReshapeOp>::OpRewritePattern;
168
matchAndRewrite(mhlo::ReshapeOp op,PatternRewriter & rewriter) const169 LogicalResult matchAndRewrite(mhlo::ReshapeOp op,
170 PatternRewriter &rewriter) const final {
171 auto operand = op.operand();
172 auto shapedTy = operand.getType().template cast<ShapedType>();
173 if (!shapedTy.hasRank() || shapedTy.getRank() > 1) return failure();
174
175 auto resultTy = op.getType().cast<ShapedType>();
176
177 auto fromElements = op.operand().getDefiningOp<tensor::FromElementsOp>();
178 if (!fromElements) return failure();
179
180 rewriter.replaceOpWithNewOp<tensor::FromElementsOp>(
181 op, resultTy, fromElements.getOperands());
182 return success();
183 }
184 };
185
186 struct HloLegalizeShapeComputationsPass
187 : public mhlo::HloLegalizeShapeComputationsPassBase<
188 HloLegalizeShapeComputationsPass> {
getDependentDialectsmlir::__anonb70784cc0111::HloLegalizeShapeComputationsPass189 void getDependentDialects(DialectRegistry ®istry) const override {
190 registry.insert<arith::ArithmeticDialect, math::MathDialect,
191 func::FuncDialect, tensor::TensorDialect>();
192 }
193
runOnOperationmlir::__anonb70784cc0111::HloLegalizeShapeComputationsPass194 void runOnOperation() override {
195 MLIRContext &ctx = getContext();
196 RewritePatternSet patterns(&ctx);
197
198 auto func = getOperation();
199 mhlo::populateShapeComputationPatterns(&ctx, &patterns);
200 if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) {
201 signalPassFailure();
202 }
203 }
204 };
205
206 } // namespace
207
208 namespace mhlo {
209
populateShapeComputationPatterns(MLIRContext * context,RewritePatternSet * patterns)210 void populateShapeComputationPatterns(MLIRContext *context,
211 RewritePatternSet *patterns) {
212 patterns->add<MhloElementwiseConverter<mhlo::AbsOp>,
213 MhloElementwiseConverter<mhlo::AddOp>,
214 MhloElementwiseConverter<mhlo::AndOp>,
215 MhloElementwiseConverter<mhlo::CeilOp>,
216 MhloElementwiseConverter<mhlo::ConvertOp>,
217 MhloElementwiseConverter<mhlo::DivOp>,
218 MhloElementwiseConverter<mhlo::FloorOp>,
219 MhloElementwiseConverter<mhlo::MaxOp>,
220 MhloElementwiseConverter<mhlo::MinOp>,
221 MhloElementwiseConverter<mhlo::MulOp>,
222 MhloElementwiseConverter<mhlo::NegOp>,
223 MhloElementwiseConverter<mhlo::RoundOp>,
224 MhloElementwiseConverter<mhlo::RsqrtOp>,
225 MhloElementwiseConverter<mhlo::SqrtOp>,
226 MhloElementwiseConverter<mhlo::SubtractOp>,
227 ConcatenateConverter, GetDimSizeConverter, ReshapeConverter>(
228 context);
229 }
230
231 std::unique_ptr<OperationPass<func::FuncOp>>
createLegalizeShapeComputationsPass()232 createLegalizeShapeComputationsPass() {
233 return std::make_unique<HloLegalizeShapeComputationsPass>();
234 }
235
236 } // namespace mhlo
237 } // namespace mlir
238