1 // No "#pragma once" because this is a raw definition that can be copied by jit codegen. 2 // Eager mode clients should not include this file directly, instead, 3 // they should #include <ATen/cuda/PhiloxCudaState.h>, which has a #pragma once. 4 5 // Stores RNG state values. Passed as a kernel argument. 6 // See Note [CUDA Graph-safe RNG states]. 7 // 8 // The raw definition lives in its own file so jit codegen can easily copy it. 9 namespace at { 10 11 struct PhiloxCudaState { 12 PhiloxCudaState() = default; 13 // Called if graph capture is not underway PhiloxCudaStateat::PhiloxCudaState14 PhiloxCudaState(uint64_t seed, 15 uint64_t offset) { 16 seed_.val = seed; 17 offset_.val = offset; 18 } 19 // Called if graph capture is underway PhiloxCudaStateat::PhiloxCudaState20 PhiloxCudaState(int64_t* seed, 21 int64_t* offset_extragraph, 22 uint32_t offset_intragraph) { 23 seed_.ptr = seed; 24 offset_.ptr = offset_extragraph; 25 offset_intragraph_ = offset_intragraph; 26 captured_ = true; 27 } 28 29 // Public members, directly accessible by at::cuda::philox::unpack. 30 // If we made them private with getters/setters, the getters/setters 31 // would have to be __device__, and we can't declare __device__ in ATen. 32 union Payload { 33 uint64_t val; 34 int64_t* ptr; 35 }; 36 37 Payload seed_{}; 38 Payload offset_{}; 39 uint32_t offset_intragraph_ = 0; 40 bool captured_ = false; 41 }; 42 43 } // namespace at 44