1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef MLIR_HLO_ANALYSIS_SHAPE_COMPONENT_ANALYSIS_H
17 #define MLIR_HLO_ANALYSIS_SHAPE_COMPONENT_ANALYSIS_H
18 
19 #include "llvm/Support/raw_ostream.h"
20 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
21 #include "mlir/IR/AffineMap.h"
22 #include "mlir/IR/Value.h"
23 
24 namespace mlir {
25 
26 // Analysis to infer shape information.
27 //
28 // This lazily analyzes the individual components of a shape (e.g., the
29 // dimensions of a tensor) or value (e.g, the elements of a shape tensor).
30 // Results are cached but the cache is not consistent across IR mutations and
31 // needs to be reset in that case.
32 class ShapeComponentAnalysis {
33  public:
34   // Represents the analysis request for a specific value. We are either
35   // interested in the shape of a value or the value itself.
36   class ShapeOrValueInfo {
37     llvm::PointerIntPair<Value, 1, bool> p;
38 
ShapeOrValueInfo(decltype (p)p)39     explicit ShapeOrValueInfo(decltype(p) p) : p(p) {}
ShapeOrValueInfo(Value v,bool isValueInfo)40     ShapeOrValueInfo(Value v, bool isValueInfo) : p(v, isValueInfo) {}
41 
42    public:
getShapeInfoOf(Value v)43     static ShapeOrValueInfo getShapeInfoOf(Value v) { return {v, false}; }
getValueInfoOf(Value v)44     static ShapeOrValueInfo getValueInfoOf(Value v) { return {v, true}; }
value()45     Value value() const { return p.getPointer(); }
isValueInfo()46     bool isValueInfo() const { return p.getInt(); }
isShapeInfo()47     bool isShapeInfo() const { return !isValueInfo(); }
48 
49     bool operator==(ShapeOrValueInfo rhs) const { return p == rhs.p; }
50     bool operator!=(ShapeOrValueInfo rhs) const { return !(*this == rhs); }
51 
52     // Forward p's DenseMapInfo.
53     struct DenseMapInfo {
54       using PairInfo = llvm::DenseMapInfo<decltype(p)>;
getEmptyKeyDenseMapInfo55       static inline ShapeOrValueInfo getEmptyKey() {
56         return ShapeOrValueInfo(PairInfo::getEmptyKey());
57       }
getTombstoneKeyDenseMapInfo58       static inline ShapeOrValueInfo getTombstoneKey() {
59         return ShapeOrValueInfo(PairInfo::getTombstoneKey());
60       }
getHashValueDenseMapInfo61       static unsigned getHashValue(ShapeOrValueInfo val) {
62         return PairInfo::getHashValue(val.p);
63       }
isEqualDenseMapInfo64       static bool isEqual(ShapeOrValueInfo lhs, ShapeOrValueInfo rhs) {
65         return lhs == rhs;
66       }
67     };
68   };
69 
70   // Symbolically represents one component of a shape (e.g., the dimensions of a
71   // tensor) or value (e.g, the elements of a shape tensor). This is used to tie
72   // symbolic expressions to components of shapes or values.
73   struct Symbol {
74     ShapeOrValueInfo source;
75     size_t index;
76 
77     bool operator==(const Symbol &rhs) const {
78       return source == rhs.source && index == rhs.index;
79     }
80     bool operator!=(const Symbol &rhs) const { return !(*this == rhs); }
81   };
82 
83   // Represents the analysis result for a one component of a shape (e.g., the
84   // dimensions of a tensor) or value (e.g, the elements of a shape tensor).
85   // This can be a constant or an expression over symbols.
86   struct SymbolicExpr {
87     SmallVector<Symbol, 1> symbols;
88     AffineExpr expr;
89 
90     // Returns true if this symbolic expression is known to be a constant equal
91     // to `value`.
92     bool isConstant(int64_t value) const;
93     // Returns true if this symbolic expression is known to be different from
94     // `-1`. This is useful for reshapes.
95     bool isKnownNotNegativeOne() const;
96     // Returns true if thus symbolic expression is known to be different from
97     // `1`. This is useful for broadcasts.
98     bool isKnownNotOne() const;
99     // If this is a reference to a singular symbol, return it.
100     Optional<Symbol> singleton() const;
101 
102     bool operator==(const SymbolicExpr &rhs) const {
103       return expr == rhs.expr && symbols == rhs.symbols;
104     }
105     bool operator!=(const SymbolicExpr &rhs) const { return !(*this == rhs); }
106 
107     void dump(llvm::raw_ostream &os = llvm::outs()) const;
108   };
109 
110   using SymbolicExprsMap = DenseMap<ShapeOrValueInfo, std::vector<SymbolicExpr>,
111                                     ShapeOrValueInfo::DenseMapInfo>;
112   using SymbolicShapeConstraintsMap = DenseMap<int, Symbol>;
113 
114  private:
115   // Mapping from the analysis requests to the results, i.e. to an array of
116   // symbolic expressions. This is essentially a cache for all the results of
117   // this analysis.
118   SymbolicExprsMap symbolicExprsMap;
119 
120   // Mapping from symbolic shape constraints, derived from the argument
121   // attributes, to the symbols used in this analysis.
122   SymbolicShapeConstraintsMap symbolicShapeConstraintsMap;
123 
124   // Run the analysis to request either shape or value information.
125   void compute(ShapeOrValueInfo v);
126 
127  public:
128   // Return the computed components for the shape of a value, e.g., the
129   // dimensions of a tensor.
130   Optional<ArrayRef<SymbolicExpr>> GetShapeInfo(Value value);
131   // Return the computed components for the value of a value, e.g, the elements
132   // of a shape tensor.
133   Optional<ArrayRef<SymbolicExpr>> GetValueInfo(Value shape);
134 
135   // Clear analysis data structures.
136   void reset();
137 };
138 }  // namespace mlir
139 
140 namespace llvm {
141 
142 template <>
143 struct DenseMapInfo<mlir::ShapeComponentAnalysis::Symbol> {
144   static inline mlir::ShapeComponentAnalysis::Symbol getEmptyKey() {
145     return {mlir::ShapeComponentAnalysis::ShapeOrValueInfo::DenseMapInfo::
146                 getEmptyKey(),
147             llvm::DenseMapInfo<size_t>::getEmptyKey()};
148   }
149   static inline mlir::ShapeComponentAnalysis::Symbol getTombstoneKey() {
150     return {mlir::ShapeComponentAnalysis::ShapeOrValueInfo::DenseMapInfo::
151                 getTombstoneKey(),
152             llvm::DenseMapInfo<size_t>::getTombstoneKey()};
153   }
154   static unsigned getHashValue(mlir::ShapeComponentAnalysis::Symbol symbol) {
155     return llvm::hash_combine(
156         mlir::ShapeComponentAnalysis::ShapeOrValueInfo::DenseMapInfo::
157             getHashValue(symbol.source),
158         llvm::DenseMapInfo<size_t>::getHashValue(symbol.index));
159   }
160   static bool isEqual(mlir::ShapeComponentAnalysis::Symbol lhs,
161                       mlir::ShapeComponentAnalysis::Symbol rhs) {
162     return lhs == rhs;
163   }
164 };
165 
166 }  // namespace llvm
167 
168 #endif  // MLIR_HLO_ANALYSIS_SHAPE_COMPONENT_ANALYSIS_H
169