xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/ActivationThresholdKernel.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #define _USE_MATH_DEFINES
3 
4 #include <ATen/native/Activation.h>
5 
6 #include <cmath>
7 
8 #include <thrust/tuple.h>
9 
10 #include <ATen/AccumulateType.h>
11 #include <ATen/Dispatch.h>
12 #include <ATen/core/TensorBase.h>
13 #include <c10/core/Scalar.h>
14 #include <c10/cuda/CUDAMathCompat.h>
15 #include <ATen/cuda/ApplyGridUtils.cuh>
16 #include <ATen/cuda/detail/OffsetCalculator.cuh>
17 #include <ATen/native/cuda/Loops.cuh>
18 
19 namespace at::native {
20 namespace {
21 
22 template <typename scalar_t>
threshold_kernel_impl(TensorIteratorBase & iter,scalar_t threshold,scalar_t value)23 void threshold_kernel_impl(
24     TensorIteratorBase& iter,
25     scalar_t threshold,
26     scalar_t value) {
27   gpu_kernel_with_scalars(
28       iter, [=] GPU_LAMBDA(scalar_t x, scalar_t other) -> scalar_t {
29         return x <= threshold ? value : other;
30       });
31 }
32 
threshold_kernel_cuda(TensorIteratorBase & iter,const Scalar & threshold,const Scalar & value)33 static void threshold_kernel_cuda(
34     TensorIteratorBase& iter,
35     const Scalar& threshold,
36     const Scalar& value) {
37   AT_DISPATCH_ALL_TYPES_AND2(
38       at::ScalarType::Half,
39       at::ScalarType::BFloat16,
40       iter.dtype(),
41       "threshold_cuda",
42       [&] {
43         threshold_kernel_impl<scalar_t>(
44             iter, threshold.to<scalar_t>(), value.to<scalar_t>());
45       });
46 }
47 
48 } // namespace
49 
50 REGISTER_DISPATCH(threshold_stub, &threshold_kernel_cuda);
51 
52 } // namespace at::native
53