xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/RreluWithNoise.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/cuda/CUDAGeneratorImpl.h>
4 #include <ATen/native/cuda/DistributionTemplates.h>
5 #include <ATen/native/Resize.h>
6 
7 #ifndef AT_PER_OPERATOR_HEADERS
8 #include <ATen/Functions.h>
9 #include <ATen/NativeFunctions.h>
10 #else
11 #include <ATen/ops/empty_like.h>
12 #include <ATen/ops/leaky_relu.h>
13 #include <ATen/ops/rrelu_with_noise_native.h>
14 #endif
15 
16 
17 namespace at::native {
18 
19 template <typename scalar_t, int unroll_factor, typename F>
20 #if __CUDA_ARCH__ >= 350 || defined USE_ROCM
21 C10_LAUNCH_BOUNDS_2(256, 4)
22 #endif
rrelu_with_noise_cuda_kernel(int numel,PhiloxCudaState philox_args,scalar_t * output,const scalar_t * input,scalar_t * noise,double lower,double upper,const F & random_func)23 __global__ void rrelu_with_noise_cuda_kernel(
24     int numel,
25     PhiloxCudaState philox_args,
26     scalar_t* output,
27     const scalar_t* input,
28     scalar_t* noise,
29     double lower,
30     double upper,
31     const F& random_func) {
32   auto seeds = at::cuda::philox::unpack(philox_args);
33   int idx = blockIdx.x * blockDim.x + threadIdx.x;
34   curandStatePhilox4_32_10_t state;
35   curand_init(std::get<0>(seeds),
36               idx,
37               std::get<1>(seeds),
38               &state);
39 
40   int grid_stride = blockDim.x * gridDim.x * unroll_factor;
41   int rounded_size = ((numel - 1) / grid_stride + 1) * grid_stride;
42   double range = upper - lower;
43 
44   for (int linear_index = idx; linear_index < rounded_size; linear_index += grid_stride) {
45     auto rand = random_func(&state);
46 
47     // ensure that (&rand.x)[ii] is safe
48     static_assert(sizeof(rand)/sizeof(rand.x) == unroll_factor, "");
49 
50     #pragma unroll
51     for (int ii = 0; ii < unroll_factor; ii++) {
52       int li = linear_index + blockDim.x * gridDim.x * ii;
53       if (li >= numel) {
54         continue;
55       }
56       scalar_t r = static_cast<scalar_t>((&rand.x)[ii]);
57       r = r * range + lower;
58       if (input[li] <= 0) {
59         output[li] = input[li] * r;
60         noise[li] = r;
61       } else {
62         output[li] = input[li];
63         noise[li] = static_cast<scalar_t>(1);
64       }
65     }
66     __syncthreads();
67   }
68 }
69 
70 template <typename scalar_t>
_rrelu_with_noise_cuda_train(Tensor & output,const Tensor & input_,const Tensor & noise_,const Scalar & lower_,const Scalar & upper_,std::optional<Generator> generator)71 inline void _rrelu_with_noise_cuda_train(
72     Tensor& output,
73     const Tensor& input_,
74     const Tensor& noise_,
75     const Scalar& lower_,
76     const Scalar& upper_,
77     std::optional<Generator> generator) {
78   auto input = input_.contiguous();
79   auto noise = noise_.contiguous();
80   Tensor tmp_output = output.contiguous();
81 
82   int64_t numel = input.numel();
83   const int unroll_factor = std::is_same<scalar_t, double>::value ? 2 : 4;
84   auto execution_policy = calc_execution_policy(numel, unroll_factor);
85 
86   auto counter_offset = std::get<0>(execution_policy);
87   auto grid = std::get<1>(execution_policy);
88   auto block = std::get<2>(execution_policy);
89 
90   auto gen = get_generator_or_default<CUDAGeneratorImpl>(
91       generator, cuda::detail::getDefaultCUDAGenerator());
92   PhiloxCudaState rng_engine_inputs;
93   {
94     // See Note [Acquire lock when using random generators]
95     std::lock_guard<std::mutex> lock(gen->mutex_);
96     rng_engine_inputs = gen->philox_cuda_state(counter_offset);
97   }
98 
99   const scalar_t* input_data = input.const_data_ptr<scalar_t>();
100   scalar_t* noise_data = noise.mutable_data_ptr<scalar_t>();
101   scalar_t* output_data = tmp_output.mutable_data_ptr<scalar_t>();
102 
103   double lower = lower_.to<double>();
104   double upper = upper_.to<double>();
105 
106   auto stream = at::cuda::getCurrentCUDAStream();
107 
108   if (std::is_same<scalar_t, double>::value) {
109     rrelu_with_noise_cuda_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
110         numel,
111         rng_engine_inputs,
112         output_data,
113         input_data,
114         noise_data,
115         lower,
116         upper,
117         [] __device__ (curandStatePhilox4_32_10_t* state) {
118           return curand_uniform2_double(state);
119         });
120         C10_CUDA_KERNEL_LAUNCH_CHECK();
121   } else {
122     // half and float
123     rrelu_with_noise_cuda_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
124         numel,
125         rng_engine_inputs,
126         output_data,
127         input_data,
128         noise_data,
129         lower, upper,
130         [] __device__ (curandStatePhilox4_32_10_t* state) {
131           return curand_uniform4(state);
132         });
133         C10_CUDA_KERNEL_LAUNCH_CHECK();
134   }
135 
136   if (!output.is_contiguous()) {
137     output.copy_(tmp_output);
138   }
139 }
140 
rrelu_with_noise_out_cuda(const Tensor & self,const Tensor & noise,const Scalar & lower,const Scalar & upper,bool training,std::optional<Generator> generator,Tensor & output)141 Tensor& rrelu_with_noise_out_cuda(const Tensor& self,
142     const Tensor& noise,
143     const Scalar& lower,
144     const Scalar& upper,
145     bool training,
146     std::optional<Generator> generator,
147     Tensor& output) {
148   at::native::resize_output(output, self.sizes());
149 
150   if (self.numel() == 0) {
151     return output;
152   }
153 
154   TensorArg self_arg{self, "self", 1}, noise_arg{noise, "noise", 2},
155       output_arg{output, "output", 3};
156   checkAllSameGPU("rrelu_with_noise_out_cuda", {self_arg, noise_arg, output_arg});
157 
158   if (training) {
159     AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
160         self.scalar_type(), "rrelu_with_noise_out_cuda", [&] {
161           _rrelu_with_noise_cuda_train<scalar_t>(
162               output, self, noise, lower, upper, generator);
163         });
164   }
165   else {
166     auto lower_tensor = lower.to<double>();
167     auto upper_tensor = upper.to<double>();
168     Scalar negative_slope = (lower_tensor + upper_tensor) / 2;
169     at::leaky_relu_out(output, self, negative_slope);
170   }
171   return output;
172 }
173 
rrelu_with_noise_cuda(const Tensor & self,const Tensor & noise,const Scalar & lower,const Scalar & upper,bool training,std::optional<Generator> generator)174 Tensor rrelu_with_noise_cuda(
175     const Tensor& self,
176     const Tensor& noise,
177     const Scalar& lower,
178     const Scalar& upper,
179     bool training,
180     std::optional<Generator> generator) {
181   Tensor output = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
182   return at::native::rrelu_with_noise_out_cuda(self, noise, lower, upper, training, generator, output);
183 }
184 
rrelu_with_noise_cuda_(Tensor & self,const Tensor & noise,const Scalar & lower,const Scalar & upper,bool training,std::optional<Generator> generator)185 Tensor& rrelu_with_noise_cuda_(
186     Tensor& self,
187     const Tensor& noise,
188     const Scalar& lower,
189     const Scalar& upper,
190     bool training,
191     std::optional<Generator> generator) {
192   return at::native::rrelu_with_noise_out_cuda(
193       self, noise, lower, upper, training, generator, self);
194 }
195 
196 }  // namespace at::native
197