1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cstddef>
17 #include <vector>
18
19 #include "llvm/ADT/None.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "mlir/Analysis/BufferViewFlowAnalysis.h" // from @llvm-project
23 #include "mlir/Analysis/Liveness.h" // from @llvm-project
24 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
25 #include "mlir/Dialect/Linalg/IR/Linalg.h" // from @llvm-project
26 #include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project
27 #include "mlir/IR/AffineMap.h" // from @llvm-project
28 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
29 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
30 #include "mlir/IR/Operation.h" // from @llvm-project
31 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
32 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
33 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
34 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
35
36 constexpr llvm::StringRef
37 mlir::kernel_gen::tf_framework::TFAllocOp::kReuseOutputAttrName;
38 constexpr llvm::StringRef
39 mlir::kernel_gen::tf_framework::TFAllocOp::kReuseInputCandidatesAttrName;
40 constexpr llvm::StringRef
41 mlir::kernel_gen::tf_framework::TFFrameworkDialect::kTFEntryAttrName;
42
43 namespace mlir {
44 namespace kernel_gen {
45 namespace transforms {
46 namespace {
47
48 class BufferReuseAnalysis {
49 public:
BufferReuseAnalysis(func::FuncOp f)50 explicit BufferReuseAnalysis(func::FuncOp f) { build(f); }
51
52 static constexpr int32_t kIndexAmbiguous = -1;
53
get_reuse_candiates(memref::AllocOp op)54 Optional<SmallVector<int32_t, 2>> get_reuse_candiates(memref::AllocOp op) {
55 auto it = reuse_candidates_.find(op);
56 if (it == reuse_candidates_.end()) return llvm::None;
57 return it->second;
58 }
59
get_output_index(memref::AllocOp op)60 Optional<int32_t> get_output_index(memref::AllocOp op) {
61 auto it = output_indices_.find(op);
62 if (it == output_indices_.end()) return llvm::None;
63 return it->second;
64 }
65
66 private:
build(func::FuncOp & f)67 void build(func::FuncOp &f) {
68 BufferViewFlowAnalysis aliases(f);
69 find_output_indices(f, aliases);
70 find_reuse_candiates(f, aliases);
71 }
72
find_output_indices(func::FuncOp & f,BufferViewFlowAnalysis & aliases)73 void find_output_indices(func::FuncOp &f, BufferViewFlowAnalysis &aliases) {
74 f.walk([&](memref::AllocOp alloc_op) {
75 int32_t output_index = kIndexAmbiguous;
76 int count_return_uses = 0;
77 auto buffer_aliases = aliases.resolve(alloc_op.getResult());
78 for (Value alias : buffer_aliases) {
79 for (auto &use : alias.getUses()) {
80 if (isa<func::ReturnOp>(use.getOwner())) {
81 int32_t index = use.getOperandNumber();
82 if (count_return_uses++ == 0)
83 output_index = index;
84 else if (output_index != index)
85 output_index = kIndexAmbiguous;
86 }
87 }
88 }
89 output_indices_[alloc_op] = output_index;
90 });
91 }
92
find_reuse_candiates(func::FuncOp & f,BufferViewFlowAnalysis & aliases)93 void find_reuse_candiates(func::FuncOp &f, BufferViewFlowAnalysis &aliases) {
94 Liveness liveness(f);
95 f.walk([&](Block *block) {
96 find_reuse_candiates(block, aliases, liveness.getLiveness(block),
97 f.getArguments());
98 });
99 }
100
find_reuse_candiates(Block * block,BufferViewFlowAnalysis & aliases,const LivenessBlockInfo * liveness,ArrayRef<BlockArgument> arguments)101 void find_reuse_candiates(Block *block, BufferViewFlowAnalysis &aliases,
102 const LivenessBlockInfo *liveness,
103 ArrayRef<BlockArgument> arguments) {
104 for (Operation &op : *block) {
105 auto alloc_op = dyn_cast<memref::AllocOp>(op);
106 if (!alloc_op) continue;
107
108 // Find first use of the newly allocated buffer within this block.
109 Value new_buffer = alloc_op.getResult();
110 Operation *first_reuse = find_first_use_in_block(new_buffer, block);
111 assert((first_reuse == nullptr || first_reuse->getBlock() == block) &&
112 "Expected first use in same block if found.");
113
114 // Find reuse candidates for the regarded allocation.
115 SmallVector<int32_t, 2> local_reuse_candidates;
116 for (BlockArgument old_buffer : arguments) {
117 if (!old_buffer.getType().isa<BaseMemRefType>()) continue;
118
119 // Lifetime criterion: Only reuse buffers that are no longer used on
120 // first reuse, i.e. they are no longer alive.
121 bool lifetimes_compatible = true;
122 for (Value old_buffer_alias : aliases.resolve(old_buffer)) {
123 if (first_reuse == nullptr) {
124 // If the first use is beyond the end of this block we look at the
125 // block end. An argument buffer that is already reusable there is
126 // certainly reusable at any later actual use. Otherwise, lifetimes
127 // are incompatible.
128 if (liveness->isLiveOut(old_buffer_alias)) {
129 lifetimes_compatible = false;
130 break;
131 }
132 } else {
133 // A buffer is reusable if
134 // i) its last use is before the point of reuse, or
135 // ii) its last use is also its first reuse and the operation
136 // allows for local reuse.
137 // Otherwise, lifetimes are incompatible.
138 Operation *last_use =
139 liveness->getEndOperation(old_buffer_alias, &block->front());
140 assert(last_use != nullptr && last_use->getBlock() == block &&
141 "Expected last use in same block.");
142 if (first_reuse->isBeforeInBlock(last_use)) {
143 lifetimes_compatible = false;
144 break;
145 }
146 if (first_reuse == last_use &&
147 !can_reuse_locally(first_reuse, old_buffer_alias, new_buffer)) {
148 lifetimes_compatible = false;
149 break;
150 }
151 }
152 }
153
154 if (lifetimes_compatible) {
155 // All criteria are fulfilled .
156 int32_t old_buffer_index = old_buffer.getArgNumber();
157 local_reuse_candidates.push_back(old_buffer_index);
158 }
159 }
160
161 reuse_candidates_[&op] = local_reuse_candidates;
162 }
163 }
164
find_first_use_in_block(Value value,Block * block)165 Operation *find_first_use_in_block(Value value, Block *block) {
166 Operation *first_use = nullptr;
167 for (Operation *op : value.getUsers()) {
168 Operation *ancestor_op = block->findAncestorOpInBlock(*op);
169 if (ancestor_op == nullptr) continue;
170 if (first_use == nullptr || ancestor_op->isBeforeInBlock(first_use))
171 first_use = ancestor_op;
172 }
173 return first_use;
174 }
175
get_buffer_arguments(func::FuncOp & f)176 std::vector<Value> get_buffer_arguments(func::FuncOp &f) {
177 std::vector<Value> buffer_arguments;
178 for (BlockArgument arg : f.getArguments()) {
179 if (arg.getType().isa<BaseMemRefType>()) buffer_arguments.push_back(arg);
180 }
181 return buffer_arguments;
182 }
183
can_reuse_locally(Operation * op,Value old_buffer,Value new_buffer)184 bool can_reuse_locally(Operation *op, Value old_buffer, Value new_buffer) {
185 // For now, we support only memrefs with the same memory layout.
186 auto old_buffer_ty = old_buffer.getType().dyn_cast<MemRefType>();
187 auto new_buffer_ty = old_buffer.getType().dyn_cast<MemRefType>();
188 if (!old_buffer_ty || !new_buffer_ty ||
189 old_buffer_ty.getLayout() != new_buffer_ty.getLayout())
190 return false;
191
192 if (auto generic_op = dyn_cast<linalg::GenericOp>(op)) {
193 SmallVector<OpOperand *> op_operands =
194 generic_op.getInputAndOutputOperands();
195 auto old_it = llvm::find_if(op_operands, [&](OpOperand *op_operand) {
196 return op_operand->get() == old_buffer;
197 });
198 auto new_it = llvm::find_if(op_operands, [&](OpOperand *op_operand) {
199 return op_operand->get() == new_buffer;
200 });
201 assert(old_it != op_operands.end() && new_it != op_operands.end() &&
202 "Expect `old/new_buffer` to be operand of `op`.");
203
204 auto is_projection = [](AffineMap map) {
205 // Allow dropping dimensions but no permutations.
206 int64_t i = -1;
207 for (AffineExpr expr : map.getResults()) {
208 auto dim_expr = expr.dyn_cast<AffineDimExpr>();
209 if (!dim_expr || dim_expr.getPosition() <= i) return false;
210 i = dim_expr.getPosition();
211 }
212 return true;
213 };
214
215 // If `linalg.generic` indexing maps are the same for input and output
216 // buffer then the last use of the input buffer happens before its first
217 // reuse (per memory location). Since we know that the inputs and outputs
218 // have the same size we also know that when one side has an identity map
219 // and the other side only drops dimensions, these dimensions have to be
220 // of size 1.
221 AffineMap old_indexing_map = generic_op.getTiedIndexingMap(*old_it);
222 AffineMap new_indexing_map = generic_op.getTiedIndexingMap(*new_it);
223 return (old_indexing_map == new_indexing_map &&
224 old_indexing_map.isProjectedPermutation()) ||
225 (old_indexing_map.isIdentity() &&
226 is_projection(new_indexing_map)) ||
227 (is_projection(old_indexing_map) && new_indexing_map.isIdentity());
228 }
229 return false;
230 }
231
232 DenseMap<Operation *, SmallVector<int32_t, 2>> reuse_candidates_;
233 DenseMap<Operation *, int32_t> output_indices_;
234 };
235
236 #define GEN_PASS_CLASSES
237 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
238
239 struct BufferReusePass : public BufferReusePassBase<BufferReusePass> {
runOnOperationmlir::kernel_gen::transforms::__anon010073860111::BufferReusePass240 void runOnOperation() override {
241 if (!getOperation()->getAttrOfType<UnitAttr>(
242 tf_framework::TFFrameworkDialect::kTFEntryAttrName))
243 return;
244
245 BufferReuseAnalysis analysis(getOperation());
246
247 // Annotate IR with reuse candidates and output indices per allocation.
248 Builder builder(&getContext());
249 getOperation().walk([&](memref::AllocOp op) {
250 if (auto output_index = analysis.get_output_index(op)) {
251 auto attr = builder.getI32IntegerAttr(*output_index);
252 op.getOperation()->setAttr(
253 tf_framework::TFAllocOp::kReuseOutputAttrName, attr);
254 }
255 if (auto reuse_candiates = analysis.get_reuse_candiates(op)) {
256 auto attr = builder.getI32ArrayAttr(*reuse_candiates);
257 op.getOperation()->setAttr(
258 tf_framework::TFAllocOp::kReuseInputCandidatesAttrName, attr);
259 }
260 });
261 }
262 };
263
264 } // namespace
265
CreateBufferReusePass()266 std::unique_ptr<OperationPass<func::FuncOp>> CreateBufferReusePass() {
267 return std::make_unique<BufferReusePass>();
268 }
269
270 } // namespace transforms
271 } // namespace kernel_gen
272 } // namespace mlir
273