1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <vector>
17 
18 #include "llvm/ADT/ArrayRef.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SmallPtrSet.h"
21 #include "mlir/Dialect/MemRef/IR/MemRef.h"  // from @llvm-project
22 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
23 #include "mlir/IR/Value.h"  // from @llvm-project
24 #include "mlir/Interfaces/SideEffectInterfaces.h"  // from @llvm-project
25 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
26 
27 namespace mlir {
28 namespace kernel_gen {
29 namespace transforms {
30 namespace {
31 #define GEN_PASS_CLASSES
32 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
33 
34 // A pass to remove memref::AllocOps and memref::CopyOps ops.
35 //
36 // The idea behind this pass is to collect all patterns we are interested in in
37 // a single place. Eventually, this should be replaced by a generalized copy
38 // removal pass.
39 
40 // Handles the pattern where an input operand of a linalg generic is copied
41 // even though the producer is not mutated.
RemoveCopyIfTargetOnlyRead(func::FuncOp func)42 void RemoveCopyIfTargetOnlyRead(func::FuncOp func) {
43   llvm::SmallVector<memref::AllocOp, 8> allocs_to_remove;
44   llvm::SmallVector<memref::CopyOp, 8> copies_to_remove;
45 
46   // Gather all allocs and copies which are only read and have an immutable
47   // source.
48   func->walk([&](memref::AllocOp op) {
49     memref::CopyOp copy;
50     MemoryEffectOpInterface reader;
51     bool at_most_one_copy = true;
52     bool at_most_one_read = true;
53     for (auto user : op->getUsers()) {
54       if (auto copy_user = dyn_cast<memref::CopyOp>(user)) {
55         if (copy) {
56           at_most_one_copy = false;
57         } else {
58           copy = copy_user;
59         }
60         continue;
61       }
62       if (auto effect_interface = cast<MemoryEffectOpInterface>(user)) {
63         if (reader) {
64           at_most_one_read = false;
65         } else {
66           reader = effect_interface;
67         }
68         SmallVector<MemoryEffects::EffectInstance, 2> effects;
69         effect_interface.getEffectsOnValue(op.getResult(), effects);
70         if (llvm::any_of(effects, [](MemoryEffects::EffectInstance it) {
71               return !isa<MemoryEffects::Read>(it.getEffect());
72             })) {
73           at_most_one_read = false;
74         }
75         continue;
76       }
77       // We don't understand this use, be conservative.
78       at_most_one_read = false;
79     }
80     if (!copy || !at_most_one_copy) return;
81     if (!reader || !at_most_one_read) return;
82     // The copy should have the alloc op as target.
83     if (copy.getTarget() != op.getResult()) return;
84 
85     // The copy should be before the reading use.
86     if (copy->getBlock() != reader->getBlock() ||
87         !copy->isBeforeInBlock(reader)) {
88       return;
89     }
90 
91     // No write effects between copy and use. With aliasing information, this
92     // could be made more precise but for now we have to be conservative. The
93     // only thing we allow are writes to values that are allocated after the
94     // copy, as the aliasing is clear in those cases.
95     bool source_is_mutated = false;
96     for (Operation *pos = copy->getNextNode(), *end = reader; pos != end;
97          pos = pos->getNextNode()) {
98       auto effect_interface = dyn_cast<MemoryEffectOpInterface>(pos);
99       if (!effect_interface) {
100         continue;
101       }
102       SmallVector<MemoryEffects::EffectInstance, 2> effects;
103       effect_interface.getEffects<MemoryEffects::Write>(effects);
104       for (auto effect : effects) {
105         if (auto alloc = effect.getValue().getDefiningOp<memref::AllocOp>()) {
106           if (alloc->getBlock() == copy->getBlock() &&
107               copy->isBeforeInBlock(alloc)) {
108             continue;
109           }
110         }
111         source_is_mutated = true;
112         break;
113       }
114     }
115     if (source_is_mutated) return;
116 
117     op->replaceAllUsesWith(ValueRange{copy.getSource()});
118     allocs_to_remove.push_back(op);
119     copies_to_remove.push_back(copy);
120   });
121   llvm::for_each(allocs_to_remove, [](Operation *op) { op->erase(); });
122   llvm::for_each(copies_to_remove, [](Operation *op) { op->erase(); });
123 }
124 
125 // Handles the case where the last instructions of a function implements a copy
126 // back to a function argument.
RemoveCopyIfTargetIsFunctionArg(func::FuncOp func)127 void RemoveCopyIfTargetIsFunctionArg(func::FuncOp func) {
128   // For now only support this on functions with a single block.
129   if (!func.getBody().hasOneBlock()) return;
130 
131   llvm::SmallVector<memref::AllocOp> allocs_to_remove;
132   llvm::SmallVector<memref::CopyOp> copies_to_remove;
133 
134   Block &body = func.getBody().front();
135   for (auto &op : llvm::reverse(body.without_terminator())) {
136     if (auto copy = dyn_cast<memref::CopyOp>(op)) {
137       auto block_arg = copy.getTarget().dyn_cast<BlockArgument>();
138       if (!block_arg) break;
139       if (!isa<func::FuncOp>(block_arg.getOwner()->getParentOp()) ||
140           !block_arg.hasOneUse())
141         break;
142       auto alloc = copy.getSource().getDefiningOp<memref::AllocOp>();
143       if (!alloc) break;
144       alloc->replaceAllUsesWith(ValueRange{block_arg});
145       allocs_to_remove.push_back(alloc);
146       copies_to_remove.push_back(copy);
147       continue;
148     }
149     break;
150   }
151   llvm::for_each(allocs_to_remove, [](Operation *op) { op->erase(); });
152   llvm::for_each(copies_to_remove, [](Operation *op) { op->erase(); });
153 }
154 
155 }  // namespace
156 
157 struct CopyCleanupPass : public CopyCleanupPassBase<CopyCleanupPass> {
getDependentDialectsmlir::kernel_gen::transforms::CopyCleanupPass158   void getDependentDialects(DialectRegistry &registry) const override {
159     registry.insert<memref::MemRefDialect>();
160   }
161 
runOnOperationmlir::kernel_gen::transforms::CopyCleanupPass162   void runOnOperation() override {
163     RemoveCopyIfTargetOnlyRead(getOperation());
164     RemoveCopyIfTargetIsFunctionArg(getOperation());
165   }
166 };
167 
CreateCopyCleanupPass()168 std::unique_ptr<OperationPass<func::FuncOp>> CreateCopyCleanupPass() {
169   return std::make_unique<CopyCleanupPass>();
170 }
171 
172 }  // namespace transforms
173 }  // namespace kernel_gen
174 }  // namespace mlir
175