xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/hash_provider.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/jit/tensorexpr/ir.h>
4 #include <torch/csrc/jit/tensorexpr/ir_printer.h>
5 #include <torch/csrc/jit/tensorexpr/ir_visitor.h>
6 #include <torch/csrc/jit/tensorexpr/tensor.h>
7 
8 #include <utility>
9 
10 namespace torch::jit::tensorexpr {
11 
12 struct TORCH_API SimplifierHashType {
13   SimplifierHashType() = default;
SimplifierHashTypeSimplifierHashType14   explicit SimplifierHashType(size_t s) : _h(s) {}
15 
16   bool operator==(const SimplifierHashType& other) const;
17   bool operator!=(const SimplifierHashType& other) const;
18   bool operator<(const SimplifierHashType& other) const;
19   bool operator==(const size_t other) const;
20   bool operator!=(const size_t other) const;
21 
22   size_t _h{0};
23 };
24 
25 } // namespace torch::jit::tensorexpr
26 
27 namespace std {
28 template <>
29 struct hash<torch::jit::tensorexpr::SimplifierHashType> {
30   size_t operator()(const torch::jit::tensorexpr::SimplifierHashType& k) const {
31     return k._h;
32   }
33 };
34 
35 } // namespace std
36 
37 namespace torch::jit::tensorexpr {
38 
39 #define CACHE_GUARD()  \
40   if (cachedHash(v)) { \
41     return;            \
42   }
43 
44 class Term;
45 class Polynomial;
46 
47 /* Expression hasher providing comparable values representing sub-exprs.
48  * Uses memoization to avoid excessive recursion. */
49 class TORCH_API HashProvider : public IRVisitor {
50  public:
51   template <class T>
52   SimplifierHashType hash(T e) {
53     e->accept(this);
54     return hashOf(e);
55   }
56 
57   bool cachedHash(const ExprPtr& e) {
58     return exprToHash_.find(e) != exprToHash_.end();
59   }
60   bool cachedHash(const StmtPtr& s) {
61     return stmtToHash_.find(s) != stmtToHash_.end();
62   }
63 
64   void clearCache() {
65     exprToHash_.clear();
66     stmtToHash_.clear();
67   }
68 
69   void visit(const AddPtr& v) override;
70   void visit(const SubPtr& v) override;
71   void visit(const MulPtr& v) override;
72   void visit(const DivPtr& v) override;
73   void visit(const ModPtr& v) override;
74   void visit(const RoundOffPtr& v) override;
75   void visit(const MaxPtr& v) override;
76   void visit(const MinPtr& v) override;
77   void visit(const AndPtr& v) override;
78   void visit(const OrPtr& v) override;
79   void visit(const XorPtr& v) override;
80   void visit(const LshiftPtr& v) override;
81   void visit(const RshiftPtr& v) override;
82   void visit(const CompareSelectPtr& v) override;
83 
84 #define IMM_VISIT(Type, Name)                    \
85   void visit(const Name##ImmPtr& v) override {   \
86     CACHE_GUARD();                               \
87     putHash(v, hash_combine(#Name, v->value())); \
88   }
89   AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_VISIT);
90 #undef IMM_VISIT
91 
92   void visit(const CastPtr& v) override;
93   void visit(const VarPtr& v) override;
94   void visit(const RampPtr& v) override;
95   void visit(const LoadPtr& v) override;
96   void visit(const StorePtr& v) override;
97   void visit(const BlockPtr& v) override;
98   void visit(const ForPtr& v) override;
99   void visit(const BroadcastPtr& v) override;
100   void visit(const IfThenElsePtr& v) override;
101   void visit(const IntrinsicsPtr& v) override;
102   void visit(const AllocatePtr& v) override;
103   void visit(const FreePtr& v) override;
104   void visit(const CondPtr& v) override;
105   void visit(const TermPtr& v) override;
106   void visit(const PolynomialPtr& v) override;
107   void visit(const MaxTermPtr& v) override;
108   void visit(const MinTermPtr& v) override;
109 
110   template <typename... Types>
111   SimplifierHashType hash_combine(const Types&... args) {
112     SimplifierHashType seed;
113     _hash_combine(seed, args...);
114     return seed;
115   }
116 
117  private:
118   SimplifierHashType hashOf(const ExprPtr& e) {
119     auto it = exprToHash_.find(e);
120     if (it != exprToHash_.end()) {
121       return it->second;
122     }
123 
124     // As a failsafe fall back to IRPrinter.
125     std::stringstream ss;
126     IRPrinter printer(ss);
127     e->accept(&printer);
128     SimplifierHashType hash = SimplifierHashType(te_hash(ss.str()));
129     putHash(e, hash);
130 
131     return hash;
132   }
133 
134   SimplifierHashType hashOf(const StmtPtr& s) {
135     auto it = stmtToHash_.find(s);
136     if (it != stmtToHash_.end()) {
137       return it->second;
138     }
139 
140     // As a failsafe fall back to IRPrinter.
141     std::stringstream ss;
142     IRPrinter printer(ss);
143     s->accept(&printer);
144     SimplifierHashType hash = SimplifierHashType(te_hash(ss.str()));
145     putHash(s, hash);
146 
147     return hash;
148   }
149 
150   // Hash funcs for various types, numbers are random.
151   template <typename T>
152   void _hash_combine(SimplifierHashType& seed, const T& val) {
153     seed._h ^= te_hash(val) + 0x1f752c19 + (seed._h << 7) + (seed._h >> 4);
154   }
155 
156   void _hash_combine(SimplifierHashType& seed, const char* val) {
157     seed._h ^= te_hash(val) + 0x1f752c19 + (seed._h << 7) + (seed._h >> 4);
158   }
159 
160   // at:::Half doesn't have a prime_number_hash, so cast to short.
161   void _hash_combine(SimplifierHashType& seed, const at::Half& val) {
162     seed._h ^=
163         te_hash((uint16_t)val) + 0x1f752c19 + (seed._h << 7) + (seed._h >> 4);
164   }
165 
166   void _hash_combine(SimplifierHashType& seed, const Dtype& val) {
167     seed._h ^= te_hash(val.ToCppString()) + 0x1f752c19 + (seed._h << 7) +
168         (seed._h >> 4);
169   }
170 
171   void _hash_combine(SimplifierHashType& seed, ExprPtr e) {
172     _hash_combine(seed, hash(std::move(e)));
173   }
174 
175   template <typename T, typename... Types>
176   void _hash_combine(
177       SimplifierHashType& seed,
178       const T& val,
179       const Types&... args) {
180     _hash_combine(seed, val);
181     _hash_combine(seed, args...);
182   }
183 
184   void putHash(const ExprPtr& e, SimplifierHashType h) {
185     auto res = exprToHash_.emplace(e, h);
186     if (res.second == false) {
187       // This is always a logic bug since we should check the cache first.
188       throw std::runtime_error("hash collision");
189     }
190   }
191   void putHash(const StmtPtr& s, SimplifierHashType h) {
192     auto res = stmtToHash_.emplace(s, h);
193     if (res.second == false) {
194       // This is always a logic bug since we should check the cache first.
195       throw std::runtime_error("hash collision");
196     }
197   }
198 
199   std::unordered_map<ExprPtr, SimplifierHashType> exprToHash_;
200   std::unordered_map<StmtPtr, SimplifierHashType> stmtToHash_;
201   UniqueNameManager name_manager_;
202 
203   size_t te_hash(SimplifierHashType val) {
204     return val._h;
205   }
206 
207   size_t te_hash(int64_t val) {
208     // put the thing down.
209     size_t h = val ^ 0x647AA4D20C0B;
210     // bit flip it.
211     size_t h2 = ~h;
212     // and reverse byte order.
213     size_t h3 = 0;
214     for (unsigned int i = 0; i < 64; i += 8) {
215       h3 |= ((h2 >> i) & 0xFF) << (64 - i - 8);
216     }
217     return h3;
218   }
219 
220   size_t te_hash(int32_t val) {
221     int64_t v2 = val;
222     return te_hash(v2);
223   }
224 
225   size_t te_hash(uint32_t val) {
226     int64_t v2 = val;
227     return te_hash(v2);
228   }
229 
230   size_t te_hash(uint64_t val) {
231     int64_t v2 = val;
232     return te_hash(v2);
233   }
234 
235   size_t te_hash(int16_t val) {
236     int64_t v2 = val;
237     return te_hash(v2);
238   }
239 
240   size_t te_hash(std::string val) {
241     size_t hash{0};
242     int64_t intval{0};
243     int64_t s = val.size() - 1;
244     while (s >= 0) {
245       for (unsigned int i = 0; i < 8; ++i) {
246         if (s < 0)
247           break;
248         int64_t c = val[s];
249         intval |= (c << (i * 8));
250 
251         s--;
252       }
253       hash ^= te_hash(intval);
254       intval = 0;
255     }
256 
257     return hash;
258   }
259 
260   size_t te_hash(double d) {
261     int64_t* n = reinterpret_cast<int64_t*>(&d);
262     return te_hash(*n);
263   }
264 
265   size_t te_hash(float d) {
266     int32_t* n = reinterpret_cast<int32_t*>(&d);
267     return te_hash(*n);
268   }
269 
270   size_t te_hash(at::Half d) {
271     int16_t* n = reinterpret_cast<int16_t*>(&d);
272     return te_hash(*n);
273   }
274 
275   size_t te_hash(at::BFloat16 d) {
276     int16_t* n = reinterpret_cast<int16_t*>(&d);
277     return te_hash(*n);
278   }
279 };
280 
281 } // namespace torch::jit::tensorexpr
282