1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file canonicalize reduction ops in hlo dialect to match the
17 // capacity of codegen backend.
18
19 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
20 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
21 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
22 #include "mlir/Dialect/Func/IR/FuncOps.h"
23 #include "mlir/Dialect/Tensor/IR/Tensor.h"
24
25 namespace mlir {
26 namespace mhlo {
27 namespace {
28
29 // All the reduce ops can be divided into following four types:
30 // - a) column reduction, only reduce the most significant dimensions.
31 // - b) row reduction, only reduce the least significant dimensions.
32 // - c) reduce to scalar, all dimensions are reduced.
33 // - d) others. (not support now, maybe use transpose to canonicalize)
34 //
35 // Currently we do following canonicalization to match the capacity of codegen
36 // backend.
37 //
38 // For case a):
39 // ====================================================================================
40 // we convert all column reduction to rank-2 column reduction.
41 // For example, suppose we have:
42 // ```
43 // func @test(%arg0: tensor<?x?x?xf32>) -> tensor<?x?xf32> {
44 // ...
45 // %2 = "mhlo.reduce"(%arg0, ...) ({...})
46 // {dimensions = dense<[0]> : tensor<1xi64>} :
47 // (tensor<?x?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
48 // return %2 : tensor<?x?xf32>
49 // }
50 // ```
51 // After conversion:
52 // func @test(%arg0: tensor<?x?x?xf32>) -> tensor<?x?xf32> {
53 // // [a, b, c] -> [a, b*c]
54 // %1 = mhlo.dynamic_reshape(%arg0, ...) : (tensor<?x?x?xf32>,
55 // tensor<2xi64>) -> tensor<?x?xf32> %2 = "mhlo.reduce"(%1, ...) ({...})
56 // {dimensions = dense<[0]> : tensor<1xi64>} :
57 // (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32>
58 // %3 = "mhlo.dynamic_reshape"(%2, ...) : (tensor<?xf32>, tensor<1xi64>)
59 // -> tensor<?x?f32> return %3 : tensor<?x?xf32>
60 // }
61 // ```
62 //
63 // For case b):
64 // ====================================================================================
65 // we convert all row reduction to rank-2 row reduction.
66 // For example, suppose we have:
67 // ```
68 // func @test(%arg0: tensor<?x?x?xf32>) -> tensor<?x?xf32> {
69 // ...
70 // %2 = "mhlo.reduce"(%arg0, ...) ({...})
71 // {dimensions = dense<[2]> : tensor<1xi64>} :
72 // (tensor<?x?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
73 // return %2 : tensor<?x?xf32>
74 // }
75 // ```
76 // After conversion:
77 // func @test(%arg0: tensor<?x?x?xf32>) -> tensor<?x?xf32> {
78 // // [a, b, c] -> [a*b, c]
79 // %1 = mhlo.dynamic_reshape(%arg0, ...) : (tensor<?x?x?xf32>,
80 // tensor<2xi64>) -> tensor<?x?xf32> %2 = "mhlo.reduce"(%1, ...) ({...})
81 // {dimensions = dense<[1]> : tensor<1xi64>} :
82 // (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32>
83 // %3 = "mhlo.dynamic_reshape"(%2, ...) : (tensor<?xf32>, tensor<1xi64>)
84 // -> tensor<?x?f32> return %3 : tensor<?x?xf32>
85 // }
86 // ```
87 //
88 // For case c):
89 // ====================================================================================
90 // we convert all reduce-to-scalar to rank-2 column reduction.
91 //
92 // For example, suppose we have:
93 // ```
94 // func @test(%arg0: tensor<?x?x?xf32>) -> tensor<f32> {
95 // ...
96 // %2 = "mhlo.reduce"(%arg0, ...) ({...})
97 // {dimensions = dense<[0,1,2]> : tensor<3xi64>} :
98 // (tensor<?x?x?xf32>, tensor<f32>) -> tensor<f32>
99 // return %2 : tensor<f32>
100 // }
101 // ```
102 // After conversion:
103 // func @test(%arg0: tensor<?x?x?xf32>) -> tensor<f32> {
104 // // [a, b, c] -> [a*b*c, 1]
105 // %1 = mhlo.dynamic_reshape(%arg0, ...) : (tensor<?x?x?xf32>,
106 // tensor<2xi64>) -> tensor<?x?xf32> %2 = "mhlo.reduce"(%1, ...) ({...})
107 // {dimensions = dense<[0]> : tensor<1xi64>} :
108 // (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32>
109 // %3 = "mhlo.reshape"(%2, ...) : (tensor<?xf32>, tensor<1xi64>) ->
110 // tensor<f32> return %3 : tensor<f32>
111 // }
112 // ```
113
114 struct HloCanonicalizeReductionPass
115 : HloCanonicalizeReductionPassBase<HloCanonicalizeReductionPass> {
getDependentDialectsmlir::mhlo::__anon9fd05cb50111::HloCanonicalizeReductionPass116 void getDependentDialects(DialectRegistry& registry) const override {
117 registry.insert<tensor::TensorDialect>();
118 }
runOnOperationmlir::mhlo::__anon9fd05cb50111::HloCanonicalizeReductionPass119 void runOnOperation() override {
120 getOperation().walk([&](ReduceOp op) {
121 SmallVector<int64_t, 4> dimsToReduce;
122 DenseSet<int64_t> dimsToReduceSet;
123 for (auto dim : op.dimensions().getValues<APInt>()) {
124 dimsToReduce.push_back(dim.getSExtValue());
125 dimsToReduceSet.insert(dimsToReduce.back());
126 }
127
128 // empty reduction is just a no-op, thus no need to do codegen.
129 if (dimsToReduce.empty()) return;
130
131 // suppose reduce input is a ranked tensor
132 auto ty = op.getOperand(0).getType().dyn_cast<RankedTensorType>();
133 if (!ty) return signalPassFailure();
134 int rank = ty.getRank();
135 int ndimsToReduce = dimsToReduce.size();
136 auto elemTy = ty.getElementType();
137 llvm::sort(dimsToReduce);
138
139 // skip case d) form since we don't support it.
140 if ((dimsToReduce.back() - dimsToReduce[0]) != (ndimsToReduce - 1) ||
141 (dimsToReduce[0] != 0 && dimsToReduce.back() != (rank - 1))) {
142 return;
143 }
144
145 // rank 2 row/column reduction is already supported.
146 if (rank == 2 && ndimsToReduce == 1) {
147 return;
148 }
149
150 SmallVector<int64_t, 4> dimsToKeep;
151 for (int i = 0; i < rank; ++i) {
152 if (!dimsToReduceSet.count(i)) dimsToKeep.push_back(i);
153 }
154
155 OpBuilder b(op);
156 auto loc = op.getLoc();
157 // TODO(disc): uniformed shape_scalar_type with shape_derivation
158 auto shapeScalarType = b.getIntegerType(32);
159 auto one = b.create<arith::ConstantIntOp>(loc, 1ll, shapeScalarType);
160
161 // funtion to get total elements in selected dimensions
162 auto dimProd = [&](ArrayRef<int64_t> dims) {
163 Value nelems = one;
164 for (int64_t v : dims) {
165 Value dimIndex = b.create<tensor::DimOp>(loc, op.getOperand(0), v);
166 nelems = b.create<arith::MulIOp>(
167 loc, nelems,
168 b.create<arith::IndexCastOp>(loc, shapeScalarType, dimIndex));
169 }
170 return nelems;
171 };
172
173 SmallVector<Value, 2> newOperandDims;
174 DenseIntElementsAttr attr;
175 Value nelemToReduce = dimProd(dimsToReduce);
176 Value nelemToKeep = dimProd(dimsToKeep);
177 if (rank == ndimsToReduce) {
178 // case c) Reduce to scalar.
179 // Currently we don't support reduce to scalar directly.
180 // As a workaround, we convert the `reduce to scalar` to a rank 2
181 // column reduction having following form:
182 // Suppose nelems = ProdutionOp(ShapeOp(I)), We convert I into
183 // shape `[nelems, 1]`.
184 // TODO(disc): this may have performance issue. Implements a reduce to
185 // scalar schedule if necessary.
186 newOperandDims.push_back(nelemToReduce);
187 newOperandDims.push_back(nelemToKeep);
188 attr = DenseIntElementsAttr::get(
189 RankedTensorType::get({1}, b.getIntegerType(64)), {0ll});
190 } else if (dimsToReduce[0] == 0) {
191 // case a) column reduction
192 newOperandDims.push_back(nelemToReduce);
193 newOperandDims.push_back(nelemToKeep);
194 attr = DenseIntElementsAttr::get(
195 RankedTensorType::get({1}, b.getIntegerType(64)), {0ll});
196 } else {
197 // case b) row reduction
198 newOperandDims.push_back(nelemToKeep);
199 newOperandDims.push_back(nelemToReduce);
200 attr = DenseIntElementsAttr::get(
201 RankedTensorType::get({1}, b.getIntegerType(64)), {1ll});
202 }
203
204 Value newOperandShape =
205 b.create<tensor::FromElementsOp>(loc, newOperandDims);
206
207 SmallVector<Value, 4> newOperands;
208 for (Value operand : op.operands()) {
209 newOperands.push_back(b.create<DynamicReshapeOp>(
210 loc,
211 RankedTensorType::get(
212 SmallVector<int64_t, 4>(newOperandDims.size(),
213 ShapedType::kDynamicSize),
214 elemTy),
215 operand, newOperandShape));
216 }
217 auto newOp = b.create<ReduceOp>(loc, newOperands, op.init_values(), attr);
218 newOp.body().takeBody(op.body());
219
220 SmallVector<Value, 4> newResults;
221 if (dimsToKeep.empty()) {
222 // case c) reduce to scalar
223 // reshape rank 1 tensor with size 1 to a rank 0 tensor
224 for (Value result : newOp.getResults()) {
225 newResults.push_back(b.create<ReshapeOp>(
226 loc, RankedTensorType::get({}, elemTy), result));
227 }
228 } else {
229 SmallVector<Value, 4> resultDims;
230 for (int64_t i : dimsToKeep) {
231 Value dimIndex = b.create<tensor::DimOp>(loc, op.getOperand(0), i);
232 resultDims.push_back(
233 b.create<arith::IndexCastOp>(loc, shapeScalarType, dimIndex));
234 }
235 Value resultShape = b.create<tensor::FromElementsOp>(loc, resultDims);
236 for (auto&& e : llvm::zip(op.getResults(), newOp.getResults())) {
237 newResults.push_back(b.create<DynamicReshapeOp>(
238 loc, std::get<0>(e).getType(), std::get<1>(e), resultShape));
239 }
240 }
241 for (auto&& e : llvm::zip(op.getResults(), newResults)) {
242 std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
243 }
244 op.erase();
245 });
246 }
247 };
248
249 } // namespace
250
251 std::unique_ptr<OperationPass<func::FuncOp>>
createHloCanonicalizeReductionPass()252 createHloCanonicalizeReductionPass() {
253 return std::make_unique<HloCanonicalizeReductionPass>();
254 }
255
256 } // namespace mhlo
257 } // namespace mlir
258