1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for lowering MHLO general dot to a regular dot.
17
18 #include <sys/types.h>
19
20 #include <utility>
21
22 #include "llvm/ADT/STLExtras.h"
23 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
24 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
25 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
26 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
27 #include "mlir/Dialect/Func/IR/FuncOps.h"
28 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
29 #include "mlir/IR/Attributes.h"
30 #include "mlir/IR/BuiltinTypes.h"
31 #include "mlir/IR/Location.h"
32 #include "mlir/IR/Operation.h"
33 #include "mlir/IR/TypeUtilities.h"
34 #include "mlir/Pass/Pass.h"
35 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
36
37 namespace mlir {
38 namespace mhlo {
39 namespace {
40
transposeReshape(Value arg,Location loc,llvm::ArrayRef<int64_t> leftDims,llvm::ArrayRef<int64_t> rightDims,llvm::ArrayRef<int64_t> argShape,PatternRewriter & rewriter)41 Value transposeReshape(Value arg, Location loc,
42 llvm::ArrayRef<int64_t> leftDims,
43 llvm::ArrayRef<int64_t> rightDims,
44 llvm::ArrayRef<int64_t> argShape,
45 PatternRewriter &rewriter) {
46 auto elementType = getElementTypeOrSelf(arg.getType());
47
48 int64_t leftSize = 1;
49 for (auto dim : leftDims) {
50 leftSize = (ShapedType::isDynamic(argShape[dim]) || leftSize < 0)
51 ? ShapedType::kDynamicSize
52 : leftSize * argShape[dim];
53 }
54
55 int64_t rightSize = 1;
56 for (auto dim : rightDims) {
57 rightSize = (ShapedType::isDynamic(argShape[dim]) || rightSize < 0)
58 ? ShapedType::kDynamicSize
59 : rightSize * argShape[dim];
60 }
61
62 // Generate the transpose permutation attribute.
63 llvm::SmallVector<int64_t, 5> transposePermutation(leftDims.begin(),
64 leftDims.end());
65 transposePermutation.append(rightDims.begin(), rightDims.end());
66
67 TensorType transposePermutationType =
68 RankedTensorType::get({static_cast<int64_t>(transposePermutation.size())},
69 rewriter.getIntegerType(64));
70
71 auto transposePermutationAttr =
72 DenseIntElementsAttr::get(transposePermutationType,
73 llvm::makeArrayRef(transposePermutation))
74 .cast<DenseIntElementsAttr>();
75
76 // Compute the resulting shape.
77 llvm::SmallVector<int64_t, 5> transposedShape;
78 for (auto val : transposePermutation) {
79 transposedShape.push_back(argShape[val]);
80 }
81
82 // If there are only a single pair of contracting dimensions and the output
83 // rank is two we can skip a needless reshape.
84 bool noReshape = transposedShape.size() == 2 && leftDims.size() == 1 &&
85 rightDims.size() == 1;
86
87 // Construct type. If no reshape is needed, the sparsity, if any, of the input
88 // operand is propagated to the ouput to ensure this information is not lost
89 // in the dot operation.
90 auto enc = sparse_tensor::getSparseTensorEncoding(arg.getType());
91 auto transposeType =
92 (enc && noReshape)
93 ? RankedTensorType::get(transposedShape, elementType, enc)
94 : RankedTensorType::get(transposedShape, elementType);
95
96 // Construct transpose. If no reshape is needed, we are done.
97 Value transposeResult = rewriter.create<TransposeOp>(
98 loc, transposeType, arg, transposePermutationAttr);
99 if (noReshape) return transposeResult;
100
101 // Return the final result.
102 auto reshapedType = RankedTensorType::get({leftSize, rightSize}, elementType);
103
104 if (reshapedType.hasStaticShape()) {
105 return rewriter.create<ReshapeOp>(loc, reshapedType, transposeResult);
106 }
107
108 SmallVector<Value> reshapeDims;
109 auto multiplyDynamicDims = [&](llvm::ArrayRef<int64_t> dims) -> Value {
110 Value dynamicSize = rewriter.create<GetDimensionSizeOp>(
111 loc, RankedTensorType::get({1}, rewriter.getI32Type()), arg,
112 rewriter.getI64IntegerAttr(dims.front()));
113
114 for (auto idx : dims.drop_front()) {
115 Value dim = rewriter.create<GetDimensionSizeOp>(
116 loc, RankedTensorType::get({1}, rewriter.getI32Type()), arg,
117 rewriter.getI64IntegerAttr(idx));
118 dynamicSize = rewriter.create<MulOp>(loc, dynamicSize, dim);
119 }
120 return dynamicSize;
121 };
122
123 if (leftSize < 0) {
124 reshapeDims.push_back(multiplyDynamicDims(leftDims));
125 } else {
126 reshapeDims.push_back(
127 rewriter.create<ConstantOp>(loc, rewriter.getI32TensorAttr(leftSize)));
128 }
129
130 if (rightSize < 0) {
131 reshapeDims.push_back(multiplyDynamicDims(rightDims));
132 } else {
133 reshapeDims.push_back(
134 rewriter.create<ConstantOp>(loc, rewriter.getI32TensorAttr(rightSize)));
135 }
136
137 Value reshapeDimsTensor = rewriter.create<ConcatenateOp>(
138 loc, RankedTensorType::get({2}, rewriter.getI32Type()), reshapeDims,
139 rewriter.getI64IntegerAttr(0));
140
141 return rewriter.create<DynamicReshapeOp>(loc, reshapedType, transposeResult,
142 reshapeDimsTensor);
143 }
144
processDotArg(Value arg,Location loc,ArrayRef<int64_t> contractDimsAttr,bool outerDimsFirst,PatternRewriter & rewriter)145 Value processDotArg(Value arg, Location loc, ArrayRef<int64_t> contractDimsAttr,
146 bool outerDimsFirst, PatternRewriter &rewriter) {
147 auto shape = arg.getType().cast<ShapedType>().getShape();
148
149 llvm::SmallVector<bool, 5> isOuterDim;
150 isOuterDim.resize(shape.size(), true);
151
152 // Compute the contract dimension ordering.
153 llvm::SmallVector<int64_t, 5> contractDims;
154 for (auto dim : contractDimsAttr) {
155 contractDims.push_back(dim);
156 isOuterDim[dim] = false;
157 }
158
159 // Compute the outer dimension orderings.
160 llvm::SmallVector<int64_t, 5> outerDims;
161 for (const auto &it : llvm::enumerate(isOuterDim)) {
162 if (it.value()) {
163 outerDims.push_back(it.index());
164 }
165 }
166
167 if (outerDimsFirst) {
168 return transposeReshape(arg, loc, outerDims, contractDims, shape, rewriter);
169 }
170
171 return transposeReshape(arg, loc, contractDims, outerDims, shape, rewriter);
172 }
173
174 struct GeneralDotConvert : public OpRewritePattern<DotGeneralOp> {
175 // Attempts to lower a General Dot operator to a standard Dot operator.
176 // General dots include batching dimensions and can have collapsing
177 // dimensions along any axis. Inserting correctly arrange transpose and
178 // reshape operators organizes the tensors and allows the General Dot to be
179 // replaced with the standard Dot operator.
180 //
181 // Note: This requires an empty list of batch dimensions.
182
GeneralDotConvertmlir::mhlo::__anon5480e5560111::GeneralDotConvert183 explicit GeneralDotConvert(MLIRContext *context)
184 : OpRewritePattern(context) {}
185
matchAndRewritemlir::mhlo::__anon5480e5560111::GeneralDotConvert186 LogicalResult matchAndRewrite(DotGeneralOp op,
187 PatternRewriter &rewriter) const override {
188 Location loc = op.getLoc();
189
190 auto dotNumbers = op.dot_dimension_numbers();
191 if (!dotNumbers.getLhsBatchingDimensions().empty() ||
192 !dotNumbers.getRhsBatchingDimensions().empty()) {
193 return failure();
194 }
195
196 auto lhsContractingDims = dotNumbers.getLhsContractingDimensions();
197 auto rhsContractingDims = dotNumbers.getRhsContractingDimensions();
198
199 auto lhs = op.lhs();
200 auto rhs = op.rhs();
201
202 RankedTensorType lhsTy = lhs.getType().dyn_cast<RankedTensorType>();
203 RankedTensorType rhsTy = rhs.getType().dyn_cast<RankedTensorType>();
204 if (!lhsTy || !rhsTy) return failure();
205
206 lhs = processDotArg(op.lhs(), op.getLoc(),
207 dotNumbers.getLhsContractingDimensions(),
208 /*outerDimsFirst=*/true, rewriter);
209
210 rhs = processDotArg(op.rhs(), op.getLoc(),
211 dotNumbers.getRhsContractingDimensions(),
212 /*outerDimsFirst=*/false, rewriter);
213
214 // Accept only static shaped types.
215 auto lhsShapeType = lhs.getType().dyn_cast_or_null<ShapedType>();
216 auto rhsShapeType = rhs.getType().dyn_cast_or_null<ShapedType>();
217 if (!lhsShapeType || !rhsShapeType) return failure();
218
219 ArrayAttr precisionConfig;
220 if (op.precision_config()) precisionConfig = *op.precision_config();
221 SmallVector<Type, 1> results;
222 LogicalResult res =
223 DotOp::inferReturnTypes(rewriter.getContext(), None, {lhs, rhs},
224 op->getAttrDictionary(), {}, results);
225 (void)res;
226 assert(succeeded(res) && "invalid input to dot");
227
228 ShapedType resultTy = op.getType().cast<ShapedType>();
229 ShapedType newTy =
230 results.front().cast<ShapedType>().clone(resultTy.getElementType());
231 Value newDotOp =
232 rewriter.create<DotOp>(op.getLoc(), newTy, lhs, rhs, precisionConfig);
233 if (static_cast<int64_t>(lhsContractingDims.size()) ==
234 lhsTy.getRank() - 1 &&
235 static_cast<int64_t>(rhsContractingDims.size()) ==
236 rhsTy.getRank() - 1) {
237 rewriter.replaceOp(op, newDotOp);
238 return success();
239 }
240
241 // We can avoid all the computation below if we know the static shape.
242 if (resultTy.hasStaticShape()) {
243 rewriter.replaceOpWithNewOp<ReshapeOp>(op, resultTy, newDotOp);
244 return success();
245 }
246
247 llvm::SmallVector<int64_t> staticDims;
248 llvm::SmallVector<Value> dynDims;
249
250 auto getDynamicDims = [&](Value arg,
251 llvm::ArrayRef<int64_t> contractingDims) {
252 RankedTensorType ty = arg.getType().cast<RankedTensorType>();
253 int index = 0;
254 for (auto contractingDim : contractingDims) {
255 for (; index < contractingDim; index++) {
256 staticDims.push_back(ty.getDimSize(index));
257 dynDims.push_back(rewriter.create<GetDimensionSizeOp>(
258 loc, RankedTensorType::get({1}, rewriter.getI32Type()), arg,
259 rewriter.getI64IntegerAttr(index)));
260 }
261 index++;
262 }
263
264 for (; index < ty.getRank(); index++) {
265 staticDims.push_back(ty.getDimSize(index));
266 dynDims.push_back(rewriter.create<GetDimensionSizeOp>(
267 loc, RankedTensorType::get({1}, rewriter.getI32Type()), arg,
268 rewriter.getI64IntegerAttr(index)));
269 }
270 };
271
272 getDynamicDims(op.lhs(), lhsContractingDims);
273 getDynamicDims(op.rhs(), rhsContractingDims);
274
275 Value reshapeDimsTensor = rewriter.create<ConcatenateOp>(
276 loc,
277 RankedTensorType::get({static_cast<int64_t>(dynDims.size())},
278 rewriter.getI32Type()),
279 dynDims, rewriter.getI64IntegerAttr(0));
280
281 Value result = rewriter.create<DynamicReshapeOp>(
282 op.getLoc(),
283 RankedTensorType::get(staticDims, resultTy.getElementType()), newDotOp,
284 reshapeDimsTensor);
285
286 rewriter.replaceOp(op, result);
287 return success();
288 }
289 };
290
291 struct LegalizeGeneralDotPass
292 : public LegalizeGeneralDotPassBase<LegalizeGeneralDotPass> {
293 /// Lower all general dots that can be represented as a non-batched matmul.
runOnOperationmlir::mhlo::__anon5480e5560111::LegalizeGeneralDotPass294 void runOnOperation() override {
295 RewritePatternSet patterns(&getContext());
296 populateGeneralDotOpLoweringPatterns(&patterns, &getContext());
297 if (failed(applyPatternsAndFoldGreedily(getOperation(),
298 std::move(patterns)))) {
299 return signalPassFailure();
300 }
301 }
302 };
303
304 } // namespace
305 } // namespace mhlo
306 } // namespace mlir
307
populateGeneralDotOpLoweringPatterns(RewritePatternSet * patterns,MLIRContext * ctx)308 void mlir::mhlo::populateGeneralDotOpLoweringPatterns(
309 RewritePatternSet *patterns, MLIRContext *ctx) {
310 patterns->add<GeneralDotConvert>(ctx);
311 }
312
createLegalizeGeneralDotPass()313 std::unique_ptr<::mlir::Pass> mlir::mhlo::createLegalizeGeneralDotPass() {
314 return std::make_unique<LegalizeGeneralDotPass>();
315 }
316