1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for lowering HLO/LHLO dialect to Linalg dialect.
17
18 #include <algorithm>
19 #include <numeric>
20 #include <string>
21 #include <utility>
22
23 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
24 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
25 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
26 #include "mlir-hlo/Dialect/mhlo/transforms/type_conversion.h"
27 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
28 #include "mlir/Dialect/Func/IR/FuncOps.h"
29 #include "mlir/Dialect/Shape/IR/Shape.h"
30 #include "mlir/Dialect/Tensor/IR/Tensor.h"
31 #include "mlir/IR/AffineExpr.h"
32 #include "mlir/IR/Attributes.h"
33 #include "mlir/IR/Builders.h"
34 #include "mlir/IR/BuiltinAttributes.h"
35 #include "mlir/IR/BuiltinOps.h"
36 #include "mlir/IR/BuiltinTypes.h"
37 #include "mlir/IR/Location.h"
38 #include "mlir/IR/MLIRContext.h"
39 #include "mlir/Pass/Pass.h"
40 #include "mlir/Pass/PassManager.h"
41 #include "mlir/Support/LogicalResult.h"
42 #include "mlir/Transforms/DialectConversion.h"
43
44 namespace mlir {
45 namespace {
46
47 struct ComputeReshapeShapeConversion
48 : public OpConversionPattern<mhlo::ComputeReshapeShapeOp> {
49 using OpConversionPattern<mhlo::ComputeReshapeShapeOp>::OpConversionPattern;
matchAndRewritemlir::__anon4c3ac5cc0111::ComputeReshapeShapeConversion50 LogicalResult matchAndRewrite(
51 mhlo::ComputeReshapeShapeOp op, OpAdaptor adaptor,
52 ConversionPatternRewriter& rewriter) const final {
53 auto loc = op.getLoc();
54 auto* ctx = op->getContext();
55 Value negOne = rewriter.create<arith::ConstantIndexOp>(loc, -1);
56 auto indexType = rewriter.getIndexType();
57 auto numElements = adaptor.getOperands()[0];
58 auto targetShapeType =
59 adaptor.getOperands()[1].getType().cast<ShapedType>();
60 auto extentType =
61 shape::getExtentTensorType(ctx, targetShapeType.getDimSize(0));
62
63 // Calculate the computed actual extent for a possible dynamic extent.
64 auto newShape = targetShapeType.getElementType().isIndex()
65 ? adaptor.getOperands()[1]
66 : rewriter.create<arith::IndexCastOp>(
67 loc, extentType, adaptor.getOperands()[1]);
68 Value newShapeRank =
69 rewriter.create<shape::RankOp>(loc, indexType, newShape);
70 // The product begins with a -1 seed which will cancel out a -1 extent in
71 // the input shape if there is one. If there is not, this computed result
72 // will never be used, so it's okay to compute a negative number of
73 // elements.
74 auto accountedNumEls =
75 rewriter.create<shape::ReduceOp>(loc, newShape, negOne);
76 {
77 PatternRewriter::InsertionGuard g(rewriter);
78 rewriter.setInsertionPointToEnd(accountedNumEls.getBody());
79 Value lhs = accountedNumEls.getBody()->getArgument(1);
80 Value rhs = accountedNumEls.getBody()->getArgument(2);
81 rewriter.create<shape::YieldOp>(
82 loc, rewriter.create<arith::MulIOp>(loc, lhs, rhs).getResult());
83 }
84 Value missingDimVal = rewriter.create<arith::DivUIOp>(
85 loc, numElements, accountedNumEls->getResult(0));
86
87 // Create the final target shape with a possible dynamic extent replace with
88 // the calculated extent.
89 SmallVector<Value> dynamicExtent;
90 if (!targetShapeType.hasStaticShape())
91 dynamicExtent.push_back(newShapeRank);
92 auto gen = rewriter.create<tensor::GenerateOp>(
93 loc, targetShapeType, dynamicExtent,
94 [&](OpBuilder& b, Location loc, ValueRange indices) {
95 Value extent = b.create<shape::GetExtentOp>(loc, indexType, newShape,
96 indices[0]);
97 Value useMissingDimVal = b.create<arith::CmpIOp>(
98 loc, arith::CmpIPredicate::eq, extent, negOne);
99 Value dimVal = b.create<arith::SelectOp>(loc, useMissingDimVal,
100 missingDimVal, extent);
101 dimVal = targetShapeType.getElementType().isIndex()
102 ? dimVal
103 : b.create<arith::IndexCastOp>(
104 loc, targetShapeType.getElementType(), dimVal);
105 b.create<tensor::YieldOp>(loc, dimVal);
106 });
107 rewriter.replaceOp(op, gen.getResult());
108
109 return success();
110 }
111 };
112
113 struct CstrReshapableConversion
114 : public OpConversionPattern<mhlo::CstrReshapableOp> {
115 using OpConversionPattern<mhlo::CstrReshapableOp>::OpConversionPattern;
matchAndRewritemlir::__anon4c3ac5cc0111::CstrReshapableConversion116 LogicalResult matchAndRewrite(
117 mhlo::CstrReshapableOp op, OpAdaptor adaptor,
118 ConversionPatternRewriter& rewriter) const final {
119 auto loc = op.getLoc();
120 auto* ctx = op->getContext();
121 Value negOne = rewriter.create<arith::ConstantIndexOp>(loc, -1);
122 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
123 Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
124 auto numElements = adaptor.getOperands()[0];
125 auto targetShapeType =
126 adaptor.getOperands()[1].getType().cast<ShapedType>();
127 auto extentType =
128 shape::getExtentTensorType(ctx, targetShapeType.getDimSize(0));
129
130 // Calculate the computed actual extent for a possible dynamic extent.
131 auto newShape = targetShapeType.getElementType().isIndex()
132 ? adaptor.getOperands()[1]
133 : rewriter.create<arith::IndexCastOp>(
134 loc, extentType, adaptor.getOperands()[1]);
135 auto reduction = rewriter.create<shape::ReduceOp>(
136 loc, newShape, llvm::makeArrayRef({one, zero, zero}));
137 {
138 PatternRewriter::InsertionGuard g(rewriter);
139 auto* body = reduction.getBody();
140 rewriter.setInsertionPointToEnd(body);
141 Value extent = body->getArgument(1);
142 Value isDynamic = rewriter.create<arith::CmpIOp>(
143 loc, arith::CmpIPredicate::eq, negOne, extent);
144 Value isInvalid = rewriter.create<arith::CmpIOp>(
145 loc, arith::CmpIPredicate::slt, extent, negOne);
146 Value totalDynamic = rewriter.create<arith::AddIOp>(
147 loc, rewriter.create<arith::SelectOp>(loc, isDynamic, one, zero),
148 body->getArgument(3));
149 Value totalInvalid = rewriter.create<arith::AddIOp>(
150 loc, rewriter.create<arith::SelectOp>(loc, isInvalid, one, zero),
151 body->getArgument(4));
152 Value extentOrOne =
153 rewriter.create<arith::SelectOp>(loc, isDynamic, one, extent);
154 Value totalElements = rewriter.create<arith::MulIOp>(
155 loc, extentOrOne, body->getArgument(2));
156 rewriter.create<shape::YieldOp>(
157 loc, llvm::makeArrayRef({totalElements, totalDynamic, totalInvalid}));
158 }
159 // Avoid division by zero.
160 Value isZeroElements = rewriter.create<arith::CmpIOp>(
161 loc, arith::CmpIPredicate::eq, reduction->getResult(0), zero);
162 Value divisor = rewriter.create<arith::SelectOp>(loc, isZeroElements, one,
163 reduction->getResult(0));
164 Value isDivisible = rewriter.create<arith::CmpIOp>(
165 loc, arith::CmpIPredicate::eq, zero,
166 rewriter.create<arith::RemSIOp>(loc, numElements, divisor));
167 // Must have 0 or 1 dynamic dimensions.
168 Value acceptablyDynamic = rewriter.create<arith::CmpIOp>(
169 loc, arith::CmpIPredicate::ule, reduction->getResult(1), one);
170 // Must have no invalid dimensions.
171 Value noInvalid = rewriter.create<arith::CmpIOp>(
172 loc, arith::CmpIPredicate::eq, reduction->getResult(2), zero);
173 // If there is no dynamic dimension then the number of elements must match.
174 Value hasOneDynamic = rewriter.create<arith::CmpIOp>(
175 loc, arith::CmpIPredicate::eq, reduction->getResult(1), one);
176 Value equalIfNotDynamic = rewriter.create<arith::OrIOp>(
177 loc, hasOneDynamic,
178 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
179 numElements, reduction->getResult(0)));
180
181 Value allPassing = rewriter.create<arith::AndIOp>(
182 loc, isDivisible,
183 rewriter.create<arith::AndIOp>(
184 loc, acceptablyDynamic,
185 rewriter.create<arith::AndIOp>(loc, noInvalid, equalIfNotDynamic)));
186
187 rewriter.replaceOpWithNewOp<shape::CstrRequireOp>(
188 op, allPassing, "Required valid reshape shape input");
189
190 return success();
191 }
192 };
193
194 struct HloLegalizeShapeOpsToStandardPass
195 : public mhlo::HloLegalizeShapeOpsToStandardPassBase<
196 HloLegalizeShapeOpsToStandardPass> {
getDependentDialectsmlir::__anon4c3ac5cc0111::HloLegalizeShapeOpsToStandardPass197 void getDependentDialects(DialectRegistry& registry) const override {
198 registry.insert<arith::ArithmeticDialect, shape::ShapeDialect,
199 tensor::TensorDialect>();
200 }
201
runOnOperationmlir::__anon4c3ac5cc0111::HloLegalizeShapeOpsToStandardPass202 void runOnOperation() override {
203 MLIRContext& ctx = getContext();
204 RewritePatternSet patterns(&ctx);
205 ConversionTarget target(ctx);
206 target.addLegalDialect<arith::ArithmeticDialect, tensor::TensorDialect,
207 shape::ShapeDialect>();
208
209 target.addLegalOp<UnrealizedConversionCastOp>();
210
211 auto func = getOperation();
212 mhlo::RemoveSignTypeConverter typeConverter;
213 mhlo::populateHloShapeOpsToStandardConversionPattern(&ctx, typeConverter,
214 &patterns);
215 if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
216 signalPassFailure();
217 }
218 }
219 };
220
221 } // namespace
222
223 namespace mhlo {
224
populateHloShapeOpsToStandardConversionPattern(MLIRContext * context,TypeConverter & typeConverter,RewritePatternSet * patterns)225 void populateHloShapeOpsToStandardConversionPattern(
226 MLIRContext* context, TypeConverter& typeConverter,
227 RewritePatternSet* patterns) {
228 // clang-format off
229 patterns->add<
230 ComputeReshapeShapeConversion,
231 CstrReshapableConversion>(typeConverter, context);
232 // clang-format on
233 }
234
235 std::unique_ptr<OperationPass<func::FuncOp>>
createLegalizeHloShapeOpsToStandardPass()236 createLegalizeHloShapeOpsToStandardPass() {
237 return std::make_unique<HloLegalizeShapeOpsToStandardPass>();
238 }
239
240 } // namespace mhlo
241 } // namespace mlir
242