1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for lowering LHLO dialect to Affine dialect.
17
18 #include <utility>
19
20 #include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
21 #include "mlir-hlo/Dialect/lhlo/transforms/PassDetail.h"
22 #include "mlir-hlo/Dialect/lhlo/transforms/map_lmhlo_to_scalar_op.h"
23 #include "mlir/Dialect/Affine/IR/AffineOps.h"
24 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
25 #include "mlir/IR/BuiltinTypes.h"
26 #include "mlir/IR/Location.h"
27 #include "mlir/Pass/Pass.h"
28 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
29
30 namespace mlir {
31 namespace lmhlo {
32 namespace {
33
34 // Builds an affine loop nest iterating from zeros to "upper_bounds" with unit
35 // steps, and populates the body of the innermost loop using "body_builder".
buildBoundedAffineLoopNest(OpBuilder & builder,Location location,ArrayRef<int64_t> upperBounds,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuilder)36 static void buildBoundedAffineLoopNest(
37 OpBuilder& builder, Location location, ArrayRef<int64_t> upperBounds,
38 function_ref<void(OpBuilder&, Location, ValueRange)> bodyBuilder) {
39 SmallVector<int64_t, 3> lowerBounds(upperBounds.size(), /*Value=*/0);
40 SmallVector<int64_t, 3> steps(upperBounds.size(), /*Value=*/1);
41 buildAffineLoopNest(builder, location, lowerBounds, upperBounds, steps,
42 bodyBuilder);
43 }
44
45 struct DotOpConverter : public OpRewritePattern<DotOp> {
46 using OpRewritePattern<DotOp>::OpRewritePattern;
47
48 // Supports only rank-2 tensors for LHS and RHS.
matchAndRewritemlir::lmhlo::__anonf647f6ce0111::DotOpConverter49 LogicalResult matchAndRewrite(DotOp op,
50 PatternRewriter& rewriter) const override {
51 Value lhs = op.getLhs();
52 Value rhs = op.getRhs();
53 MemRefType lhsType = lhs.getType().cast<MemRefType>();
54 MemRefType rhsType = rhs.getType().cast<MemRefType>();
55 Type elementType = lhsType.getElementType();
56 ArrayRef<int64_t> shapeLhs = lhsType.getShape();
57 ArrayRef<int64_t> shapeRhs = rhsType.getShape();
58
59 if ((lhsType.getRank() != 2) || (rhsType.getRank() != 2)) {
60 return failure();
61 }
62
63 // We don't currently support batching dimensions, or multiple contraction
64 // dimensions.
65 mhlo::DotDimensionNumbersAttr dotDimensionNumbers =
66 op.getDotDimensionNumbers();
67 if (!dotDimensionNumbers.getLhsBatchingDimensions().empty() ||
68 !dotDimensionNumbers.getRhsBatchingDimensions().empty())
69 return failure();
70 if (dotDimensionNumbers.getLhsContractingDimensions().size() != 1 ||
71 *dotDimensionNumbers.getLhsContractingDimensions().begin() != 1 ||
72 dotDimensionNumbers.getRhsContractingDimensions().size() != 1 ||
73 *dotDimensionNumbers.getRhsContractingDimensions().begin() != 0) {
74 return failure();
75 }
76
77 LogicalResult mapStatus = success();
78 auto bodyBuilder = [&](OpBuilder& builder, Location loc, ValueRange ivs) {
79 SmallVector<Value, 2> lhsIndices{ivs[0], ivs[2]},
80 rhsIndices{ivs[2], ivs[1]}, resultIndices{ivs[0], ivs[1]};
81
82 auto l = builder.create<AffineLoadOp>(loc, lhs, lhsIndices);
83 auto r = builder.create<AffineLoadOp>(loc, rhs, rhsIndices);
84 auto result =
85 rewriter.create<AffineLoadOp>(loc, op.getOutput(), resultIndices);
86 Value opResult = lmhlo::LhloOpToStdScalarOp::map<DotOp>(
87 op, elementType, {l, r, result}, &builder);
88 mapStatus = success(opResult != nullptr);
89 if (failed(mapStatus)) return;
90 builder.create<AffineStoreOp>(loc, opResult, op.getOutput(),
91 resultIndices);
92 };
93
94 buildBoundedAffineLoopNest(rewriter, op.getLoc(),
95 {shapeLhs[0], shapeRhs[1], shapeRhs[0]},
96 bodyBuilder);
97 if (failed(mapStatus)) return failure();
98
99 rewriter.eraseOp(op);
100 return success();
101 }
102 };
103
104 /// Concat Operation Example (2D):
105 /// Given inpA[2][1], inpB[2][2], concat_dimension = 1.
106 /// Compute output[x1][x2].
107 /// Implementation Pseudocode:
108 /// s = 0
109 /// for a in range(0, 2):
110 /// for b in range(0, 1):
111 /// output[a][b] = inpA[a][b - s]
112 /// s = 1
113 /// for a in range(0, 2):
114 /// for b in range(1, 3):
115 /// output[a][b] = inpB[a][b - s]
116 ///
117 /// Concatenate composes an array from multiple array operands. The array is of
118 /// the same rank as each of the input array operands (which must be of the same
119 /// rank as each other) and contains the arguments in the order that they were
120 /// specified.
121 struct ConcatOpConverter : public OpRewritePattern<ConcatenateOp> {
122 using OpRewritePattern<ConcatenateOp>::OpRewritePattern;
123
matchAndRewritemlir::lmhlo::__anonf647f6ce0111::ConcatOpConverter124 LogicalResult matchAndRewrite(ConcatenateOp op,
125 PatternRewriter& rewriter) const override {
126 Location loc = op.getLoc();
127 Value output = op.getOutput();
128 MemRefType outputType = output.getType().cast<MemRefType>();
129 unsigned outputRank = outputType.getRank();
130 ArrayRef<int64_t> outputShape = outputType.getShape();
131
132 ValueRange operands = op.getVal();
133 uint64_t concatDim = op.getDimension();
134 int64_t prevBound = 0;
135
136 for (Value operand : operands) {
137 MemRefType operandType = operand.getType().cast<MemRefType>();
138 ArrayRef<int64_t> operandShape = operandType.getShape();
139
140 // TODO(pashu123): Extend support for dynamic dimensions.
141 if (!operandType.hasStaticShape()) return failure();
142
143 // Only for the concatenation dimension, the value is dimension -
144 // prevBound.
145 SmallVector<AffineExpr, 4> expr;
146 for (unsigned i = 0; i < outputRank; i++) {
147 AffineExpr d0 = (i == concatDim)
148 ? rewriter.getAffineDimExpr(concatDim) - prevBound
149 : rewriter.getAffineDimExpr(i);
150
151 expr.push_back(d0);
152 }
153 AffineMap map =
154 AffineMap::get(outputRank, 0, expr, rewriter.getContext());
155
156 // Create multiple for loop nests iterating along the concatenation
157 // dimension.
158 OpBuilder::InsertionGuard guard(rewriter);
159 SmallVector<Value, 3> indices;
160 AffineForOp forOp;
161 for (unsigned i = 0; i < outputRank; i++) {
162 if (i == concatDim) {
163 forOp = rewriter.create<AffineForOp>(loc, prevBound,
164 prevBound + operandShape[i]);
165 prevBound += operandShape[i];
166 indices.push_back(forOp.getInductionVar());
167 } else {
168 forOp = rewriter.create<AffineForOp>(loc, 0, outputShape[i]);
169 indices.push_back(forOp.getInductionVar());
170 }
171 rewriter.setInsertionPointToStart(forOp.getBody());
172 }
173 Value storeVal =
174 rewriter.create<AffineLoadOp>(loc, operand, map, indices);
175 rewriter.create<AffineStoreOp>(loc, storeVal, output, indices);
176 }
177 rewriter.eraseOp(op);
178 return success();
179 }
180 };
181
182 /// Returns a zero value of type `type`. `type` is expected to be either
183 /// int or float.
getZeroValue(Type type,Location loc,PatternRewriter & rewriter)184 static Value getZeroValue(Type type, Location loc, PatternRewriter& rewriter) {
185 assert(type.isIntOrFloat() && "Expected int or float");
186
187 if (IntegerType intType = type.dyn_cast<IntegerType>())
188 return rewriter.create<mlir::arith::ConstantIntOp>(loc, 0,
189 intType.getWidth());
190
191 FloatType floatType = type.cast<FloatType>();
192 return rewriter.create<mlir::arith::ConstantFloatOp>(
193 loc, APFloat::getZero(floatType.getFloatSemantics()), floatType);
194 }
195
196 /// Emits a nest to fill the given `buffer` of memref type with `fillValue`.
fillBuffer(Location loc,Value buffer,Value fillValue,PatternRewriter & builder)197 static void fillBuffer(Location loc, Value buffer, Value fillValue,
198 PatternRewriter& builder) {
199 OpBuilder::InsertionGuard guard(builder);
200 MemRefType bufType = buffer.getType().cast<MemRefType>();
201 unsigned rank = bufType.getRank();
202 SmallVector<Value, 4> dimSizes;
203 dimSizes.reserve(rank);
204 for (unsigned i = 0; i < rank; ++i)
205 dimSizes.push_back(builder.create<mlir::memref::DimOp>(loc, buffer, i));
206
207 AffineMap idSymMap = builder.getSymbolIdentityMap();
208 AffineMap lbMap = builder.getConstantAffineMap(0);
209 SmallVector<Value, 4> ivs(rank);
210 AffineForOp forOp;
211 for (unsigned i = 0; i < rank; ++i) {
212 forOp = builder.create<AffineForOp>(loc, llvm::None, lbMap, dimSizes[i],
213 idSymMap);
214 builder.setInsertionPointToStart(forOp.getBody());
215 ivs[i] = forOp.getInductionVar();
216 }
217 Type fillValueType = fillValue.getType();
218 auto fillMemRefType = fillValueType.dyn_cast<MemRefType>();
219 assert(((fillMemRefType && fillMemRefType.getRank() == 0) ||
220 fillValueType.isIntOrFloat()) &&
221 "init value has to be a 0-d memref or int or fp");
222 Value initVal = fillMemRefType ? builder.create<AffineLoadOp>(
223 loc, fillValue, /*indices=*/llvm::None)
224 : fillValue;
225 builder.create<AffineStoreOp>(loc, initVal, buffer, ivs);
226 }
227
228 /// Converts GatherOp to Affine nest form.
229 /// Pseudocode:
230 /// 1. Fill a temporary output tensor with 0.
231 /// 2. Repeat the following for each batch dimension :-
232 /// 1. For each indices in 'operand' :-
233 /// 1. Get hold of start indices from 'start_indices'.
234 /// 2. Add offset to the start indices to get the final indices.
235 /// 3. Load value from 'operand' tensor : 'operand_val'.
236 /// 4. Load value from temporary output : 'prev_val'.
237 /// 5. If the final indices match current indices of 'operand' :
238 /// 'prev_val' = 'prev_val' + 'operand_val'
239 /// 6. Store 'prev_val' back to the temporary output.
240 class GatherOpConverter : public OpRewritePattern<GatherOp> {
241 public:
242 using OpRewritePattern<GatherOp>::OpRewritePattern;
243
matchAndRewrite(GatherOp op,PatternRewriter & rewriter) const244 LogicalResult matchAndRewrite(GatherOp op,
245 PatternRewriter& rewriter) const final {
246 Location loc = op.getLoc();
247
248 // Operand array.
249 Value operand = op.getOperand();
250 MemRefType operandType = operand.getType().cast<MemRefType>();
251 unsigned operandRank = operandType.getRank();
252 ArrayRef<int64_t> operandShape = operandType.getShape();
253
254 // Start_indices array.
255 Value startIndices = op.getStartIndices();
256 MemRefType startIndicesType = startIndices.getType().cast<MemRefType>();
257 unsigned startIndicesRank = startIndicesType.getRank();
258 ArrayRef<int64_t> startIndicesShape = startIndicesType.getShape();
259
260 // Output array.
261 Value output = op.getOutput();
262 MemRefType outputType = output.getType().cast<MemRefType>();
263 ArrayRef<int64_t> outputShape = outputType.getShape();
264
265 if (!operandType.hasStaticShape() || !startIndicesType.hasStaticShape() ||
266 !outputType.hasStaticShape())
267 return rewriter.notifyMatchFailure(op, "only static shaped type allowed");
268
269 mhlo::GatherDimensionNumbersAttr gatherDim = op.getDimensionNumbers();
270
271 auto collapsedSliceDims = gatherDim.getCollapsedSliceDims();
272 auto offsetDims = gatherDim.getOffsetDims();
273 auto startIndexMap = gatherDim.getStartIndexMap();
274 int64_t indexVectorDim = gatherDim.getIndexVectorDim();
275
276 // Slice_sizes.
277 DenseIntElementsAttr sliceSizesAttr = op.getSliceSizesAttr();
278 SmallVector<int64_t, 4> sliceSizes;
279 for (const APInt& dim : sliceSizesAttr.getValues<APInt>())
280 sliceSizes.push_back(dim.getSExtValue());
281
282 // Creating constants with 0 value. We need the Integer type constant value
283 // because the indices type will be Integer.
284 Value zeroIntVal = rewriter.create<mlir::arith::ConstantIntOp>(
285 loc, 0, startIndicesType.getElementType());
286 Type elementType = outputType.getElementType();
287 Value zeroLoadValue = getZeroValue(elementType, loc, rewriter);
288 // Initializing the output buffer with 0.
289 fillBuffer(loc, output, zeroLoadValue, rewriter);
290
291 // We fetch the shape of start_indices at index_vector_dim. In case
292 // index_vector_dim is equal to the rank of start_indices, we implicitly
293 // consider start_indices to have a trailing 1 dimension.
294 unsigned startIndicesNumbers = (indexVectorDim == startIndicesRank)
295 ? 1
296 : startIndicesShape[indexVectorDim];
297 // We create integer constants till start_incides_index which help us to
298 // fetch start_indices in affine transformation.
299 SmallVector<Value, 4> startIndicesIndex;
300 for (unsigned i = 0; i < startIndicesNumbers; i++) {
301 Value iVal = rewriter.create<mlir::arith::ConstantIntOp>(
302 loc, i, startIndicesType.getElementType());
303 iVal = rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(),
304 iVal);
305 startIndicesIndex.push_back(iVal);
306 }
307
308 // S_in contains the multiple indices that form a starting index used in the
309 // input/operand tensor. O_in contains the multiple offsets of corresponding
310 // starting index used in the input/operand tensor. We initialize both of
311 // them with 0.
312 SmallVector<Value, 4> sIn;
313 SmallVector<Value, 4> oIn;
314 Value zeroIndexVal = rewriter.create<arith::IndexCastOp>(
315 loc, rewriter.getIndexType(), zeroIntVal);
316 for (unsigned i = 0; i < operandRank; i++) {
317 sIn.push_back(zeroIndexVal);
318 oIn.push_back(zeroIndexVal);
319 }
320
321 // batch_induction_vars stores the loop induction variables pertaining to
322 // the batches of start indices.
323 SmallVector<Value, 4> batchInductionVars;
324 // output_induction_vars stores the loop induction variables pertaining to
325 // both batches and offsets within the output tensor.
326 SmallVector<Value, 4> outputInductionVars;
327 // Create loops to iterate over each batch of starting index.
328 for (unsigned i = 0; i < startIndicesRank; i++) {
329 // ith dimension of start_indices doesn't form a batch if it is equal to
330 // index_vector_dim.
331 if (i == indexVectorDim) continue;
332 AffineForOp forOp =
333 rewriter.create<AffineForOp>(loc, 0, startIndicesShape[i]);
334 batchInductionVars.push_back(forOp.getInductionVar());
335 outputInductionVars.push_back(forOp.getInductionVar());
336 rewriter.setInsertionPointToStart(forOp.getBody());
337 }
338
339 // Create loops to iterate over each offset dimension within the output
340 // tensor.
341 for (unsigned i = 0, k = 0, e = offsetDims.size(); i < e; i++) {
342 AffineForOp forOp =
343 rewriter.create<AffineForOp>(loc, 0, outputShape[offsetDims[i]]);
344 rewriter.setInsertionPointToStart(forOp.getBody());
345 // We try to fetch the first non-collapsed dimension.
346 while (k < collapsedSliceDims.size() && collapsedSliceDims[k] == i) k++;
347 // Remapping the offset_dim[i] to the non-collapsed dimension.
348 oIn[k++] = forOp.getInductionVar();
349 // We assume offset_dims to be sorted. So when inserted to
350 // output_induction_vars the loop induction variable gets inserted at the
351 // correct position.
352 outputInductionVars.insert(outputInductionVars.begin() + offsetDims[i],
353 forOp.getInductionVar());
354 }
355
356 // Create loops to iterate over all dimensions within the operand tensor.
357 SmallVector<Value, 4> operandIndex;
358 for (unsigned i = 0, k = 0; i < operandRank; i++) {
359 // We assume start_index_map to have sorted dimensions. We only include
360 // those dimensions of operand tensor which are present in
361 // start_index_map.
362 if (k < startIndexMap.size() && i == startIndexMap[k++]) {
363 AffineForOp forOp =
364 rewriter.create<AffineForOp>(loc, 0, operandShape[i]);
365 operandIndex.push_back(forOp.getInductionVar());
366 rewriter.setInsertionPointToStart(forOp.getBody());
367 } else {
368 operandIndex.push_back(oIn[i]);
369 }
370 }
371
372 // In case index_vector_dim is not equal to start_indices shape then we
373 // create another loop to iterate over starting index and update
374 // batch_induction_vars.
375 if (indexVectorDim != startIndicesRank) {
376 for (unsigned i = 0; i < startIndicesNumbers; i++) {
377 batchInductionVars.insert(batchInductionVars.begin() + indexVectorDim,
378 startIndicesIndex[i]);
379 Value startIndex = rewriter.create<AffineLoadOp>(loc, startIndices,
380 batchInductionVars);
381 startIndex = rewriter.create<arith::IndexCastOp>(
382 loc, rewriter.getIndexType(), startIndex);
383 sIn[startIndexMap[i]] = startIndex;
384 batchInductionVars.erase(batchInductionVars.begin() + indexVectorDim);
385 }
386 } else {
387 // Since index_vector_dim is equal to start_indicesRank we can directly
388 // fetch the start_index from batch_induction_vars.
389 Value startIndex =
390 rewriter.create<AffineLoadOp>(loc, startIndices, batchInductionVars);
391 startIndex = rewriter.create<arith::IndexCastOp>(
392 loc, rewriter.getIndexType(), startIndex);
393 sIn[0] = startIndex;
394 }
395
396 // We load value at a particular operand index and populate the output
397 // tensor if the index constraints match.
398 Value loadValue = rewriter.create<AffineLoadOp>(loc, operand, operandIndex);
399 SmallVector<Value, 4> predicates;
400 // Adding offsets to the corresponding starting index and comparing it with
401 // the corresponding operand index.
402 for (unsigned k = 0, i = 0; k < startIndexMap.size(); k++) {
403 i = startIndexMap[k];
404 Value addStartIndexOffset = rewriter.create<mlir::arith::AddIOp>(
405 loc, rewriter.getIndexType(), sIn[i], oIn[i]);
406 Value predicate = rewriter.create<mlir::arith::CmpIOp>(
407 loc, arith::CmpIPredicate::eq, addStartIndexOffset, operandIndex[i]);
408 predicates.push_back(predicate);
409 }
410
411 // Since the no. of predicates is equal to start_index_map.size() we
412 // iterate over pairs of predicates and join them with arith::AndIOp.
413 // We store the final predicate formed by joining other predicates with
414 // arith::AndIOp in result_predicate.
415 Value resultPredicate = nullptr;
416 for (unsigned i = 0; i < predicates.size() - 1; i += 2) {
417 Value predicateA = predicates[i];
418 Value predicateB = predicates[i + 1];
419 Value andPredicate =
420 rewriter.create<mlir::arith::AndIOp>(loc, predicateA, predicateB);
421 resultPredicate = (i == 0) ? andPredicate
422 : rewriter.create<mlir::arith::AndIOp>(
423 loc, resultPredicate, andPredicate);
424 }
425 // We fetch the last predicate value. In case this is the only predicate
426 // we let result_predicate be equal to this predicate value. Else if there
427 // are odd number of predicates we join it to other predicates using
428 // arith::AndIOp.
429 Value predicate = predicates.back();
430 if (!resultPredicate) resultPredicate = predicate;
431 // In case there are odd number of predicates we join the last predicate
432 // to the result_predicate using arith::AndIOp.
433 else if (startIndexMap.size() % 2 == 1)
434 resultPredicate =
435 rewriter.create<mlir::arith::AndIOp>(loc, resultPredicate, predicate);
436
437 // We use the loaded value if the index computed by adding offsets to
438 // starting index is equal to the current operand index. We use 0 as a value
439 // otherwise.
440 Value selectLoad = rewriter.create<mlir::arith::SelectOp>(
441 loc, resultPredicate, loadValue, zeroLoadValue);
442 // We load value at output array.
443 Value outputValue =
444 rewriter.create<AffineLoadOp>(loc, output, outputInductionVars);
445
446 // The selected value is added to the previous value stored in output array.
447 if (elementType.isa<FloatType>())
448 outputValue = rewriter.create<arith::AddFOp>(loc, elementType, selectLoad,
449 outputValue);
450 else
451 outputValue = rewriter.create<arith::AddIOp>(loc, elementType, selectLoad,
452 outputValue);
453 rewriter.create<AffineStoreOp>(loc, outputValue, output,
454 outputInductionVars);
455 rewriter.eraseOp(op);
456 return success();
457 }
458 };
459
460 /// Converts PadOp to Affine nest form.
461 /// Pseudocode:
462 /// 1. Fill `output` tensor with `padding_value`.
463 /// 2. Compute AffineMap for store into `output`.
464 /// out_idx = edge_padding_low +
465 /// operand_idx * (1 + interior_padding)
466 /// 3. Create nested loop from `operand` shape.
467 /// 3.1 load from `operand`.
468 /// 3.2 store into `output`.
469 /// NOTE: The lowering handles only ranked shapes and bails out in case any of
470 /// output type/edge_padding_low/edge_padding_high/interior_padding size
471 /// doesn't match that of the operand's rank.
472 /// Limitation for now:
473 /// interior_padding == 0 && edge_padding_* >= 0
474 struct PadOpConverter : public OpRewritePattern<PadOp> {
475 using OpRewritePattern<PadOp>::OpRewritePattern;
476
matchAndRewritemlir::lmhlo::__anonf647f6ce0111::PadOpConverter477 LogicalResult matchAndRewrite(PadOp op,
478 PatternRewriter& rewriter) const override {
479 Value operand = op.getOperand();
480 Value paddingValue = op.getPaddingValue();
481 Value output = op.getOutput();
482
483 auto operandType = operand.getType().dyn_cast<ShapedType>();
484 auto outputType = output.getType().dyn_cast<ShapedType>();
485 // We allow lowering for only ranked input/output.
486 if (!(operandType && outputType && operandType.hasRank() &&
487 outputType.hasRank()))
488 return failure();
489 unsigned rank = operandType.getRank();
490
491 auto edgePadLowRanges = op.getEdgePaddingLow().getValues<int64_t>();
492 auto edgePadHighRanges = op.getEdgePaddingHigh().getValues<int64_t>();
493 auto interiorPadRanges = op.getInteriorPadding().getValues<int64_t>();
494 // Check whether the constraints for the lowering are satisfied :-
495 // 1. interior_padding[i] == 0
496 // 2. edge_padding_*[i] >= 0
497 for (auto paddings :
498 llvm::zip(edgePadLowRanges, edgePadHighRanges, interiorPadRanges)) {
499 // Only handle non-negative edge padding.
500 if (std::get<0>(paddings) < 0 || std::get<1>(paddings) < 0)
501 return rewriter.notifyMatchFailure(
502 op, "expected non-negative edge padding");
503 // Only handle interior padding being zero for now.
504 if (std::get<2>(paddings) != 0)
505 return rewriter.notifyMatchFailure(op,
506 "expected zero interior padding");
507 }
508
509 SmallVector<int64_t> edgePaddingLow(edgePadLowRanges.begin(),
510 edgePadLowRanges.end());
511 SmallVector<int64_t> edgePaddingHigh(edgePadHighRanges.begin(),
512 edgePadHighRanges.end());
513 SmallVector<int64_t> interiorPadding(interiorPadRanges.begin(),
514 interiorPadRanges.end());
515
516 ArrayRef<int64_t> operandShape = operandType.getShape();
517 ArrayRef<int64_t> outputShape = outputType.getShape();
518
519 // Mapping the `operand` index to the `output` index.
520 SmallVector<AffineExpr, 4> expr;
521 for (unsigned i = 0; i < rank; i++) {
522 AffineExpr dimExpr = rewriter.getAffineDimExpr(i);
523 expr.push_back(dimExpr + edgePaddingLow[i]);
524 }
525 AffineMap map =
526 AffineMap::get(rank, /*symbolCount=*/0, expr, rewriter.getContext());
527
528 SmallVector<Value, 4> indices;
529
530 Location loc = op.getLoc();
531 // Set padding_value to output.
532 {
533 OpBuilder::InsertionGuard regionGuard(rewriter);
534 Value scalarPaddingValue = rewriter.create<AffineLoadOp>(
535 loc, paddingValue, SmallVector<Value, 4>());
536 AffineForOp initForOp;
537 for (unsigned i = 0; i < rank; i++) {
538 initForOp = rewriter.create<AffineForOp>(loc, 0, outputShape[i]);
539 indices.push_back(initForOp.getInductionVar());
540 rewriter.setInsertionPointToStart(initForOp.getBody());
541 }
542 rewriter.create<AffineStoreOp>(loc, scalarPaddingValue, output, indices);
543 }
544
545 // Store `operand` into `output`, loop upper bounds from `operand` shape.
546 indices.clear();
547 AffineForOp padForOp;
548 for (unsigned i = 0; i < rank; i++) {
549 padForOp = rewriter.create<AffineForOp>(loc, 0, operandShape[i]);
550 indices.push_back(padForOp.getInductionVar());
551 rewriter.setInsertionPointToStart(padForOp.getBody());
552 }
553 Value storeVal = rewriter.create<AffineLoadOp>(loc, operand, indices);
554 rewriter.create<AffineStoreOp>(loc, storeVal, output, map, indices);
555 rewriter.eraseOp(op);
556 return success();
557 }
558 };
559
560 template <typename LhloOpTy>
561 struct BinaryOpConverter : public OpRewritePattern<LhloOpTy> {
562 using OpRewritePattern<LhloOpTy>::OpRewritePattern;
563
matchAndRewritemlir::lmhlo::__anonf647f6ce0111::BinaryOpConverter564 LogicalResult matchAndRewrite(LhloOpTy op,
565 PatternRewriter& rewriter) const override {
566 const auto& lhs = op.getLhs();
567 const auto& rhs = op.getRhs();
568 const auto& lhsType = lhs.getType().template cast<MemRefType>();
569 const auto& rhsType = rhs.getType().template cast<MemRefType>();
570 const auto& elementType = lhsType.getElementType();
571
572 if (lhsType.getShape() != rhsType.getShape()) {
573 return failure();
574 }
575
576 LogicalResult mapStatus = success();
577 auto bodyBuilder = [&](OpBuilder& builder, Location loc,
578 ValueRange inductionVars) {
579 auto l = builder.create<AffineLoadOp>(loc, lhs, inductionVars);
580 auto r = builder.create<AffineLoadOp>(loc, rhs, inductionVars);
581 Value opResult = lmhlo::LhloOpToStdScalarOp::map<LhloOpTy>(
582 op, elementType, {l, r}, &builder);
583 mapStatus = success(opResult != nullptr);
584 if (failed(mapStatus)) return;
585 rewriter.create<AffineStoreOp>(loc, opResult, op.getOut(), inductionVars);
586 };
587
588 buildBoundedAffineLoopNest(rewriter, op.getLoc(), lhsType.getShape(),
589 bodyBuilder);
590 if (failed(mapStatus)) return failure();
591 rewriter.eraseOp(op);
592 return success();
593 }
594 };
595
596 /// Conversion for unary operations i.e. tanh sin cos log log1p etc.
597 template <typename LhloOpTy>
598 struct UnaryOpConverter : public OpRewritePattern<LhloOpTy> {
599 using OpRewritePattern<LhloOpTy>::OpRewritePattern;
600
matchAndRewritemlir::lmhlo::__anonf647f6ce0111::UnaryOpConverter601 LogicalResult matchAndRewrite(LhloOpTy op,
602 PatternRewriter& rewriter) const override {
603 Value input = op.getInput();
604 auto inputType = input.getType().cast<MemRefType>();
605 auto elementType = inputType.getElementType();
606 ArrayRef<int64_t> shape = inputType.getShape();
607
608 SmallVector<Value, 4> inductionVars;
609
610 LogicalResult mapStatus = success();
611 auto bodyBuilder = [&](OpBuilder& builder, Location loc,
612 ValueRange inductionVars) {
613 Value loadInput = builder.create<AffineLoadOp>(loc, input, inductionVars);
614 Value opResult = lmhlo::LhloOpToStdScalarOp::map<LhloOpTy>(
615 op, elementType, {loadInput}, &builder);
616 mapStatus = success(opResult != nullptr);
617 if (failed(mapStatus)) return;
618 rewriter.create<AffineStoreOp>(loc, opResult, op.getOutput(),
619 inductionVars);
620 };
621 buildBoundedAffineLoopNest(rewriter, op.getLoc(), shape, bodyBuilder);
622 if (failed(mapStatus)) return failure();
623 rewriter.eraseOp(op);
624 return success();
625 }
626 };
627
populateLHLOToAffineConversionPattern(MLIRContext * context,RewritePatternSet * patterns)628 void populateLHLOToAffineConversionPattern(MLIRContext* context,
629 RewritePatternSet* patterns) {
630 // clang-format off
631 patterns->add<
632 BinaryOpConverter<lmhlo::AddOp>,
633 BinaryOpConverter<lmhlo::AndOp>,
634 BinaryOpConverter<lmhlo::DivOp>,
635 BinaryOpConverter<lmhlo::MaxOp>,
636 BinaryOpConverter<lmhlo::MinOp>,
637 BinaryOpConverter<lmhlo::MulOp>,
638 BinaryOpConverter<lmhlo::SubtractOp>,
639 ConcatOpConverter,
640 DotOpConverter,
641 GatherOpConverter,
642 PadOpConverter,
643 UnaryOpConverter<lmhlo::LogOp>>(context);
644 // clang-format on
645 }
646
647 struct LhloLegalizeToAffinePass
648 : public LhloLegalizeToAffinePassBase<LhloLegalizeToAffinePass> {
getDependentDialectsmlir::lmhlo::__anonf647f6ce0111::LhloLegalizeToAffinePass649 void getDependentDialects(DialectRegistry& registry) const override {
650 registry.insert<AffineDialect, math::MathDialect>();
651 }
runOnOperationmlir::lmhlo::__anonf647f6ce0111::LhloLegalizeToAffinePass652 void runOnOperation() override {
653 auto func = getOperation();
654 RewritePatternSet patterns(&getContext());
655 populateLHLOToAffineConversionPattern(&getContext(), &patterns);
656 if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns))))
657 return signalPassFailure();
658 }
659 };
660
661 } // namespace
662
createLhloLegalizeToAffinePass()663 std::unique_ptr<OperationPass<func::FuncOp>> createLhloLegalizeToAffinePass() {
664 return std::make_unique<LhloLegalizeToAffinePass>();
665 }
666
667 } // namespace lmhlo
668 } // namespace mlir
669