xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/PhiloxRNGEngine.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 // define constants like M_PI and C keywords for MSVC
4 #ifdef _MSC_VER
5 #define _USE_MATH_DEFINES
6 #include <math.h>
7 #endif
8 
9 
10 #ifdef __CUDACC__
11 #include <cuda.h>
12 #endif
13 
14 #include <ATen/core/Array.h>
15 #include <c10/macros/Macros.h>
16 #include <c10/util/Exception.h>
17 #include <c10/util/Half.h>
18 #include <cmath>
19 #include <cstdint>
20 
21 namespace at {
22 
23 // typedefs for holding vector data
24 namespace detail {
25 
26 typedef at::detail::Array<uint32_t, 4> UINT4;
27 typedef at::detail::Array<uint32_t, 2> UINT2;
28 typedef at::detail::Array<double, 2> DOUBLE2;
29 typedef at::detail::Array<float, 2> FLOAT2;
30 
31 } // namespace detail
32 
33 /**
34  * Note [Philox Engine implementation]
35  * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
36  * Originally implemented in PyTorch's fusion compiler
37  * Refer to: http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
38  * for details regarding the engine.
39  *
40  * Note that currently this implementation of the philox engine is not used
41  * anywhere except for tests in cpu_generator_test.cpp. However, this engine
42  * will replace curandStatePhilox4_32_10_t in the future.
43  *
44  * The philox engine takes a seed value, a subsequeunce
45  * for starting the generation and an offset for the subsequence.
46  * Think of this engine as an algorithm producing a huge array. We are
47  * parallelizing this array by partitioning the huge array and assigning
48  * a thread index to each partition. In other words, each seed value
49  * (there are 2^64 possible seed values) gives a sub array of size
50  * 2^128 (each element in that array is a 128 bit number). Reasoning
51  * behind the array being of size 2^128 is, there are 2^64 possible
52  * thread index value and there is an array of size 2^64 for each of
53  * those thread index. Hence 2^64 * 2^64 = 2^128 for each seed value.
54  *
55  * In short, this generator can produce 2^64 (seed values) * 2^128 (number
56  * of elements in an array given by a seed value) = 2^192 values.
57  *
58  * Arguments:
59  * seed:        Seed values could be any number from 0 to 2^64-1.
60  * subsequence: Subsequence is just the cuda thread indexing with:
61  *              - blockIdx.x * blockDim.x + threadIdx.x
62  * offset:      The offset variable in PhiloxEngine  decides how many 128-bit
63  *              random numbers to skip (i.e. how many groups of 4, 32-bit numbers to skip)
64  *              and hence really decides the total number of randoms that can be achieved
65  *              for the given subsequence.
66  */
67 
68 class philox_engine {
69 public:
70 
71   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
72   C10_HOST_DEVICE inline explicit philox_engine(uint64_t seed = 67280421310721,
73                                  uint64_t subsequence = 0,
74                                  uint64_t offset = 0) {
75 
76     reset_state(seed, subsequence);
77     incr_n(offset);
78   }
79 
80   C10_HOST_DEVICE inline void reset_state(uint64_t seed = 67280421310721,
81                                  uint64_t subsequence = 0) {
82     key_[0] = static_cast<uint32_t>(seed);
83     key_[1] = static_cast<uint32_t>(seed >> 32);
84     counter_ = detail::UINT4(0);
85     counter_[2] = static_cast<uint32_t>(subsequence);
86     counter_[3] = static_cast<uint32_t>(subsequence >> 32);
87     STATE = 0;
88   }
89 
90   /**
91    * Set the offset field of Philox Generator to the desired offset.
92    */
set_offset(uint64_t offset)93   C10_HOST_DEVICE inline void set_offset(uint64_t offset) {
94     counter_[0] = static_cast<uint32_t>(offset);
95     counter_[1] = static_cast<uint32_t>(offset >> 32);
96   }
97 
98   /**
99    * Gets the current offset of the Philox Generator.
100    */
get_offset()101   C10_HOST_DEVICE uint64_t get_offset() const {
102     uint64_t lo = static_cast<uint64_t>(counter_[0]);
103     uint64_t hi = static_cast<uint64_t>(counter_[1]) << 32;
104     return lo | hi;
105   }
106 
107   /**
108    * Produces a unique 32-bit pseudo random number on every invocation. Bookeeps state to avoid waste.
109    */
operator()110   C10_HOST_DEVICE inline uint32_t operator()(int32_t n_rounds = 10) { // 10 here to preserve back-compat behavior
111     if(STATE == 0) {
112       detail::UINT4 counter = counter_;
113       detail::UINT2 key = key_;
114       output_ = rand(counter, key, n_rounds);
115       incr();
116     }
117     uint32_t ret = output_[static_cast<int>(STATE)];
118     STATE = (STATE + 1) & 3;
119     return ret;
120   }
121 
randn(uint32_t n_rounds)122   inline float randn(uint32_t n_rounds) {
123     #ifdef __CUDA_ARCH__
124     AT_ASSERT(false, "Unsupported invocation of randn on CUDA");
125     #endif
126     if(STATE == 0) {
127       detail::UINT4 counter = counter_;
128       detail::UINT2 key = key_;
129       output_ = rand(counter, key, n_rounds);
130       incr();
131     }
132     // TODO(min-jean-cho) change to Polar method, a more efficient version of Box-Muller method
133     // TODO(voz) We use std:: below, and thus need a separate impl for CUDA.
134     float u1 = 1 - uint32_to_uniform_float(output_[0]); // uint32_to_uniform_float returns [0,1), we need (0,1] to avoid passing 0 to log.
135     float u2 = 1 - uint32_to_uniform_float(output_[1]);
136     return static_cast<float>(std::sqrt(-2.0 * std::log(u1)) * std::cos(2.0 * M_PI * u2));
137   }
138 
139   /**
140    * Function that Skips N 128 bit numbers in a subsequence
141    */
incr_n(uint64_t n)142   C10_HOST_DEVICE inline void incr_n(uint64_t n) {
143     uint32_t nlo = static_cast<uint32_t>(n);
144     uint32_t nhi = static_cast<uint32_t>(n >> 32);
145     counter_[0] += nlo;
146     // if overflow in x has occurred, carry over to nhi
147     if (counter_[0] < nlo) {
148       nhi++;
149       // if overflow in nhi has occurred during carry over,
150       // propagate that overflow to y and exit to increment z
151       // otherwise return
152       counter_[1] += nhi;
153       if(nhi != 0) {
154         if (nhi <= counter_[1]) {
155           return;
156         }
157       }
158     } else {
159       // if overflow in y has occurred during addition,
160       // exit to increment z
161       // otherwise return
162       counter_[1] += nhi;
163       if (nhi <= counter_[1]) {
164         return;
165       }
166     }
167     if (++counter_[2])
168       return;
169     ++counter_[3];
170   }
171 
172   /**
173    * Function that Skips one 128 bit number in a subsequence
174    */
incr()175   C10_HOST_DEVICE inline void incr() {
176     if (++counter_[0])
177       return;
178     if (++counter_[1])
179       return;
180     if (++counter_[2]) {
181       return;
182     }
183     ++counter_[3];
184   }
185 
186 private:
187   detail::UINT4 counter_;
188   detail::UINT4 output_;
189   detail::UINT2 key_;
190   uint32_t STATE;
191 
mulhilo32(uint32_t a,uint32_t b,uint32_t * result_high)192   C10_HOST_DEVICE inline uint32_t mulhilo32(uint32_t a, uint32_t b,
193                                     uint32_t *result_high) {
194     #ifdef __CUDA_ARCH__
195       *result_high = __umulhi(a, b);
196       return a*b;
197     #else
198       const uint64_t product = static_cast<uint64_t>(a) * b;
199       *result_high = static_cast<uint32_t>(product >> 32);
200       return static_cast<uint32_t>(product);
201     #endif
202   }
203 
single_round(detail::UINT4 ctr,detail::UINT2 in_key)204   C10_HOST_DEVICE inline detail::UINT4 single_round(detail::UINT4 ctr, detail::UINT2 in_key) {
205     uint32_t hi0 = 0;
206     uint32_t hi1 = 0;
207     uint32_t lo0 = mulhilo32(kPhiloxSA, ctr[0], &hi0);
208     uint32_t lo1 = mulhilo32(kPhiloxSB, ctr[2], &hi1);
209     detail::UINT4 ret;
210     ret[0] = hi1 ^ ctr[1] ^ in_key[0];
211     ret[1] = lo1;
212     ret[2] = hi0 ^ ctr[3] ^ in_key[1];
213     ret[3] = lo0;
214     return ret;
215   }
216 
uint32_to_uniform_float(uint32_t value)217   C10_HOST_DEVICE constexpr float uint32_to_uniform_float(uint32_t value) {
218       // maximum value such that `MAX_INT * scale < 1.0` (with float rounding)
219       constexpr float scale = 4.6566127342e-10;
220       return static_cast<float>(value & 0x7FFFFFFF) * scale;
221   }
222 
223 
224 
rand(detail::UINT4 & counter,detail::UINT2 & key,uint32_t n_rounds)225   C10_HOST_DEVICE inline detail::UINT4 rand(detail::UINT4& counter, detail::UINT2& key, uint32_t n_rounds) {
226     for (uint32_t round = 0; round < (n_rounds - 1); round++) {
227         counter = single_round(counter, key);
228         key[0] += (kPhilox10A); key[1] += (kPhilox10B);
229       }
230     return single_round(counter, key);
231   }
232 
233 
234   static const uint32_t kPhilox10A = 0x9E3779B9;
235   static const uint32_t kPhilox10B = 0xBB67AE85;
236   static const uint32_t kPhiloxSA = 0xD2511F53;
237   static const uint32_t kPhiloxSB = 0xCD9E8D57;
238 };
239 
240 typedef philox_engine Philox4_32;
241 
242 } // namespace at
243