1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "mlir-hlo/Dialect/gml_st/transforms/bufferizable_op_interface_impl.h"
17 
18 #include <iterator>
19 #include <tuple>
20 
21 #include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h"
22 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
23 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
24 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
25 #include "mlir/Dialect/MemRef/IR/MemRef.h"
26 #include "mlir/Interfaces/ViewLikeInterface.h"
27 #include "mlir/Support/LogicalResult.h"
28 
29 using mlir::bufferization::AnalysisState;
30 using mlir::bufferization::BufferizableOpInterface;
31 using mlir::bufferization::BufferizationOptions;
32 using mlir::bufferization::BufferRelation;
33 using mlir::bufferization::ToMemrefOp;
34 using mlir::bufferization::ToTensorOp;
35 
36 namespace mlir {
37 namespace gml_st {
38 namespace {
39 
40 /// Bufferization of gml_st.loop. Replace with a new gml_st.loop
41 /// that operates entirely on memrefs.
42 struct LoopOpInterface
43     : public BufferizableOpInterface::ExternalModel<LoopOpInterface, LoopOp> {
bufferizesToMemoryReadmlir::gml_st::__anon421317500111::LoopOpInterface44   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
45                               const AnalysisState &state) const {
46     auto loopOp = cast<LoopOp>(op);
47 
48     // gml_st.loop operands alone do not bufferize to a memory read, but
49     // one of the uses of their matching bbArgs may.
50     return state.isValueRead(loopOp.getTiedBlockArgument(opOperand));
51   }
52 
bufferizesToMemoryWritemlir::gml_st::__anon421317500111::LoopOpInterface53   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
54                                const AnalysisState &state) const {
55     // Only operands with an aliasing OpResult (i.e., output operands) bufferize
56     // to a memory write.
57     auto bufferizableOp = cast<BufferizableOpInterface>(op);
58     return !bufferizableOp.getAliasingOpResult(opOperand, state).empty();
59   }
60 
getAliasingOpResultmlir::gml_st::__anon421317500111::LoopOpInterface61   SmallVector<OpResult> getAliasingOpResult(
62       Operation *op, OpOperand &opOperand,
63       const AnalysisState & /*state*/) const {
64     auto loopOp = cast<LoopOp>(op);
65 
66     // Output operands are tied to their corresponding OpResults.
67     OpResult opResult = loopOp.getTiedOpResult(opOperand);
68     if (!opResult) return {};
69     return {opResult};
70   }
71 
bufferRelationmlir::gml_st::__anon421317500111::LoopOpInterface72   BufferRelation bufferRelation(Operation * /*op*/, OpResult /*opResult*/,
73                                 const AnalysisState & /*state*/) const {
74     return BufferRelation::Equivalent;
75   }
76 
isWritablemlir::gml_st::__anon421317500111::LoopOpInterface77   bool isWritable(Operation * /*op*/, Value /*value*/,
78                   const AnalysisState & /*state*/) const {
79     // Interestingly, LoopOp's bbArgs can **always** be viewed
80     // inplace from the perspective of nested ops:
81     //   1. Either the matching iter operand is not bufferized inplace and an
82     //      alloc + optional copy makes the bbArg itself inplaceable.
83     //   2. Or the matching iter operand is bufferized inplace and bbArg just
84     //      bufferizes to that too.
85     return true;
86   }
87 
getBufferTypemlir::gml_st::__anon421317500111::LoopOpInterface88   FailureOr<BaseMemRefType> getBufferType(
89       Operation *op, BlockArgument bbArg,
90       const BufferizationOptions &options) const {
91     auto loopOp = cast<LoopOp>(op);
92     return bufferization::getBufferType(loopOp.getTiedOperand(bbArg).get(),
93                                         options);
94   }
95 
bufferizemlir::gml_st::__anon421317500111::LoopOpInterface96   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
97                           const BufferizationOptions &options) const {
98     auto loopOp = cast<LoopOp>(op);
99 
100     // Compute new inputs, outputs and results.
101     SmallVector<Value> newInputs, newOutputs, newResults;
102     for (unsigned i = loopOp.getNumControlOperands();
103          i < loopOp->getNumOperands(); ++i) {
104       OpOperand &operand = loopOp->getOpOperand(i);
105       Value rewrittenValue = operand.get();
106       if (rewrittenValue.getType().isa<TensorType>()) {
107         FailureOr<Value> maybeBuffer =
108             getBuffer(rewriter, operand.get(), options);
109         if (failed(maybeBuffer)) return failure();
110         rewrittenValue = *maybeBuffer;
111       }
112       if (i < loopOp.getNumControlOperands() + loopOp.getNumInputs()) {
113         newInputs.push_back(rewrittenValue);
114       } else {
115         newOutputs.push_back(rewrittenValue);
116         if (operand.get().getType().isa<TensorType>())
117           newResults.push_back(rewrittenValue);
118       }
119     }
120 
121     // Create new TiledLoopOp.
122     auto newLoopOp = rewriter.create<LoopOp>(
123         loopOp.getLoc(), loopOp.lowerBound(), loopOp.upperBound(),
124         loopOp.step(), newInputs, newOutputs, loopOp.iterator_types(),
125         loopOp.distribution_types());
126 
127     // Remove terminator.
128     if (!newLoopOp.getBody()->empty())
129       rewriter.eraseOp(loopOp.getBody()->getTerminator());
130 
131     // Compute new loop body arguments.
132     SmallVector<Value> newBlockArgs, newRegionInOutArgs, oldRegionInOutArgs;
133     ValueRange newInductionVars = newLoopOp.getInductionVars();
134     newBlockArgs.append(newInductionVars.begin(), newInductionVars.end());
135 
136     ValueRange newRegionInArgs = newLoopOp.getRegionInputArgs();
137     ValueRange newRegionOutArgs = newLoopOp.getRegionOutputArgs();
138     newRegionInOutArgs.append(newRegionInArgs.begin(), newRegionInArgs.end());
139     newRegionInOutArgs.append(newRegionOutArgs.begin(), newRegionOutArgs.end());
140 
141     ValueRange oldRegionInArgs = loopOp.getRegionInputArgs();
142     ValueRange oldRegionOutArgs = loopOp.getRegionOutputArgs();
143     oldRegionInOutArgs.append(oldRegionInArgs.begin(), oldRegionInArgs.end());
144     oldRegionInOutArgs.append(oldRegionOutArgs.begin(), oldRegionOutArgs.end());
145     assert(newRegionInArgs.size() == oldRegionInArgs.size() &&
146            "expected same number of input args");
147     assert(newRegionOutArgs.size() == oldRegionOutArgs.size() &&
148            "expected same number of output args");
149 
150     for (auto it : llvm::zip(oldRegionInOutArgs, newRegionInOutArgs)) {
151       Value oldArg = std::get<0>(it);
152       Value newArg = std::get<1>(it);
153       rewriter.setInsertionPointToStart(newLoopOp.getBody());
154       if (oldArg.getType().isa<TensorType>()) {
155         newBlockArgs.push_back(rewriter.create<bufferization::ToTensorOp>(
156             oldArg.getLoc(), newArg));
157       } else {
158         newBlockArgs.push_back(newArg);
159       }
160     }
161 
162     // Move old body into new loop.
163     rewriter.mergeBlocks(loopOp.getBody(), newLoopOp.getBody(), newBlockArgs);
164 
165     // Replace previous terminator with a new one that does not yield anything.
166     auto oldTerminator =
167         cast<gml_st::YieldOp>(newLoopOp.getBody()->getTerminator());
168     rewriter.setInsertionPointToEnd(newLoopOp.getBody());
169     auto newTerminator =
170         rewriter.create<gml_st::YieldOp>(oldTerminator->getLoc());
171 
172     // Copy buffer of yielded tensor to output buffer. If everything bufferized
173     // inplace, this copy will fold away.
174     rewriter.setInsertionPoint(newTerminator);
175     for (auto it : llvm::zip(oldTerminator.values(), newOutputs)) {
176       Value output = std::get<1>(it);
177       Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
178           newTerminator.getLoc(), output.getType(), std::get<0>(it));
179       if (failed(options.createMemCpy(rewriter, newTerminator.getLoc(),
180                                       toMemrefOp, output)))
181         return failure();
182     }
183 
184     // Erase old terminator.
185     rewriter.eraseOp(oldTerminator);
186 
187     // Replace results and delete old op.
188     bufferization::replaceOpWithBufferizedValues(rewriter, op, newResults);
189 
190     return success();
191   }
192 };
193 
194 // Returns the set chain in reverse order, i.e. from set to space.
195 // The space operation itself is not included.
findSetChain(Value set)196 FailureOr<SmallVector<Operation *>> findSetChain(Value set) {
197   SmallVector<Operation *> sets;
198   Operation *current = set.getDefiningOp();
199   while (current) {
200     if (auto space = dyn_cast<SpaceOp>(*current)) break;
201 
202     sets.push_back(current);
203     // TODO(pifon): It might be useful to have a set interface.
204     if (auto tile = dyn_cast<TileOp>(*current)) {
205       current = tile.superset().getDefiningOp();
206       continue;
207     }
208     if (auto point = dyn_cast<PointOp>(*current)) {
209       current = point.superset().getDefiningOp();
210       continue;
211     }
212     return failure();
213   }
214   return sets;
215 }
216 
217 // TODO(pifon): Clean this up, for example, by using ViewLikeInterface.
getPointIndicesValues(OpBuilder & b,PointOp pointOp)218 SmallVector<Value> getPointIndicesValues(OpBuilder &b, PointOp pointOp) {
219   SmallVector<Value> indices;
220   unsigned rank = pointOp.getRank();
221   indices.reserve(rank);
222   unsigned numDynamic = 0;
223   for (auto staticIndex : pointOp.static_indices().getAsRange<IntegerAttr>()) {
224     if (ShapedType::isDynamicStrideOrOffset(staticIndex.getInt())) {
225       indices.push_back(pointOp.dynamic_indices()[numDynamic++]);
226     } else {
227       Value indexValue = b.create<arith::ConstantIndexOp>(pointOp.getLoc(),
228                                                           staticIndex.getInt());
229       indices.push_back(indexValue);
230     }
231   }
232   return indices;
233 }
234 
235 // Returns a scalar or a memref type result of `gml_st.materialize` op after
236 // bufferization.
materializeExtraction(OpBuilder & b,Value memref,Value set)237 FailureOr<Value> materializeExtraction(OpBuilder &b, Value memref, Value set) {
238   auto setsOr = findSetChain(set);
239   if (failed(setsOr)) return failure();
240 
241   // Find set use-def chain from space to the set.
242   // Create subview or load ops for the set computation.
243   OpBuilder::InsertionGuard g(b);
244   Value result = memref;
245   for (auto *set : llvm::reverse(*setsOr)) {
246     Location loc = set->getLoc();
247     b.setInsertionPointAfter(set);
248     if (auto tile = dyn_cast<TileOp>(*set)) {
249       result = b.create<memref::SubViewOp>(loc, result, tile.getMixedOffsets(),
250                                            tile.getMixedSizes(),
251                                            tile.getMixedStrides());
252       continue;
253     }
254     if (auto point = dyn_cast<PointOp>(*set)) {
255       result = b.create<memref::LoadOp>(loc, result,
256                                         getPointIndicesValues(b, point));
257       continue;
258     }
259     return failure();
260   }
261   return result;
262 }
263 
materializeInsertion(OpBuilder & b,Value update,Value set,Value memref,const BufferizationOptions & options)264 LogicalResult materializeInsertion(OpBuilder &b, Value update, Value set,
265                                    Value memref,
266                                    const BufferizationOptions &options) {
267   auto sets = findSetChain(set);
268   if (failed(sets)) return failure();
269 
270   if (sets->empty())
271     return options.createMemCpy(b, update.getLoc(), update, memref);
272 
273   // Create subviews or store ops for the set computation.
274   OpBuilder::InsertionGuard g(b);
275   auto *it = std::prev(sets->end());
276   // The first element for the use-def chain is the `gml_st.space` op and it
277   // should be ignored for now.
278   for (; it != sets->begin(); --it) {
279     Location loc = (*it)->getLoc();
280     b.setInsertionPointAfter(*it);
281 
282     auto tile = dyn_cast<TileOp>(*it);
283     if (!tile) return failure();
284 
285     memref = b.create<memref::SubViewOp>(loc, memref, tile.getMixedOffsets(),
286                                          tile.getMixedSizes(),
287                                          tile.getMixedStrides());
288   }
289   Location loc = (*it)->getLoc();
290   if (auto point = dyn_cast<PointOp>(*it)) {
291     b.create<memref::StoreOp>(loc, update, memref,
292                               getPointIndicesValues(b, point));
293     return success();
294   }
295   if (auto tile = dyn_cast<TileOp>(*it)) {
296     memref = b.create<memref::SubViewOp>(loc, memref, tile.getMixedOffsets(),
297                                          tile.getMixedSizes(),
298                                          tile.getMixedOffsets());
299     return success();
300   }
301   llvm_unreachable("Unknown set type");
302 }
303 
304 struct MaterializeOpInterface
305     : public BufferizableOpInterface::ExternalModel<MaterializeOpInterface,
306                                                     MaterializeOp> {
bufferizesToMemoryReadmlir::gml_st::__anon421317500111::MaterializeOpInterface307   bool bufferizesToMemoryRead(Operation * /*op*/, OpOperand & /*opOperand*/,
308                               const AnalysisState & /*state*/) const {
309     return false;
310   }
311 
bufferizesToMemoryWritemlir::gml_st::__anon421317500111::MaterializeOpInterface312   bool bufferizesToMemoryWrite(Operation * /*op*/, OpOperand & /*opOperand*/,
313                                const AnalysisState & /*state*/) const {
314     return false;
315   }
316 
getAliasingOpResultmlir::gml_st::__anon421317500111::MaterializeOpInterface317   SmallVector<OpResult> getAliasingOpResult(
318       Operation *op, OpOperand &opOperand,
319       const AnalysisState & /*state*/) const {
320     auto result = op->getOpResult(0);
321     if (result.getType().isa<RankedTensorType>() &&
322         opOperand.getOperandNumber() == 0)
323       return {result};
324     return {};
325   }
326 
bufferRelationmlir::gml_st::__anon421317500111::MaterializeOpInterface327   BufferRelation bufferRelation(Operation * /*op*/, OpResult /*opResult*/,
328                                 const AnalysisState & /*state*/) const {
329     return BufferRelation::Equivalent;
330   }
331 
bufferizemlir::gml_st::__anon421317500111::MaterializeOpInterface332   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
333                           const BufferizationOptions &options) const {
334     auto materializeOp = cast<MaterializeOp>(op);
335 
336     FailureOr<Value> bufferOr =
337         getBuffer(rewriter, materializeOp->getOpOperand(0).get(), options);
338     if (failed(bufferOr)) return failure();
339 
340     FailureOr<Value> resultOr =
341         materializeExtraction(rewriter, *bufferOr, materializeOp.set());
342 
343     if (failed(resultOr)) return failure();
344 
345     bufferization::replaceOpWithBufferizedValues(rewriter, op, *resultOr);
346     return success();
347   }
348 };
349 
350 struct ParallelOpInterface
351     : public BufferizableOpInterface::ExternalModel<ParallelOpInterface,
352                                                     ParallelOp> {
getAliasingOpOperandmlir::gml_st::__anon421317500111::ParallelOpInterface353   SmallVector<OpOperand *> getAliasingOpOperand(
354       Operation *op, OpResult opResult, const AnalysisState & /*state*/) const {
355     auto parallelOp = cast<ParallelOp>(op);
356     return {
357         parallelOp.getTerminator().getDstOperand(opResult.getResultNumber())};
358   }
359 
isMemoryWritemlir::gml_st::__anon421317500111::ParallelOpInterface360   bool isMemoryWrite(Operation *, OpResult, const AnalysisState &) const {
361     // This op is a memory write. Stop lookup here to avoid finding false
362     // conflicts involving this op and one of the ops in the region. This is
363     // similar to how scf.if ops are analyzed.
364     return true;
365   }
366 
bufferRelationmlir::gml_st::__anon421317500111::ParallelOpInterface367   BufferRelation bufferRelation(Operation * /*op*/, OpResult /*opResult*/,
368                                 const AnalysisState & /*state*/) const {
369     return BufferRelation::Equivalent;
370   }
371 
isWritablemlir::gml_st::__anon421317500111::ParallelOpInterface372   bool isWritable(Operation * /*op*/, Value /*value*/,
373                   const AnalysisState & /*state*/) const {
374     return true;
375   }
376 
bufferizemlir::gml_st::__anon421317500111::ParallelOpInterface377   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
378                           const BufferizationOptions & /*options*/) const {
379     auto loopOp = cast<ParallelOp>(op);
380 
381     // Create new TiledLoopOp.
382     auto newLoopOp = rewriter.create<ParallelOp>(
383         loopOp.getLoc(), TypeRange{llvm::None}, loopOp.lowerBound(),
384         loopOp.upperBound(), loopOp.step(), nullptr);
385 
386     // Move the old body into the new loop.
387     rewriter.mergeBlocks(loopOp.getBody(), newLoopOp.getBody(),
388                          newLoopOp.getInductionVars());
389 
390     // Remove the old op.
391     rewriter.eraseOp(op);
392     return success();
393   }
394 };
395 
396 struct ForOpInterface
397     : public BufferizableOpInterface::ExternalModel<ForOpInterface, ForOp> {
getAliasingOpOperandmlir::gml_st::__anon421317500111::ForOpInterface398   SmallVector<OpOperand *> getAliasingOpOperand(
399       Operation *op, OpResult opResult, const AnalysisState & /*state*/) const {
400     auto forOp = cast<gml_st::ForOp>(op);
401     return {&forOp.getOpOperandForResult(opResult)};
402   }
403 
isWritablemlir::gml_st::__anon421317500111::ForOpInterface404   bool isWritable(Operation * /*op*/, Value /*value*/,
405                   const AnalysisState & /*state*/) const {
406     // Interestingly, ForOp's bbArg can **always** be viewed
407     // inplace from the perspective of ops nested under:
408     //   1. Either the matching iter operand is not bufferized inplace and an
409     //      alloc + optional copy makes the bbArg itself inplaceable.
410     //   2. Or the matching iter operand is bufferized inplace and bbArg just
411     //      bufferizes to that too.
412     return true;
413   }
414 
bufferizesToMemoryReadmlir::gml_st::__anon421317500111::ForOpInterface415   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
416                               const AnalysisState &state) const {
417     auto forOp = cast<gml_st::ForOp>(op);
418     return state.isValueRead(forOp.getRegionOutputArgForOpOperand(opOperand));
419   }
420 
bufferizesToMemoryWritemlir::gml_st::__anon421317500111::ForOpInterface421   bool bufferizesToMemoryWrite(Operation * /*op*/, OpOperand & /*opOperand*/,
422                                const AnalysisState & /*state*/) const {
423     return true;
424   }
425 
getAliasingOpResultmlir::gml_st::__anon421317500111::ForOpInterface426   SmallVector<OpResult> getAliasingOpResult(
427       Operation *op, OpOperand &opOperand,
428       const AnalysisState & /*state*/) const {
429     auto forOp = cast<gml_st::ForOp>(op);
430     return {forOp.getResultForOpOperand(opOperand)};
431   }
432 
bufferRelationmlir::gml_st::__anon421317500111::ForOpInterface433   BufferRelation bufferRelation(Operation * /*op*/, OpResult /*opResult*/,
434                                 const AnalysisState & /*state*/) const {
435     return BufferRelation::Equivalent;
436   }
437 
bufferizemlir::gml_st::__anon421317500111::ForOpInterface438   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
439                           const BufferizationOptions &options) const {
440     auto forOp = cast<ForOp>(op);
441     Location loc = forOp.getLoc();
442 
443     // Get the bufferized output arguments.
444     SmallVector<Value> bufferizedOutputs;
445     bufferizedOutputs.reserve(forOp.getNumOutputs());
446     for (Value output : forOp.outputs()) {
447       FailureOr<Value> maybeBuffer = getBuffer(rewriter, output, options);
448       if (failed(maybeBuffer)) return failure();
449       bufferizedOutputs.push_back(*maybeBuffer);
450     }
451 
452     // Create new ForOp.
453     auto newForOp = rewriter.create<ForOp>(loc, TypeRange{}, forOp.lowerBound(),
454                                            forOp.upperBound(), forOp.step(),
455                                            ValueRange{}, nullptr);
456     Block *loopBody = newForOp.getBody();
457 
458     // Add conversions to tensor so that we can reuse the old loop body.
459     rewriter.setInsertionPointToStart(loopBody);
460     SmallVector<Value> outputsToTensors;
461     for (auto buf : bufferizedOutputs) {
462       Value tensor = rewriter.create<bufferization::ToTensorOp>(loc, buf);
463       outputsToTensors.push_back(tensor);
464     }
465     SmallVector<Value> blockArgs = newForOp.getInductionVars();
466     blockArgs.append(outputsToTensors);
467 
468     // Move old body into new for loop.
469     rewriter.mergeBlocks(forOp.getBody(), loopBody, blockArgs);
470 
471     // Replace results and delete old op.
472     bufferization::replaceOpWithBufferizedValues(rewriter, op,
473                                                  bufferizedOutputs);
474     return success();
475   }
476 };
477 
478 struct SetYieldOpInterface
479     : public BufferizableOpInterface::ExternalModel<SetYieldOpInterface,
480                                                     SetYieldOp> {
getAliasingOpResultmlir::gml_st::__anon421317500111::SetYieldOpInterface481   SmallVector<OpResult> getAliasingOpResult(
482       Operation *op, OpOperand &opOperand,
483       const AnalysisState & /*state*/) const {
484     auto yieldOp = cast<SetYieldOp>(op);
485     if (!yieldOp.isDstOperand(opOperand)) return {};
486 
487     auto loopResult = yieldOp.getTiedOpResult(opOperand);
488     assert(succeeded(loopResult) && "didn't find a corresponding loop result");
489     return {*loopResult};
490   }
491 
bufferizemlir::gml_st::__anon421317500111::SetYieldOpInterface492   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
493                           const BufferizationOptions &options) const {
494     auto yieldOp = cast<SetYieldOp>(op);
495     Operation *loop = yieldOp->getParentOp();
496     if (!isa<ForOp, ParallelOp>(loop))
497       return yieldOp->emitError("unsupported gml_st::SetYieldOp parent");
498 
499     for (const auto &it :
500          llvm::enumerate(llvm::zip(yieldOp.srcs(), yieldOp.dsts(),
501                                    yieldOp.sets(), loop->getResults()))) {
502       Value src, dst, set, loopResult;
503       std::tie(src, dst, set, loopResult) = it.value();
504 
505       // `src` can be a scalar, that's `getBuffer()` should be called only for
506       // tensor types.
507       if (src.getType().isa<RankedTensorType>()) {
508         FailureOr<Value> srcBufferOr = getBuffer(rewriter, src, options);
509         if (failed(srcBufferOr)) return failure();
510 
511         src = *srcBufferOr;
512       }
513 
514       FailureOr<Value> dstBufferOr = getBuffer(rewriter, dst, options);
515       if (failed(dstBufferOr)) return failure();
516       Value dstBuffer = *dstBufferOr;
517 
518       if (failed(materializeInsertion(rewriter, src, set, dstBuffer, options)))
519         return failure();
520       if (auto parallelOp =
521               dyn_cast<gml_st::ParallelOp>(yieldOp->getParentOp())) {
522         // Replace results of the enclosing loop with `to_tensor(dst)`.
523         OpBuilder::InsertionGuard g(rewriter);
524         rewriter.setInsertionPointAfter(loop);
525 
526         Value resultToTensor =
527             rewriter.create<ToTensorOp>(loop->getLoc(), dstBuffer);
528         for (OpOperand &use : loopResult.getUses()) {
529           rewriter.updateRootInPlace(use.getOwner(),
530                                      [&]() { use.set(resultToTensor); });
531         }
532       }
533     }
534     rewriter.replaceOpWithNewOp<SetYieldOp>(op);
535     return success();
536   }
537 
bufferizesToMemoryReadmlir::gml_st::__anon421317500111::SetYieldOpInterface538   bool bufferizesToMemoryRead(Operation * /*op*/, OpOperand & /*opOperand*/,
539                               const AnalysisState & /*state*/) const {
540     return true;
541   }
542 
bufferizesToMemoryWritemlir::gml_st::__anon421317500111::SetYieldOpInterface543   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
544                                const AnalysisState & /*state*/) const {
545     return cast<SetYieldOp>(op).isDstOperand(opOperand);
546   }
547 
isNotConflictingmlir::gml_st::__anon421317500111::SetYieldOpInterface548   bool isNotConflicting(Operation * /*op*/, OpOperand * /*uRead*/,
549                         OpOperand * /*uConflictingWrite*/,
550                         const AnalysisState & /*state*/) const {
551     // TODO(pifon): Implement proper analysis here similar to InsertSliceOp.
552     return true;
553   }
554 };
555 
556 }  // namespace
557 }  // namespace gml_st
558 }  // namespace mlir
559 
registerBufferizableOpInterfaceExternalModels(DialectRegistry & registry)560 void mlir::gml_st::registerBufferizableOpInterfaceExternalModels(
561     DialectRegistry &registry) {
562   registry.addExtension(
563       +[](MLIRContext *ctx, gml_st::GmlStDialect * /*dialect*/) {
564         ForOp::attachInterface<ForOpInterface>(*ctx);
565         LoopOp::attachInterface<LoopOpInterface>(*ctx);
566         MaterializeOp::attachInterface<MaterializeOpInterface>(*ctx);
567         ParallelOp::attachInterface<ParallelOpInterface>(*ctx);
568         SetYieldOp::attachInterface<SetYieldOpInterface>(*ctx);
569       });
570 }
571