1 #pragma once 2 3 #include <torch/csrc/jit/tensorexpr/ir.h> 4 #include <torch/csrc/jit/tensorexpr/ir_printer.h> 5 #include <torch/csrc/jit/tensorexpr/ir_visitor.h> 6 #include <torch/csrc/jit/tensorexpr/tensor.h> 7 8 #include <utility> 9 10 namespace torch::jit::tensorexpr { 11 12 struct TORCH_API SimplifierHashType { 13 SimplifierHashType() = default; SimplifierHashTypeSimplifierHashType14 explicit SimplifierHashType(size_t s) : _h(s) {} 15 16 bool operator==(const SimplifierHashType& other) const; 17 bool operator!=(const SimplifierHashType& other) const; 18 bool operator<(const SimplifierHashType& other) const; 19 bool operator==(const size_t other) const; 20 bool operator!=(const size_t other) const; 21 22 size_t _h{0}; 23 }; 24 25 } // namespace torch::jit::tensorexpr 26 27 namespace std { 28 template <> 29 struct hash<torch::jit::tensorexpr::SimplifierHashType> { 30 size_t operator()(const torch::jit::tensorexpr::SimplifierHashType& k) const { 31 return k._h; 32 } 33 }; 34 35 } // namespace std 36 37 namespace torch::jit::tensorexpr { 38 39 #define CACHE_GUARD() \ 40 if (cachedHash(v)) { \ 41 return; \ 42 } 43 44 class Term; 45 class Polynomial; 46 47 /* Expression hasher providing comparable values representing sub-exprs. 48 * Uses memoization to avoid excessive recursion. */ 49 class TORCH_API HashProvider : public IRVisitor { 50 public: 51 template <class T> 52 SimplifierHashType hash(T e) { 53 e->accept(this); 54 return hashOf(e); 55 } 56 57 bool cachedHash(const ExprPtr& e) { 58 return exprToHash_.find(e) != exprToHash_.end(); 59 } 60 bool cachedHash(const StmtPtr& s) { 61 return stmtToHash_.find(s) != stmtToHash_.end(); 62 } 63 64 void clearCache() { 65 exprToHash_.clear(); 66 stmtToHash_.clear(); 67 } 68 69 void visit(const AddPtr& v) override; 70 void visit(const SubPtr& v) override; 71 void visit(const MulPtr& v) override; 72 void visit(const DivPtr& v) override; 73 void visit(const ModPtr& v) override; 74 void visit(const RoundOffPtr& v) override; 75 void visit(const MaxPtr& v) override; 76 void visit(const MinPtr& v) override; 77 void visit(const AndPtr& v) override; 78 void visit(const OrPtr& v) override; 79 void visit(const XorPtr& v) override; 80 void visit(const LshiftPtr& v) override; 81 void visit(const RshiftPtr& v) override; 82 void visit(const CompareSelectPtr& v) override; 83 84 #define IMM_VISIT(Type, Name) \ 85 void visit(const Name##ImmPtr& v) override { \ 86 CACHE_GUARD(); \ 87 putHash(v, hash_combine(#Name, v->value())); \ 88 } 89 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_VISIT); 90 #undef IMM_VISIT 91 92 void visit(const CastPtr& v) override; 93 void visit(const VarPtr& v) override; 94 void visit(const RampPtr& v) override; 95 void visit(const LoadPtr& v) override; 96 void visit(const StorePtr& v) override; 97 void visit(const BlockPtr& v) override; 98 void visit(const ForPtr& v) override; 99 void visit(const BroadcastPtr& v) override; 100 void visit(const IfThenElsePtr& v) override; 101 void visit(const IntrinsicsPtr& v) override; 102 void visit(const AllocatePtr& v) override; 103 void visit(const FreePtr& v) override; 104 void visit(const CondPtr& v) override; 105 void visit(const TermPtr& v) override; 106 void visit(const PolynomialPtr& v) override; 107 void visit(const MaxTermPtr& v) override; 108 void visit(const MinTermPtr& v) override; 109 110 template <typename... Types> 111 SimplifierHashType hash_combine(const Types&... args) { 112 SimplifierHashType seed; 113 _hash_combine(seed, args...); 114 return seed; 115 } 116 117 private: 118 SimplifierHashType hashOf(const ExprPtr& e) { 119 auto it = exprToHash_.find(e); 120 if (it != exprToHash_.end()) { 121 return it->second; 122 } 123 124 // As a failsafe fall back to IRPrinter. 125 std::stringstream ss; 126 IRPrinter printer(ss); 127 e->accept(&printer); 128 SimplifierHashType hash = SimplifierHashType(te_hash(ss.str())); 129 putHash(e, hash); 130 131 return hash; 132 } 133 134 SimplifierHashType hashOf(const StmtPtr& s) { 135 auto it = stmtToHash_.find(s); 136 if (it != stmtToHash_.end()) { 137 return it->second; 138 } 139 140 // As a failsafe fall back to IRPrinter. 141 std::stringstream ss; 142 IRPrinter printer(ss); 143 s->accept(&printer); 144 SimplifierHashType hash = SimplifierHashType(te_hash(ss.str())); 145 putHash(s, hash); 146 147 return hash; 148 } 149 150 // Hash funcs for various types, numbers are random. 151 template <typename T> 152 void _hash_combine(SimplifierHashType& seed, const T& val) { 153 seed._h ^= te_hash(val) + 0x1f752c19 + (seed._h << 7) + (seed._h >> 4); 154 } 155 156 void _hash_combine(SimplifierHashType& seed, const char* val) { 157 seed._h ^= te_hash(val) + 0x1f752c19 + (seed._h << 7) + (seed._h >> 4); 158 } 159 160 // at:::Half doesn't have a prime_number_hash, so cast to short. 161 void _hash_combine(SimplifierHashType& seed, const at::Half& val) { 162 seed._h ^= 163 te_hash((uint16_t)val) + 0x1f752c19 + (seed._h << 7) + (seed._h >> 4); 164 } 165 166 void _hash_combine(SimplifierHashType& seed, const Dtype& val) { 167 seed._h ^= te_hash(val.ToCppString()) + 0x1f752c19 + (seed._h << 7) + 168 (seed._h >> 4); 169 } 170 171 void _hash_combine(SimplifierHashType& seed, ExprPtr e) { 172 _hash_combine(seed, hash(std::move(e))); 173 } 174 175 template <typename T, typename... Types> 176 void _hash_combine( 177 SimplifierHashType& seed, 178 const T& val, 179 const Types&... args) { 180 _hash_combine(seed, val); 181 _hash_combine(seed, args...); 182 } 183 184 void putHash(const ExprPtr& e, SimplifierHashType h) { 185 auto res = exprToHash_.emplace(e, h); 186 if (res.second == false) { 187 // This is always a logic bug since we should check the cache first. 188 throw std::runtime_error("hash collision"); 189 } 190 } 191 void putHash(const StmtPtr& s, SimplifierHashType h) { 192 auto res = stmtToHash_.emplace(s, h); 193 if (res.second == false) { 194 // This is always a logic bug since we should check the cache first. 195 throw std::runtime_error("hash collision"); 196 } 197 } 198 199 std::unordered_map<ExprPtr, SimplifierHashType> exprToHash_; 200 std::unordered_map<StmtPtr, SimplifierHashType> stmtToHash_; 201 UniqueNameManager name_manager_; 202 203 size_t te_hash(SimplifierHashType val) { 204 return val._h; 205 } 206 207 size_t te_hash(int64_t val) { 208 // put the thing down. 209 size_t h = val ^ 0x647AA4D20C0B; 210 // bit flip it. 211 size_t h2 = ~h; 212 // and reverse byte order. 213 size_t h3 = 0; 214 for (unsigned int i = 0; i < 64; i += 8) { 215 h3 |= ((h2 >> i) & 0xFF) << (64 - i - 8); 216 } 217 return h3; 218 } 219 220 size_t te_hash(int32_t val) { 221 int64_t v2 = val; 222 return te_hash(v2); 223 } 224 225 size_t te_hash(uint32_t val) { 226 int64_t v2 = val; 227 return te_hash(v2); 228 } 229 230 size_t te_hash(uint64_t val) { 231 int64_t v2 = val; 232 return te_hash(v2); 233 } 234 235 size_t te_hash(int16_t val) { 236 int64_t v2 = val; 237 return te_hash(v2); 238 } 239 240 size_t te_hash(std::string val) { 241 size_t hash{0}; 242 int64_t intval{0}; 243 int64_t s = val.size() - 1; 244 while (s >= 0) { 245 for (unsigned int i = 0; i < 8; ++i) { 246 if (s < 0) 247 break; 248 int64_t c = val[s]; 249 intval |= (c << (i * 8)); 250 251 s--; 252 } 253 hash ^= te_hash(intval); 254 intval = 0; 255 } 256 257 return hash; 258 } 259 260 size_t te_hash(double d) { 261 int64_t* n = reinterpret_cast<int64_t*>(&d); 262 return te_hash(*n); 263 } 264 265 size_t te_hash(float d) { 266 int32_t* n = reinterpret_cast<int32_t*>(&d); 267 return te_hash(*n); 268 } 269 270 size_t te_hash(at::Half d) { 271 int16_t* n = reinterpret_cast<int16_t*>(&d); 272 return te_hash(*n); 273 } 274 275 size_t te_hash(at::BFloat16 d) { 276 int16_t* n = reinterpret_cast<int16_t*>(&d); 277 return te_hash(*n); 278 } 279 }; 280 281 } // namespace torch::jit::tensorexpr 282