1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for lowering HLO dialect to LHLO dialect.
17 
18 #include <functional>
19 #include <memory>
20 #include <utility>
21 
22 #include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
23 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
24 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
25 #include "mlir-hlo/Dialect/mhlo/transforms/bufferizable_op_interface_impl.h"
26 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
27 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
28 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
29 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
30 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
31 #include "mlir/Dialect/MemRef/IR/MemRef.h"
32 #include "mlir/IR/BuiltinOps.h"
33 #include "mlir/IR/BuiltinTypes.h"
34 #include "mlir/IR/Value.h"
35 #include "mlir/Pass/Pass.h"
36 
37 namespace mlir {
38 namespace mhlo {
39 namespace {
40 
41 using bufferization::AnalysisState;
42 using bufferization::BufferizableOpInterface;
43 using bufferization::BufferizationOptions;
44 using bufferization::BufferRelation;
45 using bufferization::replaceOpWithNewBufferizedOp;
46 
47 struct CustomCallOpInterface
48     : public BufferizableOpInterface::ExternalModel<CustomCallOpInterface,
49                                                     mhlo::CustomCallOp> {
bufferizesToMemoryReadmlir::mhlo::__anon7d94d1360111::CustomCallOpInterface50   bool bufferizesToMemoryRead(Operation *, OpOperand &,
51                               const AnalysisState &) const {
52     return true;
53   }
54 
bufferizesToMemoryWritemlir::mhlo::__anon7d94d1360111::CustomCallOpInterface55   bool bufferizesToMemoryWrite(Operation *, OpOperand &,
56                                const AnalysisState &) const {
57     return false;  // Arguments are read-only.
58   }
59 
getAliasingOpResultmlir::mhlo::__anon7d94d1360111::CustomCallOpInterface60   SmallVector<OpResult> getAliasingOpResult(Operation *, OpOperand &,
61                                             const AnalysisState &) const {
62     return {};
63   }
64 
bufferizemlir::mhlo::__anon7d94d1360111::CustomCallOpInterface65   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
66                           const BufferizationOptions &options) const {
67     auto customCallOp = cast<mhlo::CustomCallOp>(op);
68 
69     // Bufferize arguments.
70     SmallVector<Value> bufferArgs;
71     for (OpOperand &operand : customCallOp->getOpOperands()) {
72       if (!operand.get().getType().isa<TensorType>()) return failure();
73       FailureOr<Value> operandBuffer =
74           getBuffer(rewriter, operand.get(), options);
75       if (failed(operandBuffer)) return failure();
76       bufferArgs.push_back(*operandBuffer);
77     }
78 
79     // Allocate outputs.
80     for (OpResult result : customCallOp->getOpResults()) {
81       auto tensorType = result.getType().cast<RankedTensorType>();
82       if (!tensorType) return failure();
83       // TODO(springerm): Create alloc_tensor ops during TensorCopyInsertion.
84       AnalysisState analysisState(options);
85       FailureOr<Value> tensorAlloc =
86           bufferization::allocateTensorForShapedValue(
87               rewriter, op->getLoc(), result,
88               analysisState.isTensorYielded(result), options);
89       if (failed(tensorAlloc)) return failure();
90       auto memrefType =
91           MemRefType::get(tensorType.getShape(), tensorType.getElementType());
92       Value resultBuffer = rewriter.create<bufferization::ToMemrefOp>(
93           op->getLoc(), memrefType, *tensorAlloc);
94       bufferArgs.push_back(resultBuffer);
95     }
96 
97     auto lhloOp = rewriter.create<lmhlo::CustomCallOp>(
98         op->getLoc(), llvm::None, bufferArgs, op->getAttrs());
99     // lmhlo.custom_call uses a segment_size attribute to tell input from output
100     // arguments.
101     lhloOp->setAttr(lhloOp.getOperandSegmentSizeAttr(),
102                     rewriter.getDenseI32ArrayAttr(
103                         {static_cast<int32_t>(op->getNumOperands()),
104                          static_cast<int32_t>(op->getNumResults())}));
105     bufferization::replaceOpWithBufferizedValues(
106         rewriter, op, makeArrayRef(bufferArgs).slice(op->getNumOperands()));
107     return success();
108   }
109 };
110 
111 struct ReshapeOpInterface
112     : public BufferizableOpInterface::ExternalModel<ReshapeOpInterface,
113                                                     mhlo::ReshapeOp> {
bufferizesToMemoryReadmlir::mhlo::__anon7d94d1360111::ReshapeOpInterface114   bool bufferizesToMemoryRead(Operation * /*op*/, OpOperand & /*opOperand*/,
115                               const AnalysisState & /*state*/) const {
116     return false;
117   }
118 
bufferizesToMemoryWritemlir::mhlo::__anon7d94d1360111::ReshapeOpInterface119   bool bufferizesToMemoryWrite(Operation * /*op*/, OpOperand & /*opOperand*/,
120                                const AnalysisState & /*state*/) const {
121     return false;
122   }
123 
getAliasingOpResultmlir::mhlo::__anon7d94d1360111::ReshapeOpInterface124   SmallVector<OpResult> getAliasingOpResult(
125       Operation *op, OpOperand & /*opOperand*/,
126       const AnalysisState & /*state*/) const {
127     return {op->getResult(0)};
128   }
129 
bufferRelationmlir::mhlo::__anon7d94d1360111::ReshapeOpInterface130   BufferRelation bufferRelation(Operation * /*op*/, OpResult /*opResult*/,
131                                 const AnalysisState & /*state*/) const {
132     return BufferRelation::Equivalent;
133   }
134 
bufferizemlir::mhlo::__anon7d94d1360111::ReshapeOpInterface135   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
136                           const BufferizationOptions &options) const {
137     auto reshapeOp = cast<mhlo::ReshapeOp>(op);
138     auto unrankedOperandType =
139         reshapeOp.operand().getType().dyn_cast<UnrankedTensorType>();
140     if (unrankedOperandType == nullptr) return success();
141 
142     // The buffer still has the old (pre-reshape) type.
143     FailureOr<Value> operandBuffer =
144         getBuffer(rewriter, reshapeOp.operand(), options);
145     if (failed(operandBuffer)) return failure();
146 
147     auto resultType = reshapeOp.getType().cast<RankedTensorType>();
148     auto destType =
149         MemRefType::get(resultType.getShape(), resultType.getElementType());
150     replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, destType,
151                                                  *operandBuffer);
152     return success();
153   }
154 };
155 
156 struct DynamicReshapeOpInterface
157     : public BufferizableOpInterface::ExternalModel<DynamicReshapeOpInterface,
158                                                     mhlo::DynamicReshapeOp> {
bufferizesToMemoryReadmlir::mhlo::__anon7d94d1360111::DynamicReshapeOpInterface159   bool bufferizesToMemoryRead(Operation * /*op*/, OpOperand & /*opOperand*/,
160                               const AnalysisState & /*state*/) const {
161     return false;
162   }
163 
bufferizesToMemoryWritemlir::mhlo::__anon7d94d1360111::DynamicReshapeOpInterface164   bool bufferizesToMemoryWrite(Operation * /*op*/, OpOperand & /*opOperand*/,
165                                const AnalysisState & /*state*/) const {
166     return false;
167   }
168 
getAliasingOpResultmlir::mhlo::__anon7d94d1360111::DynamicReshapeOpInterface169   SmallVector<OpResult> getAliasingOpResult(
170       Operation *op, OpOperand & /*opOperand*/,
171       const AnalysisState & /*state*/) const {
172     return {op->getResult(0)};
173   }
174 
bufferRelationmlir::mhlo::__anon7d94d1360111::DynamicReshapeOpInterface175   BufferRelation bufferRelation(Operation * /*op*/, OpResult /*opResult*/,
176                                 const AnalysisState & /*state*/) const {
177     return BufferRelation::Equivalent;
178   }
179 
bufferizemlir::mhlo::__anon7d94d1360111::DynamicReshapeOpInterface180   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
181                           const BufferizationOptions &options) const {
182     auto reshapeOp = cast<mhlo::DynamicReshapeOp>(op);
183 
184     // The buffer still has the old (pre-reshape) type.
185     FailureOr<Value> operandBuffer =
186         getBuffer(rewriter, reshapeOp.operand(), options);
187     FailureOr<Value> outputShapeBuffer =
188         getBuffer(rewriter, reshapeOp.output_shape(), options);
189     if (failed(operandBuffer) || failed(outputShapeBuffer)) return failure();
190 
191     ShapedType resultType;
192     TensorType opResultType = reshapeOp.getType();
193     if (auto rankedType = opResultType.dyn_cast<RankedTensorType>()) {
194       resultType =
195           MemRefType::get(rankedType.getShape(), rankedType.getElementType());
196     } else if (auto unrankedType =
197                    opResultType.dyn_cast<UnrankedTensorType>()) {
198       resultType = UnrankedMemRefType::get(unrankedType.getElementType(), 0);
199     }
200     auto operand = *operandBuffer;
201     // If the operand has a non-identity affine map, we will have to add a copy.
202     auto bufferType = operandBuffer->getType().dyn_cast<MemRefType>();
203     if (bufferType && !bufferType.getLayout().isIdentity()) {
204       // TODO(springerm): Create alloc_tensor ops during TensorCopyInsertion.
205       AnalysisState analysisState(options);
206       FailureOr<Value> tensorAlloc =
207           bufferization::allocateTensorForShapedValue(
208               rewriter, op->getLoc(), *operandBuffer,
209               analysisState.isTensorYielded(reshapeOp.getResult()), options);
210       if (failed(tensorAlloc)) return failure();
211       auto memrefType =
212           MemRefType::get(bufferType.getShape(), bufferType.getElementType());
213       operand = rewriter.create<bufferization::ToMemrefOp>(
214           op->getLoc(), memrefType, *tensorAlloc);
215     }
216     bufferization::replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
217         rewriter, op, resultType, operand, *outputShapeBuffer);
218     return success();
219   }
220 };
221 
222 // Inserts dynamic memref to change the layout of the memref to put 0-stride
223 // and size of the target dimension if size-1 dimension expansion is
224 // necessary.
insertDynamicMemrefCastOp(mhlo::DynamicBroadcastInDimOp op,Value operand,RewriterBase & rewriter,const BufferizationOptions & options)225 FailureOr<Value> insertDynamicMemrefCastOp(
226     mhlo::DynamicBroadcastInDimOp op, Value operand, RewriterBase &rewriter,
227     const BufferizationOptions &options) {
228   auto loc = op.getLoc();
229   auto operandType = operand.getType().cast<MemRefType>();
230   auto operandShape = operandType.getShape();
231   auto operandRank = operandType.getRank();
232 
233   auto resultType = op.getType().cast<RankedTensorType>();
234   auto resultRank = resultType.getRank();
235 
236   Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
237   Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
238 
239   // Compute a reversed scan product. Compute the stride for the dimensions so
240   // far, working from minor to major dimensions. Additionally, save the
241   // operand shape Values to use in the next loop.
242   SmallVector<Value, 2> operandStrides(operandRank, one);
243   SmallVector<Value, 2> operandSizes(operandRank, one);
244   Value strideSoFar = one;
245   for (int i = operandRank - 1; i >= 0; --i) {
246     Value operandDimSize =
247         ShapedType::isDynamic(operandShape[i])
248             ? rewriter.create<memref::DimOp>(loc, operand, i).getResult()
249             : rewriter.create<arith::ConstantIndexOp>(loc, operandShape[i])
250                   .getResult();
251     operandSizes[i] = operandDimSize;
252 
253     operandStrides[i] = strideSoFar;
254     if (i > 0) {
255       strideSoFar =
256           rewriter.create<arith::MulIOp>(loc, strideSoFar, operandDimSize);
257     }
258   }
259 
260   SmallVector<OpFoldResult, 2> sizes, strides;
261   sizes.reserve(resultRank);
262   strides.reserve(resultRank);
263 
264   DenseMap<int, int> outputToInputDim;
265   for (const auto &dim : llvm::enumerate(op.broadcast_dimensions())) {
266     outputToInputDim[dim.value().getSExtValue()] = dim.index();
267   }
268   for (int i = 0; i < resultRank; ++i) {
269     Value iVal = rewriter.create<arith::ConstantIndexOp>(loc, i);
270     FailureOr<Value> outputDimsBuffer =
271         getBuffer(rewriter, op.output_dimensions(), options);
272     if (failed(outputDimsBuffer)) return failure();
273     Value resultDimSize =
274         rewriter.create<memref::LoadOp>(loc, *outputDimsBuffer, iVal);
275     if (!resultDimSize.getType().isIndex()) {
276       resultDimSize = rewriter.create<arith::IndexCastOp>(
277           loc, rewriter.getIndexType(), resultDimSize);
278     }
279     if (resultType.isDynamicDim(i)) {
280       sizes.push_back(resultDimSize);
281     } else {
282       sizes.push_back(rewriter.getIndexAttr(resultType.getDimSize(i)));
283     }
284 
285     auto it = outputToInputDim.find(i);
286     // If the rank of the output is greater than the rank of the input, i.e.
287     // there was no output dimension in the inverse broadcast_dimensions map
288     // we also set stride to 0 to emulate padding of the shape with 1s and the
289     // corresponding expansion.
290     if (it == outputToInputDim.end()) {
291       strides.push_back(zero);
292       continue;
293     }
294 
295     // There can be two cases:
296     // 1) Operand dim == result dim => expansion is not needed
297     //    => stride flattened buffer stride
298     // 2) Operand dim < result dim => expansion is needed => stride := 0.
299     int dim = it->second;
300     Value isExpansion = rewriter.create<arith::CmpIOp>(
301         loc, arith::CmpIPredicate::slt, operandSizes[dim], resultDimSize);
302     Value select = rewriter.create<mlir::arith::SelectOp>(
303         loc, isExpansion, zero, operandStrides[dim]);
304     strides.push_back(select);
305   }
306 
307   // Type-erased memref type with static rank and dynamic strides.
308   SmallVector<int64_t, 2> dynamicLayout(resultRank,
309                                         ShapedType::kDynamicStrideOrOffset);
310   auto typeErasedMemrefType = MemRefType::get(
311       resultType.getShape(), operandType.getElementType(),
312       makeStridedLinearLayoutMap(dynamicLayout,
313                                  /*offset=*/0, rewriter.getContext()));
314 
315   auto transformedOperand = rewriter.create<memref::ReinterpretCastOp>(
316       loc, typeErasedMemrefType, operand,
317       /*offset=*/rewriter.getI64IntegerAttr(0), sizes, strides);
318   return transformedOperand.getResult();
319 }
320 
321 struct DynamicBroadcastInDimOpInterface
322     : public BufferizableOpInterface::ExternalModel<
323           DynamicBroadcastInDimOpInterface, mhlo::DynamicBroadcastInDimOp> {
bufferizesToMemoryReadmlir::mhlo::__anon7d94d1360111::DynamicBroadcastInDimOpInterface324   bool bufferizesToMemoryRead(Operation * /*op*/, OpOperand & /*opOperand*/,
325                               const AnalysisState & /*state*/) const {
326     return true;
327   }
328 
bufferizesToMemoryWritemlir::mhlo::__anon7d94d1360111::DynamicBroadcastInDimOpInterface329   bool bufferizesToMemoryWrite(Operation * /*op*/, OpOperand & /*opOperand*/,
330                                const AnalysisState & /*state*/) const {
331     return false;
332   }
333 
getAliasingOpResultmlir::mhlo::__anon7d94d1360111::DynamicBroadcastInDimOpInterface334   SmallVector<OpResult> getAliasingOpResult(
335       Operation *op, OpOperand & /*opOperand*/,
336       const AnalysisState & /*state*/) const {
337     return {op->getResult(0)};
338   }
339 
bufferRelationmlir::mhlo::__anon7d94d1360111::DynamicBroadcastInDimOpInterface340   BufferRelation bufferRelation(Operation * /*op*/, OpResult /*opResult*/,
341                                 const AnalysisState & /*state*/) const {
342     // The op may allocate.
343     return BufferRelation::None;
344   }
345 
bufferizemlir::mhlo::__anon7d94d1360111::DynamicBroadcastInDimOpInterface346   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
347                           const BufferizationOptions &options) const {
348     auto broadcastInDimOp = cast<mhlo::DynamicBroadcastInDimOp>(op);
349     auto resultType = broadcastInDimOp.getType().dyn_cast<RankedTensorType>();
350     if (!resultType) return success();
351 
352     // The buffer still has the old (pre-reshape) type.
353     FailureOr<Value> operandBuffer =
354         getBuffer(rewriter, broadcastInDimOp.operand(), options);
355     if (failed(operandBuffer)) return failure();
356     FailureOr<Value> result = insertDynamicMemrefCastOp(
357         broadcastInDimOp, *operandBuffer, rewriter, options);
358     if (failed(result)) return failure();
359     bufferization::replaceOpWithBufferizedValues(rewriter, op, *result);
360     return success();
361   }
362 };
363 
364 struct HloLegalizeToMemrefPass
365     : public HloLegalizeToMemrefPassBase<HloLegalizeToMemrefPass> {
getDependentDialectsmlir::mhlo::__anon7d94d1360111::HloLegalizeToMemrefPass366   void getDependentDialects(DialectRegistry &registry) const override {
367     registry.insert<bufferization::BufferizationDialect, memref::MemRefDialect,
368                     mhlo::MhloDialect, lmhlo::LmhloDialect>();
369     registerBufferizableOpInterfaceExternalModels(registry);
370   }
371 
372  public:
runOnOperationmlir::mhlo::__anon7d94d1360111::HloLegalizeToMemrefPass373   void runOnOperation() override {
374     bufferization::BufferizationOptions options =
375         bufferization::getPartialBufferizationOptions();
376     options.opFilter.allowDialect<mhlo::MhloDialect>();
377     if (failed(bufferizeOp(getOperation(), options))) signalPassFailure();
378   }
379 };
380 
381 }  // namespace
382 
createLegalizeToMemrefPass()383 std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToMemrefPass() {
384   return std::make_unique<HloLegalizeToMemrefPass>();
385 }
386 
registerBufferizableOpInterfaceExternalModels(DialectRegistry & registry)387 void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
388   registry.addExtension(+[](MLIRContext *ctx, MhloDialect * /*dialect*/) {
389     CustomCallOp::attachInterface<CustomCallOpInterface>(*ctx);
390     ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx);
391     DynamicReshapeOp::attachInterface<DynamicReshapeOpInterface>(*ctx);
392     DynamicBroadcastInDimOp::attachInterface<DynamicBroadcastInDimOpInterface>(
393         *ctx);
394   });
395 }
396 
397 }  // namespace mhlo
398 }  // namespace mlir
399