1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <utility>
18 
19 #include "llvm/ADT/ArrayRef.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
22 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
23 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
24 #include "mlir-hlo/Dialect/thlo/IR/thlo_ops.h"
25 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
26 #include "mlir/Dialect/Func/IR/FuncOps.h"
27 #include "mlir/Dialect/Linalg/IR/Linalg.h"
28 #include "mlir/Dialect/Tensor/IR/Tensor.h"
29 #include "mlir/IR/BuiltinTypes.h"
30 #include "mlir/Pass/Pass.h"
31 #include "mlir/Support/LogicalResult.h"
32 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
33 
34 namespace mlir {
35 namespace mhlo {
36 namespace {
37 
isIotaArray(llvm::ArrayRef<int64_t> array,int expectedSize=-1)38 bool isIotaArray(llvm::ArrayRef<int64_t> array, int expectedSize = -1) {
39   if (expectedSize != -1 && static_cast<int>(array.size()) != expectedSize)
40     return false;
41   for (int64_t i = 0, e = array.size(); i < e; ++i) {
42     if (i != array[i]) return false;
43   }
44   return true;
45 }
46 
47 struct ConcatenateOpPattern : public OpRewritePattern<mhlo::ConcatenateOp> {
48   using OpRewritePattern<mhlo::ConcatenateOp>::OpRewritePattern;
49 
matchAndRewritemlir::mhlo::__anon5b754e9e0111::ConcatenateOpPattern50   LogicalResult matchAndRewrite(mhlo::ConcatenateOp op,
51                                 PatternRewriter& rewriter) const override {
52     const uint64_t concatDim = op.dimension();
53     const Location loc = op.getLoc();
54     const Value anyOperand = op.val().front();
55 
56     auto resultTy = op.getResult().getType().cast<RankedTensorType>();
57     const ArrayRef<int64_t> resultShape = resultTy.getShape();
58     const int64_t rank = resultTy.getRank();
59 
60     // Determine init tensor size.
61     SmallVector<int64_t> staticInitSizes(resultShape.begin(),
62                                          resultShape.end());
63     SmallVector<Value> dynamicInitSizes;
64     for (int64_t i = 0; i < rank; ++i) {
65       // No need to materialize anything for static dimensions.
66       if (staticInitSizes[i] != ShapedType::kDynamicSize) {
67         continue;
68       }
69 
70       // For all dimensions other than the concatenation dimension, we can copy
71       // the size from any operand.
72       if (i != concatDim) {
73         dynamicInitSizes.push_back(
74             rewriter.create<tensor::DimOp>(loc, anyOperand, i));
75         continue;
76       }
77 
78       // For the concatenation dimensions, sum up the sizes of all operands in
79       // that dimension.
80       int64_t staticSum = 0;
81       Value dynamicSum;
82       for (const Value operand : op.val()) {
83         auto operandTy = operand.getType().cast<RankedTensorType>();
84         if (operandTy.getDimSize(concatDim) == ShapedType::kDynamicSize) {
85           const Value dynamicSummand =
86               rewriter.create<tensor::DimOp>(loc, operand, concatDim);
87           if (dynamicSum) {
88             dynamicSum =
89                 rewriter.create<arith::AddIOp>(loc, dynamicSum, dynamicSummand);
90           } else {
91             dynamicSum = dynamicSummand;
92           }
93         } else {
94           staticSum += operandTy.getDimSize(concatDim);
95         }
96       }
97       assert(dynamicSum && "expect at least one dynamic summand in this case");
98       if (staticSum != 0) {
99         dynamicSum = rewriter.create<arith::AddIOp>(
100             loc, dynamicSum,
101             rewriter.create<arith::ConstantIndexOp>(loc, staticSum));
102       }
103       dynamicInitSizes.push_back(dynamicSum);
104     }
105 
106     // Create init tensor and the new concat op.
107     auto init = rewriter.create<linalg::InitTensorOp>(
108         loc, dynamicInitSizes, staticInitSizes, resultTy.getElementType());
109     rewriter.replaceOpWithNewOp<thlo::ConcatenateOp>(op, resultTy, op.val(),
110                                                      init, concatDim);
111     return success();
112   }
113 };
114 
115 struct DynamicBroadcastInDimOpPattern
116     : public OpRewritePattern<mhlo::DynamicBroadcastInDimOp> {
117   using OpRewritePattern<mhlo::DynamicBroadcastInDimOp>::OpRewritePattern;
118 
matchAndRewritemlir::mhlo::__anon5b754e9e0111::DynamicBroadcastInDimOpPattern119   LogicalResult matchAndRewrite(mhlo::DynamicBroadcastInDimOp op,
120                                 PatternRewriter& rewriter) const override {
121     auto loc = op.getLoc();
122     Value outputDimensions = op.output_dimensions();
123     auto operandTy = op.operand().getType().cast<RankedTensorType>();
124     auto resultTy = op.getType().cast<RankedTensorType>();
125 
126     // Only  apply to broadcasts that cannot be lowered to linalg, i.e. those
127     // for which we do not know their expansion behavior at compile time.
128     int64_t countKnownExpansionBehavior = 0;
129     if (auto expandingDims = op.known_expanding_dimensions()) {
130       countKnownExpansionBehavior += expandingDims->size();
131     }
132     if (auto nonexpandingDims = op.known_nonexpanding_dimensions()) {
133       countKnownExpansionBehavior += nonexpandingDims->size();
134     }
135     if (operandTy.getRank() == countKnownExpansionBehavior) return failure();
136 
137     // Create init tensor as none of the operands are reusable/updatable.
138     SmallVector<Value> dynamicDims;
139     SmallVector<int64_t> staticShapeInfo;
140     for (int i = 0; i < resultTy.getRank(); i++) {
141       dynamicDims.push_back(rewriter.create<tensor::ExtractOp>(
142           loc, outputDimensions,
143           ValueRange{rewriter.create<arith::ConstantIndexOp>(loc, i)}));
144       staticShapeInfo.push_back(ShapedType::kDynamicSize);
145     }
146     auto initTensor = rewriter.create<linalg::InitTensorOp>(
147         loc, dynamicDims, staticShapeInfo, resultTy.getElementType());
148 
149     // TODO(akuegel): Add a builder for getDenseI64ArrayAttr upstream.
150     auto broadcastDims = DenseI64ArrayAttr::get(
151         rewriter.getContext(),
152         llvm::to_vector(
153             llvm::map_range(op.broadcast_dimensions(), [](const auto& d) {
154               return static_cast<int64_t>(d.getLimitedValue());
155             })));
156     DenseI64ArrayAttr knownExpandingDims;
157     if (op.known_expanding_dimensions().has_value()) {
158       knownExpandingDims = DenseI64ArrayAttr::get(
159           rewriter.getContext(),
160           llvm::to_vector(llvm::map_range(
161               op.known_expanding_dimensionsAttr(), [](const auto& d) {
162                 return static_cast<int64_t>(d.getLimitedValue());
163               })));
164     }
165     DenseI64ArrayAttr knownNonexpandingDims;
166     if (op.known_nonexpanding_dimensions().has_value()) {
167       knownNonexpandingDims = DenseI64ArrayAttr::get(
168           rewriter.getContext(),
169           llvm::to_vector(llvm::map_range(
170               op.known_nonexpanding_dimensionsAttr(), [](const auto& d) {
171                 return static_cast<int64_t>(d.getLimitedValue());
172               })));
173     }
174 
175     rewriter.replaceOpWithNewOp<thlo::DynamicBroadcastInDimOp>(
176         op, resultTy, op.operand(), initTensor, broadcastDims,
177         knownExpandingDims, knownNonexpandingDims);
178     return success();
179   }
180 };
181 
182 // Rewrites simple gather patterns (as checked below).
183 struct GatherPattern : public OpRewritePattern<mhlo::GatherOp> {
184   using OpRewritePattern<mhlo::GatherOp>::OpRewritePattern;
185 
matchAndRewritemlir::mhlo::__anon5b754e9e0111::GatherPattern186   LogicalResult matchAndRewrite(mhlo::GatherOp op,
187                                 PatternRewriter& rewriter) const override {
188     auto startIndicesType =
189         op.start_indices().getType().dyn_cast<RankedTensorType>();
190     auto operandType = op.operand().getType().dyn_cast<RankedTensorType>();
191 
192     if (!startIndicesType || !operandType) return failure();
193 
194     // index_vector_dim must be the last dimension of start_indices.
195     int indexVectorDim = op.dimension_numbers().getIndexVectorDim();
196     if (startIndicesType.getRank() - 1 != indexVectorDim) return failure();
197 
198     // All slice_sizes must be 1.
199     if (!llvm::all_of(op.slice_sizes(), [](auto size) { return size == 1; }))
200       return failure();
201 
202     // offset_dims must be []
203     if (!op.dimension_numbers().getOffsetDims().empty()) return failure();
204 
205     // collapsed_slice_dims[] must be range(operand.rank)
206     auto collapsedSliceDims = op.dimension_numbers().getCollapsedSliceDims();
207     if (!isIotaArray(collapsedSliceDims, operandType.getRank()))
208       return failure();
209 
210     // start_index_map[] must be range(start_indices.shape[index_vector_dim])
211     auto startIndexMap = op.dimension_numbers().getStartIndexMap();
212     if (!isIotaArray(startIndexMap,
213                      startIndicesType.getShape()[indexVectorDim]))
214       return failure();
215 
216     // The shape of the result must be statically known.
217     if (op.getType().getNumDynamicDims() > 0) return failure();
218 
219     auto loc = op.getLoc();
220     auto initTensor = rewriter.create<linalg::InitTensorOp>(
221         loc, mlir::ValueRange{}, op.getType().getShape(),
222         op.getType().getElementType());
223     rewriter.replaceOpWithNewOp<thlo::GatherOp>(op, op.getType(), op.operand(),
224                                                 op.start_indices(), initTensor);
225     return success();
226   }
227 };
228 
229 // Rewrites simple scatter patterns.
230 struct ScatterPattern : public OpRewritePattern<mhlo::ScatterOp> {
231   using OpRewritePattern<mhlo::ScatterOp>::OpRewritePattern;
232 
matchAndRewritemlir::mhlo::__anon5b754e9e0111::ScatterPattern233   LogicalResult matchAndRewrite(mhlo::ScatterOp op,
234                                 PatternRewriter& rewriter) const override {
235     // The variadic case is not supported.
236     if (op.updates().size() != 1) return failure();
237 
238     // update_computation is sum.
239     if (matchUpdateComputation(op.update_computation()).failed())
240       return failure();
241 
242     const auto& dims = op.scatter_dimension_numbers();
243     auto scatterIndicesType =
244         op.scatter_indices().getType().dyn_cast<RankedTensorType>();
245     if (!scatterIndicesType) return failure();
246 
247     // Only point updates are supported.
248     //  - update_window_dims is []
249     //  - inserted_window_dims is range(operand.shape.rank)
250     //  - scatter_dims_to_operand_dims is range(scatter_indices.shape.rank)
251     //  - index_vector_dim is scatter_indices.shape.rank-1
252     if (!dims.getUpdateWindowDims().empty() ||
253         !isIotaArray(dims.getInsertedWindowDims()) ||
254         !isIotaArray(dims.getScatterDimsToOperandDims()) ||
255         dims.getIndexVectorDim() != scatterIndicesType.getRank() - 1)
256       return failure();
257 
258     auto opType = op.getType(0).dyn_cast<ShapedType>();
259     if (!opType)
260       return failure();  // Type is a tensor in the non-variadic case.
261 
262     rewriter.replaceOpWithNewOp<thlo::ScatterOp>(
263         op, opType, op.scatter_indices(), op.updates().front(),
264         op.operands().front());
265     return success();
266   }
267 
matchUpdateComputationmlir::mhlo::__anon5b754e9e0111::ScatterPattern268   LogicalResult matchUpdateComputation(mlir::Region& computation) const {
269     Block& block = computation.front();
270     if (block.getNumArguments() != 2) return failure();
271 
272     mhlo::ReturnOp returnOp = dyn_cast<mhlo::ReturnOp>(block.getTerminator());
273     if (!returnOp || returnOp.getNumOperands() != 1) return failure();
274 
275     auto* returnOperand = returnOp.getOperand(0).getDefiningOp();
276     auto addOp = dyn_cast<mhlo::AddOp>(returnOperand);
277     if (!addOp || addOp->getNumOperands() != 2) return failure();
278 
279     auto lhs = addOp->getOperand(0);
280     auto rhs = addOp->getOperand(1);
281     auto arg0 = block.getArgument(0);
282     auto arg1 = block.getArgument(1);
283 
284     return success((lhs == arg0 && rhs == arg1) ||
285                    (lhs == arg1 && rhs == arg0));
286   }
287 };
288 
289 class LegalizeMHLOToTHLOPass
290     : public LegalizeMHLOToTHLOPassBase<LegalizeMHLOToTHLOPass> {
getDependentDialects(DialectRegistry & registry) const291   void getDependentDialects(DialectRegistry& registry) const final {
292     registry.insert<thlo::THLODialect, linalg::LinalgDialect>();
293   }
294 
runOnOperation()295   void runOnOperation() final {
296     MLIRContext* ctx = &getContext();
297     RewritePatternSet patterns(ctx);
298 
299     // List of patterns.
300     // clang-format off
301     patterns.insert<
302         ConcatenateOpPattern,
303         DynamicBroadcastInDimOpPattern,
304         GatherPattern,
305         ScatterPattern>(ctx);
306     // clang-format on
307 
308     if (failed(applyPatternsAndFoldGreedily(getOperation(),
309                                             std::move(patterns)))) {
310       return signalPassFailure();
311     }
312   }
313 };
314 
315 }  // namespace
316 
createLegalizeMHLOToTHLOPass()317 std::unique_ptr<OperationPass<func::FuncOp>> createLegalizeMHLOToTHLOPass() {
318   return std::make_unique<LegalizeMHLOToTHLOPass>();
319 }
320 
321 }  // namespace mhlo
322 }  // namespace mlir
323