1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file contains the analysis and transformation to rewrite kernel
17 // functions such that information about alignment, aliasing and zero offsets
18 // steming from the tf_framework uses is propagated.
19
20 #include <cstdint>
21 #include <memory>
22
23 #include "llvm/ADT/Bitfields.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
26 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
27 #include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project
28 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
29 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
30 #include "mlir/IR/MLIRContext.h" // from @llvm-project
31 #include "mlir/Support/LLVM.h" // from @llvm-project
32 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
33 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
34 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
35
36 namespace mlir {
37 namespace kernel_gen {
38 namespace transforms {
39 namespace {
40
41 #define GEN_PASS_CLASSES
42 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
43
44 struct PropagateTfAbiKnowledgeToKernelsPass
45 : public PropagateTfAbiKnowledgeToKernelsBase<
46 PropagateTfAbiKnowledgeToKernelsPass> {
runOnOperationmlir::kernel_gen::transforms::__anondd5d2af30111::PropagateTfAbiKnowledgeToKernelsPass47 void runOnOperation() override {
48 func::FuncOp function = getOperation();
49 llvm::SmallVector<Value, 4> worklist;
50 // We currently only handle entry functions and do not propagate across
51 // functions.
52 if (function->getAttrOfType<mlir::UnitAttr>(
53 tf_framework::TFFrameworkDialect::kTFEntryAttrName)) {
54 // For all operands of this function, we know they are aligned. Also, by
55 // construction of kernel generator, we know that there is no offset and
56 // the inner stride is one.
57 // TODO(herhut): Insert asserts in debug mode to check this.
58 for (auto argument : function.getArguments()) {
59 if (argument.getType().isa<BaseMemRefType>()) {
60 worklist.push_back(argument);
61 allocated_by_tf_runtime.insert(argument);
62 offset_is_zero.insert(argument);
63 inner_stride_is_constant.insert({argument, 1});
64 }
65 }
66 }
67
68 // For locally allocated values, we know they are aligned and have offset
69 // zero. Further, they also do not alias with other memrefs, except in
70 // benign ways. This is by construction and ensured by the reuse analysis.
71 function.walk([&](tf_framework::TFAllocOp op) {
72 Value allocated = op.getResult();
73 worklist.push_back(allocated);
74 no_alias.insert(allocated);
75 allocated_by_tf_runtime.insert(allocated);
76 offset_is_zero.insert(allocated);
77 inner_stride_is_constant.insert({allocated, 1});
78 });
79
80 // Next, take what we have and propagate it through known operations.
81 propagateThroughUses(worklist);
82
83 // Now look at launches and make use of the knowledge we have.
84 function.walk([&](gpu::LaunchFuncOp launch) {
85 auto module = launch->getParentOfType<ModuleOp>();
86 auto kernel = module.lookupSymbol<LLVM::LLVMFuncOp>(launch.kernel());
87
88 if (!kernel || kernel.isExternal()) return;
89
90 // Count the position of kernel operands independently, as they do not
91 // coincide with laucnh operands as memref parameters get expanded when
92 // lowered to llvm.
93 int kernel_p = 0;
94 OpBuilder b = OpBuilder::atBlockBegin(&kernel.getBody().front());
95 llvm::SmallDenseMap<int64_t, Value> constants;
96 auto loc = kernel.getLoc();
97 for (auto operand : launch.operands()) {
98 auto memref = operand.getType().dyn_cast<MemRefType>();
99 if (!memref) {
100 // Scalar argument, advance kernel position by one.
101 kernel_p++;
102 continue;
103 }
104 if (allocated_by_tf_runtime.contains(operand)) {
105 // This was allocated by the tf runtime, so the two pointers in the
106 // descriptor coincide. Rewrite the kernel accordingly.
107 Value alloc_ptr = kernel.getArgument(kernel_p);
108 Value align_ptr = kernel.getArgument(kernel_p + 1);
109 alloc_ptr.replaceAllUsesWith(align_ptr);
110 kernel.setArgAttr(
111 kernel_p + 1, LLVM::LLVMDialect::getAlignAttrName(),
112 b.getIndexAttr(
113 tf_framework::TFFrameworkDialect::kAllocationAlignment));
114 }
115 if (offset_is_zero.contains(operand)) {
116 Value offset = kernel.getArgument(kernel_p + 2);
117 Value &zero = constants[0];
118 if (!zero) {
119 zero = b.create<LLVM::ConstantOp>(loc, offset.getType(),
120 b.getIndexAttr(0));
121 }
122 offset.replaceAllUsesWith(zero);
123 }
124 auto const_stride = inner_stride_is_constant.find(operand);
125 if (const_stride != inner_stride_is_constant.end()) {
126 // The stride is the last argument belonging to this memref.
127 Value inner_stride =
128 kernel.getArgument(kernel_p + 2 + memref.getRank() * 2);
129 Value &stride_val = constants[const_stride->second];
130 if (!stride_val) {
131 stride_val = b.create<LLVM::ConstantOp>(
132 loc, inner_stride.getType(),
133 b.getIndexAttr(const_stride->second));
134 }
135 inner_stride.replaceAllUsesWith(stride_val);
136 }
137 if (no_alias.contains(operand)) {
138 // TODO(herhut): We also need to check whether any of the other args
139 // are aliases. This is currently never the case by construction
140 // but we could use the alias analysis from buffer placement here
141 // to make sure.
142 // Add the no_alias attribute to the corresponding pointer.
143 kernel.setArgAttr(kernel_p + 1,
144 LLVM::LLVMDialect::getNoAliasAttrName(),
145 b.getUnitAttr());
146 }
147 // Advance base, aligned, offset, strides and sizes many arguments.
148 kernel_p += memref.getRank() * 2 + 3;
149 }
150 });
151 }
152
153 private:
propagateThroughUsesmlir::kernel_gen::transforms::__anondd5d2af30111::PropagateTfAbiKnowledgeToKernelsPass154 void propagateThroughUses(SmallVectorImpl<Value> &worklist) {
155 while (!worklist.empty()) {
156 Value candidate = worklist.pop_back_val();
157 for (auto user : candidate.getUsers()) {
158 if (isa<memref::CastOp, memref::ReshapeOp>(user)) {
159 // Reshape and Cast propagate alignment, offset and innermost stride.
160 // TODO(herhut): This should be a trait.
161 Value result = user->getResult(0);
162 if (allocated_by_tf_runtime.contains(candidate)) {
163 allocated_by_tf_runtime.insert(result);
164 }
165 auto const_stride = inner_stride_is_constant.find(candidate);
166 if (const_stride != inner_stride_is_constant.end()) {
167 inner_stride_is_constant.insert({result, const_stride->second});
168 }
169 if (offset_is_zero.contains(candidate)) {
170 offset_is_zero.insert(result);
171 }
172 worklist.push_back(result);
173 }
174 if (auto cast = dyn_cast<memref::ReinterpretCastOp>(user)) {
175 // Check that we have offset 0.
176 Value result = cast.getResult();
177 if (!cast.isDynamicOffset(0) && cast.getStaticOffset(0) == 0) {
178 offset_is_zero.insert(result);
179 }
180 if (allocated_by_tf_runtime.contains(candidate)) {
181 allocated_by_tf_runtime.insert(result);
182 }
183 size_t last_stride = cast.getResultRank() - 1;
184 // TODO(herhut): Remove this once canonicalization handles this.
185 if (cast.isDynamicStride(last_stride)) {
186 auto dyn_stride = cast.getDynamicStride(last_stride)
187 .getDefiningOp<arith::ConstantIndexOp>();
188 if (dyn_stride) {
189 inner_stride_is_constant.insert({result, dyn_stride.value()});
190 }
191 } else {
192 inner_stride_is_constant.insert(
193 {result, cast.getStaticStride(last_stride)});
194 }
195 worklist.push_back(result);
196 }
197 }
198 }
199 }
200
201 // Set of values that were allocated by the tf runtime and hence are aligned.
202 llvm::SmallPtrSet<Value, 8> allocated_by_tf_runtime;
203 // Set of values that are known to not have an offset of 0.
204 llvm::SmallPtrSet<Value, 8> offset_is_zero;
205 // Set of values that are known to have a constant stride.
206 llvm::SmallDenseMap<Value, int64_t, 8> inner_stride_is_constant;
207 // Set of values we know do not alias other values.
208 llvm::SmallPtrSet<Value, 8> no_alias;
209 };
210
211 } // namespace
212
213 std::unique_ptr<OperationPass<func::FuncOp>>
CreatePropagateTfAbiKnowledgeToKernels()214 CreatePropagateTfAbiKnowledgeToKernels() {
215 return std::make_unique<PropagateTfAbiKnowledgeToKernelsPass>();
216 }
217
218 } // namespace transforms
219 } // namespace kernel_gen
220 } // namespace mlir
221