xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/transformers/cuda/flash_attn/philox.cuh (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Pytorch also has an implementation of Philox RNG: https://github.com/pytorch/pytorch/blob/8ca3c881db3e3510fcb7725389f6a0633c9b992c/torch/csrc/jit/tensorexpr/cuda_random.h
2 #pragma once
3 // Philox CUDA.
4 
5 #include <ATen/cuda/PhiloxUtils.cuh>
6 
7 namespace pytorch_flash{
8 
9 struct ull2 {
10     unsigned long long x;
11     unsigned long long y;
12 };
13 
mulhilo32(const unsigned int a,const unsigned int b)14 __forceinline__ __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) {
15     uint2 *res;
16     unsigned long long tmp;
17     asm ("mul.wide.u32 %0, %1, %2;\n\t"
18           : "=l"(tmp)
19           : "r"(a), "r"(b));
20     res = (uint2*)(&tmp);
21     return *res;
22 }
23 
philox_single_round(const uint4 ctr,const uint2 key)24 __forceinline__ __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) {
25     constexpr unsigned long kPhiloxSA = 0xD2511F53;
26     constexpr unsigned long kPhiloxSB = 0xCD9E8D57;
27     uint2 res0 = mulhilo32(kPhiloxSA, ctr.x);
28     uint2 res1 = mulhilo32(kPhiloxSB, ctr.z);
29     uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x};
30     return ret;
31 }
32 
philox(unsigned long long seed,unsigned long long subsequence,unsigned long long offset)33 __forceinline__ __device__ uint4 philox(unsigned long long seed,
34                                unsigned long long subsequence,
35                                unsigned long long offset) {
36     constexpr unsigned long kPhilox10A = 0x9E3779B9;
37     constexpr unsigned long kPhilox10B = 0xBB67AE85;
38     uint2 key = reinterpret_cast<uint2&>(seed);
39     uint4 counter;
40     ull2 *tmp = reinterpret_cast<ull2*>(&counter);
41     tmp->x = offset;
42     tmp->y = subsequence;
43     #pragma unroll
44     for (int i = 0; i < 6; i++) {
45         counter = philox_single_round(counter, key);
46         key.x += (kPhilox10A);
47         key.y += (kPhilox10B);
48     }
49     uint4 output = philox_single_round(counter, key);
50     return output;
51 }
52 
53 } // namespace flash
54