xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/utils/ParamsHash.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/util/irange.h>
4 #include <memory>
5 #include <mutex>
6 
7 namespace at::native {
8 
9 // Hashing machinery for Params
10 // Fowler–Noll–Vo hash function
11 // see
12 // https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function
13 template <typename Params>
14 struct ParamsHash {
15   // Params must be a POD because we read out its memory
16   // contents as char* when hashing
17   static_assert(std::is_standard_layout_v<Params>, "Params is not POD");
18 
operatorParamsHash19   size_t operator()(const Params& params) const {
20     auto ptr = reinterpret_cast<const uint8_t*>(&params);
21     uint32_t value = 0x811C9DC5;
22     for (const auto i : c10::irange(sizeof(Params))) {
23       value ^= ptr[i];
24       value *= 0x01000193;
25     }
26     return (size_t)value;
27   }
28 };
29 
30 template <typename Params>
31 struct ParamsEqual {
32   // Params must be a POD because we read out its memory
33   // contents as char* when comparing
34   static_assert(std::is_standard_layout_v<Params>, "Params is not POD");
35 
operatorParamsEqual36   bool operator()(const Params& a, const Params& b) const {
37     auto ptr1 = reinterpret_cast<const uint8_t*>(&a);
38     auto ptr2 = reinterpret_cast<const uint8_t*>(&b);
39     return memcmp(ptr1, ptr2, sizeof(Params)) == 0;
40   }
41 };
42 
43 // Provide explicit byte-for-byte constructors to avoid uwittingly leaving
44 // padding bytes unitialized (e.g., when passing Params by value)
45 template <typename T>
46 struct ParamsWrapper {
47   T pod;
48   static_assert(
49       std::is_standard_layout_v<T>,
50       "ParamsWrapper cannot wrap non-POD data");
51 
ParamsWrapperParamsWrapper52   ParamsWrapper() {
53     memset(&(this->pod), 0, sizeof(this->pod));
54   }
55 
ParamsWrapperParamsWrapper56   ParamsWrapper(const ParamsWrapper& other) {
57     memcpy(&(this->pod), &(other.pod), sizeof(this->pod));
58   }
59 
ParamsWrapperParamsWrapper60   ParamsWrapper(ParamsWrapper&& other) noexcept {
61     memcpy(&(this->pod), &(other.pod), sizeof(this->pod));
62   }
63 
64   ParamsWrapper& operator=(const ParamsWrapper& other) {
65     memcpy(&(this->pod), &(other.pod), sizeof(this->pod));
66     return *this;
67   }
68 
69   ParamsWrapper& operator=(ParamsWrapper&& other) noexcept {
70     memcpy(&(this->pod), &(other.pod), sizeof(this->pod));
71     return *this;
72   }
73 
74   inline friend bool operator==(
75       const ParamsWrapper& lhs,
76       const ParamsWrapper& rhs) noexcept {
77     auto ptr1 = reinterpret_cast<const uint8_t*>(&(lhs.pod));
78     auto ptr2 = reinterpret_cast<const uint8_t*>(&(rhs.pod));
79     return memcmp(ptr1, ptr2, sizeof(lhs.pod)) == 0;
80   }
81 };
82 
83 // Wrapped version: this allows the outer struct to have custom copy and move
84 // constructors for additional safety
85 template <typename ParamsWrapper>
86 struct ParamsWrapperHash {
87   // Params must be a POD because we read out its memory
88   // contents as char* when hashing
89   static_assert(
90       std::is_standard_layout_v<decltype(ParamsWrapper::pod)>,
91       "ParamsWrapper cannot wrap non-POD data");
92 
operatorParamsWrapperHash93   size_t operator()(const ParamsWrapper& params_wrapper) const {
94     auto ptr = reinterpret_cast<const uint8_t*>(&(params_wrapper.pod));
95     uint32_t value = 0x811C9DC5;
96     for (const auto i : c10::irange(sizeof(params_wrapper.pod))) {
97       value ^= ptr[i];
98       value *= 0x01000193;
99     }
100     return (size_t)value;
101   }
102 };
103 
104 } // namespace at::native
105