1 #pragma once 2 3 #include <c10/util/irange.h> 4 #include <memory> 5 #include <mutex> 6 7 namespace at::native { 8 9 // Hashing machinery for Params 10 // Fowler–Noll–Vo hash function 11 // see 12 // https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function 13 template <typename Params> 14 struct ParamsHash { 15 // Params must be a POD because we read out its memory 16 // contents as char* when hashing 17 static_assert(std::is_standard_layout_v<Params>, "Params is not POD"); 18 operatorParamsHash19 size_t operator()(const Params& params) const { 20 auto ptr = reinterpret_cast<const uint8_t*>(¶ms); 21 uint32_t value = 0x811C9DC5; 22 for (const auto i : c10::irange(sizeof(Params))) { 23 value ^= ptr[i]; 24 value *= 0x01000193; 25 } 26 return (size_t)value; 27 } 28 }; 29 30 template <typename Params> 31 struct ParamsEqual { 32 // Params must be a POD because we read out its memory 33 // contents as char* when comparing 34 static_assert(std::is_standard_layout_v<Params>, "Params is not POD"); 35 operatorParamsEqual36 bool operator()(const Params& a, const Params& b) const { 37 auto ptr1 = reinterpret_cast<const uint8_t*>(&a); 38 auto ptr2 = reinterpret_cast<const uint8_t*>(&b); 39 return memcmp(ptr1, ptr2, sizeof(Params)) == 0; 40 } 41 }; 42 43 // Provide explicit byte-for-byte constructors to avoid uwittingly leaving 44 // padding bytes unitialized (e.g., when passing Params by value) 45 template <typename T> 46 struct ParamsWrapper { 47 T pod; 48 static_assert( 49 std::is_standard_layout_v<T>, 50 "ParamsWrapper cannot wrap non-POD data"); 51 ParamsWrapperParamsWrapper52 ParamsWrapper() { 53 memset(&(this->pod), 0, sizeof(this->pod)); 54 } 55 ParamsWrapperParamsWrapper56 ParamsWrapper(const ParamsWrapper& other) { 57 memcpy(&(this->pod), &(other.pod), sizeof(this->pod)); 58 } 59 ParamsWrapperParamsWrapper60 ParamsWrapper(ParamsWrapper&& other) noexcept { 61 memcpy(&(this->pod), &(other.pod), sizeof(this->pod)); 62 } 63 64 ParamsWrapper& operator=(const ParamsWrapper& other) { 65 memcpy(&(this->pod), &(other.pod), sizeof(this->pod)); 66 return *this; 67 } 68 69 ParamsWrapper& operator=(ParamsWrapper&& other) noexcept { 70 memcpy(&(this->pod), &(other.pod), sizeof(this->pod)); 71 return *this; 72 } 73 74 inline friend bool operator==( 75 const ParamsWrapper& lhs, 76 const ParamsWrapper& rhs) noexcept { 77 auto ptr1 = reinterpret_cast<const uint8_t*>(&(lhs.pod)); 78 auto ptr2 = reinterpret_cast<const uint8_t*>(&(rhs.pod)); 79 return memcmp(ptr1, ptr2, sizeof(lhs.pod)) == 0; 80 } 81 }; 82 83 // Wrapped version: this allows the outer struct to have custom copy and move 84 // constructors for additional safety 85 template <typename ParamsWrapper> 86 struct ParamsWrapperHash { 87 // Params must be a POD because we read out its memory 88 // contents as char* when hashing 89 static_assert( 90 std::is_standard_layout_v<decltype(ParamsWrapper::pod)>, 91 "ParamsWrapper cannot wrap non-POD data"); 92 operatorParamsWrapperHash93 size_t operator()(const ParamsWrapper& params_wrapper) const { 94 auto ptr = reinterpret_cast<const uint8_t*>(&(params_wrapper.pod)); 95 uint32_t value = 0x811C9DC5; 96 for (const auto i : c10::irange(sizeof(params_wrapper.pod))) { 97 value ^= ptr[i]; 98 value *= 0x01000193; 99 } 100 return (size_t)value; 101 } 102 }; 103 104 } // namespace at::native 105