1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <vector>
17
18 #include "llvm/ADT/ArrayRef.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SmallPtrSet.h"
21 #include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project
22 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
23 #include "mlir/IR/Value.h" // from @llvm-project
24 #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
25 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
26
27 namespace mlir {
28 namespace kernel_gen {
29 namespace transforms {
30 namespace {
31 #define GEN_PASS_CLASSES
32 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
33
34 // A pass to remove memref::AllocOps and memref::CopyOps ops.
35 //
36 // The idea behind this pass is to collect all patterns we are interested in in
37 // a single place. Eventually, this should be replaced by a generalized copy
38 // removal pass.
39
40 // Handles the pattern where an input operand of a linalg generic is copied
41 // even though the producer is not mutated.
RemoveCopyIfTargetOnlyRead(func::FuncOp func)42 void RemoveCopyIfTargetOnlyRead(func::FuncOp func) {
43 llvm::SmallVector<memref::AllocOp, 8> allocs_to_remove;
44 llvm::SmallVector<memref::CopyOp, 8> copies_to_remove;
45
46 // Gather all allocs and copies which are only read and have an immutable
47 // source.
48 func->walk([&](memref::AllocOp op) {
49 memref::CopyOp copy;
50 MemoryEffectOpInterface reader;
51 bool at_most_one_copy = true;
52 bool at_most_one_read = true;
53 for (auto user : op->getUsers()) {
54 if (auto copy_user = dyn_cast<memref::CopyOp>(user)) {
55 if (copy) {
56 at_most_one_copy = false;
57 } else {
58 copy = copy_user;
59 }
60 continue;
61 }
62 if (auto effect_interface = cast<MemoryEffectOpInterface>(user)) {
63 if (reader) {
64 at_most_one_read = false;
65 } else {
66 reader = effect_interface;
67 }
68 SmallVector<MemoryEffects::EffectInstance, 2> effects;
69 effect_interface.getEffectsOnValue(op.getResult(), effects);
70 if (llvm::any_of(effects, [](MemoryEffects::EffectInstance it) {
71 return !isa<MemoryEffects::Read>(it.getEffect());
72 })) {
73 at_most_one_read = false;
74 }
75 continue;
76 }
77 // We don't understand this use, be conservative.
78 at_most_one_read = false;
79 }
80 if (!copy || !at_most_one_copy) return;
81 if (!reader || !at_most_one_read) return;
82 // The copy should have the alloc op as target.
83 if (copy.getTarget() != op.getResult()) return;
84
85 // The copy should be before the reading use.
86 if (copy->getBlock() != reader->getBlock() ||
87 !copy->isBeforeInBlock(reader)) {
88 return;
89 }
90
91 // No write effects between copy and use. With aliasing information, this
92 // could be made more precise but for now we have to be conservative. The
93 // only thing we allow are writes to values that are allocated after the
94 // copy, as the aliasing is clear in those cases.
95 bool source_is_mutated = false;
96 for (Operation *pos = copy->getNextNode(), *end = reader; pos != end;
97 pos = pos->getNextNode()) {
98 auto effect_interface = dyn_cast<MemoryEffectOpInterface>(pos);
99 if (!effect_interface) {
100 continue;
101 }
102 SmallVector<MemoryEffects::EffectInstance, 2> effects;
103 effect_interface.getEffects<MemoryEffects::Write>(effects);
104 for (auto effect : effects) {
105 if (auto alloc = effect.getValue().getDefiningOp<memref::AllocOp>()) {
106 if (alloc->getBlock() == copy->getBlock() &&
107 copy->isBeforeInBlock(alloc)) {
108 continue;
109 }
110 }
111 source_is_mutated = true;
112 break;
113 }
114 }
115 if (source_is_mutated) return;
116
117 op->replaceAllUsesWith(ValueRange{copy.getSource()});
118 allocs_to_remove.push_back(op);
119 copies_to_remove.push_back(copy);
120 });
121 llvm::for_each(allocs_to_remove, [](Operation *op) { op->erase(); });
122 llvm::for_each(copies_to_remove, [](Operation *op) { op->erase(); });
123 }
124
125 // Handles the case where the last instructions of a function implements a copy
126 // back to a function argument.
RemoveCopyIfTargetIsFunctionArg(func::FuncOp func)127 void RemoveCopyIfTargetIsFunctionArg(func::FuncOp func) {
128 // For now only support this on functions with a single block.
129 if (!func.getBody().hasOneBlock()) return;
130
131 llvm::SmallVector<memref::AllocOp> allocs_to_remove;
132 llvm::SmallVector<memref::CopyOp> copies_to_remove;
133
134 Block &body = func.getBody().front();
135 for (auto &op : llvm::reverse(body.without_terminator())) {
136 if (auto copy = dyn_cast<memref::CopyOp>(op)) {
137 auto block_arg = copy.getTarget().dyn_cast<BlockArgument>();
138 if (!block_arg) break;
139 if (!isa<func::FuncOp>(block_arg.getOwner()->getParentOp()) ||
140 !block_arg.hasOneUse())
141 break;
142 auto alloc = copy.getSource().getDefiningOp<memref::AllocOp>();
143 if (!alloc) break;
144 alloc->replaceAllUsesWith(ValueRange{block_arg});
145 allocs_to_remove.push_back(alloc);
146 copies_to_remove.push_back(copy);
147 continue;
148 }
149 break;
150 }
151 llvm::for_each(allocs_to_remove, [](Operation *op) { op->erase(); });
152 llvm::for_each(copies_to_remove, [](Operation *op) { op->erase(); });
153 }
154
155 } // namespace
156
157 struct CopyCleanupPass : public CopyCleanupPassBase<CopyCleanupPass> {
getDependentDialectsmlir::kernel_gen::transforms::CopyCleanupPass158 void getDependentDialects(DialectRegistry ®istry) const override {
159 registry.insert<memref::MemRefDialect>();
160 }
161
runOnOperationmlir::kernel_gen::transforms::CopyCleanupPass162 void runOnOperation() override {
163 RemoveCopyIfTargetOnlyRead(getOperation());
164 RemoveCopyIfTargetIsFunctionArg(getOperation());
165 }
166 };
167
CreateCopyCleanupPass()168 std::unique_ptr<OperationPass<func::FuncOp>> CreateCopyCleanupPass() {
169 return std::make_unique<CopyCleanupPass>();
170 }
171
172 } // namespace transforms
173 } // namespace kernel_gen
174 } // namespace mlir
175