xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/ActivationHardswishKernel.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #define _USE_MATH_DEFINES
3 
4 #include <ATen/native/Activation.h>
5 
6 #include <cmath>
7 
8 #include <thrust/tuple.h>
9 
10 #include <ATen/AccumulateType.h>
11 #include <ATen/Dispatch.h>
12 #include <ATen/core/TensorBase.h>
13 #include <c10/core/Scalar.h>
14 #include <c10/cuda/CUDAMathCompat.h>
15 #include <ATen/cuda/ApplyGridUtils.cuh>
16 #include <ATen/cuda/detail/OffsetCalculator.cuh>
17 #include <ATen/native/cuda/Loops.cuh>
18 
19 namespace at::native {
20 namespace {
21 
hardswish_kernel(TensorIterator & iter)22 void hardswish_kernel(TensorIterator& iter) {
23   AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "hardswish_cuda", [&]() {
24     using opmath_t = at::opmath_type<scalar_t>;
25     const opmath_t zero(0.0f);
26     const opmath_t one_sixth(1.0f / 6.0f);
27     const opmath_t three(3.0f);
28     const opmath_t six(6.0f);
29     gpu_kernel(iter, [zero, one_sixth, three, six]GPU_LAMBDA(scalar_t self_val) -> scalar_t {
30       opmath_t x = static_cast<opmath_t>(self_val);
31       return x * std::min(std::max(x + three, zero), six) * one_sixth;
32     });
33   });
34 }
35 
hardswish_backward_kernel(TensorIterator & iter)36 void hardswish_backward_kernel(TensorIterator& iter) {
37   AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "hardswish_backward_cuda", [&]() {
38     using opmath_t = at::opmath_type<scalar_t>;
39     const opmath_t zero(0.0f);
40     const opmath_t three(3.0f);
41     const opmath_t neg_three(-3.0f);
42     const opmath_t one_half(0.5f);
43     gpu_kernel(
44       iter,
45       [zero, three, neg_three, one_half]GPU_LAMBDA(scalar_t grad_val_, scalar_t self_val_) -> scalar_t {
46         opmath_t grad_val = static_cast<opmath_t>(grad_val_);
47         opmath_t self_val = static_cast<opmath_t>(self_val_);
48         if (self_val < neg_three) {
49           return zero;
50         } else if (self_val <= three) {
51           return grad_val * ((self_val / three) + one_half);
52         } else {
53           return grad_val;
54         }
55     });
56   });
57 }
58 } // namespace
59 
60 REGISTER_DISPATCH(hardswish_stub, &hardswish_kernel);
61 REGISTER_DISPATCH(hardswish_backward_stub, &hardswish_backward_kernel);
62 
63 } // namespace at::native
64