1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for lowering HLO dialect to LHLO dialect.
17
18 #include <functional>
19 #include <memory>
20 #include <utility>
21
22 #include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
23 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
24 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
25 #include "mlir-hlo/Dialect/mhlo/transforms/bufferizable_op_interface_impl.h"
26 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
27 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
28 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
29 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
30 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
31 #include "mlir/Dialect/MemRef/IR/MemRef.h"
32 #include "mlir/IR/BuiltinOps.h"
33 #include "mlir/IR/BuiltinTypes.h"
34 #include "mlir/IR/Value.h"
35 #include "mlir/Pass/Pass.h"
36
37 namespace mlir {
38 namespace mhlo {
39 namespace {
40
41 using bufferization::AnalysisState;
42 using bufferization::BufferizableOpInterface;
43 using bufferization::BufferizationOptions;
44 using bufferization::BufferRelation;
45 using bufferization::replaceOpWithNewBufferizedOp;
46
47 struct CustomCallOpInterface
48 : public BufferizableOpInterface::ExternalModel<CustomCallOpInterface,
49 mhlo::CustomCallOp> {
bufferizesToMemoryReadmlir::mhlo::__anon7d94d1360111::CustomCallOpInterface50 bool bufferizesToMemoryRead(Operation *, OpOperand &,
51 const AnalysisState &) const {
52 return true;
53 }
54
bufferizesToMemoryWritemlir::mhlo::__anon7d94d1360111::CustomCallOpInterface55 bool bufferizesToMemoryWrite(Operation *, OpOperand &,
56 const AnalysisState &) const {
57 return false; // Arguments are read-only.
58 }
59
getAliasingOpResultmlir::mhlo::__anon7d94d1360111::CustomCallOpInterface60 SmallVector<OpResult> getAliasingOpResult(Operation *, OpOperand &,
61 const AnalysisState &) const {
62 return {};
63 }
64
bufferizemlir::mhlo::__anon7d94d1360111::CustomCallOpInterface65 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
66 const BufferizationOptions &options) const {
67 auto customCallOp = cast<mhlo::CustomCallOp>(op);
68
69 // Bufferize arguments.
70 SmallVector<Value> bufferArgs;
71 for (OpOperand &operand : customCallOp->getOpOperands()) {
72 if (!operand.get().getType().isa<TensorType>()) return failure();
73 FailureOr<Value> operandBuffer =
74 getBuffer(rewriter, operand.get(), options);
75 if (failed(operandBuffer)) return failure();
76 bufferArgs.push_back(*operandBuffer);
77 }
78
79 // Allocate outputs.
80 for (OpResult result : customCallOp->getOpResults()) {
81 auto tensorType = result.getType().cast<RankedTensorType>();
82 if (!tensorType) return failure();
83 // TODO(springerm): Create alloc_tensor ops during TensorCopyInsertion.
84 AnalysisState analysisState(options);
85 FailureOr<Value> tensorAlloc =
86 bufferization::allocateTensorForShapedValue(
87 rewriter, op->getLoc(), result,
88 analysisState.isTensorYielded(result), options);
89 if (failed(tensorAlloc)) return failure();
90 auto memrefType =
91 MemRefType::get(tensorType.getShape(), tensorType.getElementType());
92 Value resultBuffer = rewriter.create<bufferization::ToMemrefOp>(
93 op->getLoc(), memrefType, *tensorAlloc);
94 bufferArgs.push_back(resultBuffer);
95 }
96
97 auto lhloOp = rewriter.create<lmhlo::CustomCallOp>(
98 op->getLoc(), llvm::None, bufferArgs, op->getAttrs());
99 // lmhlo.custom_call uses a segment_size attribute to tell input from output
100 // arguments.
101 lhloOp->setAttr(lhloOp.getOperandSegmentSizeAttr(),
102 rewriter.getDenseI32ArrayAttr(
103 {static_cast<int32_t>(op->getNumOperands()),
104 static_cast<int32_t>(op->getNumResults())}));
105 bufferization::replaceOpWithBufferizedValues(
106 rewriter, op, makeArrayRef(bufferArgs).slice(op->getNumOperands()));
107 return success();
108 }
109 };
110
111 struct ReshapeOpInterface
112 : public BufferizableOpInterface::ExternalModel<ReshapeOpInterface,
113 mhlo::ReshapeOp> {
bufferizesToMemoryReadmlir::mhlo::__anon7d94d1360111::ReshapeOpInterface114 bool bufferizesToMemoryRead(Operation * /*op*/, OpOperand & /*opOperand*/,
115 const AnalysisState & /*state*/) const {
116 return false;
117 }
118
bufferizesToMemoryWritemlir::mhlo::__anon7d94d1360111::ReshapeOpInterface119 bool bufferizesToMemoryWrite(Operation * /*op*/, OpOperand & /*opOperand*/,
120 const AnalysisState & /*state*/) const {
121 return false;
122 }
123
getAliasingOpResultmlir::mhlo::__anon7d94d1360111::ReshapeOpInterface124 SmallVector<OpResult> getAliasingOpResult(
125 Operation *op, OpOperand & /*opOperand*/,
126 const AnalysisState & /*state*/) const {
127 return {op->getResult(0)};
128 }
129
bufferRelationmlir::mhlo::__anon7d94d1360111::ReshapeOpInterface130 BufferRelation bufferRelation(Operation * /*op*/, OpResult /*opResult*/,
131 const AnalysisState & /*state*/) const {
132 return BufferRelation::Equivalent;
133 }
134
bufferizemlir::mhlo::__anon7d94d1360111::ReshapeOpInterface135 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
136 const BufferizationOptions &options) const {
137 auto reshapeOp = cast<mhlo::ReshapeOp>(op);
138 auto unrankedOperandType =
139 reshapeOp.operand().getType().dyn_cast<UnrankedTensorType>();
140 if (unrankedOperandType == nullptr) return success();
141
142 // The buffer still has the old (pre-reshape) type.
143 FailureOr<Value> operandBuffer =
144 getBuffer(rewriter, reshapeOp.operand(), options);
145 if (failed(operandBuffer)) return failure();
146
147 auto resultType = reshapeOp.getType().cast<RankedTensorType>();
148 auto destType =
149 MemRefType::get(resultType.getShape(), resultType.getElementType());
150 replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, destType,
151 *operandBuffer);
152 return success();
153 }
154 };
155
156 struct DynamicReshapeOpInterface
157 : public BufferizableOpInterface::ExternalModel<DynamicReshapeOpInterface,
158 mhlo::DynamicReshapeOp> {
bufferizesToMemoryReadmlir::mhlo::__anon7d94d1360111::DynamicReshapeOpInterface159 bool bufferizesToMemoryRead(Operation * /*op*/, OpOperand & /*opOperand*/,
160 const AnalysisState & /*state*/) const {
161 return false;
162 }
163
bufferizesToMemoryWritemlir::mhlo::__anon7d94d1360111::DynamicReshapeOpInterface164 bool bufferizesToMemoryWrite(Operation * /*op*/, OpOperand & /*opOperand*/,
165 const AnalysisState & /*state*/) const {
166 return false;
167 }
168
getAliasingOpResultmlir::mhlo::__anon7d94d1360111::DynamicReshapeOpInterface169 SmallVector<OpResult> getAliasingOpResult(
170 Operation *op, OpOperand & /*opOperand*/,
171 const AnalysisState & /*state*/) const {
172 return {op->getResult(0)};
173 }
174
bufferRelationmlir::mhlo::__anon7d94d1360111::DynamicReshapeOpInterface175 BufferRelation bufferRelation(Operation * /*op*/, OpResult /*opResult*/,
176 const AnalysisState & /*state*/) const {
177 return BufferRelation::Equivalent;
178 }
179
bufferizemlir::mhlo::__anon7d94d1360111::DynamicReshapeOpInterface180 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
181 const BufferizationOptions &options) const {
182 auto reshapeOp = cast<mhlo::DynamicReshapeOp>(op);
183
184 // The buffer still has the old (pre-reshape) type.
185 FailureOr<Value> operandBuffer =
186 getBuffer(rewriter, reshapeOp.operand(), options);
187 FailureOr<Value> outputShapeBuffer =
188 getBuffer(rewriter, reshapeOp.output_shape(), options);
189 if (failed(operandBuffer) || failed(outputShapeBuffer)) return failure();
190
191 ShapedType resultType;
192 TensorType opResultType = reshapeOp.getType();
193 if (auto rankedType = opResultType.dyn_cast<RankedTensorType>()) {
194 resultType =
195 MemRefType::get(rankedType.getShape(), rankedType.getElementType());
196 } else if (auto unrankedType =
197 opResultType.dyn_cast<UnrankedTensorType>()) {
198 resultType = UnrankedMemRefType::get(unrankedType.getElementType(), 0);
199 }
200 auto operand = *operandBuffer;
201 // If the operand has a non-identity affine map, we will have to add a copy.
202 auto bufferType = operandBuffer->getType().dyn_cast<MemRefType>();
203 if (bufferType && !bufferType.getLayout().isIdentity()) {
204 // TODO(springerm): Create alloc_tensor ops during TensorCopyInsertion.
205 AnalysisState analysisState(options);
206 FailureOr<Value> tensorAlloc =
207 bufferization::allocateTensorForShapedValue(
208 rewriter, op->getLoc(), *operandBuffer,
209 analysisState.isTensorYielded(reshapeOp.getResult()), options);
210 if (failed(tensorAlloc)) return failure();
211 auto memrefType =
212 MemRefType::get(bufferType.getShape(), bufferType.getElementType());
213 operand = rewriter.create<bufferization::ToMemrefOp>(
214 op->getLoc(), memrefType, *tensorAlloc);
215 }
216 bufferization::replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
217 rewriter, op, resultType, operand, *outputShapeBuffer);
218 return success();
219 }
220 };
221
222 // Inserts dynamic memref to change the layout of the memref to put 0-stride
223 // and size of the target dimension if size-1 dimension expansion is
224 // necessary.
insertDynamicMemrefCastOp(mhlo::DynamicBroadcastInDimOp op,Value operand,RewriterBase & rewriter,const BufferizationOptions & options)225 FailureOr<Value> insertDynamicMemrefCastOp(
226 mhlo::DynamicBroadcastInDimOp op, Value operand, RewriterBase &rewriter,
227 const BufferizationOptions &options) {
228 auto loc = op.getLoc();
229 auto operandType = operand.getType().cast<MemRefType>();
230 auto operandShape = operandType.getShape();
231 auto operandRank = operandType.getRank();
232
233 auto resultType = op.getType().cast<RankedTensorType>();
234 auto resultRank = resultType.getRank();
235
236 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
237 Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
238
239 // Compute a reversed scan product. Compute the stride for the dimensions so
240 // far, working from minor to major dimensions. Additionally, save the
241 // operand shape Values to use in the next loop.
242 SmallVector<Value, 2> operandStrides(operandRank, one);
243 SmallVector<Value, 2> operandSizes(operandRank, one);
244 Value strideSoFar = one;
245 for (int i = operandRank - 1; i >= 0; --i) {
246 Value operandDimSize =
247 ShapedType::isDynamic(operandShape[i])
248 ? rewriter.create<memref::DimOp>(loc, operand, i).getResult()
249 : rewriter.create<arith::ConstantIndexOp>(loc, operandShape[i])
250 .getResult();
251 operandSizes[i] = operandDimSize;
252
253 operandStrides[i] = strideSoFar;
254 if (i > 0) {
255 strideSoFar =
256 rewriter.create<arith::MulIOp>(loc, strideSoFar, operandDimSize);
257 }
258 }
259
260 SmallVector<OpFoldResult, 2> sizes, strides;
261 sizes.reserve(resultRank);
262 strides.reserve(resultRank);
263
264 DenseMap<int, int> outputToInputDim;
265 for (const auto &dim : llvm::enumerate(op.broadcast_dimensions())) {
266 outputToInputDim[dim.value().getSExtValue()] = dim.index();
267 }
268 for (int i = 0; i < resultRank; ++i) {
269 Value iVal = rewriter.create<arith::ConstantIndexOp>(loc, i);
270 FailureOr<Value> outputDimsBuffer =
271 getBuffer(rewriter, op.output_dimensions(), options);
272 if (failed(outputDimsBuffer)) return failure();
273 Value resultDimSize =
274 rewriter.create<memref::LoadOp>(loc, *outputDimsBuffer, iVal);
275 if (!resultDimSize.getType().isIndex()) {
276 resultDimSize = rewriter.create<arith::IndexCastOp>(
277 loc, rewriter.getIndexType(), resultDimSize);
278 }
279 if (resultType.isDynamicDim(i)) {
280 sizes.push_back(resultDimSize);
281 } else {
282 sizes.push_back(rewriter.getIndexAttr(resultType.getDimSize(i)));
283 }
284
285 auto it = outputToInputDim.find(i);
286 // If the rank of the output is greater than the rank of the input, i.e.
287 // there was no output dimension in the inverse broadcast_dimensions map
288 // we also set stride to 0 to emulate padding of the shape with 1s and the
289 // corresponding expansion.
290 if (it == outputToInputDim.end()) {
291 strides.push_back(zero);
292 continue;
293 }
294
295 // There can be two cases:
296 // 1) Operand dim == result dim => expansion is not needed
297 // => stride flattened buffer stride
298 // 2) Operand dim < result dim => expansion is needed => stride := 0.
299 int dim = it->second;
300 Value isExpansion = rewriter.create<arith::CmpIOp>(
301 loc, arith::CmpIPredicate::slt, operandSizes[dim], resultDimSize);
302 Value select = rewriter.create<mlir::arith::SelectOp>(
303 loc, isExpansion, zero, operandStrides[dim]);
304 strides.push_back(select);
305 }
306
307 // Type-erased memref type with static rank and dynamic strides.
308 SmallVector<int64_t, 2> dynamicLayout(resultRank,
309 ShapedType::kDynamicStrideOrOffset);
310 auto typeErasedMemrefType = MemRefType::get(
311 resultType.getShape(), operandType.getElementType(),
312 makeStridedLinearLayoutMap(dynamicLayout,
313 /*offset=*/0, rewriter.getContext()));
314
315 auto transformedOperand = rewriter.create<memref::ReinterpretCastOp>(
316 loc, typeErasedMemrefType, operand,
317 /*offset=*/rewriter.getI64IntegerAttr(0), sizes, strides);
318 return transformedOperand.getResult();
319 }
320
321 struct DynamicBroadcastInDimOpInterface
322 : public BufferizableOpInterface::ExternalModel<
323 DynamicBroadcastInDimOpInterface, mhlo::DynamicBroadcastInDimOp> {
bufferizesToMemoryReadmlir::mhlo::__anon7d94d1360111::DynamicBroadcastInDimOpInterface324 bool bufferizesToMemoryRead(Operation * /*op*/, OpOperand & /*opOperand*/,
325 const AnalysisState & /*state*/) const {
326 return true;
327 }
328
bufferizesToMemoryWritemlir::mhlo::__anon7d94d1360111::DynamicBroadcastInDimOpInterface329 bool bufferizesToMemoryWrite(Operation * /*op*/, OpOperand & /*opOperand*/,
330 const AnalysisState & /*state*/) const {
331 return false;
332 }
333
getAliasingOpResultmlir::mhlo::__anon7d94d1360111::DynamicBroadcastInDimOpInterface334 SmallVector<OpResult> getAliasingOpResult(
335 Operation *op, OpOperand & /*opOperand*/,
336 const AnalysisState & /*state*/) const {
337 return {op->getResult(0)};
338 }
339
bufferRelationmlir::mhlo::__anon7d94d1360111::DynamicBroadcastInDimOpInterface340 BufferRelation bufferRelation(Operation * /*op*/, OpResult /*opResult*/,
341 const AnalysisState & /*state*/) const {
342 // The op may allocate.
343 return BufferRelation::None;
344 }
345
bufferizemlir::mhlo::__anon7d94d1360111::DynamicBroadcastInDimOpInterface346 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
347 const BufferizationOptions &options) const {
348 auto broadcastInDimOp = cast<mhlo::DynamicBroadcastInDimOp>(op);
349 auto resultType = broadcastInDimOp.getType().dyn_cast<RankedTensorType>();
350 if (!resultType) return success();
351
352 // The buffer still has the old (pre-reshape) type.
353 FailureOr<Value> operandBuffer =
354 getBuffer(rewriter, broadcastInDimOp.operand(), options);
355 if (failed(operandBuffer)) return failure();
356 FailureOr<Value> result = insertDynamicMemrefCastOp(
357 broadcastInDimOp, *operandBuffer, rewriter, options);
358 if (failed(result)) return failure();
359 bufferization::replaceOpWithBufferizedValues(rewriter, op, *result);
360 return success();
361 }
362 };
363
364 struct HloLegalizeToMemrefPass
365 : public HloLegalizeToMemrefPassBase<HloLegalizeToMemrefPass> {
getDependentDialectsmlir::mhlo::__anon7d94d1360111::HloLegalizeToMemrefPass366 void getDependentDialects(DialectRegistry ®istry) const override {
367 registry.insert<bufferization::BufferizationDialect, memref::MemRefDialect,
368 mhlo::MhloDialect, lmhlo::LmhloDialect>();
369 registerBufferizableOpInterfaceExternalModels(registry);
370 }
371
372 public:
runOnOperationmlir::mhlo::__anon7d94d1360111::HloLegalizeToMemrefPass373 void runOnOperation() override {
374 bufferization::BufferizationOptions options =
375 bufferization::getPartialBufferizationOptions();
376 options.opFilter.allowDialect<mhlo::MhloDialect>();
377 if (failed(bufferizeOp(getOperation(), options))) signalPassFailure();
378 }
379 };
380
381 } // namespace
382
createLegalizeToMemrefPass()383 std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToMemrefPass() {
384 return std::make_unique<HloLegalizeToMemrefPass>();
385 }
386
registerBufferizableOpInterfaceExternalModels(DialectRegistry & registry)387 void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) {
388 registry.addExtension(+[](MLIRContext *ctx, MhloDialect * /*dialect*/) {
389 CustomCallOp::attachInterface<CustomCallOpInterface>(*ctx);
390 ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx);
391 DynamicReshapeOp::attachInterface<DynamicReshapeOpInterface>(*ctx);
392 DynamicBroadcastInDimOp::attachInterface<DynamicBroadcastInDimOpInterface>(
393 *ctx);
394 });
395 }
396
397 } // namespace mhlo
398 } // namespace mlir
399