xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/detail/UnpackRaw.cuh (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // No "#pragma once" because this is a raw definition that can be copied by jit codegen.
2 // Eager mode clients should not include this file directly, instead,
3 // they should #include <ATen/cuda/PhiloxUtils.cuh>, which has a #pragma once.
4 
5 namespace at::cuda::philox {
6 
7 // In-kernel call to retrieve philox seed and offset from a PhiloxCudaState instance whether
8 // that instance was created with graph capture underway or not.
9 // See Note [CUDA Graph-safe RNG states].
10 //
11 // We can't write a __device__ function in CUDAGeneratorImpl.h, because it's in ATen.
12 // Also, whatever call unpacks PhiloxCudaState in consumer kernels must be inlineable.
13 // Easiest thing that comes to mind is, define a __device__ unpack helper here, in ATen/cuda.
14 //
15 // The raw definition lives in its own file so jit codegen can easily copy it.
16 __host__ __device__ __forceinline__ std::tuple<uint64_t, uint64_t>
unpack(at::PhiloxCudaState arg)17 unpack(at::PhiloxCudaState arg) {
18   if (arg.captured_) {
19     // static_cast avoids "warning: invalid narrowing conversion from "long" to "unsigned long".
20     // *(arg.offset_.ptr) is a broadcast load of a single int64_t to the entire kernel.
21     // For most threads' reads it will hit in cache, so it shouldn't hurt performance.
22     return std::make_tuple(static_cast<uint64_t>(*arg.seed_.ptr), static_cast<uint64_t>(*(arg.offset_.ptr) + arg.offset_intragraph_));
23   } else {
24     return std::make_tuple(arg.seed_.val, arg.offset_.val);
25   }
26 }
27 
28 } // namespace at::cuda::philox
29