1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for lowering HLO/LHLO dialect to Linalg dialect.
17 
18 #include <algorithm>
19 #include <numeric>
20 #include <string>
21 #include <utility>
22 
23 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
24 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
25 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
26 #include "mlir-hlo/Dialect/mhlo/transforms/type_conversion.h"
27 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
28 #include "mlir/Dialect/Func/IR/FuncOps.h"
29 #include "mlir/Dialect/Shape/IR/Shape.h"
30 #include "mlir/Dialect/Tensor/IR/Tensor.h"
31 #include "mlir/IR/AffineExpr.h"
32 #include "mlir/IR/Attributes.h"
33 #include "mlir/IR/Builders.h"
34 #include "mlir/IR/BuiltinAttributes.h"
35 #include "mlir/IR/BuiltinOps.h"
36 #include "mlir/IR/BuiltinTypes.h"
37 #include "mlir/IR/Location.h"
38 #include "mlir/IR/MLIRContext.h"
39 #include "mlir/Pass/Pass.h"
40 #include "mlir/Pass/PassManager.h"
41 #include "mlir/Support/LogicalResult.h"
42 #include "mlir/Transforms/DialectConversion.h"
43 
44 namespace mlir {
45 namespace {
46 
47 struct ComputeReshapeShapeConversion
48     : public OpConversionPattern<mhlo::ComputeReshapeShapeOp> {
49   using OpConversionPattern<mhlo::ComputeReshapeShapeOp>::OpConversionPattern;
matchAndRewritemlir::__anon4c3ac5cc0111::ComputeReshapeShapeConversion50   LogicalResult matchAndRewrite(
51       mhlo::ComputeReshapeShapeOp op, OpAdaptor adaptor,
52       ConversionPatternRewriter& rewriter) const final {
53     auto loc = op.getLoc();
54     auto* ctx = op->getContext();
55     Value negOne = rewriter.create<arith::ConstantIndexOp>(loc, -1);
56     auto indexType = rewriter.getIndexType();
57     auto numElements = adaptor.getOperands()[0];
58     auto targetShapeType =
59         adaptor.getOperands()[1].getType().cast<ShapedType>();
60     auto extentType =
61         shape::getExtentTensorType(ctx, targetShapeType.getDimSize(0));
62 
63     // Calculate the computed actual extent for a possible dynamic extent.
64     auto newShape = targetShapeType.getElementType().isIndex()
65                         ? adaptor.getOperands()[1]
66                         : rewriter.create<arith::IndexCastOp>(
67                               loc, extentType, adaptor.getOperands()[1]);
68     Value newShapeRank =
69         rewriter.create<shape::RankOp>(loc, indexType, newShape);
70     // The product begins with a -1 seed which will cancel out a -1 extent in
71     // the input shape if there is one. If there is not, this computed result
72     // will never be used, so it's okay to compute a negative number of
73     // elements.
74     auto accountedNumEls =
75         rewriter.create<shape::ReduceOp>(loc, newShape, negOne);
76     {
77       PatternRewriter::InsertionGuard g(rewriter);
78       rewriter.setInsertionPointToEnd(accountedNumEls.getBody());
79       Value lhs = accountedNumEls.getBody()->getArgument(1);
80       Value rhs = accountedNumEls.getBody()->getArgument(2);
81       rewriter.create<shape::YieldOp>(
82           loc, rewriter.create<arith::MulIOp>(loc, lhs, rhs).getResult());
83     }
84     Value missingDimVal = rewriter.create<arith::DivUIOp>(
85         loc, numElements, accountedNumEls->getResult(0));
86 
87     // Create the final target shape with a possible dynamic extent replace with
88     // the calculated extent.
89     SmallVector<Value> dynamicExtent;
90     if (!targetShapeType.hasStaticShape())
91       dynamicExtent.push_back(newShapeRank);
92     auto gen = rewriter.create<tensor::GenerateOp>(
93         loc, targetShapeType, dynamicExtent,
94         [&](OpBuilder& b, Location loc, ValueRange indices) {
95           Value extent = b.create<shape::GetExtentOp>(loc, indexType, newShape,
96                                                       indices[0]);
97           Value useMissingDimVal = b.create<arith::CmpIOp>(
98               loc, arith::CmpIPredicate::eq, extent, negOne);
99           Value dimVal = b.create<arith::SelectOp>(loc, useMissingDimVal,
100                                                    missingDimVal, extent);
101           dimVal = targetShapeType.getElementType().isIndex()
102                        ? dimVal
103                        : b.create<arith::IndexCastOp>(
104                              loc, targetShapeType.getElementType(), dimVal);
105           b.create<tensor::YieldOp>(loc, dimVal);
106         });
107     rewriter.replaceOp(op, gen.getResult());
108 
109     return success();
110   }
111 };
112 
113 struct CstrReshapableConversion
114     : public OpConversionPattern<mhlo::CstrReshapableOp> {
115   using OpConversionPattern<mhlo::CstrReshapableOp>::OpConversionPattern;
matchAndRewritemlir::__anon4c3ac5cc0111::CstrReshapableConversion116   LogicalResult matchAndRewrite(
117       mhlo::CstrReshapableOp op, OpAdaptor adaptor,
118       ConversionPatternRewriter& rewriter) const final {
119     auto loc = op.getLoc();
120     auto* ctx = op->getContext();
121     Value negOne = rewriter.create<arith::ConstantIndexOp>(loc, -1);
122     Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
123     Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
124     auto numElements = adaptor.getOperands()[0];
125     auto targetShapeType =
126         adaptor.getOperands()[1].getType().cast<ShapedType>();
127     auto extentType =
128         shape::getExtentTensorType(ctx, targetShapeType.getDimSize(0));
129 
130     // Calculate the computed actual extent for a possible dynamic extent.
131     auto newShape = targetShapeType.getElementType().isIndex()
132                         ? adaptor.getOperands()[1]
133                         : rewriter.create<arith::IndexCastOp>(
134                               loc, extentType, adaptor.getOperands()[1]);
135     auto reduction = rewriter.create<shape::ReduceOp>(
136         loc, newShape, llvm::makeArrayRef({one, zero, zero}));
137     {
138       PatternRewriter::InsertionGuard g(rewriter);
139       auto* body = reduction.getBody();
140       rewriter.setInsertionPointToEnd(body);
141       Value extent = body->getArgument(1);
142       Value isDynamic = rewriter.create<arith::CmpIOp>(
143           loc, arith::CmpIPredicate::eq, negOne, extent);
144       Value isInvalid = rewriter.create<arith::CmpIOp>(
145           loc, arith::CmpIPredicate::slt, extent, negOne);
146       Value totalDynamic = rewriter.create<arith::AddIOp>(
147           loc, rewriter.create<arith::SelectOp>(loc, isDynamic, one, zero),
148           body->getArgument(3));
149       Value totalInvalid = rewriter.create<arith::AddIOp>(
150           loc, rewriter.create<arith::SelectOp>(loc, isInvalid, one, zero),
151           body->getArgument(4));
152       Value extentOrOne =
153           rewriter.create<arith::SelectOp>(loc, isDynamic, one, extent);
154       Value totalElements = rewriter.create<arith::MulIOp>(
155           loc, extentOrOne, body->getArgument(2));
156       rewriter.create<shape::YieldOp>(
157           loc, llvm::makeArrayRef({totalElements, totalDynamic, totalInvalid}));
158     }
159     // Avoid division by zero.
160     Value isZeroElements = rewriter.create<arith::CmpIOp>(
161         loc, arith::CmpIPredicate::eq, reduction->getResult(0), zero);
162     Value divisor = rewriter.create<arith::SelectOp>(loc, isZeroElements, one,
163                                                      reduction->getResult(0));
164     Value isDivisible = rewriter.create<arith::CmpIOp>(
165         loc, arith::CmpIPredicate::eq, zero,
166         rewriter.create<arith::RemSIOp>(loc, numElements, divisor));
167     // Must have 0 or 1 dynamic dimensions.
168     Value acceptablyDynamic = rewriter.create<arith::CmpIOp>(
169         loc, arith::CmpIPredicate::ule, reduction->getResult(1), one);
170     // Must have no invalid dimensions.
171     Value noInvalid = rewriter.create<arith::CmpIOp>(
172         loc, arith::CmpIPredicate::eq, reduction->getResult(2), zero);
173     // If there is no dynamic dimension then the number of elements must match.
174     Value hasOneDynamic = rewriter.create<arith::CmpIOp>(
175         loc, arith::CmpIPredicate::eq, reduction->getResult(1), one);
176     Value equalIfNotDynamic = rewriter.create<arith::OrIOp>(
177         loc, hasOneDynamic,
178         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
179                                        numElements, reduction->getResult(0)));
180 
181     Value allPassing = rewriter.create<arith::AndIOp>(
182         loc, isDivisible,
183         rewriter.create<arith::AndIOp>(
184             loc, acceptablyDynamic,
185             rewriter.create<arith::AndIOp>(loc, noInvalid, equalIfNotDynamic)));
186 
187     rewriter.replaceOpWithNewOp<shape::CstrRequireOp>(
188         op, allPassing, "Required valid reshape shape input");
189 
190     return success();
191   }
192 };
193 
194 struct HloLegalizeShapeOpsToStandardPass
195     : public mhlo::HloLegalizeShapeOpsToStandardPassBase<
196           HloLegalizeShapeOpsToStandardPass> {
getDependentDialectsmlir::__anon4c3ac5cc0111::HloLegalizeShapeOpsToStandardPass197   void getDependentDialects(DialectRegistry& registry) const override {
198     registry.insert<arith::ArithmeticDialect, shape::ShapeDialect,
199                     tensor::TensorDialect>();
200   }
201 
runOnOperationmlir::__anon4c3ac5cc0111::HloLegalizeShapeOpsToStandardPass202   void runOnOperation() override {
203     MLIRContext& ctx = getContext();
204     RewritePatternSet patterns(&ctx);
205     ConversionTarget target(ctx);
206     target.addLegalDialect<arith::ArithmeticDialect, tensor::TensorDialect,
207                            shape::ShapeDialect>();
208 
209     target.addLegalOp<UnrealizedConversionCastOp>();
210 
211     auto func = getOperation();
212     mhlo::RemoveSignTypeConverter typeConverter;
213     mhlo::populateHloShapeOpsToStandardConversionPattern(&ctx, typeConverter,
214                                                          &patterns);
215     if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
216       signalPassFailure();
217     }
218   }
219 };
220 
221 }  // namespace
222 
223 namespace mhlo {
224 
populateHloShapeOpsToStandardConversionPattern(MLIRContext * context,TypeConverter & typeConverter,RewritePatternSet * patterns)225 void populateHloShapeOpsToStandardConversionPattern(
226     MLIRContext* context, TypeConverter& typeConverter,
227     RewritePatternSet* patterns) {
228   // clang-format off
229   patterns->add<
230       ComputeReshapeShapeConversion,
231       CstrReshapableConversion>(typeConverter, context);
232   // clang-format on
233 }
234 
235 std::unique_ptr<OperationPass<func::FuncOp>>
createLegalizeHloShapeOpsToStandardPass()236 createLegalizeHloShapeOpsToStandardPass() {
237   return std::make_unique<HloLegalizeShapeOpsToStandardPass>();
238 }
239 
240 }  // namespace mhlo
241 }  // namespace mlir
242