1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <utility>
18
19 #include "llvm/ADT/ArrayRef.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
22 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
23 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
24 #include "mlir-hlo/Dialect/thlo/IR/thlo_ops.h"
25 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
26 #include "mlir/Dialect/Func/IR/FuncOps.h"
27 #include "mlir/Dialect/Linalg/IR/Linalg.h"
28 #include "mlir/Dialect/Tensor/IR/Tensor.h"
29 #include "mlir/IR/BuiltinTypes.h"
30 #include "mlir/Pass/Pass.h"
31 #include "mlir/Support/LogicalResult.h"
32 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
33
34 namespace mlir {
35 namespace mhlo {
36 namespace {
37
isIotaArray(llvm::ArrayRef<int64_t> array,int expectedSize=-1)38 bool isIotaArray(llvm::ArrayRef<int64_t> array, int expectedSize = -1) {
39 if (expectedSize != -1 && static_cast<int>(array.size()) != expectedSize)
40 return false;
41 for (int64_t i = 0, e = array.size(); i < e; ++i) {
42 if (i != array[i]) return false;
43 }
44 return true;
45 }
46
47 struct ConcatenateOpPattern : public OpRewritePattern<mhlo::ConcatenateOp> {
48 using OpRewritePattern<mhlo::ConcatenateOp>::OpRewritePattern;
49
matchAndRewritemlir::mhlo::__anon5b754e9e0111::ConcatenateOpPattern50 LogicalResult matchAndRewrite(mhlo::ConcatenateOp op,
51 PatternRewriter& rewriter) const override {
52 const uint64_t concatDim = op.dimension();
53 const Location loc = op.getLoc();
54 const Value anyOperand = op.val().front();
55
56 auto resultTy = op.getResult().getType().cast<RankedTensorType>();
57 const ArrayRef<int64_t> resultShape = resultTy.getShape();
58 const int64_t rank = resultTy.getRank();
59
60 // Determine init tensor size.
61 SmallVector<int64_t> staticInitSizes(resultShape.begin(),
62 resultShape.end());
63 SmallVector<Value> dynamicInitSizes;
64 for (int64_t i = 0; i < rank; ++i) {
65 // No need to materialize anything for static dimensions.
66 if (staticInitSizes[i] != ShapedType::kDynamicSize) {
67 continue;
68 }
69
70 // For all dimensions other than the concatenation dimension, we can copy
71 // the size from any operand.
72 if (i != concatDim) {
73 dynamicInitSizes.push_back(
74 rewriter.create<tensor::DimOp>(loc, anyOperand, i));
75 continue;
76 }
77
78 // For the concatenation dimensions, sum up the sizes of all operands in
79 // that dimension.
80 int64_t staticSum = 0;
81 Value dynamicSum;
82 for (const Value operand : op.val()) {
83 auto operandTy = operand.getType().cast<RankedTensorType>();
84 if (operandTy.getDimSize(concatDim) == ShapedType::kDynamicSize) {
85 const Value dynamicSummand =
86 rewriter.create<tensor::DimOp>(loc, operand, concatDim);
87 if (dynamicSum) {
88 dynamicSum =
89 rewriter.create<arith::AddIOp>(loc, dynamicSum, dynamicSummand);
90 } else {
91 dynamicSum = dynamicSummand;
92 }
93 } else {
94 staticSum += operandTy.getDimSize(concatDim);
95 }
96 }
97 assert(dynamicSum && "expect at least one dynamic summand in this case");
98 if (staticSum != 0) {
99 dynamicSum = rewriter.create<arith::AddIOp>(
100 loc, dynamicSum,
101 rewriter.create<arith::ConstantIndexOp>(loc, staticSum));
102 }
103 dynamicInitSizes.push_back(dynamicSum);
104 }
105
106 // Create init tensor and the new concat op.
107 auto init = rewriter.create<linalg::InitTensorOp>(
108 loc, dynamicInitSizes, staticInitSizes, resultTy.getElementType());
109 rewriter.replaceOpWithNewOp<thlo::ConcatenateOp>(op, resultTy, op.val(),
110 init, concatDim);
111 return success();
112 }
113 };
114
115 struct DynamicBroadcastInDimOpPattern
116 : public OpRewritePattern<mhlo::DynamicBroadcastInDimOp> {
117 using OpRewritePattern<mhlo::DynamicBroadcastInDimOp>::OpRewritePattern;
118
matchAndRewritemlir::mhlo::__anon5b754e9e0111::DynamicBroadcastInDimOpPattern119 LogicalResult matchAndRewrite(mhlo::DynamicBroadcastInDimOp op,
120 PatternRewriter& rewriter) const override {
121 auto loc = op.getLoc();
122 Value outputDimensions = op.output_dimensions();
123 auto operandTy = op.operand().getType().cast<RankedTensorType>();
124 auto resultTy = op.getType().cast<RankedTensorType>();
125
126 // Only apply to broadcasts that cannot be lowered to linalg, i.e. those
127 // for which we do not know their expansion behavior at compile time.
128 int64_t countKnownExpansionBehavior = 0;
129 if (auto expandingDims = op.known_expanding_dimensions()) {
130 countKnownExpansionBehavior += expandingDims->size();
131 }
132 if (auto nonexpandingDims = op.known_nonexpanding_dimensions()) {
133 countKnownExpansionBehavior += nonexpandingDims->size();
134 }
135 if (operandTy.getRank() == countKnownExpansionBehavior) return failure();
136
137 // Create init tensor as none of the operands are reusable/updatable.
138 SmallVector<Value> dynamicDims;
139 SmallVector<int64_t> staticShapeInfo;
140 for (int i = 0; i < resultTy.getRank(); i++) {
141 dynamicDims.push_back(rewriter.create<tensor::ExtractOp>(
142 loc, outputDimensions,
143 ValueRange{rewriter.create<arith::ConstantIndexOp>(loc, i)}));
144 staticShapeInfo.push_back(ShapedType::kDynamicSize);
145 }
146 auto initTensor = rewriter.create<linalg::InitTensorOp>(
147 loc, dynamicDims, staticShapeInfo, resultTy.getElementType());
148
149 // TODO(akuegel): Add a builder for getDenseI64ArrayAttr upstream.
150 auto broadcastDims = DenseI64ArrayAttr::get(
151 rewriter.getContext(),
152 llvm::to_vector(
153 llvm::map_range(op.broadcast_dimensions(), [](const auto& d) {
154 return static_cast<int64_t>(d.getLimitedValue());
155 })));
156 DenseI64ArrayAttr knownExpandingDims;
157 if (op.known_expanding_dimensions().has_value()) {
158 knownExpandingDims = DenseI64ArrayAttr::get(
159 rewriter.getContext(),
160 llvm::to_vector(llvm::map_range(
161 op.known_expanding_dimensionsAttr(), [](const auto& d) {
162 return static_cast<int64_t>(d.getLimitedValue());
163 })));
164 }
165 DenseI64ArrayAttr knownNonexpandingDims;
166 if (op.known_nonexpanding_dimensions().has_value()) {
167 knownNonexpandingDims = DenseI64ArrayAttr::get(
168 rewriter.getContext(),
169 llvm::to_vector(llvm::map_range(
170 op.known_nonexpanding_dimensionsAttr(), [](const auto& d) {
171 return static_cast<int64_t>(d.getLimitedValue());
172 })));
173 }
174
175 rewriter.replaceOpWithNewOp<thlo::DynamicBroadcastInDimOp>(
176 op, resultTy, op.operand(), initTensor, broadcastDims,
177 knownExpandingDims, knownNonexpandingDims);
178 return success();
179 }
180 };
181
182 // Rewrites simple gather patterns (as checked below).
183 struct GatherPattern : public OpRewritePattern<mhlo::GatherOp> {
184 using OpRewritePattern<mhlo::GatherOp>::OpRewritePattern;
185
matchAndRewritemlir::mhlo::__anon5b754e9e0111::GatherPattern186 LogicalResult matchAndRewrite(mhlo::GatherOp op,
187 PatternRewriter& rewriter) const override {
188 auto startIndicesType =
189 op.start_indices().getType().dyn_cast<RankedTensorType>();
190 auto operandType = op.operand().getType().dyn_cast<RankedTensorType>();
191
192 if (!startIndicesType || !operandType) return failure();
193
194 // index_vector_dim must be the last dimension of start_indices.
195 int indexVectorDim = op.dimension_numbers().getIndexVectorDim();
196 if (startIndicesType.getRank() - 1 != indexVectorDim) return failure();
197
198 // All slice_sizes must be 1.
199 if (!llvm::all_of(op.slice_sizes(), [](auto size) { return size == 1; }))
200 return failure();
201
202 // offset_dims must be []
203 if (!op.dimension_numbers().getOffsetDims().empty()) return failure();
204
205 // collapsed_slice_dims[] must be range(operand.rank)
206 auto collapsedSliceDims = op.dimension_numbers().getCollapsedSliceDims();
207 if (!isIotaArray(collapsedSliceDims, operandType.getRank()))
208 return failure();
209
210 // start_index_map[] must be range(start_indices.shape[index_vector_dim])
211 auto startIndexMap = op.dimension_numbers().getStartIndexMap();
212 if (!isIotaArray(startIndexMap,
213 startIndicesType.getShape()[indexVectorDim]))
214 return failure();
215
216 // The shape of the result must be statically known.
217 if (op.getType().getNumDynamicDims() > 0) return failure();
218
219 auto loc = op.getLoc();
220 auto initTensor = rewriter.create<linalg::InitTensorOp>(
221 loc, mlir::ValueRange{}, op.getType().getShape(),
222 op.getType().getElementType());
223 rewriter.replaceOpWithNewOp<thlo::GatherOp>(op, op.getType(), op.operand(),
224 op.start_indices(), initTensor);
225 return success();
226 }
227 };
228
229 // Rewrites simple scatter patterns.
230 struct ScatterPattern : public OpRewritePattern<mhlo::ScatterOp> {
231 using OpRewritePattern<mhlo::ScatterOp>::OpRewritePattern;
232
matchAndRewritemlir::mhlo::__anon5b754e9e0111::ScatterPattern233 LogicalResult matchAndRewrite(mhlo::ScatterOp op,
234 PatternRewriter& rewriter) const override {
235 // The variadic case is not supported.
236 if (op.updates().size() != 1) return failure();
237
238 // update_computation is sum.
239 if (matchUpdateComputation(op.update_computation()).failed())
240 return failure();
241
242 const auto& dims = op.scatter_dimension_numbers();
243 auto scatterIndicesType =
244 op.scatter_indices().getType().dyn_cast<RankedTensorType>();
245 if (!scatterIndicesType) return failure();
246
247 // Only point updates are supported.
248 // - update_window_dims is []
249 // - inserted_window_dims is range(operand.shape.rank)
250 // - scatter_dims_to_operand_dims is range(scatter_indices.shape.rank)
251 // - index_vector_dim is scatter_indices.shape.rank-1
252 if (!dims.getUpdateWindowDims().empty() ||
253 !isIotaArray(dims.getInsertedWindowDims()) ||
254 !isIotaArray(dims.getScatterDimsToOperandDims()) ||
255 dims.getIndexVectorDim() != scatterIndicesType.getRank() - 1)
256 return failure();
257
258 auto opType = op.getType(0).dyn_cast<ShapedType>();
259 if (!opType)
260 return failure(); // Type is a tensor in the non-variadic case.
261
262 rewriter.replaceOpWithNewOp<thlo::ScatterOp>(
263 op, opType, op.scatter_indices(), op.updates().front(),
264 op.operands().front());
265 return success();
266 }
267
matchUpdateComputationmlir::mhlo::__anon5b754e9e0111::ScatterPattern268 LogicalResult matchUpdateComputation(mlir::Region& computation) const {
269 Block& block = computation.front();
270 if (block.getNumArguments() != 2) return failure();
271
272 mhlo::ReturnOp returnOp = dyn_cast<mhlo::ReturnOp>(block.getTerminator());
273 if (!returnOp || returnOp.getNumOperands() != 1) return failure();
274
275 auto* returnOperand = returnOp.getOperand(0).getDefiningOp();
276 auto addOp = dyn_cast<mhlo::AddOp>(returnOperand);
277 if (!addOp || addOp->getNumOperands() != 2) return failure();
278
279 auto lhs = addOp->getOperand(0);
280 auto rhs = addOp->getOperand(1);
281 auto arg0 = block.getArgument(0);
282 auto arg1 = block.getArgument(1);
283
284 return success((lhs == arg0 && rhs == arg1) ||
285 (lhs == arg1 && rhs == arg0));
286 }
287 };
288
289 class LegalizeMHLOToTHLOPass
290 : public LegalizeMHLOToTHLOPassBase<LegalizeMHLOToTHLOPass> {
getDependentDialects(DialectRegistry & registry) const291 void getDependentDialects(DialectRegistry& registry) const final {
292 registry.insert<thlo::THLODialect, linalg::LinalgDialect>();
293 }
294
runOnOperation()295 void runOnOperation() final {
296 MLIRContext* ctx = &getContext();
297 RewritePatternSet patterns(ctx);
298
299 // List of patterns.
300 // clang-format off
301 patterns.insert<
302 ConcatenateOpPattern,
303 DynamicBroadcastInDimOpPattern,
304 GatherPattern,
305 ScatterPattern>(ctx);
306 // clang-format on
307
308 if (failed(applyPatternsAndFoldGreedily(getOperation(),
309 std::move(patterns)))) {
310 return signalPassFailure();
311 }
312 }
313 };
314
315 } // namespace
316
createLegalizeMHLOToTHLOPass()317 std::unique_ptr<OperationPass<func::FuncOp>> createLegalizeMHLOToTHLOPass() {
318 return std::make_unique<LegalizeMHLOToTHLOPass>();
319 }
320
321 } // namespace mhlo
322 } // namespace mlir
323