1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for lowering MHLO general dot to a regular dot.
17 
18 #include <sys/types.h>
19 
20 #include <utility>
21 
22 #include "llvm/ADT/STLExtras.h"
23 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
24 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
25 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
26 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
27 #include "mlir/Dialect/Func/IR/FuncOps.h"
28 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
29 #include "mlir/IR/Attributes.h"
30 #include "mlir/IR/BuiltinTypes.h"
31 #include "mlir/IR/Location.h"
32 #include "mlir/IR/Operation.h"
33 #include "mlir/IR/TypeUtilities.h"
34 #include "mlir/Pass/Pass.h"
35 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
36 
37 namespace mlir {
38 namespace mhlo {
39 namespace {
40 
transposeReshape(Value arg,Location loc,llvm::ArrayRef<int64_t> leftDims,llvm::ArrayRef<int64_t> rightDims,llvm::ArrayRef<int64_t> argShape,PatternRewriter & rewriter)41 Value transposeReshape(Value arg, Location loc,
42                        llvm::ArrayRef<int64_t> leftDims,
43                        llvm::ArrayRef<int64_t> rightDims,
44                        llvm::ArrayRef<int64_t> argShape,
45                        PatternRewriter &rewriter) {
46   auto elementType = getElementTypeOrSelf(arg.getType());
47 
48   int64_t leftSize = 1;
49   for (auto dim : leftDims) {
50     leftSize = (ShapedType::isDynamic(argShape[dim]) || leftSize < 0)
51                    ? ShapedType::kDynamicSize
52                    : leftSize * argShape[dim];
53   }
54 
55   int64_t rightSize = 1;
56   for (auto dim : rightDims) {
57     rightSize = (ShapedType::isDynamic(argShape[dim]) || rightSize < 0)
58                     ? ShapedType::kDynamicSize
59                     : rightSize * argShape[dim];
60   }
61 
62   // Generate the transpose permutation attribute.
63   llvm::SmallVector<int64_t, 5> transposePermutation(leftDims.begin(),
64                                                      leftDims.end());
65   transposePermutation.append(rightDims.begin(), rightDims.end());
66 
67   TensorType transposePermutationType =
68       RankedTensorType::get({static_cast<int64_t>(transposePermutation.size())},
69                             rewriter.getIntegerType(64));
70 
71   auto transposePermutationAttr =
72       DenseIntElementsAttr::get(transposePermutationType,
73                                 llvm::makeArrayRef(transposePermutation))
74           .cast<DenseIntElementsAttr>();
75 
76   // Compute the resulting shape.
77   llvm::SmallVector<int64_t, 5> transposedShape;
78   for (auto val : transposePermutation) {
79     transposedShape.push_back(argShape[val]);
80   }
81 
82   // If there are only a single pair of contracting dimensions and the output
83   // rank is two we can skip a needless reshape.
84   bool noReshape = transposedShape.size() == 2 && leftDims.size() == 1 &&
85                    rightDims.size() == 1;
86 
87   // Construct type. If no reshape is needed, the sparsity, if any, of the input
88   // operand is propagated to the ouput to ensure this information is not lost
89   // in the dot operation.
90   auto enc = sparse_tensor::getSparseTensorEncoding(arg.getType());
91   auto transposeType =
92       (enc && noReshape)
93           ? RankedTensorType::get(transposedShape, elementType, enc)
94           : RankedTensorType::get(transposedShape, elementType);
95 
96   // Construct transpose. If no reshape is needed, we are done.
97   Value transposeResult = rewriter.create<TransposeOp>(
98       loc, transposeType, arg, transposePermutationAttr);
99   if (noReshape) return transposeResult;
100 
101   // Return the final result.
102   auto reshapedType = RankedTensorType::get({leftSize, rightSize}, elementType);
103 
104   if (reshapedType.hasStaticShape()) {
105     return rewriter.create<ReshapeOp>(loc, reshapedType, transposeResult);
106   }
107 
108   SmallVector<Value> reshapeDims;
109   auto multiplyDynamicDims = [&](llvm::ArrayRef<int64_t> dims) -> Value {
110     Value dynamicSize = rewriter.create<GetDimensionSizeOp>(
111         loc, RankedTensorType::get({1}, rewriter.getI32Type()), arg,
112         rewriter.getI64IntegerAttr(dims.front()));
113 
114     for (auto idx : dims.drop_front()) {
115       Value dim = rewriter.create<GetDimensionSizeOp>(
116           loc, RankedTensorType::get({1}, rewriter.getI32Type()), arg,
117           rewriter.getI64IntegerAttr(idx));
118       dynamicSize = rewriter.create<MulOp>(loc, dynamicSize, dim);
119     }
120     return dynamicSize;
121   };
122 
123   if (leftSize < 0) {
124     reshapeDims.push_back(multiplyDynamicDims(leftDims));
125   } else {
126     reshapeDims.push_back(
127         rewriter.create<ConstantOp>(loc, rewriter.getI32TensorAttr(leftSize)));
128   }
129 
130   if (rightSize < 0) {
131     reshapeDims.push_back(multiplyDynamicDims(rightDims));
132   } else {
133     reshapeDims.push_back(
134         rewriter.create<ConstantOp>(loc, rewriter.getI32TensorAttr(rightSize)));
135   }
136 
137   Value reshapeDimsTensor = rewriter.create<ConcatenateOp>(
138       loc, RankedTensorType::get({2}, rewriter.getI32Type()), reshapeDims,
139       rewriter.getI64IntegerAttr(0));
140 
141   return rewriter.create<DynamicReshapeOp>(loc, reshapedType, transposeResult,
142                                            reshapeDimsTensor);
143 }
144 
processDotArg(Value arg,Location loc,ArrayRef<int64_t> contractDimsAttr,bool outerDimsFirst,PatternRewriter & rewriter)145 Value processDotArg(Value arg, Location loc, ArrayRef<int64_t> contractDimsAttr,
146                     bool outerDimsFirst, PatternRewriter &rewriter) {
147   auto shape = arg.getType().cast<ShapedType>().getShape();
148 
149   llvm::SmallVector<bool, 5> isOuterDim;
150   isOuterDim.resize(shape.size(), true);
151 
152   // Compute the contract dimension ordering.
153   llvm::SmallVector<int64_t, 5> contractDims;
154   for (auto dim : contractDimsAttr) {
155     contractDims.push_back(dim);
156     isOuterDim[dim] = false;
157   }
158 
159   // Compute the outer dimension orderings.
160   llvm::SmallVector<int64_t, 5> outerDims;
161   for (const auto &it : llvm::enumerate(isOuterDim)) {
162     if (it.value()) {
163       outerDims.push_back(it.index());
164     }
165   }
166 
167   if (outerDimsFirst) {
168     return transposeReshape(arg, loc, outerDims, contractDims, shape, rewriter);
169   }
170 
171   return transposeReshape(arg, loc, contractDims, outerDims, shape, rewriter);
172 }
173 
174 struct GeneralDotConvert : public OpRewritePattern<DotGeneralOp> {
175   // Attempts to lower a General Dot operator to a standard Dot operator.
176   // General dots include batching dimensions and can have collapsing
177   // dimensions along any axis. Inserting correctly arrange transpose and
178   // reshape operators organizes the tensors and allows the General Dot to be
179   // replaced with the standard Dot operator.
180   //
181   // Note: This requires an empty list of batch dimensions.
182 
GeneralDotConvertmlir::mhlo::__anon5480e5560111::GeneralDotConvert183   explicit GeneralDotConvert(MLIRContext *context)
184       : OpRewritePattern(context) {}
185 
matchAndRewritemlir::mhlo::__anon5480e5560111::GeneralDotConvert186   LogicalResult matchAndRewrite(DotGeneralOp op,
187                                 PatternRewriter &rewriter) const override {
188     Location loc = op.getLoc();
189 
190     auto dotNumbers = op.dot_dimension_numbers();
191     if (!dotNumbers.getLhsBatchingDimensions().empty() ||
192         !dotNumbers.getRhsBatchingDimensions().empty()) {
193       return failure();
194     }
195 
196     auto lhsContractingDims = dotNumbers.getLhsContractingDimensions();
197     auto rhsContractingDims = dotNumbers.getRhsContractingDimensions();
198 
199     auto lhs = op.lhs();
200     auto rhs = op.rhs();
201 
202     RankedTensorType lhsTy = lhs.getType().dyn_cast<RankedTensorType>();
203     RankedTensorType rhsTy = rhs.getType().dyn_cast<RankedTensorType>();
204     if (!lhsTy || !rhsTy) return failure();
205 
206     lhs = processDotArg(op.lhs(), op.getLoc(),
207                         dotNumbers.getLhsContractingDimensions(),
208                         /*outerDimsFirst=*/true, rewriter);
209 
210     rhs = processDotArg(op.rhs(), op.getLoc(),
211                         dotNumbers.getRhsContractingDimensions(),
212                         /*outerDimsFirst=*/false, rewriter);
213 
214     // Accept only static shaped types.
215     auto lhsShapeType = lhs.getType().dyn_cast_or_null<ShapedType>();
216     auto rhsShapeType = rhs.getType().dyn_cast_or_null<ShapedType>();
217     if (!lhsShapeType || !rhsShapeType) return failure();
218 
219     ArrayAttr precisionConfig;
220     if (op.precision_config()) precisionConfig = *op.precision_config();
221     SmallVector<Type, 1> results;
222     LogicalResult res =
223         DotOp::inferReturnTypes(rewriter.getContext(), None, {lhs, rhs},
224                                 op->getAttrDictionary(), {}, results);
225     (void)res;
226     assert(succeeded(res) && "invalid input to dot");
227 
228     ShapedType resultTy = op.getType().cast<ShapedType>();
229     ShapedType newTy =
230         results.front().cast<ShapedType>().clone(resultTy.getElementType());
231     Value newDotOp =
232         rewriter.create<DotOp>(op.getLoc(), newTy, lhs, rhs, precisionConfig);
233     if (static_cast<int64_t>(lhsContractingDims.size()) ==
234             lhsTy.getRank() - 1 &&
235         static_cast<int64_t>(rhsContractingDims.size()) ==
236             rhsTy.getRank() - 1) {
237       rewriter.replaceOp(op, newDotOp);
238       return success();
239     }
240 
241     // We can avoid all the computation below if we know the static shape.
242     if (resultTy.hasStaticShape()) {
243       rewriter.replaceOpWithNewOp<ReshapeOp>(op, resultTy, newDotOp);
244       return success();
245     }
246 
247     llvm::SmallVector<int64_t> staticDims;
248     llvm::SmallVector<Value> dynDims;
249 
250     auto getDynamicDims = [&](Value arg,
251                               llvm::ArrayRef<int64_t> contractingDims) {
252       RankedTensorType ty = arg.getType().cast<RankedTensorType>();
253       int index = 0;
254       for (auto contractingDim : contractingDims) {
255         for (; index < contractingDim; index++) {
256           staticDims.push_back(ty.getDimSize(index));
257           dynDims.push_back(rewriter.create<GetDimensionSizeOp>(
258               loc, RankedTensorType::get({1}, rewriter.getI32Type()), arg,
259               rewriter.getI64IntegerAttr(index)));
260         }
261         index++;
262       }
263 
264       for (; index < ty.getRank(); index++) {
265         staticDims.push_back(ty.getDimSize(index));
266         dynDims.push_back(rewriter.create<GetDimensionSizeOp>(
267             loc, RankedTensorType::get({1}, rewriter.getI32Type()), arg,
268             rewriter.getI64IntegerAttr(index)));
269       }
270     };
271 
272     getDynamicDims(op.lhs(), lhsContractingDims);
273     getDynamicDims(op.rhs(), rhsContractingDims);
274 
275     Value reshapeDimsTensor = rewriter.create<ConcatenateOp>(
276         loc,
277         RankedTensorType::get({static_cast<int64_t>(dynDims.size())},
278                               rewriter.getI32Type()),
279         dynDims, rewriter.getI64IntegerAttr(0));
280 
281     Value result = rewriter.create<DynamicReshapeOp>(
282         op.getLoc(),
283         RankedTensorType::get(staticDims, resultTy.getElementType()), newDotOp,
284         reshapeDimsTensor);
285 
286     rewriter.replaceOp(op, result);
287     return success();
288   }
289 };
290 
291 struct LegalizeGeneralDotPass
292     : public LegalizeGeneralDotPassBase<LegalizeGeneralDotPass> {
293   /// Lower all general dots that can be represented as a non-batched matmul.
runOnOperationmlir::mhlo::__anon5480e5560111::LegalizeGeneralDotPass294   void runOnOperation() override {
295     RewritePatternSet patterns(&getContext());
296     populateGeneralDotOpLoweringPatterns(&patterns, &getContext());
297     if (failed(applyPatternsAndFoldGreedily(getOperation(),
298                                             std::move(patterns)))) {
299       return signalPassFailure();
300     }
301   }
302 };
303 
304 }  // namespace
305 }  // namespace mhlo
306 }  // namespace mlir
307 
populateGeneralDotOpLoweringPatterns(RewritePatternSet * patterns,MLIRContext * ctx)308 void mlir::mhlo::populateGeneralDotOpLoweringPatterns(
309     RewritePatternSet *patterns, MLIRContext *ctx) {
310   patterns->add<GeneralDotConvert>(ctx);
311 }
312 
createLegalizeGeneralDotPass()313 std::unique_ptr<::mlir::Pass> mlir::mhlo::createLegalizeGeneralDotPass() {
314   return std::make_unique<LegalizeGeneralDotPass>();
315 }
316