xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/ActivationHardtanhKernel.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #define _USE_MATH_DEFINES
3 
4 #include <ATen/native/Activation.h>
5 
6 #include <cmath>
7 
8 #include <thrust/tuple.h>
9 
10 #include <ATen/AccumulateType.h>
11 #include <ATen/Dispatch.h>
12 #include <ATen/core/TensorBase.h>
13 #include <c10/core/Scalar.h>
14 #include <c10/cuda/CUDAMathCompat.h>
15 #include <ATen/cuda/ApplyGridUtils.cuh>
16 #include <ATen/cuda/detail/OffsetCalculator.cuh>
17 #include <ATen/native/cuda/Loops.cuh>
18 
19 namespace at::native {
20 namespace {
21 
hardtanh_backward_kernel(TensorIterator & iter,const Scalar & min,const Scalar & max)22 void hardtanh_backward_kernel(
23     TensorIterator& iter,
24     const Scalar& min,
25     const Scalar& max) {
26   AT_DISPATCH_FLOATING_TYPES_AND2(
27       at::ScalarType::Half, at::ScalarType::BFloat16,
28       iter.dtype(), "hardtanh_backward_cuda", [&]() {
29         using opmath_t = at::opmath_type<scalar_t>;
30         auto min_val = min.to<opmath_t>();
31         auto max_val = max.to<opmath_t>();
32         gpu_kernel(
33             iter,
34             [min_val, max_val] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
35               opmath_t aop = static_cast<opmath_t>(a);
36               opmath_t bop = static_cast<opmath_t>(b);
37               return (bop <= min_val) || (bop >= max_val) ? opmath_t(0) : aop;
38             });
39       });
40 }
41 } // namespace
42 
43 REGISTER_DISPATCH(hardtanh_backward_stub, &hardtanh_backward_kernel);
44 
45 } // namespace at::native
46