1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file contains the analysis and transformation to rewrite kernel
17 // functions such that they use a single set of arguments for the strides and
18 // sizes of operands with equal shapes.
19
20 #include <memory>
21
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/DenseMapInfo.h"
24 #include "llvm/ADT/EquivalenceClasses.h"
25 #include "llvm/ADT/Hashing.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/Support/Debug.h"
28 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
29 #include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project
30 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
31 #include "mlir/IR/AsmState.h" // from @llvm-project
32 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
33 #include "mlir/IR/MLIRContext.h" // from @llvm-project
34 #include "mlir/IR/Value.h" // from @llvm-project
35 #include "mlir/Support/LLVM.h" // from @llvm-project
36 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
37 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
38 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
39
40 #define DEBUG_TYPE "kernel-gen-shapes"
41
42 namespace {
43
44 using mlir::ArrayRef;
45 using mlir::SmallVector;
46 using mlir::Value;
47
48 /// Represents a value or constant. Used to unify operands for operations that
49 /// take both ssa values and attributes.
50 struct ValueOrConst {
ValueOrConst__anon120b5d2c0111::ValueOrConst51 explicit ValueOrConst(Value v) : value_or_constant(v), is_constant(false) {}
ValueOrConst__anon120b5d2c0111::ValueOrConst52 explicit ValueOrConst(int64_t c) : value_or_constant(c), is_constant(true) {}
53
value__anon120b5d2c0111::ValueOrConst54 Value value() const {
55 assert(!is_constant);
56 return value_or_constant.value;
57 }
58
constant__anon120b5d2c0111::ValueOrConst59 int64_t constant() const {
60 assert(is_constant);
61 return value_or_constant.constant;
62 }
63
isConstant__anon120b5d2c0111::ValueOrConst64 bool isConstant() const { return is_constant; }
65
66 private:
67 union ValueOrConstStorage {
ValueOrConstStorage(Value v)68 explicit ValueOrConstStorage(Value v) : value(v) {}
ValueOrConstStorage(size_t c)69 explicit ValueOrConstStorage(size_t c) : constant(c) {}
70
71 Value value;
72 int64_t constant;
73 } value_or_constant;
74
75 bool is_constant;
76 };
77
hash_value(ValueOrConst value)78 llvm::hash_code hash_value(ValueOrConst value) {
79 return value.isConstant() ? static_cast<llvm::hash_code>(value.constant())
80 : mlir::hash_value(value.value());
81 }
82
operator ==(ValueOrConst lhs,ValueOrConst rhs)83 bool operator==(ValueOrConst lhs, ValueOrConst rhs) {
84 if (lhs.isConstant()) {
85 return rhs.isConstant() && lhs.constant() == rhs.constant();
86 } else {
87 return !rhs.isConstant() && lhs.value() == rhs.value();
88 }
89 }
90
operator <<(llvm::raw_ostream & os,const ValueOrConst & value)91 inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
92 const ValueOrConst &value) {
93 if (value.isConstant()) {
94 os << value.constant();
95 } else {
96 Value val = value.value();
97 mlir::AsmState asm_state(
98 val.getParentRegion()->getParentOfType<mlir::func::FuncOp>());
99 val.printAsOperand(os, asm_state);
100 }
101 return os;
102 }
103
104 /// Represents a shape, as either a single SSA value that represents the entire
105 /// shape vector or as a vector of SSA values representing scalars.
106 struct ShapeValue {
ShapeValue__anon120b5d2c0111::ShapeValue107 explicit ShapeValue(Value vector)
108 : shape({ValueOrConst{vector}}), is_vector(true) {}
ShapeValue__anon120b5d2c0111::ShapeValue109 explicit ShapeValue(ValueOrConst vector) : shape({vector}), is_vector(true) {
110 assert(!vector.isConstant());
111 }
112 template <typename T>
ShapeValue__anon120b5d2c0111::ShapeValue113 explicit ShapeValue(T values)
114 : shape(values.begin(), values.end()), is_vector(false) {}
115
vector__anon120b5d2c0111::ShapeValue116 ValueOrConst vector() const {
117 assert(is_vector);
118 return shape.front();
119 }
120
scalars__anon120b5d2c0111::ShapeValue121 ArrayRef<ValueOrConst> scalars() const {
122 assert(!is_vector);
123 return llvm::makeArrayRef(shape);
124 }
125
isVector__anon120b5d2c0111::ShapeValue126 bool isVector() const { return is_vector; }
127
128 private:
129 SmallVector<ValueOrConst, 4> shape;
130 bool is_vector;
131 };
132
hash_value(ShapeValue shape)133 llvm::hash_code hash_value(ShapeValue shape) {
134 return shape.isVector() ? hash_value(shape.vector())
135 : hash_value(shape.scalars());
136 }
137
operator ==(ShapeValue lhs,ShapeValue rhs)138 bool operator==(ShapeValue lhs, ShapeValue rhs) {
139 if (lhs.isVector()) {
140 return rhs.isVector() && lhs.vector() == rhs.vector();
141 } else {
142 return !rhs.isVector() && lhs.scalars() == rhs.scalars();
143 }
144 }
145
operator <<(llvm::raw_ostream & os,const ShapeValue & shape)146 inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
147 const ShapeValue &shape) {
148 if (shape.isVector()) {
149 os << shape.vector();
150 return os;
151 }
152 os << "[";
153 bool first = true;
154 for (auto scalar : shape.scalars()) {
155 if (!first) {
156 os << ", ";
157 }
158 first = false;
159 os << scalar;
160 }
161 os << "]";
162 return os;
163 }
164
165 } // namespace
166
167 namespace llvm {
168
169 template <>
170 struct DenseMapInfo<ShapeValue> {
getEmptyKeyllvm::DenseMapInfo171 static ShapeValue getEmptyKey() {
172 return ShapeValue(DenseMapInfo<mlir::Value>::getEmptyKey());
173 }
getTombstoneKeyllvm::DenseMapInfo174 static ShapeValue getTombstoneKey() {
175 return ShapeValue(DenseMapInfo<mlir::Value>::getTombstoneKey());
176 }
getHashValuellvm::DenseMapInfo177 static unsigned getHashValue(ShapeValue shape) { return hash_value(shape); }
isEqualllvm::DenseMapInfo178 static bool isEqual(ShapeValue LHS, ShapeValue RHS) { return LHS == RHS; }
179 };
180
181 } // namespace llvm
182
183 namespace mlir {
184 namespace kernel_gen {
185 namespace transforms {
186
187 namespace {
188
189 #define GEN_PASS_CLASSES
190 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
191
192 // A basic shape equality inference. This should be superceeded by a proper
193 // inference once available. Until then, we just build this out to the needs of
194 // the kernel generator project.
195 class ShapeEqualityKnowledge {
196 public:
197 /// Checks all operations for potential shape equality of their respective
198 /// results.
build(func::FuncOp function)199 void build(func::FuncOp function) {
200 function.walk([&](Operation *op) {
201 if (auto reshape = dyn_cast<memref::ReshapeOp>(op)) {
202 registerAssociation(ShapeValue{reshape.getShape()},
203 reshape.getResult());
204 return;
205 }
206 if (auto cast = dyn_cast<memref::ReinterpretCastOp>(op)) {
207 // Only support fully dynamic sizes for now.
208 // TODO(herhut): Fix once the op has canonicalizers that break this.
209 for (unsigned int p = 0, e = cast.getResultRank(); p < e; ++p) {
210 if (!cast.isDynamicSize(p)) {
211 return;
212 }
213 }
214 registerAssociation(ShapeValue{cast.getSizes()}, cast.getResult());
215 return;
216 }
217 if (auto alloc = dyn_cast<memref::AllocOp>(op)) {
218 SmallVector<ValueOrConst, 4> shape;
219 ShapedType type = alloc.getResult().getType().cast<ShapedType>();
220 fillShapeFromAllocLike(alloc.getDynamicSizes(), type, shape);
221 registerAssociation(ShapeValue{shape}, alloc.getResult());
222 return;
223 }
224 if (auto alloc = dyn_cast<tf_framework::TFAllocOp>(op)) {
225 // Construct a symbol representing the allocated shape.
226 SmallVector<ValueOrConst, 4> shape;
227 ShapedType type = alloc.getResult().getType().cast<ShapedType>();
228 fillShapeFromAllocLike(alloc.dyn_sizes(), type, shape);
229 registerAssociation(ShapeValue{shape}, alloc.getResult());
230 return;
231 }
232 });
233 }
234
235 /// Checks whether `one` and `other` are known to have the same shape and
236 /// strides.
haveSameShape(Value one,Value other)237 bool haveSameShape(Value one, Value other) {
238 return equal_shapes_.isEquivalent(one.getAsOpaquePointer(),
239 other.getAsOpaquePointer());
240 }
241
242 private:
fillShapeFromAllocLike(mlir::OperandRange operands,ShapedType type,SmallVectorImpl<ValueOrConst> & shape)243 static void fillShapeFromAllocLike(mlir::OperandRange operands,
244 ShapedType type,
245 SmallVectorImpl<ValueOrConst> &shape) {
246 assert(type.hasRank());
247 auto dynamic_sizes = operands.begin();
248 for (auto extent : type.getShape()) {
249 shape.push_back(ShapedType::isDynamic(extent)
250 ? ValueOrConst{*(dynamic_sizes++)}
251 : ValueOrConst{extent});
252 }
253 }
254
255 /// Registers the value `value` to have the shape represented by `shape`. If
256 /// `shape` has been registered before, place `value` into the same
257 /// equivalence class. Otherwise register `value` as an equivalence class of
258 /// its own.
registerAssociation(ShapeValue shape,Value value)259 void registerAssociation(ShapeValue shape, Value value) {
260 LLVM_DEBUG({ llvm::dbgs() << "Processing " << value << "\n"; });
261 auto insert_symbolic = symbolic_shapes_.insert({shape, value});
262 if (insert_symbolic.second) {
263 LLVM_DEBUG({ llvm::dbgs() << "New symbolic shape " << shape << "\n"; });
264 equal_shapes_.insert(value.getAsOpaquePointer());
265 // We have seen this symbolic shape for the first time. Try to match it
266 // with a vector or shape we already know and alias classes if possible.
267 // This could be based on shape dialect if we weren't late in the
268 // lowering.
269 tryEvaluateShapeToRoot(shape, value);
270 } else {
271 auto rep = insert_symbolic.first->second;
272 LLVM_DEBUG({ llvm::dbgs() << "Aliasing with rep " << rep << "\n"; });
273 equal_shapes_.unionSets(rep.getAsOpaquePointer(),
274 value.getAsOpaquePointer());
275 }
276 }
277
278 /// Follows the definition chains of the ShapeValue `shape` to identify cases
279 /// where `shape` is derived from some other value's shape. In such case, the
280 /// equivalence classes of that other value and `value` are unioned.
281 /// This is based on pattern matching and not complete.
tryEvaluateShapeToRoot(ShapeValue shape,Value value)282 void tryEvaluateShapeToRoot(ShapeValue shape, Value value) {
283 // Just some pattern matching for common cases here.
284 if (!shape.isVector()) {
285 // Patterns that revolve around scalars.
286 // Check whether the scalars are all dim operations for some other memref.
287 Value candidate;
288 bool all_are_dimops =
289 llvm::all_of(llvm::enumerate(shape.scalars()), [&candidate](auto p) {
290 ValueOrConst val = p.value();
291 if (val.isConstant()) return false;
292 auto dimOp = val.value().getDefiningOp<memref::DimOp>();
293 if (!dimOp) return false;
294 if (!candidate) candidate = dimOp.getSource();
295 auto index = dimOp.getConstantIndex();
296 if (!index.has_value()) return false;
297 return candidate == dimOp.getSource() &&
298 p.index() == index.getValue();
299 });
300 if (all_are_dimops && candidate) {
301 equal_shapes_.unionSets(candidate.getAsOpaquePointer(),
302 value.getAsOpaquePointer());
303 }
304 }
305 }
306
307 // These are values with identical shapes (or rather their opaque pointers).
308 llvm::EquivalenceClasses<void *> equal_shapes_;
309 // A map from a value that encodes a shape to a value that has this shape.
310 llvm::DenseMap<ShapeValue, Value> symbolic_shapes_;
311 };
312
313 /// For arguments to kernels that have the same shape, use the stride and
314 /// shape information of the left-most argument inside of the kernel function.
315 /// That way, llvm can CSE index computations on same-shaped inputs.
316 struct PropagateShapeKnowledgeToKernels
317 : public PropagateShapeKnowledgeToKernelsBase<
318 PropagateShapeKnowledgeToKernels> {
runOnOperationmlir::kernel_gen::transforms::__anon120b5d2c0211::PropagateShapeKnowledgeToKernels319 void runOnOperation() override {
320 ShapeEqualityKnowledge knowledge;
321
322 knowledge.build(getOperation());
323
324 getOperation().walk([&](gpu::LaunchFuncOp launch) {
325 auto module = launch->getParentOfType<ModuleOp>();
326 auto kernel = module.lookupSymbol<LLVM::LLVMFuncOp>(launch.kernel());
327
328 if (!kernel || kernel.isExternal()) return;
329
330 llvm::SmallVector<std::pair<Value, int>, 4> seen_memrefs;
331 // Position of the kernel argument we are currently at.
332 int kernel_p = 0;
333 for (auto operand : launch.operands()) {
334 auto memref = operand.getType().dyn_cast<MemRefType>();
335 if (!memref) {
336 // Scalar argument, advance kernel position by one.
337 kernel_p++;
338 continue;
339 }
340 for (auto previous : seen_memrefs) {
341 if (!knowledge.haveSameShape(operand, previous.first)) {
342 continue;
343 }
344 auto previous_type = previous.first.getType().cast<MemRefType>();
345 // We use the first equality found and replace uses of corresponding
346 // size and (potentially) stride information here.
347 auto args_to_replace = memref.getRank();
348 // If both memrefs have identity layouts, we can also reuse the
349 // strides here, as they are the identity strides and hence fully
350 // determinded by the shape.
351 if (previous_type.getLayout().isIdentity() &&
352 memref.getLayout().isIdentity()) {
353 args_to_replace *= 2;
354 }
355 int previous_args_pos = previous.second;
356 auto previous_args = kernel.getArguments()
357 .drop_front(previous_args_pos + 3)
358 .take_front(args_to_replace);
359 auto current_args = kernel.getArguments()
360 .drop_front(kernel_p + 3)
361 .take_front(args_to_replace);
362 for (auto pair : llvm::zip(previous_args, current_args)) {
363 mlir::BlockArgument prev, curr;
364 std::tie(prev, curr) = pair;
365 curr.replaceAllUsesWith(prev);
366 }
367 break;
368 }
369 seen_memrefs.push_back({operand, kernel_p});
370 // Advance base, aligned, offset, strides and sizes many arguments.
371 kernel_p += memref.getRank() * 2 + 3;
372 }
373 });
374 }
375 };
376
377 } // namespace
378
379 std::unique_ptr<OperationPass<func::FuncOp>>
CreatePropagateShapeKnowledgeToKernels()380 CreatePropagateShapeKnowledgeToKernels() {
381 return std::make_unique<PropagateShapeKnowledgeToKernels>();
382 }
383
384 } // namespace transforms
385 } // namespace kernel_gen
386 } // namespace mlir
387