1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file contains the analysis and transformation to rewrite kernel
17 // functions such that they use a single set of arguments for the strides and
18 // sizes of operands with equal shapes.
19 
20 #include <memory>
21 
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/DenseMapInfo.h"
24 #include "llvm/ADT/EquivalenceClasses.h"
25 #include "llvm/ADT/Hashing.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/Support/Debug.h"
28 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
29 #include "mlir/Dialect/GPU/IR/GPUDialect.h"  // from @llvm-project
30 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"  // from @llvm-project
31 #include "mlir/IR/AsmState.h"  // from @llvm-project
32 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
33 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
34 #include "mlir/IR/Value.h"  // from @llvm-project
35 #include "mlir/Support/LLVM.h"  // from @llvm-project
36 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
37 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
38 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
39 
40 #define DEBUG_TYPE "kernel-gen-shapes"
41 
42 namespace {
43 
44 using mlir::ArrayRef;
45 using mlir::SmallVector;
46 using mlir::Value;
47 
48 /// Represents a value or constant. Used to unify operands for operations that
49 /// take both ssa values and attributes.
50 struct ValueOrConst {
ValueOrConst__anon120b5d2c0111::ValueOrConst51   explicit ValueOrConst(Value v) : value_or_constant(v), is_constant(false) {}
ValueOrConst__anon120b5d2c0111::ValueOrConst52   explicit ValueOrConst(int64_t c) : value_or_constant(c), is_constant(true) {}
53 
value__anon120b5d2c0111::ValueOrConst54   Value value() const {
55     assert(!is_constant);
56     return value_or_constant.value;
57   }
58 
constant__anon120b5d2c0111::ValueOrConst59   int64_t constant() const {
60     assert(is_constant);
61     return value_or_constant.constant;
62   }
63 
isConstant__anon120b5d2c0111::ValueOrConst64   bool isConstant() const { return is_constant; }
65 
66  private:
67   union ValueOrConstStorage {
ValueOrConstStorage(Value v)68     explicit ValueOrConstStorage(Value v) : value(v) {}
ValueOrConstStorage(size_t c)69     explicit ValueOrConstStorage(size_t c) : constant(c) {}
70 
71     Value value;
72     int64_t constant;
73   } value_or_constant;
74 
75   bool is_constant;
76 };
77 
hash_value(ValueOrConst value)78 llvm::hash_code hash_value(ValueOrConst value) {
79   return value.isConstant() ? static_cast<llvm::hash_code>(value.constant())
80                             : mlir::hash_value(value.value());
81 }
82 
operator ==(ValueOrConst lhs,ValueOrConst rhs)83 bool operator==(ValueOrConst lhs, ValueOrConst rhs) {
84   if (lhs.isConstant()) {
85     return rhs.isConstant() && lhs.constant() == rhs.constant();
86   } else {
87     return !rhs.isConstant() && lhs.value() == rhs.value();
88   }
89 }
90 
operator <<(llvm::raw_ostream & os,const ValueOrConst & value)91 inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
92                                      const ValueOrConst &value) {
93   if (value.isConstant()) {
94     os << value.constant();
95   } else {
96     Value val = value.value();
97     mlir::AsmState asm_state(
98         val.getParentRegion()->getParentOfType<mlir::func::FuncOp>());
99     val.printAsOperand(os, asm_state);
100   }
101   return os;
102 }
103 
104 /// Represents a shape, as either a single SSA value that represents the entire
105 /// shape vector or as a vector of SSA values representing scalars.
106 struct ShapeValue {
ShapeValue__anon120b5d2c0111::ShapeValue107   explicit ShapeValue(Value vector)
108       : shape({ValueOrConst{vector}}), is_vector(true) {}
ShapeValue__anon120b5d2c0111::ShapeValue109   explicit ShapeValue(ValueOrConst vector) : shape({vector}), is_vector(true) {
110     assert(!vector.isConstant());
111   }
112   template <typename T>
ShapeValue__anon120b5d2c0111::ShapeValue113   explicit ShapeValue(T values)
114       : shape(values.begin(), values.end()), is_vector(false) {}
115 
vector__anon120b5d2c0111::ShapeValue116   ValueOrConst vector() const {
117     assert(is_vector);
118     return shape.front();
119   }
120 
scalars__anon120b5d2c0111::ShapeValue121   ArrayRef<ValueOrConst> scalars() const {
122     assert(!is_vector);
123     return llvm::makeArrayRef(shape);
124   }
125 
isVector__anon120b5d2c0111::ShapeValue126   bool isVector() const { return is_vector; }
127 
128  private:
129   SmallVector<ValueOrConst, 4> shape;
130   bool is_vector;
131 };
132 
hash_value(ShapeValue shape)133 llvm::hash_code hash_value(ShapeValue shape) {
134   return shape.isVector() ? hash_value(shape.vector())
135                           : hash_value(shape.scalars());
136 }
137 
operator ==(ShapeValue lhs,ShapeValue rhs)138 bool operator==(ShapeValue lhs, ShapeValue rhs) {
139   if (lhs.isVector()) {
140     return rhs.isVector() && lhs.vector() == rhs.vector();
141   } else {
142     return !rhs.isVector() && lhs.scalars() == rhs.scalars();
143   }
144 }
145 
operator <<(llvm::raw_ostream & os,const ShapeValue & shape)146 inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
147                                      const ShapeValue &shape) {
148   if (shape.isVector()) {
149     os << shape.vector();
150     return os;
151   }
152   os << "[";
153   bool first = true;
154   for (auto scalar : shape.scalars()) {
155     if (!first) {
156       os << ", ";
157     }
158     first = false;
159     os << scalar;
160   }
161   os << "]";
162   return os;
163 }
164 
165 }  // namespace
166 
167 namespace llvm {
168 
169 template <>
170 struct DenseMapInfo<ShapeValue> {
getEmptyKeyllvm::DenseMapInfo171   static ShapeValue getEmptyKey() {
172     return ShapeValue(DenseMapInfo<mlir::Value>::getEmptyKey());
173   }
getTombstoneKeyllvm::DenseMapInfo174   static ShapeValue getTombstoneKey() {
175     return ShapeValue(DenseMapInfo<mlir::Value>::getTombstoneKey());
176   }
getHashValuellvm::DenseMapInfo177   static unsigned getHashValue(ShapeValue shape) { return hash_value(shape); }
isEqualllvm::DenseMapInfo178   static bool isEqual(ShapeValue LHS, ShapeValue RHS) { return LHS == RHS; }
179 };
180 
181 }  // namespace llvm
182 
183 namespace mlir {
184 namespace kernel_gen {
185 namespace transforms {
186 
187 namespace {
188 
189 #define GEN_PASS_CLASSES
190 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
191 
192 // A basic shape equality inference. This should be superceeded by a proper
193 // inference once available. Until then, we just build this out to the needs of
194 // the kernel generator project.
195 class ShapeEqualityKnowledge {
196  public:
197   /// Checks all operations for potential shape equality of their respective
198   /// results.
build(func::FuncOp function)199   void build(func::FuncOp function) {
200     function.walk([&](Operation *op) {
201       if (auto reshape = dyn_cast<memref::ReshapeOp>(op)) {
202         registerAssociation(ShapeValue{reshape.getShape()},
203                             reshape.getResult());
204         return;
205       }
206       if (auto cast = dyn_cast<memref::ReinterpretCastOp>(op)) {
207         // Only support fully dynamic sizes for now.
208         // TODO(herhut): Fix once the op has canonicalizers that break this.
209         for (unsigned int p = 0, e = cast.getResultRank(); p < e; ++p) {
210           if (!cast.isDynamicSize(p)) {
211             return;
212           }
213         }
214         registerAssociation(ShapeValue{cast.getSizes()}, cast.getResult());
215         return;
216       }
217       if (auto alloc = dyn_cast<memref::AllocOp>(op)) {
218         SmallVector<ValueOrConst, 4> shape;
219         ShapedType type = alloc.getResult().getType().cast<ShapedType>();
220         fillShapeFromAllocLike(alloc.getDynamicSizes(), type, shape);
221         registerAssociation(ShapeValue{shape}, alloc.getResult());
222         return;
223       }
224       if (auto alloc = dyn_cast<tf_framework::TFAllocOp>(op)) {
225         // Construct a symbol representing the allocated shape.
226         SmallVector<ValueOrConst, 4> shape;
227         ShapedType type = alloc.getResult().getType().cast<ShapedType>();
228         fillShapeFromAllocLike(alloc.dyn_sizes(), type, shape);
229         registerAssociation(ShapeValue{shape}, alloc.getResult());
230         return;
231       }
232     });
233   }
234 
235   /// Checks whether `one` and `other` are known to have the same shape and
236   /// strides.
haveSameShape(Value one,Value other)237   bool haveSameShape(Value one, Value other) {
238     return equal_shapes_.isEquivalent(one.getAsOpaquePointer(),
239                                       other.getAsOpaquePointer());
240   }
241 
242  private:
fillShapeFromAllocLike(mlir::OperandRange operands,ShapedType type,SmallVectorImpl<ValueOrConst> & shape)243   static void fillShapeFromAllocLike(mlir::OperandRange operands,
244                                      ShapedType type,
245                                      SmallVectorImpl<ValueOrConst> &shape) {
246     assert(type.hasRank());
247     auto dynamic_sizes = operands.begin();
248     for (auto extent : type.getShape()) {
249       shape.push_back(ShapedType::isDynamic(extent)
250                           ? ValueOrConst{*(dynamic_sizes++)}
251                           : ValueOrConst{extent});
252     }
253   }
254 
255   /// Registers the value `value` to have the shape represented by `shape`. If
256   /// `shape` has been registered before, place `value` into the same
257   /// equivalence class. Otherwise register `value` as an equivalence class of
258   /// its own.
registerAssociation(ShapeValue shape,Value value)259   void registerAssociation(ShapeValue shape, Value value) {
260     LLVM_DEBUG({ llvm::dbgs() << "Processing " << value << "\n"; });
261     auto insert_symbolic = symbolic_shapes_.insert({shape, value});
262     if (insert_symbolic.second) {
263       LLVM_DEBUG({ llvm::dbgs() << "New symbolic shape " << shape << "\n"; });
264       equal_shapes_.insert(value.getAsOpaquePointer());
265       // We have seen this symbolic shape for the first time. Try to match it
266       // with a vector or shape we already know and alias classes if possible.
267       // This could be based on shape dialect if we weren't late in the
268       // lowering.
269       tryEvaluateShapeToRoot(shape, value);
270     } else {
271       auto rep = insert_symbolic.first->second;
272       LLVM_DEBUG({ llvm::dbgs() << "Aliasing with rep " << rep << "\n"; });
273       equal_shapes_.unionSets(rep.getAsOpaquePointer(),
274                               value.getAsOpaquePointer());
275     }
276   }
277 
278   /// Follows the definition chains of the ShapeValue `shape` to identify cases
279   /// where `shape` is derived from some other value's shape. In such case, the
280   /// equivalence classes of that other value and `value` are unioned.
281   /// This is based on pattern matching and not complete.
tryEvaluateShapeToRoot(ShapeValue shape,Value value)282   void tryEvaluateShapeToRoot(ShapeValue shape, Value value) {
283     // Just some pattern matching for common cases here.
284     if (!shape.isVector()) {
285       // Patterns that revolve around scalars.
286       // Check whether the scalars are all dim operations for some other memref.
287       Value candidate;
288       bool all_are_dimops =
289           llvm::all_of(llvm::enumerate(shape.scalars()), [&candidate](auto p) {
290             ValueOrConst val = p.value();
291             if (val.isConstant()) return false;
292             auto dimOp = val.value().getDefiningOp<memref::DimOp>();
293             if (!dimOp) return false;
294             if (!candidate) candidate = dimOp.getSource();
295             auto index = dimOp.getConstantIndex();
296             if (!index.has_value()) return false;
297             return candidate == dimOp.getSource() &&
298                    p.index() == index.getValue();
299           });
300       if (all_are_dimops && candidate) {
301         equal_shapes_.unionSets(candidate.getAsOpaquePointer(),
302                                 value.getAsOpaquePointer());
303       }
304     }
305   }
306 
307   // These are values with identical shapes (or rather their opaque pointers).
308   llvm::EquivalenceClasses<void *> equal_shapes_;
309   // A map from a value that encodes a shape to a value that has this shape.
310   llvm::DenseMap<ShapeValue, Value> symbolic_shapes_;
311 };
312 
313 /// For arguments to kernels that have the same shape, use the stride and
314 /// shape information of the left-most argument inside of the kernel function.
315 /// That way, llvm can CSE index computations on same-shaped inputs.
316 struct PropagateShapeKnowledgeToKernels
317     : public PropagateShapeKnowledgeToKernelsBase<
318           PropagateShapeKnowledgeToKernels> {
runOnOperationmlir::kernel_gen::transforms::__anon120b5d2c0211::PropagateShapeKnowledgeToKernels319   void runOnOperation() override {
320     ShapeEqualityKnowledge knowledge;
321 
322     knowledge.build(getOperation());
323 
324     getOperation().walk([&](gpu::LaunchFuncOp launch) {
325       auto module = launch->getParentOfType<ModuleOp>();
326       auto kernel = module.lookupSymbol<LLVM::LLVMFuncOp>(launch.kernel());
327 
328       if (!kernel || kernel.isExternal()) return;
329 
330       llvm::SmallVector<std::pair<Value, int>, 4> seen_memrefs;
331       // Position of the kernel argument we are currently at.
332       int kernel_p = 0;
333       for (auto operand : launch.operands()) {
334         auto memref = operand.getType().dyn_cast<MemRefType>();
335         if (!memref) {
336           // Scalar argument, advance kernel position by one.
337           kernel_p++;
338           continue;
339         }
340         for (auto previous : seen_memrefs) {
341           if (!knowledge.haveSameShape(operand, previous.first)) {
342             continue;
343           }
344           auto previous_type = previous.first.getType().cast<MemRefType>();
345           // We use the first equality found and replace uses of corresponding
346           // size and (potentially) stride information here.
347           auto args_to_replace = memref.getRank();
348           // If both memrefs have identity layouts, we can also reuse the
349           // strides here, as they are the identity strides and hence fully
350           // determinded by the shape.
351           if (previous_type.getLayout().isIdentity() &&
352               memref.getLayout().isIdentity()) {
353             args_to_replace *= 2;
354           }
355           int previous_args_pos = previous.second;
356           auto previous_args = kernel.getArguments()
357                                    .drop_front(previous_args_pos + 3)
358                                    .take_front(args_to_replace);
359           auto current_args = kernel.getArguments()
360                                   .drop_front(kernel_p + 3)
361                                   .take_front(args_to_replace);
362           for (auto pair : llvm::zip(previous_args, current_args)) {
363             mlir::BlockArgument prev, curr;
364             std::tie(prev, curr) = pair;
365             curr.replaceAllUsesWith(prev);
366           }
367           break;
368         }
369         seen_memrefs.push_back({operand, kernel_p});
370         // Advance base, aligned, offset, strides and sizes many arguments.
371         kernel_p += memref.getRank() * 2 + 3;
372       }
373     });
374   }
375 };
376 
377 }  // namespace
378 
379 std::unique_ptr<OperationPass<func::FuncOp>>
CreatePropagateShapeKnowledgeToKernels()380 CreatePropagateShapeKnowledgeToKernels() {
381   return std::make_unique<PropagateShapeKnowledgeToKernels>();
382 }
383 
384 }  // namespace transforms
385 }  // namespace kernel_gen
386 }  // namespace mlir
387