1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef MLIR_HLO_ANALYSIS_SHAPE_COMPONENT_ANALYSIS_H 17 #define MLIR_HLO_ANALYSIS_SHAPE_COMPONENT_ANALYSIS_H 18 19 #include "llvm/Support/raw_ostream.h" 20 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" 21 #include "mlir/IR/AffineMap.h" 22 #include "mlir/IR/Value.h" 23 24 namespace mlir { 25 26 // Analysis to infer shape information. 27 // 28 // This lazily analyzes the individual components of a shape (e.g., the 29 // dimensions of a tensor) or value (e.g, the elements of a shape tensor). 30 // Results are cached but the cache is not consistent across IR mutations and 31 // needs to be reset in that case. 32 class ShapeComponentAnalysis { 33 public: 34 // Represents the analysis request for a specific value. We are either 35 // interested in the shape of a value or the value itself. 36 class ShapeOrValueInfo { 37 llvm::PointerIntPair<Value, 1, bool> p; 38 ShapeOrValueInfo(decltype (p)p)39 explicit ShapeOrValueInfo(decltype(p) p) : p(p) {} ShapeOrValueInfo(Value v,bool isValueInfo)40 ShapeOrValueInfo(Value v, bool isValueInfo) : p(v, isValueInfo) {} 41 42 public: getShapeInfoOf(Value v)43 static ShapeOrValueInfo getShapeInfoOf(Value v) { return {v, false}; } getValueInfoOf(Value v)44 static ShapeOrValueInfo getValueInfoOf(Value v) { return {v, true}; } value()45 Value value() const { return p.getPointer(); } isValueInfo()46 bool isValueInfo() const { return p.getInt(); } isShapeInfo()47 bool isShapeInfo() const { return !isValueInfo(); } 48 49 bool operator==(ShapeOrValueInfo rhs) const { return p == rhs.p; } 50 bool operator!=(ShapeOrValueInfo rhs) const { return !(*this == rhs); } 51 52 // Forward p's DenseMapInfo. 53 struct DenseMapInfo { 54 using PairInfo = llvm::DenseMapInfo<decltype(p)>; getEmptyKeyDenseMapInfo55 static inline ShapeOrValueInfo getEmptyKey() { 56 return ShapeOrValueInfo(PairInfo::getEmptyKey()); 57 } getTombstoneKeyDenseMapInfo58 static inline ShapeOrValueInfo getTombstoneKey() { 59 return ShapeOrValueInfo(PairInfo::getTombstoneKey()); 60 } getHashValueDenseMapInfo61 static unsigned getHashValue(ShapeOrValueInfo val) { 62 return PairInfo::getHashValue(val.p); 63 } isEqualDenseMapInfo64 static bool isEqual(ShapeOrValueInfo lhs, ShapeOrValueInfo rhs) { 65 return lhs == rhs; 66 } 67 }; 68 }; 69 70 // Symbolically represents one component of a shape (e.g., the dimensions of a 71 // tensor) or value (e.g, the elements of a shape tensor). This is used to tie 72 // symbolic expressions to components of shapes or values. 73 struct Symbol { 74 ShapeOrValueInfo source; 75 size_t index; 76 77 bool operator==(const Symbol &rhs) const { 78 return source == rhs.source && index == rhs.index; 79 } 80 bool operator!=(const Symbol &rhs) const { return !(*this == rhs); } 81 }; 82 83 // Represents the analysis result for a one component of a shape (e.g., the 84 // dimensions of a tensor) or value (e.g, the elements of a shape tensor). 85 // This can be a constant or an expression over symbols. 86 struct SymbolicExpr { 87 SmallVector<Symbol, 1> symbols; 88 AffineExpr expr; 89 90 // Returns true if this symbolic expression is known to be a constant equal 91 // to `value`. 92 bool isConstant(int64_t value) const; 93 // Returns true if this symbolic expression is known to be different from 94 // `-1`. This is useful for reshapes. 95 bool isKnownNotNegativeOne() const; 96 // Returns true if thus symbolic expression is known to be different from 97 // `1`. This is useful for broadcasts. 98 bool isKnownNotOne() const; 99 // If this is a reference to a singular symbol, return it. 100 Optional<Symbol> singleton() const; 101 102 bool operator==(const SymbolicExpr &rhs) const { 103 return expr == rhs.expr && symbols == rhs.symbols; 104 } 105 bool operator!=(const SymbolicExpr &rhs) const { return !(*this == rhs); } 106 107 void dump(llvm::raw_ostream &os = llvm::outs()) const; 108 }; 109 110 using SymbolicExprsMap = DenseMap<ShapeOrValueInfo, std::vector<SymbolicExpr>, 111 ShapeOrValueInfo::DenseMapInfo>; 112 using SymbolicShapeConstraintsMap = DenseMap<int, Symbol>; 113 114 private: 115 // Mapping from the analysis requests to the results, i.e. to an array of 116 // symbolic expressions. This is essentially a cache for all the results of 117 // this analysis. 118 SymbolicExprsMap symbolicExprsMap; 119 120 // Mapping from symbolic shape constraints, derived from the argument 121 // attributes, to the symbols used in this analysis. 122 SymbolicShapeConstraintsMap symbolicShapeConstraintsMap; 123 124 // Run the analysis to request either shape or value information. 125 void compute(ShapeOrValueInfo v); 126 127 public: 128 // Return the computed components for the shape of a value, e.g., the 129 // dimensions of a tensor. 130 Optional<ArrayRef<SymbolicExpr>> GetShapeInfo(Value value); 131 // Return the computed components for the value of a value, e.g, the elements 132 // of a shape tensor. 133 Optional<ArrayRef<SymbolicExpr>> GetValueInfo(Value shape); 134 135 // Clear analysis data structures. 136 void reset(); 137 }; 138 } // namespace mlir 139 140 namespace llvm { 141 142 template <> 143 struct DenseMapInfo<mlir::ShapeComponentAnalysis::Symbol> { 144 static inline mlir::ShapeComponentAnalysis::Symbol getEmptyKey() { 145 return {mlir::ShapeComponentAnalysis::ShapeOrValueInfo::DenseMapInfo:: 146 getEmptyKey(), 147 llvm::DenseMapInfo<size_t>::getEmptyKey()}; 148 } 149 static inline mlir::ShapeComponentAnalysis::Symbol getTombstoneKey() { 150 return {mlir::ShapeComponentAnalysis::ShapeOrValueInfo::DenseMapInfo:: 151 getTombstoneKey(), 152 llvm::DenseMapInfo<size_t>::getTombstoneKey()}; 153 } 154 static unsigned getHashValue(mlir::ShapeComponentAnalysis::Symbol symbol) { 155 return llvm::hash_combine( 156 mlir::ShapeComponentAnalysis::ShapeOrValueInfo::DenseMapInfo:: 157 getHashValue(symbol.source), 158 llvm::DenseMapInfo<size_t>::getHashValue(symbol.index)); 159 } 160 static bool isEqual(mlir::ShapeComponentAnalysis::Symbol lhs, 161 mlir::ShapeComponentAnalysis::Symbol rhs) { 162 return lhs == rhs; 163 } 164 }; 165 166 } // namespace llvm 167 168 #endif // MLIR_HLO_ANALYSIS_SHAPE_COMPONENT_ANALYSIS_H 169