1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7    http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for lowering HLO/LHLO dialect to scalar shape
17 // operations.
18 
19 #include <algorithm>
20 #include <numeric>
21 #include <string>
22 #include <utility>
23 
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SetVector.h"
26 #include "llvm/ADT/StringSet.h"
27 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
28 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
29 #include "mlir-hlo/Dialect/mhlo/transforms/map_mhlo_to_scalar_op.h"
30 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
31 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
32 #include "mlir/Dialect/Func/IR/FuncOps.h"
33 #include "mlir/Dialect/Math/IR/Math.h"
34 #include "mlir/Dialect/Tensor/IR/Tensor.h"
35 #include "mlir/IR/Attributes.h"
36 #include "mlir/IR/Builders.h"
37 #include "mlir/IR/BuiltinAttributes.h"
38 #include "mlir/IR/BuiltinOps.h"
39 #include "mlir/IR/BuiltinTypes.h"
40 #include "mlir/IR/Location.h"
41 #include "mlir/IR/MLIRContext.h"
42 #include "mlir/IR/Matchers.h"
43 #include "mlir/IR/Operation.h"
44 #include "mlir/IR/OperationSupport.h"
45 #include "mlir/IR/PatternMatch.h"
46 #include "mlir/IR/TypeUtilities.h"
47 #include "mlir/Pass/Pass.h"
48 #include "mlir/Pass/PassManager.h"
49 #include "mlir/Support/LogicalResult.h"
50 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
51 
52 namespace mlir {
53 namespace {
54 
55 // We assume that if one of the operands is a FromElements operation that means
56 // it is a shape computation.
opIsShapeComputation(Operation * op)57 bool opIsShapeComputation(Operation *op) {
58   bool foundFromElements = false;
59   for (auto operand : op->getOperands()) {
60     auto shapedTy = operand.getType().template cast<ShapedType>();
61     if (!shapedTy.hasRank() || shapedTy.getRank() > 1) return false;
62     if (auto fromElements =
63             operand.template getDefiningOp<tensor::FromElementsOp>()) {
64       foundFromElements = true;
65       continue;
66     }
67   }
68   return foundFromElements;
69 }
70 
71 template <typename OpTy>
72 class MhloElementwiseConverter : public OpRewritePattern<OpTy> {
73  public:
74   using OpRewritePattern<OpTy>::OpRewritePattern;
75 
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const76   LogicalResult matchAndRewrite(OpTy op,
77                                 PatternRewriter &rewriter) const final {
78     if (!opIsShapeComputation(op)) return failure();
79 
80     auto resultTy = op.getType().template cast<ShapedType>();
81 
82     Location loc = op.getLoc();
83     SmallVector<Value> operands;
84     for (int i = 0, s = resultTy.getNumElements(); i < s; i++) {
85       SmallVector<Value> extracts;
86       for (auto operand : op->getOperands()) {
87         ShapedType operandTy = operand.getType().template cast<ShapedType>();
88         if (operandTy.getRank() == 0) {
89           Value extract =
90               rewriter.create<tensor::ExtractOp>(loc, operand, ValueRange({}));
91           extracts.push_back(extract);
92         } else {
93           Value idx = rewriter.create<arith::ConstantIndexOp>(loc, i);
94           Value extract = rewriter.create<tensor::ExtractOp>(loc, operand, idx);
95           extracts.push_back(extract);
96         }
97       }
98 
99       Value scalarOp = mhlo::MhloOpToStdScalarOp::mapOp(
100           op, resultTy.getElementType(), extracts, &rewriter);
101       operands.push_back(scalarOp);
102     }
103 
104     rewriter.replaceOpWithNewOp<tensor::FromElementsOp>(op, resultTy, operands);
105 
106     return success();
107   }
108 };
109 
110 class ConcatenateConverter : public OpRewritePattern<mhlo::ConcatenateOp> {
111  public:
112   using OpRewritePattern<mhlo::ConcatenateOp>::OpRewritePattern;
113 
matchAndRewrite(mhlo::ConcatenateOp op,PatternRewriter & rewriter) const114   LogicalResult matchAndRewrite(mhlo::ConcatenateOp op,
115                                 PatternRewriter &rewriter) const final {
116     if (!opIsShapeComputation(op)) return failure();
117 
118     Location loc = op.getLoc();
119     auto resultTy = op.getType().cast<ShapedType>();
120     llvm::SmallVector<Value> elements;
121     elements.reserve(resultTy.getNumElements());
122 
123     for (auto operand : op->getOperands()) {
124       ShapedType operandTy = operand.getType().template cast<ShapedType>();
125       if (operandTy.getRank() == 0) {
126         Value extract =
127             rewriter.create<tensor::ExtractOp>(loc, operand, ValueRange({}));
128         elements.push_back(extract);
129       } else {
130         for (int i = 0, s = operandTy.getNumElements(); i < s; i++) {
131           Value idx = rewriter.create<arith::ConstantIndexOp>(loc, i);
132           Value extract = rewriter.create<tensor::ExtractOp>(loc, operand, idx);
133           elements.push_back(extract);
134         }
135       }
136     }
137 
138     rewriter.replaceOpWithNewOp<tensor::FromElementsOp>(op, resultTy, elements);
139     return success();
140   }
141 };
142 
143 class GetDimSizeConverter : public OpRewritePattern<mhlo::GetDimensionSizeOp> {
144  public:
145   using OpRewritePattern<mhlo::GetDimensionSizeOp>::OpRewritePattern;
146 
matchAndRewrite(mhlo::GetDimensionSizeOp op,PatternRewriter & rewriter) const147   LogicalResult matchAndRewrite(mhlo::GetDimensionSizeOp op,
148                                 PatternRewriter &rewriter) const final {
149     Location loc = op.getLoc();
150     auto resultTy = op.getType();
151     auto elementTy = getElementTypeOrSelf(resultTy);
152     auto dimAttr = rewriter.getIndexAttr(op.dimension());
153     auto dimConst = rewriter.create<arith::ConstantOp>(loc, dimAttr);
154 
155     Value dimOp = rewriter.create<tensor::DimOp>(loc, rewriter.getIndexType(),
156                                                  op.operand(), dimConst);
157 
158     // Cast to the correct element type and convert to a tensor.
159     Value cast = rewriter.create<arith::IndexCastOp>(loc, elementTy, dimOp);
160     rewriter.replaceOpWithNewOp<tensor::FromElementsOp>(op, resultTy, cast);
161     return success();
162   }
163 };
164 
165 class ReshapeConverter : public OpRewritePattern<mhlo::ReshapeOp> {
166  public:
167   using OpRewritePattern<mhlo::ReshapeOp>::OpRewritePattern;
168 
matchAndRewrite(mhlo::ReshapeOp op,PatternRewriter & rewriter) const169   LogicalResult matchAndRewrite(mhlo::ReshapeOp op,
170                                 PatternRewriter &rewriter) const final {
171     auto operand = op.operand();
172     auto shapedTy = operand.getType().template cast<ShapedType>();
173     if (!shapedTy.hasRank() || shapedTy.getRank() > 1) return failure();
174 
175     auto resultTy = op.getType().cast<ShapedType>();
176 
177     auto fromElements = op.operand().getDefiningOp<tensor::FromElementsOp>();
178     if (!fromElements) return failure();
179 
180     rewriter.replaceOpWithNewOp<tensor::FromElementsOp>(
181         op, resultTy, fromElements.getOperands());
182     return success();
183   }
184 };
185 
186 struct HloLegalizeShapeComputationsPass
187     : public mhlo::HloLegalizeShapeComputationsPassBase<
188           HloLegalizeShapeComputationsPass> {
getDependentDialectsmlir::__anonb70784cc0111::HloLegalizeShapeComputationsPass189   void getDependentDialects(DialectRegistry &registry) const override {
190     registry.insert<arith::ArithmeticDialect, math::MathDialect,
191                     func::FuncDialect, tensor::TensorDialect>();
192   }
193 
runOnOperationmlir::__anonb70784cc0111::HloLegalizeShapeComputationsPass194   void runOnOperation() override {
195     MLIRContext &ctx = getContext();
196     RewritePatternSet patterns(&ctx);
197 
198     auto func = getOperation();
199     mhlo::populateShapeComputationPatterns(&ctx, &patterns);
200     if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) {
201       signalPassFailure();
202     }
203   }
204 };
205 
206 }  // namespace
207 
208 namespace mhlo {
209 
populateShapeComputationPatterns(MLIRContext * context,RewritePatternSet * patterns)210 void populateShapeComputationPatterns(MLIRContext *context,
211                                       RewritePatternSet *patterns) {
212   patterns->add<MhloElementwiseConverter<mhlo::AbsOp>,
213                 MhloElementwiseConverter<mhlo::AddOp>,
214                 MhloElementwiseConverter<mhlo::AndOp>,
215                 MhloElementwiseConverter<mhlo::CeilOp>,
216                 MhloElementwiseConverter<mhlo::ConvertOp>,
217                 MhloElementwiseConverter<mhlo::DivOp>,
218                 MhloElementwiseConverter<mhlo::FloorOp>,
219                 MhloElementwiseConverter<mhlo::MaxOp>,
220                 MhloElementwiseConverter<mhlo::MinOp>,
221                 MhloElementwiseConverter<mhlo::MulOp>,
222                 MhloElementwiseConverter<mhlo::NegOp>,
223                 MhloElementwiseConverter<mhlo::RoundOp>,
224                 MhloElementwiseConverter<mhlo::RsqrtOp>,
225                 MhloElementwiseConverter<mhlo::SqrtOp>,
226                 MhloElementwiseConverter<mhlo::SubtractOp>,
227                 ConcatenateConverter, GetDimSizeConverter, ReshapeConverter>(
228       context);
229 }
230 
231 std::unique_ptr<OperationPass<func::FuncOp>>
createLegalizeShapeComputationsPass()232 createLegalizeShapeComputationsPass() {
233   return std::make_unique<HloLegalizeShapeComputationsPass>();
234 }
235 
236 }  // namespace mhlo
237 }  // namespace mlir
238