xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/detail/PhiloxCudaStateRaw.cuh (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // No "#pragma once" because this is a raw definition that can be copied by jit codegen.
2 // Eager mode clients should not include this file directly, instead,
3 // they should #include <ATen/cuda/PhiloxCudaState.h>, which has a #pragma once.
4 
5 // Stores RNG state values. Passed as a kernel argument.
6 // See Note [CUDA Graph-safe RNG states].
7 //
8 // The raw definition lives in its own file so jit codegen can easily copy it.
9 namespace at {
10 
11 struct PhiloxCudaState {
12   PhiloxCudaState() = default;
13   // Called if graph capture is not underway
PhiloxCudaStateat::PhiloxCudaState14   PhiloxCudaState(uint64_t seed,
15                   uint64_t offset) {
16     seed_.val = seed;
17     offset_.val = offset;
18   }
19   // Called if graph capture is underway
PhiloxCudaStateat::PhiloxCudaState20   PhiloxCudaState(int64_t* seed,
21                   int64_t* offset_extragraph,
22                   uint32_t offset_intragraph) {
23     seed_.ptr = seed;
24     offset_.ptr = offset_extragraph;
25     offset_intragraph_ = offset_intragraph;
26     captured_ = true;
27   }
28 
29   // Public members, directly accessible by at::cuda::philox::unpack.
30   // If we made them private with getters/setters, the getters/setters
31   // would have to be __device__, and we can't declare __device__ in ATen.
32   union Payload {
33     uint64_t val;
34     int64_t* ptr;
35   };
36 
37   Payload seed_{};
38   Payload offset_{};
39   uint32_t offset_intragraph_ = 0;
40   bool captured_ = false;
41 };
42 
43 } // namespace at
44