1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for lowering LHLO dialect to Affine dialect.
17 
18 #include <utility>
19 
20 #include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
21 #include "mlir-hlo/Dialect/lhlo/transforms/PassDetail.h"
22 #include "mlir-hlo/Dialect/lhlo/transforms/map_lmhlo_to_scalar_op.h"
23 #include "mlir/Dialect/Affine/IR/AffineOps.h"
24 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
25 #include "mlir/IR/BuiltinTypes.h"
26 #include "mlir/IR/Location.h"
27 #include "mlir/Pass/Pass.h"
28 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
29 
30 namespace mlir {
31 namespace lmhlo {
32 namespace {
33 
34 // Builds an affine loop nest iterating from zeros to "upper_bounds" with unit
35 // steps, and populates the body of the innermost loop using "body_builder".
buildBoundedAffineLoopNest(OpBuilder & builder,Location location,ArrayRef<int64_t> upperBounds,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuilder)36 static void buildBoundedAffineLoopNest(
37     OpBuilder& builder, Location location, ArrayRef<int64_t> upperBounds,
38     function_ref<void(OpBuilder&, Location, ValueRange)> bodyBuilder) {
39   SmallVector<int64_t, 3> lowerBounds(upperBounds.size(), /*Value=*/0);
40   SmallVector<int64_t, 3> steps(upperBounds.size(), /*Value=*/1);
41   buildAffineLoopNest(builder, location, lowerBounds, upperBounds, steps,
42                       bodyBuilder);
43 }
44 
45 struct DotOpConverter : public OpRewritePattern<DotOp> {
46   using OpRewritePattern<DotOp>::OpRewritePattern;
47 
48   // Supports only rank-2 tensors for LHS and RHS.
matchAndRewritemlir::lmhlo::__anonf647f6ce0111::DotOpConverter49   LogicalResult matchAndRewrite(DotOp op,
50                                 PatternRewriter& rewriter) const override {
51     Value lhs = op.getLhs();
52     Value rhs = op.getRhs();
53     MemRefType lhsType = lhs.getType().cast<MemRefType>();
54     MemRefType rhsType = rhs.getType().cast<MemRefType>();
55     Type elementType = lhsType.getElementType();
56     ArrayRef<int64_t> shapeLhs = lhsType.getShape();
57     ArrayRef<int64_t> shapeRhs = rhsType.getShape();
58 
59     if ((lhsType.getRank() != 2) || (rhsType.getRank() != 2)) {
60       return failure();
61     }
62 
63     // We don't currently support batching dimensions, or multiple contraction
64     // dimensions.
65     mhlo::DotDimensionNumbersAttr dotDimensionNumbers =
66         op.getDotDimensionNumbers();
67     if (!dotDimensionNumbers.getLhsBatchingDimensions().empty() ||
68         !dotDimensionNumbers.getRhsBatchingDimensions().empty())
69       return failure();
70     if (dotDimensionNumbers.getLhsContractingDimensions().size() != 1 ||
71         *dotDimensionNumbers.getLhsContractingDimensions().begin() != 1 ||
72         dotDimensionNumbers.getRhsContractingDimensions().size() != 1 ||
73         *dotDimensionNumbers.getRhsContractingDimensions().begin() != 0) {
74       return failure();
75     }
76 
77     LogicalResult mapStatus = success();
78     auto bodyBuilder = [&](OpBuilder& builder, Location loc, ValueRange ivs) {
79       SmallVector<Value, 2> lhsIndices{ivs[0], ivs[2]},
80           rhsIndices{ivs[2], ivs[1]}, resultIndices{ivs[0], ivs[1]};
81 
82       auto l = builder.create<AffineLoadOp>(loc, lhs, lhsIndices);
83       auto r = builder.create<AffineLoadOp>(loc, rhs, rhsIndices);
84       auto result =
85           rewriter.create<AffineLoadOp>(loc, op.getOutput(), resultIndices);
86       Value opResult = lmhlo::LhloOpToStdScalarOp::map<DotOp>(
87           op, elementType, {l, r, result}, &builder);
88       mapStatus = success(opResult != nullptr);
89       if (failed(mapStatus)) return;
90       builder.create<AffineStoreOp>(loc, opResult, op.getOutput(),
91                                     resultIndices);
92     };
93 
94     buildBoundedAffineLoopNest(rewriter, op.getLoc(),
95                                {shapeLhs[0], shapeRhs[1], shapeRhs[0]},
96                                bodyBuilder);
97     if (failed(mapStatus)) return failure();
98 
99     rewriter.eraseOp(op);
100     return success();
101   }
102 };
103 
104 /// Concat Operation Example (2D):
105 /// Given inpA[2][1], inpB[2][2], concat_dimension = 1.
106 /// Compute output[x1][x2].
107 /// Implementation Pseudocode:
108 /// s = 0
109 /// for a in range(0, 2):
110 ///   for b in range(0, 1):
111 ///     output[a][b] = inpA[a][b - s]
112 /// s = 1
113 /// for a in range(0, 2):
114 ///   for b in range(1, 3):
115 ///     output[a][b] = inpB[a][b - s]
116 ///
117 /// Concatenate composes an array from multiple array operands. The array is of
118 /// the same rank as each of the input array operands (which must be of the same
119 /// rank as each other) and contains the arguments in the order that they were
120 /// specified.
121 struct ConcatOpConverter : public OpRewritePattern<ConcatenateOp> {
122   using OpRewritePattern<ConcatenateOp>::OpRewritePattern;
123 
matchAndRewritemlir::lmhlo::__anonf647f6ce0111::ConcatOpConverter124   LogicalResult matchAndRewrite(ConcatenateOp op,
125                                 PatternRewriter& rewriter) const override {
126     Location loc = op.getLoc();
127     Value output = op.getOutput();
128     MemRefType outputType = output.getType().cast<MemRefType>();
129     unsigned outputRank = outputType.getRank();
130     ArrayRef<int64_t> outputShape = outputType.getShape();
131 
132     ValueRange operands = op.getVal();
133     uint64_t concatDim = op.getDimension();
134     int64_t prevBound = 0;
135 
136     for (Value operand : operands) {
137       MemRefType operandType = operand.getType().cast<MemRefType>();
138       ArrayRef<int64_t> operandShape = operandType.getShape();
139 
140       // TODO(pashu123): Extend support for dynamic dimensions.
141       if (!operandType.hasStaticShape()) return failure();
142 
143       // Only for the concatenation dimension, the value is dimension -
144       // prevBound.
145       SmallVector<AffineExpr, 4> expr;
146       for (unsigned i = 0; i < outputRank; i++) {
147         AffineExpr d0 = (i == concatDim)
148                             ? rewriter.getAffineDimExpr(concatDim) - prevBound
149                             : rewriter.getAffineDimExpr(i);
150 
151         expr.push_back(d0);
152       }
153       AffineMap map =
154           AffineMap::get(outputRank, 0, expr, rewriter.getContext());
155 
156       // Create multiple for loop nests iterating along the concatenation
157       // dimension.
158       OpBuilder::InsertionGuard guard(rewriter);
159       SmallVector<Value, 3> indices;
160       AffineForOp forOp;
161       for (unsigned i = 0; i < outputRank; i++) {
162         if (i == concatDim) {
163           forOp = rewriter.create<AffineForOp>(loc, prevBound,
164                                                prevBound + operandShape[i]);
165           prevBound += operandShape[i];
166           indices.push_back(forOp.getInductionVar());
167         } else {
168           forOp = rewriter.create<AffineForOp>(loc, 0, outputShape[i]);
169           indices.push_back(forOp.getInductionVar());
170         }
171         rewriter.setInsertionPointToStart(forOp.getBody());
172       }
173       Value storeVal =
174           rewriter.create<AffineLoadOp>(loc, operand, map, indices);
175       rewriter.create<AffineStoreOp>(loc, storeVal, output, indices);
176     }
177     rewriter.eraseOp(op);
178     return success();
179   }
180 };
181 
182 /// Returns a zero value of type `type`. `type` is expected to be either
183 /// int or float.
getZeroValue(Type type,Location loc,PatternRewriter & rewriter)184 static Value getZeroValue(Type type, Location loc, PatternRewriter& rewriter) {
185   assert(type.isIntOrFloat() && "Expected int or float");
186 
187   if (IntegerType intType = type.dyn_cast<IntegerType>())
188     return rewriter.create<mlir::arith::ConstantIntOp>(loc, 0,
189                                                        intType.getWidth());
190 
191   FloatType floatType = type.cast<FloatType>();
192   return rewriter.create<mlir::arith::ConstantFloatOp>(
193       loc, APFloat::getZero(floatType.getFloatSemantics()), floatType);
194 }
195 
196 /// Emits a nest to fill the given `buffer` of memref type with `fillValue`.
fillBuffer(Location loc,Value buffer,Value fillValue,PatternRewriter & builder)197 static void fillBuffer(Location loc, Value buffer, Value fillValue,
198                        PatternRewriter& builder) {
199   OpBuilder::InsertionGuard guard(builder);
200   MemRefType bufType = buffer.getType().cast<MemRefType>();
201   unsigned rank = bufType.getRank();
202   SmallVector<Value, 4> dimSizes;
203   dimSizes.reserve(rank);
204   for (unsigned i = 0; i < rank; ++i)
205     dimSizes.push_back(builder.create<mlir::memref::DimOp>(loc, buffer, i));
206 
207   AffineMap idSymMap = builder.getSymbolIdentityMap();
208   AffineMap lbMap = builder.getConstantAffineMap(0);
209   SmallVector<Value, 4> ivs(rank);
210   AffineForOp forOp;
211   for (unsigned i = 0; i < rank; ++i) {
212     forOp = builder.create<AffineForOp>(loc, llvm::None, lbMap, dimSizes[i],
213                                         idSymMap);
214     builder.setInsertionPointToStart(forOp.getBody());
215     ivs[i] = forOp.getInductionVar();
216   }
217   Type fillValueType = fillValue.getType();
218   auto fillMemRefType = fillValueType.dyn_cast<MemRefType>();
219   assert(((fillMemRefType && fillMemRefType.getRank() == 0) ||
220           fillValueType.isIntOrFloat()) &&
221          "init value has to be a 0-d memref or int or fp");
222   Value initVal = fillMemRefType ? builder.create<AffineLoadOp>(
223                                        loc, fillValue, /*indices=*/llvm::None)
224                                  : fillValue;
225   builder.create<AffineStoreOp>(loc, initVal, buffer, ivs);
226 }
227 
228 /// Converts GatherOp to Affine nest form.
229 /// Pseudocode:
230 ///   1. Fill a temporary output tensor with 0.
231 ///   2. Repeat the following for each batch dimension :-
232 ///      1. For each indices in 'operand' :-
233 ///        1. Get hold of start indices from 'start_indices'.
234 ///        2. Add offset to the start indices to get the final indices.
235 ///        3. Load value from 'operand' tensor : 'operand_val'.
236 ///        4. Load value from temporary output : 'prev_val'.
237 ///        5. If the final indices match current indices of 'operand' :
238 ///             'prev_val' = 'prev_val' + 'operand_val'
239 ///        6. Store 'prev_val' back to the temporary output.
240 class GatherOpConverter : public OpRewritePattern<GatherOp> {
241  public:
242   using OpRewritePattern<GatherOp>::OpRewritePattern;
243 
matchAndRewrite(GatherOp op,PatternRewriter & rewriter) const244   LogicalResult matchAndRewrite(GatherOp op,
245                                 PatternRewriter& rewriter) const final {
246     Location loc = op.getLoc();
247 
248     // Operand array.
249     Value operand = op.getOperand();
250     MemRefType operandType = operand.getType().cast<MemRefType>();
251     unsigned operandRank = operandType.getRank();
252     ArrayRef<int64_t> operandShape = operandType.getShape();
253 
254     // Start_indices array.
255     Value startIndices = op.getStartIndices();
256     MemRefType startIndicesType = startIndices.getType().cast<MemRefType>();
257     unsigned startIndicesRank = startIndicesType.getRank();
258     ArrayRef<int64_t> startIndicesShape = startIndicesType.getShape();
259 
260     // Output array.
261     Value output = op.getOutput();
262     MemRefType outputType = output.getType().cast<MemRefType>();
263     ArrayRef<int64_t> outputShape = outputType.getShape();
264 
265     if (!operandType.hasStaticShape() || !startIndicesType.hasStaticShape() ||
266         !outputType.hasStaticShape())
267       return rewriter.notifyMatchFailure(op, "only static shaped type allowed");
268 
269     mhlo::GatherDimensionNumbersAttr gatherDim = op.getDimensionNumbers();
270 
271     auto collapsedSliceDims = gatherDim.getCollapsedSliceDims();
272     auto offsetDims = gatherDim.getOffsetDims();
273     auto startIndexMap = gatherDim.getStartIndexMap();
274     int64_t indexVectorDim = gatherDim.getIndexVectorDim();
275 
276     // Slice_sizes.
277     DenseIntElementsAttr sliceSizesAttr = op.getSliceSizesAttr();
278     SmallVector<int64_t, 4> sliceSizes;
279     for (const APInt& dim : sliceSizesAttr.getValues<APInt>())
280       sliceSizes.push_back(dim.getSExtValue());
281 
282     // Creating constants with 0 value. We need the Integer type constant value
283     // because the indices type will be Integer.
284     Value zeroIntVal = rewriter.create<mlir::arith::ConstantIntOp>(
285         loc, 0, startIndicesType.getElementType());
286     Type elementType = outputType.getElementType();
287     Value zeroLoadValue = getZeroValue(elementType, loc, rewriter);
288     // Initializing the output buffer with 0.
289     fillBuffer(loc, output, zeroLoadValue, rewriter);
290 
291     // We fetch the shape of start_indices at index_vector_dim. In case
292     // index_vector_dim is equal to the rank of start_indices, we implicitly
293     // consider start_indices to have a trailing 1 dimension.
294     unsigned startIndicesNumbers = (indexVectorDim == startIndicesRank)
295                                        ? 1
296                                        : startIndicesShape[indexVectorDim];
297     // We create integer constants till start_incides_index which help us to
298     // fetch start_indices in affine transformation.
299     SmallVector<Value, 4> startIndicesIndex;
300     for (unsigned i = 0; i < startIndicesNumbers; i++) {
301       Value iVal = rewriter.create<mlir::arith::ConstantIntOp>(
302           loc, i, startIndicesType.getElementType());
303       iVal = rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(),
304                                                  iVal);
305       startIndicesIndex.push_back(iVal);
306     }
307 
308     // S_in contains the multiple indices that form a starting index used in the
309     // input/operand tensor. O_in contains the multiple offsets of corresponding
310     // starting index used in the input/operand tensor. We initialize both of
311     // them with 0.
312     SmallVector<Value, 4> sIn;
313     SmallVector<Value, 4> oIn;
314     Value zeroIndexVal = rewriter.create<arith::IndexCastOp>(
315         loc, rewriter.getIndexType(), zeroIntVal);
316     for (unsigned i = 0; i < operandRank; i++) {
317       sIn.push_back(zeroIndexVal);
318       oIn.push_back(zeroIndexVal);
319     }
320 
321     // batch_induction_vars stores the loop induction variables pertaining to
322     // the batches of start indices.
323     SmallVector<Value, 4> batchInductionVars;
324     // output_induction_vars stores the loop induction variables pertaining to
325     // both batches and offsets within the output tensor.
326     SmallVector<Value, 4> outputInductionVars;
327     // Create loops to iterate over each batch of starting index.
328     for (unsigned i = 0; i < startIndicesRank; i++) {
329       // ith dimension of start_indices doesn't form a batch if it is equal to
330       // index_vector_dim.
331       if (i == indexVectorDim) continue;
332       AffineForOp forOp =
333           rewriter.create<AffineForOp>(loc, 0, startIndicesShape[i]);
334       batchInductionVars.push_back(forOp.getInductionVar());
335       outputInductionVars.push_back(forOp.getInductionVar());
336       rewriter.setInsertionPointToStart(forOp.getBody());
337     }
338 
339     // Create loops to iterate over each offset dimension within the output
340     // tensor.
341     for (unsigned i = 0, k = 0, e = offsetDims.size(); i < e; i++) {
342       AffineForOp forOp =
343           rewriter.create<AffineForOp>(loc, 0, outputShape[offsetDims[i]]);
344       rewriter.setInsertionPointToStart(forOp.getBody());
345       // We try to fetch the first non-collapsed dimension.
346       while (k < collapsedSliceDims.size() && collapsedSliceDims[k] == i) k++;
347       // Remapping the offset_dim[i] to the non-collapsed dimension.
348       oIn[k++] = forOp.getInductionVar();
349       // We assume offset_dims to be sorted. So when inserted to
350       // output_induction_vars the loop induction variable gets inserted at the
351       // correct position.
352       outputInductionVars.insert(outputInductionVars.begin() + offsetDims[i],
353                                  forOp.getInductionVar());
354     }
355 
356     // Create loops to iterate over all dimensions within the operand tensor.
357     SmallVector<Value, 4> operandIndex;
358     for (unsigned i = 0, k = 0; i < operandRank; i++) {
359       // We assume start_index_map to have sorted dimensions. We only include
360       // those dimensions of operand tensor which are present in
361       // start_index_map.
362       if (k < startIndexMap.size() && i == startIndexMap[k++]) {
363         AffineForOp forOp =
364             rewriter.create<AffineForOp>(loc, 0, operandShape[i]);
365         operandIndex.push_back(forOp.getInductionVar());
366         rewriter.setInsertionPointToStart(forOp.getBody());
367       } else {
368         operandIndex.push_back(oIn[i]);
369       }
370     }
371 
372     // In case index_vector_dim is not equal to start_indices shape then we
373     // create another loop to iterate over starting index and update
374     // batch_induction_vars.
375     if (indexVectorDim != startIndicesRank) {
376       for (unsigned i = 0; i < startIndicesNumbers; i++) {
377         batchInductionVars.insert(batchInductionVars.begin() + indexVectorDim,
378                                   startIndicesIndex[i]);
379         Value startIndex = rewriter.create<AffineLoadOp>(loc, startIndices,
380                                                          batchInductionVars);
381         startIndex = rewriter.create<arith::IndexCastOp>(
382             loc, rewriter.getIndexType(), startIndex);
383         sIn[startIndexMap[i]] = startIndex;
384         batchInductionVars.erase(batchInductionVars.begin() + indexVectorDim);
385       }
386     } else {
387       // Since index_vector_dim is equal to start_indicesRank we can directly
388       // fetch the start_index from batch_induction_vars.
389       Value startIndex =
390           rewriter.create<AffineLoadOp>(loc, startIndices, batchInductionVars);
391       startIndex = rewriter.create<arith::IndexCastOp>(
392           loc, rewriter.getIndexType(), startIndex);
393       sIn[0] = startIndex;
394     }
395 
396     // We load value at a particular operand index and populate the output
397     // tensor if the index constraints match.
398     Value loadValue = rewriter.create<AffineLoadOp>(loc, operand, operandIndex);
399     SmallVector<Value, 4> predicates;
400     // Adding offsets to the corresponding starting index and comparing it with
401     // the corresponding operand index.
402     for (unsigned k = 0, i = 0; k < startIndexMap.size(); k++) {
403       i = startIndexMap[k];
404       Value addStartIndexOffset = rewriter.create<mlir::arith::AddIOp>(
405           loc, rewriter.getIndexType(), sIn[i], oIn[i]);
406       Value predicate = rewriter.create<mlir::arith::CmpIOp>(
407           loc, arith::CmpIPredicate::eq, addStartIndexOffset, operandIndex[i]);
408       predicates.push_back(predicate);
409     }
410 
411     // Since the no. of predicates is equal to start_index_map.size() we
412     // iterate over pairs of predicates and join them with arith::AndIOp.
413     // We store the final predicate formed by joining other predicates with
414     // arith::AndIOp in result_predicate.
415     Value resultPredicate = nullptr;
416     for (unsigned i = 0; i < predicates.size() - 1; i += 2) {
417       Value predicateA = predicates[i];
418       Value predicateB = predicates[i + 1];
419       Value andPredicate =
420           rewriter.create<mlir::arith::AndIOp>(loc, predicateA, predicateB);
421       resultPredicate = (i == 0) ? andPredicate
422                                  : rewriter.create<mlir::arith::AndIOp>(
423                                        loc, resultPredicate, andPredicate);
424     }
425     // We fetch the last predicate value. In case this is the only predicate
426     // we let result_predicate be equal to this predicate value. Else if there
427     // are odd number of predicates we join it to other predicates using
428     // arith::AndIOp.
429     Value predicate = predicates.back();
430     if (!resultPredicate) resultPredicate = predicate;
431     // In case there are odd number of predicates we join the last predicate
432     // to the result_predicate using arith::AndIOp.
433     else if (startIndexMap.size() % 2 == 1)
434       resultPredicate =
435           rewriter.create<mlir::arith::AndIOp>(loc, resultPredicate, predicate);
436 
437     // We use the loaded value if the index computed by adding offsets to
438     // starting index is equal to the current operand index. We use 0 as a value
439     // otherwise.
440     Value selectLoad = rewriter.create<mlir::arith::SelectOp>(
441         loc, resultPredicate, loadValue, zeroLoadValue);
442     // We load value at output array.
443     Value outputValue =
444         rewriter.create<AffineLoadOp>(loc, output, outputInductionVars);
445 
446     // The selected value is added to the previous value stored in output array.
447     if (elementType.isa<FloatType>())
448       outputValue = rewriter.create<arith::AddFOp>(loc, elementType, selectLoad,
449                                                    outputValue);
450     else
451       outputValue = rewriter.create<arith::AddIOp>(loc, elementType, selectLoad,
452                                                    outputValue);
453     rewriter.create<AffineStoreOp>(loc, outputValue, output,
454                                    outputInductionVars);
455     rewriter.eraseOp(op);
456     return success();
457   }
458 };
459 
460 /// Converts PadOp to Affine nest form.
461 /// Pseudocode:
462 ///   1. Fill `output` tensor with `padding_value`.
463 ///   2. Compute AffineMap for store into `output`.
464 ///      out_idx = edge_padding_low +
465 ///                operand_idx * (1 + interior_padding)
466 ///   3. Create nested loop from `operand` shape.
467 ///      3.1 load from `operand`.
468 ///      3.2 store into `output`.
469 /// NOTE: The lowering handles only ranked shapes and bails out in case any of
470 ///       output type/edge_padding_low/edge_padding_high/interior_padding size
471 ///       doesn't match that of the operand's rank.
472 /// Limitation for now:
473 ///   interior_padding == 0 && edge_padding_* >= 0
474 struct PadOpConverter : public OpRewritePattern<PadOp> {
475   using OpRewritePattern<PadOp>::OpRewritePattern;
476 
matchAndRewritemlir::lmhlo::__anonf647f6ce0111::PadOpConverter477   LogicalResult matchAndRewrite(PadOp op,
478                                 PatternRewriter& rewriter) const override {
479     Value operand = op.getOperand();
480     Value paddingValue = op.getPaddingValue();
481     Value output = op.getOutput();
482 
483     auto operandType = operand.getType().dyn_cast<ShapedType>();
484     auto outputType = output.getType().dyn_cast<ShapedType>();
485     // We allow lowering for only ranked input/output.
486     if (!(operandType && outputType && operandType.hasRank() &&
487           outputType.hasRank()))
488       return failure();
489     unsigned rank = operandType.getRank();
490 
491     auto edgePadLowRanges = op.getEdgePaddingLow().getValues<int64_t>();
492     auto edgePadHighRanges = op.getEdgePaddingHigh().getValues<int64_t>();
493     auto interiorPadRanges = op.getInteriorPadding().getValues<int64_t>();
494     // Check whether the constraints for the lowering are satisfied :-
495     //   1. interior_padding[i] == 0
496     //   2. edge_padding_*[i] >= 0
497     for (auto paddings :
498          llvm::zip(edgePadLowRanges, edgePadHighRanges, interiorPadRanges)) {
499       // Only handle non-negative edge padding.
500       if (std::get<0>(paddings) < 0 || std::get<1>(paddings) < 0)
501         return rewriter.notifyMatchFailure(
502             op, "expected non-negative edge padding");
503       // Only handle interior padding being zero for now.
504       if (std::get<2>(paddings) != 0)
505         return rewriter.notifyMatchFailure(op,
506                                            "expected zero interior padding");
507     }
508 
509     SmallVector<int64_t> edgePaddingLow(edgePadLowRanges.begin(),
510                                         edgePadLowRanges.end());
511     SmallVector<int64_t> edgePaddingHigh(edgePadHighRanges.begin(),
512                                          edgePadHighRanges.end());
513     SmallVector<int64_t> interiorPadding(interiorPadRanges.begin(),
514                                          interiorPadRanges.end());
515 
516     ArrayRef<int64_t> operandShape = operandType.getShape();
517     ArrayRef<int64_t> outputShape = outputType.getShape();
518 
519     // Mapping the `operand` index to the `output` index.
520     SmallVector<AffineExpr, 4> expr;
521     for (unsigned i = 0; i < rank; i++) {
522       AffineExpr dimExpr = rewriter.getAffineDimExpr(i);
523       expr.push_back(dimExpr + edgePaddingLow[i]);
524     }
525     AffineMap map =
526         AffineMap::get(rank, /*symbolCount=*/0, expr, rewriter.getContext());
527 
528     SmallVector<Value, 4> indices;
529 
530     Location loc = op.getLoc();
531     // Set padding_value to output.
532     {
533       OpBuilder::InsertionGuard regionGuard(rewriter);
534       Value scalarPaddingValue = rewriter.create<AffineLoadOp>(
535           loc, paddingValue, SmallVector<Value, 4>());
536       AffineForOp initForOp;
537       for (unsigned i = 0; i < rank; i++) {
538         initForOp = rewriter.create<AffineForOp>(loc, 0, outputShape[i]);
539         indices.push_back(initForOp.getInductionVar());
540         rewriter.setInsertionPointToStart(initForOp.getBody());
541       }
542       rewriter.create<AffineStoreOp>(loc, scalarPaddingValue, output, indices);
543     }
544 
545     // Store `operand` into `output`, loop upper bounds from `operand` shape.
546     indices.clear();
547     AffineForOp padForOp;
548     for (unsigned i = 0; i < rank; i++) {
549       padForOp = rewriter.create<AffineForOp>(loc, 0, operandShape[i]);
550       indices.push_back(padForOp.getInductionVar());
551       rewriter.setInsertionPointToStart(padForOp.getBody());
552     }
553     Value storeVal = rewriter.create<AffineLoadOp>(loc, operand, indices);
554     rewriter.create<AffineStoreOp>(loc, storeVal, output, map, indices);
555     rewriter.eraseOp(op);
556     return success();
557   }
558 };
559 
560 template <typename LhloOpTy>
561 struct BinaryOpConverter : public OpRewritePattern<LhloOpTy> {
562   using OpRewritePattern<LhloOpTy>::OpRewritePattern;
563 
matchAndRewritemlir::lmhlo::__anonf647f6ce0111::BinaryOpConverter564   LogicalResult matchAndRewrite(LhloOpTy op,
565                                 PatternRewriter& rewriter) const override {
566     const auto& lhs = op.getLhs();
567     const auto& rhs = op.getRhs();
568     const auto& lhsType = lhs.getType().template cast<MemRefType>();
569     const auto& rhsType = rhs.getType().template cast<MemRefType>();
570     const auto& elementType = lhsType.getElementType();
571 
572     if (lhsType.getShape() != rhsType.getShape()) {
573       return failure();
574     }
575 
576     LogicalResult mapStatus = success();
577     auto bodyBuilder = [&](OpBuilder& builder, Location loc,
578                            ValueRange inductionVars) {
579       auto l = builder.create<AffineLoadOp>(loc, lhs, inductionVars);
580       auto r = builder.create<AffineLoadOp>(loc, rhs, inductionVars);
581       Value opResult = lmhlo::LhloOpToStdScalarOp::map<LhloOpTy>(
582           op, elementType, {l, r}, &builder);
583       mapStatus = success(opResult != nullptr);
584       if (failed(mapStatus)) return;
585       rewriter.create<AffineStoreOp>(loc, opResult, op.getOut(), inductionVars);
586     };
587 
588     buildBoundedAffineLoopNest(rewriter, op.getLoc(), lhsType.getShape(),
589                                bodyBuilder);
590     if (failed(mapStatus)) return failure();
591     rewriter.eraseOp(op);
592     return success();
593   }
594 };
595 
596 /// Conversion for unary operations i.e. tanh sin cos log log1p etc.
597 template <typename LhloOpTy>
598 struct UnaryOpConverter : public OpRewritePattern<LhloOpTy> {
599   using OpRewritePattern<LhloOpTy>::OpRewritePattern;
600 
matchAndRewritemlir::lmhlo::__anonf647f6ce0111::UnaryOpConverter601   LogicalResult matchAndRewrite(LhloOpTy op,
602                                 PatternRewriter& rewriter) const override {
603     Value input = op.getInput();
604     auto inputType = input.getType().cast<MemRefType>();
605     auto elementType = inputType.getElementType();
606     ArrayRef<int64_t> shape = inputType.getShape();
607 
608     SmallVector<Value, 4> inductionVars;
609 
610     LogicalResult mapStatus = success();
611     auto bodyBuilder = [&](OpBuilder& builder, Location loc,
612                            ValueRange inductionVars) {
613       Value loadInput = builder.create<AffineLoadOp>(loc, input, inductionVars);
614       Value opResult = lmhlo::LhloOpToStdScalarOp::map<LhloOpTy>(
615           op, elementType, {loadInput}, &builder);
616       mapStatus = success(opResult != nullptr);
617       if (failed(mapStatus)) return;
618       rewriter.create<AffineStoreOp>(loc, opResult, op.getOutput(),
619                                      inductionVars);
620     };
621     buildBoundedAffineLoopNest(rewriter, op.getLoc(), shape, bodyBuilder);
622     if (failed(mapStatus)) return failure();
623     rewriter.eraseOp(op);
624     return success();
625   }
626 };
627 
populateLHLOToAffineConversionPattern(MLIRContext * context,RewritePatternSet * patterns)628 void populateLHLOToAffineConversionPattern(MLIRContext* context,
629                                            RewritePatternSet* patterns) {
630   // clang-format off
631   patterns->add<
632       BinaryOpConverter<lmhlo::AddOp>,
633       BinaryOpConverter<lmhlo::AndOp>,
634       BinaryOpConverter<lmhlo::DivOp>,
635       BinaryOpConverter<lmhlo::MaxOp>,
636       BinaryOpConverter<lmhlo::MinOp>,
637       BinaryOpConverter<lmhlo::MulOp>,
638       BinaryOpConverter<lmhlo::SubtractOp>,
639       ConcatOpConverter,
640       DotOpConverter,
641       GatherOpConverter,
642       PadOpConverter,
643       UnaryOpConverter<lmhlo::LogOp>>(context);
644   // clang-format on
645 }
646 
647 struct LhloLegalizeToAffinePass
648     : public LhloLegalizeToAffinePassBase<LhloLegalizeToAffinePass> {
getDependentDialectsmlir::lmhlo::__anonf647f6ce0111::LhloLegalizeToAffinePass649   void getDependentDialects(DialectRegistry& registry) const override {
650     registry.insert<AffineDialect, math::MathDialect>();
651   }
runOnOperationmlir::lmhlo::__anonf647f6ce0111::LhloLegalizeToAffinePass652   void runOnOperation() override {
653     auto func = getOperation();
654     RewritePatternSet patterns(&getContext());
655     populateLHLOToAffineConversionPattern(&getContext(), &patterns);
656     if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns))))
657       return signalPassFailure();
658   }
659 };
660 
661 }  // namespace
662 
createLhloLegalizeToAffinePass()663 std::unique_ptr<OperationPass<func::FuncOp>> createLhloLegalizeToAffinePass() {
664   return std::make_unique<LhloLegalizeToAffinePass>();
665 }
666 
667 }  // namespace lmhlo
668 }  // namespace mlir
669