1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "mlir-hlo/Dialect/thlo/IR/thlo_ops.h"
17
18 #include <algorithm>
19 #include <iterator>
20 #include <memory>
21 #include <utility>
22
23 #include "llvm/ADT/ArrayRef.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/Sequence.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 #include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h"
29 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
30 #include "mlir/Dialect/MemRef/IR/MemRef.h"
31 #include "mlir/Dialect/SCF/IR/SCF.h"
32 #include "mlir/Dialect/Tensor/IR/Tensor.h"
33 #include "mlir/Dialect/Tensor/Utils/Utils.h"
34 #include "mlir/IR/BlockAndValueMapping.h"
35 #include "mlir/IR/BuiltinAttributes.h"
36 #include "mlir/IR/BuiltinTypes.h"
37 #include "mlir/IR/DialectImplementation.h"
38 #include "mlir/IR/OpImplementation.h"
39 #include "mlir/IR/PatternMatch.h"
40 #include "mlir/Interfaces/ViewLikeInterface.h"
41
42 namespace mlir {
43 namespace {
44
45 //===----------------------------------------------------------------------===//
46 // Destination-style ops tools
47 //===----------------------------------------------------------------------===//
48
hasTensorSemantics(OperandRange operands,unsigned numOutputArgs)49 bool hasTensorSemantics(OperandRange operands, unsigned numOutputArgs) {
50 for (auto operand : operands.drop_back(numOutputArgs)) {
51 if (!operand.getType().isa<ShapedType>()) continue;
52 if (!operand.getType().isa<RankedTensorType>()) return false;
53 }
54 return llvm::all_of(operands.take_back(numOutputArgs), [](Value operand) {
55 return operand.getType().isa<RankedTensorType>();
56 });
57 }
58
hasBufferSemantics(OperandRange operands)59 bool hasBufferSemantics(OperandRange operands) {
60 return llvm::all_of(operands, [](Value operand) {
61 return operand.getType().isa<MemRefType>();
62 });
63 }
64
verifyDestinationStyleOp(Operation * op,unsigned numOutputArgs=1)65 LogicalResult verifyDestinationStyleOp(Operation *op,
66 unsigned numOutputArgs = 1) {
67 if (hasBufferSemantics(op->getOperands()))
68 return success(op->getNumResults() == 0);
69
70 if (!hasTensorSemantics(op->getOperands(), numOutputArgs))
71 return op->emitOpError("expected either buffer or tensor semantics");
72
73 if (op->getNumResults() != numOutputArgs) {
74 return op->emitOpError(
75 "expected the number of output args to match the number of results");
76 }
77 for (auto &en : llvm::enumerate(llvm::zip(
78 op->getResultTypes(), op->getOperands().take_back(numOutputArgs)))) {
79 size_t index = en.index();
80 Type resultType = std::get<0>(en.value());
81 Type outputOperandType = std::get<1>(en.value()).getType();
82 if (resultType != outputOperandType)
83 op->emitOpError() << "type " << resultType << " of result " << index
84 << " does not match output operand type "
85 << outputOperandType;
86 }
87 return success();
88 }
89
90 template <typename DstOpTy>
printDstStyleOp(DstOpTy op,OpAsmPrinter & p)91 void printDstStyleOp(DstOpTy op, OpAsmPrinter &p) {
92 if (op.getNumInputs() != 0) {
93 p << " ins(";
94 llvm::interleaveComma(
95 op.getOperands().take_front(op.getNumInputs()), p,
96 [&](Value input) { p << input << " : " << input.getType(); });
97 p << ")";
98 }
99 p << " outs(";
100 llvm::interleaveComma(
101 op.getOperands().take_back(op.getNumOutputs()), p,
102 [&](Value output) { p << output << " : " << output.getType(); });
103 p << ")";
104
105 p.printOptionalAttrDict(op->getAttrs());
106 }
107
parseKeywordOperandListWithTypes(OpAsmParser & parser,OperationState & result,StringRef keyword,SmallVectorImpl<Type> * operandTypes)108 ParseResult parseKeywordOperandListWithTypes(
109 OpAsmParser &parser, OperationState &result, StringRef keyword,
110 SmallVectorImpl<Type> *operandTypes) {
111 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
112 if (succeeded(parser.parseOptionalKeyword(keyword))) {
113 SMLoc operandsOperandsLoc = parser.getCurrentLocation();
114
115 if (parser.parseCommaSeparatedList(
116 AsmParser::Delimiter::Paren, [&]() -> ParseResult {
117 if (parser.parseOperand(operands.emplace_back(),
118 /*allowResultNumber=*/false) ||
119 parser.parseColon() ||
120 parser.parseType(operandTypes->emplace_back())) {
121 return failure();
122 }
123 return success();
124 }))
125 return failure();
126
127 if (parser.resolveOperands(operands, *operandTypes, operandsOperandsLoc,
128 result.operands))
129 return failure();
130 }
131 return success();
132 }
133
parseDstStyleOp(OpAsmParser & parser,OperationState & result)134 ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result) {
135 // Parse `ins` and `outs`.
136 SmallVector<Type, 4> inputTypes, outputTypes;
137 if (parseKeywordOperandListWithTypes(parser, result, "ins", &inputTypes) ||
138 parseKeywordOperandListWithTypes(parser, result, "outs", &outputTypes))
139 return failure();
140
141 // Add result types.
142 for (Type outputType : outputTypes) {
143 if (outputType.isa<RankedTensorType>()) result.addTypes(outputType);
144 }
145
146 // Parse optional attributes.
147 if (parser.parseOptionalAttrDict(result.attributes)) return failure();
148 return success();
149 }
150
151 } // namespace
152 } // namespace mlir
153
154 // Generated dialect definitions.
155 #include "mlir-hlo/Dialect/thlo/IR/thlo_dialect.cc.inc"
156
157 namespace mlir {
158 namespace thlo {
159
initialize()160 void THLODialect::initialize() {
161 addOperations<
162 #define GET_OP_LIST
163 #include "mlir-hlo/Dialect/thlo/IR/thlo_ops.cc.inc"
164 >();
165 }
166
167 //===----------------------------------------------------------------------===//
168 // ConcatenateOp
169 //===----------------------------------------------------------------------===//
170
171 namespace {
172
fuseConcatenateOpThroughTile(ConcatenateOp op,OpBuilder & builder,Location loc,Value tile)173 Value fuseConcatenateOpThroughTile(ConcatenateOp op, OpBuilder &builder,
174 Location loc, Value tile) {
175 uint64_t concatDim = op.dimension();
176 auto resultTy = op.getResult().getType().cast<RankedTensorType>();
177 int64_t rank = resultTy.getRank();
178 OperandRange allOperands = op.operands();
179 Value anyOperand = allOperands.front();
180
181 // Create the shared tile strides, which are the exact same for every operand
182 // tile. Also create a basis for the space sizes, tile offsets, and tile
183 // sizes. These hold the shared values in all non-concat dimensions and can be
184 // amended in the concat dimension to create the individual operand tiles.
185 SmallVector<Value> sharedTileStrides(rank);
186 SmallVector<Value> baseSpaceSizes(rank);
187 SmallVector<Value> baseTileOffsets(rank);
188 SmallVector<Value> baseTileSizes(rank);
189 for (int64_t i = 0; i < rank; ++i) {
190 Value iCst = builder.create<arith::ConstantIndexOp>(loc, i);
191 sharedTileStrides[i] = builder.create<gml_st::StrideOp>(loc, tile, iCst);
192
193 // The space sizes, tile offsets, and tile sizes differ in the concat
194 // dimension. Do not populate these.
195 if (i == concatDim) {
196 continue;
197 }
198
199 baseSpaceSizes[i] =
200 builder.createOrFold<tensor::DimOp>(loc, anyOperand, iCst);
201 baseTileOffsets[i] = builder.create<gml_st::OffsetOp>(loc, tile, iCst);
202 baseTileSizes[i] = builder.create<gml_st::SizeOp>(loc, tile, iCst);
203 }
204
205 // Some shared values.
206 ArrayAttr allDynamicStridesOrOffsetsAttr = builder.getI64ArrayAttr(
207 SmallVector<int64_t>(rank, ShapedType::kDynamicStrideOrOffset));
208 ArrayAttr allDynamicSizesAttr = builder.getI64ArrayAttr(
209 SmallVector<int64_t>(rank, ShapedType::kDynamicSize));
210 Value zeroCst = builder.create<arith::ConstantIndexOp>(loc, 0);
211 Value concatDimCst = builder.create<arith::ConstantIndexOp>(loc, concatDim);
212 Value maxTileSizeInConcatDim =
213 builder.create<gml_st::SizeOp>(loc, tile, concatDimCst);
214
215 // The remaining tile offset in the concat dimension is subtracted by each
216 // operand's size in that dimension. We maintain the invariant
217 // remainingTileOffsetInConcatDim >= 0.
218 Value remainingTileOffsetInConcatDim =
219 builder.create<gml_st::OffsetOp>(loc, tile, concatDimCst);
220
221 // Create the relevant subsets per operand. These tiles can be empty at
222 // runtime.
223 SmallVector<Value> subOperands;
224 subOperands.reserve(allOperands.size());
225 for (Value operand : allOperands) {
226 // Create operand space.
227 Value operandSizeInConcatDim =
228 builder.create<tensor::DimOp>(loc, operand, concatDimCst);
229 baseSpaceSizes[concatDim] = operandSizeInConcatDim;
230 Value operandSpace = builder.create<gml_st::SpaceOp>(loc, baseSpaceSizes,
231 allDynamicSizesAttr);
232
233 // Find the current operand's tile offset in the concat dimension. This is
234 // the remaining offset clamped into the bounds of the operand. Note that
235 // the remaining offset is always >= 0.
236 Value operandTileOffsetInConcatDim = builder.create<arith::MinUIOp>(
237 loc, remainingTileOffsetInConcatDim, operandSizeInConcatDim);
238 baseTileOffsets[concatDim] = operandTileOffsetInConcatDim;
239
240 // Find the current operand's tile size in the concat dimension.
241 Value remainingOperandSizeInConcatDim = builder.create<arith::SubIOp>(
242 loc, operandSizeInConcatDim, operandTileOffsetInConcatDim);
243 baseTileSizes[concatDim] = builder.create<arith::MinUIOp>(
244 loc, remainingOperandSizeInConcatDim, maxTileSizeInConcatDim);
245
246 // Create the operand tile and materialize the subset for this operand.
247 Value tile = builder.create<gml_st::TileOp>(
248 loc, operandSpace, baseTileOffsets, baseTileSizes, sharedTileStrides,
249 allDynamicStridesOrOffsetsAttr, allDynamicSizesAttr,
250 allDynamicStridesOrOffsetsAttr);
251 subOperands.push_back(
252 builder.create<gml_st::MaterializeOp>(loc, operand, tile));
253
254 // Unless it is the last operand, update the remaining tile offset in the
255 // concat dimension. The remaining offset is subtracted by the operand's
256 // size but must remain >= 0.
257 if (operand != allOperands.back()) {
258 remainingTileOffsetInConcatDim = builder.create<arith::SelectOp>(
259 loc,
260 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule,
261 remainingTileOffsetInConcatDim,
262 operandSizeInConcatDim),
263 zeroCst,
264 builder.create<arith::SubIOp>(loc, remainingTileOffsetInConcatDim,
265 operandSizeInConcatDim));
266 }
267 }
268
269 // Create the tiled concat op.
270 auto tileType = tile.getType().cast<gml_st::TileType>();
271 Value subInit = builder.create<gml_st::MaterializeOp>(loc, op.init(), tile);
272 auto subResultType =
273 RankedTensorType::get(tileType.getShape(), resultTy.getElementType());
274 return builder.create<thlo::ConcatenateOp>(loc, subResultType, subOperands,
275 subInit, concatDim);
276 }
277
fuseConcatenateOpThroughPointRecursively(OpBuilder & builder,Location loc,RankedTensorType rankedTy,uint64_t concatDim,SmallVector<Value> & remainingOffsets,ValueRange remainingOperands)278 Value fuseConcatenateOpThroughPointRecursively(
279 OpBuilder &builder, Location loc, RankedTensorType rankedTy,
280 uint64_t concatDim, SmallVector<Value> &remainingOffsets,
281 ValueRange remainingOperands) {
282 // Bail if called for no operands.
283 if (remainingOperands.empty()) {
284 return {};
285 }
286 Value leadingOperand = remainingOperands.front();
287
288 // Terminal case of exactly one operand.
289 if (remainingOperands.size() == 1) {
290 // Create operand space.
291 SmallVector<Value> dynamicDims =
292 tensor::createDynamicDimValues(builder, loc, leadingOperand);
293 ArrayAttr staticDims = builder.getI64ArrayAttr(rankedTy.getShape());
294 Value operandSpace =
295 builder.create<gml_st::SpaceOp>(loc, dynamicDims, staticDims);
296
297 // Create operand point.
298 SmallVector<int64_t> allDynamicOffsets(rankedTy.getRank(),
299 ShapedType::kDynamicStrideOrOffset);
300 Value operandPoint = builder.create<gml_st::PointOp>(
301 loc, operandSpace, remainingOffsets,
302 builder.getI64ArrayAttr(allDynamicOffsets));
303
304 return builder.create<gml_st::MaterializeOp>(loc, leadingOperand,
305 operandPoint);
306 }
307
308 // For more than 1 operand, distinguish between the leading operand and the
309 // remainder.
310 assert(remainingOperands.size() > 1 &&
311 "expect more than 1 operand at this point");
312 Value leadingOperandConcatDim =
313 builder.create<tensor::DimOp>(loc, leadingOperand, concatDim);
314 Value leadingOperandPredicate = builder.create<arith::CmpIOp>(
315 loc, arith::CmpIPredicate::ult, remainingOffsets[concatDim],
316 leadingOperandConcatDim);
317 auto ifOp = builder.create<scf::IfOp>(
318 loc, rankedTy.getElementType(), leadingOperandPredicate,
319 [&](OpBuilder &builder, Location loc) {
320 // For the leading operand, recur with the current offsets.
321 Value fused = fuseConcatenateOpThroughPointRecursively(
322 builder, loc, rankedTy, concatDim, remainingOffsets,
323 leadingOperand);
324 builder.create<scf::YieldOp>(loc, fused);
325 },
326 [&](OpBuilder &builder, Location loc) {
327 // For the remaining operands, substract the leading operand's size from
328 // the remaining offsets in the concatenation dimension.
329 SmallVector<Value> thenRemainingOffsets(remainingOffsets.begin(),
330 remainingOffsets.end());
331 thenRemainingOffsets[concatDim] = builder.create<arith::SubIOp>(
332 loc, remainingOffsets[concatDim], leadingOperandConcatDim);
333 Value fused = fuseConcatenateOpThroughPointRecursively(
334 builder, loc, rankedTy, concatDim, thenRemainingOffsets,
335 remainingOperands.drop_front());
336 builder.create<scf::YieldOp>(loc, fused);
337 });
338 return ifOp.getResults().front();
339 }
340
fuseConcatenateOpThroughPoint(ConcatenateOp op,OpBuilder & builder,Location loc,Value subset)341 Value fuseConcatenateOpThroughPoint(ConcatenateOp op, OpBuilder &builder,
342 Location loc, Value subset) {
343 auto resultTy = op.getType().cast<RankedTensorType>();
344 int64_t resultRank = resultTy.getRank();
345 uint64_t concatDim = op.dimension();
346
347 // Materialize initial offsets.
348 SmallVector<Value> initialOffsets;
349 initialOffsets.reserve(resultRank);
350 for (int64_t i = 0; i < resultRank; ++i) {
351 initialOffsets.push_back(builder.create<gml_st::OffsetOp>(
352 loc, subset, builder.create<arith::ConstantIndexOp>(loc, i)));
353 }
354
355 ValueRange initialOperands = op.operands();
356 return fuseConcatenateOpThroughPointRecursively(
357 builder, loc, resultTy, concatDim, initialOffsets, initialOperands);
358 }
359
360 } // namespace
361
fuse(Location loc,Value subset,OpBuilder & builder)362 Value ConcatenateOp::fuse(Location loc, Value subset, OpBuilder &builder) {
363 Type subsetTy = subset.getType();
364 if (subsetTy.isa<gml_st::TileType>()) {
365 return fuseConcatenateOpThroughTile(*this, builder, loc, subset);
366 }
367 if (subsetTy.isa<gml_st::PointType>()) {
368 return fuseConcatenateOpThroughPoint(*this, builder, loc, subset);
369 }
370 return {};
371 }
372
parse(OpAsmParser & parser,OperationState & result)373 ParseResult ConcatenateOp::parse(OpAsmParser &parser, OperationState &result) {
374 return parseDstStyleOp(parser, result);
375 }
376
print(OpAsmPrinter & p)377 void ConcatenateOp::print(OpAsmPrinter &p) {
378 printDstStyleOp(cast<ConcatenateOp>(getOperation()), p);
379 }
380
verify()381 LogicalResult ConcatenateOp::verify() {
382 return verifyDestinationStyleOp(getOperation(), getNumOutputs());
383 }
384
385 //===----------------------------------------------------------------------===//
386 // DynamicBroadcastInDimOp
387 //===----------------------------------------------------------------------===//
388
parse(OpAsmParser & parser,OperationState & result)389 ParseResult DynamicBroadcastInDimOp::parse(OpAsmParser &parser,
390 OperationState &result) {
391 return parseDstStyleOp(parser, result);
392 }
393
print(OpAsmPrinter & p)394 void DynamicBroadcastInDimOp::print(OpAsmPrinter &p) {
395 printDstStyleOp(cast<DynamicBroadcastInDimOp>(getOperation()), p);
396 }
397
verify()398 LogicalResult DynamicBroadcastInDimOp::verify() {
399 return verifyDestinationStyleOp(getOperation(), getNumOutputs());
400 }
401
fuse(Location loc,Value subset,OpBuilder & builder)402 Value DynamicBroadcastInDimOp::fuse(Location loc, Value subset,
403 OpBuilder &builder) {
404 Type subsetTy = subset.getType();
405 auto operandTy = operand().getType().cast<RankedTensorType>();
406 auto resultTy = getType(0).cast<RankedTensorType>();
407 int64_t operandRank = operandTy.getRank();
408
409 // Create the needed constants only once.
410 DenseMap<uint64_t, Value> localIndexConstants;
411 auto getIndexConstant = [&](uint64_t c) -> Value {
412 auto it = localIndexConstants.find(c);
413 if (it != localIndexConstants.end()) return it->second;
414 auto cst = builder.create<arith::ConstantIndexOp>(loc, c);
415 localIndexConstants[c] = cst;
416 return cst;
417 };
418
419 // Materialize operand space.
420 auto operandSpaceTy = builder.getType<gml_st::TileType>(operandTy.getShape());
421 auto dynamicDims = tensor::createDynamicDimValues(builder, loc, operand());
422 auto staticDims = builder.getI64ArrayAttr(operandTy.getShape());
423 Value operandSpace = builder.create<gml_st::SpaceOp>(loc, operandSpaceTy,
424 dynamicDims, staticDims);
425
426 // Materialize operand dimensions.
427 SmallVector<Value> operandDims;
428 int64_t dynamicDimsIdx = 0;
429 operandDims.reserve(operandTy.getRank());
430 for (const auto &it : llvm::enumerate(operandTy.getShape())) {
431 int64_t d = it.value();
432 Value dim = d == ShapedType::kDynamicSize ? dynamicDims[dynamicDimsIdx++]
433 : getIndexConstant(d);
434 operandDims.push_back(dim);
435 }
436
437 // Collapse the subset to operate only on corresponding dimensions.
438 // TODO(frgossen): Only generate this when needed.
439 auto collapsedSubset = builder.create<gml_st::DropDimsOp>(
440 loc, subset, broadcast_dimensionsAttr());
441
442 // Find the expanding dimensions. If corresponding operand and result
443 // dimensions are different then the dimension is expanding.
444 // TODO(frgossen): Use info from known expanding and known non-expanding
445 // dimensions here.
446 SmallVector<Value> operandExpandingDims;
447 for (const auto &it : llvm::enumerate(broadcast_dimensions())) {
448 auto operandDim = operandDims[it.index()];
449 auto resultDim = builder.create<tensor::DimOp>(
450 loc, init(), getIndexConstant(it.value()));
451 operandExpandingDims.push_back(builder.create<arith::CmpIOp>(
452 loc, arith::CmpIPredicate::ne, operandDim, resultDim));
453 }
454
455 // Compute operand offsets, which are needed for tile and point subsets.
456 auto staticOffsets = builder.getI64ArrayAttr(
457 SmallVector<int64_t>(operandRank, ShapedType::kDynamicStrideOrOffset));
458 SmallVector<Value> offsets;
459 Value zero = getIndexConstant(0);
460 for (int i = 0; i < operandRank; ++i) {
461 Value isExpanding = operandExpandingDims[i];
462 Value collapsedSubsetOffset = builder.create<gml_st::OffsetOp>(
463 loc, collapsedSubset, getIndexConstant(i));
464 offsets.push_back(builder.create<arith::SelectOp>(loc, isExpanding, zero,
465 collapsedSubsetOffset));
466 }
467
468 // If the regarded subset is of point type, we can already construct the
469 // operand point and materialize it.
470 if (auto pointTy = subsetTy.dyn_cast<gml_st::PointType>()) {
471 auto operandPoint = builder.create<gml_st::PointOp>(
472 loc, pointTy, operandSpace, offsets, staticOffsets);
473 return builder.create<gml_st::MaterializeOp>(
474 loc, operandTy.getElementType(), operand(), operandPoint);
475 }
476
477 // If the regarded subset is of tile type, we still need the operand tile
478 // sizes to materialize a fused broadcast.
479 if (auto tileTy = subsetTy.dyn_cast<gml_st::TileType>()) {
480 // Compute operand tile sizes.
481 auto staticTileSizes = builder.getI64ArrayAttr(
482 SmallVector<int64_t>(operandRank, ShapedType::kDynamicSize));
483 SmallVector<Value> tileSizes;
484 Value one = getIndexConstant(1);
485 for (int i = 0; i < operandRank; ++i) {
486 Value isExpanding = operandExpandingDims[i];
487 Value tileSize = builder.create<gml_st::SizeOp>(loc, collapsedSubset,
488 getIndexConstant(i));
489 tileSizes.push_back(
490 builder.create<arith::SelectOp>(loc, isExpanding, one, tileSize));
491 }
492
493 // Create operand tile.
494 auto staticTileStrides =
495 builder.getI64ArrayAttr(SmallVector<int64_t>(operandRank, 1));
496 SmallVector<Value> tileStrides = {};
497 auto operandTileTy = builder.getType<gml_st::TileType>(
498 SmallVector<int64_t>(operandRank, ShapedType::kDynamicSize));
499 auto operandTile = builder.create<gml_st::TileOp>(
500 loc, operandTileTy, operandSpace, offsets, tileSizes, tileStrides,
501 staticOffsets, staticTileSizes, staticTileStrides);
502
503 // Materialize operand subsets.
504 Value tiledInit =
505 builder.create<gml_st::MaterializeOp>(loc, init(), subset);
506 Value tiledOperand =
507 builder.create<gml_st::MaterializeOp>(loc, operand(), operandTile);
508
509 // Finally, materialize tiled broadcast.
510 auto tiledResultTy =
511 RankedTensorType::get(tileTy.getShape(), resultTy.getElementType());
512 return builder
513 .create<DynamicBroadcastInDimOp>(
514 loc, TypeRange{tiledResultTy}, tiledOperand, tiledInit,
515 broadcast_dimensionsAttr(), known_expanding_dimensionsAttr(),
516 known_nonexpanding_dimensionsAttr())
517 .getResult(0);
518 }
519
520 return {};
521 }
522
523 //===----------------------------------------------------------------------===//
524 // ScatterOp
525 //===----------------------------------------------------------------------===//
526
parse(OpAsmParser & parser,OperationState & result)527 ParseResult ScatterOp::parse(OpAsmParser &parser, OperationState &result) {
528 return parseDstStyleOp(parser, result);
529 }
530
print(OpAsmPrinter & p)531 void ScatterOp::print(OpAsmPrinter &p) {
532 printDstStyleOp(cast<ScatterOp>(getOperation()), p);
533 }
534
verify()535 LogicalResult ScatterOp::verify() {
536 return verifyDestinationStyleOp(getOperation(), getNumOutputs());
537 }
538
539 //===----------------------------------------------------------------------===//
540 // GatherOp
541 //===----------------------------------------------------------------------===//
542
parse(OpAsmParser & parser,OperationState & result)543 ParseResult GatherOp::parse(OpAsmParser &parser, OperationState &result) {
544 return parseDstStyleOp(parser, result);
545 }
546
print(OpAsmPrinter & p)547 void GatherOp::print(OpAsmPrinter &p) {
548 printDstStyleOp(cast<GatherOp>(getOperation()), p);
549 }
550
verify()551 LogicalResult GatherOp::verify() {
552 return verifyDestinationStyleOp(getOperation(), getNumOutputs());
553 }
554
555 //===----------------------------------------------------------------------===//
556 // TransposeOp
557 //===----------------------------------------------------------------------===//
558
parse(OpAsmParser & parser,OperationState & result)559 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
560 return parseDstStyleOp(parser, result);
561 }
562
print(OpAsmPrinter & p)563 void TransposeOp::print(OpAsmPrinter &p) {
564 printDstStyleOp(cast<TransposeOp>(getOperation()), p);
565 }
566
isValidPermutation(ArrayRef<int64_t> permutation)567 bool isValidPermutation(ArrayRef<int64_t> permutation) {
568 SmallVector<bool> seen(permutation.size(), false);
569 for (auto p : permutation) {
570 // Verify that each element is in [0..n-1] range and is present only once.
571 if (p < 0 || p >= permutation.size() || seen[p]) return false;
572
573 seen[p] = true;
574 }
575 return true;
576 }
577
verify()578 LogicalResult TransposeOp::verify() {
579 ArrayRef<int64_t> permutationRef = permutation();
580
581 if (!isValidPermutation(permutationRef))
582 return emitOpError("permutation is not valid");
583
584 auto inputType = input().getType().cast<ShapedType>();
585 auto initType = init().getType().cast<ShapedType>();
586
587 int64_t rank = inputType.getRank();
588
589 if (rank != initType.getRank())
590 return emitOpError() << "input rank " << rank
591 << " does not match init rank " << initType.getRank();
592
593 if (rank != permutationRef.size())
594 return emitOpError() << "size of permutation " << permutationRef.size()
595 << " does not match the argument rank " << rank;
596
597 auto inputDims = inputType.getShape();
598 auto initDims = initType.getShape();
599
600 for (size_t i = 0; i < rank; ++i) {
601 int64_t inputDim = inputDims[permutationRef[i]];
602 int64_t initDim = initDims[i];
603
604 if (inputDim != ShapedType::kDynamicSize &&
605 initDim != ShapedType::kDynamicSize && inputDim != initDim) {
606 return emitOpError() << "dim(result, " << i << ") = " << initDim
607 << " doesn't match dim(input, permutation[" << i
608 << "]) = " << inputDim;
609 }
610 }
611
612 return verifyDestinationStyleOp(getOperation(), getNumOutputs());
613 }
614
615 } // namespace thlo
616 } // namespace mlir
617
618 // Generated op classes.
619 #define GET_OP_CLASSES
620 #include "mlir-hlo/Dialect/thlo/IR/thlo_ops.cc.inc"
621