xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/thlo/IR/thlo_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "mlir-hlo/Dialect/thlo/IR/thlo_ops.h"
17 
18 #include <algorithm>
19 #include <iterator>
20 #include <memory>
21 #include <utility>
22 
23 #include "llvm/ADT/ArrayRef.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/Sequence.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 #include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h"
29 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
30 #include "mlir/Dialect/MemRef/IR/MemRef.h"
31 #include "mlir/Dialect/SCF/IR/SCF.h"
32 #include "mlir/Dialect/Tensor/IR/Tensor.h"
33 #include "mlir/Dialect/Tensor/Utils/Utils.h"
34 #include "mlir/IR/BlockAndValueMapping.h"
35 #include "mlir/IR/BuiltinAttributes.h"
36 #include "mlir/IR/BuiltinTypes.h"
37 #include "mlir/IR/DialectImplementation.h"
38 #include "mlir/IR/OpImplementation.h"
39 #include "mlir/IR/PatternMatch.h"
40 #include "mlir/Interfaces/ViewLikeInterface.h"
41 
42 namespace mlir {
43 namespace {
44 
45 //===----------------------------------------------------------------------===//
46 // Destination-style ops tools
47 //===----------------------------------------------------------------------===//
48 
hasTensorSemantics(OperandRange operands,unsigned numOutputArgs)49 bool hasTensorSemantics(OperandRange operands, unsigned numOutputArgs) {
50   for (auto operand : operands.drop_back(numOutputArgs)) {
51     if (!operand.getType().isa<ShapedType>()) continue;
52     if (!operand.getType().isa<RankedTensorType>()) return false;
53   }
54   return llvm::all_of(operands.take_back(numOutputArgs), [](Value operand) {
55     return operand.getType().isa<RankedTensorType>();
56   });
57 }
58 
hasBufferSemantics(OperandRange operands)59 bool hasBufferSemantics(OperandRange operands) {
60   return llvm::all_of(operands, [](Value operand) {
61     return operand.getType().isa<MemRefType>();
62   });
63 }
64 
verifyDestinationStyleOp(Operation * op,unsigned numOutputArgs=1)65 LogicalResult verifyDestinationStyleOp(Operation *op,
66                                        unsigned numOutputArgs = 1) {
67   if (hasBufferSemantics(op->getOperands()))
68     return success(op->getNumResults() == 0);
69 
70   if (!hasTensorSemantics(op->getOperands(), numOutputArgs))
71     return op->emitOpError("expected either buffer or tensor semantics");
72 
73   if (op->getNumResults() != numOutputArgs) {
74     return op->emitOpError(
75         "expected the number of output args to match the number of results");
76   }
77   for (auto &en : llvm::enumerate(llvm::zip(
78            op->getResultTypes(), op->getOperands().take_back(numOutputArgs)))) {
79     size_t index = en.index();
80     Type resultType = std::get<0>(en.value());
81     Type outputOperandType = std::get<1>(en.value()).getType();
82     if (resultType != outputOperandType)
83       op->emitOpError() << "type " << resultType << " of result " << index
84                         << " does not match output operand type "
85                         << outputOperandType;
86   }
87   return success();
88 }
89 
90 template <typename DstOpTy>
printDstStyleOp(DstOpTy op,OpAsmPrinter & p)91 void printDstStyleOp(DstOpTy op, OpAsmPrinter &p) {
92   if (op.getNumInputs() != 0) {
93     p << " ins(";
94     llvm::interleaveComma(
95         op.getOperands().take_front(op.getNumInputs()), p,
96         [&](Value input) { p << input << " : " << input.getType(); });
97     p << ")";
98   }
99   p << " outs(";
100   llvm::interleaveComma(
101       op.getOperands().take_back(op.getNumOutputs()), p,
102       [&](Value output) { p << output << " : " << output.getType(); });
103   p << ")";
104 
105   p.printOptionalAttrDict(op->getAttrs());
106 }
107 
parseKeywordOperandListWithTypes(OpAsmParser & parser,OperationState & result,StringRef keyword,SmallVectorImpl<Type> * operandTypes)108 ParseResult parseKeywordOperandListWithTypes(
109     OpAsmParser &parser, OperationState &result, StringRef keyword,
110     SmallVectorImpl<Type> *operandTypes) {
111   SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
112   if (succeeded(parser.parseOptionalKeyword(keyword))) {
113     SMLoc operandsOperandsLoc = parser.getCurrentLocation();
114 
115     if (parser.parseCommaSeparatedList(
116             AsmParser::Delimiter::Paren, [&]() -> ParseResult {
117               if (parser.parseOperand(operands.emplace_back(),
118                                       /*allowResultNumber=*/false) ||
119                   parser.parseColon() ||
120                   parser.parseType(operandTypes->emplace_back())) {
121                 return failure();
122               }
123               return success();
124             }))
125       return failure();
126 
127     if (parser.resolveOperands(operands, *operandTypes, operandsOperandsLoc,
128                                result.operands))
129       return failure();
130   }
131   return success();
132 }
133 
parseDstStyleOp(OpAsmParser & parser,OperationState & result)134 ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result) {
135   // Parse `ins` and `outs`.
136   SmallVector<Type, 4> inputTypes, outputTypes;
137   if (parseKeywordOperandListWithTypes(parser, result, "ins", &inputTypes) ||
138       parseKeywordOperandListWithTypes(parser, result, "outs", &outputTypes))
139     return failure();
140 
141   // Add result types.
142   for (Type outputType : outputTypes) {
143     if (outputType.isa<RankedTensorType>()) result.addTypes(outputType);
144   }
145 
146   // Parse optional attributes.
147   if (parser.parseOptionalAttrDict(result.attributes)) return failure();
148   return success();
149 }
150 
151 }  // namespace
152 }  // namespace mlir
153 
154 // Generated dialect definitions.
155 #include "mlir-hlo/Dialect/thlo/IR/thlo_dialect.cc.inc"
156 
157 namespace mlir {
158 namespace thlo {
159 
initialize()160 void THLODialect::initialize() {
161   addOperations<
162 #define GET_OP_LIST
163 #include "mlir-hlo/Dialect/thlo/IR/thlo_ops.cc.inc"
164       >();
165 }
166 
167 //===----------------------------------------------------------------------===//
168 // ConcatenateOp
169 //===----------------------------------------------------------------------===//
170 
171 namespace {
172 
fuseConcatenateOpThroughTile(ConcatenateOp op,OpBuilder & builder,Location loc,Value tile)173 Value fuseConcatenateOpThroughTile(ConcatenateOp op, OpBuilder &builder,
174                                    Location loc, Value tile) {
175   uint64_t concatDim = op.dimension();
176   auto resultTy = op.getResult().getType().cast<RankedTensorType>();
177   int64_t rank = resultTy.getRank();
178   OperandRange allOperands = op.operands();
179   Value anyOperand = allOperands.front();
180 
181   // Create the shared tile strides, which are the exact same for every operand
182   // tile. Also create a basis for the space sizes, tile offsets, and tile
183   // sizes. These hold the shared values in all non-concat dimensions and can be
184   // amended in the concat dimension to create the individual operand tiles.
185   SmallVector<Value> sharedTileStrides(rank);
186   SmallVector<Value> baseSpaceSizes(rank);
187   SmallVector<Value> baseTileOffsets(rank);
188   SmallVector<Value> baseTileSizes(rank);
189   for (int64_t i = 0; i < rank; ++i) {
190     Value iCst = builder.create<arith::ConstantIndexOp>(loc, i);
191     sharedTileStrides[i] = builder.create<gml_st::StrideOp>(loc, tile, iCst);
192 
193     // The space sizes, tile offsets, and tile sizes differ in the concat
194     // dimension. Do not populate these.
195     if (i == concatDim) {
196       continue;
197     }
198 
199     baseSpaceSizes[i] =
200         builder.createOrFold<tensor::DimOp>(loc, anyOperand, iCst);
201     baseTileOffsets[i] = builder.create<gml_st::OffsetOp>(loc, tile, iCst);
202     baseTileSizes[i] = builder.create<gml_st::SizeOp>(loc, tile, iCst);
203   }
204 
205   // Some shared values.
206   ArrayAttr allDynamicStridesOrOffsetsAttr = builder.getI64ArrayAttr(
207       SmallVector<int64_t>(rank, ShapedType::kDynamicStrideOrOffset));
208   ArrayAttr allDynamicSizesAttr = builder.getI64ArrayAttr(
209       SmallVector<int64_t>(rank, ShapedType::kDynamicSize));
210   Value zeroCst = builder.create<arith::ConstantIndexOp>(loc, 0);
211   Value concatDimCst = builder.create<arith::ConstantIndexOp>(loc, concatDim);
212   Value maxTileSizeInConcatDim =
213       builder.create<gml_st::SizeOp>(loc, tile, concatDimCst);
214 
215   // The remaining tile offset in the concat dimension is subtracted by each
216   // operand's size in that dimension. We maintain the invariant
217   // remainingTileOffsetInConcatDim >= 0.
218   Value remainingTileOffsetInConcatDim =
219       builder.create<gml_st::OffsetOp>(loc, tile, concatDimCst);
220 
221   // Create the relevant subsets per operand. These tiles can be empty at
222   // runtime.
223   SmallVector<Value> subOperands;
224   subOperands.reserve(allOperands.size());
225   for (Value operand : allOperands) {
226     // Create operand space.
227     Value operandSizeInConcatDim =
228         builder.create<tensor::DimOp>(loc, operand, concatDimCst);
229     baseSpaceSizes[concatDim] = operandSizeInConcatDim;
230     Value operandSpace = builder.create<gml_st::SpaceOp>(loc, baseSpaceSizes,
231                                                          allDynamicSizesAttr);
232 
233     // Find the current operand's tile offset in the concat dimension. This is
234     // the remaining offset clamped into the bounds of the operand. Note that
235     // the remaining offset is always >= 0.
236     Value operandTileOffsetInConcatDim = builder.create<arith::MinUIOp>(
237         loc, remainingTileOffsetInConcatDim, operandSizeInConcatDim);
238     baseTileOffsets[concatDim] = operandTileOffsetInConcatDim;
239 
240     // Find the current operand's tile size in the concat dimension.
241     Value remainingOperandSizeInConcatDim = builder.create<arith::SubIOp>(
242         loc, operandSizeInConcatDim, operandTileOffsetInConcatDim);
243     baseTileSizes[concatDim] = builder.create<arith::MinUIOp>(
244         loc, remainingOperandSizeInConcatDim, maxTileSizeInConcatDim);
245 
246     // Create the operand tile and materialize the subset for this operand.
247     Value tile = builder.create<gml_st::TileOp>(
248         loc, operandSpace, baseTileOffsets, baseTileSizes, sharedTileStrides,
249         allDynamicStridesOrOffsetsAttr, allDynamicSizesAttr,
250         allDynamicStridesOrOffsetsAttr);
251     subOperands.push_back(
252         builder.create<gml_st::MaterializeOp>(loc, operand, tile));
253 
254     // Unless it is the last operand, update the remaining tile offset in the
255     // concat dimension. The remaining offset is subtracted by the operand's
256     // size but must remain >= 0.
257     if (operand != allOperands.back()) {
258       remainingTileOffsetInConcatDim = builder.create<arith::SelectOp>(
259           loc,
260           builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule,
261                                         remainingTileOffsetInConcatDim,
262                                         operandSizeInConcatDim),
263           zeroCst,
264           builder.create<arith::SubIOp>(loc, remainingTileOffsetInConcatDim,
265                                         operandSizeInConcatDim));
266     }
267   }
268 
269   // Create the tiled concat op.
270   auto tileType = tile.getType().cast<gml_st::TileType>();
271   Value subInit = builder.create<gml_st::MaterializeOp>(loc, op.init(), tile);
272   auto subResultType =
273       RankedTensorType::get(tileType.getShape(), resultTy.getElementType());
274   return builder.create<thlo::ConcatenateOp>(loc, subResultType, subOperands,
275                                              subInit, concatDim);
276 }
277 
fuseConcatenateOpThroughPointRecursively(OpBuilder & builder,Location loc,RankedTensorType rankedTy,uint64_t concatDim,SmallVector<Value> & remainingOffsets,ValueRange remainingOperands)278 Value fuseConcatenateOpThroughPointRecursively(
279     OpBuilder &builder, Location loc, RankedTensorType rankedTy,
280     uint64_t concatDim, SmallVector<Value> &remainingOffsets,
281     ValueRange remainingOperands) {
282   // Bail if called for no operands.
283   if (remainingOperands.empty()) {
284     return {};
285   }
286   Value leadingOperand = remainingOperands.front();
287 
288   // Terminal case of exactly one operand.
289   if (remainingOperands.size() == 1) {
290     // Create operand space.
291     SmallVector<Value> dynamicDims =
292         tensor::createDynamicDimValues(builder, loc, leadingOperand);
293     ArrayAttr staticDims = builder.getI64ArrayAttr(rankedTy.getShape());
294     Value operandSpace =
295         builder.create<gml_st::SpaceOp>(loc, dynamicDims, staticDims);
296 
297     // Create operand point.
298     SmallVector<int64_t> allDynamicOffsets(rankedTy.getRank(),
299                                            ShapedType::kDynamicStrideOrOffset);
300     Value operandPoint = builder.create<gml_st::PointOp>(
301         loc, operandSpace, remainingOffsets,
302         builder.getI64ArrayAttr(allDynamicOffsets));
303 
304     return builder.create<gml_st::MaterializeOp>(loc, leadingOperand,
305                                                  operandPoint);
306   }
307 
308   // For more than 1 operand, distinguish between the leading operand and the
309   // remainder.
310   assert(remainingOperands.size() > 1 &&
311          "expect more than 1 operand at this point");
312   Value leadingOperandConcatDim =
313       builder.create<tensor::DimOp>(loc, leadingOperand, concatDim);
314   Value leadingOperandPredicate = builder.create<arith::CmpIOp>(
315       loc, arith::CmpIPredicate::ult, remainingOffsets[concatDim],
316       leadingOperandConcatDim);
317   auto ifOp = builder.create<scf::IfOp>(
318       loc, rankedTy.getElementType(), leadingOperandPredicate,
319       [&](OpBuilder &builder, Location loc) {
320         // For the leading operand, recur with the current offsets.
321         Value fused = fuseConcatenateOpThroughPointRecursively(
322             builder, loc, rankedTy, concatDim, remainingOffsets,
323             leadingOperand);
324         builder.create<scf::YieldOp>(loc, fused);
325       },
326       [&](OpBuilder &builder, Location loc) {
327         // For the remaining operands, substract the leading operand's size from
328         // the remaining offsets in the concatenation dimension.
329         SmallVector<Value> thenRemainingOffsets(remainingOffsets.begin(),
330                                                 remainingOffsets.end());
331         thenRemainingOffsets[concatDim] = builder.create<arith::SubIOp>(
332             loc, remainingOffsets[concatDim], leadingOperandConcatDim);
333         Value fused = fuseConcatenateOpThroughPointRecursively(
334             builder, loc, rankedTy, concatDim, thenRemainingOffsets,
335             remainingOperands.drop_front());
336         builder.create<scf::YieldOp>(loc, fused);
337       });
338   return ifOp.getResults().front();
339 }
340 
fuseConcatenateOpThroughPoint(ConcatenateOp op,OpBuilder & builder,Location loc,Value subset)341 Value fuseConcatenateOpThroughPoint(ConcatenateOp op, OpBuilder &builder,
342                                     Location loc, Value subset) {
343   auto resultTy = op.getType().cast<RankedTensorType>();
344   int64_t resultRank = resultTy.getRank();
345   uint64_t concatDim = op.dimension();
346 
347   // Materialize initial offsets.
348   SmallVector<Value> initialOffsets;
349   initialOffsets.reserve(resultRank);
350   for (int64_t i = 0; i < resultRank; ++i) {
351     initialOffsets.push_back(builder.create<gml_st::OffsetOp>(
352         loc, subset, builder.create<arith::ConstantIndexOp>(loc, i)));
353   }
354 
355   ValueRange initialOperands = op.operands();
356   return fuseConcatenateOpThroughPointRecursively(
357       builder, loc, resultTy, concatDim, initialOffsets, initialOperands);
358 }
359 
360 }  // namespace
361 
fuse(Location loc,Value subset,OpBuilder & builder)362 Value ConcatenateOp::fuse(Location loc, Value subset, OpBuilder &builder) {
363   Type subsetTy = subset.getType();
364   if (subsetTy.isa<gml_st::TileType>()) {
365     return fuseConcatenateOpThroughTile(*this, builder, loc, subset);
366   }
367   if (subsetTy.isa<gml_st::PointType>()) {
368     return fuseConcatenateOpThroughPoint(*this, builder, loc, subset);
369   }
370   return {};
371 }
372 
parse(OpAsmParser & parser,OperationState & result)373 ParseResult ConcatenateOp::parse(OpAsmParser &parser, OperationState &result) {
374   return parseDstStyleOp(parser, result);
375 }
376 
print(OpAsmPrinter & p)377 void ConcatenateOp::print(OpAsmPrinter &p) {
378   printDstStyleOp(cast<ConcatenateOp>(getOperation()), p);
379 }
380 
verify()381 LogicalResult ConcatenateOp::verify() {
382   return verifyDestinationStyleOp(getOperation(), getNumOutputs());
383 }
384 
385 //===----------------------------------------------------------------------===//
386 // DynamicBroadcastInDimOp
387 //===----------------------------------------------------------------------===//
388 
parse(OpAsmParser & parser,OperationState & result)389 ParseResult DynamicBroadcastInDimOp::parse(OpAsmParser &parser,
390                                            OperationState &result) {
391   return parseDstStyleOp(parser, result);
392 }
393 
print(OpAsmPrinter & p)394 void DynamicBroadcastInDimOp::print(OpAsmPrinter &p) {
395   printDstStyleOp(cast<DynamicBroadcastInDimOp>(getOperation()), p);
396 }
397 
verify()398 LogicalResult DynamicBroadcastInDimOp::verify() {
399   return verifyDestinationStyleOp(getOperation(), getNumOutputs());
400 }
401 
fuse(Location loc,Value subset,OpBuilder & builder)402 Value DynamicBroadcastInDimOp::fuse(Location loc, Value subset,
403                                     OpBuilder &builder) {
404   Type subsetTy = subset.getType();
405   auto operandTy = operand().getType().cast<RankedTensorType>();
406   auto resultTy = getType(0).cast<RankedTensorType>();
407   int64_t operandRank = operandTy.getRank();
408 
409   // Create the needed constants only once.
410   DenseMap<uint64_t, Value> localIndexConstants;
411   auto getIndexConstant = [&](uint64_t c) -> Value {
412     auto it = localIndexConstants.find(c);
413     if (it != localIndexConstants.end()) return it->second;
414     auto cst = builder.create<arith::ConstantIndexOp>(loc, c);
415     localIndexConstants[c] = cst;
416     return cst;
417   };
418 
419   // Materialize operand space.
420   auto operandSpaceTy = builder.getType<gml_st::TileType>(operandTy.getShape());
421   auto dynamicDims = tensor::createDynamicDimValues(builder, loc, operand());
422   auto staticDims = builder.getI64ArrayAttr(operandTy.getShape());
423   Value operandSpace = builder.create<gml_st::SpaceOp>(loc, operandSpaceTy,
424                                                        dynamicDims, staticDims);
425 
426   // Materialize operand dimensions.
427   SmallVector<Value> operandDims;
428   int64_t dynamicDimsIdx = 0;
429   operandDims.reserve(operandTy.getRank());
430   for (const auto &it : llvm::enumerate(operandTy.getShape())) {
431     int64_t d = it.value();
432     Value dim = d == ShapedType::kDynamicSize ? dynamicDims[dynamicDimsIdx++]
433                                               : getIndexConstant(d);
434     operandDims.push_back(dim);
435   }
436 
437   // Collapse the subset to operate only on corresponding dimensions.
438   // TODO(frgossen): Only generate this when needed.
439   auto collapsedSubset = builder.create<gml_st::DropDimsOp>(
440       loc, subset, broadcast_dimensionsAttr());
441 
442   // Find the expanding dimensions. If corresponding operand and result
443   // dimensions are different then the dimension is expanding.
444   // TODO(frgossen): Use info from known expanding and known non-expanding
445   // dimensions here.
446   SmallVector<Value> operandExpandingDims;
447   for (const auto &it : llvm::enumerate(broadcast_dimensions())) {
448     auto operandDim = operandDims[it.index()];
449     auto resultDim = builder.create<tensor::DimOp>(
450         loc, init(), getIndexConstant(it.value()));
451     operandExpandingDims.push_back(builder.create<arith::CmpIOp>(
452         loc, arith::CmpIPredicate::ne, operandDim, resultDim));
453   }
454 
455   // Compute operand offsets, which are needed for tile and point subsets.
456   auto staticOffsets = builder.getI64ArrayAttr(
457       SmallVector<int64_t>(operandRank, ShapedType::kDynamicStrideOrOffset));
458   SmallVector<Value> offsets;
459   Value zero = getIndexConstant(0);
460   for (int i = 0; i < operandRank; ++i) {
461     Value isExpanding = operandExpandingDims[i];
462     Value collapsedSubsetOffset = builder.create<gml_st::OffsetOp>(
463         loc, collapsedSubset, getIndexConstant(i));
464     offsets.push_back(builder.create<arith::SelectOp>(loc, isExpanding, zero,
465                                                       collapsedSubsetOffset));
466   }
467 
468   // If the regarded subset is of point type, we can already construct the
469   // operand point and materialize it.
470   if (auto pointTy = subsetTy.dyn_cast<gml_st::PointType>()) {
471     auto operandPoint = builder.create<gml_st::PointOp>(
472         loc, pointTy, operandSpace, offsets, staticOffsets);
473     return builder.create<gml_st::MaterializeOp>(
474         loc, operandTy.getElementType(), operand(), operandPoint);
475   }
476 
477   // If the regarded subset is of tile type, we still need the operand tile
478   // sizes to materialize a fused broadcast.
479   if (auto tileTy = subsetTy.dyn_cast<gml_st::TileType>()) {
480     // Compute operand tile sizes.
481     auto staticTileSizes = builder.getI64ArrayAttr(
482         SmallVector<int64_t>(operandRank, ShapedType::kDynamicSize));
483     SmallVector<Value> tileSizes;
484     Value one = getIndexConstant(1);
485     for (int i = 0; i < operandRank; ++i) {
486       Value isExpanding = operandExpandingDims[i];
487       Value tileSize = builder.create<gml_st::SizeOp>(loc, collapsedSubset,
488                                                       getIndexConstant(i));
489       tileSizes.push_back(
490           builder.create<arith::SelectOp>(loc, isExpanding, one, tileSize));
491     }
492 
493     // Create operand tile.
494     auto staticTileStrides =
495         builder.getI64ArrayAttr(SmallVector<int64_t>(operandRank, 1));
496     SmallVector<Value> tileStrides = {};
497     auto operandTileTy = builder.getType<gml_st::TileType>(
498         SmallVector<int64_t>(operandRank, ShapedType::kDynamicSize));
499     auto operandTile = builder.create<gml_st::TileOp>(
500         loc, operandTileTy, operandSpace, offsets, tileSizes, tileStrides,
501         staticOffsets, staticTileSizes, staticTileStrides);
502 
503     // Materialize operand subsets.
504     Value tiledInit =
505         builder.create<gml_st::MaterializeOp>(loc, init(), subset);
506     Value tiledOperand =
507         builder.create<gml_st::MaterializeOp>(loc, operand(), operandTile);
508 
509     // Finally, materialize tiled broadcast.
510     auto tiledResultTy =
511         RankedTensorType::get(tileTy.getShape(), resultTy.getElementType());
512     return builder
513         .create<DynamicBroadcastInDimOp>(
514             loc, TypeRange{tiledResultTy}, tiledOperand, tiledInit,
515             broadcast_dimensionsAttr(), known_expanding_dimensionsAttr(),
516             known_nonexpanding_dimensionsAttr())
517         .getResult(0);
518   }
519 
520   return {};
521 }
522 
523 //===----------------------------------------------------------------------===//
524 // ScatterOp
525 //===----------------------------------------------------------------------===//
526 
parse(OpAsmParser & parser,OperationState & result)527 ParseResult ScatterOp::parse(OpAsmParser &parser, OperationState &result) {
528   return parseDstStyleOp(parser, result);
529 }
530 
print(OpAsmPrinter & p)531 void ScatterOp::print(OpAsmPrinter &p) {
532   printDstStyleOp(cast<ScatterOp>(getOperation()), p);
533 }
534 
verify()535 LogicalResult ScatterOp::verify() {
536   return verifyDestinationStyleOp(getOperation(), getNumOutputs());
537 }
538 
539 //===----------------------------------------------------------------------===//
540 // GatherOp
541 //===----------------------------------------------------------------------===//
542 
parse(OpAsmParser & parser,OperationState & result)543 ParseResult GatherOp::parse(OpAsmParser &parser, OperationState &result) {
544   return parseDstStyleOp(parser, result);
545 }
546 
print(OpAsmPrinter & p)547 void GatherOp::print(OpAsmPrinter &p) {
548   printDstStyleOp(cast<GatherOp>(getOperation()), p);
549 }
550 
verify()551 LogicalResult GatherOp::verify() {
552   return verifyDestinationStyleOp(getOperation(), getNumOutputs());
553 }
554 
555 //===----------------------------------------------------------------------===//
556 // TransposeOp
557 //===----------------------------------------------------------------------===//
558 
parse(OpAsmParser & parser,OperationState & result)559 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
560   return parseDstStyleOp(parser, result);
561 }
562 
print(OpAsmPrinter & p)563 void TransposeOp::print(OpAsmPrinter &p) {
564   printDstStyleOp(cast<TransposeOp>(getOperation()), p);
565 }
566 
isValidPermutation(ArrayRef<int64_t> permutation)567 bool isValidPermutation(ArrayRef<int64_t> permutation) {
568   SmallVector<bool> seen(permutation.size(), false);
569   for (auto p : permutation) {
570     // Verify that each element is in [0..n-1] range and is present only once.
571     if (p < 0 || p >= permutation.size() || seen[p]) return false;
572 
573     seen[p] = true;
574   }
575   return true;
576 }
577 
verify()578 LogicalResult TransposeOp::verify() {
579   ArrayRef<int64_t> permutationRef = permutation();
580 
581   if (!isValidPermutation(permutationRef))
582     return emitOpError("permutation is not valid");
583 
584   auto inputType = input().getType().cast<ShapedType>();
585   auto initType = init().getType().cast<ShapedType>();
586 
587   int64_t rank = inputType.getRank();
588 
589   if (rank != initType.getRank())
590     return emitOpError() << "input rank " << rank
591                          << " does not match init rank " << initType.getRank();
592 
593   if (rank != permutationRef.size())
594     return emitOpError() << "size of permutation " << permutationRef.size()
595                          << " does not match the argument rank " << rank;
596 
597   auto inputDims = inputType.getShape();
598   auto initDims = initType.getShape();
599 
600   for (size_t i = 0; i < rank; ++i) {
601     int64_t inputDim = inputDims[permutationRef[i]];
602     int64_t initDim = initDims[i];
603 
604     if (inputDim != ShapedType::kDynamicSize &&
605         initDim != ShapedType::kDynamicSize && inputDim != initDim) {
606       return emitOpError() << "dim(result, " << i << ") = " << initDim
607                            << " doesn't match dim(input, permutation[" << i
608                            << "]) = " << inputDim;
609     }
610   }
611 
612   return verifyDestinationStyleOp(getOperation(), getNumOutputs());
613 }
614 
615 }  // namespace thlo
616 }  // namespace mlir
617 
618 // Generated op classes.
619 #define GET_OP_CLASSES
620 #include "mlir-hlo/Dialect/thlo/IR/thlo_ops.cc.inc"
621