1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file canonicalize reduction ops in hlo dialect to match the
17 // capacity of codegen backend.
18 
19 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
20 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
21 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
22 #include "mlir/Dialect/Func/IR/FuncOps.h"
23 #include "mlir/Dialect/Tensor/IR/Tensor.h"
24 
25 namespace mlir {
26 namespace mhlo {
27 namespace {
28 
29 // All the reduce ops can be divided into following four types:
30 //  - a) column reduction, only reduce the most significant dimensions.
31 //  - b) row reduction, only reduce the least significant dimensions.
32 //  - c) reduce to scalar, all dimensions are reduced.
33 //  - d) others. (not support now, maybe use transpose to canonicalize)
34 //
35 // Currently we do following canonicalization to match the capacity of codegen
36 // backend.
37 //
38 // For case a):
39 // ====================================================================================
40 //   we convert all column reduction to rank-2 column reduction.
41 //   For example, suppose we have:
42 //   ```
43 //     func @test(%arg0: tensor<?x?x?xf32>) -> tensor<?x?xf32> {
44 //       ...
45 //       %2 = "mhlo.reduce"(%arg0, ...) ({...})
46 //         {dimensions = dense<[0]> : tensor<1xi64>} :
47 //         (tensor<?x?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
48 //       return %2 : tensor<?x?xf32>
49 //     }
50 //  ```
51 //   After conversion:
52 //     func @test(%arg0: tensor<?x?x?xf32>) -> tensor<?x?xf32> {
53 //       // [a, b, c] -> [a, b*c]
54 //       %1 = mhlo.dynamic_reshape(%arg0, ...) : (tensor<?x?x?xf32>,
55 //       tensor<2xi64>) -> tensor<?x?xf32> %2 = "mhlo.reduce"(%1, ...) ({...})
56 //         {dimensions = dense<[0]> : tensor<1xi64>} :
57 //         (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32>
58 //       %3 = "mhlo.dynamic_reshape"(%2, ...) : (tensor<?xf32>, tensor<1xi64>)
59 //       -> tensor<?x?f32> return %3 : tensor<?x?xf32>
60 //     }
61 //  ```
62 //
63 // For case b):
64 // ====================================================================================
65 //   we convert all row reduction to rank-2 row reduction.
66 //   For example, suppose we have:
67 //   ```
68 //     func @test(%arg0: tensor<?x?x?xf32>) -> tensor<?x?xf32> {
69 //       ...
70 //       %2 = "mhlo.reduce"(%arg0, ...) ({...})
71 //         {dimensions = dense<[2]> : tensor<1xi64>} :
72 //         (tensor<?x?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
73 //       return %2 : tensor<?x?xf32>
74 //     }
75 //  ```
76 //   After conversion:
77 //     func @test(%arg0: tensor<?x?x?xf32>) -> tensor<?x?xf32> {
78 //       // [a, b, c] -> [a*b, c]
79 //       %1 = mhlo.dynamic_reshape(%arg0, ...) : (tensor<?x?x?xf32>,
80 //       tensor<2xi64>) -> tensor<?x?xf32> %2 = "mhlo.reduce"(%1, ...) ({...})
81 //         {dimensions = dense<[1]> : tensor<1xi64>} :
82 //         (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32>
83 //       %3 = "mhlo.dynamic_reshape"(%2, ...) : (tensor<?xf32>, tensor<1xi64>)
84 //       -> tensor<?x?f32> return %3 : tensor<?x?xf32>
85 //     }
86 //  ```
87 //
88 // For case c):
89 // ====================================================================================
90 //   we convert all reduce-to-scalar to rank-2 column reduction.
91 //
92 //   For example, suppose we have:
93 //   ```
94 //     func @test(%arg0: tensor<?x?x?xf32>) -> tensor<f32> {
95 //       ...
96 //       %2 = "mhlo.reduce"(%arg0, ...) ({...})
97 //         {dimensions = dense<[0,1,2]> : tensor<3xi64>} :
98 //         (tensor<?x?x?xf32>, tensor<f32>) -> tensor<f32>
99 //       return %2 : tensor<f32>
100 //     }
101 //  ```
102 //   After conversion:
103 //     func @test(%arg0: tensor<?x?x?xf32>) -> tensor<f32> {
104 //       // [a, b, c] -> [a*b*c, 1]
105 //       %1 = mhlo.dynamic_reshape(%arg0, ...) : (tensor<?x?x?xf32>,
106 //       tensor<2xi64>) -> tensor<?x?xf32> %2 = "mhlo.reduce"(%1, ...) ({...})
107 //         {dimensions = dense<[0]> : tensor<1xi64>} :
108 //         (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32>
109 //       %3 = "mhlo.reshape"(%2, ...) : (tensor<?xf32>, tensor<1xi64>) ->
110 //       tensor<f32> return %3 : tensor<f32>
111 //     }
112 //  ```
113 
114 struct HloCanonicalizeReductionPass
115     : HloCanonicalizeReductionPassBase<HloCanonicalizeReductionPass> {
getDependentDialectsmlir::mhlo::__anon9fd05cb50111::HloCanonicalizeReductionPass116   void getDependentDialects(DialectRegistry& registry) const override {
117     registry.insert<tensor::TensorDialect>();
118   }
runOnOperationmlir::mhlo::__anon9fd05cb50111::HloCanonicalizeReductionPass119   void runOnOperation() override {
120     getOperation().walk([&](ReduceOp op) {
121       SmallVector<int64_t, 4> dimsToReduce;
122       DenseSet<int64_t> dimsToReduceSet;
123       for (auto dim : op.dimensions().getValues<APInt>()) {
124         dimsToReduce.push_back(dim.getSExtValue());
125         dimsToReduceSet.insert(dimsToReduce.back());
126       }
127 
128       // empty reduction is just a no-op, thus no need to do codegen.
129       if (dimsToReduce.empty()) return;
130 
131       // suppose reduce input is a ranked tensor
132       auto ty = op.getOperand(0).getType().dyn_cast<RankedTensorType>();
133       if (!ty) return signalPassFailure();
134       int rank = ty.getRank();
135       int ndimsToReduce = dimsToReduce.size();
136       auto elemTy = ty.getElementType();
137       llvm::sort(dimsToReduce);
138 
139       // skip case d) form since we don't support it.
140       if ((dimsToReduce.back() - dimsToReduce[0]) != (ndimsToReduce - 1) ||
141           (dimsToReduce[0] != 0 && dimsToReduce.back() != (rank - 1))) {
142         return;
143       }
144 
145       // rank 2 row/column reduction is already supported.
146       if (rank == 2 && ndimsToReduce == 1) {
147         return;
148       }
149 
150       SmallVector<int64_t, 4> dimsToKeep;
151       for (int i = 0; i < rank; ++i) {
152         if (!dimsToReduceSet.count(i)) dimsToKeep.push_back(i);
153       }
154 
155       OpBuilder b(op);
156       auto loc = op.getLoc();
157       // TODO(disc): uniformed shape_scalar_type with shape_derivation
158       auto shapeScalarType = b.getIntegerType(32);
159       auto one = b.create<arith::ConstantIntOp>(loc, 1ll, shapeScalarType);
160 
161       // funtion to get total elements in selected dimensions
162       auto dimProd = [&](ArrayRef<int64_t> dims) {
163         Value nelems = one;
164         for (int64_t v : dims) {
165           Value dimIndex = b.create<tensor::DimOp>(loc, op.getOperand(0), v);
166           nelems = b.create<arith::MulIOp>(
167               loc, nelems,
168               b.create<arith::IndexCastOp>(loc, shapeScalarType, dimIndex));
169         }
170         return nelems;
171       };
172 
173       SmallVector<Value, 2> newOperandDims;
174       DenseIntElementsAttr attr;
175       Value nelemToReduce = dimProd(dimsToReduce);
176       Value nelemToKeep = dimProd(dimsToKeep);
177       if (rank == ndimsToReduce) {
178         // case c) Reduce to scalar.
179         // Currently we don't support reduce to scalar directly.
180         // As a workaround, we convert the `reduce to scalar` to a rank 2
181         // column reduction having following form:
182         // Suppose nelems = ProdutionOp(ShapeOp(I)), We convert I into
183         // shape `[nelems, 1]`.
184         // TODO(disc): this may have performance issue. Implements a reduce to
185         // scalar schedule if necessary.
186         newOperandDims.push_back(nelemToReduce);
187         newOperandDims.push_back(nelemToKeep);
188         attr = DenseIntElementsAttr::get(
189             RankedTensorType::get({1}, b.getIntegerType(64)), {0ll});
190       } else if (dimsToReduce[0] == 0) {
191         // case a) column reduction
192         newOperandDims.push_back(nelemToReduce);
193         newOperandDims.push_back(nelemToKeep);
194         attr = DenseIntElementsAttr::get(
195             RankedTensorType::get({1}, b.getIntegerType(64)), {0ll});
196       } else {
197         // case b) row reduction
198         newOperandDims.push_back(nelemToKeep);
199         newOperandDims.push_back(nelemToReduce);
200         attr = DenseIntElementsAttr::get(
201             RankedTensorType::get({1}, b.getIntegerType(64)), {1ll});
202       }
203 
204       Value newOperandShape =
205           b.create<tensor::FromElementsOp>(loc, newOperandDims);
206 
207       SmallVector<Value, 4> newOperands;
208       for (Value operand : op.operands()) {
209         newOperands.push_back(b.create<DynamicReshapeOp>(
210             loc,
211             RankedTensorType::get(
212                 SmallVector<int64_t, 4>(newOperandDims.size(),
213                                         ShapedType::kDynamicSize),
214                 elemTy),
215             operand, newOperandShape));
216       }
217       auto newOp = b.create<ReduceOp>(loc, newOperands, op.init_values(), attr);
218       newOp.body().takeBody(op.body());
219 
220       SmallVector<Value, 4> newResults;
221       if (dimsToKeep.empty()) {
222         // case c) reduce to scalar
223         // reshape rank 1 tensor with size 1 to a rank 0 tensor
224         for (Value result : newOp.getResults()) {
225           newResults.push_back(b.create<ReshapeOp>(
226               loc, RankedTensorType::get({}, elemTy), result));
227         }
228       } else {
229         SmallVector<Value, 4> resultDims;
230         for (int64_t i : dimsToKeep) {
231           Value dimIndex = b.create<tensor::DimOp>(loc, op.getOperand(0), i);
232           resultDims.push_back(
233               b.create<arith::IndexCastOp>(loc, shapeScalarType, dimIndex));
234         }
235         Value resultShape = b.create<tensor::FromElementsOp>(loc, resultDims);
236         for (auto&& e : llvm::zip(op.getResults(), newOp.getResults())) {
237           newResults.push_back(b.create<DynamicReshapeOp>(
238               loc, std::get<0>(e).getType(), std::get<1>(e), resultShape));
239         }
240       }
241       for (auto&& e : llvm::zip(op.getResults(), newResults)) {
242         std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
243       }
244       op.erase();
245     });
246   }
247 };
248 
249 }  // namespace
250 
251 std::unique_ptr<OperationPass<func::FuncOp>>
createHloCanonicalizeReductionPass()252 createHloCanonicalizeReductionPass() {
253   return std::make_unique<HloCanonicalizeReductionPass>();
254 }
255 
256 }  // namespace mhlo
257 }  // namespace mlir
258