1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "mlir-hlo/Dialect/gml_st/transforms/bufferizable_op_interface_impl.h"
17
18 #include <iterator>
19 #include <tuple>
20
21 #include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h"
22 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
23 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
24 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
25 #include "mlir/Dialect/MemRef/IR/MemRef.h"
26 #include "mlir/Interfaces/ViewLikeInterface.h"
27 #include "mlir/Support/LogicalResult.h"
28
29 using mlir::bufferization::AnalysisState;
30 using mlir::bufferization::BufferizableOpInterface;
31 using mlir::bufferization::BufferizationOptions;
32 using mlir::bufferization::BufferRelation;
33 using mlir::bufferization::ToMemrefOp;
34 using mlir::bufferization::ToTensorOp;
35
36 namespace mlir {
37 namespace gml_st {
38 namespace {
39
40 /// Bufferization of gml_st.loop. Replace with a new gml_st.loop
41 /// that operates entirely on memrefs.
42 struct LoopOpInterface
43 : public BufferizableOpInterface::ExternalModel<LoopOpInterface, LoopOp> {
bufferizesToMemoryReadmlir::gml_st::__anon421317500111::LoopOpInterface44 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
45 const AnalysisState &state) const {
46 auto loopOp = cast<LoopOp>(op);
47
48 // gml_st.loop operands alone do not bufferize to a memory read, but
49 // one of the uses of their matching bbArgs may.
50 return state.isValueRead(loopOp.getTiedBlockArgument(opOperand));
51 }
52
bufferizesToMemoryWritemlir::gml_st::__anon421317500111::LoopOpInterface53 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
54 const AnalysisState &state) const {
55 // Only operands with an aliasing OpResult (i.e., output operands) bufferize
56 // to a memory write.
57 auto bufferizableOp = cast<BufferizableOpInterface>(op);
58 return !bufferizableOp.getAliasingOpResult(opOperand, state).empty();
59 }
60
getAliasingOpResultmlir::gml_st::__anon421317500111::LoopOpInterface61 SmallVector<OpResult> getAliasingOpResult(
62 Operation *op, OpOperand &opOperand,
63 const AnalysisState & /*state*/) const {
64 auto loopOp = cast<LoopOp>(op);
65
66 // Output operands are tied to their corresponding OpResults.
67 OpResult opResult = loopOp.getTiedOpResult(opOperand);
68 if (!opResult) return {};
69 return {opResult};
70 }
71
bufferRelationmlir::gml_st::__anon421317500111::LoopOpInterface72 BufferRelation bufferRelation(Operation * /*op*/, OpResult /*opResult*/,
73 const AnalysisState & /*state*/) const {
74 return BufferRelation::Equivalent;
75 }
76
isWritablemlir::gml_st::__anon421317500111::LoopOpInterface77 bool isWritable(Operation * /*op*/, Value /*value*/,
78 const AnalysisState & /*state*/) const {
79 // Interestingly, LoopOp's bbArgs can **always** be viewed
80 // inplace from the perspective of nested ops:
81 // 1. Either the matching iter operand is not bufferized inplace and an
82 // alloc + optional copy makes the bbArg itself inplaceable.
83 // 2. Or the matching iter operand is bufferized inplace and bbArg just
84 // bufferizes to that too.
85 return true;
86 }
87
getBufferTypemlir::gml_st::__anon421317500111::LoopOpInterface88 FailureOr<BaseMemRefType> getBufferType(
89 Operation *op, BlockArgument bbArg,
90 const BufferizationOptions &options) const {
91 auto loopOp = cast<LoopOp>(op);
92 return bufferization::getBufferType(loopOp.getTiedOperand(bbArg).get(),
93 options);
94 }
95
bufferizemlir::gml_st::__anon421317500111::LoopOpInterface96 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
97 const BufferizationOptions &options) const {
98 auto loopOp = cast<LoopOp>(op);
99
100 // Compute new inputs, outputs and results.
101 SmallVector<Value> newInputs, newOutputs, newResults;
102 for (unsigned i = loopOp.getNumControlOperands();
103 i < loopOp->getNumOperands(); ++i) {
104 OpOperand &operand = loopOp->getOpOperand(i);
105 Value rewrittenValue = operand.get();
106 if (rewrittenValue.getType().isa<TensorType>()) {
107 FailureOr<Value> maybeBuffer =
108 getBuffer(rewriter, operand.get(), options);
109 if (failed(maybeBuffer)) return failure();
110 rewrittenValue = *maybeBuffer;
111 }
112 if (i < loopOp.getNumControlOperands() + loopOp.getNumInputs()) {
113 newInputs.push_back(rewrittenValue);
114 } else {
115 newOutputs.push_back(rewrittenValue);
116 if (operand.get().getType().isa<TensorType>())
117 newResults.push_back(rewrittenValue);
118 }
119 }
120
121 // Create new TiledLoopOp.
122 auto newLoopOp = rewriter.create<LoopOp>(
123 loopOp.getLoc(), loopOp.lowerBound(), loopOp.upperBound(),
124 loopOp.step(), newInputs, newOutputs, loopOp.iterator_types(),
125 loopOp.distribution_types());
126
127 // Remove terminator.
128 if (!newLoopOp.getBody()->empty())
129 rewriter.eraseOp(loopOp.getBody()->getTerminator());
130
131 // Compute new loop body arguments.
132 SmallVector<Value> newBlockArgs, newRegionInOutArgs, oldRegionInOutArgs;
133 ValueRange newInductionVars = newLoopOp.getInductionVars();
134 newBlockArgs.append(newInductionVars.begin(), newInductionVars.end());
135
136 ValueRange newRegionInArgs = newLoopOp.getRegionInputArgs();
137 ValueRange newRegionOutArgs = newLoopOp.getRegionOutputArgs();
138 newRegionInOutArgs.append(newRegionInArgs.begin(), newRegionInArgs.end());
139 newRegionInOutArgs.append(newRegionOutArgs.begin(), newRegionOutArgs.end());
140
141 ValueRange oldRegionInArgs = loopOp.getRegionInputArgs();
142 ValueRange oldRegionOutArgs = loopOp.getRegionOutputArgs();
143 oldRegionInOutArgs.append(oldRegionInArgs.begin(), oldRegionInArgs.end());
144 oldRegionInOutArgs.append(oldRegionOutArgs.begin(), oldRegionOutArgs.end());
145 assert(newRegionInArgs.size() == oldRegionInArgs.size() &&
146 "expected same number of input args");
147 assert(newRegionOutArgs.size() == oldRegionOutArgs.size() &&
148 "expected same number of output args");
149
150 for (auto it : llvm::zip(oldRegionInOutArgs, newRegionInOutArgs)) {
151 Value oldArg = std::get<0>(it);
152 Value newArg = std::get<1>(it);
153 rewriter.setInsertionPointToStart(newLoopOp.getBody());
154 if (oldArg.getType().isa<TensorType>()) {
155 newBlockArgs.push_back(rewriter.create<bufferization::ToTensorOp>(
156 oldArg.getLoc(), newArg));
157 } else {
158 newBlockArgs.push_back(newArg);
159 }
160 }
161
162 // Move old body into new loop.
163 rewriter.mergeBlocks(loopOp.getBody(), newLoopOp.getBody(), newBlockArgs);
164
165 // Replace previous terminator with a new one that does not yield anything.
166 auto oldTerminator =
167 cast<gml_st::YieldOp>(newLoopOp.getBody()->getTerminator());
168 rewriter.setInsertionPointToEnd(newLoopOp.getBody());
169 auto newTerminator =
170 rewriter.create<gml_st::YieldOp>(oldTerminator->getLoc());
171
172 // Copy buffer of yielded tensor to output buffer. If everything bufferized
173 // inplace, this copy will fold away.
174 rewriter.setInsertionPoint(newTerminator);
175 for (auto it : llvm::zip(oldTerminator.values(), newOutputs)) {
176 Value output = std::get<1>(it);
177 Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
178 newTerminator.getLoc(), output.getType(), std::get<0>(it));
179 if (failed(options.createMemCpy(rewriter, newTerminator.getLoc(),
180 toMemrefOp, output)))
181 return failure();
182 }
183
184 // Erase old terminator.
185 rewriter.eraseOp(oldTerminator);
186
187 // Replace results and delete old op.
188 bufferization::replaceOpWithBufferizedValues(rewriter, op, newResults);
189
190 return success();
191 }
192 };
193
194 // Returns the set chain in reverse order, i.e. from set to space.
195 // The space operation itself is not included.
findSetChain(Value set)196 FailureOr<SmallVector<Operation *>> findSetChain(Value set) {
197 SmallVector<Operation *> sets;
198 Operation *current = set.getDefiningOp();
199 while (current) {
200 if (auto space = dyn_cast<SpaceOp>(*current)) break;
201
202 sets.push_back(current);
203 // TODO(pifon): It might be useful to have a set interface.
204 if (auto tile = dyn_cast<TileOp>(*current)) {
205 current = tile.superset().getDefiningOp();
206 continue;
207 }
208 if (auto point = dyn_cast<PointOp>(*current)) {
209 current = point.superset().getDefiningOp();
210 continue;
211 }
212 return failure();
213 }
214 return sets;
215 }
216
217 // TODO(pifon): Clean this up, for example, by using ViewLikeInterface.
getPointIndicesValues(OpBuilder & b,PointOp pointOp)218 SmallVector<Value> getPointIndicesValues(OpBuilder &b, PointOp pointOp) {
219 SmallVector<Value> indices;
220 unsigned rank = pointOp.getRank();
221 indices.reserve(rank);
222 unsigned numDynamic = 0;
223 for (auto staticIndex : pointOp.static_indices().getAsRange<IntegerAttr>()) {
224 if (ShapedType::isDynamicStrideOrOffset(staticIndex.getInt())) {
225 indices.push_back(pointOp.dynamic_indices()[numDynamic++]);
226 } else {
227 Value indexValue = b.create<arith::ConstantIndexOp>(pointOp.getLoc(),
228 staticIndex.getInt());
229 indices.push_back(indexValue);
230 }
231 }
232 return indices;
233 }
234
235 // Returns a scalar or a memref type result of `gml_st.materialize` op after
236 // bufferization.
materializeExtraction(OpBuilder & b,Value memref,Value set)237 FailureOr<Value> materializeExtraction(OpBuilder &b, Value memref, Value set) {
238 auto setsOr = findSetChain(set);
239 if (failed(setsOr)) return failure();
240
241 // Find set use-def chain from space to the set.
242 // Create subview or load ops for the set computation.
243 OpBuilder::InsertionGuard g(b);
244 Value result = memref;
245 for (auto *set : llvm::reverse(*setsOr)) {
246 Location loc = set->getLoc();
247 b.setInsertionPointAfter(set);
248 if (auto tile = dyn_cast<TileOp>(*set)) {
249 result = b.create<memref::SubViewOp>(loc, result, tile.getMixedOffsets(),
250 tile.getMixedSizes(),
251 tile.getMixedStrides());
252 continue;
253 }
254 if (auto point = dyn_cast<PointOp>(*set)) {
255 result = b.create<memref::LoadOp>(loc, result,
256 getPointIndicesValues(b, point));
257 continue;
258 }
259 return failure();
260 }
261 return result;
262 }
263
materializeInsertion(OpBuilder & b,Value update,Value set,Value memref,const BufferizationOptions & options)264 LogicalResult materializeInsertion(OpBuilder &b, Value update, Value set,
265 Value memref,
266 const BufferizationOptions &options) {
267 auto sets = findSetChain(set);
268 if (failed(sets)) return failure();
269
270 if (sets->empty())
271 return options.createMemCpy(b, update.getLoc(), update, memref);
272
273 // Create subviews or store ops for the set computation.
274 OpBuilder::InsertionGuard g(b);
275 auto *it = std::prev(sets->end());
276 // The first element for the use-def chain is the `gml_st.space` op and it
277 // should be ignored for now.
278 for (; it != sets->begin(); --it) {
279 Location loc = (*it)->getLoc();
280 b.setInsertionPointAfter(*it);
281
282 auto tile = dyn_cast<TileOp>(*it);
283 if (!tile) return failure();
284
285 memref = b.create<memref::SubViewOp>(loc, memref, tile.getMixedOffsets(),
286 tile.getMixedSizes(),
287 tile.getMixedStrides());
288 }
289 Location loc = (*it)->getLoc();
290 if (auto point = dyn_cast<PointOp>(*it)) {
291 b.create<memref::StoreOp>(loc, update, memref,
292 getPointIndicesValues(b, point));
293 return success();
294 }
295 if (auto tile = dyn_cast<TileOp>(*it)) {
296 memref = b.create<memref::SubViewOp>(loc, memref, tile.getMixedOffsets(),
297 tile.getMixedSizes(),
298 tile.getMixedOffsets());
299 return success();
300 }
301 llvm_unreachable("Unknown set type");
302 }
303
304 struct MaterializeOpInterface
305 : public BufferizableOpInterface::ExternalModel<MaterializeOpInterface,
306 MaterializeOp> {
bufferizesToMemoryReadmlir::gml_st::__anon421317500111::MaterializeOpInterface307 bool bufferizesToMemoryRead(Operation * /*op*/, OpOperand & /*opOperand*/,
308 const AnalysisState & /*state*/) const {
309 return false;
310 }
311
bufferizesToMemoryWritemlir::gml_st::__anon421317500111::MaterializeOpInterface312 bool bufferizesToMemoryWrite(Operation * /*op*/, OpOperand & /*opOperand*/,
313 const AnalysisState & /*state*/) const {
314 return false;
315 }
316
getAliasingOpResultmlir::gml_st::__anon421317500111::MaterializeOpInterface317 SmallVector<OpResult> getAliasingOpResult(
318 Operation *op, OpOperand &opOperand,
319 const AnalysisState & /*state*/) const {
320 auto result = op->getOpResult(0);
321 if (result.getType().isa<RankedTensorType>() &&
322 opOperand.getOperandNumber() == 0)
323 return {result};
324 return {};
325 }
326
bufferRelationmlir::gml_st::__anon421317500111::MaterializeOpInterface327 BufferRelation bufferRelation(Operation * /*op*/, OpResult /*opResult*/,
328 const AnalysisState & /*state*/) const {
329 return BufferRelation::Equivalent;
330 }
331
bufferizemlir::gml_st::__anon421317500111::MaterializeOpInterface332 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
333 const BufferizationOptions &options) const {
334 auto materializeOp = cast<MaterializeOp>(op);
335
336 FailureOr<Value> bufferOr =
337 getBuffer(rewriter, materializeOp->getOpOperand(0).get(), options);
338 if (failed(bufferOr)) return failure();
339
340 FailureOr<Value> resultOr =
341 materializeExtraction(rewriter, *bufferOr, materializeOp.set());
342
343 if (failed(resultOr)) return failure();
344
345 bufferization::replaceOpWithBufferizedValues(rewriter, op, *resultOr);
346 return success();
347 }
348 };
349
350 struct ParallelOpInterface
351 : public BufferizableOpInterface::ExternalModel<ParallelOpInterface,
352 ParallelOp> {
getAliasingOpOperandmlir::gml_st::__anon421317500111::ParallelOpInterface353 SmallVector<OpOperand *> getAliasingOpOperand(
354 Operation *op, OpResult opResult, const AnalysisState & /*state*/) const {
355 auto parallelOp = cast<ParallelOp>(op);
356 return {
357 parallelOp.getTerminator().getDstOperand(opResult.getResultNumber())};
358 }
359
isMemoryWritemlir::gml_st::__anon421317500111::ParallelOpInterface360 bool isMemoryWrite(Operation *, OpResult, const AnalysisState &) const {
361 // This op is a memory write. Stop lookup here to avoid finding false
362 // conflicts involving this op and one of the ops in the region. This is
363 // similar to how scf.if ops are analyzed.
364 return true;
365 }
366
bufferRelationmlir::gml_st::__anon421317500111::ParallelOpInterface367 BufferRelation bufferRelation(Operation * /*op*/, OpResult /*opResult*/,
368 const AnalysisState & /*state*/) const {
369 return BufferRelation::Equivalent;
370 }
371
isWritablemlir::gml_st::__anon421317500111::ParallelOpInterface372 bool isWritable(Operation * /*op*/, Value /*value*/,
373 const AnalysisState & /*state*/) const {
374 return true;
375 }
376
bufferizemlir::gml_st::__anon421317500111::ParallelOpInterface377 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
378 const BufferizationOptions & /*options*/) const {
379 auto loopOp = cast<ParallelOp>(op);
380
381 // Create new TiledLoopOp.
382 auto newLoopOp = rewriter.create<ParallelOp>(
383 loopOp.getLoc(), TypeRange{llvm::None}, loopOp.lowerBound(),
384 loopOp.upperBound(), loopOp.step(), nullptr);
385
386 // Move the old body into the new loop.
387 rewriter.mergeBlocks(loopOp.getBody(), newLoopOp.getBody(),
388 newLoopOp.getInductionVars());
389
390 // Remove the old op.
391 rewriter.eraseOp(op);
392 return success();
393 }
394 };
395
396 struct ForOpInterface
397 : public BufferizableOpInterface::ExternalModel<ForOpInterface, ForOp> {
getAliasingOpOperandmlir::gml_st::__anon421317500111::ForOpInterface398 SmallVector<OpOperand *> getAliasingOpOperand(
399 Operation *op, OpResult opResult, const AnalysisState & /*state*/) const {
400 auto forOp = cast<gml_st::ForOp>(op);
401 return {&forOp.getOpOperandForResult(opResult)};
402 }
403
isWritablemlir::gml_st::__anon421317500111::ForOpInterface404 bool isWritable(Operation * /*op*/, Value /*value*/,
405 const AnalysisState & /*state*/) const {
406 // Interestingly, ForOp's bbArg can **always** be viewed
407 // inplace from the perspective of ops nested under:
408 // 1. Either the matching iter operand is not bufferized inplace and an
409 // alloc + optional copy makes the bbArg itself inplaceable.
410 // 2. Or the matching iter operand is bufferized inplace and bbArg just
411 // bufferizes to that too.
412 return true;
413 }
414
bufferizesToMemoryReadmlir::gml_st::__anon421317500111::ForOpInterface415 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
416 const AnalysisState &state) const {
417 auto forOp = cast<gml_st::ForOp>(op);
418 return state.isValueRead(forOp.getRegionOutputArgForOpOperand(opOperand));
419 }
420
bufferizesToMemoryWritemlir::gml_st::__anon421317500111::ForOpInterface421 bool bufferizesToMemoryWrite(Operation * /*op*/, OpOperand & /*opOperand*/,
422 const AnalysisState & /*state*/) const {
423 return true;
424 }
425
getAliasingOpResultmlir::gml_st::__anon421317500111::ForOpInterface426 SmallVector<OpResult> getAliasingOpResult(
427 Operation *op, OpOperand &opOperand,
428 const AnalysisState & /*state*/) const {
429 auto forOp = cast<gml_st::ForOp>(op);
430 return {forOp.getResultForOpOperand(opOperand)};
431 }
432
bufferRelationmlir::gml_st::__anon421317500111::ForOpInterface433 BufferRelation bufferRelation(Operation * /*op*/, OpResult /*opResult*/,
434 const AnalysisState & /*state*/) const {
435 return BufferRelation::Equivalent;
436 }
437
bufferizemlir::gml_st::__anon421317500111::ForOpInterface438 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
439 const BufferizationOptions &options) const {
440 auto forOp = cast<ForOp>(op);
441 Location loc = forOp.getLoc();
442
443 // Get the bufferized output arguments.
444 SmallVector<Value> bufferizedOutputs;
445 bufferizedOutputs.reserve(forOp.getNumOutputs());
446 for (Value output : forOp.outputs()) {
447 FailureOr<Value> maybeBuffer = getBuffer(rewriter, output, options);
448 if (failed(maybeBuffer)) return failure();
449 bufferizedOutputs.push_back(*maybeBuffer);
450 }
451
452 // Create new ForOp.
453 auto newForOp = rewriter.create<ForOp>(loc, TypeRange{}, forOp.lowerBound(),
454 forOp.upperBound(), forOp.step(),
455 ValueRange{}, nullptr);
456 Block *loopBody = newForOp.getBody();
457
458 // Add conversions to tensor so that we can reuse the old loop body.
459 rewriter.setInsertionPointToStart(loopBody);
460 SmallVector<Value> outputsToTensors;
461 for (auto buf : bufferizedOutputs) {
462 Value tensor = rewriter.create<bufferization::ToTensorOp>(loc, buf);
463 outputsToTensors.push_back(tensor);
464 }
465 SmallVector<Value> blockArgs = newForOp.getInductionVars();
466 blockArgs.append(outputsToTensors);
467
468 // Move old body into new for loop.
469 rewriter.mergeBlocks(forOp.getBody(), loopBody, blockArgs);
470
471 // Replace results and delete old op.
472 bufferization::replaceOpWithBufferizedValues(rewriter, op,
473 bufferizedOutputs);
474 return success();
475 }
476 };
477
478 struct SetYieldOpInterface
479 : public BufferizableOpInterface::ExternalModel<SetYieldOpInterface,
480 SetYieldOp> {
getAliasingOpResultmlir::gml_st::__anon421317500111::SetYieldOpInterface481 SmallVector<OpResult> getAliasingOpResult(
482 Operation *op, OpOperand &opOperand,
483 const AnalysisState & /*state*/) const {
484 auto yieldOp = cast<SetYieldOp>(op);
485 if (!yieldOp.isDstOperand(opOperand)) return {};
486
487 auto loopResult = yieldOp.getTiedOpResult(opOperand);
488 assert(succeeded(loopResult) && "didn't find a corresponding loop result");
489 return {*loopResult};
490 }
491
bufferizemlir::gml_st::__anon421317500111::SetYieldOpInterface492 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
493 const BufferizationOptions &options) const {
494 auto yieldOp = cast<SetYieldOp>(op);
495 Operation *loop = yieldOp->getParentOp();
496 if (!isa<ForOp, ParallelOp>(loop))
497 return yieldOp->emitError("unsupported gml_st::SetYieldOp parent");
498
499 for (const auto &it :
500 llvm::enumerate(llvm::zip(yieldOp.srcs(), yieldOp.dsts(),
501 yieldOp.sets(), loop->getResults()))) {
502 Value src, dst, set, loopResult;
503 std::tie(src, dst, set, loopResult) = it.value();
504
505 // `src` can be a scalar, that's `getBuffer()` should be called only for
506 // tensor types.
507 if (src.getType().isa<RankedTensorType>()) {
508 FailureOr<Value> srcBufferOr = getBuffer(rewriter, src, options);
509 if (failed(srcBufferOr)) return failure();
510
511 src = *srcBufferOr;
512 }
513
514 FailureOr<Value> dstBufferOr = getBuffer(rewriter, dst, options);
515 if (failed(dstBufferOr)) return failure();
516 Value dstBuffer = *dstBufferOr;
517
518 if (failed(materializeInsertion(rewriter, src, set, dstBuffer, options)))
519 return failure();
520 if (auto parallelOp =
521 dyn_cast<gml_st::ParallelOp>(yieldOp->getParentOp())) {
522 // Replace results of the enclosing loop with `to_tensor(dst)`.
523 OpBuilder::InsertionGuard g(rewriter);
524 rewriter.setInsertionPointAfter(loop);
525
526 Value resultToTensor =
527 rewriter.create<ToTensorOp>(loop->getLoc(), dstBuffer);
528 for (OpOperand &use : loopResult.getUses()) {
529 rewriter.updateRootInPlace(use.getOwner(),
530 [&]() { use.set(resultToTensor); });
531 }
532 }
533 }
534 rewriter.replaceOpWithNewOp<SetYieldOp>(op);
535 return success();
536 }
537
bufferizesToMemoryReadmlir::gml_st::__anon421317500111::SetYieldOpInterface538 bool bufferizesToMemoryRead(Operation * /*op*/, OpOperand & /*opOperand*/,
539 const AnalysisState & /*state*/) const {
540 return true;
541 }
542
bufferizesToMemoryWritemlir::gml_st::__anon421317500111::SetYieldOpInterface543 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
544 const AnalysisState & /*state*/) const {
545 return cast<SetYieldOp>(op).isDstOperand(opOperand);
546 }
547
isNotConflictingmlir::gml_st::__anon421317500111::SetYieldOpInterface548 bool isNotConflicting(Operation * /*op*/, OpOperand * /*uRead*/,
549 OpOperand * /*uConflictingWrite*/,
550 const AnalysisState & /*state*/) const {
551 // TODO(pifon): Implement proper analysis here similar to InsertSliceOp.
552 return true;
553 }
554 };
555
556 } // namespace
557 } // namespace gml_st
558 } // namespace mlir
559
registerBufferizableOpInterfaceExternalModels(DialectRegistry & registry)560 void mlir::gml_st::registerBufferizableOpInterfaceExternalModels(
561 DialectRegistry ®istry) {
562 registry.addExtension(
563 +[](MLIRContext *ctx, gml_st::GmlStDialect * /*dialect*/) {
564 ForOp::attachInterface<ForOpInterface>(*ctx);
565 LoopOp::attachInterface<LoopOpInterface>(*ctx);
566 MaterializeOp::attachInterface<MaterializeOpInterface>(*ctx);
567 ParallelOp::attachInterface<ParallelOpInterface>(*ctx);
568 SetYieldOp::attachInterface<SetYieldOpInterface>(*ctx);
569 });
570 }
571