1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cstddef>
17 #include <vector>
18 
19 #include "llvm/ADT/None.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "mlir/Analysis/BufferViewFlowAnalysis.h"  // from @llvm-project
23 #include "mlir/Analysis/Liveness.h"  // from @llvm-project
24 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
25 #include "mlir/Dialect/Linalg/IR/Linalg.h"  // from @llvm-project
26 #include "mlir/Dialect/MemRef/IR/MemRef.h"  // from @llvm-project
27 #include "mlir/IR/AffineMap.h"  // from @llvm-project
28 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
29 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
30 #include "mlir/IR/Operation.h"  // from @llvm-project
31 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
32 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
33 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
34 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
35 
36 constexpr llvm::StringRef
37     mlir::kernel_gen::tf_framework::TFAllocOp::kReuseOutputAttrName;
38 constexpr llvm::StringRef
39     mlir::kernel_gen::tf_framework::TFAllocOp::kReuseInputCandidatesAttrName;
40 constexpr llvm::StringRef
41     mlir::kernel_gen::tf_framework::TFFrameworkDialect::kTFEntryAttrName;
42 
43 namespace mlir {
44 namespace kernel_gen {
45 namespace transforms {
46 namespace {
47 
48 class BufferReuseAnalysis {
49  public:
BufferReuseAnalysis(func::FuncOp f)50   explicit BufferReuseAnalysis(func::FuncOp f) { build(f); }
51 
52   static constexpr int32_t kIndexAmbiguous = -1;
53 
get_reuse_candiates(memref::AllocOp op)54   Optional<SmallVector<int32_t, 2>> get_reuse_candiates(memref::AllocOp op) {
55     auto it = reuse_candidates_.find(op);
56     if (it == reuse_candidates_.end()) return llvm::None;
57     return it->second;
58   }
59 
get_output_index(memref::AllocOp op)60   Optional<int32_t> get_output_index(memref::AllocOp op) {
61     auto it = output_indices_.find(op);
62     if (it == output_indices_.end()) return llvm::None;
63     return it->second;
64   }
65 
66  private:
build(func::FuncOp & f)67   void build(func::FuncOp &f) {
68     BufferViewFlowAnalysis aliases(f);
69     find_output_indices(f, aliases);
70     find_reuse_candiates(f, aliases);
71   }
72 
find_output_indices(func::FuncOp & f,BufferViewFlowAnalysis & aliases)73   void find_output_indices(func::FuncOp &f, BufferViewFlowAnalysis &aliases) {
74     f.walk([&](memref::AllocOp alloc_op) {
75       int32_t output_index = kIndexAmbiguous;
76       int count_return_uses = 0;
77       auto buffer_aliases = aliases.resolve(alloc_op.getResult());
78       for (Value alias : buffer_aliases) {
79         for (auto &use : alias.getUses()) {
80           if (isa<func::ReturnOp>(use.getOwner())) {
81             int32_t index = use.getOperandNumber();
82             if (count_return_uses++ == 0)
83               output_index = index;
84             else if (output_index != index)
85               output_index = kIndexAmbiguous;
86           }
87         }
88       }
89       output_indices_[alloc_op] = output_index;
90     });
91   }
92 
find_reuse_candiates(func::FuncOp & f,BufferViewFlowAnalysis & aliases)93   void find_reuse_candiates(func::FuncOp &f, BufferViewFlowAnalysis &aliases) {
94     Liveness liveness(f);
95     f.walk([&](Block *block) {
96       find_reuse_candiates(block, aliases, liveness.getLiveness(block),
97                            f.getArguments());
98     });
99   }
100 
find_reuse_candiates(Block * block,BufferViewFlowAnalysis & aliases,const LivenessBlockInfo * liveness,ArrayRef<BlockArgument> arguments)101   void find_reuse_candiates(Block *block, BufferViewFlowAnalysis &aliases,
102                             const LivenessBlockInfo *liveness,
103                             ArrayRef<BlockArgument> arguments) {
104     for (Operation &op : *block) {
105       auto alloc_op = dyn_cast<memref::AllocOp>(op);
106       if (!alloc_op) continue;
107 
108       // Find first use of the newly allocated buffer within this block.
109       Value new_buffer = alloc_op.getResult();
110       Operation *first_reuse = find_first_use_in_block(new_buffer, block);
111       assert((first_reuse == nullptr || first_reuse->getBlock() == block) &&
112              "Expected first use in same block if found.");
113 
114       // Find reuse candidates for the regarded allocation.
115       SmallVector<int32_t, 2> local_reuse_candidates;
116       for (BlockArgument old_buffer : arguments) {
117         if (!old_buffer.getType().isa<BaseMemRefType>()) continue;
118 
119         // Lifetime criterion: Only reuse buffers that are no longer used on
120         // first reuse, i.e. they are no longer alive.
121         bool lifetimes_compatible = true;
122         for (Value old_buffer_alias : aliases.resolve(old_buffer)) {
123           if (first_reuse == nullptr) {
124             // If the first use is beyond the end of this block we look at the
125             // block end. An argument buffer that is already reusable there is
126             // certainly reusable at any later actual use. Otherwise, lifetimes
127             // are incompatible.
128             if (liveness->isLiveOut(old_buffer_alias)) {
129               lifetimes_compatible = false;
130               break;
131             }
132           } else {
133             // A buffer is reusable if
134             //   i)  its last use is before the point of reuse, or
135             //   ii) its last use is also its first reuse and the operation
136             //       allows for local reuse.
137             // Otherwise, lifetimes are incompatible.
138             Operation *last_use =
139                 liveness->getEndOperation(old_buffer_alias, &block->front());
140             assert(last_use != nullptr && last_use->getBlock() == block &&
141                    "Expected last use in same block.");
142             if (first_reuse->isBeforeInBlock(last_use)) {
143               lifetimes_compatible = false;
144               break;
145             }
146             if (first_reuse == last_use &&
147                 !can_reuse_locally(first_reuse, old_buffer_alias, new_buffer)) {
148               lifetimes_compatible = false;
149               break;
150             }
151           }
152         }
153 
154         if (lifetimes_compatible) {
155           // All criteria are fulfilled ��.
156           int32_t old_buffer_index = old_buffer.getArgNumber();
157           local_reuse_candidates.push_back(old_buffer_index);
158         }
159       }
160 
161       reuse_candidates_[&op] = local_reuse_candidates;
162     }
163   }
164 
find_first_use_in_block(Value value,Block * block)165   Operation *find_first_use_in_block(Value value, Block *block) {
166     Operation *first_use = nullptr;
167     for (Operation *op : value.getUsers()) {
168       Operation *ancestor_op = block->findAncestorOpInBlock(*op);
169       if (ancestor_op == nullptr) continue;
170       if (first_use == nullptr || ancestor_op->isBeforeInBlock(first_use))
171         first_use = ancestor_op;
172     }
173     return first_use;
174   }
175 
get_buffer_arguments(func::FuncOp & f)176   std::vector<Value> get_buffer_arguments(func::FuncOp &f) {
177     std::vector<Value> buffer_arguments;
178     for (BlockArgument arg : f.getArguments()) {
179       if (arg.getType().isa<BaseMemRefType>()) buffer_arguments.push_back(arg);
180     }
181     return buffer_arguments;
182   }
183 
can_reuse_locally(Operation * op,Value old_buffer,Value new_buffer)184   bool can_reuse_locally(Operation *op, Value old_buffer, Value new_buffer) {
185     // For now, we support only memrefs with the same memory layout.
186     auto old_buffer_ty = old_buffer.getType().dyn_cast<MemRefType>();
187     auto new_buffer_ty = old_buffer.getType().dyn_cast<MemRefType>();
188     if (!old_buffer_ty || !new_buffer_ty ||
189         old_buffer_ty.getLayout() != new_buffer_ty.getLayout())
190       return false;
191 
192     if (auto generic_op = dyn_cast<linalg::GenericOp>(op)) {
193       SmallVector<OpOperand *> op_operands =
194           generic_op.getInputAndOutputOperands();
195       auto old_it = llvm::find_if(op_operands, [&](OpOperand *op_operand) {
196         return op_operand->get() == old_buffer;
197       });
198       auto new_it = llvm::find_if(op_operands, [&](OpOperand *op_operand) {
199         return op_operand->get() == new_buffer;
200       });
201       assert(old_it != op_operands.end() && new_it != op_operands.end() &&
202              "Expect `old/new_buffer` to be operand of `op`.");
203 
204       auto is_projection = [](AffineMap map) {
205         // Allow dropping dimensions but no permutations.
206         int64_t i = -1;
207         for (AffineExpr expr : map.getResults()) {
208           auto dim_expr = expr.dyn_cast<AffineDimExpr>();
209           if (!dim_expr || dim_expr.getPosition() <= i) return false;
210           i = dim_expr.getPosition();
211         }
212         return true;
213       };
214 
215       // If `linalg.generic` indexing maps are the same for input and output
216       // buffer then the last use of the input buffer happens before its first
217       // reuse (per memory location). Since we know that the inputs and outputs
218       // have the same size we also know that when one side has an identity map
219       // and the other side only drops dimensions, these dimensions have to be
220       // of size 1.
221       AffineMap old_indexing_map = generic_op.getTiedIndexingMap(*old_it);
222       AffineMap new_indexing_map = generic_op.getTiedIndexingMap(*new_it);
223       return (old_indexing_map == new_indexing_map &&
224               old_indexing_map.isProjectedPermutation()) ||
225              (old_indexing_map.isIdentity() &&
226               is_projection(new_indexing_map)) ||
227              (is_projection(old_indexing_map) && new_indexing_map.isIdentity());
228     }
229     return false;
230   }
231 
232   DenseMap<Operation *, SmallVector<int32_t, 2>> reuse_candidates_;
233   DenseMap<Operation *, int32_t> output_indices_;
234 };
235 
236 #define GEN_PASS_CLASSES
237 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
238 
239 struct BufferReusePass : public BufferReusePassBase<BufferReusePass> {
runOnOperationmlir::kernel_gen::transforms::__anon010073860111::BufferReusePass240   void runOnOperation() override {
241     if (!getOperation()->getAttrOfType<UnitAttr>(
242             tf_framework::TFFrameworkDialect::kTFEntryAttrName))
243       return;
244 
245     BufferReuseAnalysis analysis(getOperation());
246 
247     // Annotate IR with reuse candidates and output indices per allocation.
248     Builder builder(&getContext());
249     getOperation().walk([&](memref::AllocOp op) {
250       if (auto output_index = analysis.get_output_index(op)) {
251         auto attr = builder.getI32IntegerAttr(*output_index);
252         op.getOperation()->setAttr(
253             tf_framework::TFAllocOp::kReuseOutputAttrName, attr);
254       }
255       if (auto reuse_candiates = analysis.get_reuse_candiates(op)) {
256         auto attr = builder.getI32ArrayAttr(*reuse_candiates);
257         op.getOperation()->setAttr(
258             tf_framework::TFAllocOp::kReuseInputCandidatesAttrName, attr);
259       }
260     });
261   }
262 };
263 
264 }  // namespace
265 
CreateBufferReusePass()266 std::unique_ptr<OperationPass<func::FuncOp>> CreateBufferReusePass() {
267   return std::make_unique<BufferReusePass>();
268 }
269 
270 }  // namespace transforms
271 }  // namespace kernel_gen
272 }  // namespace mlir
273